DSU On Tree

只能支持子树查询,且不支持修改操作的一种优秀的暴力数据结构。

DSU On Tree

DSU On Tree用来解决这样一类问题:统计树上一个节点的子树中具有某种特征的节点数,且无修改操作。 也有人把它叫做 静态树上众数。

例如子树中颜色为 $x$ 的个数, 这种方法可以做到 $O(n \log n)$ 的复杂度。

算法原理

其实很简单,就是枚举这个点,然后这棵子树扫一遍得到答案,然后清空 hsh 数组。

这就是笨蛋的想法,我们会发现它做了一些无用功,比如说最后一次清空,其实可以用于它的父节点,这样父节点就可以少算一个子节点。

我们想让尽量大的子树不擦除,那么就树剖剖出重儿子,重儿子不擦除就可以了!

算法实现步骤

  1. 利用树链剖分剖出重儿子。
  2. DFS 遍历子树。先遍历轻儿子,然后再遍历重儿子。
  3. 如果当前根节点是重儿子,那么就跳过,否则将答案清空。

应用

CodeForces 600E

给你一棵树,每个节点有一种颜色,问你每个子树 $x$ 的颜色数最多的那种颜色。

如果颜色数相同,那么种类数相加。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
//#define NOSTDCPP
//#define Cpp11
//#define Linux_System
#ifndef NOSTDCPP
#include <bits/stdc++.h>
#else
#include <algorithm>
#include <bitset>
#include <cassert>
#include <climits>
#include <complex>
#include <cstring>
#include <cstdio>
#include <deque>
#include <exception>
#include <functional>
#include <iomanip>
#include <iostream>
#include <istream>
#include <iterator>
#include <list>
#include <map>
#include <ostream>
#include <queue>
#include <set>
#include <sstream>
#include <stack>
#include <string>
#include <typeinfo>
#include <utility>
#include <valarray>
#include <vector>
#endif
# ifdef Linux_System
# define getchar getchar_unlocked
# define putchar putchar_unlocked
# endif
# define RESET(_) memset(_, 0, sizeof(_))
# define RESET_(_, val) memset(_, val, sizeof(_))
# define fi first
# define se second
# define pb push_back
# define midf(x, y) ((x + y) >> 1)
# define DXA(_) ((_ << 1))
# define DXB(_) ((_ << 1) | 1)
# define next __Chtholly__
# define x1 __Mercury__
# define y1 __bbtl04__
# define index __ikooo__
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef vector <int> vi;
typedef set <int> si;
typedef pair <int, int> pii;
typedef long double ld;
const int MOD = 1e9 + 7;
const int maxn = 500009;
const int maxm = maxn << 1;
const ll inf = 1e18;
const double pi = acos(-1.0);
const double eps = 1e-6;
ll myrand(ll mod){return ((ll)rand() << 32 ^ (ll)rand() << 16 ^ rand()) % mod;}
template <class T>
inline bool scan_d(T & ret)
{
char c;
int sgn;
if(c = getchar(), c == EOF)return false;
while(c != '-' && (c < '0' || c > '9'))c = getchar();
sgn = (c == '-') ? -1 : 1;
ret = (c == '-') ? 0 : (c - '0');
while(c = getchar(), c >= '0' && c <= '9')
ret = ret * 10 + (c - '0');
ret *= sgn;
return true;
}
#ifdef Cpp11
template <class T, class ... Args>
inline bool scan_d(T & ret, Args & ... args)
{
scan_d(ret);
scan_d(args...);
}
#define cin.tie(0); cin.tie(nullptr);
#define cout.tie(0); cout.tie(nullptr);
#endif
inline bool scan_ch(char &ch)
{
if(ch = getchar(), ch == EOF)return false;
while(ch == ' ' || ch == '\n')ch = getchar();
return true;
}
template <class T>
inline void out_number(T x)
{
if(x < 0)
{
putchar('-');
out_number(- x);
return ;
}
if(x > 9)out_number(x / 10);
putchar(x % 10 + '0');
}
int n, m, A[maxn], hsh[maxn];
ll C, ans[maxn], AnsMax;
struct edge
{
int to, next;
}edg[maxm];
int fir[maxn], edg_cnt;
void addedge(int f, int t)
{
edg[++ edg_cnt].to = t;
edg[edg_cnt].next = fir[f];
fir[f] = edg_cnt;
}
int father[maxn], size[maxn], son[maxn];
bool vis[maxn];
void dfs1(int u, int fa)
{
father[u] = fa;
size[u] = 1;
for(int i = fir[u], v; i; i = edg[i].next)
{
v = edg[i].to;
if(v != fa)
{
dfs1(v, u);
size[u] += size[v];
if(! son[u] || size[son[u]] < size[v])
son[u] = v;
}
}
}
void cal(int u, int val)
{
//统计子树的结果
hsh[A[u]] += val;
if(val > 0 && hsh[A[u]] == AnsMax) C += A[u];
if(val > 0 && hsh[A[u]] > AnsMax) C = A[u], AnsMax = hsh[A[u]];
for(int i = fir[u], v; i; i = edg[i].next)
{
v = edg[i].to;
if(v != father[u] && ! vis[v]) //这里的判断条件是避免重复判断
cal(v, val);
}
}
void dfs2(int u, bool preferred_son)
{
//先遍历轻儿子
for(int i = fir[u], v; i; i = edg[i].next)
{
v = edg[i].to;
if(v != father[u] && v != son[u])
dfs2(v, false);
}
//再遍历重儿子
if(son[u]) dfs2(son[u], true), vis[son[u]] = true;
cal(u, 1);
ans[u] = C;
if(son[u]) vis[son[u]] = false;
//如果是轻儿子,要擦除结果
if(! preferred_son) cal(u, -1), C = AnsMax = 0;
}
int main()
{
edg_cnt = 0;
scanf("%d", &n);
for(register int i = 1; i <= n; ++ i) scanf("%d", A + i);
for(register int i = 1, f, t; i < n; ++ i)
{
scanf("%d %d", &f, &t);
addedge(f, t);
addedge(t, f);
}
dfs1(1, -1);
dfs2(1, false);
for(int i = 1; i <= n; ++ i) printf("%I64d ", ans[i]);
puts("");
return 0;
}

