SPOJ GSS7

树剖跑得比 LCT 快啊。。

SPOJ GSS7

题目描述

给你一棵树,需要支持以下操作:

  • 将树上一条链的权值修改为同一个值
  • 求树上一条链的权值连续最大和

题解

树链剖分

根据树剖的原理,只要两个点不在一条链上,我们需要向上跳,直到属于一条链为止。

线段树维护区间和最大的时候,对于每个结点,需要保存以下信息:

  • maxsum 区间和最大
  • lsum 左区间连续和最大
  • rsum 右区间连续和最大
  • sum 区间连续和

这样区间信息就可以合并了。

对于每次查询 $Q(a, b)$,记 $a$ 为在树上相对位于左边的结果,$b$ 为在树上相对位于右边的结果。保存 $a, b$ 点向上跳的区间结点,然后在最后进行合并即可。

这里有两个问题:

如何保存区间结点?

用栈保存。为什么呢?看下面的图

图1

每一次是从下往上跳的,所以我们要将结果“先进后出”,最后通过合并得到一条链的答案。

如何合并两条链的结果?

将左边链的结果反向就行。

图2

根据上图,我们得到了两条链的结果,然后我们发现:两条链的 lsumrsum 是在同一边的,所以我们需要把 $a$ 的 lsumrsum 位置反向即可,然后再将两个结果合并即可。

总结一下:更新直接用树剖的更新方法就行。查找的时候我们需要从两个结点向上跳,每次跳的时候将结果node 保存在节点对应的栈里面,直到它们在一条链上。两个点在一条链上的时候我们将结果 node 随便保存在一个栈里面就行(但是一定要保存)。node 保存完了以后,将栈里每个 node 都进行合并即可。合并得到两个 node,将一个结果的 lsumrsum 进行翻转,再将这两个结果合并即可。

