Cod sursă (job #786106)

Utilizator avatar AndiR Tanasescu Andrei Rares AndiR IP ascuns
Problemă Déjà vu Compilator cpp-32 | 2,32 kb
Rundă Arhiva de probleme Status evaluat
Dată 14 sept. 2024 00:01:59 Scor 0
/// sunt foarte aproape dp-ul merge insa gresesc la adunarea crt-ului si nu stiu de ce...

#include <iostream>
#include <fstream>

using namespace std;
ifstream fin ("dejavu.in");
ofstream fout ("dejavu.out");

typedef long long ll;

const ll Nmax=5000+5, MOD=1e9+7;

ll dp[2][Nmax][Nmax];
ll fact[Nmax], invfact[Nmax];
int vis[Nmax];

int n, q, nr, x, y, z;

ll aran(ll n, ll m){
    return fact[n]*invfact[n-m]%MOD;
}
ll comb(ll n, ll m){
    return fact[n]*invfact[m]%MOD*invfact[n-m]%MOD;
}
ll exp(ll b, ll e){
    if (e==0)
        return 1;
    if (e%2==0)
        return exp(b*b%MOD, e/2);
    return b*exp(b, e-1)%MOD;
}

ll sum_comb(ll sz, ll x, ll y, ll z){
    ll sum=0;
    for (int l=0; l<=sz; l++){
        int r=sz-l;
        if (r<=y){
            if (l!=0)
                sum=(sum+dp[0][l][z]*aran(y, r)%MOD*comb(sz, r))%MOD;
            else sum=(sum+aran(y, r)%MOD*comb(sz, r))%MOD;
        }
    }
    return sum;
}

int main(){
    
    fact[0]=1;
    for (int i=1; i<Nmax; i++)
        fact[i]=fact[i-1]*i%MOD;
    
    invfact[Nmax-1]=exp(fact[Nmax-1], MOD-2);
    for (int i=Nmax-2; i>=0; i--){
        invfact[i]=invfact[i+1]*(i+1)%MOD;
    }
    
    for (int j=1; j<Nmax; j++){
        dp[0][1][j]=j;
        dp[1][1][j]=1;
    }

    for (int i=2; i<Nmax; i++){
        // j=1
        if (i==2)
            dp[0][i][1]=1;
        
        for (int j=2; j<Nmax; j++){
            dp[1][i][j]=dp[0][i-1][j-1];
            dp[0][i][j]=j*(dp[1][i-1][j]*(i-1)+dp[0][i-1][j-1])%MOD;
        }
    }
    
    fin>>n>>q;
    
    for (int i=0; i<q; i++){
        x=y=0;
        z=n;
        ll crt=0;

        for (int j=1; j<=n; j++){
            fin>>nr;

            ll s1=sum_comb(n-j, x, y+1, z-1);
            ll s2=0;
            if (y!=0)
                s2=sum_comb(n-j, x+1, y-1, z);


            //cout<<s1<<'\n';

            for (int k=1; k<nr; k++)
                if (vis[k]==0)
                    crt=(crt+s1)%MOD;
                else if (vis[k]==1)
                    crt=(crt+s2)%MOD;

            //fout<<crt<<' ';

            vis[nr]++;
            if (vis[nr]==1){
                z--;
                y++;
            }
            else if (vis[nr]==2){
                y--;
                x++;
            }
        }
        fout<<crt<<'\n';

        for (int i=1; i<=n; i++)
            vis[i]=0;
    }

    return 0;
}