Skip to content

线段树学习笔记

笔记约 3.4 千字
检测到 KaTeX 加载失败,可能会导致文中的数学公式无法正常渲染。

线段树(Segment Tree)是一种用来维护区间的数据结构。

与树状数组相比,线段树可以实现时间复杂度在 O(logn)O(\log n) 级别的区间修改,还可以同时支持多种操作(加、乘、最值等)。

#操作列表

  • 上传(pushup)
  • 建树(build)
  • 下放懒标记(pushdown)
  • 区间查询(query)
  • 区间修改(modify)

#通用操作

#存储线段树

线段树是一个典型的二叉树,因此我们可以使用一个数组来存储线段树。

分析:很容易就知道线段树的深度为 logn\lceil\log n\rceil ,可得线段树的节点个数为 2logn+112^{\left\lceil\log{n}\right\rceil+1}-1,粗略估计开大小为 4n4n 的数组即可(可以使用位运算写成 n << 2)。

C++
1
2
3
4
5
6
7
8
9
10
11
12
13
struct node {
int l, r;
long long s, d;

node() {
l = r = s = d = 0;
}
node(int _l, int _r) {
l = _l;
r = _r;
s = d = 0;
}
} tr[100005 << 2];
变量名 用途
l 区间的左端点
r 区间的右端点
s 区间和
d 懒标记

#上传(pushup)

之所以把上传放在建树前面说,是因为建树的时候要用到它。

C++
1
2
3
4
5
6
7
/**
* 上传信息
* @param u 父节点下标
*/
inline void pushup(int u) {
tr[u].s = tr[u << 1].s + tr[u << 1 | 1].s;
}

将两个子节点所代表的区间的和相加即为父区间的和。

#建树(build)

C++
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
/**
* 建立线段树
* @param u 根节点下标
* @param l 左端点
* @param r 右端点
*/
void build(int u, int l, int r) {
tr[u] = node(l, r);
if (l == r) {
tr[u].s = a[l];
return;
}
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
pushup(u);
}

先初始化当前区间,接下来分两种情况:

  1. 若当前区间长度等于 1 (l=r)1\ \ (l = r) ,则直接将当前区间的区间和赋值为 a[l] 即可。
  2. 若当前区间长度大于 1 (l<r)1\ \ (l < r) ,则将区间平均分成两部分(即从 (l+r)/2\lfloor(l+r)/2\rfloor 处断开分为两个区间,可写作 l + r >> 1),继续向下递归建立左右子树即可。

需要注意的是两个子区间没有交集,因此左子树的左端点是 ll 、右端点是 midmid ,右子树的左端点是 mid+1mid+1 、右端点是 rr

#区间查询(query)

C++
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
/**
* 区间查询
* @param u 父节点
* @param l 左端点
* @param r 右端点
*/
long long query(int u, int l, int r) {
if (tr[u].l >= l && tr[u].r <= r) { // 被包含直接返回当前区间和
return tr[u].s;
}
int mid = tr[u].l + tr[u].r >> 1;
long long s = 0;
pushdown(u); // 下放懒标记
if (l <= mid) s += query(u << 1, l, r); // 和左侧有交集
if (r > mid) s += query(u << 1 | 1, l, r); // 和右侧有交集
return s;
}
  1. 如果这个区间被包含,直接返回该区间的和。
  2. 如果和左儿子区间有交集,则继续向左儿子区间递归查询。
  3. 如果和右儿子区间有交集,则继续向右儿子区间递归查询。

需要注意的是在递归查询左右儿子区间之前要先下放懒标记(pushdown),否则会出问题。

#区间加

本部分以 洛谷 P3372 【模板】线段树 1 为例子来简述一下线段树区间加的实现。

#下放懒标记(pushdown)

