1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
|
#include <bits/stdc++.h> using namespace std; #define int long long #define resetIO(x) \ freopen(#x ".in", "r", stdin), freopen(#x ".out", "w", stdout) #define debug(fmt, ...) \ fprintf(stderr, "[%s:%d] " fmt "\n", __FILE__, __LINE__, ##__VA_ARGS__) template <class _Tp> inline _Tp& read(_Tp &x) { bool sign = false; char ch = getchar(); long double tmp = 1; for (; !isdigit(ch); ch = getchar()) sign |= (ch == '-'); for (x = 0; isdigit(ch); ch = getchar()) x = x * 10 + (ch ^ 48); if (ch == '.') for (ch = getchar(); isdigit(ch); ch = getchar()) tmp /= 10.0, x += tmp * (ch ^ 48); return sign ? (x = -x) : x; } template <class _Tp> inline void write(_Tp x) { if (x < 0) putchar('-'), x = -x; if (x > 9) write(x / 10); putchar((x % 10) ^ 48); } const int MAXN = 4e5 + 10; const int INF = 0x3f3f3f3f3f3f3f3f; struct suffix_array { char *str; int n, sa[MAXN], rk[MAXN], tp[MAXN], ht[MAXN]; void radix_sort(int m) { static int buk[MAXN]; for (int i = 0; i <= m; ++i) buk[i] = 0; for (int i = 1; i <= n; ++i) buk[rk[i]]++; for (int i = 1; i <= m; ++i) buk[i] += buk[i - 1]; for (int i = n; i >= 1; --i) sa[buk[rk[tp[i]]]--] = tp[i]; } void build(char *s, int n) { this->str = s; this->n = n; int m = 128; for (int i = 1; i <= n; ++i) rk[i] = s[i] + 1, tp[i] = i; radix_sort(m); for (int p = 0, w = 1; p < n; m = p, w <<= 1) { p = 0; for (int i = 1; i <= w; ++i) tp[++p] = n - w + i; for (int i = 1; i <= n; ++i) if (sa[i] > w) tp[++p] = sa[i] - w; radix_sort(m); copy(rk + 1, rk + n + 1, tp + 1); rk[sa[1]] = p = 1; for (int i = 2; i <= n; ++i) { if (tp[sa[i - 1]] == tp[sa[i]] && tp[sa[i - 1] + w] == tp[sa[i] + w]) rk[sa[i]] = p; else rk[sa[i]] = ++p; } } for (int i = 1, k = 0; i <= n; ++i) { if (k) k--; while (s[i + k] == s[sa[rk[i] - 1] + k]) k++; ht[i] = k; } } int calc() { int ret = 0, top; static int pre[MAXN], nxt[MAXN], sta[MAXN]; sta[top = 0] = 0; for (int i = 1; i <= n; ++i) { while (top && ht[sa[sta[top]]] >= ht[sa[i]]) --top; pre[i] = i - sta[top]; sta[++top] = i; } sta[top = 0] = n + 1; for (int i = n; i >= 1; --i) { while (top && ht[sa[sta[top]]] > ht[sa[i]]) --top; nxt[i] = sta[top] - i; sta[++top] = i; } for (int i = 1; i <= n; ++i) ret += pre[i] * nxt[i] * ht[sa[i]]; return ret; } }; int n0, n1, n2; char s0[MAXN], s1[MAXN], s2[MAXN]; suffix_array sa0, sa1, sa2; signed main() { scanf("%s%s", s1 + 1, s2 + 1); n1 = strlen(s1 + 1); n2 = strlen(s2 + 1); copy(s1 + 1, s1 + n1 + 1, s0 + 1); copy(s2 + 1, s2 + n2 + 1, s0 + n1 + 2); s0[n1 + 1] = '$', n0 = n1 + n2 + 1; sa0.build(s0, n0); sa1.build(s1, n1); sa2.build(s2, n2); int ans = sa0.calc() - sa1.calc() - sa2.calc(); write(ans), putchar('\n'); return 0; }
|