NOIP2016 Day1 T3(概率期望DP)

设$dp(i,j,0)$为前$i$个课程申请$j$次,第$j$次成功的最小体力期望,$dp(i,j,1)$为前$i$个课程申请$j$次,第$j$次不成功的最小体力期望。
转移方程具体看代码,太长了,不在这里重复打
注意double别用memset并且赋值考虑是否会溢出

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#include<queue>
#define ms(i, j) memset(i, j, sizeof i)
#define LL long long
#define db double
using namespace std;
const int MAXN = 2000 + 5, MAXV = 300 + 5;
int n, m, v, e, ci[MAXN], di[MAXN], G[MAXV][MAXV];
db ki[MAXN], dp[MAXN][MAXN][2];
void clean() {
for (int i = 1; i <= v; i++)
for (int j = 1; j <= v; j++) if (i == j) G[i][j] = 0; else G[i][j] = 1000000000.0;
}
void solve() {
clean();
for (int i = 1; i <= n; i++) scanf("%d", &ci[i]);
for (int i = 1; i <= n; i++) scanf("%d", &di[i]);
for (int i = 1; i <= n; i++) scanf("%lf", &ki[i]);
for (int x, y, w, i = 1; i <= e; i++) {
scanf("%d%d%d", &x, &y, &w);
G[x][y] = min(G[x][y], w);
G[y][x] = min(G[y][x], w);//注意邻接矩阵重边处理
}
for (int k = 1; k <= v; k++)
for (int i = 1; i <= v; i++)
for (int j = 1; j <= v; j++)
if (i != j && i != k && j != k) G[i][j] = min(G[i][j], G[i][k] + G[k][j]);
for (int i = 0; i <= n; i++) for (int j = 0; j <= m; j++) dp[i][j][0] = dp[i][j][1] = 1000000000.0;
db ans = dp[1][0][0];
dp[1][0][0] = 0, dp[1][1][0] = 0, dp[1][1][1] = 0;
for (int i = 2; i <= n; i++) {
for (int j = 0; j <= m; j++) {
dp[i][j][0] = min(dp[i][j][0], min(dp[i - 1][j][0] + (db)G[ci[i - 1]][ci[i]], dp[i - 1][j][1] + (db)G[di[i - 1]][ci[i]] * ki[i - 1] + G[ci[i - 1]][ci[i]] * (1 - ki[i - 1])));
if (j - 1 >= 0)
dp[i][j][1] = min(dp[i][j][1], min(dp[i - 1][j - 1][0] + (db)G[ci[i - 1]][ci[i]] * (1 - ki[i]) + G[ci[i - 1]][di[i]] * ki[i],dp[i - 1][j - 1][1] + (db)G[ci[i - 1]][ci[i]] * (1 - ki[i]) * (1 - ki[i - 1]) + G[di[i - 1]][ci[i]] * (1 - ki[i]) * ki[i - 1] +G[di[i - 1]][di[i]] * ki[i] * ki[i - 1] + G[ci[i - 1]][di[i]] * ki[i] * (1 - ki[i - 1])));
}
}
for (int i = 0; i <= m; i++) {
ans = min(ans, min(dp[n][i][1], dp[n][i][0]));
}
printf("%.2f\n", ans);
}
int main() {
scanf("%d%d%d%d", &n, &m, &v, &e), solve();
return 0;
}
------ 本文结束 ------