CodeForces 570D

给你一棵树,询问以一个点 $u$ 为根,深度为 $d$ 的所有节点是否能够成一个回文串。

DFS 遍历整棵树,然后统计每个点的子树有哪些字符(二维数组 $cnt[i][26]$ ),然后在 $u$ 点记录有多少个点是出现奇数次的,如果出现奇数次的点少于 $1$ ,就是 YES ,否则就是 NO

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
//#define NOSTDCPP
//#define Cpp11
//#define Linux_System
#ifndef NOSTDCPP
#include <bits/stdc++.h>
#else
#include <algorithm>
#include <bitset>
#include <cassert>
#include <climits>
#include <complex>
#include <cstring>
#include <cstdio>
#include <deque>
#include <exception>
#include <functional>
#include <iomanip>
#include <iostream>
#include <istream>
#include <iterator>
#include <list>
#include <map>
#include <ostream>
#include <queue>
#include <set>
#include <sstream>
#include <stack>
#include <string>
#include <typeinfo>
#include <utility>
#include <valarray>
#include <vector>
#endif
# ifdef Linux_System
# define getchar getchar_unlocked
# define putchar putchar_unlocked
# endif
# define RESET(_) memset(_, 0, sizeof(_))
# define RESET_(_, val) memset(_, val, sizeof(_))
# define fi first
# define se second
# define pb push_back
# define midf(x, y) ((x + y) >> 1)
# define DXA(_) ((_ << 1))
# define DXB(_) ((_ << 1) | 1)
# define next __Chtholly__
# define x1 __Mercury__
# define y1 __bbtl04__
# define index __ikooo__
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef vector <int> vi;
typedef set <int> si;
typedef pair <int, int> pii;
typedef long double ld;
const int MOD = 1e9 + 7;
const int maxn = 500009;
const int maxm = maxn << 1;
const ll inf = 1e18;
const double pi = acos(-1.0);
const double eps = 1e-6;
ll myrand(ll mod){return ((ll)rand() << 32 ^ (ll)rand() << 16 ^ rand()) % mod;}
template <class T>
inline bool scan_d(T & ret)
{
char c;
int sgn;
if(c = getchar(), c == EOF)return false;
while(c != '-' && (c < '0' || c > '9'))c = getchar();
sgn = (c == '-') ? -1 : 1;
ret = (c == '-') ? 0 : (c - '0');
while(c = getchar(), c >= '0' && c <= '9')
ret = ret * 10 + (c - '0');
ret *= sgn;
return true;
}
#ifdef Cpp11
template <class T, class ... Args>
inline bool scan_d(T & ret, Args & ... args)
{
scan_d(ret);
scan_d(args...);
}
#define cin.tie(0); cin.tie(nullptr);
#define cout.tie(0); cout.tie(nullptr);
#endif
inline bool scan_ch(char &ch)
{
if(ch = getchar(), ch == EOF)return false;
while(ch == ' ' || ch == '\n')ch = getchar();
return true;
}
template <class T>
inline void out_number(T x)
{
if(x < 0)
{
putchar('-');
out_number(- x);
return ;
}
if(x > 9)out_number(x / 10);
putchar(x % 10 + '0');
}
int n, m;
char A[maxn];
vector <pii> query[maxn];
int cnt2[maxn];
bool ans[maxn], cnt[maxn][26];
struct edge
{
int to, next;
}edg[maxm];
int fir[maxn], edg_cnt;
void addedge(int f, int t)
{
edg[++ edg_cnt].to = t;
edg[edg_cnt].next = fir[f];
fir[f] = edg_cnt;
}
int father[maxn], size[maxn], son[maxn], deep[maxn];
bool vis[maxn];
void dfs1(int u, int fa, int d)
{
father[u] = fa;
size[u] = 1;
deep[u] = d;
for(int i = fir[u], v; i; i = edg[i].next)
{
v = edg[i].to;
if(v != fa)
{
dfs1(v, u, d + 1);
size[u] += size[v];
if(! son[u] || size[son[u]] < size[v])
son[u] = v;
}
}
}
void cal(int u, int val)
{
//cnt2[deep[u]] -= cnt[deep[u]][A[u] - 'a'];
cnt[deep[u]][A[u] - 'a'] ^= true;
if(cnt[deep[u]][A[u] - 'a']) ++ cnt2[deep[u]];
else -- cnt2[deep[u]];
for(int i = fir[u], v; i; i = edg[i].next)
{
v = edg[i].to;
if(v != father[u] && ! vis[v])
cal(v, val);
}
}
void dfs2(int u, bool preferred_son)
{
for(int i = fir[u], v; i; i = edg[i].next)
{
v = edg[i].to;
if(v != father[u] && v != son[u])
dfs2(v, false);
}
if(son[u]) dfs2(son[u], true), vis[son[u]] = true;
cal(u, 1);
for(register int i = 0; i < query[u].size(); ++ i)
ans[query[u][i].second] = cnt2[query[u][i].first] <= 1;
if(son[u]) vis[son[u]] = false;
if(! preferred_son) cal(u, -1);
}
int main()
{
scanf("%d %d", &n, &m);
for(register int i = 2, f; i <= n; ++ i)
{
scanf("%d", &f);
addedge(f, i);
addedge(i, f);
}
scanf(" %s", A + 1);
for(register int i = 1, x, d; i <= m; ++ i)
{
scanf("%d %d", &x, &d);
query[x].push_back(make_pair(d, i));
}
RESET(ans);
dfs1(1, -1, 1);
dfs2(1, false);
for(register int i = 1; i <= m; ++ i) ans[i] ? puts("Yes") : puts("No");
puts("");
return 0;
}
/*
6 5
1 1 1 3 3
zacccd
1 1
3 3
4 1
6 1
1 2
*/

