阿狸的打字机 AC自动机+fail树+DFS序+树状数组

n个字符串,m次询问,每次询问(x, y),输出字符串x在字符串y中出现的次数。对于字符串x和y:

若x在y中出现过,则y在Trie树上的所有结点中,一定存在1个以上的结点,满足:沿着该结点的fail边走,会走到x对应的终结点。x在y中出现的次数,即为满足这个条件的结点的个数。
x在y中出现的次数也等于:把fail边反向,构成的以0为根的树,以x的终结点为根的子树中,有多少个点属于字符串y,x在y中就出现了多少次。(以x为根的子树中的所有孩子结点,所各自对应的字符串中,x都作为一个后缀出现)。

有了如上前提,我们可以这样做:
构造AC自动机,连接fail边。写出fail树的DFS序,x的终结点为根的子树的所有孩子结点,都在x后面的某个长度的区间内,我们只需统计该区间内有多少个结点属于y即可。如何统计?遍历Trie树,每向下访问一个结点,这个结点在DFS序中对应的位置的权值就+1,结束访问时就-1,这样的话,当访问到y的终结点时,DFS序中权值为1的点对应的字符串就都是y的子串了,那么我们求一下哪些点在x子树的DFS区间内即可,这个过程可以用树状数组或线段树来完成。

#include <cstdio>
#include <cstring>
#include <queue>
#include <algorithm>
#include <vector>
const int maxn = 1e5+5;
using namespace std;

char s[maxn];
int m;
int ans[maxn];
inline int lowbit(int i){
    return i & (-i);
}
struct query{
    int x, y;
    int idx;
}query[maxn];
struct AC_Auto{
    int Trie[maxn][27], fail[maxn], cnt, root;
    int num, cur;
    vector<int> G[maxn];                     //存fail图   fail树根节点为0
    int Sons[maxn];                          //fail树上每个结点的孩子结点个数  注意是结点
    int fa[maxn], mp1[maxn], mp2[maxn];      //mp1[i]表示第i个字符串的终结点位置。   mp2[i]表示结点i在DFS序中对应的下标
    int DFS[maxn];
    int c[maxn];                //树状数组

    AC_Auto(){
        root = 0;
        num = 0;    cur = 0;
    }
    void insert(char *s, int len){
        root = num = cnt = 0;
        for (int i=1; i<=len; i++){
            if (s[i]=='P'){
                mp1[++num] = root;                  //;
            }else if (s[i]=='B'){
                root = fa[root];
            }else{
                if (!Trie[root][s[i]-'a'+1]){
                    Trie[root][s[i]-'a'+1] = ++cnt;
                    fa[cnt] = root;
                }
                root = Trie[root][s[i]-'a'+1];
            }
        }
        return;
    }
    void build(){
        queue<int> q;
        for (int i=1; i<27; i++)    if (Trie[0][i])
            q.push(Trie[0][i]);
        while (!q.empty()){
            int tmp = q.front();    q.pop();
            for (int i=1; i<27; i++){
                if (Trie[tmp][i]){
                    fail[Trie[tmp][i]] = Trie[fail[tmp]][i];
                    q.push(Trie[tmp][i]);
                }else{
                    Trie[tmp][i] = Trie[fail[tmp]][i];
                }
            }
        }
        for (int i=1; i<=cnt; i++)  G[fail[i]].push_back(i);
        cur = -1;
        getSons(0);
        return;
    }

    int getSons(int rot){
        Sons[rot] = 0;
        DFS[++cur] = rot;
        mp2[rot] = cur;
        for (int i=0; i<G[rot].size(); i++)
            Sons[rot] += getSons(G[rot][i]);
        return Sons[rot]+1;
    }

    void update(int i, int k){
        while (i<=cnt){
            c[i] += k;
            i += lowbit(i);
        }
        return;
    }
    int getSum(int l, int r){
        int ans = 0;
        while (r){
            ans += c[r];
            r -= lowbit(r);
        }
        l--;
        while (l){
            ans -= c[l];
            l -= lowbit(l);
        }
        return ans;
    }
    void solve(char *s, int len){
        root = 0;
        cur = 1;
        for (int i=1; i<=len; i++){
            if (s[i]=='B'){
                update(mp2[root], -1);
                root = fa[root];
            }else if (s[i]=='P'){
                for (; root==mp1[query[cur].y] && cur<=m; cur++){
                    ans[query[cur].idx] = getSum(mp2[mp1[query[cur].x]], mp2[mp1[query[cur].x]] + Sons[mp1[query[cur].x]]);
                }
            }else{
                root = Trie[root][s[i]-'a'+1];
                update(mp2[root], 1);
            }
        }
        return;
    }
}AC;
inline bool cmp1(struct query a, struct query b){
    return a.y < b.y;
}
int main(){
    scanf ("%s", s+1);
    int len = strlen(s+1);
    AC.insert(s, len);
    AC.build();
    scanf ("%d", &m);
    for (int i=1; i<=m; i++){
        scanf ("%d %d", &query[i].x, &query[i].y);
        query[i].idx = i;
    }
    sort (query+1, query+m+1, cmp1);
    AC.solve(s, len);
    for (int i=1; i<=m; i++)
        printf ("%d\n", ans[i]);
    return 0;
}

Mike and Friends

题意:n个字符串,m次询问,每次询问l, r, k。输出字符串k在字符串[l, r]中出现的次数。

