主席树
Eva.Q Lv9

有些事情,回头看,才发现是多么幸运;
有些事情,回头看,才发现是多么可惜。

“期末考试”考完啦!!一个寒假的小摆烂,在期末的赶脚中结束了,趁着还没彻底放飞自我,Co 老师教了我主席树。

主席树

主席树是一种基于线段树的数据结构,相比线段树不同的地方在于主席树支持 可持久化,也就是可以支持回退,访问历史版本的数据结构。

如果我们要让线段树支持可持久化,那么我们首先会直接想到每进行一次操作都将线段树复制一遍然后对复制后的线段树进行修改操作。但显然这样时间复杂度会变为 ,而空间复杂度会变为 ,显然这样的时空复杂度是我们完全不能够接受的。

可以发现,对于每一次单点修改操作,由于线段树的特性,这个点和这个点的所以直到根节点的祖先节点都需要更新,而除了这些节点之外的其它节点并不受影响。

所以每次单点修改操作最多需要修改 次,所以我们可以考虑每次只修改这 个节点即可。

那么对于每一次修改,我们都需要创造出一个新的树根 Root ,然后对于新的线段树中的任意节点,如果以该节点代表的区间中不包含被修改的节点,那么也就意味着 以这个节点为根的子树不需要修改

于是我们就可以让这个节点的父亲节点的与这个节点 对应 的儿子直接指向 修改前 的线段树上的该节点。

若该节点即对应被修改的节点,那么就可以直接建立新节点。

若该节点对应的区间中包含该节点,那么我们再对该节点的左右儿子分别处理即可。

问题的关键是确定版本是什么

例 1:求区间第 小的值

给定一个序列 次询问,每次询问区间 的第 小值。

小的值,也就是说该区间里有 个小于它的数,因为要比较数值大小,所以采用值域线段树。在值域线段数中,根据左儿子的值域范围内出现的数字的个数可以判断第 小的值,应该在左儿子值域范围里,还是在右儿子值域范围里。

观察到区间里不超过 的数的个数可以转换成单个数字出现次数的前缀和进行维护,也就是要维护一个二维的数组 表示前 个数中数字 出现的次数。因此对于这道题,版本是前 个数。

每次查询转换成 ,在线段树上二分找到

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
#include<bits/stdc++.h>  
using namespace std;

inline int rd() {
    int x = 0;
    bool f = 0;
    char c = getchar();
    for (; !isdigit(c); c = getchar()) f |= (c == '-');
    for (; isdigit(c); c = getchar()) x = x * 10 + (c ^ 48);
    return f ? -x : x;
}

#define N 100007
int a[N], rt[N], totnode;

struct node {int ls, rs, val;} c[N << 5];

inline int newnode() {return ++totnode;}

inline void pushup(int nw) {
    c[nw].val = c[c[nw].ls].val + c[c[nw].rs].val;
}

void update(int &nw, int l, int r, int p, int val) {
    int t = newnode();
    c[t] = c[nw]; nw = t;
    if (l == r) {c[nw].val += val; return;}
    int mid = (l + r) >> 1;
    if (p <= mid) update(c[nw].ls, l, mid, p, val);
    else update(c[nw].rs, mid + 1, r, p, val);
    pushup(nw);
}

int query(int nwl, int nwr, int l, int r, int k) {
    if (l == r) return l;
    int mid = (l + r) >> 1;
    int nwcnt = c[c[nwr].ls].val - c[c[nwl].ls].val;
    if (nwcnt >= k) return query(c[nwl].ls, c[nwr].ls, l, mid, k);
    else return query(c[nwl].rs, c[nwr].rs, mid + 1, r, k - nwcnt);
}

int main() {
    int n = rd(), m = rd(), mx = 0;
    for (int i = 1; i <= n; ++i) {a[i] = rd(); mx = max(mx, a[i]);}
    for (int i = 1; i <= n; ++i) {
        rt[i] = rt[i - 1];
        update(rt[i], 1, mx, a[i], 1);
    }
    for (int i = 1; i <= m; ++i) {
        int l = rd(), r = rd(), k = rd();
        printf("%d\n", query(rt[l - 1], rt[r], 1, mx, k));
    }
    return 0;
}

例2:求区间

首先,可以观察到答案不会超过 , 因此大于 的数可以忽略。

假设每次询问的右端点都是 ,那只需要知道每个数字 最后一次出现的地方 ,就可以解决这个问题,每次询问依次扫描 ,若 ,则 没有出现在区间 。复杂度是

用线段树优化每次询问时候的复杂度,用线段树维护区间 的最小值。如果
的最小值都小于 那么 一定是 其中的一个,所以一定在左子树中,否则一定在右子树中,这样复杂度就是

所以可以先把询问离线,再依次把每个 修改为

如果强制在线的话,就需要用主席树了,主席树的版本是区间的右端点,左端点作为判断的条件。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
#include<bits/stdc++.h>  
using namespace std;