SGU 507

给你一棵树,叶子结点定义权值,其他结点(X 类结点)没有定义,询问所有 X 类结点的叶子结点最相近的权值差。

利用了 DSU 的思想,每个点用 set 表示权值,然后每次做启发式合并即可。

注意写法!一定要从小的往大的合并才能保证复杂度!

时间复杂度 $O(n \log^2 n)$ (提示:考虑合并的复杂度)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
//#define NOSTDCPP
//#define Cpp11
//#define Linux_System
#ifndef NOSTDCPP
#include <bits/stdc++.h>
#else
#include <algorithm>
#include <bitset>
#include <cassert>
#include <climits>
#include <complex>
#include <cstring>
#include <cstdio>
#include <deque>
#include <exception>
#include <functional>
#include <iomanip>
#include <iostream>
#include <istream>
#include <iterator>
#include <list>
#include <map>
#include <ostream>
#include <queue>
#include <set>
#include <sstream>
#include <stack>
#include <string>
#include <typeinfo>
#include <utility>
#include <valarray>
#include <vector>
#endif
# ifdef Linux_System
# define getchar getchar_unlocked
# define putchar putchar_unlocked
# endif
# define RESET(_) memset(_, 0, sizeof(_))
# define RESET_(_, val) memset(_, val, sizeof(_))
# define fi first
# define se second
# define pb push_back
# define midf(x, y) ((x + y) >> 1)
# define DXA(_) ((_ << 1))
# define DXB(_) ((_ << 1) | 1)
# define next __Chtholly__
# define x1 __Mercury__
# define y1 __bbtl04__
# define index __ikooo__
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef vector <int> vi;
typedef set <int> si;
typedef pair <int, int> pii;
typedef long double ld;
const int MOD = 1e9 + 7;
const int maxn = 50009;
const int maxm = maxn << 1;
const int inf = 2147483647;
const double pi = acos(-1.0);
const double eps = 1e-6;
ll myrand(ll mod){return ((ll)rand() << 32 ^ (ll)rand() << 16 ^ rand()) % mod;}
template <class T>
inline bool scan_d(T & ret)
{
char c;
int sgn;
if(c = getchar(), c == EOF)return false;
while(c != '-' && (c < '0' || c > '9'))c = getchar();
sgn = (c == '-') ? -1 : 1;
ret = (c == '-') ? 0 : (c - '0');
while(c = getchar(), c >= '0' && c <= '9')
ret = ret * 10 + (c - '0');
ret *= sgn;
return true;
}
#ifdef Cpp11
template <class T, class ... Args>
inline bool scan_d(T & ret, Args & ... args)
{
scan_d(ret);
scan_d(args...);
}
#define cin.tie(0); cin.tie(nullptr);
#define cout.tie(0); cout.tie(nullptr);
#endif
inline bool scan_ch(char &ch)
{
if(ch = getchar(), ch == EOF)return false;
while(ch == ' ' || ch == '\n')ch = getchar();
return true;
}
template <class T>
inline void out_number(T x)
{
if(x < 0)
{
putchar('-');
out_number(- x);
return ;
}
if(x > 9)out_number(x / 10);
putchar(x % 10 + '0');
}
int n, m;
int ans[maxn], A[maxn];
set <int> st[maxn];
set <int> :: iterator itx, ity, pre, suf;
bool leaf[maxn];
struct edge
{
int to, next;
}edg[maxm];
int fir[maxn], edg_cnt;
void addedge(int f, int t)
{
edg[++ edg_cnt].to = t;
edg[edg_cnt].next = fir[f];
fir[f] = edg_cnt;
}
int __merge__(set <int> &x, set <int> &y)
{
if(y.size() > x.size()) swap(x, y);
int ans = inf;
for(ity = y.begin(); ity != y.end(); ++ ity)
{
pre = suf = x.lower_bound(*ity);
if(pre != x.begin()) -- pre;
if(pre != x.end()) ans = min(ans, abs(*ity - *pre));
if(suf != x.end()) ans = min(ans, abs(*ity - *suf));
x.insert(*ity);
}
return ans;
}
void dfs1(int u, int fa)
{
ans[u] = inf;
if(u >= n - m + 1)
{
st[u].insert(A[u]);
return ;
}
for(int i = fir[u], v; i; i = edg[i].next)
{
v = edg[i].to;
if(v != fa)
{
dfs1(v, u);
ans[u] = min(ans[v], ans[u]);
ans[u] = min(ans[u], __merge__(st[u], st[v]));
}
}
}
int main()
{
scanf("%d %d", &n, &m);
for(register int i = 1; i <= n; ++ i) st[i].clear();
for(register int i = 2, f; i <= n; ++ i)
{
scanf("%d", &f);
addedge(f, i);
//addedge(i, f);
}
for(register int i = n - m + 1, x, d; i <= n; ++ i) scanf("%d", A + i);
RESET(ans);
RESET_(leaf, true);
dfs1(1, -1);
for(register int i = 1; i <= n - m; ++ i) printf("%d ", ans[i]);
puts("");
return 0;
}
/*
5 4
1 1 1 1
1 4 7 9
5 4
1 1 1 1
1 4 7 10
7 4
1 2 1 2 3 3
2 10 7 15
2 1
1
100
*/

