[组合 + FFT] #295. 「湖南省队集训2018 Day3」admirable

不算太难。每条边要么被覆盖小于等于 \(1\) 次,要么被覆盖 \(K\) 次。

容易想到对于每条 \(K\) 覆盖的链求方案数。考虑 \(K\) 条路径从链的端点处延伸出去。

需要满足这些路径在 \(i\) 的子树中没有被覆盖超过 \(1\) 次的边。令 \(f_i\) 表示这个方案数。

肯定是一个儿子子树最多去一个,有些路径留在 \(i\)

容易得到: \[ H_i=\prod_{j \in son_i}(1+xsize_j) \\ f_i=\sum_{j=0}^K A(K,j)H_i[j] \] \(H_i\) 长度是 \(i\) 的度数,我们直接哈夫曼树合并那样搞,复杂度是两个 \(\log\) 。这样就求出 \(f\) 了。

\(K\) 覆盖路径可能是返祖路径,这样还需要求一个 \(g_i\) ,表示 \(fa_i\) 是路径上端,延伸的方案数。

同理 考虑求 \(i\) 的某个儿子 \(j\)\[ G=H_i\frac{1+(n-size_i)x}{1+size_jx} \\ f_j=\sum_{k=0}^K A(K,k)G_i[k] \]

直接算复杂度是萎的,每个点需要度数平方的时间。

怎么办呢,你发现 \(size\) 相同儿子的 \(g\) 是一样的,而 \(size\) 不同的点的个数是 \(\sqrt n\)

一条路径的贡献是 \(f_u f_v\)\(f_ug_v\) , 很容易一趟 \(dfs\) 记一些 \(sum\),统计总答案。

