[后缀数组 + 二分] BZOJ4310 跳蚤

记一下改过之后的后缀数组模板,以前写的太鬼畜......

这题就是利用后缀数组搞出排名第几的字串是什么,然后就能二分答案了。

验证时从后往前,比较时用后缀数组。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long LL;
const int maxn=100005;
int n,K,M;
char s[maxn];
int sa[maxn],rk[maxn],tp[maxn],Hht[maxn];
void Dsort(){
static int cnt[maxn];
for(int i=1;i<=M;i++) cnt[i]=0;
for(int i=1;i<=n;i++) cnt[rk[i]]++;
for(int i=1;i<=M;i++) cnt[i]+=cnt[i-1];
for(int i=n;i>=1;i--) sa[cnt[rk[tp[i]]]--]=tp[i];
}
int f[maxn][18];
bool Build_SA(){
M=60; for(int i=1;i<=n;i++) rk[i]=s[i]-'a'+1, tp[i]=i; Dsort();
for(int w=1,p=0;p<n;M=p,w<<=1){
p=0; for(int i=n-w+1;i<=n;i++) tp[++p]=i;
for(int i=1;i<=n;i++) if(sa[i]>w) tp[++p]=sa[i]-w;
Dsort(); memcpy(tp,rk,sizeof(rk));
rk[sa[1]]=p=1;
for(int i=2;i<=n;i++)
rk[sa[i]]=(tp[sa[i]]==tp[sa[i-1]]&&tp[sa[i]+w]==tp[sa[i-1]+w])?p:++p;
}
for(int i=1,k=0;i<=n;i++){
if(k) k--;
while(s[i+k]==s[sa[rk[i]-1]+k]) k++;
Hht[rk[i]]=k;
}
for(int i=1;i<=n;i++) f[i][0]=Hht[i];
for(int j=1;j<=17;j++)
for(int i=1;i+(1<<j)-1<=n;i++)
f[i][j]=min(f[i][j-1],f[i+(1<<j-1)][j-1]);
}
int nL,nR;
LL ans;
int LCP(int x,int y){
if(x==y) return n;
x=rk[x]; y=rk[y]; if(x>y) swap(x,y); x++;
int t=0; while((1<<t+1)<=y-x+1) t++;
return min(f[x][t],f[y-(1<<t)+1][t]);
}
void Getpos(LL k){
for(int i=1;i<=n;i++){
int t=n-sa[i]+1-Hht[i];
if(k>t) k-=t; else return nL=sa[i],nR=sa[i]+Hht[i]+k-1,void();
}
}
bool Cmp(int L1,int R1,int L2,int R2){
int len1=R1-L1+1,len2=R2-L2+1,lcp=LCP(L1,L2);
if(lcp>=len1&&lcp>=len2) return len1>len2?0:1;
if(len1<=len2&&lcp>=len1) return 1;
if(len1>len2&&lcp>=len2) return 0;
return s[L1+lcp]>s[L2+lcp]?0:1;
}
bool check(LL mid){
Getpos(mid);
int res=1,lst=n;
for(int i=n;i>=1;i--){
if(!Cmp(i,lst,nL,nR)) res++, lst=i;
if(s[i]>s[nL]) return false;
if(res>K) return false;
}
return true;
}
int main(){
scanf("%d%s",&K,s+1); n=strlen(s+1);
Build_SA();
//for(int i=1;i<=n;i++) printf("%s Hht=%d\n",s+sa[i],Hht[i]);
LL L=1,R=(LL)(1+n)*n/2; for(int i=1;i<=n;i++) R-=Hht[i];
while(L<=R){
LL mid=(L+R)>>1;
if(check(mid)) R=mid-1, ans=mid; else L=mid+1;
}
Getpos(ans);
for(int i=nL;i<=nR;i++) putchar(s[i]);
return 0;
}