[HEOI2013]SAO题解

[HEOI2013]SAO

标签

拓扑排序 + 计数 + dp + 树形dp

思路

题意是给一个树状有向图,求有多少种拓扑序。

\starf[i][j]f[i][j] 表示在 ii 的子树中,ii 的拓扑序为 jj 的方案数。

那么考虑转移,合并 uuuu 的子树 vv。分两种情况。

  1. uu 指向 vv

f[u][p3]=p1=1sizxp2=1sizyf[u][p1]f[v][p2](p31p11)(sizu+sizvp3sizup1)[p1p3p2+p11]f[u][p3] = \sum_{p1 = 1}^{siz_x} \sum_{p2=1}^{siz_y} f[u][p1] * f[v][p2] * \tbinom{p3-1}{p1-1} * \tbinom{siz_u+siz_v-p3}{siz_u-p1} [p1 \leq p3 \leq p2 + p1 - 1]

p1p3p1 \leq p3 因为原来拓扑序在 p1p1 左侧的点,合并后拓扑序一定还在 p1p1 左侧, 因为原来拓扑序在 p1p1 右侧的点,合并后拓扑序一定还在 p1p1 右侧

p3p2+p11p3 \leq p2 + p1 - 1 , 因为合并后 vv 的拓扑序一定比 uu 大,所以 [p2,sizv][p2,siz_v] 一定在 uu 右侧, [1,p2][1,p2] 既可能在 uu 左,又可能在 uu 右,所以 p3max=p2+p11p3_{max} = p2 + p1 - 1

合并后,如果 uu 的拓扑序是 p3p3 , 那么左边的 p31p3 - 1 个点中,有 p11p1 - 1 个是原来的,右边的 sizu+sizvp3siz_u+siz_v - p3 个点中,有 sizup1siz_u - p1 个点是原来的,所以要乘上 (p31p11)(sizu+sizvp3sizup1)\tbinom{p3-1}{p1-1} * \tbinom{siz_u+siz_v-p3}{siz_u-p1}

但是这样转移是 O(n3)O(n^3) 的,考虑优化。

p3p2+p11p2p3p1+1f[u][p3]=p1=1sizxp2=1sizyf[u][p1]f[v][p2](p31p11)(sizu+sizvp3sizup1)[p1p3p2+p11]f[u][p3]=p1=1sizxp2=1sizyf[u][p1]f[v][p2](p31p11)(sizu+sizvp3sizup1)[p1p3p2p3p1+1]f[u][p3]=p1=1p3p2=1p3p1+1f[u][p1]f[v][p2](p31p11)(sizu+sizvp3sizup1)f[u][p3]=p1=1p3f[u][p1](p31p11)(sizu+sizvp3sizup1)(p2=1p3p1+1f[v][p2])\because p3 \leq p2 + p1 - 1\\ \therefore p2 \geq p3 - p1 + 1\\ 又 \because f[u][p3] = \sum_{p1 = 1}^{siz_x} \sum_{p2=1}^{siz_y} f[u][p1] * f[v][p2] * \tbinom{p3-1}{p1-1} * \tbinom{siz_u+siz_v-p3}{siz_u-p1} [p1 \leq p3 \leq p2 + p1 - 1]\\ \therefore f[u][p3] = \sum_{p1 = 1}^{siz_x} \sum_{p2=1}^{siz_y} f[u][p1] * f[v][p2] * \tbinom{p3-1}{p1-1} * \tbinom{siz_u+siz_v-p3}{siz_u-p1} [p1 \leq p3 \land p2 \geq p3 - p1 + 1]\\ \therefore f[u][p3] = \sum_{p1 = 1}^{p3} \sum_{p2=1}^{p3 - p1 + 1} f[u][p1] * f[v][p2] * \tbinom{p3-1}{p1-1} * \tbinom{siz_u+siz_v-p3}{siz_u-p1}\\ \therefore f[u][p3] = \sum_{p1 = 1}^{p3} f[u][p1] * \tbinom{p3-1}{p1-1} * \tbinom{siz_u+siz_v-p3}{siz_u-p1} * (\sum_{p2=1}^{p3 - p1 + 1} f[v][p2])\\

p2=1p3p1+1f[v][p2]\sum_{p2=1}^{p3 - p1 + 1} f[v][p2] 可以用前缀和优化.

这样复杂度就是 O(n2)O(n^2).

  1. vv 指向 uu

    和 1 差不多。

最后 ans=i=1nf[1][i]ans = \sum_{i=1}^{n}f[1][i]

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
#include <iostream>
#include <cstdio>
#include <cstring>
#define ll long long

using namespace std;

const int N = 1010;
const ll mod = 1e9 + 7;
int T;
int n;
int head[N], nxt[2 * N], to[2 * N], op[2 * N], e_tot;
int siz[N];
ll C[N][N];
ll f[N][N], g[N], s[N][N];

void link(int x, int y, int z)
{
nxt[++e_tot] = head[x], head[x] = e_tot, to[e_tot] = y, op[e_tot] = z;
}

ll add(ll x, ll y)
{
return x + y >= mod ? x + y - mod : x + y;
}

ll suf(ll x, ll y)
{
return x - y < 0 ? x - y + mod : x - y;
}

void dfs(int u, int _fa)
{
siz[u] = 1, f[u][1] = 1;
for (int i = head[u]; i; i = nxt[i])
{
int v = to[i];
if (v == _fa) continue;
dfs(v, u);
memcpy(g, f[u], sizeof(f[u]));
memset(f[u], 0, sizeof(f[u]));
if (op[i] == 1)
{
for (int p1 = 1; p1 <= siz[u]; ++p1)
{
for (int p3 = p1; p3 < p1 + siz[v]; ++p3)
{
f[u][p3] = add(f[u][p3], C[p3 - 1][p1 - 1] * C[siz[u] + siz[v] - p3][siz[u] - p1] % mod * g[p1] % mod * suf(s[v][siz[v]], s[v][p3 - p1]) % mod);
}
}
}
else
{
for (int p1 = 1; p1 <= siz[u]; ++p1)
{
for (int p3 = p1 + 1; p3 <= p1 + siz[v]; ++p3)
{
f[u][p3] = add(f[u][p3], C[p3 - 1][p1 - 1] * C[siz[u] + siz[v] - p3][siz[u] - p1] % mod * g[p1] % mod * s[v][p3 - p1] % mod);
}
}
}
siz[u] += siz[v];
}
for (int i = 1; i <= siz[u]; ++i)
{
s[u][i] = add(s[u][i - 1], f[u][i]);
}
}

int main()
{
scanf("%d", &T);
for (int i = 0; i <= 1005; ++i) C[i][0] = 1;
for (int i = 1; i <= 1005; ++i)
{
for (int j = 1; j <= i; ++j)
{
C[i][j] = add(C[i - 1][j - 1], C[i - 1][j]);
}
}
while (T--)
{
scanf("%d", &n);
memset(f, 0, sizeof(f));
memset(head, 0, sizeof(head));
memset(nxt, 0, sizeof(nxt));
memset(to, 0, sizeof(to));
memset(op, 0, sizeof(op));
e_tot = 0;
for (int i = 1; i < n; ++i)
{
int x, y;
char c;
scanf("%d %c %d", &x, &c, &y);
++x, ++y;
link(x, y, c == '<');
link(y, x, c == '>');
}
dfs(1, 0);
cout << s[1][n] << endl;
}
return 0;
}