主席树初探

主席树初探

主席树初探

概述

很多问题如果用线段树处理的话需要采用离线思想,若用主席树则可直接在线处理。故很多时候离线线段树求解可以转化为在线主席树求解。注意,主席树本质就是线段树,变化就在其实现可持久化,后一刻可以参考前一刻的状态,二者共同部分很多。一颗线段树的节点维护的是当前节点对应区间的信息,倘若每次区间都不一样,就会给处理带来一些困难。有时可以直接细分区间然后合并,此种情况线段树可以直接搞定;但有时无法通过直接划分区间来求解,如频繁询问区间第k小元素,主席树就出场了。

主席树的结构

主席树的每个节点对应一颗线段树,此处有点抽象。

在我们的印象中,每个线段树的节点维护的树左右子树下标以及当前节点对应区间的信息(信息视具体问题定)。对于一个待处理的序列$a[1], a[2], …a[n]$,有$n$个前缀。每个前缀可以看做一棵线段树,共有$n$棵线段树;若不采用可持久化结构,带来的严重后果就是会MLE

根据可持久化数据结构的定义,由于相邻线段树即前缀的公共部分很多,可以充分利用,达到优化目的,同时每棵线段树还是保留所有的叶节点只是较之前共用了很多共用节点。主席树很重要的操作就是如何寻找公用的节点信息,这些可能可能出现在根节点也可能出现在叶节点

对原来的数列$[1, 2, …, n]$的每一个前缀$[1, 2, …, i]$$(1 \leq i \leq n)$建立一棵线段树,线段树的每一个节点存某个前缀$[1, 2, …, i]$中属于区间$[L, L + 1, …, R]$的数一共有多少个

比如根节点是$[1, 2, …, n]$,一共$i$个数,sum[root] = i;根节点的左儿子是$[1, 2, …, \frac {L+R} {2}]$,若不大于$\frac {L+R} {2}$的数有$x$个,那么sum[root.left] = x

若要查找 $[i, i + 1, …, j]$ 中第 $k$ 大数时,设某结点 $x$,那么x.sum[j] - x.sum[i - 1]就是$[i, i + 1, …, j]$中在结点$x$内的数字总数。由于每个$[1, 2, …, i]$和$[1, 2, …, i-1]$只有一条路是不一样的,那么其他的结点只要用回前一棵树的结点即可,时空复杂度为$O(n \log n)$。因为只要修改 $\log n$ 个节点就可以了,那么我们只要每次增加 $\log n$ 个节点就可以记录它原来的状态了, 即更新一个值的时候仅仅只是更新了一条链,其他的节点都相同,即可共用。

由于主席树每棵节点保存的是一颗线段树,维护的区间相同,结构相同,保存的信息不同,因此具有了加减性。所以在求区间的时候,若要处理区间$[l, r]$, 只需要处理rt[r] - rt[l-1]就可以了。

rt[l-1]处理的是$[1,l-1]$的数,rt[r]处理的是$[1,r]$的数,相减即为$[l, r]$这个区间里面的数。

主席树的功能实现

建树

1
2
3
4
5
6
7
8
9
10
void build(int nl, int nr, int &p, int pre, int val, int place)
{
p = ++ tot;
tree[p] = tree[pre];
tree[p].sum += val;
if(nl == nr) return ;
int mid = midf(nl, nr);
if(mid >= place) build(nl, mid, tree[p].l, tree[pre].l, val, place);
else build(mid + 1, nr, tree[p].r, tree[pre].r, val, place);
}

在两棵树上查询差值

1
2
3
4
5
6
7
int query(int l, int r, int nl, int nr, int place)
{
if(nl == nr) return nl;
int sum = tree[tree[r].l].sum - tree[tree[l].l].sum, mid = midf(nl, nr);
if(sum >= place) return query(tree[l].l, tree[r].l, nl, mid, place);
return query(tree[l].r, tree[r].r, mid + 1, nr, place - sum);
}

在多棵树上查询差值

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
int sl = 0, sr = siz - 1;
while(sl < sr)
{
int sum = 0, mid = midf(sl, sr);
for(int i = 0; i < ppp; ++ i)
sum += tree[tree[ansr[i]].l].sum - tree[tree[ansl[i]].l].sum;
if(sum < k)
{
k -= sum;
sl = mid + 1;
for(int i = 0; i < ppp; ++ i)
{
ansr[i] = tree[ansr[i]].r;
ansl[i] = tree[ansl[i]].r;
}
}
else
{
sr = mid;
for(int i = 0; i < ppp; ++ i)
{
ansr[i] = tree[ansr[i]].l;
ansl[i] = tree[ansl[i]].l;
}
}
}

主席树模板