AC Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
//#define NOSTDCPP
//#define Cpp11
//#define Linux_System
#ifndef NOSTDCPP
#include <bits/stdc++.h>
#else
#include <algorithm>
#include <bitset>
#include <cassert>
#include <climits>
#include <complex>
#include <cstring>
#include <cstdio>
#include <deque>
#include <exception>
#include <functional>
#include <iomanip>
#include <iostream>
#include <istream>
#include <iterator>
#include <list>
#include <map>
#include <ostream>
#include <queue>
#include <set>
#include <sstream>
#include <stack>
#include <string>
#include <typeinfo>
#include <utility>
#include <valarray>
#include <vector>
#endif
# ifdef Linux_System
# define getchar getchar_unlocked
# define putchar putchar_unlocked
# endif
# define RESET(_) memset(_, 0, sizeof(_))
# define RESET_(_, val) memset(_, val, sizeof(_))
# define fi first
# define se second
# define pb push_back
# define midf(x, y) ((x + y) >> 1)
# define DXA(_) ((_ << 1))
# define DXB(_) ((_ << 1) | 1)
# define next __Chtholly__
# define x1 __Mercury__
# define y1 __bbtl04__
# define index __ikooo__
using namespace std;
typedef long long ll;
typedef vector <int> vi;
typedef set <int> si;
typedef pair <int, int> pii;
typedef long double ld;
const int MOD = 1e9 + 7;
const int maxn = 100009;
const int maxm = 300009;
const ll TAG_MAX = 100000;
const ll inf = 0x3fffffffffffffff;
const double pi = acos(-1.0);
const double eps = 1e-6;
ll myrand(ll mod){return ((ll)rand() << 32 ^ (ll)rand() << 16 ^ rand()) % mod;}
template <class T>
inline bool scan_d(T & ret)
{
char c;
int sgn;
if(c = getchar(), c == EOF)return false;
while(c != '-' && (c < '0' || c > '9'))c = getchar();
sgn = (c == '-') ? -1 : 1;
ret = (c == '-') ? 0 : (c - '0');
while(c = getchar(), c >= '0' && c <= '9')
ret = ret * 10 + (c - '0');
ret *= sgn;
return true;
}
#ifdef Cpp11
template <class T, class ... Args>
inline bool scan_d(T & ret, Args & ... args)
{
scan_d(ret);
scan_d(args...);
}
#define cin.tie(0); cin.tie(nullptr);
#define cout.tie(0); cout.tie(nullptr);
#endif
inline bool scan_ch(char &ch)
{
if(ch = getchar(), ch == EOF)return false;
while(ch == ' ' || ch == '\n')ch = getchar();
return true;
}
inline void out_number(ll x)
{
if(x < 0)
{
putchar('-');
out_number(- x);
return ;
}
if(x > 9)out_number(x / 10);
putchar(x % 10 + '0');
}
int n;
struct Edge
{
int to, next;
}edge[maxn << 1];
int head[maxn], tot, top[maxn], fa[maxn], deep[maxn], num[maxn], p[maxn], fp[maxn];
int son[maxn];
int pos;
typedef struct
{
ll lsum, rsum, sum, maxsum, tag;
}node;
node tree[maxn << 2];
ll A[maxn];
void __tag(int p, ll val, int len)
{
tree[p].tag = val;
tree[p].lsum = tree[p].rsum = tree[p].maxsum = val * len;
tree[p].sum = val * len;
}
void __max_node(node &a, const node &b, const node &c)
{
a.sum = b.sum + c.sum;
a.maxsum = max(max(b.maxsum, c.maxsum), b.rsum + c.lsum);
a.lsum = max(b.lsum, b.sum + c.lsum);
a.rsum = max(c.rsum, c.sum + b.rsum);
}
void pushup(int p, int l, int r)
{
if(l < r)
{
tree[p].sum = tree[DXA(p)].sum + tree[DXB(p)].sum;
tree[p].lsum = max(tree[DXA(p)].lsum, tree[DXA(p)].sum + tree[DXB(p)].lsum);
tree[p].rsum = max(tree[DXB(p)].rsum, tree[DXB(p)].sum + tree[DXA(p)].rsum);
tree[p].maxsum = max(max(tree[DXA(p)].maxsum, tree[DXB(p)].maxsum), tree[DXA(p)].rsum + tree[DXB(p)].lsum);
}
}
void pushdown(int p, int l, int r)
{
if(tree[p].tag != TAG_MAX)
{
int mid = midf(l, r);
__tag(DXA(p), tree[p].tag, mid - l + 1);
__tag(DXB(p), tree[p].tag, r - mid);
tree[p].tag = TAG_MAX;
}
}
void pre(int l, int r, int pp)
{
tree[pp].tag = TAG_MAX;
//tree[pp].lsum = tree[pp].rsum = tree[pp].maxsum = -inf;
tree[pp].sum = 0;
if(l == r)
{
tree[pp].lsum = tree[pp].rsum = tree[pp].maxsum = tree[pp].sum = A[fp[l]];
return ;
}
int mid = midf(l, r);
pre(l, mid, DXA(pp));
pre(mid + 1, r, DXB(pp));
pushup(pp, l, r);
}
int l, r;
ll val;
node query(int nl, int nr, int p)
{
if(l <= nl && nr <= r) return tree[p];
int mid = midf(nl, nr);
pushdown(p, nl, nr);
if(mid < l) return query(mid + 1, nr, DXB(p));
if(r <= mid) return query(nl, mid, DXA(p));
node a, b, c;
c = query(mid + 1, nr, DXB(p));
b = query(nl, mid, DXA(p));
__max_node(a, b, c);
return a;
}
void update(int nl, int nr, int p)
{
if(l <= nl && nr <= r)
{
__tag(p, val, nr - nl + 1);
return ;
}
pushdown(p, nl, nr);
int mid = midf(nl, nr);
if(l <= mid) update(nl, mid, DXA(p));
if(mid < r) update(mid + 1, nr, DXB(p));
pushup(p, nl, nr);
}
void init()
{
tot = 0;
RESET_(head, -1);
pos = 1;
RESET_(son, -1);
RESET_(deep, -1);
}
void addedge(int u,int v)
{
edge[tot].to = v;edge[tot].next = head[u];head[u] = tot++;
}
void dfs1(int u,int pre,int d)
{
deep[u] = d;
fa[u] = pre;
num[u] = 1;
for(int i = head[u];i != -1; i = edge[i].next)
{
int v = edge[i].to;
if(v != pre)
{
dfs1(v,u,d+1);
num[u] += num[v];
if(son[u] == -1 || num[v] > num[son[u]])
son[u] = v;
}
}
}
void getpos(int u, int sp)
{
top[u] = sp;
if(son[u] != -1)
{
p[u] = pos++;
fp[p[u]] = u;
getpos(son[u],sp);
}
else
{
p[u] = pos++;
fp[p[u]] = u;
return;
}
for(int i = head[u] ; ~i; i = edge[i].next)
{
int v = edge[i].to;
if(v != son[u] && v != fa[u])
getpos(v, v);
}
}
void __update__(int u, int v)
{
int f1 = top[u], f2 = top[v];
while(f1 != f2)
{
if(deep[f1] < deep[f2])
{
swap(f1, f2);
swap(u, v);
}
l = p[f1], r = p[u];
update(1, n, 1);
u = fa[f1]; f1 = top[u];
}
if(deep[u] > deep[v]) swap(u, v);
l = p[u], r = p[v];
//cout << ">>" << l << ' ' << r << endl;
update(1, n, 1);
}
stack <node> dq1, dq2;
node ans1, ttmp, ans2, ans3;
ll __query__(int u, int v)
{
int f1 = top[u], f2 = top[v];
bool rev = false;
while(! dq1.empty()) dq1.pop();
while(! dq2.empty()) dq2.pop();
ll ans = 0;
while(f1 != f2)
{
if(deep[f1] < deep[f2])
{
swap(f1, f2);
swap(u, v);
rev ^= true;
}
l = p[f1], r = p[u];
if(! rev) dq1.push(query(1, n, 1));
else dq2.push(query(1, n, 1));
u = fa[f1]; f1 = top[u];
}
if(deep[u] > deep[v]) swap(u, v), rev ^= true;
l = p[u], r = p[v];
if(rev) dq1.push(query(1, n, 1));
else dq2.push(query(1, n, 1));
//cout << l << ' ' << r << endl;
bool u1 = false, u2 = false;
//ans1.maxsum = ans1.lsum = ans1.rsum = -inf;
//ans1.sum = 0;
//ans3 = ans1;
//cout << "size = " << dq1.size() << ' ' << dq2.size() << endl;
while(! dq1.empty())
{
ttmp = dq1.top();
dq1.pop();
//cout << "Lsum = " << ttmp.lsum << ", Rsum = " << ttmp.rsum << ", Maxsum = " << ttmp.maxsum << ", Sum = " << ttmp.sum << endl;
ans2 = ans1;
if(u1) __max_node(ans1, ans2, ttmp);
else u1 = true, ans1 = ttmp;
}
swap(ans1.lsum, ans1.rsum);
//cout << "ans1: Lsum = " << ans1.lsum << ", Rsum = " << ans1.rsum << ", MaxSum = " << ans1.maxsum << ", Sum = " << ans1.sum << endl;
//assert(u1);
while(! dq2.empty())
{
ttmp = dq2.top();
dq2.pop();
ans2 = ans3;
//cout << "Lsum = " << ttmp.lsum << ", Rsum = " << ttmp.rsum << ", Maxsum = " << ttmp.maxsum << ", Sum = " << ttmp.sum << endl;
if(u2) __max_node(ans3, ans2, ttmp);
else u2 = true, ans3 = ttmp;
}
//cout << "ans3: Lsum = " << ans3.lsum << ", Rsum = " << ans3.rsum << ", MaxSum = " << ans3.maxsum << ", Sum = " << ans3.sum << endl;
//assert(u2);
if(u1 && u2)
{
//cout << "ans2: Lsum = " << ans2.lsum << ", Rsum = " << ans2.rsum << ", MaxSum = " << ans2.maxsum << ", Sum = " << ans2.sum << endl;
//cout << ans1.lsum << ' ' << ans1.rsum + ans3.lsum << ' ' << ans1.maxsum << ' ' << ans3.maxsum << endl;
__max_node(ans2, ans1, ans3);
//cout << "ans2: Lsum = " << ans2.lsum << ", Rsum = " << ans2.rsum << ", MaxSum = " << ans2.maxsum << ", Sum = " << ans2.sum << endl;
return ans2.maxsum;
}
if(u1) return ans1.maxsum;
if(u2) return ans3.maxsum;
assert(false);
}
int main()
{
int f, t, q, type;
scanf("%d", &n);
for(int i = 1; i <= n; ++ i) scanf("%lld", A + i);
init();
for(int i = 1; i < n; ++ i)
{
scanf("%d %d", &f, &t);
addedge(f, t);
addedge(t, f);
}
dfs1(1, -1, 0);
getpos(1, 1);
pre(1, n, 1);
scanf("%d", &q);
while(q --)
{
scanf("%d %d %d", &type, &l, &r);
if(deep[l] > deep[r]) swap(l, r);
switch(type)
{
case 1:
printf("%lld\n", max(__query__(l, r), 0ll));
break;
case 2:
scanf("%lld", &val);
__update__(l, r);
break;
default:
assert(type == 1 || type == 2);
}
}
return 0;
}
/*
12
2354 8995 6660 9648 8119 7591 4462 1208 478 7230 815 7824
1 2
1 3
2 4
4 5
5 6
2 7
7 8
2 9
8 10
8 11
7 12
1
1 3 11
5
-1102 -899 7058 -8459 -9769
1 2
2 3
2 4
4 5
7
1 2 4
1 3 1
2 4 2 600
1 5 4
1 3 1
1 4 1
1 4 4
20
8633 3974 -1389 2846 4052 -8806 -8266 -3153 -7191 4863 8921 -8321 -6339 6107 -7701 6753 -4060 -155 9036 9418
1 2
1 3
3 4
4 5
2 6
4 7
4 8
1 9
9 10
10 11
8 12
7 13
5 14
4 15
5 16
13 17
11 18
15 19
18 20
20
1 5 3
2 8 15 3674
2 11 19 2522
2 3 18 1755
2 17 16 809
2 8 3 6923
2 12 5 1236
2 5 4 858
1 19 7
2 19 10 8903
1 6 8
2 4 6 8006
1 8 15
1 19 13
1 12 11
2 2 2 4294
1 11 20
1 14 7
2 5 16 6346
2 6 10 9009
15
-1588 -8841 1640 -4769 -4831 6793 6309 7604 8338 -8194 -7844 772 1262 -3374 2992
1 2
1 3
3 4
4 5
3 6
3 7
3 8
8 9
9 10
4 11
11 12
4 13
1 14
5 15
1
1 11 10
*/

