Loj 2320.「清华集训 2017」生成树计数

题目描述

在一个 \(s\) 个点的图中,存在 \(s-n\) 条边,使图中形成了 \(n\) 个连通块,第 \(i\) 个连通块中有 \(a_i\) 个点。

现在我们需要再连接 \(n-1\) 条边,使该图变成一棵树。对一种连边方案,设原图中第 \(i\) 个连通块连出了 \(d_i\) 条边,那么这棵树 \(T\) 的价值为:

SRE实战 互联网时代守护先锋,助力企业售后服务体系运筹帷幄!一键直达领取阿里云限量特价优惠。

\[ \mathrm{val}(T) = \left(\prod_{i=1}^{n} {d_i}^m\right)\left(\sum_{i=1}^{n} {d_i}^m\right) \]

你的任务是求出所有可能的生成树的价值之和,对 \(998244353\) 取模。

输入格式

输入的第一行包含两个整数 \(n,m\),意义见题目描述。

接下来一行有 \(n\) 个整数,第 \(i\) 个整数表示 \(a_i\) \((1\le a_i< 998244353)\)

* 你可以由 \(a_i\) 计算出图的总点数 \(s\),所以在输入中不再给出 \(s\) 的值。

输出格式

输出包含一行一个整数,表示答案。

数据范围与提示

本题共有 \(20\) 个测试点,每个测试点 \(5\) 分。

- \(20\%\) 的数据中,\(n\le500\)

- 另外 \(20\%\) 的数据中,\(n \le 3000\)

- 另外 \(10\%\) 的数据中,\(n \le 10010, m = 1\)

- 另外 \(10\%\)的数据中,\(n \le 10015,m = 2\)

-另外 \(20\%\) 的数据中,所有 \(a_i\) 相等。

\(\\\)

好神的题啊!

假设我们知道了每个点的度数,考虑计算此时的生成树的个数。这个用\(prufer\)序列非常好解决:

