Yuhang's Blog

树的启发式合并

2021-01-18 Coding

  1. 1. 例题
    1. 1.1. CF375D

对于一个树上问题,我们递归求解时需要将子问题的答案进行合并。如果求解时,我们需要利用若干大型数据结构(包括数组、mapset等)才能获得以某节点为根的子树的解,单纯的暴力做法会导致时间或空间复杂度难以承受。此时,我们考虑利用树的启发式合并。其直观理解是:当合并两堆石子时,将小的那堆合并到大的那堆,要比反过来做更加省力。这里我们考虑节点u的重儿子hson[u]:在结束对hson[u]的求解后,求解所利用的大型数据结构中保留着所有的演算结果,我们不要清除它们,而是在它的基础上将u所有的轻儿子合并进来,由此可以获得u的求解结果。由于每次都尽可能多地继承了已有的运算结果,这样的做法可以获得较大的时间优势。以上是启发式合并的主要精神。 具体操作如下:

  1. 用一个深搜统计出所有节点的重儿子。
  2. 实现一个求解结构体,内含所有求解需要的大型数据结构,并要求含有以下三个成员函数:
    1. add(u, skip, ...)u表示当前的遍历的节点,skip是一个需要跳过的节点编号,即当u==skip时,直接终止函数。对于其他值,利用大型数据结构,递归对u子树的问题答案进行求解。
    2. get_ans(u):对节点u,返回求解后的答案或进行查询操作。
    3. del(u):对节点u,递归进行add的逆向操作,其结果是初始化所有大型数据结构。注意此时不需要跳过任何节点,因为我们需要初始化整个大型数据结构,这也包括抹去hson[u]作出的贡献。这里一般不宜对数组直接用memset,原因在于通过遍历子树来抹去数据才能保证时间复杂度;但对STL的数据结构可以使用clear()(而且更快)。
  3. 用第二个深搜dfs2(u, keep)实现启发式合并逻辑:
    1. 对所有u的轻儿子v,不加保留地用dfs2(v, 0)进行答案求解。
    2. u的重儿子hson[u](如果存在),保留演算过程地利用dfs2(hson[u], 1)进行答案求解。
    3. 调用结构体的add函数,要求skip传值为hson[u],原因在于此时遍历u的子树时不需要遍历hson[u]的子树,因为它的答案已经被保留在大型数据结构中。
    4. 记录u的答案get_ans(u)
    5. 如果keep为非,调用del(u),清除演算过程。

其中1和3步骤均是模板化完成即可,2步在每个问题中才有差别。

例题

CF375D

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#include<bits/stdc++.h>
using namespace std;
#define rep(i,from,to) for(int i=(int)(from);i<=(int)(to);++i)
#define rev(i,from,to) for(int i=(int)(from);i>=(int)(to);--i)
#define For(i,to) for(int i=0;i<(int)(to);++i)
#define see(x) (cerr<<(#x)<<'='<<(x)<<endl)
void dbg() {cout << "\n";}
template<typename T, typename... A> void dbg(T a, A... x) {cout << a << ' '; dbg(x...);}
#define logs(x...) {cout << #x << " -> "; dbg(x);}
typedef long long ll;
typedef long double ld;
inline ll read(){
ll x=0; ll sign=1; char c=getchar();
while(c>'9' c<'0') {if (c=='-') sign=-1;c=getchar();}
while(c>='0' && c<='9'){x=(x<<3)+(x<<1)+c-'0';c=getchar();}
return x*sign;
}
const int N = 212345;
int n, m;
struct Query{
int q, a;

void output(){
printf("%d\n", a);
}
};

vector<int> son[N]; int fa[N];
int hson[N], sz[N];
int c[N];
vector<Query> query;
vector<int> q[N];

void dfs1(int u){
sz[u]=1;
for(int v:son[u]) if(v!=fa[u]){
fa[v]=u;
dfs1(v);
sz[u]+=sz[v];
if(sz[v]>sz[hson[u]])hson[u]=v;
}
}

struct TreeSolve{
map<int, int> cnt;
map<int, int> cntt;

void add(int u, int skip, int val) {
if(u == skip) return;

// need some consideration here
if (val == -1) cntt[cnt[c[u]]]--;
cnt[c[u]] += val;
if (val == 1) cntt[cnt[c[u]]]++;

for(int v : son[u]) if (v!=fa[u]){
add(v, skip, val);
}
}

void get_ans(int u) {
for(int i : q[u]) {
int k = query[i].q;
query[i].a = cntt[k];
}
}

void del(int u) {
add(u, 0, -1);
}

}TS;

void dfs2(int u, bool keep) {
for(int v : son[u]) if (v != hson[u] && v != fa[u]) {
dfs2(v, 0);
}
if (hson[u]) dfs2(hson[u], 1);
TS.add(u, hson[u], 1);
TS.get_ans(u);
if (!keep) TS.del(u);
}

int main() {
n = read(), m = read();
rep(i, 1, n) c[i] = read();
rep(i, 1, n - 1) {
int x = read(), y = read();
son[x].push_back(y);
son[y].push_back(x);
}
rep(i, 1, m) {
int v = read(), k = read();
query.push_back((Query){k, 0});
q[v].push_back(query.size() - 1);
}
dfs1(1);
dfs2(1, 1);

for(auto &q : query) {
q.output();
}

return 0;
}
This article was last updated on days ago, and the information described in the article may have changed.