Yuhang's Blog

Luogu P1272:树形背包

2020-12-19 Coding

树上, 状态必然有一个维度是“子树的根节点编号”, 记作$u$. 需要完全理解题意, 可构造出本题用$f(u,j)$来表示“将以$u$为根节点的子树拆成一棵大小是$j$的子树至少需要切断多少条边”为好. 注意为了让问题被等价地分割, 我们这里暂不考虑$u$和它的父亲的连边, 在最后的答案中记得将非根节点与其父亲的连边考虑在内即可:

$$
\min_{1 \le u \le n}
\begin{cases} f(u,p) &\text{if } u = 1 \
f(u,p)+1 &\text{otherwise}\end{cases}
$$

下面考虑状态转移. 首先容易得到:
$$
f(u,1) = \text{number of sons of } u
$$
对应的就是把$u$的所有儿子全部切断, 只剩下$u$自己一个节点.

虽然这里还可以有另一种初始状态, 即$f(u, s(u))=0$, 其中$s(u)$是$u$的子树大小, 但这种初始状态似不方便下面的递推.

我们下面需要做的是依次将$u$的儿子考虑进来, 并计算对我们答案的影响. 我们的初始值是所有儿子被切断的情形, 而被切断意味着, 对于每个儿子来说, 它的子问题的最优解不能对我们的当前节点产生影响. 因此我们要做的其实是依次把儿子如果和当前节点连通对答案的影响计算出来, 或者, 可以形象地把这一过程理解为将儿子逐个挂到当前节点.

假定我们已经考虑了$u$的若干儿子, 正在考虑$u$的某个儿子$v$. 我们枚举所有的分割子树大小$j$, 对于所有的$f(u, j)$, $v$的子树最优解都有可能对其产生优化. 这种优化的方式如何体现? 我们再枚举所有的$l$, 表示$v$子树对大小为$j$的(以$u$为根的)分割子树贡献了大小为$l$的(以$v$为根的)分割子树, 而剩下的$j-l$大小的(以$u$为根的)分割子树, 我们只需使用”考虑过$v$之前的所有儿子”的$u$的当前最优解即可. 因为我们决定连通$u$和$v$, 所以我们需要删去的边减去了$u$到$v$的那一条. 注意, 以$v$为根的子树至少有一个节点要进入最后生成的子树, 否则我们不能不切断连通这个儿子的边. 因为每考虑完一个儿子, 所有的$f(u, j)$都会被更新到当前最优解的状态, 所以当考虑完所有的儿子, 我们就得到要求的答案.

记$f_v(u,j)$为考虑了$v$及其之前的儿子后, $f(u,j)$的值. 下面这个式子应该理解为对于每个$v$, 同时更新所有的$j$:
$$
f_v(u,j) \leftarrow \min_{u \to v, l \ge 1} f(v,l) + f_{v-1}(u,j-l) - 1
$$
当然”同时更新”只是一种理论说法, 但注意到要更新$f_v(u,j)$, 要利用的只有第二个维度小于$j$的那些$f_{v-1}(u,j-l)$, 因此采用倒序循环即可.

这里没提背包两个字, 但已经把背包的思想说出来了. 背包的思想其实就在于”逐个考虑”: 当前最优解已经考虑了若干个子问题的最优解的情况下, 我们增加一个子问题最优解, 看这个子问题最优解会对当前最优解产生什么影响.

下面放代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#include <bits/stdc++.h>
using namespace std;
#define rep(i,from,to) for(register int i=(int)(from);i<=(int)(to);++i)
#define For(i,to) for(register int i=0;i<(int)(to);++i)
typedef long long ll;
inline ll read(){
ll x=0; ll sign=1; char c=getchar();
while(c>'9' || c<'0') {if (c=='-') sign=-1;c=getchar();}
while(c>='0' && c<='9'){x=(x<<3)+(x<<1)+c-'0';c=getchar();}
return x*sign;
}
#define N 300
int n, p;
vector<int> son[N];
int sz[N]; // size of subtree
int f[N][N];

void calsz(int u) {
sz[u] = 1;
for(int v : son[u]) {
calsz(v);
sz[u] += sz[v];
}
}

void dp(int u) {
f[u][1] = son[u].size();
for(int v : son[u]) {
dp(v);
for(int j = sz[u]; j >= 2; j--) {
for(int l = 1; l <= j - 1; l++) {
f[u][j] = min(f[u][j], f[v][l] + f[u][j-l] - 1);
}
}
}
}

int main() {
n = read(); p = read();
For(i, n - 1) {
int ff = read(), j = read();
son[ff].push_back(j);
}
memset(f, 0x3f3f3f3f, sizeof(f));
calsz(1);
dp(1);
int ans = f[1][p];
for(int i = 2; i <= n; i++) {
ans = min(ans, f[i][p] + 1);
}
cout << ans << endl;
return 0;
}
This article was last updated on days ago, and the information described in the article may have changed.