tree-dp入门2

tree dp

tutorial

下面主要是对上面文章的复现

问题1

给一棵树,每个节点都有一个权值,要求选择一个节点的集合,使得它们不相邻,并且权值的和最大

我们先想一想在线性的数组上面进行,很容易列出下面的dp方程:
$dp[i] = max(dp[i-1], dp[i-2]+val[i]) $

下面我们考虑树形的结构
$ dp[v] = max(\sum _{i=1}^n dp[v_i], val[v]+dp[all child of v_i])$

为了减少寻找的次数,我们使用两个数组记录变量:
dp1[i]表示的是选择i节点的最优值
dp2[i]表示的是不选择i节点的最优值
具体的细节可以看看代码

问题2

先确定一个节点,然后求经过这个节点的最长链
要求O(1)的查询经过每一个节点的最长链
我们保存经过这个节点的很多条链,然后最长的两条相加就可以了

问题3

Given a tree T of N nodes and an integer K, find number of different sub trees of size less than or equal to K.
给定一棵树,然后统计有多少种划分的方法,使得每一个子树的大小不超过k。
每一棵子树要求都是连通的

我们先考虑这样一个问题:
我们不考虑子树的大小,直接统计有多少种划分的方法。

回到原来的问题
(下面的还没看懂。。。

问题4

问以哪个节点为根,使得花费的期望最小
这种转移的设计非常的巧妙
假设v的父亲节点是v’,那么g(v’)表示的是父亲节点对v节点代价的贡献

问题5

给两棵树,将节点进行相应的映射,并且可以添加节点,最后使得两棵树同构,问最少需要插入多少个节点。
时间复杂度$ O(n^3) $

下面是刷题环节

#

题解

设计下面的状态:
dp[i][0]:以i节点为根的子树,有0个黑色节点的方案数
dp[i][1]:以i节点为根的子树,有1个黑色节点的方案数

ac code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<iostream>
#include<queue>
#include<cmath>
#include<map>
#include<stack>
#include<set>
#include<bitset>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> pii;
#define ls (rt<<1)
#define rs (rt<<1|1)
#define mid (l+r>>1)
#define pb(x) push_back(x)
#define cls(x, val) memset(x, val, sizeof(x))
#define fi first
#define se second
#define mp(x, y) make_pair(x, y)
#define lowbit(x) (x&(-x))
#define inc(i, l, r) for(int i=l; i<=r; i++)
#define dec(i, r, l) for(int i=r; i>=l; i--)
const int inf = 0x3f3f3f3f;
const int maxn = 1e5+10;
const double pi = acos(-1.0);
const double eps = 1e-7;
const ll mod = 1e9+7;
int readint()
{
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
ll readll(){
ll x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
int n, m;
int p[maxn];
int color[maxn];
ll DP[maxn][2];
vector<int> ve[maxn];
void dfs(int u, int fa){
//初始化的条件很神奇,一开始假设为白色节点
DP[u][0] = 1;
DP[u][1] = 0;
for(int i=0; i<ve[u].size(); i++){
int v=ve[u][i];
dfs(v, u);
DP[u][1] *= DP[v][0], DP[u][1]%=mod;
DP[u][1] += DP[u][0]*DP[v][1], DP[u][1] %= mod;
DP[u][0] *= DP[v][0], DP[u][0] %= mod;
}
if(color[u] == 1)//必定是要有一个黑的
DP[u][1] = DP[u][0], DP[u][1]%=mod;
else//i节点的子树可以看成一个整体,也可以切成两部分
DP[u][0] += DP[u][1], DP[u][0] %= mod;
}
int main()
{
n=readint();
inc(i, 1, n-1) p[i]=readint(), ve[p[i]].pb(i);
inc(i ,0, n-1) color[i]=readint();
dfs(0, -1);
printf("%lld\n", DP[0][1]%mod);
return 0;
}

巡逻

里面有用dp求树的直径。
初始化ans=2(n-1)
若k=1, ans -= (dia-1);
若k=2, 将最长链上面的权值全部置为-1,再求一次最长链ans -= (dia-1);

ac code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<iostream>
#include<queue>
#include<cmath>
#include<map>
#include<stack>
#include<set>
#include<bitset>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> pii;
#define ls (rt<<1)
#define rs (rt<<1|1)
#define mid (l+r>>1)
#define pb(x) push_back(x)
#define cls(x, val) memset(x, val, sizeof(x))
#define fi first
#define se second
#define mp(x, y) make_pair(x, y)
#define lowbit(x) (x&(-x))
#define inc(i, l, r) for(int i=l; i<=r; i++)
#define dec(i, r, l) for(int i=r; i>=l; i--)
const int inf = 0x3f3f3f3f;
const int maxn = 1e5+10;
const double pi = acos(-1.0);
const double eps = 1e-7;
const int mod = 1e9+7;
int readint()
{
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
ll readll(){
ll x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
int n, k;
struct Node{
int v, next, w;
}node[maxn<<1];
int tot, head[maxn];
void init(){
cls(head, -1);
tot = 0;
}
void add_edge(int u, int v, int w){
node[tot].v = v, node[tot].w = w, node[tot].next = head[u], head[u] = tot++;
node[tot].v=u, node[tot].w = w, node[tot].next = head[v], head[v] = tot++;
}
int dia;
int id1[maxn], id2[maxn];
int rt = -1;
int dfs(int u, int fa){
int mx1=0, mx2=0;
for(int i=head[u]; ~i; i=node[i].next){
int v=node[i].v;
if(v == fa) continue;
int len=node[i].w+dfs(v, u);
if(len>mx1) mx2=mx1, id2[u]=id1[u], mx1=len, id1[u]=i;
else if(len>mx2) mx2=len, id2[u]=i;
}
if(mx1+mx2>dia) dia = mx1+mx2, rt = u;
return mx1;
}
int main()
{
n=readint(), k=readint();
int u, v;
init();
inc(i, 1, n-1){
u=readint(), v=readint();
add_edge(u ,v, 1);
}
dia = 0;
int ans = 2*(n-1);
cls(id1, -1), cls(id2, -1);
dfs(1, 0);
ans -= (dia-1);
if(k == 2){
//cout<<"in"<<endl;
for(int i=id1[rt]; ~i; i=id1[node[i].v]) node[i].w = node[i^1].w = -1;
//根节点的次小边+孩子的最长链!!!
for(int i=id2[rt]; ~i; i=id1[node[i].v]) node[i].w = node[i^1].w = -1;
dia = -1;
//cout<<"out"<<endl;
dfs(1, 0);
ans -= (dia-1);
}
printf("%d\n", ans);
return 0;
}
/*
8 2
1 2
3 1
3 4
5 3
7 5
8 5
5 6
*/

未解决的问题

文章目录
  1. 1. 问题1
  2. 2. 问题2
  3. 3. 问题3
  4. 4. 问题4
  5. 5. 问题5
  6. 6. 下面是刷题环节
    1. 6.1. 题解
    2. 6.2. ac code
  7. 7. 巡逻
    1. 7.1. ac code
  8. 8. 未解决的问题
{{ live2d() }}