然后就做完了,复杂度 \(O(n\log^2 n + n\sqrt n)\)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
#include<cstdio>
#include<cmath>
#include<cctype>
#include<cstring>
#include<algorithm>
using namespace std;
inline char gc(){
static char buf[100000],*p1=buf,*p2=buf;
return p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++;
}
inline int getint(){
char ch=gc(); int res=0,ff=1;
while(!isdigit(ch)) ch=='-'?ff=-1:0, ch=gc();
while(isdigit(ch)) res=(res<<3)+(res<<1)+ch-'0', ch=gc();
return res*ff;
}
typedef long long LL;
typedef long double Ld;
const Ld PI=acos(-1);
#define eps (0.5)
const int maxn=500005,maxe=200005,P=1e9+9;
struct E{
Ld r,i;
E(Ld t1=0,Ld t2=0){ r=t1; i=t2; }
} W[2][maxn];
E operator + (const E &A,const E &B){ return E(A.r+B.r,A.i+B.i); }
E operator - (const E &A,const E &B){ return E(A.r-B.r,A.i-B.i); }
E operator * (const E &A,const E &B){ return E(A.r*B.r-A.i*B.i,A.r*B.i+A.i*B.r); }
int Wnum;
void PreW(int n){
Wnum=n; W[1][0]=W[0][0]=1;
for(int i=1;i<=Wnum-1;i++) W[1][i]=E(cos(PI*2*i/Wnum),sin(PI*2*i/Wnum));
for(int i=1;i<=Wnum-1;i++) W[0][i]=W[1][Wnum-i];
}
int rev[maxn];
void get_rev(int n){
int t=0; while((1<<t)<n) t++;
for(int i=1;i<=n-1;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<t-1);
}
void FFT(E a[],int n,int k){
for(int i=0;i<=n-1;i++) if(i<rev[i]) swap(a[i],a[rev[i]]);
for(int m=2;m<=n;m<<=1)
for(int i=0;i<=n-1;i+=m)
for(int j=0;j<=(m>>1)-1;j++){
E t0=a[i+j], t1=a[i+j+(m>>1)]*W[k][Wnum/m*j];
a[i+j]=t0+t1; a[i+j+(m>>1)]=t0-t1;
}
if(k==0) for(int i=0;i<=n-1;i++) a[i].r/=n;
}
int n,K,sz[maxn],f[maxn],g[maxn],fsum[maxn];
int fir[maxn],nxt[maxe],son[maxe],tot;
void add(int x,int y){
son[++tot]=y; nxt[tot]=fir[x]; fir[x]=tot;
}
int v[maxn];
void Solve(int L,int R,int *a){
if(L==R) return a[0]=1, a[1]=sz[v[L]], void();
int mid=(L+R)>>1;
Solve(L,mid,a); Solve(mid+1,R,a+(mid-L+2));
int N=1; while(N<R-L+2) N<<=1; get_rev(N);
static E A[maxn],B[maxn],C[maxn],D[maxn],t1[maxn],t2[maxn],t3[maxn];
for(int i=0;i<=mid-L+1;i++) A[i]=a[i]>>16, B[i]=a[i]&65535;
for(int i=mid-L+2;i<=N-1;i++) A[i]=B[i]=0;
for(int i=0;i<=R-mid;i++) C[i]=a[mid-L+2+i]>>16, D[i]=a[mid-L+2+i]&65535;
for(int i=R-mid+1;i<=N-1;i++) C[i]=D[i]=0;
FFT(A,N,1); FFT(B,N,1); FFT(C,N,1); FFT(D,N,1);
for(int i=0;i<=N-1;i++)
t1[i]=A[i]*C[i], t2[i]=A[i]*D[i]+B[i]*C[i], t3[i]=B[i]*D[i];
FFT(t1,N,0); FFT(t2,N,0); FFT(t3,N,0);
for(int i=0;i<=N-1;i++){
a[i]=((LL)(t1[i].r+eps)%P*294967260%P+(LL)(t2[i].r+eps)%P*65536%P+(LL)(t3[i].r+eps)%P)%P;
}
}
LL fac[maxn],inv[maxn],fac_inv[maxn];
LL AR(int n,int m){
return fac[n]*fac_inv[n-m]%P;
}
bool my_cmp(int A,int B){ return sz[A]<sz[B]; }
int H[maxn];
void dfs(int x,int pre){
sz[x]=1;
for(int j=fir[x];j;j=nxt[j])
if(son[j]!=pre) dfs(son[j],x), sz[x]+=sz[son[j]], (fsum[x]+=fsum[son[j]])%=P;
v[0]=0; for(int j=fir[x];j;j=nxt[j]) if(son[j]!=pre) v[++v[0]]=son[j];
if(!v[0]) return f[x]=fsum[x]=1, void();
sort(v+1,v+1+v[0],my_cmp);
Solve(1,v[0],H);
for(int i=0;i<=min(K,v[0]);i++) (f[x]+=AR(K,i)*H[i]%P)%=P;
H[v[0]+1]=0; for(int i=v[0]+1;i>=1;i--) H[i]=(H[i]+(LL)H[i-1]*(n-sz[x])%P)%P;
static int tmp[maxn];
for(int i=1;i<=v[0];i++)
if(i==1||sz[v[i]]!=sz[v[i-1]]){
for(int j=0;j<=v[0]+1;j++) tmp[j]=H[j];
for(int j=0;j<=v[0];j++) tmp[j+1]=(tmp[j+1]-(LL)tmp[j]*sz[v[i]]%P)%P;
for(int j=0;j<=min(K,v[0]);j++) (g[v[i]]+=AR(K,j)*tmp[j]%P)%=P;
} else g[v[i]]=g[v[i-1]];
(fsum[x]+=f[x])%=P;
}
int ans=0;
void dfs_calc(int x,int pre,int gs,int fs){
(gs+=g[x])%=P;
(ans+=(LL)f[x]*(((LL)gs*2+fsum[1]-fs-fsum[x])%P)%P)%=P;
(fs+=f[x])%=P;
for(int j=fir[x];j;j=nxt[j])
if(son[j]!=pre) dfs_calc(son[j],x,gs,fs);
(gs-=g[x])%=P; (fs-=f[x])%=P;
}
int main(){
n=getint(); K=getint();
fac[0]=1; for(int i=1;i<=K+5;i++) fac[i]=fac[i-1]*i%P;
inv[1]=1; for(int i=2;i<=K+5;i++) inv[i]=(LL)(P-P/i)*inv[P%i]%P;
fac_inv[0]=1; for(int i=1;i<=K+5;i++) fac_inv[i]=fac_inv[i-1]*inv[i]%P;
int _n=1; while(_n<n) _n<<=1; PreW(_n<<1);
for(int i=1;i<=n-1;i++){
int x=getint(),y=getint();
add(x,y); add(y,x);
}
dfs(1,1); g[1]=0;
dfs_calc(1,1,0,0);
ans=(LL)ans*inv[2]%P;
printf("%d\n",(ans%P+P)%P);
return 0;
}