inline int rd() {
    int x = 0;
    bool f = 0;
    char c = getchar();
    for (; !isdigit(c); c = getchar()) f |= (c == '-');
    for (; isdigit(c); c = getchar()) x = x * 10 + (c ^ 48);
    return f ? -x : x;
}

#define N 200007
int a[N], rt[N], totnode;

struct node {int ls, rs, mn;} c[N << 5];

inline int newnode() {return ++totnode;}

inline void pushup(int nw) {
    c[nw].mn = min(c[c[nw].ls].mn, c[c[nw].rs].mn);
}

void update(int &nw, int l, int r, int p, int val) {
    int t = newnode();
    c[t] = c[nw]; nw = t;
    if (l == r) {c[nw].mn = val; return;}
    int mid = (l + r) >> 1;
    if (p <= mid) update(c[nw].ls, l, mid, p, val);
    else update(c[nw].rs, mid + 1, r, p, val);
    pushup(nw);
}

int query(int nw, int l, int r, int lim) {
    if (l == r) return l;
    int mid = (l + r) >> 1;
    if (c[c[nw].ls].mn < lim) return query(c[nw].ls, l, mid, lim);
    else return query(c[nw].rs, mid + 1, r, lim);
}

int main() {
    int n = rd(), m = rd();
    for (int i = 1; i <= n; ++i) {
        rt[i] = rt[i - 1];
        update(rt[i], 0, n, rd(), i);
    }
    for (int i = 1; i <= m; ++i) {
        int l = rd(), r = rd();
        printf("%d\n", query(rt[r], 0, n, l));
    }
    return 0;
}

例3:树上路径中位数

https://codeforces.com/gym/101161 E

求树上任意两点 路径上边权的中位数。

首先因为是树上路径,所以很自然会想到 ;又因为求的是中位数,与数字的大小关系有关,所以要用值域线段树。设路径长度为 ,根据路径长度的奇偶性,将问题转换成路径上第 小的数或第 和第 小的平均值。

版本 维护的是根到 这条链上的信息,这样版本间的关系就是原树的关系。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#include<bits/stdc++.h>
using namespace std;

#define pii pair<int, int>
#define mp make_pair
#define pb push_back
#define fr first
#define sc second

inline int rd() {
int x = 0;
bool f = 0;
char c = getchar();
for (; !isdigit(c); c = getchar()) f |= (c == '-');
for (; isdigit(c); c = getchar()) x = x * 10 + (c ^ 48);
return f ? -x : x;
}

#define N 100007
int a[N], rt[N], totnode;

struct node {int ls, rs, val;} c[N << 5];

inline int newnode() {return ++totnode;}

inline void pushup(int nw) {
c[nw].val = c[c[nw].ls].val + c[c[nw].rs].val;
}

void update(int &nw, int l, int r, int p, int val) {
int t = newnode();
c[t] = c[nw]; nw = t;
if (l == r) {c[nw].val += val; return;}
int mid = (l + r) >> 1;
if (p <= mid) update(c[nw].ls, l, mid, p, val);
else update(c[nw].rs, mid + 1, r, p, val);
pushup(nw);
}

int query(int nwu, int nwv, int nwlca, int l, int r, int k) {
if (l == r) return l;
int mid = (l + r) >> 1;
int cnt = c[c[nwu].ls].val + c[c[nwv].ls].val - 2 * c[c[nwlca].ls].val;
if (cnt >= k) return query(c[nwu].ls, c[nwv].ls, c[nwlca].ls, l, mid, k);
else return query(c[nwu].rs, c[nwv].rs, c[nwlca].rs, mid + 1, r, k - cnt);
}

#define M 50007

vector<pii> e[M];
int dep[M], fa[M][18], t, mx;

inline void dfs(int u, int w) {
rt[u] = rt[fa[u][0]];
if (w != 0) update(rt[u], 1, mx, w, 1);
for (int i = 1; i <= t; ++i)
fa[u][i] = fa[fa[u][i - 1]][i - 1];
for (auto cur : e[u]) {
int v = cur.fr;
if (v != fa[u][0]) {
fa[v][0] = u;
dep[v] = dep[u] + 1;
dfs(v, cur.sc);
}
}
}

inline int lca(int u, int v) {
if (dep[u] > dep[v]) swap(u, v);
for (int i = t; i >= 0; --i)
if (dep[fa[v][i]] >= dep[u]) v = fa[v][i];
if (u == v) return u;
for (int i = t; i >= 0; --i)
if (fa[u][i] != fa[v][i]) {
u = fa[u][i]; v = fa[v][i];
}
return fa[u][0];
}