POJ 2104

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#define NOSTDCPP
#ifndef NOSTDCPP
#include <bits/stdc++.h>
#else
#include <algorithm>
#include <bitset>
#include <complex>
#include <cstring>
#include <deque>
#include <exception>
#include <fstream>
#include <functional>
#include <iomanip>
#include <iostream>
#include <istream>
#include <iterator>
#include <list>
#include <map>
#include <ostream>
#include <queue>
#include <set>
#include <sstream>
#include <stack>
#include <string>
#include <typeinfo>
#include <utility>
#include <valarray>
#include <vector>
#endif
# define RESET(_) memset(_, 0, sizeof(_))
# define RESET_(_, val) memset(_, val, sizeof(_))
# define midf(x, y) (((x) + (y)) >> 1)
# define fi first
# define se second
using namespace std;
typedef long long ll;
typedef vector <int> vi;
typedef set <int> si;
const int MOD = 1e9 + 7;
const int maxn = 100009;
const int maxm = maxn * 40;
const double pi = acos(-1.0);
const double eps = 1e-6;
template <class T>
inline bool scan_d(T & ret)
{
char c;
int sgn;
if(c = getchar(), c == EOF)return false;
while(c != '-' && (c < '0' || c > '9'))c = getchar();
sgn = (c == '-') ? -1 : 1;
ret = (c == '-') ? 0 : (c - '0');
while(c = getchar(), c >= '0' && c <= '9')
ret = ret * 10 + (c - '0');
ret *= sgn;
return true;
}
inline void out(ll x)
{
if(x < 0)
{
putchar('-');
out(- x);
return ;
}
if(x > 9)out (x / 10);
putchar(x % 10 + '0');
}
int n, A[maxn];
struct node
{
int l, r, sum;
}tree[maxm];
int root[maxn], tot;
vi dis;
int id(int x){return lower_bound(dis.begin(), dis.end(), x) - dis.begin();}
void build(int nl, int nr, int &p, int pre, int val, int place)
{
p = ++ tot;
tree[p] = tree[pre];
tree[p].sum += val;
if(nl == nr) return ;
int mid = midf(nl, nr);
if(mid >= place) build(nl, mid, tree[p].l, tree[pre].l, val, place);
else build(mid + 1, nr, tree[p].r, tree[pre].r, val, place);
}
int query(int l, int r, int nl, int nr, int place)
{
if(nl == nr) return nl;
int sum = tree[tree[r].l].sum - tree[tree[l].l].sum, mid = midf(nl, nr);
if(sum >= place) return query(tree[l].l, tree[r].l, nl, mid, place);
return query(tree[l].r, tree[r].r, mid + 1, nr, place - sum);
}
int main()
{
int q, x, siz;
tot = 1;
scanf("%d %d", &n, &q);
for(int i = 1;i <= n; ++ i)
scanf("%d", A + i),
dis.push_back(A[i]);
sort(dis.begin(), dis.end());
dis.erase(unique(dis.begin(), dis.end()), dis.end());
siz = dis.size();
for(int i = 1; i <= n; ++ i)
build(0, siz - 1, root[i], root[i - 1], 1, id(A[i]));
while(q --)
{
int a, b, c;
scanf("%d %d %d", &a, &b, &c);
printf("%d\n", dis[query(root[a - 1], root[b], 0, siz - 1, c)]);
}
return 0;
}

SPOJ COT