LCT

对于 LCT 而言,对于每个结点,可以直接表示树上的信息。

合并有点难受。。因为要求结果必须大于 0 ,所以我们需要将小于 0 的结果忽略。。见代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 100000+20;
const ll inf = 1e18;
int n,m,top;
int fa[maxn],c[maxn][2], size[maxn];
ll maxsum[maxn], lsum[maxn], rsum[maxn], sum[maxn], val[maxn], lazy[maxn];
int q[maxn];
bool rev[maxn];
int id;
int tot, head[maxn];
void init(int n){
tot = 0;
memset(head, -1, sizeof(head));
memset(rev, 0, sizeof(rev));
memset(c, 0, sizeof(c));
memset(fa, 0, sizeof(fa));
for(int i = 0; i <= n + 1; ++ i)
maxsum[i] = lsum[i] = rsum[i] = -inf, lazy[i] = inf, size[i] = 1;
size[0] = 0;
size[n + 1] = 0;
}
bool isroot(int x)
{
return c[fa[x]][0]!=x && c[fa[x]][1]!=x;
}
void update(int x)
{
maxsum[0] = lsum[0] = rsum[0] = 0;
sum[0] = 0;
size[0] = 0;
int l = c[x][0], r = c[x][1];
size[x] = size[l] + size[r] + 1;
sum[x] = sum[l] + sum[r] + val[x];
maxsum[x] = max(max(maxsum[l], maxsum[r]), lsum[r] + val[x] + rsum[l]);
lsum[x] = max(lsum[l], val[x] + sum[l] + lsum[r]);
rsum[x] = max(rsum[r], rsum[l] + sum[r] + val[x]);
}
void __tag(int p, ll v)
{
lazy[p] = val[p] = v;
maxsum[p] = lsum[p] = rsum[p] = v > 0 ? v * size[p] : 0ll;
sum[p] = v * size[p];
}
void __rev(int p)
{
rev[p] ^= true;
swap(c[p][0], c[p][1]);
swap(lsum[p], rsum[p]);
}
void pushdown(int x)
{
int l=c[x][0],r=c[x][1];
if(rev[x])
{
rev[x]^=1;
__rev(l);
__rev(r);
}
if(lazy[x] != inf)
{
__tag(l, lazy[x]);
__tag(r, lazy[x]);
lazy[x] = inf;
}
}
void rotate(int x)
{
int y=fa[x],z=fa[y],l,r;
l=(c[y][1]==x);r=l^1;
if(!isroot(y))c[z][c[z][1]==y]=x;
fa[c[x][r]]=y;fa[y]=x;fa[x]=z;
c[y][l]=c[x][r];c[x][r]=y;
update(y);update(x);
}
void splay(int x)
{
top = 0;
q[++top]=x;
for(int i=x;!isroot(i);i=fa[i])
q[++top]=fa[i];
while(top)pushdown(q[top--]);
while(!isroot(x))
{
int y=fa[x],z=fa[y];
if(!isroot(y))
{
if(c[y][0]==x^c[z][0]==y)rotate(x);
else rotate(y);
}
rotate(x);
}
}
void access(int x)
{
for(int t=0;x;t=x,x=fa[x])
splay(x),c[x][1]=t,update(x);
}
void makeroot(int x)
{
access(x);splay(x);rev[x]^=1;swap(c[x][0], c[x][1]);
}
void link(int x,int y)
{
makeroot(x); fa[x] = y;
}
void split(int x,int y)
{
makeroot(x);access(y);splay(y);
}
ll ans(int x, int y)
{
split(x, y);
return max(maxsum[y], 0ll);
}
void __update(int x, int y, ll val)
{
split(x, y);
__tag(y, val);
}
int main()
{
int f, t, type;
ll v;
scanf("%d", &n);
init(n);
for(int i = 1; i <= n; ++ i)
{
scanf("%lld", sum + i);
maxsum[i] = lsum[i] = rsum[i] = sum[i] > 0 ? sum[i] : 0;
val[i] = sum[i];
}
for(int i = 1; i < n; ++ i)
{
scanf("%d %d", &f, &t);
link(f, t);
}
scanf("%d", &m);
while(m --)
{
scanf("%d %d %d", &type, &f, &t);
switch(type)
{
case 1:
printf("%lld\n", ans(f, t));
break;
case 2:
scanf("%lld", &v);
__update(f, t, v);
//for(int i = 1; i <= n; ++ i) cout << val[i] << ' ';cout << endl;
break;
default:
assert(false);
}
}
return 0;
}
/*
5
-3 -2 1 2 3
1 2
2 3
1 4
4 5
3
1 2 5
2 3 4 2
1 2 5
5
-3968 -165 7588 -173 -75
1 2
2 3
1 4
2 5
7
1 1 5
1 5 2
2 4 2 1938
1 2 4
1 2 3
1 3 4
1 5 1
6
3255 8180 -5067 3612 9586 -1412
1 2
1 3
3 4
4 5
2 6
6
2 4 3 7869
2 5 6 -4673
2 6 1 -8963
2 1 6 8043
1 3 3
1 4 5
9
-3606 -9163 362 -9043 -3313 923 3159 5766 740
1 2
1 3
3 4
3 5
2 6
1 7
5 8
1 9
3
2 6 7 588
1 1 8
1 6 1
*/
文章目录
  1. 1. SPOJ GSS7
    1. 1.1. 题目描述
    2. 1.2. 题解
      1. 1.2.1. 树链剖分
        1. 1.2.1.1. 如何保存区间结点?
        2. 1.2.1.2. 如何合并两条链的结果?
        3. 1.2.1.3. AC Code
      2. 1.2.2. LCT
{{ live2d() }}