yaoxi-std 的博客

$\text{开}\mathop{\text{卷}}\limits^{ju\check{a}n}\text{有益}$

0%

P1117 [NOI2016] 优秀的拆分

P1117 [NOI2016] 优秀的拆分

题面

题目链接

解法

$O(n^2)$暴力+$hash$能拿$95pts$。

但显然我们不满足于这$95pts$,考虑如何优化这个$O(n^2)$,即更快地求出每个位置对应的AA串数量。

考虑枚举A串的长度$len$,并且每隔$len$个字符分割一次,这样每一条长度为$len\times 2$的AA串都经过了至少$2$个分隔符,于是考虑从分隔符向两边拓展即可,用后缀数组计算。时间复杂度$O(n\log n)$。

鬼知道我调了多久(具体看代码吧)。SA多次使用一定要清空,我也不知道为什么。

AC代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
/**
* @file: P1117.cpp
* @author: yaoxi-std
* @url: https://www.luogu.com.cn/problem/P1117
*/
// #pragma GCC optimize ("O2")
// #pragma GCC optimize ("Ofast", "inline", "-ffast-math")
// #pragma GCC target ("avx,sse2,sse3,sse4,mmx")
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define resetIO(x) \
freopen(#x ".in", "r", stdin), freopen(#x ".out", "w", stdout)
#define debug(fmt, ...) \
fprintf(stderr, "[%s:%d] " fmt "\n", __FILE__, __LINE__, ##__VA_ARGS__)
template <class _Tp>
inline _Tp& read(_Tp& x) {
bool sign = false;
char ch = getchar();
long double tmp = 1;
for (; !isdigit(ch); ch = getchar())
sign |= (ch == '-');
for (x = 0; isdigit(ch); ch = getchar())
x = x * 10 + (ch ^ 48);
if (ch == '.')
for (ch = getchar(); isdigit(ch); ch = getchar())
tmp /= 10.0, x += tmp * (ch ^ 48);
return sign ? (x = -x) : x;
}
template <class _Tp>
inline void write(_Tp x) {
if (x < 0)
putchar('-'), x = -x;
if (x > 9)
write(x / 10);
putchar((x % 10) ^ 48);
}
const int MAXN = 3e4 + 10;
const int LOGN = 16;
const int INF = 0x3f3f3f3f3f3f3f3f;
struct SuffixArray {
int n, sa[MAXN], rk[MAXN], tp[MAXN], he[MAXN];
void clear() {
fill(sa, sa + n + 1, 0);
fill(rk, rk + n + 1, 0);
fill(tp, tp + n + 1, 0);
fill(he, he + n + 1, 0);
n = 0;
}
void radix_sort(int m) {
static int buc[MAXN];
for (int i = 0; i <= m; ++i)
buc[i] = 0;
for (int i = 1; i <= n; ++i)
buc[rk[i]]++;
for (int i = 1; i <= m; ++i)
buc[i] += buc[i - 1];
for (int i = n; i >= 1; --i)
sa[buc[rk[tp[i]]]--] = tp[i];
}
void init(int n, char* s) {
this->n = n;
int m = 200;
for (int i = 1; i <= n; ++i)
rk[i] = s[i] + 1, tp[i] = i;
radix_sort(m);
for (int w = 1, p = 0; p < n; m = p, w <<= 1) {
p = 0;
for (int i = 1; i <= w; ++i)
tp[++p] = n - w + i;
for (int i = 1; i <= n; ++i)
if (sa[i] > w)
tp[++p] = sa[i] - w;
radix_sort(m);
copy(rk + 1, rk + n + 1, tp + 1);
rk[sa[1]] = p = 1;
for (int i = 2; i <= n; ++i) {
if (tp[sa[i - 1]] == tp[sa[i]] && tp[sa[i - 1] + w] == tp[sa[i] + w])
rk[sa[i]] = p;
else
rk[sa[i]] = ++p;
}
}
for (int i = 1, k = 0; i <= n; ++i) {
if (k)
k--;
while (s[i + k] == s[sa[rk[i] - 1] + k])
k++;
he[rk[i]] = k;
}
}
} sa, pa;
int n, a[MAXN], b[MAXN], lg2[MAXN];
int smn[MAXN][LOGN], pmn[MAXN][LOGN];
char s[MAXN];
int getlcs(int l, int r) {
if (l > r)
swap(l, r);
int k = lg2[r - l];
return min(smn[l + 1][k], smn[r - (1 << k) + 1][k]);
}
int getlcp(int l, int r) {
if (l > r)
swap(l, r);
int k = lg2[r - l];
return min(pmn[l + 1][k], pmn[r - (1 << k) + 1][k]);
}
signed main() {
int tt;
read(tt);
while (tt--) {
scanf("%s", s + 1);
n = strlen(s + 1);
sa.init(n, s);
reverse(s + 1, s + n + 1);
pa.init(n, s);
reverse(s + 1, s + n + 1);
fill(a, a + n + 2, 0);
fill(b, b + n + 2, 0);
for (int i = 2; i <= n; ++i)
lg2[i] = (i & (i - 1)) ? lg2[i - 1] : lg2[i - 1] + 1;
for (int i = 1; i <= n; ++i) {
smn[i][0] = sa.he[i];
pmn[i][0] = pa.he[i];
}
for (int j = 1; j < LOGN; ++j) {
for (int i = 1; i + (1 << j) - 1 <= n; ++i) {
smn[i][j] = min(smn[i][j - 1], smn[i + (1 << (j - 1))][j - 1]);
pmn[i][j] = min(pmn[i][j - 1], pmn[i + (1 << (j - 1))][j - 1]);
}
}
for (int x = 1; x + x <= n; ++x) {
for (int sr = x + 1; sr <= n; sr += x) {
int sl = sr - x;
int pl = n - sr + 1, pr = n - sl + 1;
int pre = getlcp(pa.rk[pl], pa.rk[pr]);
int nxt = getlcs(sa.rk[sl], sa.rk[sr]);
pre = min(pre, x), nxt = min(nxt, x);
if (pre + nxt > x) {
a[sr - pre + x]++, a[sr + nxt]--;
b[sl - pre]--, b[sl + nxt - x]++;
}
}
}
for (int i = 1; i <= n; ++i)
a[i] += a[i - 1];
for (int i = n; i >= 1; --i)
b[i] += b[i + 1];
int ans = 0;
for (int i = 1; i < n; ++i)
ans += a[i] * b[i + 1];
write(ans), putchar('\n');
sa.clear(), pa.clear();
}
return 0;
}