C++
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
/**
* 下放懒标记
* @param u 父节点下标
*/
inline void pushdown(int u) {
if (!tr[u].d) return;
// 处理左子树
tr[u << 1].d += tr[u].d;
tr[u << 1].s += (tr[u << 1].r - tr[u << 1].l + 1) * tr[u].d;
// 处理右子树
tr[u << 1 | 1].d += tr[u].d;
tr[u << 1 | 1].s += (tr[u << 1 | 1].r - tr[u << 1 | 1].l + 1) * tr[u].d;
// 清除懒标记
tr[u].d = 0;
}

这部分代码其实很简单。

将左、右子树的懒标记加上父节点的懒标记,区间和加上 (rl+1)×d(r - l + 1)\times dr,lr, l 分别表示儿子区间的左、右端点,dd表示父节点的懒标记),最后清空父节点的懒标记即可。

#区间修改(modify)

C++
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
/**
* 区间修改
* @param u 父节点下标
* @param l 左端点
* @param r 右端点
* @param d 增加的值
*/
void modify(int u, int l, int r, int d) {
if (tr[u].l >= l && tr[u].r <= r) { // 被包含直接修改
tr[u].d += d;
tr[u].s += (tr[u].r - tr[u].l + 1) * d;
return;
}
int mid = tr[u].l + tr[u].r >> 1;
pushdown(u); // 下放懒标记
if (l <= mid) modify(u << 1, l, r, d); // 和左侧有交集
if (r > mid) modify(u << 1 | 1, l, r, d); // 和右侧有交集
pushup(u); // 上传新信息
}

区间修改和区间查询的实现相似。

  1. 如果当前区间被包含,直接添加懒标记并修改区间和。
  2. 如果和左儿子区间有交集,则继续向左儿子区间递归修改。
  3. 如果和右儿子区间有交集,则继续向右儿子区间递归修改。

需要注意的是在递归修改左右儿子区间之前要先下放懒标记(pushdown),修改完成以后要上传新信息(pushup),否则会出问题。

#区间加、乘

本部分以 洛谷 P3373 【模板】线段树 2 为例子来简述一下线段树区间加、乘的实现。

在编写之前,结构体中需要先添加一个乘法的懒标记 x ,并将其赋初值为 11 ,修改之后的结构体如下所示。

C++
1
2
3
4
5
6
7
8
9
10
11
12
13
14
struct node {
int l, r;
long long s, d, x;

node() {
l = r = s = d = 0;
x = 1;
}
node(int _l, int _r) {
l = _l, r = _r;
s = d = 0;
x = 1;
}
} tr[100005 << 2];

#下放懒标记(pushdown)

C++
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
/**
* 下放懒标记
* @param u 父节点下标
* @attention 先乘后加
*/
void pushdown(int u) {
// 左子树
tr[u << 1].s = ((tr[u << 1].s * tr[u].x) + (tr[u << 1].r - tr[u << 1].l + 1) * tr[u].d) % p;
tr[u << 1].x = tr[u << 1].x * tr[u].x % p;
tr[u << 1].d = (tr[u << 1].d * tr[u].x + tr[u].d) % p;
// 右子树
tr[u << 1 | 1].s = ((tr[u << 1 | 1].s * tr[u].x) + (tr[u << 1 | 1].r - tr[u << 1 | 1].l + 1) * tr[u].d) % p;
tr[u << 1 | 1].x = tr[u << 1 | 1].x * tr[u].x % p;
tr[u << 1 | 1].d = (tr[u << 1 | 1].d * tr[u].x + tr[u].d) % p;
// 清除懒标记
tr[u].d = 0;
tr[u].x = 1;
}

此处遵循先乘后加的原则,先修改区间和,再修改乘法懒标记,最后修改加法懒标记,不要忘记 mod p\bmod\ p

注意:此处清除懒标记的时候,乘法懒标记应修改为 11

#区间修改(modify)