树链剖分 + 主席树。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
//#define NOSTDCPP
//#define Cpp11
//#define Linux_System
#ifndef NOSTDCPP
#include <bits/stdc++.h>
#else
#include <algorithm>
#include <bitset>
#include <cassert>
#include <climits>
#include <complex>
#include <cstring>
#include <cstdio>
#include <deque>
#include <exception>
#include <functional>
#include <iomanip>
#include <iostream>
#include <istream>
#include <iterator>
#include <list>
#include <map>
#include <ostream>
#include <queue>
#include <set>
#include <sstream>
#include <stack>
#include <string>
#include <typeinfo>
#include <utility>
#include <valarray>
#include <vector>
#endif
# ifdef Linux_System
# define getchar getchar_unlocked
# define putchar putchar_unlocked
# endif
# define RESET(_) memset(_, 0, sizeof(_))
# define RESET_(_, val) memset(_, val, sizeof(_))
# define fi first
# define se second
# define pb push_back
# define midf(x, y) ((x + y) >> 1)
# define DXA(_) ((_ << 1))
# define DXB(_) ((_ << 1) | 1)
# define next __Chtholly__
# define x1 __Mercury__
# define y1 __bbtl04__
# define index __ikooo__
using namespace std;
typedef long long ll;
typedef vector <int> vi;
typedef set <int> si;
typedef pair <int, int> pii;
typedef long double ld;
const int MOD = 1e9 + 7;
const int maxn = 100009;
const int maxm = 300009;
const double pi = acos(-1.0);
const double eps = 1e-6;
ll myrand(ll mod){return ((ll)rand() << 32 ^ (ll)rand() << 16 ^ rand()) % mod;}
template <class T>
inline bool scan_d(T & ret)
{
char c;
int sgn;
if(c = getchar(), c == EOF)return false;
while(c != '-' && (c < '0' || c > '9'))c = getchar();
sgn = (c == '-') ? -1 : 1;
ret = (c == '-') ? 0 : (c - '0');
while(c = getchar(), c >= '0' && c <= '9')
ret = ret * 10 + (c - '0');
ret *= sgn;
return true;
}
#ifdef Cpp11
template <class T, class ... Args>
inline bool scan_d(T & ret, Args & ... args)
{
scan_d(ret);
scan_d(args...);
}
#define cin.tie(0); cin.tie(nullptr);
#define cout.tie(0); cout.tie(nullptr);
#endif
inline bool scan_ch(char &ch)
{
if(ch = getchar(), ch == EOF)return false;
while(ch == ' ' || ch == '\n')ch = getchar();
return true;
}
inline void out_number(ll x)
{
if(x < 0)
{
putchar('-');
out_number(- x);
return ;
}
if(x > 9)out_number(x / 10);
putchar(x % 10 + '0');
}
int n, m;
int A[maxn];
struct Edge
{
int to, next;
}edge[maxn << 1];
int head[maxn], tot, top[maxn], fa[maxn], deep[maxn], num[maxn], p[maxn], fp[maxn];
int son[maxn];
int pos;
void init()
{
tot = 0;
RESET_(head, -1);
pos = 1;
RESET_(son, -1);
}
void addedge(int u,int v)
{
edge[tot].to = v;edge[tot].next = head[u];head[u] = tot++;
}
void dfs1(int u,int pre,int d)
{
deep[u] = d;
fa[u] = pre;
num[u] = 1;
for(int i = head[u];i != -1; i = edge[i].next)
{
int v = edge[i].to;
if(v != pre)
{
dfs1(v,u,d+1);
num[u] += num[v];
if(son[u] == -1 || num[v] > num[son[u]])
son[u] = v;
}
}
}
void getpos(int u,int sp)
{
top[u] = sp;
if(son[u] != -1)
{
p[u] = pos++;
fp[p[u]] = u;
getpos(son[u],sp);
}
else
{
p[u] = pos++;
fp[p[u]] = u;
return;
}
for(int i = head[u] ; i != -1; i = edge[i].next)
{
int v = edge[i].to;
if(v != son[u] && v != fa[u])
getpos(v,v);
}
}
struct node
{
int l, r, sum;
}tree[maxn * 40];
vi dis;
int id(int x) {return lower_bound(dis.begin(), dis.end(), x) - dis.begin();}
int tot_tree, root[maxn], ansl[maxn], ansr[maxn], siz;
void build(int nl, int nr, int &p, int pre, int val, int place)
{
p = ++ tot_tree;
tree[p] = tree[pre];
tree[p].sum += val;
if(nl == nr) return ;
int mid = midf(nl, nr);
if(mid >= place) build(nl, mid, tree[p].l, tree[pre].l, val, place);
else build(mid + 1, nr, tree[p].r, tree[pre].r, val, place);
}
int query(int l, int r, int nl, int nr, int place)
{
if(nl == nr) return nl;
int mid = midf(nl, nr), sum = tree[tree[r].l].sum - tree[tree[l].l].sum;
if(mid >= place) return query(tree[l].l, tree[r].l, nl, mid, place);
return query(tree[l].r, tree[r].r, mid + 1, nr, place - sum);
}
int ans(int f, int t, int k)
{
int ff = top[f], tt = top[t];
int ppp = 0;
while(ff != tt)
{
if(deep[ff] < deep[tt]) swap(ff, tt), swap(f, t);
ansr[ppp] = root[p[f]], ansl[ppp] = root[p[ff] - 1];
ppp ++;
f = fa[ff];
ff = top[f];
}
if(deep[f] > deep[t]) swap f, t);
ansr[ppp] = root[p[t]], ansl[ppp] = root[p[f] - 1];
ppp ++;
int sl = 0, sr = siz - 1;
while(sl < sr)
{
int sum = 0, mid = midf(sl, sr);
for(int i = 0; i < ppp; ++ i)
sum += tree[tree[ansr[i]].l].sum - tree[tree[ansl[i]].l].sum;
if(sum < k)
{
k -= sum;
sl = mid + 1;
for(int i = 0; i < ppp; ++ i)
{
ansr[i] = tree[ansr[i]].r;
ansl[i] = tree[ansl[i]].r;
}
}
else
{
sr = mid;
for(int i = 0; i < ppp; ++ i)
{
ansr[i] = tree[ansr[i]].l;
ansl[i] = tree[ansl[i]].l;
}
}
}
return dis[sl];
}
int main()
{
int f, t, k;
init();
tot_tree = 0;
scan_d(n);
scan_d(m);
for(int i = 1; i <= n; ++ i)
scan_d(A[i]), dis.push_back(A[i]);
for(int i = 1; i < n; ++ i)
{
scan_d(f); scan_d(t);
addedge(f, t);
addedge(t, f);
}
dfs1(1, 0, 0);
getpos(1, 1);
sort(dis.begin(), dis.end());
dis.erase(unique(dis.begin(), dis.end()), dis.end());
siz = dis.size();
for(int i = 1; i <= n; ++ i)
build(0, siz - 1, root[i], root[i - 1], 1, id(A[fp[i]]));
while(m --)
{
scan_d(f); scan_d(t); scan_d(k);
out_number(ans(f, t, k));
puts("");
}
return 0;
}
/*
7 7
4 10 2 2 5 5 4
1 2
2 3
3 4
4 5
5 6
6 7
1 7 1
1 7 2
1 7 3
1 7 4
1 7 5
1 7 6
1 7 7
*/

HDU 4417

