题目链接

洛谷:https://www.luogu.org/problemnew/show/P4389

Solution

挺巧妙的题。

SRE实战 互联网时代守护先锋,助力企业售后服务体系运筹帷幄!一键直达领取阿里云限量特价优惠。

对于每件物品可以看成无穷多个,背包转移可以写成卷积的形式,对于质量为\(v\)的物品,写成生成函数就是:
\[ F(x)=\sum_{i=0}^{+\infty}x^{vi} \]
然后有\(1e5\)个这样的东西,乘起来就是答案,复杂度\(O(mn\log n)\)

但这样显然过不了,我们把上面的函数变一下:
\[ F(x)=\sum_{i=0}^{+\infty}x^{vi}=\frac{1}{1-x^v} \]
然后若干个乘起来就是:
\[ ans=\prod_{i=1}^{n}\frac{1}{1-x^{v_i}} \]
考虑用\(\ln\)化成为加:
\[ \ln ans=\sum_{i=1}^n\ln \frac{1}{1-x^{v_i}} \]
由于:
\[ \ln F(x)=\int F'(x)F^{-1}(x)~{\rm{d}}x \]
带进去可得:
\[ \ln \frac{1}{1-x^{v}}=\int \frac{vx^{v-1}}{1-x^v}~{\rm{d}}x \]
展开再积分:
\[ \ln \frac{1}{1-x^{v}}=\int \sum_{i=1}^{+\infty}vx^{vi-1}~{\rm{d}}x=\sum_{i=1}^{+\infty}\frac{1}{i}x^{vi} \]
那么这个就可以直接算了,最后\(\exp\)一下就好了,注意处理出这个函数可以用记个桶然后调和级数的小技巧。

时间复杂度\(O(n\log n)\)

#include<bits/stdc++.h>
using namespace std;

void read(int &x) {
    x=0;int f=1;char ch=getchar();
    for(;!isdigit(ch);ch=getchar()) if(ch=='-') f=-f;
    for(;isdigit(ch);ch=getchar()) x=x*10+ch-'0';x*=f;
}

void print(int x) {
    if(x<0) putchar('-'),x=-x;
    if(!x) return ;print(x/10),putchar(x%10+48);
}
void write(int x) {if(!x) putchar('0');else print(x);putchar('\n');}

#define lf double
#define ll long long 

const int maxn = 8e5+10;
const int inf = 1e9;
const lf eps = 1e-8;
const int mod = 998244353;

int add(int x,int y) {return x+y>mod?x+y-mod:x+y;}
int del(int x,int y) {return x-y<0?x-y+mod:x-y;}
int mul(int x,int y) {return 1ll*x*y-1ll*x*y/mod*mod;}

int w[maxn],rw[maxn],pos[maxn],N,bit,f[maxn];
int tmp[10][maxn],n,m,cnt[maxn],inv[maxn],g[maxn],mxn;

int qpow(int a,int x) {
    int res=1;
    for(;x;x>>=1,a=mul(a,a)) if(x&1) res=mul(res,a);
    return res;
}

void ntt_init() {
    w[0]=1,w[1]=qpow(3,(mod-1)/mxn);
    for(int i=2;i<=mxn;i++) w[i]=mul(w[i-1],w[1]);
    rw[0]=1,rw[1]=qpow(w[1],mod-2);
    for(int i=2;i<=mxn;i++) rw[i]=mul(rw[i-1],rw[1]);
}

void ntt_get(int len) {
    for(N=1,bit=0;N<=len;N<<=1,bit++);
    for(int i=0;i<N;i++) pos[i]=pos[i>>1]>>1|((i&1)<<(bit-1));
}

void ntt(int *r,int op) {
    for(int i=0;i<N;i++) if(pos[i]>i) swap(r[i],r[pos[i]]);
    for(int i=1,d=mxn>>1;i<N;i<<=1,d>>=1)
        for(int j=0;j<N;j+=i<<1)
            for(int k=0;k<i;k++) {
                int x=r[j+k],y=mul((op==-1?rw:w)[k*d],r[i+j+k]);
                r[j+k]=add(x,y),r[i+j+k]=del(x,y);
            }
    if(op==-1) {int d=qpow(N,mod-2);for(int i=0;i<N;i++) r[i]=mul(r[i],d);}
}

void poly_inv(int *r,int *t,int len) {
    if(len==1) return t[0]=qpow(r[0],mod-2),void();
    poly_inv(r,t,len>>1);ntt_get(len);
    for(int i=0;i<len>>1;i++) tmp[0][i]=r[i],tmp[1][i]=t[i];
    for(int i=len>>1;i<len;i++) tmp[0][i]=r[i],tmp[1][i]=0;
    for(int i=len;i<N;i++) tmp[0][i]=tmp[1][i]=0;
    ntt(tmp[0],1),ntt(tmp[1],1);
    for(int i=0;i<N;i++) t[i]=del(mul(2,tmp[1][i]),mul(mul(tmp[1][i],tmp[1][i]),tmp[0][i]));
    ntt(t,-1);for(int i=len;i<N;i++) t[i]=0;
}

void poly_der(int *r,int *t,int len) {
    for(int i=1;i<len;i++) t[i-1]=mul(i,r[i]);
    for(int i=len-1;i<len<<1;i++) t[i]=0;
}

void poly_int(int *r,int *t,int len) {
    for(int i=0;i<len;i++) t[i+1]=mul(inv[i+1],r[i]);t[0]=0;
    for(int i=len+1;i<len<<1;i++) t[i]=0;
}

void poly_ln(int *r,int *t,int len) {
    poly_der(r,tmp[2],len);
    poly_inv(r,tmp[3],len);
    ntt_get(len),ntt(tmp[2],1),ntt(tmp[3],1);
    for(int i=0;i<N;i++) tmp[2][i]=mul(tmp[2][i],tmp[3][i]);
    ntt(tmp[2],-1);poly_int(tmp[2],t,len);
    for(int i=0;i<len<<1;i++) tmp[3][0]=0;
}

void poly_exp(int *r,int *t,int len) {
    if(len==1) return t[0]=1,void();
    poly_exp(r,t,len>>1);
    poly_ln(t,tmp[4],len);
    for(int i=0;i<len;i++) tmp[4][i]=del(r[i],tmp[4][i]);tmp[4][0]=add(tmp[4][0],1);
    for(int i=0;i<len;i++) tmp[5][i]=t[i];
    for(int i=len;i<len<<1;i++) tmp[4][i]=tmp[5][i]=0;
    ntt_get(len);ntt(tmp[4],1),ntt(tmp[5],1);
    for(int i=0;i<N;i++) t[i]=mul(tmp[4][i],tmp[5][i]);
    ntt(t,-1);for(int i=len;i<N;i++) t[i]=0;
}

int main() {
    read(n),read(m);
    for(int i=1,x;i<=n;i++) read(x),cnt[x]++;
    inv[0]=inv[1]=1;for(int i=2;i<=m<<2;i++) inv[i]=mul(mod-mod/i,inv[mod%i]);
    for(int i=1;i<=m;i++)
        if(cnt[i]) for(int j=i;j<=m;j+=i) f[j]=add(f[j],mul(cnt[i],inv[j/i]));
    ntt_get(m);mxn=N<<1;ntt_init();poly_exp(f,g,N);
    for(int i=1;i<=m;i++) write(g[i]);
    return 0;
}
扫码关注我们
微信号:SRE实战
拒绝背锅 运筹帷幄