BZOJ 2599

求一条简单路径,权值和等于 $K$ ,且边的数量最小。

题目分析

我们可以知道,结果一定是 (A)

1
2
3
/
/
/

或 (B)

1
2
3
/\
/ \
/ \

我们可以在遍历 DFS 的时候,维护了一个桶,以权值和为下标,记录着从这个点往下走的所有路径中,构造每种 权值和 最少需要的边数。

对于 (A),建 mp 后直接用 mp[m] 更新即可。

对于 (B),如何构建 mp

考虑节点 $u$,它的儿子们为 $v$,构建 mp
对于每个儿子的子树中的点,计算出其到 $u$ 的相对带权距离 $d_w$ 及深度差 $d$,并用 $d$ 与 $f[k−d_w]$ 的和来更新答案。 用刚才遍历到的每个 $d_w$ 对应的 $d$ 来更新 $f[d_w]$ 的值。

这样就相当于,对于每个儿子内的点到 $u$ 的链,用与之带权距离和恰为 $k$ 的,之前遍历过的另一条链与它一起,拼成一条路径并更新答案。然后,再把它们自己也加入桶中等待后面的子树被匹配。

对于每个点,继承其重儿子的桶以保证复杂度。

这样做能保证在暴力遍历轻儿子以更新的情况下,复杂度为 $O(n \log^2 n)$ 。

