LOJ #6183 看无可看
优秀的数学题,Orz samjia2000
这个题显然需要将和转积来处理,这个时候就要用到特征方程的一些知识了!
其实就是这个样子:$f[x]=a\times f[x-1]+b\times f[x-2]$
那么必然可以写作:$f[x]-t\times f[x-1]=k\times (f[x-1]-t\times f[x-2])$
化简:$f[x]=(k+t)\times f[x-1]-k\times t\times f[x-2]$
即:$a=k+t,b=-kt$
消去$K$,得到:$k^2=a\times k+b$
这个东西就是这个二阶递推式的特征方程。
那么显然,你可以得到:$f[x]-k_2\times f[x-1]=k_1^{x-2}\times (f[2]-k_2\times f[1]),f[x]-k_1\times f[x-1]=k_2^{x-2}\times (f[2]-k_1\times f[1])$
那么对于上述两个方程,可以得到:
$(k_1-k_2)f[x]=k_1^{x-2}\times (f[2]-k_2\times f[1])-k_2^{x-2}\times (f[2]-k_1\times f[1])$
即:$f[x]=\frac{k_1^{x-2}\times (f[2]-k_2\times f[1])-k_2^{x-2}\times (f[2]-k_1\times f[1])}{k_1-k_2}$
设:$A=\frac{f[2]-k_2\times f[1]}{k_1-k_2},B=\frac{f[2]-k_1\times f[1]}{k_1-k_2}$
故可得:$f[x]=A \times k_1^{x-2}+B\times k_2^{x-2}$,的通项公式如上
那么根据通项公式,就可以把题目里的原始式子化为$\sum\limits_{S'\subset S ,|S'|=k}A\times k_1^{\prod\limits_{x\in S'} x}+B\times k_2^{\prod\limits_{x\in S'} x}$
那么接下来的就是裸的分治FFT了!
#include#include #include #include #include #include #include #include using namespace std;#define N 100005#define mod 99991#define ll long longconst long double pi=acos(-1);int a[N],n,k;struct cp{ long double x,y; cp(){} cp(long double a,long double b){x=a,y=b;} cp operator + (const cp &a) const {return cp(x+a.x,y+a.y);} cp operator - (const cp &a) const {return cp(x-a.x,y-a.y);} cp operator * (const cp &a) const {return cp(x*a.x-y*a.y,x*a.y+y*a.x);}}A[N<<2],B[N<<2];void FFT(cp *a,int len,int flag){ int i,j,k,t;cp w,x,tmp;long long tt; for(i=k=0;i k)swap(a[i],a[k]); for(j=len>>1;(k^=j) >=1); } for(k=2;k<=len;k<<=1) { x=cp(cos(2*pi*flag/k),sin(2*pi*flag/k));t=k>>1; for(i=0;i a;int len; ploy(){a.clear();len=0;} ploy(int x){a.resize(3);a[0]=cp(1,0);a[1]=cp(x,0);len=2;} void print(){printf("%d\n",len);for(int i=0;i >=1,x=(ll)x*x%mod)if(n&1)ret=(ll)ret*x%mod;return ret;}int get(int x){return q_pow(3,x);}ploy solve(int l,int r){ if(l==r)return ploy(get(a[l]));int m=(l+r)>>1; return solve(l,m)*solve(m+1,r);}int get1(int x){return q_pow(mod-1,x);}ploy solve1(int l,int r){ if(l==r)return ploy(get1(a[l]));int m=(l+r)>>1; return solve1(l,m)*solve1(m+1,r);}int main(){ // freopen("see.in","r",stdin); // freopen("see.out","w",stdout); scanf("%d%d",&n,&k); for(int i=1;i<=n;i++)scanf("%d",&a[i]); scanf("%d%d",&f1,&f2);c1=(f1+f2)*41663ll%mod; c2=(3ll*c1-f1+mod)%mod; c1=c1*3ll%mod;c2=((ll)c2*(mod-1))%mod; // c1=c1*(ll)q_pow(66661ll,k-1)%mod;c2=(ll)c2*q_pow(mod-1,k-1)%mod; ret=solve(1,n);ret1=solve1(1,n); printf("%lld\n",(((long long)(ret.a[k].x+0.1)*c1+(long long)(ret1.a[k].x+0.1)*c2)%mod+mod)%mod);}