0.前情提要:

0.1.为什么写这篇博客:

神仙老师在测试的时候放了这道神仙题,然而神仙的是,神仙老师也没有实现过这份神仙代码,本小蒟蒻和众大佬研究了神仙SFN1036的代码,懵逼数日之后,本蒟蒻终于A掉了这道题,一路上感触良多,觉得应该好好补充一下这道题的题解了,神仙可以忽略

0.2.简述题意:

加速一个K进制下树状数组的实现

SRE实战 互联网时代守护先锋,助力企业售后服务体系运筹帷幄!一键直达领取阿里云限量特价优惠。

重要概念:lowbit(x),表示x在K进制下从低位往高位数,第一个非零位的位值

为了表述清晰,我们再定义low(x)表示x在k进制下第一个非零位的数值,lowbitv(x)为该位的位权。

举个栗子:

x=(123456000)
lowbit(x)=(6000)
lowbitv(x)=(1000)
low(x)=(6)

0.3.题目链接:https://loj.ac/problem/510

1.思路题解:

1.1官方题解:https://loj.ac/article/87

1.2重点难点关键点:一个神仙规律

1.2.1.k为奇数

我们非常随便取几个数试试,看看add操作中x=x+lowbit(x)的走向

k=7 1,2,4,11,12,14,21,22,24...
   3,6,15,23,26,35,43,46,55...
k=9 1,2,4,8,17,25,31,32,34,38,47,55,61,62,64,68,77,85...
   3,6,13,16,23,26...

 不难发现,lowbit(x)形成了若干条互不相交循环链,这个规律是做题的核心。

1.2.2.k为偶数

我们很期待这个规律可以推广到任意情况,然而聪明的你看到我竟然分了两类讨论,应该就知道接下可能要发生一些不好的事情了。

k=6 1,2,4,12,14,22,24,..
3,10,20,40,120,140,220,240...
5,14,22,24,32,34,42...
k=8 1,2,4,10,20,40,100,200,400...
3,6,14,20,40,100... 5,12,14,20,40,100,200,400... 7,16,24,30,60,140,200,400...

然而,细心的你会发现,虽然没有奇数那么明显的规律,我们依然可以从数据中看到,偶数情况下有暗含的规律,我们可以从中剥离出一条正统的链。

比如k=6,正统的链是2,4,12,14,22...

我们发现这些链上的lowbit(x)的变化是有迹可循的,更加可喜可贺的是,非正统链上的数总是可以跳到正统链上。

于是接下来我们要解决一个问题:如何判定当前的x是否在正统链上。

这需要先证明这个奇妙的规律,写在文末,因为考虑到有些神犇一眼看穿。总之,结论是:

k进制下,当且仅当low(x)中含的2的幂次大于等于k中包含的2的幂次,x在正统链上。

 当然,这个结论的逆命题也成立,详细看文末的证明吧。

值得一提的是,正统的链不止一条,有时候大体一样,但数位不一样,也要分开考虑。

比如k=6时,2,4,12,14,22...与20,40,120,140,220...就是两条互不相交的正统链,要注意分开,这也解释了为什么这道题线段树的空间需求有点大。

2.代码呈现