主席树的查询可以写的像线段树一样。。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#include <bits/stdc++.h>
# define RESET(_) memset(_, 0, sizeof(_))
# define RESET_(_, val) memset(_, val, sizeof(_))
# define midf(x, y) (((x) + (y)) >> 1)
# define fi first
# define se second
using namespace std;
typedef long long ll;
typedef vector <int> vi;
typedef set <int> si;
const int MOD = 1e9 + 7;
const int maxn = 100009;
const int maxm = maxn * 40;
const int inf = 1e9 + 300;
const double pi = acos(-1.0);
const double eps = 1e-6;
template <class T>
inline bool scan_d(T & ret)
{
char c;
int sgn;
if(c = getchar(), c == EOF)return false;
while(c != '-' && (c < '0' || c > '9'))c = getchar();
sgn = (c == '-') ? -1 : 1;
ret = (c == '-') ? 0 : (c - '0');
while(c = getchar(), c >= '0' && c <= '9')
ret = ret * 10 + (c - '0');
ret *= sgn;
return true;
}
inline void out(ll x)
{
if(x < 0)
{
putchar('-');
out(- x);
return ;
}
if(x > 9)out (x / 10);
putchar(x % 10 + '0');
}
int n, A[maxn];
struct node
{
int l, r, sum;
}tree[maxm];
int root[maxn], tot;
vi dis;
int id(int x){int p = lower_bound(dis.begin(), dis.end(), x) - dis.begin();if(dis[p] == x) return p + 1; return p;}
void build(int nl, int nr, int &p, int pre, int val, int place)
{
p = ++ tot;
tree[p] = tree[pre];
tree[p].sum += val;
if(nl == nr) return ;
int mid = midf(nl, nr);
if(mid >= place) build(nl, mid, tree[p].l, tree[pre].l, val, place);
else build(mid + 1, nr, tree[p].r, tree[pre].r, val, place);
}
int l, r;
int query(int pl, int pr, int nl, int nr)
{
if(l <= nl && nr <= r)
{
return tree[pr].sum - tree[pl].sum;
}
int sum = 0, mid = midf(nl, nr);
if(l <= mid) sum += query(tree[pl].l, tree[pr].l, nl, mid);
if(mid < r) sum += query(tree[pl].r, tree[pr].r, mid + 1, nr);
return sum;
}
int main()
{
int q, x, siz;
int T;
scan_d(T);
for(int Casen = 1; Casen <= T; ++ Casen)
{
tot = 1;
root[0] = 0;
scan_d(n);
scan_d(q);
dis.clear();
//scanf("%d %d", &n, &q);
for(int i = 1;i <= n; ++ i)
scan_d(A[i]),
//scanf("%d", A + i),
dis.push_back(A[i]);
sort(dis.begin(), dis.end());
dis.erase(unique(dis.begin(), dis.end()), dis.end());
siz = dis.size();
for(int i = 1; i <= n; ++ i)
build(1, siz, root[i], root[i - 1], 1, id(A[i]));
printf("Case %d:\n", Casen);
while(q --)
{
int a, b, c, ans = inf, dd;
scan_d(a);
scan_d(b);
scan_d(c);
if(c < dis[0]) ans = 0;
else if(c >= dis[siz - 1]) r = siz;
else r = id(c);
l = 1;
if(ans != 0) ans = query(root[a], root[b + 1], 1, siz);
out(ans);
puts("");
/*
scanf("%d %d %d", &a, &b, &c);
printf("%d\n", dis[query(root[a - 1], root[b], 0, siz - 1, c)]);
*/
}
}
return 0;
}

BZOJ 3932

每个权值出现在 $[a, b]$ 之间。

询问权值出现在 $[l, r]$ 之间,前 $\min (k, r - l + 1)$ 个权值的和。

其中 $k = (Ai * ans{i-1} + B_i) \ \text{mod} \ C_i, ans_0 = 0$

把时间看做主席树下标。

把每个权值 $x$ 出现 $a$ 和消失 $b$ 看做两个事件,将主席树的信息更新。给时间为 $[a, +\infty)$ 的主席树上加 $x$ ,给时间为 $(b, +\infty)$ 的主席树减去 $x$ 。