C++
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
/**
* 区间修改
* @details 修改区间 [l, r] 中的每一个数
* @param u 父节点下标
* @param l 左端点
* @param r 右端点
* @param x 乘上的数
* @param d 增加的值
*/
void modify(int u, int l, int r, long long x, long long d) {
// 被包含直接修改
if (tr[u].l >= l && tr[u].r <= r) {
tr[u].s = ((tr[u].s * x) + (tr[u].r - tr[u].l + 1) * d) % p;
tr[u].x = tr[u].x * x % p;
tr[u].d = (tr[u].d * x + d) % p;
return;
}
int mid = tr[u].l + tr[u].r >> 1;
pushdown(u); // 下放懒标记
if (l <= mid) modify(u << 1, l, r, x, d); // 和左侧有交集
if (r > mid) modify(u << 1 | 1, l, r, x, d); // 和右侧有交集
pushup(u); // 上传新信息
}

大体上和加法的修改函数一样,而在修改时与下放懒标记做法相同,遵循先乘后加的原则。

调用的时候若只需要使用乘法部分,加数设置为 00 即可。若只需要使用加法部分,乘数设置为 11 即可。

#全部代码

到这里基本操作就说完了,下面是全部的 AC 代码。

#区间加

C++
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#include <bits/stdc++.h>

using namespace std;

/**
* 线段树节点
*/
struct node {
int l, r;
long long s, d;

node() {
l = r = s = d = 0;
}
node(int _l, int _r) {
l = _l;
r = _r;
s = d = 0;
}
} tr[100005 << 2];
int n, m, op, x, y, k, a[100005];

/**
* 上传区间和
* @param u 父节点下标
*/
void pushup(int u) {
tr[u].s = tr[u << 1].s + tr[u << 1 | 1].s;
}

/**
* 下放懒标记
* @param u 父节点下标
*/
void pushdown(int u) {
if (!tr[u].d) return;
// 处理左子树
tr[u << 1].d += tr[u].d;
tr[u << 1].s += (tr[u << 1].r - tr[u << 1].l + 1) * tr[u].d;
// 处理右子树
tr[u << 1 | 1].d += tr[u].d;
tr[u << 1 | 1].s += (tr[u << 1 | 1].r - tr[u << 1 | 1].l + 1) * tr[u].d;
// 清除懒标记
tr[u].d = 0;
}

/**
* 建立线段树
* @param u 根节点下标
* @param l 左端点
* @param r 右端点
*/
void build(int u, int l, int r) {
tr[u] = node(l, r);
if (l == r) {
tr[u].s = a[l];
return;
}
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
pushup(u);
}

/**
* 区间修改
* @param u 父节点下标
* @param l 左端点
* @param r 右端点
* @param d 增加的值
*/
void modify(int u, int l, int r, int d) {
if (tr[u].l >= l && tr[u].r <= r) { // 被包含直接修改
tr[u].d += d;
tr[u].s += (tr[u].r - tr[u].l + 1) * d;
return;
}
int mid = tr[u].l + tr[u].r >> 1;
pushdown(u); // 下放懒标记
if (l <= mid) modify(u << 1, l, r, d); // 和左侧有交集
if (r > mid) modify(u << 1 | 1, l, r, d); // 和右侧有交集
pushup(u); // 上传新信息
}

/**
* 区间查询
* @param u 父节点
* @param l 左端点
* @param r 右端点
*/
long long query(int u, int l, int r) {
if (tr[u].l >= l && tr[u].r <= r) { // 被包含直接返回
return tr[u].s;
}
int mid = tr[u].l + tr[u].r >> 1;
long long s = 0;
pushdown(u); // 下放懒标记
if (l <= mid) s += query(u << 1, l, r); // 和左侧有交集
if (r > mid) s += query(u << 1 | 1, l, r); // 和右侧有交集
return s;
}

int main() {
cin >> n >> m;
for (int i = 1; i <= n; i++) {
cin >> a[i];
}
build(1, 1, n);
for (int i = 0; i < m; i++) {
cin >> op >> x >> y;
if (op == 1) {
cin >> k;
modify(1, x, y, k);
}
else if (op == 2) {
cout << query(1, x, y) << endl;
}
}
return 0;
}