loj#510北校门外的回忆,详解,真的详解 随笔 第1张
  1 #include<bits/stdc++.h>
  2 using namespace std;
  3 
  4 const int N=2e5+7;
  5 
  6 struct seg{
  7     int ls,rs,vl;
  8 }tr[N<<6];
  9 
 10 int opt,n,k,Q,RT,trip,tot;
 11 int num[N],pr[N],q[N],pj[N],bz1[N][20],bz2[N][20];
 12 bool vis[N];
 13 
 14 int lowbit(int m){
 15     while(m%k==0) m/=k;
 16     return m%k;
 17 }
 18 
 19 int lowbitv(int m){
 20     int ret=1;
 21     while(m%k==0) ret*=k,m/=k;
 22     return ret*(m%k);
 23 }
 24 
 25 void prework(){
 26     for(int i=1;i<=k;i++)
 27         for(int x=i;x%2==0;x>>=1,pr[i]++);
 28     for(int i=1;i<k;i++){
 29         if(vis[i]||pr[i]<pr[k]) continue;
 30         int len=0;
 31         q[++len]=i,vis[i]=1;
 32         for(int c=i*2%k;c^i;c=(c<<1)%k)
 33             q[++len]=c,vis[c]=1;
 34         int cnt=0,lm=k/2;
 35         for(int c=1;c<=len;c++)
 36             cnt+=(q[c]>lm);
 37         for(int c=1;c<=len;c++) {
 38             int u=q[c],v=(c-1)?q[c-1]:q[len];
 39             pj[u]=cnt,num[u]=len;
 40             bz1[u][0]=v;
 41             bz2[u][0]=(v>lm);
 42         }
 43         for(int s=1;s<=18;s++)
 44             for(int k=1;k<=len;k++){
 45                 int u=q[k];
 46                 bz1[u][s]=bz1[bz1[u][s-1]][s-1];
 47                 bz2[u][s]=bz2[bz1[u][s-1]][s-1]+bz2[u][s-1];
 48             }
 49     }
 50 }
 51 
 52 void ins(int &o,int l,int r,int x,int y,int v){
 53     if(!o) o=++tot;
 54     if(x<=l&&y>=r) {tr[o].vl^=v;return;}
 55     int mid=(l+r)>>1;
 56     if(x<=mid) ins(tr[o].ls,l,mid,x,y,v);
 57     if(y>mid) ins(tr[o].rs,mid+1,r,x,y,v);
 58 }
 59 
 60 int query(int o,int l,int r,int x){
 61     if(l==r||!o) return tr[o].vl;
 62     int mid=(l+r)>>1;
 63     if(x<=mid) return query(tr[o].ls,l,mid,x)^tr[o].vl;
 64     else return query(tr[o].rs,mid+1,r,x)^tr[o].vl;
 65 }
 66 
 67 int fnd_top(int x){
 68     int lg=0,top,jw;
 69     for(x;x%k==0;x/=k,lg++);
 70     top=x%k,jw=(x-top)/k;
 71     trip=1+jw/pj[top]*num[top];
 72     jw%=pj[top];
 73     for(int i=18;i>=0;i--)
 74         if(bz2[top][i]<=jw){
 75             jw-=bz2[top][i];
 76             trip+=(1<<i);
 77             top=bz1[top][i];
 78         }
 79     while(lg--) top*=k;
 80     return top;
 81 }
 82 
 83 void add(int o,int v){
 84     while(o<=n&&pr[lowbit(o)]<pr[k]){
 85         ins(RT,1,n,o,o,v);
 86         o+=lowbitv(o);
 87     }
 88     if(o>n) return;
 89     int st=fnd_top(o);
 90     int rt=query(RT,1,n,st),tmp=rt;
 91     ins(rt,1,n,trip,n,v);
 92     if(!tmp) ins(RT,1,n,st,st,rt);
 93 }
 94 
 95 int fnd_ans(int o){
 96     int ret=0;
 97     while(o){
 98         if(pr[lowbit(o)]<pr[k]) ret^=query(RT,1,n,o);
 99         else{
100             int st=fnd_top(o),rt=query(RT,1,n,st);
101             if(rt) ret^=query(rt,1,n,trip);
102         }
103         o-=lowbitv(o);
104     }
105     return ret;
106 }
107 
108 int main(){
109     int x,v;
110     scanf("%d%d%d",&n,&Q,&k);
111     prework();
112     while(Q--){
113         scanf("%d%d",&opt,&x);
114         if(opt==1){
115             scanf("%d",&v);
116             add(x,v);
117         }
118         else printf("%d\n",fnd_ans(x));
119     }
120 }
View Code

不得不说,自己写了一遍,发现了原神仙的写法十分精炼,显示出其神仙级的压行和代码优化操作,毕竟让我们死磕了好久,emmmmm

所以,这里我简述分模块讲一下。

2.1.变量说明

const int N=2e5+7;

struct seg{
    int ls,rs,vl;
}tr[N<<6];

int opt,n,k,Q,RT,trip,tot;
int num[N],pr[N],q[N],pj[N],bz1[N][20],bz2[N][20];
bool vis[N];

其中,trip记录的是当前x是从所属链的从头数第几位,top记录的是当前x所属链的链头(类似并查集的代表元),tot与seg共同构成看似一棵神仙动态开点线段树,rt,RT是看似两组线段树的跟(实际上有三组,只不过有两组合体了),num是每条链的长度,pr是每个小于k的正整数中包含的2的幂次,q是暂存每个链的lowbit(x),pj是没经历一次链的循环会进多少位,bz1是倍增第一维走2的i次方步走到的lowbit值,bz2是在此过程中一共进了多少位。没错,我们,很在意进位的事情,因为这是我们寻找链头的依据。

2.2.函数说明

void prework(){
    for(int i=1;i<=k;i++)
        for(int x=i;x%2==0;x>>=1,pr[i]++);
    for(int i=1;i<k;i++){
        if(vis[i]||pr[i]<pr[k]) continue;
        int len=0;
        q[++len]=i,vis[i]=1;
        for(int c=i*2%k;c^i;c=(c<<1)%k)
            q[++len]=c,vis[c]=1;
        int cnt=0,lm=k/2;
        for(int c=1;c<=len;c++)
            cnt+=(q[c]>lm);
        for(int c=1;c<=len;c++) {
            int u=q[c],v=(c-1)?q[c-1]:q[len];
            pj[u]=cnt,num[u]=len;
            bz1[u][0]=v;
            bz2[u][0]=(v>lm);
        }
        for(int s=1;s<=18;s++)
            for(int k=1;k<=len;k++){
                int u=q[k];
                bz1[u][s]=bz1[bz1[u][s-1]][s-1];
                bz2[u][s]=bz2[bz1[u][s-1]][s-1]+bz2[u][s-1];
            }
    }
}