本题做法

记两个增量 offset[u]delta[u] ,分别表示当前节点相对于当前重链的链底的带权距离和距离。这能通过简单的回溯操作计算。

在全局开一个桶 mp

与前面例题相似,每到一个节点,优先递归其轻儿子,最后递归重儿子,每次结束一个节点的递归时,如果这个点是轻儿子,则清空,否则不做出操作 。
那么当回溯回当前节点时,mp 中存放着重儿子的信息。

先更新好两个增量。 于是在更新和查询时,下标值统一减去 offset[u]mp 内部的值则在查询时加上delta[u],更新时减去 delta[u]

可以发现这样即可实现 $O(1)$ 继承重儿子。

对于轻儿子,暴力更新即可。

考虑到每次查询和修改都是 $O(\log {n})$ ,于是这么做复杂度为$O(n \log^2 n)$ 。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
//#define NOSTDCPP
//#define Cpp11
#define Linux_System
#ifndef NOSTDCPP
#include <bits/stdc++.h>
#else
#include <algorithm>
#include <bitset>
#include <cassert>
#include <climits>
#include <complex>
#include <cstring>
#include <cstdio>
#include <deque>
#include <exception>
#include <functional>
#include <iomanip>
#include <iostream>
#include <istream>
#include <iterator>
#include <list>
#include <map>
#include <ostream>
#include <queue>
#include <set>
#include <sstream>
#include <stack>
#include <string>
#include <typeinfo>
#include <utility>
#include <valarray>
#include <vector>
#endif
# ifdef Linux_System
# define getchar getchar_unlocked
# define putchar putchar_unlocked
# endif
# define RESET(_) memset(_, 0, sizeof(_))
# define RESET_(_, val) memset(_, val, sizeof(_))
# define fi first
# define se second
# define pb push_back
# define midf(x, y) ((x + y) >> 1)
# define DXA(_) ((_ << 1))
# define DXB(_) ((_ << 1) | 1)
# define next __Chtholly__
# define x1 __Mercury__
# define y1 __bbtl04__
# define index __ikooo__
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef vector <int> vi;
typedef set <int> si;
typedef pair <int, int> pii;
typedef long double ld;
const int MOD = 1e9 + 7;
const int maxn = 500009;
const int maxm = maxn << 1;
const int inf = 1e9;
const double pi = acos(-1.0);
const double eps = 1e-6;
ll myrand(ll mod){return ((ll)rand() << 32 ^ (ll)rand() << 16 ^ rand()) % mod;}
template <class T>
inline bool scan_d(T & ret)
{
char c;
int sgn;
if(c = getchar(), c == EOF)return false;
while(c != '-' && (c < '0' || c > '9'))c = getchar();
sgn = (c == '-') ? -1 : 1;
ret = (c == '-') ? 0 : (c - '0');
while(c = getchar(), c >= '0' && c <= '9')
ret = ret * 10 + (c - '0');
ret *= sgn;
return true;
}
#ifdef Cpp11
template <class T, class ... Args>
inline bool scan_d(T & ret, Args & ... args)
{
scan_d(ret);
scan_d(args...);
}
#define cin.tie(0); cin.tie(nullptr);
#define cout.tie(0); cout.tie(nullptr);
#endif
inline bool scan_ch(char &ch)
{
if(ch = getchar(), ch == EOF)return false;
while(ch == ' ' || ch == '\n')ch = getchar();
return true;
}
template <class T>
inline void out_number(T x)
{
if(x < 0)
{
putchar('-');
out_number(- x);
return ;
}
if(x > 9)out_number(x / 10);
putchar(x % 10 + '0');
}
int n, m, A[maxn], hsh[maxn];
int ans, offset[maxn], delta[maxn];
//offset[]: 当前节点相对于当前重链的链底的带权距离
//delta []: 当前节点相对于当前重链的链底的距离
struct edge
{
int to, next, val;
}edg[maxm];
int fir[maxn], edg_cnt;
void addedge(int f, int t, int val)
{
edg[++ edg_cnt].to = t;
edg[edg_cnt].next = fir[f];
edg[edg_cnt].val = val;
fir[f] = edg_cnt;
}
int father[maxn], size[maxn], son[maxn], deep[maxn], fae[maxn];
int seg[maxn], id[maxn], ed[maxn], tot;
ll sm[maxn];
bool vis[maxn];
void dfs1(int u, int fa)
{
seg[id[u] = ++ tot] = u;
father[u] = fa;
size[u] = 1;
son[u] = 0;
for(int i = fir[u], v; i; i = edg[i].next)
{
v = edg[i].to;
if(v != fa)
{
deep[v] = deep[u] + 1;
fae[v] = edg[i].val;
sm[v] = sm[u] + fae[v];
dfs1(v, u);
size[u] += size[v];
if(! son[u] || size[son[u]] < size[v])
son[u] = v;
}
}
ed[u] = tot;
}
map <ll, int> mp;
void dfs2(int u, bool preferred_son)
{
if(! son[u])
{
offset[u] = delta[u] = 0;
return ;
}
//先遍历轻儿子
for(int i = fir[u], v; i; i = edg[i].next)
{
v = edg[i].to;
if(v != father[u] && v != son[u])
dfs2(v, false);
}
//再遍历重儿子
if(son[u]) dfs2(son[u], true), vis[son[u]] = true;
int &dlt = delta[u], &ofs = offset[u];
ofs = offset[son[u]] + fae[son[u]];
dlt = delta[son[u]] + 1;
if(mp.count(mp[fae[son[u]] - offset[u]]))
mp[fae[son[u]] - offset[u]] = 1 - dlt;
else mp[fae[son[u]] - offset[u]] = 1 - dlt;
register ll tmp;
for(register int i = fir[u], v; i; i = edg[i].next)
{
v = edg[i].to;
if(v != father[u] && v != son[u])
{
for(register int j = id[v]; j <= ed[v]; ++ j)
{
tmp = m - sm[seg[j]] + sm[u];
if(tmp > 0 && mp.count(tmp - ofs))
ans = min(ans, mp[tmp - ofs] + deep[seg[j]] - deep[u] + dlt);
}
for(register int j = id[v]; j <= ed[v]; ++ j)
{
tmp = sm[seg[j]] - sm[u];
if(tmp <= m)
{
if(mp.count(tmp - ofs))
mp[tmp - ofs] = min(mp[tmp - ofs], deep[seg[j]] - deep[u] - dlt);
else
mp[tmp - ofs] = deep[seg[j]] - deep[u] - dlt;
}
}
}
}
if(mp.count(m - ofs))
ans = min(ans, mp[m - ofs] + dlt);
//如果是轻儿子,要擦除结果
if(! preferred_son)
mp.clear();
}
int main()
{
edg_cnt = 0;
//scanf("%d %d", &n, &m);
scan_d(n);
scan_d(m);
for(register int i = 1, f, t, val; i < n; ++ i)
{
//scanf("%d %d %d", &f, &t, &val);
scan_d(f);
scan_d(t);
scan_d(val);
++ f, ++ t;
addedge(f, t, val);
addedge(t, f, val);
}
dfs1(1, -1);
ans = inf;
dfs2(1, false);
printf("%d\n", ans == inf ? -1 : ans);
return 0;
}
/*
4 3
0 1 1
1 2 2
1 3 4
*/
文章目录
  1. 1. DSU On Tree
    1. 1.1. 算法原理
    2. 1.2. 算法实现步骤
    3. 1.3. 应用
      1. 1.3.1. CodeForces 600E
      2. 1.3.2. CodeForces 570D
      3. 1.3.3. SGU 507
      4. 1.3.4. BZOJ 2599
        1. 1.3.4.1. 题目分析
        2. 1.3.4.2. 本题做法
{{ live2d() }}