0%

线段树学习笔记

摘要

线段树是一种常用的数据结构,这是一篇关于本蒟蒻学习线段树的一些笔记,希望对大家有帮助。

引言

线段树是一种基于分治思想的二叉树数据结构,主要用于高效处理区间查询和区间更新问题。
所以在学习线段树之前,要理解分治思想,同时还要看得懂递归以及了解树这种数据结构。
关于线段树,b占董晓老师的视频讲的很详细,视频地址: https://www.bilibili.com/video/BV1G34y1L7b3?vd_source=7fb744401da62aff10b87abf48fbcbf7
这篇文章主要说我在学习线段树遇到的一些笔记。
线段树示例线段树示意图

各种操作实现

在代码开头,我们应该定义结构体,或者数组,来管理线段树的相关变量,我更喜欢结构体,所以以结构体来说明。
定义

1
2
3
4
5
6
7
8
#define lc p<<1
#define rc p<<1|1
#define N 100005
#define int long long
int w[N],n,m;
struct node{
int l,r,sum,add;
}tr[4*N];

这里结构体大小,为什么开4*N呢?
因为我们知道线段树其实是一颗平衡二叉树,假设一颗树一共有m层,第m层没有排满,m层之前都已经排满了。那么m-1
层节点个数假设为n,那么m-2层到第一层节点个数一共是n-1个,m层排满的话有2n个节点。
所以开4n空间完全足够了。
向上更新:

1
2
3
4
void pushup(ll p)//向上更新 
{
tr[p].sum=tr[lc].sum+tr[rc].sum;
}

每次区间修改完成,要回溯,往上走的时候,要pushup。
向下更新:

1
2
3
4
5
6
7
8
9
10
void pushdown(ll p)//向下更新,更新之前要处理懒标记,但是只用向下处理一层,因为每次直接递归全部处理会浪费时间
{
if(tr[p].add){//加法懒标记不为0
tr[lc].sum+=tr[p].add*(tr[lc].r-tr[lc].l+1);
tr[rc].sum+=tr[p].add*(tr[rc].r-tr[rc].l+1);
tr[lc].add+=tr[p].add;//懒标记向下传递
tr[rc].add+=tr[p].add;
tr[p].add=0;
}
}

应该在每次裂开之前调用向下更新
建树:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
   void build(int p,int l,int r)
{
tr[p]={l,r,0,0};//将我括号里面的数值赋值给结构体里面的参数,参数的顺序和我上面结构体内部定义的顺序一致
if(l==r)//左孩子数值和有孩子一致说明是叶子节点
{
tr[p].sum=w[l];//将我初始数组里面的值赋给这个节点
return ;//是叶子节点就返回
}
int m=l+r>>1;//不是就裂开
build(lc,l,m);//递归建立左子树
build(rc,m+1,r);//递归建立右子树
pushup(p);//向上更新
return ;
}

区间修改:这里是加法,即给区间内每个元素加上k。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
void updatesect(ll p, ll x, ll y,ll k)//区间修改 
{
if(x<=tr[p].l&&tr[p].r<=y)//覆盖则修改 就是这个树的区间完全被要修改的区间包含
{
tr[p].sum+=(tr[p].r-tr[p].l+1)*k;//加上区间节点个k,这样就不用向下搜索了,而是更新懒标记。下次需要向下搜索再更新
tr[p].add+=k;
return;
}
int m=tr[p].l+tr[p].r>>1;//不覆盖就裂开
pushdown(p);
if(x<=m)updatesect(lc,x,y,k);//区间没有包含完就修改没有被修改的左右子树部分
if(y>m)updatesect(rc,x,y,k);
pushup(p);//向上更新
}

点修改:

1
2
3
4
5
6
7
8
9
10
11
12
13
void updatedot(ll p, ll x,ll k)//点修改 
{
if(tr[p].l==x&&tr[p].r==x)//到达叶子节点
{
tr[p].sum+=k;
return;
}
pushdown(p);
ll m=tr[p].l+tr[p].r>>1;//不是叶子节点就裂开
if(x<=m) updatedot(lc,x,k);
if(x>m) updatedot(rc,x,k);
tr[p].sum=tr[lc].sum+tr[rc].sum;
}