具体做法是:将事件按照时间排序,如果事件与上一个事件时间相同,就在这棵主席树上加或减 $x$ ,否则复制根,再到事件发生的时间去搞。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#include <bits/stdc++.h>
using namespace std;
const int maxn = 100009;
const int maxm = 100009 * 128;
struct timer
{
int time, event, chg;
//timer(int time = 0, int event = 0, int chg = -1) : time(time), event(event), chg(chg) {}
bool operator < (const timer &b) const
{
return time < b.time;
}
}eve[maxn << 1];
int n, m;
vi h;
int root[maxn << 1], tot;
struct node
{
int l, r, siz;
ll sum;
}tree[maxm];
void pre_(int nl, int nr, int &p)
{
p = ++ tot;
tree[p].siz = tree[p].sum = 0;
if(nl == nr) return ;
int mid = midf(nl, nr);
pre_(nl, mid, tree[p].l);
pre_(mid + 1, nr, tree[p].r);
}
void build(int pre, int &now, int nl, int nr, int chg, int vpl)
{
now = ++ tot;
tree[now] = tree[pre];
tree[now].siz += chg;
tree[now].sum += h[vpl] * chg;
if(nl == nr) return ;
int mid = midf(nl, nr);
if(vpl <= mid) build(tree[pre].l, tree[now].l, nl, mid, chg, vpl);
else build(tree[pre].r, tree[now].r, mid + 1, nr, chg, vpl);
}
ll query(int nl, int nr, int p, int k)
{
if(k >= tree[p].siz) return tree[p].sum;
if(nl == nr) return tree[p].sum / tree[p].siz * k;
int sz = tree[tree[p].l].siz, mid = midf(nl, nr);
ll ans;
if(sz >= k) ans = query(nl, mid, tree[p].l, k);
else ans = tree[tree[p].l].sum + query(mid + 1, nr, tree[p].r, k - sz);
return ans;
}
int main()
{
//freopen("query4.in", "r", stdin);
//freopen("query4.ans", "w", stdout);
scanf("%d %d", &m, &n);
int treesiz;
tot = 0;
for(int i = 1, f, t, x; i <= m; ++ i)
{
scanf("%d %d %d", &f, &t, &x);
eve[i * 2 - 1] = timer{f, x, 1};
eve[i * 2] = timer{t + 1, x, -1};
h.push_back(x);
}
sort(h.begin(), h.end());
h.erase(unique(h.begin(), h.end()), h.end());
treesiz = h.size();
m <<= 1;
sort(eve + 1, eve + 1 + m);
ll pre = 1, x, a, b, c, k;
root[0] = 0;
for(int i = 1, nowtime; i <= m; ++ i)
{
x = lower_bound(h.begin(), h.end(), eve[i].event) - h.begin();
nowtime = eve[i].time;
if(i == 1 || eve[i - 1].time != nowtime)
{
if(i != 1)
{
for(int j = eve[i - 1].time + 1; j < nowtime; ++ j)
root[j] = root[j - 1];
}
build(root[nowtime - 1], root[nowtime], 0, treesiz - 1, eve[i].chg, x);
}
else build(root[nowtime], root[nowtime], 0, treesiz - 1, eve[i].chg, x);
}
while(n --)
{
scanf("%lld %lld %lld %lld", &x, &a, &b, &c);
k = (pre * a + b) % c + 1;
k = query(0, treesiz - 1, root[x], (int)k);
printf("%lld\n", k);
pre = k;
}
return 0;
}
/* in
4 3
1 2 6
2 3 3
1 3 2
3 3 4
3 1 3 2
1 1 3 4
2 2 4 3
*/
/* out
2
8
11
*/

主席树套树状数组

树状数组的每个节点都是一颗线段树,但这棵线段树不再保存每个前缀的信息了,而是由树状数组的$\text {sum}$函数计算出这个前缀的信息,那么显而易见这棵线段树保存的是辅助数组 $B$ 的值,即 $B=A[i-\text {lowbit}(i)+1]+…+A[i]$, 其中 $A[i]$ 表示值为 $i$ 的元素出现的次数。

对于每次修改,我们要修改树状数组上的 $\log n$ 棵树,对于每棵树,我们要修改 $\log n$ 个结点,所以时空复杂度为$O((n+q)\log n \log n)$,初始时建立一颗静态的主席树,树状数组只保存每次修改的信息,那么时空复杂度降为了$O(n \log n+q \log n \log n)$ 。