预处理应该很容易懂,只要你好好看懂每个变量的意义。注意其中bz1数组是往前倒着跳的。

void ins(int &o,int l,int r,int x,int y,int v){
    if(!o) o=++tot;
    if(x<=l&&y>=r) {tr[o].vl^=v;return;}
    int mid=(l+r)>>1;
    if(x<=mid) ins(tr[o].ls,l,mid,x,y,v);
    if(y>mid) ins(tr[o].rs,mid+1,r,x,y,v);
}

int query(int o,int l,int r,int x){
    if(l==r||!o) return tr[o].vl;
    int mid=(l+r)>>1;
    if(x<=mid) return query(tr[o].ls,l,mid,x)^tr[o].vl;
    else return query(tr[o].rs,mid+1,r,x)^tr[o].vl;
}

这是常规的线段树函数,只不过是维护异或。神奇在于,这道题一共出现了三堆线段树合成两组后写成一棵的神奇操作。

注意当非叶子节点都保持在0的时候,这棵线段树将变身省空间费时间的数组,这一点在后面很重要。

int fnd_top(int x){
    int lg=0,top,jw;
    for(x;x%k==0;x/=k,lg++);
    top=x%k,jw=(x-top)/k;
    trip=1+jw/pj[top]*num[top];
    jw%=pj[top];
    for(int i=18;i>=0;i--)
        if(bz2[top][i]<=jw){
            jw-=bz2[top][i];
            trip+=(1<<i);
            top=bz1[top][i];
        }
    while(lg--) top*=k;
    return top;
} 

这段抽象的代码其实执行的是很直观的操作,举个例子:

(top=x=1234560000)→(123456)→(12345    6)→(top=6,jw=12345)

比较奇怪的是trip,它通过计算出该数一共进了多少位,用bz2丈量,计算x在链上跳了多少位,同时top通过bz1数组跳到链头。

Caution:我们返回的top要补回去掉的0,前文已经说过,正统链有多条并列存在。

至此我们得到了:x在所属链的位置,所属链的链头。

void add(int o,int v){
    while(o<=n&&pr[lowbit(o)]<pr[k]){
        ins(RT,1,n,o,o,v);
        o+=lowbitv(o);
    }
    if(o>n) return;
    int st=fnd_top(o);
    int rt=query(RT,1,n,st),tmp=rt;
    ins(rt,1,n,trip,n,v);
    if(!tmp) ins(RT,1,n,st,st,rt);
}

加入函数,前面部分判断了当前x是否跳到了正统链上,如果不是,我们暴力更新,用以RT为根的线段树中,o不在正统链上的位置模拟树状数组,如果在,那么我们用以RT为根剩下的部分,就是o在正统链上位置,记录以该链链头的线段树的根,是的,就是rt。至此三组线段树呈现完成。注意以rt为根的线段树维护的区间是链上的下标,而以RT为根的线段树就是省空间数组。

int fnd_ans(int o){
    int ret=0;
    while(o){
        if(pr[lowbit(o)]<pr[k]) ret^=query(RT,1,n,o);
        else{
            int st=fnd_top(o),rt=query(RT,1,n,st);
            if(rt) ret^=query(rt,1,n,trip);
        }
        o-=lowbitv(o);
    }
    return ret;
}

找答案函数,同样分两种情况操作,从末尾往前操作,每操作完一位就删去,注意这里调用的是lowbitv。

3.数学严谨证明

3.1.k为奇数

 我们发现,每一次的操作相当于把最末非0位乘上2,所以我们尝试讨论一下关于2的幂次。

令x=2^k1*d(d为奇数),由于lowbit(x)一定小于k,所以d≠k,而每一次增加的只有2的幂次,所以x不管怎么变都不会变成0.也就是说lowbit的位置不变。再根据欧拉定理可以知道,2^Φ(k) mod k = 1,所以总会跳到同一个点。反过来,若x1=x2(mod k),一定可以得到d1=d2。具体做法是把x1,x2分别写开,根据同余的性质可以得到(x1-x2)|k,提取公因式后得到2^x1*(2^x2*d1-d2)|k,k为奇数,所以得到2^x2*d2=d1。

扫码关注我们
微信号:SRE实战
拒绝背锅 运筹帷幄