Leetcode : 629. K Inverse Pairs Array

leetcode.com

問題概要

長さ n の順列を考える.反転数が k となる数列を数えなさい.

考察

一番愚直だと思われる解法は n! 通りの数列の反転数を計算し,kとなるものを数え上げるというものである. この計算量は O(nlogn*n!) であるがこれは明らかにTLEする.

部分問題に切り分けることを考える.1 から n-1 までの数字を用いて作る数列に n を追加することを考えると漸化式を立てれそうな気がする.

dp[n][k] := 数列の長さが n で 1 ~ n の数字で構成される数列において反転数が k となる数

とすると

dp[n][k] += sum_{i=0}^{n} dp[n-1][k-i]

となる.

これでは O(n3) となるが,まだTLEする.そこで上のdpを高速に計算する方法を考える.

遷移を見ると,規則性が見えて累積和を取ることによって O(n2) に落ちる.

ソースコード

コメントの部分が O(n3) の解法です.

class Solution {
    
public:
    int kInversePairs(int n, int k) {
        static long long dp[1024][1024];
        static long long conv[1024];
        
        for(int i = 0; i < 1024; i++) {
            for(int j = 0; j < 1024; j++) {
                dp[i][j] = 0;
            }
        }
        
        // 長さ1の数列で反転数=0の数
        dp[1][0] = 1;
        
        // 右からiを挿入したときの反転数.
        for(int i = 0; i < 1024; i++) {
            conv[i] = (i*(i+1)/2);
        }
        
        long long mod = 1e9 + 7LL; 
        
        // i: 長さiの数列
        for(int i = 1; i <= n; i++) {
            // a: 今の反転数(上限はk)
            /*
            for(int a = 0; a < min(conv[i], 1001LL); a++) {
                // j: 右からj番目に(i+1)を挿入する
                // 例) 
                // [1 2 3 4]
                // j = 0: [1 2 3 4 5], j = 1 : [1 2 3 5 4], ...
                for(int j = 0; j <= i; j++) {
                    if(a+j > k) break;
                    //printf("[i=%d] : next: %d: j= %d, conv[j] = %d, dp[i][a=%d] = %d\n", i, a+conv[j], j, conv[j], a, dp[i][a]);
                    
                    dp[i+1][a+j] += dp[i][a];
                    dp[i+1][a+j] %= mod;
                    
                }
            }
            */
            long long cnt = 0LL;
            for(int j = 0; j < min(conv[i], 1001LL); j++) {
                cnt += dp[i-1][j];
                if(j-i >= 0) {
                    cnt -= dp[i-1][j-i];
                }
                //printf("[i=%d] : j=%d, conv[j] = %d, cnt = %d\n", i, j, conv[j], cnt);

                dp[i][j] += cnt;
                dp[i][j] %= mod;
            }
        }

        return dp[n][k];
    }
};

所感

こういう問題をバグなく高速に通せるようにしておきたい.