[点分治 + 分块] UOJ33. [UR #2] 树上GCD

首先是一个套路容斥,转化为求 \(ans(i)\) 等于 \(gcd\)\(i\) 的倍数的点对数。

然后就可做了,路径相关的东西容易想到点分治。

每次分治求出重心,重心有若干个儿子子树和一个父亲子树。

儿子子树之间的贡献比较简单:

枚举 \(d\) ,设第 \(i\) 个子树有 \(a_i\) 个深度是 \(d\) 的倍数的节点,则对 $ ans(d)$ 贡献为 \(\sum a_ia_j\)

由于调和级数,是可以做到 \(O(n \log n)\) 的。

考虑父亲子树,直接往重心祖先跳,枚举 \(lca\) ,然后还是枚举 \(d\) , 父亲子树计数是 \(O(n \log n)\) 的。

我们每次还要对于所有儿子子树,求深度模 \(d\)\(-dis(G,lca)\) 的节点数。

看似要 \(O(n^2)\) ,然而注意到,若我们把重复询问记忆化,模 \(d\) 的询问只有 \(d\) 种。

这就比较明显的是分块了, 对于 \(d < \sqrt n\) ,每种 \(d\) 复杂度 \(O(d\ \frac{n}{d})=O(n)\) ,共 \(O(n \sqrt n)\)

对于 \(d>\sqrt n\) ,每次询问小于 \(\sqrt n\) ,直接做就好了。

这题一不小心复杂度容易写萎,需要注意细节,不然就会像我一样卡了半天常数,才发现是复杂度错了...

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
#include<cstdio>
#include<algorithm>
#include<vector>
#include<cctype>
#include<cmath>
using namespace std;
typedef long long LL;
inline char gc(){
static char buf[100000],*p1=buf,*p2=buf;
return p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++;
}
inline int getint(){
char ch=gc(); int res=0,ff=1;
while(!isdigit(ch)) ch=='-'?ff=-1:0, ch=gc();
while(isdigit(ch)) res=(res<<3)+(res<<1)+ch-'0', ch=gc();
return res*ff;
}
const int maxn=200005,maxe=200005;
int fir[maxn],nxt[maxe],son[maxe],tot,pre[maxn];
void add(int x,int y){
son[++tot]=y; nxt[tot]=fir[x]; fir[x]=tot;
}
int sz[maxn];
bool vis[maxn];
void dfs_info(int x){
sz[x]=1;
for(int j=fir[x];j;j=nxt[j])
if(!vis[son[j]]) dfs_info(son[j]), sz[x]+=sz[son[j]];
}
int max(int x,int y){ return x>y?x:y; }
int min(int x,int y){ return x<y?x:y; }
int allsz,_G,now_min;
void dfs_G(int x){
int _max=0,sum=0;
for(int j=fir[x];j;j=nxt[j])
if(!vis[son[j]])
_max=max(_max,sz[son[j]]), sum+=sz[son[j]], dfs_G(son[j]);
_max=max(_max,allsz-sum-1);
if(_max<now_min) now_min=_max, _G=x;
}
int n,dep[maxn];
vector<int> v[maxn];
void dfs_son(int x,int id){
if(dep[x]+1>v[id].size()) v[id].push_back(1); else v[id][dep[x]]++;
for(int j=fir[x];j;j=nxt[j])
if(!vis[son[j]]) dep[son[j]]=dep[x]+1, dfs_son(son[j],id);
}
inline LL Calc(int id,int d,int k=0){
LL res=0; int len=v[id].size();
for(;k<len;k+=d) res+=v[id][k];
return res;
}
int clk[605][605],Tim;
LL g[605][605];
inline LL Get(int d,int k){
if(d>600) return Calc(n+1,d,k);
if(clk[d][k]==Tim) return g[d][k];
clk[d][k]=Tim; return g[d][k]=Calc(n+1,d,k);
}
LL ans[maxn],cnt[maxn];
int q[maxn];

bool my_cmp(const int &A,const int &B){
return v[A].size()<v[B].size();
}
void Divide(int rt){
dfs_info(rt);
allsz=sz[rt]; now_min=1e9; dfs_G(rt); int G=_G;
v[n+1].clear(); v[n+1].push_back(1); // G
if(allsz==1) return;
q[0]=0;
for(int j=fir[G];j;j=nxt[j])
if(!vis[son[j]]){
v[son[j]].clear(); v[son[j]].push_back(0);
dep[son[j]]=1; dfs_son(son[j],son[j]); dfs_son(son[j],n+1);
q[++q[0]]=son[j];
}
sort(q+1,q+1+q[0],my_cmp); int now=1;
for(int d=1;d<v[n+1].size();d++){
while(now<=q[0]&&v[q[now]].size()<d+1) now++;
LL sum1=1,sum2=1; // G
for(int i=now;i<=q[0];i++)
if(!vis[q[i]]){
LL res=Calc(q[i],d);
sum1+=res; sum2+=res*res;
}
ans[d]+=(sum1*sum1-sum2)>>1;
}

Tim++;
if(!vis[pre[G]]){
int dis=1;
for(int x=pre[G],lst=G;!vis[x];lst=x,x=pre[x],dis++){
v[0].clear(); v[0].push_back(0); // without x
for(int j=fir[x];j;j=nxt[j])
if(!vis[son[j]]&&son[j]!=lst) dep[son[j]]=1, dfs_son(son[j],0);
int len_now=min(v[0].size(),dis+v[n+1].size());
for(int d=1;d<len_now;d++)
ans[d]+=Calc(0,d)*Get(d,(d-dis%d)%d);
}
for(int i=0;i<v[n+1].size();i++) // x
cnt[i+1]+=v[n+1][i], cnt[i+dis]-=v[n+1][i];
}

vis[G]=true;
for(int j=fir[G];j;j=nxt[j]) if(!vis[son[j]]) Divide(son[j]);
if(!vis[pre[G]]) Divide(rt);
}
int main(){
freopen("uoj33.in","r",stdin);
freopen("uoj33.out","w",stdout);
n=getint();
for(int i=2;i<=n;i++){
int x=getint();
add(x,i); pre[i]=x;
}
vis[0]=true; Divide(1);
for(int i=n-1;i>=1;i--)
for(int j=(i<<1);j<=n-1;j+=i) ans[i]-=ans[j];
for(int i=1;i<=n-1;i++) cnt[i]+=cnt[i-1];
for(int i=1;i<=n-1;i++) printf("%lld\n",ans[i]+cnt[i]);
return 0;
}