区间查询:这个和区间修改有点像,可以类比思考

1
2
3
4
5
6
7
8
9
10
11
ll query(ll p, ll x,ll y)
{
if(x<=tr[p].l&&tr[p].r<=y)//覆盖则返回
return tr[p].sum;
ll m=tr[p].l+tr[p].r>>1;//不覆盖则裂开
pushdown(p);
ll sum=0;
if(x<=m) sum+=query(lc,x,y);//调用递归前要一定要pushdown
if(y>m) sum+=query(rc,x,y);
return sum;
}

练一练

看看是否能独立实现。
学习后可以做一下洛谷题:https://www.luogu.com.cn/problem/P3372
这里附上ac代码,由于这个代码是我二敲的,试了一下signed main

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
#include<bits/stdc++.h>
using namespace std;
#define lc p<<1
#define rc p<<1|1
#define N 100005
#define int long long
int w[N],n,m;
struct node{
int l,r,sum,add;
}tr[4*N];
void pushup(int p)
{
tr[p].sum=tr[lc].sum+tr[rc].sum;
return ;
}
void pushdown(int p)
{
if(tr[p].add)//向下传递
{
tr[lc].sum+=tr[p].add*(tr[lc].r-tr[lc].l+1);
tr[rc].sum+=tr[p].add*(tr[rc].r-tr[rc].l+1);
tr[lc].add+=tr[p].add;
tr[rc].add+=tr[p].add;
tr[p].add=0;
}
return;
}
void build(int p,int l,int r)
{
tr[p]={l,r,0,0};
if(l==r)
{
tr[p].sum=w[l];
return ;//是叶子节点就返回
}
int m=l+r>>1;//不是就裂开
build(lc,l,m);
build(rc,m+1,r);
pushup(p);
return ;
}

void update_sect(int p,int x,int y,int k)
{
if(tr[p].l>=x&&tr[p].r<=y)//全覆盖
{
tr[p].sum+=(tr[p].r-tr[p].l+1)*k;
tr[p].add+=k;
return ;//这里相当于是暂存懒标记,如果这些修改会用到,到时候再传下去
}
pushdown(p);
int m=tr[p].l+tr[p].r>>1;
if(x<=m)
{
update_sect(lc,x,y,k);
}
if(y>m)
{
update_sect(rc,x,y,k);
}
pushup(p);
return ;
}
int query(int p,int x,int y)
{
if(x<=tr[p].l&&y>=tr[p].r)
{
return tr[p].sum;//覆盖则返回
}
pushdown(p);
int sum=0;
int m=tr[p].l+tr[p].r>>1;
if(x<=m)
{
sum+=query(lc,x,y);
}
if(y>m)
{
sum+=query(rc,x,y);
}
return sum;
}
signed main()
{
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
cin>>n>>m;
for(int i=1;i<=n;i++)
{
cin>>w[i];
}
build(1,1,n);
while(m--)
{
int f,x,y,k;
cin>>f;
if(f==1)
{
cin>>x>>y>>k;
update_sect(1,x,y,k);
}
else
{
cin>>x>>y;
cout<<query(1,x,y)<<'\n';
}
}
return 0;
}

