不难的题。看到所求的式子这么复杂,就要想到可能要把所有东西都求出来。

那么就容易想到,用某种方法合并两个儿子。考虑启发式合并,现在要使合并复杂度只和小的儿子的规模有关。

概率的变化大概是:

\[ D_i \leftarrow LD_i\ (p_x\sum_{j<i}RD_j+(1-p_x)\sum_{i<j}RD_j) \]

来自小的子树的值直接求。注意到大的子树,是一段一段乘以某个相同的数。

所以需要实现单点修改,区间乘,还要能合并。可以 \(splay\) ,或线段树动态开点。

为什么我常数巨大…

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
#include<cstdio>
#include<vector>
#include<cctype>
#include<cstring>
#include<algorithm>
#define mp make_pair
#define X first
#define Y second
using namespace std;
typedef long long LL;
inline char gc(){
static char buf[100000],*p1=buf,*p2=buf;
return p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++;
}
inline int getint(){
char ch=gc(); int res=0,ff=1;
while(!isdigit(ch)) ch=='-'?-1:0, ch=gc();
while(isdigit(ch)) res=(res<<3)+(res<<1)+ch-'0', ch=gc();
return res*ff;
}
const int maxn=300005,maxe=300005,P=998244353;
int n,a[maxn],p[maxn],nson[maxn];
int fir[maxn],nxt[maxe],son[maxe],tot;
void add(int x,int y){
son[++tot]=y; nxt[tot]=fir[x]; fir[x]=tot;
}
struct node{
int sum,tag; node* ch[2];
node(int t1=0,int t2=0,node* son=NULL){ sum=t1; tag=t2; ch[0]=ch[1]=son; }
void Mul(int val){ sum=(LL)sum*val%P; tag=(LL)tag*val%P; }
void maintain(){
sum=(ch[0]->sum+ch[1]->sum)%P;
}
void pushdown(){
ch[0]->Mul(tag); ch[1]->Mul(tag);
tag=1;
}
} _poor[maxn*30], *poor_top=_poor, *rt[maxn], nil, *null=&nil;
typedef node* P_node;
void Update(P_node p,int L,int R,int qL,int qR,int val){
if(p==null||qR<L||R<qL||qL>qR||!p->sum) return;
if(qL<=L&&R<=qR){ p->Mul(val); return; }
p->pushdown(); int mid=(L+R)>>1;
Update(p->ch[0],L,mid,qL,qR,val); Update(p->ch[1],mid+1,R,qL,qR,val);
p->maintain();
}
int Query(P_node p,int L,int R,int qL,int qR){
if(p==null||qR<L||R<qL||qL>qR||!p->sum) return 0;
if(qL<=L&&R<=qR) return p->sum;
p->pushdown(); int mid=(L+R)>>1;
return (Query(p->ch[0],L,mid,qL,qR)+Query(p->ch[1],mid+1,R,qL,qR))%P;
}
inline P_node newnode(){
(*poor_top)=node(0,1,null);
return poor_top++;
}
void Update(P_node &p,int L,int R,int pos,int val){
if(p==null) p=newnode();
if(L==R){ p->sum=val; return; }
p->pushdown(); int mid=(L+R)>>1;
if(pos<=mid) Update(p->ch[0],L,mid,pos,val);
else Update(p->ch[1],mid+1,R,pos,val);
p->maintain();
}
vector< pair<int,int> > v;
void Travel(P_node p,int L,int R){
if(p==null) return;
if(L==R) v.push_back(mp(L,p->sum));
p->pushdown(); int mid=(L+R)>>1;
Travel(p->ch[0],L,mid); Travel(p->ch[1],mid+1,R);
}
int cnt[maxn],len_tmp;
pair<int,int> tmp[maxn];
int Find(int x){
return lower_bound(a+1,a+1+a[0],x)-a;
}
inline void addm(int &x,int y){
x+=y; if(x>P) x-=P;
}
void dfs(int x){
if(!nson[x]){
rt[x]=newnode();
Update(rt[x],1,a[0],Find(p[x]),1); cnt[x]=1;
return;
}
if(nson[x]==1){
dfs(son[fir[x]]); rt[x]=rt[son[fir[x]]]; cnt[x]=cnt[son[fir[x]]];
return;
}
int id1=son[fir[x]],id2=son[nxt[fir[x]]];
dfs(id1); dfs(id2);
if(cnt[id1]>cnt[id2]) swap(id1,id2);
rt[x]=rt[id2]; cnt[x]=cnt[id1]+cnt[id2];
v.clear(); Travel(rt[id1],1,a[0]);
P_node rt_2=rt[id2]; int _p=p[x],p_=(1+P-p[x])%P;
len_tmp=0;
int alls=0; int alls_2=Query(rt_2,1,a[0],1,a[0]);
for(int i=0;i<v.size();i++){
int pres_2=Query(rt_2,1,a[0],1,v[i].X-1);
int t=((LL)pres_2*_p%P+(LL)(alls_2+P-pres_2)*p_)%P;
tmp[++len_tmp]=mp(v[i].X,(LL)v[i].Y*t%P);
addm(alls,v[i].Y);
}
for(int i=1;i<=len_tmp;i++) Update(rt_2,1,a[0],tmp[i].X,tmp[i].Y);
if(1<v[0].X) Update(rt_2,1,a[0],1,v[0].X-1,(LL)alls*p_%P);
for(int i=1,pres=v[0].Y;i<v.size();i++){
int t=((LL)pres*_p%P+(LL)(alls+P-pres)*p_%P)%P;
Update(rt_2,1,a[0],v[i-1].X+1,v[i].X-1,t);
addm(pres,v[i].Y);
}
if(v[v.size()-1].X<a[0]) Update(rt_2,1,a[0],v[v.size()-1].X+1,a[0],(LL)alls*_p%P);
}
int Pow(LL a,int b){
LL res=1;
for(;b;b>>=1,a=a*a%P) if(b&1) res=res*a%P;
return res;
}
int ans;
int main(){
scanf("%d%*d",&n);
for(int i=2;i<=n;i++){
int x=getint(); nson[x]++;
add(x,i);
}
int _inv=Pow(10000,P-2);
for(int i=1;i<=n;i++){
p[i]=getint();
if(!nson[i]) a[++a[0]]=p[i]; else p[i]=(LL)p[i]*_inv%P;
}
sort(a+1,a+1+a[0]);
dfs(1);
v.clear(); Travel(rt[1],1,a[0]);
for(int i=0;i<v.size();i++)
addm(ans,(LL)(i+1)*a[v[i].X]%P*v[i].Y%P*v[i].Y%P);
printf("%d\n",(ans+P)%P);
return 0;
}