ZOJ 2112

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
#include <algorithm>
#include <bitset>
#include <cassert>
#include <complex>
#include <cstring>
#include <deque>
#include <exception>
#include <fstream>
#include <functional>
#include <iomanip>
#include <iostream>
#include <istream>
#include <iterator>
#include <list>
#include <map>
#include <ostream>
#include <queue>
#include <set>
#include <sstream>
#include <stack>
#include <string>
#include <typeinfo>
#include <utility>
#include <valarray>
#include <vector>
# define RESET(_) memset(_, 0, sizeof(_))
# define RESET_(_, val) memset(_, val, sizeof(_))
# define midf(x, y) (((x) + (y)) >> 1)
# define fi first
# define se second
using namespace std;
typedef long long ll;
typedef vector <int> vi;
typedef set <int> si;
const int MOD = 1e9 + 7;
const int maxn = 50009;
const int maxm = maxn * 40;
const double pi = acos(-1.0);
const double eps = 1e-6;
template <class T>
inline bool scan_d(T & ret)
{
char c;
int sgn;
if(c = getchar(), c == EOF)return false;
while(c != '-' && (c < '0' || c > '9'))c = getchar();
sgn = (c == '-') ? -1 : 1;
ret = (c == '-') ? 0 : (c - '0');
while(c = getchar(), c >= '0' && c <= '9')
ret = ret * 10 + (c - '0');
ret *= sgn;
return true;
}
inline void out(ll x)
{
if(x < 0)
{
putchar('-');
out(- x);
return ;
}
if(x > 9)out (x / 10);
putchar(x % 10 + '0');
}
int n, A[maxn], siz;
struct node
{
int l, r, sum;
}tree[maxm];
int root[maxn], tot, yroot[maxn];
void pre(int l, int r, int &p)
{
p = ++ tot;
tree[p].sum = 0;
if(l == r) return;
int mid = midf(l, r);
pre(l, mid, tree[p].l);
pre(mid + 1, r, tree[p].r);
}
/*非递归版
void updat(int pos, int v, int &p, int pre)
{
int tmp = ++ tot, l = 1, r = siz;
p = tmp;
tree[tmp] = tree[pre];
tree[tmp].sum += v;
while(l < r)
{
int mid = midf(l, r);
if(pos <= mid)
{
tree[tmp].l = ++ tot; tree[tmp].r = tree[pre].r;
tree[tree[tmp].l].sum = tree[tree[pre].l].sum + v;
tmp = tree[tmp].l;
pre = tree[pre].l;
r = mid;
}
else
{
tree[tmp].r = ++ tot; tree[tmp].l = tree[pre].l;
tree[tree[tmp].r].sum = tree[tree[pre].r].sum + v;
tmp = tree[tmp].r;
pre = tree[pre].r;
l = mid + 1;
}
}
}
*/
void build(int nl, int nr, int &p, int pre, int val, int place)
{
p = ++ tot;
tree[p] = tree[pre];
tree[p].sum += val;
if(nl == nr) return ;
int mid = midf(nl, nr);
if(mid >= place) build(nl, mid, tree[p].l, tree[pre].l, val, place);
else build(mid + 1, nr, tree[p].r, tree[pre].r, val, place);
}
namespace BIT
{
int B[maxn]; //保存新的临时根的数组
int lowbit(int x) {return x & -x;}
void update(int x, int pos, int v)
{
for(int i = x; i <= n; i += lowbit(i))
//updat(pos, v, root[i], root[i]);
build(1, siz, root[i], root[i], v, pos); //将变化更新到树状数组上
}
int sum(int x)
{
int ans = 0;
for(int i = x; i; i -= lowbit(i))
ans += tree[tree[B[i]].l].sum; //从缓存数组中读取结果
return ans;
}
};
using namespace BIT;
vi dis;
int id(int x){return lower_bound(dis.begin(), dis.end(), x) - dis.begin();}
struct query
{
int type, f, t, k;
};
query qqq[maxn];
int ans(int f, int t, int k)
{
//将结果缓存到树状数组里面
for(int i = f - 1; i; i -= lowbit(i)) B[i] = root[i];
for(int i = t; i; i -= lowbit(i)) B[i] = root[i];
int sl = 1, sr = siz, lrt = yroot[f - 1], rrt = yroot[t];
while(sl < sr) // 二分查找第 k 大
{
int s = sum(t) - sum(f - 1) + tree[tree[rrt].l].sum - tree[tree[lrt].l].sum;
int mid = midf(sl, sr);
if(s < k)
{
//去右子树查找
k -= s;
sl = mid + 1;
for(int i = f - 1; i; i -= lowbit(i)) B[i] = tree[B[i]].r;
for(int i = t; i; i -= lowbit(i)) B[i] = tree[B[i]].r;
rrt = tree[rrt].r, lrt = tree[lrt].r;
}
else
{
//去左子树查找
for(int i = f - 1; i; i -= lowbit(i)) B[i] = tree[B[i]].l;
for(int i = t; i; i -= lowbit(i)) B[i] = tree[B[i]].l;
rrt = tree[rrt].l, lrt = tree[lrt].l;
sr = mid;
}
}
return sl - 1;
}
int main()
{
int q, tt, T;
char ch;
scanf("%d", &T);
while(T --)
{
tot = 0;
scanf("%d %d", &n, &q);
dis.clear();
for(int i = 1; i <= n; ++ i)
scanf("%d", A + i),
dis.push_back(A[i]);
for(int i = 1; i <= q; ++ i)
{
scanf("\n%c", &ch);
switch(ch)
{
case 'Q':
qqq[i].type = 1;
scanf("%d %d %d", &qqq[i].f, &qqq[i].t, &qqq[i].k);
break;
case 'C':
qqq[i].type = 2;
scanf("%d %d", &qqq[i].f, &qqq[i].t);
dis.push_back(qqq[i].t);
//对查询离散化,离线处理答案
break;
default:
assert(ch == 'Q' || ch == 'C');
}
}
sort(dis.begin(), dis.end());
dis.erase(unique(dis.begin(), dis.end()), dis.end());
siz = dis.size();
for(int i = 1; i <= n; ++ i)
build(1, siz, yroot[i], yroot[i - 1], 1, id(A[i]) + 1);
RESET(root);
for(int i = 1; i <= q; ++ i)
{
if(qqq[i].type == 1)
printf("%d\n", dis[ans(qqq[i].f, qqq[i].t, qqq[i].k)]);
else
{
BIT :: update(qqq[i].f, id(A[qqq[i].f]) + 1, -1);
BIT :: update(qqq[i].f, id(qqq[i].t) + 1, 1);
A[qqq[i].f] = qqq[i].t;
}
}
}
return 0;
}

HDU 5919 询问每个数的后继位置 & 不相同的数字个数

给一个长为 $n$ 的数列,$m$ 次查询,每次查询一个区间中每一种数在区间中第一次出现的位置的中位数,强制在线

$1\leq n,m\leq 2*10^5$

设区间内有 $K$ 个不同数字,二分一个 $mid$,使得 $[l, mid]$ 中恰好有 $K$ 个不同数字