inline void work() {
totnode = 0; int n = rd(); t = __lg(n - 1) + 1; mx = 0;
for (int i = 1; i <= n; ++i) {
e[i].clear(); dep[i] = 0;
for (int j = 1; j <= t; ++j) fa[i][j] = 0;
}
for (int i = 1; i < n; ++i) {
int u = rd(), v = rd(), w = rd();
mx = max(mx, w);
e[u].pb(mp(v, w)); e[v].pb(mp(u, w));
}
fa[1][0] = 1; dfs(1, 0);
for (int q = rd(); q; --q) {
int u = rd(), v = rd();
int tlca = lca(u, v), len = dep[u] + dep[v] - 2 * dep[tlca];
if (len % 2) {
printf("%.1lf\n", 1.0 * query(rt[u], rt[v], rt[tlca], 1, mx, len / 2 + 1));
} else {
printf("%.1lf\n", 1.0 * (query(rt[u], rt[v], rt[tlca], 1, mx, len / 2) + query(rt[u], rt[v], rt[tlca], 1, mx, len / 2 + 1)) / 2);
}
}
}

int main() {
for (int t = rd(); t; --t) work();
return 0;
}

例4:可持久化标记

http://acm.hdu.edu.cn/showproblem.php?pid=4348

标记永久化顾名思义,指标记一旦被打上,就不再下传或清空。而是在询问的过程中计算每个遇到结点对当前询问的影响。

因为对于主席树,每次修改操作我们在之前基础上建一棵新树,但是他们的一些儿子是共用的,如果直接用线段树的 标记 一下,之前的版本也会受到影响。而我们的想法就是只修改当前版本的标记。标记不用 ,自然也不用 (开始建树的时候,可能要

其实永久性标记相当于是在每次修改,都会进行更新操作,但也加入了一个 标记,而在区间查询时,不断递归,同时也需要更新一个权值,代表我要找的区间需要增加的标记值。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;

#define pii pair<int, int>
#define mp make_pair
#define pb push_back
#define fr first
#define sc second

inline int rd() {
int x = 0;
bool f = 0;
char c = getchar();
for (; !isdigit(c); c = getchar()) f |= (c == '-');
for (; isdigit(c); c = getchar()) x = x * 10 + (c ^ 48);
return f ? -x : x;
}

#define N 100007
int a[N];

int rt[N], totnode, totrt;

struct node{int ls, rs; ll lazy, sum;} c[N << 5];

inline int newnode() {return ++totnode;}

inline void pushup(int rt) {c[rt].sum = c[c[rt].ls].sum + c[c[rt].rs].sum;}

void build(int &nw, int l, int r) {
if (!nw) nw = newnode();
c[nw].lazy = c[nw].ls = c[nw].rs = 0;
if (l == r) {c[nw].sum = a[l]; return;}
int mid = (l + r) >> 1;
build(c[nw].ls, l, mid);
build(c[nw].rs, mid + 1, r);
pushup(nw);
}

void update(int &nw, int l, int r, int L, int R, int val) {
int t = newnode();
c[t] = c[nw]; nw = t;
c[nw].sum += 1ll * val * (min(R, r) - max(l, L) + 1);
if (L <= l && r <= R) {
c[nw].lazy += val;
return;
}
int mid = (l + r) >> 1;
if (L <= mid) update(c[nw].ls, l, mid, L, R, val);
if (mid < R) update(c[nw].rs, mid + 1, r, L, R, val);
}

ll query(int nw, int l, int r, int L, int R, ll lazy) {
if (L <= l && r <= R) return c[nw].sum + 1ll * lazy * (r - l + 1);
ll ans = 0;
int mid = (l + r) >> 1;
if (L <= mid) ans += query(c[nw].ls, l, mid, L, R, 1ll * lazy + c[nw].lazy);
if (mid < R) ans += query(c[nw].rs, mid + 1, r, L, R, 1ll * lazy + c[nw].lazy);
return ans;
}

int n, m;

inline void work() {
totrt = totnode = 0;
for (int i = 1; i <= n; ++i) a[i] = rd();
build(rt[totrt], 1, n); totrt++;
for (int i = 1; i <= m; ++i) {
char op = getchar();
for (; op != 'C' && op != 'Q' && op != 'H' && op != 'B'; op = getchar());
if (op == 'C') {
int l = rd(), r = rd(), d = rd();
rt[totrt] = rt[totrt - 1];
update(rt[totrt], 1, n, l, r, d);
totrt++;
} else if (op == 'Q') {
int l = rd(), r = rd();
printf("%lld\n", query(rt[totrt - 1], 1, n, l, r, 0));
} else if (op == 'H') {
int l = rd(), r = rd(), t = rd();
printf("%lld\n", query(rt[t], 1, n, l, r, 0));
} else {totrt = rd() + 1;}
}
return;
}

int main() {
while (cin >> n >> m) work();
return 0;
}
  • Post title:主席树
  • Post author:Eva.Q
  • Create time:2023-02-21 19:31:06
  • Post link:https://qyy/2023/02/21/Algorithm/Persistent Segment Tree/
  • Copyright Notice:All articles in this blog are licensed under BY-NC-SA unless stating additionally.