Hdu 2825(AC自动机+状压DP)

hdu 2825
题意:给出$m$个模式串,求至少包含$k$个模式串长为$n$的主串个数。
用模式串建立AC自动机,设$dp(i,j,S)$为主串前$i$个字符,在AC自动机上$j$点,当前存在模式串状态的方案数。
$$dp(i,v,S|val_v)=dp(i-1,j,S)$$
因为有用的只有两层,所以其中可以运用滚动数组节省空间

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<queue>
#define ms(i, j) memset(i, j, sizeof i)
#define LL long long
#define db double
using namespace std;
const int MO = 20090717, MAXN = 10 * 10 + 5;
int n, m, k, sz = 10 * 10;
char s[15];
int ch[MAXN][35], val[MAXN], f[MAXN], dp[2][MAXN][(1 << 10) + 100];
void insert(char *s, int ith) {
int now = 0, len = strlen(s);
for (int i = 0; i < len; i++) {
int c = s[i] - 'a';
if (!ch[now][c]) ch[now][c] = ++sz;
now = ch[now][c];
if (i == len - 1) val[now] += (1 << (ith - 1));
}
}
void getFail() {
queue<int> q;
f[0] = 0;
for (int c = 0; c < 26; c++) {
int v = ch[0][c];
if (v) q.push(v), f[v] = 0;
}
while (!q.empty()) {
int u = q.front(); q.pop();
for (int c = 0; c < 26; c++) {
int v = ch[u][c];
if (!v) {ch[u][c] = ch[f[u]][c]; continue;}
q.push(v);
int j = f[u]; while (j && !ch[j][c]) j = f[j];
f[v] = ch[j][c];
val[v] |= val[f[v]];//注意传递
}
}
}
void cal() {
dp[0][0][0] = 1;
int x = 1;//滚动数组当前位置
for (int i = 1; i <= n; i++) {
for (int j = 0; j <= sz; j++)
for (int S = 0; S < (1 << m); S++) dp[x][j][S] = 0;
for (int j = 0; j <= sz; j++) {
for (int S = 0; S < (1 << m); S++) {
if (!dp[x ^ 1][j][S]) continue;
for (int c = 0; c < 26; c++) {
int v = ch[j][c];
dp[x][v][S | val[v]] = (dp[x][v][S | val[v]] + dp[x ^ 1][j][S]) % MO;//方程不要写错了
}
}
}
x ^= 1;
}
}
void clean() {
for (int i = 0; i <= sz; i++) {
for (int j = 0; j < 28; j++) ch[i][j] = 0;
for (int j = 0; j < 1030; j++) dp[1][i][j] = dp[0][i][j] = 0;
val[i] = f[i] = 0;
}
sz = 0;
}
bool check(int x) {
int ret = 0, tmp = x;
do {
ret += tmp & 1;
tmp >>= 1;
} while (tmp != 0);
return ret >= k;
}
void solve() {
clean();
for (int i = 1; i <= m; i++) {
scanf("%s", s);
insert(s, i);
}
getFail(), cal();
int taki = 0;
for (int j = 0; j <= sz; j++) {
for (int S = 0; S < (1 << m); S++) {
if (check(S)) taki = (taki + dp[n % 2][j][S]) % MO;
}
}
printf("%d\n", taki);
}
int main() {
while (scanf("%d%d%d", &n, &m, &k) == 3 && (n || m || k)) solve();
return 0;
}

------ 本文结束 ------