Yuhang's Blog

插头DP

2020-12-19 Coding

看了很多插头DP的教程,感觉都说的不是很清楚。但我还是搞明白了它是什么。

说例题,吃树。矩阵地图,有障碍,要求用若干条回路把地图铺满,问方案数。方阵在$12\times12$以内,状态可以用二进制数压缩。

状态压缩DP往往用一个二进制数表示一行的状态,因而状态转移是以行为单位的。但在这里,以行为单位的转移显然不可行,我们以每个小格为单位进行转移。假如我们考虑到了第$i$行第$j$列的方格,其中$0\le i < n, 0 \le j < m$,那么我们有多少种方式填充这个方格呢?不难意识到任何一个非障碍方格都只能有以下六种情况:左上、左下、右上、右下、上下、左右(表示方格内的路径连通的方格的两个边)。是不是可以随便选择这六种情形呢?显然不是,因为我们要保证在这一方格进行的选择和之前方格的选择不发生冲突,也就是说不能和它上面、左边相邻的方格的选择冲突。比如左边的方格选了“左右”,那我们就不能选“上下”,否则就把路径破坏了。从更大的视角来说,这就表明“已经考虑过的方格”会对“尚未考虑的方格”的决策产生影响。我们还能观察到,这种影响其实仅限于“已考虑的方格”中所有和“尚未考虑的方格”有公共边的那些。更进一步,所有的影响其实都是通过这些公共边产生的:如果一个已考虑的方格当时选择的路径连通了一个公共边(和这个已考虑的方格的另一无论哪个边),那么这个公共边的另一侧的(尚未考虑的)方格就必须也连通这个公共边;反之,如果一个已考虑的方格选择的路径没有连通一个公共边,那么这个公共边另一侧的(尚未考虑的)方格也不能连通这个公共边。最终,我们发现将每条公共边的一个01状态——是否被已考虑的方格当时选择的路径连通——合起来,就能完全体现已经考虑过的方格对我们尚未考虑的方格的决策的所有影响。这时,我们就可以自然地引出“轮廓线”,即一条长度是$m+1$的、分隔已考虑的方格和尚未考虑的方格的折线(或已考虑的方格和未考虑的方格的公共边的集合),而其中每一条公共边的那个01状态就是所谓插头(往往称某公共边“有无插头”)。总而言之,将轮廓线每一段公共边的是否存在插头的状态总体用一个$m+1$位二进制数表示,就可以确定在每个方格进行决策时受到的限制。(“轮廓线”和“插头”的称呼都是形象化的,但对于不理解它们的人来说,直接使用反而会制造困惑——不仅仅是对于它们的所指的困惑,也包括对引入它们的原因的困惑。)

确定了状态参量,说状态转移。记当前轮廓线的状态为st,那么我们只需取st的第$j$位即可表示当前考虑方格左侧的公共边的状态,取第$j+1$位即可表示其上侧公共边的状态。在进行转移之后,st的第$j$位会变成进行选择后当前方格下侧公共边的状态,第$j+1$位则表示其右侧公共边的状态。几种情形:

  1. 当前方格已超过最后一行的行末,那么整个地图都考虑完了,所以整个轮廓线不能有任何插头。满足则记录1种方案。
  2. 当前方格已超过行末(一行刚刚被考虑完),那么当前方格左侧必须没有插头,然后转移到下一行第0列,并在st的第0位添一个0(左移一位)。这相当于轮廓线的那条竖线从当前行最右侧变成下一行最左侧。
  3. 当前考虑方格是障碍物。那么左侧公共边和上侧公共边必须同时没有插头,且只能转移到右侧和下侧公共边也没有插头的状态。
  4. 当前考虑方格不是障碍物,那么分左侧公共边和上侧公共边分别有无插头的四种情况进行讨论,然后转移。如下:
    1. 左有上有:转移到下无右无。
    2. 左无上无:转移到下有右有。
    3. 左有上无,或左无上有:转移到下有右无,或下无右有。

然后就做完了。下面代码是记忆化搜索,因为我思维比较习惯这种方式。实现的时候懒得去想太多二进制运算,就直接写了两个函数,用来获取和更改某一整数的某一二进制位。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#include <bits/stdc++.h>
using namespace std;
#define rep(i,from,to) for(register int i=(int)(from);i<=(int)(to);++i)
#define For(i,to) for(register int i=0;i<(int)(to);++i)
typedef long long ll;
inline ll read(){
ll x=0; ll sign=1; char c=getchar();
while(c>'9' || c<'0') {if (c=='-') sign=-1;c=getchar();}
while(c>='0' && c<='9'){x=(x<<3)+(x<<1)+c-'0';c=getchar();}
return x*sign;
}

#define N 14
int n, m;
int game[N];

inline bool gb(int s, int j) {
// get bit
return ((s >> j) & 1);
}
int sb(int s, int j, bool b) {
// set bit
return b ? (s | (1 << j)) : (s & (~(1 << j)));
}

int sbs(int s, int j1, bool b1, int j2, bool b2) {
// set two bits
return sb(sb(s, j1, b1), j2, b2);
}

void output_bin(int s){
// for debug
if (s < 0) return;
if (s == 1 || s == 0) cout << s;
else {
output_bin(s >> 1);
output_bin(s & 1);
}
}

ll f[N][N][1<<N];
ll dp(int i, int j, int st) {
// we are considering i-th row, j-th column, and the state is now st.
// we want to output count of possible solutions
// i in [0, n-1], j in [0, m-1].

// printf("i=%d, j=%d, st=", i, j); output_bin(st); cout << endl;
// assert(j >= 0 && j < m);
// assert(st >= 0 && (st < (1 << (m+1))));

ll &mem = f[i][j][st];
if (mem > -1) return mem;

int stl = gb(st, j);
int stu = gb(st, j + 1);

if (j == m) {
if (i == n - 1) return st == 0; // end of game
// end of line
else {
if (stl) return 0;
else return dp(i + 1, 0, st << 1);
}
}

// obstacle
if (!gb(game[i], j)) {
if (stl || stu) return mem = 0;
else return mem = dp(i, j + 1, st);
}
// not obstacle
else {
if (stl && stu) {
return mem = dp(i, j + 1, sbs(st, j, 0, j + 1, 0));
}
else if (!stl && !stu) {
return mem = dp(i, j + 1, sbs(st, j, 1, j + 1, 1));
}
else {
mem = 0;
mem += dp(i, j + 1, sbs(st, j, 0, j + 1, 1));
mem += dp(i, j + 1, sbs(st, j, 1, j + 1, 0));
return mem;
}
}

}

int main() {
#ifdef D
freopen("5074.in", "r", stdin);
double TIMEA = clock();
#endif
int T = read();
while (T--) {
n = read(); m = read();
memset(f, -1, sizeof(f));
memset(game, 0, sizeof(game));
For(i, n) {
For(j, m) {
int k = read();
game[i] <<= 1; game[i] |= k;
}
}
cout << dp(0, 0, 0) << endl;
}

#ifdef D
double TIMEB=clock();
printf("\n# Time consumed: %.3fs.\n", (TIMEB-TIMEA)/1000);
#endif
return 0;
}
This article was last updated on days ago, and the information described in the article may have changed.