有兴趣的话还可以做一下区间乘法修改线段树: https://www.luogu.com.cn/problem/P3373
下面附上ac代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
#include<bits/stdc++.h>
using namespace std;
#define lc p<<1
#define rc p<<1|1//2*p+1
#define N 100005
using ll=long long;
ll n,mod,w[N],q;
struct node{
ll l,r,sum,add,mul;
}tr[4*N];
void pushup(ll p)//向上更新
{
tr[p].sum=tr[lc].sum+tr[rc].sum%mod;
}
void pushdown(ll p)//向下更新
{
// 先处理乘法,再处理加法(重要顺序!)
if(tr[p].mul!=1){//懒标记不为0
tr[lc].sum=(tr[lc].sum*tr[p].mul)%mod;
tr[rc].sum=(tr[rc].sum*tr[p].mul)%mod;
tr[lc].mul=(tr[lc].mul*tr[p].mul)%mod;
tr[rc].mul=(tr[rc].mul*tr[p].mul)%mod;
tr[lc].add=(tr[lc].add*tr[p].mul)%mod;// 加法标记也要乘
tr[rc].add=(tr[rc].add*tr[p].mul)%mod;
tr[p].mul=1;
}
if (tr[p].add) {
tr[lc].sum = (tr[lc].sum + tr[p].add * (tr[lc].r - tr[lc].l + 1)) % mod;
tr[rc].sum = (tr[rc].sum + tr[p].add * (tr[rc].r - tr[rc].l + 1)) % mod;
tr[lc].add = (tr[lc].add + tr[p].add) % mod;
tr[rc].add = (tr[rc].add + tr[p].add) % mod;
tr[p].add = 0;
}
}
void build(ll p,ll l,ll r)//建树
{
tr[p]={l,r,w[l]%mod,0,1};// 初始化mul为1
if(l==r) return ;//是叶子节点就返回
int m=l+r>>1;//不是叶子节点就裂开
build(lc,l,m);
build(rc,m+1,r);
pushup(p);
}
void updatesect(ll p, ll x, ll y,ll k)//区间修改加法
{
if(x<=tr[p].l&&tr[p].r<=y)//覆盖则修改
{
tr[p].sum+=(tr[p].r-tr[p].l+1)*k;
tr[p].sum%=mod;
tr[p].add+=k;
tr[p].add%=mod;
return;
}
int m=tr[p].l+tr[p].r>>1;//不覆盖就裂开
pushdown(p);
if(x<=m)updatesect(lc,x,y,k);
if(y>m)updatesect(rc,x,y,k);
pushup(p);
}
void updatesectmul(ll p, ll x, ll y,ll k)//区间修改乘法
{
if(x<=tr[p].l&&tr[p].r<=y)//覆盖则修改
{
tr[p].sum = (tr[p].sum * k) % mod;
tr[p].mul = (tr[p].mul * k) % mod;
tr[p].add = (tr[p].add * k) % mod; // 加法标记也要乘
return;
}
int m=tr[p].l+tr[p].r>>1;//不覆盖就裂开
pushdown(p);
if(x<=m)updatesectmul(lc,x,y,k);
if(y>m)updatesectmul(rc,x,y,k);
pushup(p);
}
void updatedot(ll p, ll x,ll k)//点修改
{
if(tr[p].l==x&&tr[p].r==x)//到达叶子节点
{
tr[p].sum+=k;
return;
}
pushdown(p);
ll m=tr[p].l+tr[p].r>>1;//不是叶子节点就裂开
if(x<=m) updatedot(lc,x,k);
if(x>m) updatedot(rc,x,k);
tr[p].sum=tr[lc].sum+tr[rc].sum;
}

ll query(ll p, ll x,ll y)
{
if(x<=tr[p].l&&tr[p].r<=y)//覆盖则返回
return tr[p].sum;
ll m=tr[p].l+tr[p].r>>1;//不覆盖则裂开
pushdown(p);
ll sum=0;
if(x<=m)
{
sum+=query(lc,x,y);
sum%=mod;
}
if(y>m)
{
sum+=query(rc,x,y);
sum%=mod;
}
return sum%mod;
}
int main()
{
ios::sync_with_stdio(false);
cin>>n>>q>>mod;

for(int i=1;i<=n;i++)
{
cin>>w[i];
}
build(1,1,n);
while(q--)
{
int flag,x,y,k;
cin>>flag;
if(flag==1)
{
cin>>x>>y>>k;
updatesectmul(1,x,y,k);
}
if(flag==2)
{
cin>>x>>y>>k;
updatesect(1,x,y,k);
}
if(flag==3)
{
cin>>x>>y;
cout<<query(1,x,y)<<'\n';

}
}
return 0;
}

请注意、每次运算后都要取模、不能最后再取、否则答案错误。

结语

先说说我的感悟吧,线段树其实是一种思想,充分利用了递归、分治,以及树的数据结构,同时得益于懒标记,使得对区间修改时间复杂度大大减小,便于对区间进行修改和查询。
同时我文章也会同步更新到csdn,欢迎大家来看哦。我的主页: https://blog.csdn.net/2503_91354377?spm=1011.2266.3001.5343