#区间加、乘

C++
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#include <bits/stdc++.h>

using namespace std;

struct node {
int l, r;
long long s, d, x;

node() {
l = r = s = d = 0;
x = 1;
}
node(int _l, int _r) {
l = _l, r = _r;
s = d = 0;
x = 1;
}
} tr[100005 << 2];
int n, m, p, op, x, y;
long long k, a[100005];

/**
* 上传信息
* @param u 父节点下标
*/
void pushup(int u) {
tr[u].s = (tr[u << 1].s + tr[u << 1 | 1].s) % p;
}

/**
* 下放懒标记
* @param u 父节点下标
* @attention 先乘后加
*/
void pushdown(int u) {
// 左子树
tr[u << 1].s = ((tr[u << 1].s * tr[u].x) + (tr[u << 1].r - tr[u << 1].l + 1) * tr[u].d) % p;
tr[u << 1].x = tr[u << 1].x * tr[u].x % p;
tr[u << 1].d = (tr[u << 1].d * tr[u].x + tr[u].d) % p;
// 右子树
tr[u << 1 | 1].s = ((tr[u << 1 | 1].s * tr[u].x) + (tr[u << 1 | 1].r - tr[u << 1 | 1].l + 1) * tr[u].d) % p;
tr[u << 1 | 1].x = tr[u << 1 | 1].x * tr[u].x % p;
tr[u << 1 | 1].d = (tr[u << 1 | 1].d * tr[u].x + tr[u].d) % p;
// 清除懒标记
tr[u].d = 0;
tr[u].x = 1;
}

/**
* 建立线段树
* @param u 根节点下标
* @param l 左端点
* @param r 右端点
*/
void build(int u, int l, int r) {
tr[u] = node(l, r);
if (l == r) {
tr[u].s = a[l] % p;
return;
}
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
pushup(u);
}

/**
* 区间修改
* @details 将区间 [l, r] 中的每一个数加上 d
* @param u 父节点下标
* @param l 左端点
* @param r 右端点
* @param x 乘上的数
* @param d 增加的值
*/
void modify(int u, int l, int r, long long x, long long d) {
// 被包含直接修改
if (tr[u].l >= l && tr[u].r <= r) {
tr[u].s = ((tr[u].s * x) + (tr[u].r - tr[u].l + 1) * d) % p;
tr[u].x = tr[u].x * x % p;
tr[u].d = (tr[u].d * x + d) % p;
return;
}
int mid = tr[u].l + tr[u].r >> 1;
pushdown(u); // 下放懒标记
if (l <= mid) modify(u << 1, l, r, x, d); // 和左侧有交集
if (r > mid) modify(u << 1 | 1, l, r, x, d); // 和右侧有交集
pushup(u); // 上传新信息
}

/**
* 区间查询
* @param u
* @param l
* @param r
* @return int
*/
long long query(int u, int l, int r) {
if (tr[u].l >= l && tr[u].r <= r) { // 被包含直接返回
return tr[u].s;
}
int mid = tr[u].l + tr[u].r >> 1;
long long s = 0;
pushdown(u); // 下放懒标记
if (l <= mid) s = query(u << 1, l, r); // 和左侧有交集
if (r > mid) s = (s + query(u << 1 | 1, l, r)) % p; // 和右侧有交集
return s;
}

int main() {
cin >> n >> m >> p;
for (int i = 1; i <= n; i++) {
cin >> a[i];
}
build(1, 1, n);
while (m--) {
cin >> op >> x >> y;
if (op == 1) {
cin >> k;
modify(1, x, y, k, 0);
}
else if (op == 2) {
cin >> k;
modify(1, x, y, 1, k);
}
else if (op == 3) {
cout << query(1, x, y) % p << endl;
}
}
return 0;
}