问出现次数,那肯定AC自动机没跑了。仍然利用性质:把fail边反向建立的树种,字符串k的终结点为根的子树的结点都是包含k的字符串。
建立AC自动机,fail边反向,写出fail树的DFS序。在DFS序上建立主席树。按字符串插入顺序跑一遍自动机,每跑到一个结点就在前一棵线段树的基础上以该结点在DFS序的位置为权值新建一个线段树。线段树每个结点表示DFS序里区间[l, r]内的权值,要求k的出现次数只用求线段树里相应区间的权值即可。要求在[l,r]内的出现次数只需两颗线段树相减得到的权值即可。

#include <iostream>
#include <string>
#include <queue>
#include <vector>
using namespace std;
const int maxn = 2e5+5;
typedef struct Presistial_SegmentTree{
    int left, right, val;
}PST;
int n, m;
int len[maxn];
string s[maxn];

struct AC_auto{
    int Trie[maxn][27], fail[maxn], cnt;
    int mp1[maxn], mp2[maxn], mp3[maxn];
    vector<int> G[maxn];
    int Sons[maxn], DFS[maxn], cur;

    PST tree[maxn*30];
    int top, root[maxn], cnt2;

    AC_auto(){
        top = cnt = cnt2 = 0;
        cur = -1;
    }
    void insert(string s, int len, int idx){
        int root = 0;
        for (int i=0; i<len; i++){
            if (!Trie[root][s[i]-'a'+1])    Trie[root][s[i]-'a'+1] = ++cnt;
            root = Trie[root][s[i]-'a'+1];
        }
        mp1[idx] = root;
        return;
    }
    void build(){
        queue<int>  q;
        for (int i=1; i<27; i++)    if (Trie[0][i])
            q.push(Trie[0][i]);
        while (!q.empty()){
            int tmp = q.front();    q.pop();
            for (int i=1; i<27; i++){
                if (Trie[tmp][i]){
                    fail[Trie[tmp][i]] = Trie[fail[tmp]][i];
                    q.push(Trie[tmp][i]);
                }else   Trie[tmp][i] = Trie[fail[tmp]][i];
            }
        }
        for (int i=1; i<=cnt; i++)
            G[fail[i]].push_back(i);
        cur = -1;
        Sons[0] = GetSons(0) - 1;
        return;
    }
    int GetSons(int i){
        DFS[++cur] = i;
        mp2[i] = cur;
        for (int j=0; j<G[i].size(); j++)
            Sons[i] += GetSons(G[i][j]);
        return Sons[i] + 1;
    }
    int build(int l, int r){
        int tmp = ++top;
        if (l==r){
            tree[tmp].val = 0;
            return tmp;
        }
        int mid = (l+r)>>1;
        tree[tmp].left = build(l, mid);
        tree[tmp].right = build(mid+1, r);
        tree[tmp].val = 0;
        return tmp;
    }
    int add(int src, int l, int r, int target_idx){
        int tmp = ++top;
        if (l==r){
            tree[tmp].val = tree[src].val + 1;
            return tmp;
        }
        int mid = (l+r)>>1;
        if (mid >= target_idx){
            tree[tmp].left = add(tree[src].left, l, mid, target_idx);
            tree[tmp].right = tree[src].right;
        }else{
            tree[tmp].right = add(tree[src].right, mid+1, r, target_idx);
            tree[tmp].left = tree[src].left;
        }
        tree[tmp].val = tree[tree[tmp].left].val + tree[tree[tmp].right].val;
        return tmp;
    }
    void solve(){
        root[0] = build(1, cnt);
        for (int i=1; i<=n; i++){
            int tmp = 0;
            for (int j=0; j<len[i]; j++){
                tmp = Trie[tmp][s[i][j]-'a'+1];
                cnt2++;
                root[cnt2] = add(root[cnt2-1], 1, cnt, mp2[tmp]);
            }
            mp3[i] = cnt2;
        }
        return;
    }

    int search(int rt1, int rt2, int l, int r, int target_l, int target_r){
        if (l>=target_l && r<=target_r)     return tree[rt2].val - tree[rt1].val;
        int ans = 0;
        int mid = (l+r)>>1;
        if (mid>=target_l)  ans += search(tree[rt1].left, tree[rt2].left, l, mid, target_l, target_r);
        if (mid+1<=target_r)    ans += search(tree[rt1].right, tree[rt2].right, mid+1, r, target_l, target_r);
        return ans;
    }
    int getans(int l, int r, int k){
        return search(root[mp3[l-1]], root[mp3[r]], 1, cnt, mp2[mp1[k]], mp2[mp1[k]]+Sons[mp1[k]]);
    }

}ACP;

inline int read(){
    int x = 0;  char ch = getchar();
    while (ch<'0' || ch>'9')    ch = getchar();
    while (ch>='0' && ch<='9'){
        x = x*10 + ch-'0';
        ch = getchar();
    }
    return x;
}
int main(){
    int l, r, k;
    n = read(), m = read();
    for (int i=1; i<=n; i++){
        cin >> s[i];
        len[i] = s[i].length();
        ACP.insert(s[i], len[i], i);
    }
    ACP.build();
    ACP.solve();
    for (int i=1; i<=m; i++){
        l = read(), r = read(), k = read();
        printf ("%d\n", ACP.getans(l, r, k));
    }
    return 0;
}

0 条评论

发表评论

邮箱地址不会被公开。