查询 $[l,r]$ 中不相同的数字个数:记录一下每个数字下一次出现的位置 ,问题转化成该区间有多少个 $nxt$ 大于 $r$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
//#define NOSTDCPP
//#define Cpp11
//#define Linux_System
#ifndef NOSTDCPP
#include <bits/stdc++.h>
#else
#include <algorithm>
#include <bitset>
#include <cassert>
#include <climits>
#include <complex>
#include <cstring>
#include <cstdio>
#include <deque>
#include <exception>
#include <functional>
#include <iomanip>
#include <iostream>
#include <istream>
#include <iterator>
#include <list>
#include <map>
#include <ostream>
#include <queue>
#include <set>
#include <sstream>
#include <stack>
#include <string>
#include <typeinfo>
#include <utility>
#include <valarray>
#include <vector>
#endif
# ifdef Linux_System
# define getchar getchar_unlocked
# define putchar putchar_unlocked
# endif
# define RESET(_) memset(_, 0, sizeof(_))
# define RESET_(_, val) memset(_, val, sizeof(_))
# define fi first
# define se second
# define pb push_back
# define midf(x, y) ((x + y) >> 1)
# define DXA(_) (((_) << 1))
# define DXB(_) (((_) << 1) | 1)
# define next __Chtholly__
# define x1 __Mercury__
# define y1 __bbtl04__
# define index __ikooo__
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef vector <int> vi;
typedef set <int> si;
typedef pair <int, int> pii;
typedef long double ld;
const int MOD = 1e9 + 7;
const int maxn = 200009;
const int maxm = 300009;
const ll inf = 1e18;
const double pi = acos(-1.0);
const double eps = 1e-6;
ll myrand(ll mod){return ((ll)rand() << 32 ^ (ll)rand() << 16 ^ rand()) % mod;}
template <class T>
inline bool scan_d(T & ret)
{
char c;
int sgn;
if(c = getchar(), c == EOF)return false;
while(c != '-' && (c < '0' || c > '9'))c = getchar();
sgn = (c == '-') ? -1 : 1;
ret = (c == '-') ? 0 : (c - '0');
while(c = getchar(), c >= '0' && c <= '9')
ret = ret * 10 + (c - '0');
ret *= sgn;
return true;
}
#ifdef Cpp11
template <class T, class ... Args>
inline bool scan_d(T & ret, Args & ... args)
{
scan_d(ret);
scan_d(args...);
}
#define cin.tie(0); cin.tie(nullptr);
#define cout.tie(0); cout.tie(nullptr);
#endif
inline bool scan_ch(char &ch)
{
if(ch = getchar(), ch == EOF)return false;
while(ch == ' ' || ch == '\n')ch = getchar();
return true;
}
template <class T>
inline void out_number(T x)
{
if(x < 0)
{
putchar('-');
out_number(- x);
return ;
}
if(x > 9)out_number(x / 10);
putchar(x % 10 + '0');
}
int n, m;
int A[maxn], root[maxn], cnt, nxt[maxn], last[maxn];
struct node
{
int l, r, val;
}tree[maxn * 50];
#define lson(_) (tree[_].l)
#define rson(_) (tree[_].r)
int v;
void update(int pre, int &now, int nl, int nr)
{
now = ++ cnt;
tree[now] = tree[pre];
++ tree[now].val;
if(nl == nr) return ;
int mid = midf(nl, nr);
if(v <= mid) update(lson(pre), lson(now), nl, mid);
else update(rson(pre), rson(now), mid + 1, nr);
}
int query(int x, int y, int nl, int nr)
{
if(nl == nr) return 0;
int mid = midf(nl, nr);
if(v <= mid) return tree[rson(y)].val - tree[rson(x)].val + query(lson(x), lson(y), nl, mid);
else return query(rson(x), rson(y), mid + 1, nr);
}
#undef lson
#undef rson
int T, ans, f, t, l, r, mid, now;
int ff(int x)
{
v = x;
return query(root[f - 1], root[x], 1, n + 1);
}
int main()
{
scanf("%d", &T);
//scan_d(T);
for(register int kssn = 1; kssn <= T; ++ kssn)
{
scanf("%d %d", &n, &m);
//scan_d(n);
//scan_d(m);
ans = cnt = 0;
for(register int i = 1; i <= n; ++ i) scan_d(A[i]);
for(register int i = 0; i <= 200000; ++ i) last[i] = n + 1;
for(register int i = n; i >= 1; -- i)
{
nxt[i] = last[A[i]];
last[A[i]] = i;
}
for(register int i = 1; i <= n; ++ i)
{
v = nxt[i];
update(root[i - 1], root[i], 1, n + 1);
}
printf("Case #%d:", kssn);
while(m --)
{
scanf("%d %d", &f, &t);
//scan_d(f);
//scan_d(t);
f = (f + ans) % n + 1, t = (t + ans) % n + 1;
if(f > t) swap(f, t);
l = f, r = t;
now = ff(t);
now = (now + 1) >> 1;
while(l <= r)
{
mid = midf(l, r);
if(ff(mid) >= now) ans = mid, r = mid - 1;
else l = mid + 1;
}
printf(" %d", ans);
}
puts("");
}
return 0;
}

主席树与 LCA

XDOJ 1248

询问你在树上一条链上的第一个没有出现的数字 - 1 ( 树链上的 mex)

这个题标准解法是利用 LCA 倍增去做,然而我去写了树链剖分。。