假设第\(i\)个点在\(prufer\)序列中出现次数为\(d_i\),(则其度数为\(d_i+1\)
\[ Ans=(n-2)!\prod_{i=1}^n\frac{{a_i}^{d_i+1}}{d_i!} \]
先考虑对式子进行变形
\[ \begin{align} Ans&= \sum_{\sum d_i==n-2} (n-2)! \sum_{i=1}^n\frac{{{a_i}^{d_i+1}d_i}^{2m}}{d_i!} \prod_{j=1,j\neq i}^n\frac{{d_j}^m}{d_j!}\\ &=(n-2)!\prod_{i=1}^na_i \sum_{\sum_{d_i==n-2}}\sum_{i=1}^n\frac{{{a_i}^{d_i}d_i}^{2m}}{d_i!} \prod_{j=1,j\neq i}^n\frac{{d_j}^m}{d_j!}\\ \end{align} \]

\[ Ans'=\sum_{\sum_{d_i==n-2}}\sum_{i=1}^n\frac{{{a_i}^{d_i}d_i}^{2m}}{d_i!}\prod_{j=1,j\neq i}^n\frac{{d_j}^m}{d_j!} \]
考虑用生成函数解决:
\[ A(x)=\sum_{i=0}^n\frac{i^{2m}}{i!}x^i\\ B(x)=\sum_{i=0}^n\frac{i^m}{i!}x^i \]
\(Ans'\)的生成函数为
\[ \sum_{i=1}^nA(a_i)\prod_{j=1,j\neq i}^nB(a_j)\\ =\sum_{i=1}^n\frac{A(a_i)}{B(a_i)}\prod_{j=1}^nB(a_j) \]
对于\(\prod_{j=1}^nB(a_j)\),我们的一般套路是将其写成
\[ \exp(\ln(\prod_{j=1}^nB(a_j)))\\ =\exp(\sum_{j=1}^n\ln(B(a_j))) \]
这样做的好处是我们只需要求出\(\ln(B(x))\),然后对第\(i\)项系数乘上\(\displaystyle \sum_{j=1}^n{a_j}^i\)就可以得到\(\displaystyle \sum_{j=1}^n\ln(B(a_j))\)了。对于\(\displaystyle \sum_{i=1}^n\frac{A(a_i)}{B(a_i)}\)我们也 用相同的处理方式。

所以:
\[ Ans'=\sum_{i=1}^n\frac{A}{B}(a_i)\exp(\sum_{j=1}^n\ln(B(a_j))) \]
现在的问题是如何求出
\[ \sum_{i=1}^n{a_i}^k \]
考虑\(\ln(x)\)的取\(x_0=1\)时的泰勒展开形式
\[ \ln(x)=\sum_{i=0}\frac{\ln^{[i](1)}}{i!}(x-1)^i\\ =\sum_{i=1}\frac{(-1)^{i-1}}{i}(x-1)^i \]
所以:
\[ \ln(1+a_jx)=\sum_{i=1}\frac{(-1)^{i-1}{a_j}^i}{i}x^i \]
那么我们只需要求出
\[ \sum_{i=1}^n\ln(a_i) \]
就行了。
\[ \sum_{i=1}^n\ln(a_i)=\ln(\prod_{i=1}^n(1+a_ix)) \]
\(\prod_{i=1}^n(1+a_ix)\)可以用分治\(NTT\)求出。

代码:

#include<bits/stdc++.h>
#define ll long long
#define N 200005

using namespace std;
inline int Get() {int x=0,f=1;char ch=getchar();while(ch<'0'||ch>'9') {if(ch=='-') f=-1;ch=getchar();}while('0'<=ch&&ch<='9') {x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}return x*f;}

const ll mod=998244353;
ll ksm(ll t,ll x) {
    ll ans=1;
    for(;x;x>>=1,t=t*t%mod)
        if(x&1) ans=ans*t%mod;
    return ans;
}

int n,m;
int a[N];

void NTT(ll *a,int d,int flag) {
    int n=1<<d;
    static int rev[N<<2];
    static ll G=3;
    for(int i=0;i<n;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<d-1);
    for(int i=0;i<n;i++) if(i<rev[i]) swap(a[i],a[rev[i]]);
    
    for(int s=1;s<=d;s++) {
        int len=1<<s,mid=len>>1;
        ll w=flag==1?ksm(G,(mod-1)/len):ksm(G,mod-1-(mod-1)/len);
        for(int i=0;i<n;i+=len) {
            ll t=1;
            for(int j=0;j<mid;j++,t=t*w%mod) {
                ll u=a[i+j],v=a[i+j+mid]*t%mod;
                a[i+j]=(u+v)%mod;
                a[i+j+mid]=(u-v+mod)%mod;
            }
        }
    }
    
    if(flag==-1) {
        ll inv=ksm(n,mod-2);
        for(int i=0;i<n;i++) a[i]=a[i]*inv%mod;
    }
}

ll A[N<<2],B[N<<2];
ll inv[N<<2];
ll f[N<<2],g[N<<2];
void Inv(ll *inv,ll *a,int d) {
    static ll A[N<<3];
    if(d==0) {
        inv[0]=ksm(a[0],mod-2);
        return ;
    }
    Inv(inv,a,d-1);
    for(int i=0;i<1<<d;i++) A[i]=a[i];
    for(int i=1<<d;i<1<<d+1;i++) inv[i]=A[i]=0;
    NTT(A,d+1,1);
    NTT(inv,d+1,1);
    for(int i=0;i<1<<d+1;i++) {
        inv[i]=(2*inv[i]-A[i]*inv[i]%mod*inv[i]%mod+mod)%mod;
    }
    NTT(inv,d+1,-1);
    for(int i=1<<d;i<1<<d+1;i++) inv[i]=0;
}

void Der(ll *a,int d) {
    int n=1<<d;
    for(int i=0;i<n-1;i++) a[i]=(i+1)*a[i+1]%mod;
    a[n-1]=0;
}

void Int(ll *a,int d) {
    int n=1<<d;
    for(int i=n-1;i>0;i--) a[i]=ksm(i,mod-2)*a[i-1]%mod;
    a[0]=0;
}

ll ln[N<<2];
void Ln(ll *ln,ll *a,int d) {
    static ll der[N<<2];
    for(int i=0;i<1<<d+1;i++) der[i]=0;
    for(int i=0;i<1<<d;i++) der[i]=a[i];
    Inv(inv,a,d);
    Der(der,d);
    NTT(inv,d+1,1),NTT(der,d+1,1);
    for(int i=0;i<1<<d+1;i++) ln[i]=inv[i]*der[i]%mod;
    NTT(ln,d+1,-1);
    for(int i=1<<d;i<1<<d+1;i++) ln[i]=0;
    Int(ln,d);
    for(int i=1<<d;i<1<<d+1;i++) ln[i]=0;
}

ll ex[N<<2];

void Exp(ll *exp,ll *a,int d) {
    static ll A[N<<2],B[N<<2];
    if(d==0) {
        exp[0]=1;
        return ;
    }
    Exp(exp,a,d-1);
    for(int i=0;i<1<<d;i++) A[i]=a[i];
    for(int i=1<<d;i<1<<d+1;i++) exp[i]=A[i]=0;
    Ln(B,exp,d);
    NTT(exp,d+1,1);
    NTT(B,d+1,1);
    NTT(A,d+1,1);
    for(int i=0;i<1<<d+1;i++) {
        exp[i]=exp[i]*(1-B[i]+A[i]+mod)%mod;
    }
    NTT(exp,d+1,-1);
    for(int i=1<<d;i<1<<d+1;i++) exp[i]=0;
}

void solve(int l,int r,ll *a) {
    static ll A[N<<2],B[N<<2];
    if(l==r) return ;
    int mid=l+r>>1;
    solve(l,mid,a),solve(mid+1,r,a);
    int d=ceil(log2(r-l+2));
    for(int i=0;i<1<<d;i++) A[i]=B[i]=0;
    for(int i=l;i<=mid;i++) A[i-l+1]=a[i];
    for(int i=mid+1;i<=r;i++) B[i-mid]=a[i];
    A[0]=B[0]=1;
    NTT(A,d,1),NTT(B,d,1);
    for(int i=0;i<1<<d;i++) A[i]=A[i]*B[i]%mod;
    NTT(A,d,-1);
    for(int i=l;i<=r;i++) a[i]=A[i-l+1];
}

ll summ[N];
ll cal(int k) {
    ll ans=0;
    for(int i=1;i<=n;i++) (ans+=ksm(a[i],k))%=mod;
    return ans;
}

ll tem[N<<2];
ll fac[N],ifac[N];
int main() {
    n=Get(),m=Get();
    fac[0]=1;
    for(int i=1;i<=n;i++) fac[i]=fac[i-1]*i%mod;
    ifac[n]=ksm(fac[n],mod-2);
    for(int i=n-1;i>=0;i--) ifac[i]=ifac[i+1]*(i+1)%mod;
    for(int i=1;i<=n;i++) a[i]=Get();
    for(int i=1;i<=n;i++) summ[i]=a[i];
    int d=ceil(log2(2*n+1));
    solve(1,n,summ);
    summ[0]=1;
    Ln(ln,summ,d);
    memcpy(summ,ln,sizeof(summ));
    
    summ[0]=n;
    for(int i=1;i<=n;i++) {
        if(!(i&1)) summ[i]=summ[i]*(mod-1)%mod;
        summ[i]=summ[i]*i%mod;
    }
    
    for(int i=0;i<=n;i++) {
        A[i]=ksm(i+1,2*m)*ifac[i]%mod;
        B[i]=ksm(i+1,m)*ifac[i]%mod;
    }
    
    
    Ln(ln,B,d);
    for(int i=0;i<1<<d;i++) ln[i]=ln[i]*summ[i]%mod;
    for(int i=n;i<1<<d;i++) ln[i]=0;
    Exp(g,ln,d);
    for(int i=n;i<=1<<d;i++) g[i]=0;
    
    Inv(inv,B,d);
    for(int i=n;i<1<<d;i++) inv[i]=0;

    NTT(inv,d,1),NTT(A,d,1);
    for(int i=0;i<1<<d;i++) f[i]=inv[i]*A[i]%mod;
    NTT(f,d,-1);
    for(int i=0;i<1<<d;i++) f[i]=f[i]*summ[i]%mod;
    for(int i=n;i<1<<d;i++) f[i]=0;
    
    
    NTT(f,d,1),NTT(g,d,1);
    for(int i=0;i<1<<d;i++) f[i]=f[i]*g[i]%mod;
    NTT(f,d,-1);
    
    ll ans=fac[n-2];
    for(int i=1;i<=n;i++) ans=ans*a[i]%mod;
    
    ans=ans*f[n-2]%mod;
    cout<<ans;
    return 0;
}
扫码关注我们
微信号:SRE实战
拒绝背锅 运筹帷幄