阿狸的打字机 AC自动机+fail树+DFS序+树状数组
n个字符串,m次询问,每次询问(x, y),输出字符串x在字符串y中出现的次数。对于字符串x和y:
若x在y中出现过,则y在Trie树上的所有结点中,一定存在1个以上的结点,满足:沿着该结点的fail边走,会走到x对应的终结点。x在y中出现的次数,即为满足这个条件的结点的个数。
x在y中出现的次数也等于:把fail边反向,构成的以0为根的树,以x的终结点为根的子树中,有多少个点属于字符串y,x在y中就出现了多少次。(以x为根的子树中的所有孩子结点,所各自对应的字符串中,x都作为一个后缀出现)。
有了如上前提,我们可以这样做:
构造AC自动机,连接fail边。写出fail树的DFS序,x的终结点为根的子树的所有孩子结点,都在x后面的某个长度的区间内,我们只需统计该区间内有多少个结点属于y即可。如何统计?遍历Trie树,每向下访问一个结点,这个结点在DFS序中对应的位置的权值就+1,结束访问时就-1,这样的话,当访问到y的终结点时,DFS序中权值为1的点对应的字符串就都是y的子串了,那么我们求一下哪些点在x子树的DFS区间内即可,这个过程可以用树状数组或线段树来完成。
#include <cstdio>
#include <cstring>
#include <queue>
#include <algorithm>
#include <vector>
const int maxn = 1e5+5;
using namespace std;
char s[maxn];
int m;
int ans[maxn];
inline int lowbit(int i){
return i & (-i);
}
struct query{
int x, y;
int idx;
}query[maxn];
struct AC_Auto{
int Trie[maxn][27], fail[maxn], cnt, root;
int num, cur;
vector<int> G[maxn]; //存fail图 fail树根节点为0
int Sons[maxn]; //fail树上每个结点的孩子结点个数 注意是结点
int fa[maxn], mp1[maxn], mp2[maxn]; //mp1[i]表示第i个字符串的终结点位置。 mp2[i]表示结点i在DFS序中对应的下标
int DFS[maxn];
int c[maxn]; //树状数组
AC_Auto(){
root = 0;
num = 0; cur = 0;
}
void insert(char *s, int len){
root = num = cnt = 0;
for (int i=1; i<=len; i++){
if (s[i]=='P'){
mp1[++num] = root; //;
}else if (s[i]=='B'){
root = fa[root];
}else{
if (!Trie[root][s[i]-'a'+1]){
Trie[root][s[i]-'a'+1] = ++cnt;
fa[cnt] = root;
}
root = Trie[root][s[i]-'a'+1];
}
}
return;
}
void build(){
queue<int> q;
for (int i=1; i<27; i++) if (Trie[0][i])
q.push(Trie[0][i]);
while (!q.empty()){
int tmp = q.front(); q.pop();
for (int i=1; i<27; i++){
if (Trie[tmp][i]){
fail[Trie[tmp][i]] = Trie[fail[tmp]][i];
q.push(Trie[tmp][i]);
}else{
Trie[tmp][i] = Trie[fail[tmp]][i];
}
}
}
for (int i=1; i<=cnt; i++) G[fail[i]].push_back(i);
cur = -1;
getSons(0);
return;
}
int getSons(int rot){
Sons[rot] = 0;
DFS[++cur] = rot;
mp2[rot] = cur;
for (int i=0; i<G[rot].size(); i++)
Sons[rot] += getSons(G[rot][i]);
return Sons[rot]+1;
}
void update(int i, int k){
while (i<=cnt){
c[i] += k;
i += lowbit(i);
}
return;
}
int getSum(int l, int r){
int ans = 0;
while (r){
ans += c[r];
r -= lowbit(r);
}
l--;
while (l){
ans -= c[l];
l -= lowbit(l);
}
return ans;
}
void solve(char *s, int len){
root = 0;
cur = 1;
for (int i=1; i<=len; i++){
if (s[i]=='B'){
update(mp2[root], -1);
root = fa[root];
}else if (s[i]=='P'){
for (; root==mp1[query[cur].y] && cur<=m; cur++){
ans[query[cur].idx] = getSum(mp2[mp1[query[cur].x]], mp2[mp1[query[cur].x]] + Sons[mp1[query[cur].x]]);
}
}else{
root = Trie[root][s[i]-'a'+1];
update(mp2[root], 1);
}
}
return;
}
}AC;
inline bool cmp1(struct query a, struct query b){
return a.y < b.y;
}
int main(){
scanf ("%s", s+1);
int len = strlen(s+1);
AC.insert(s, len);
AC.build();
scanf ("%d", &m);
for (int i=1; i<=m; i++){
scanf ("%d %d", &query[i].x, &query[i].y);
query[i].idx = i;
}
sort (query+1, query+m+1, cmp1);
AC.solve(s, len);
for (int i=1; i<=m; i++)
printf ("%d\n", ans[i]);
return 0;
}
Mike and Friends
题意:n个字符串,m次询问,每次询问l, r, k。输出字符串k在字符串[l, r]中出现的次数。
问出现次数,那肯定AC自动机没跑了。仍然利用性质:把fail边反向建立的树种,字符串k的终结点为根的子树的结点都是包含k的字符串。
建立AC自动机,fail边反向,写出fail树的DFS序。在DFS序上建立主席树。按字符串插入顺序跑一遍自动机,每跑到一个结点就在前一棵线段树的基础上以该结点在DFS序的位置为权值新建一个线段树。线段树每个结点表示DFS序里区间[l, r]内的权值,要求k的出现次数只用求线段树里相应区间的权值即可。要求在[l,r]内的出现次数只需两颗线段树相减得到的权值即可。
#include <iostream>
#include <string>
#include <queue>
#include <vector>
using namespace std;
const int maxn = 2e5+5;
typedef struct Presistial_SegmentTree{
int left, right, val;
}PST;
int n, m;
int len[maxn];
string s[maxn];
struct AC_auto{
int Trie[maxn][27], fail[maxn], cnt;
int mp1[maxn], mp2[maxn], mp3[maxn];
vector<int> G[maxn];
int Sons[maxn], DFS[maxn], cur;
PST tree[maxn*30];
int top, root[maxn], cnt2;
AC_auto(){
top = cnt = cnt2 = 0;
cur = -1;
}
void insert(string s, int len, int idx){
int root = 0;
for (int i=0; i<len; i++){
if (!Trie[root][s[i]-'a'+1]) Trie[root][s[i]-'a'+1] = ++cnt;
root = Trie[root][s[i]-'a'+1];
}
mp1[idx] = root;
return;
}
void build(){
queue<int> q;
for (int i=1; i<27; i++) if (Trie[0][i])
q.push(Trie[0][i]);
while (!q.empty()){
int tmp = q.front(); q.pop();
for (int i=1; i<27; i++){
if (Trie[tmp][i]){
fail[Trie[tmp][i]] = Trie[fail[tmp]][i];
q.push(Trie[tmp][i]);
}else Trie[tmp][i] = Trie[fail[tmp]][i];
}
}
for (int i=1; i<=cnt; i++)
G[fail[i]].push_back(i);
cur = -1;
Sons[0] = GetSons(0) - 1;
return;
}
int GetSons(int i){
DFS[++cur] = i;
mp2[i] = cur;
for (int j=0; j<G[i].size(); j++)
Sons[i] += GetSons(G[i][j]);
return Sons[i] + 1;
}
int build(int l, int r){
int tmp = ++top;
if (l==r){
tree[tmp].val = 0;
return tmp;
}
int mid = (l+r)>>1;
tree[tmp].left = build(l, mid);
tree[tmp].right = build(mid+1, r);
tree[tmp].val = 0;
return tmp;
}
int add(int src, int l, int r, int target_idx){
int tmp = ++top;
if (l==r){
tree[tmp].val = tree[src].val + 1;
return tmp;
}
int mid = (l+r)>>1;
if (mid >= target_idx){
tree[tmp].left = add(tree[src].left, l, mid, target_idx);
tree[tmp].right = tree[src].right;
}else{
tree[tmp].right = add(tree[src].right, mid+1, r, target_idx);
tree[tmp].left = tree[src].left;
}
tree[tmp].val = tree[tree[tmp].left].val + tree[tree[tmp].right].val;
return tmp;
}
void solve(){
root[0] = build(1, cnt);
for (int i=1; i<=n; i++){
int tmp = 0;
for (int j=0; j<len[i]; j++){
tmp = Trie[tmp][s[i][j]-'a'+1];
cnt2++;
root[cnt2] = add(root[cnt2-1], 1, cnt, mp2[tmp]);
}
mp3[i] = cnt2;
}
return;
}
int search(int rt1, int rt2, int l, int r, int target_l, int target_r){
if (l>=target_l && r<=target_r) return tree[rt2].val - tree[rt1].val;
int ans = 0;
int mid = (l+r)>>1;
if (mid>=target_l) ans += search(tree[rt1].left, tree[rt2].left, l, mid, target_l, target_r);
if (mid+1<=target_r) ans += search(tree[rt1].right, tree[rt2].right, mid+1, r, target_l, target_r);
return ans;
}
int getans(int l, int r, int k){
return search(root[mp3[l-1]], root[mp3[r]], 1, cnt, mp2[mp1[k]], mp2[mp1[k]]+Sons[mp1[k]]);
}
}ACP;
inline int read(){
int x = 0; char ch = getchar();
while (ch<'0' || ch>'9') ch = getchar();
while (ch>='0' && ch<='9'){
x = x*10 + ch-'0';
ch = getchar();
}
return x;
}
int main(){
int l, r, k;
n = read(), m = read();
for (int i=1; i<=n; i++){
cin >> s[i];
len[i] = s[i].length();
ACP.insert(s[i], len[i], i);
}
ACP.build();
ACP.solve();
for (int i=1; i<=m; i++){
l = read(), r = read(), k = read();
printf ("%d\n", ACP.getans(l, r, k));
}
return 0;
}
0 条评论