在链上的话,注意建树的位置(在标号以后就根据父亲的信息建树即可)。

询问的时候,询问 [lca - u, falca - v] 上的值即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
#include <bits/stdc++.h>
# define RESET(_) memset(_, 0, sizeof(_))
# define RESET_(_, val) memset(_, val, sizeof(_))
# define fi first
# define se second
# define pb push_back
# define midf(x, y) ((x + y) >> 1)
# define DXA(_) (((_) << 1))
# define DXB(_) (((_) << 1) | 1)
# define next __Chtholly__
# define x1 __Mercury__
# define y1 __bbtl04__
# define index __ikooo__
using namespace std;
const int maxn = 200009 << 1;
const int maxm = 200009;
struct node
{
int l, r, sum;
}tree[maxn * 25];
int data, tree_tot, root[maxm];
void update(int pre, int &now, int nl, int nr)
{
now = ++ tree_tot;
tree[now] = tree[pre];
++ tree[now].sum;
if(nl == nr) return ;
int mid = midf(nl, nr);
if(data <= mid) update(tree[pre].l, tree[now].l, nl, mid);
else update(tree[pre].r, tree[now].r, mid + 1, nr);
}
int query(int x, int y, int lca, int falca, int nl, int nr)
{
if(nl == nr) return nl;
int sum1 = tree[tree[y].l].sum - tree[tree[x].l].sum,
sum2 = tree[tree[falca].l].sum - tree[tree[lca].l].sum, mid = midf(nl, nr);
if(sum1 == 0 && sum2 == 0) return nl;
if(sum1 + sum2 <= mid - nl) return query(tree[x].l, tree[y].l, tree[lca].l, tree[falca].l, nl, mid);
else return query(tree[x].r, tree[y].r, tree[lca].r, tree[falca].r, mid + 1, nr);
}
struct eddy
{
int to, next;
}edg_2[maxn];
int fir[maxm], eddy_tot;
int deep[maxm], ffather[maxm], son[maxm], size[maxm], head[maxm];
int id[maxm], pid[maxm], id_tot;
int A[maxm];
int n, m, u, v, ans, lca;
void addedge(int f, int t)
{
edg_2[++ eddy_tot].to = t;
edg_2[eddy_tot].next = fir[f];
fir[f] = eddy_tot;
}
void dfs1(int u, int pre, int d)
{
deep[u] = d;
ffather[u] = pre;
son[u] = 0;
size[u] = 1;
for(int i = fir[u]; i; i = edg_2[i].next)
{
if(edg_2[i].to != pre)
{
dfs1(edg_2[i].to, u, d + 1);
size[u] += size[edg_2[i].to];
if(! son[u] || size[edg_2[i].to] > size[son[u]])
son[u] = edg_2[i].to;
}
}
}
void getpos(int u, int sp, int pre)
{
head[u] = sp;
id[u] = ++ id_tot;
pid[id_tot] = u;
data = A[u];
update(root[id[pre]], root[id_tot], 1, n + 1); //这里建树
if(son[u]) getpos(son[u], sp, u);
for(int i = fir[u]; i; i = edg_2[i].next)
{
if(edg_2[i].to != son[u] && edg_2[i].to != ffather[u])
getpos(edg_2[i].to, edg_2[i].to, u);
}
}
int LCA(int u, int v)
{
int f1 = head[u], f2 = head[v];
while(f1 != f2)
{
if(deep[f1] < deep[f2])
{
swap(f1, f2);
swap(u, v);
}
u = ffather[f1]; f1 = head[u];
}
if(u == v) return u;
if(deep[u] > deep[v]) swap(u, v);
return u;
}
int main()
{
scanf("%d", &n);
eddy_tot = tree_tot = 0;
for(register int i = 1, f, t; i < n; ++ i)
{
scanf("%d %d", &f, &t);
addedge(f, t);
addedge(t, f);
}
for(register int i = 1; i <= n; ++ i) scanf("%d", A + i), ++ A[i];
dfs1(1, -1, 1);
getpos(1, 1, 0);
scanf("%d", &m);
while(m --)
{
scanf("%d %d", &u, &v);
lca = LCA(u, v);
ans = query(root[id[lca]], root[id[u]], root[id[ffather[lca]]], root[id[v]], 1, n + 1) - 1;
printf("%d\n", ans);
}
return 0;
}
/*
5
1 2
2 4
2 5
1 3
0 1 4 2 3
5
1 1
1 2
2 3
1 4
4 5
*/
文章目录
  1. 1. 主席树初探
    1. 1.1. 概述
    2. 1.2. 主席树的结构
    3. 1.3. 主席树的功能实现
      1. 1.3.1. 建树
      2. 1.3.2. 在两棵树上查询差值
      3. 1.3.3. 在多棵树上查询差值
    4. 1.4. 主席树模板
      1. 1.4.1. POJ 2104
      2. 1.4.2. SPOJ COT
      3. 1.4.3. HDU 4417
      4. 1.4.4. BZOJ 3932
    5. 1.5. 主席树套树状数组
      1. 1.5.1. ZOJ 2112
      2. 1.5.2. HDU 5919 询问每个数的后继位置 & 不相同的数字个数
    6. 1.6. 主席树与 LCA
      1. 1.6.1. XDOJ 1248
{{ live2d() }}