あめんばーどのバーチャル日記

バーチャルな世界で過ごして競プロしてる人の雑談

【AtCoder】F - Lottery / AtCoder Beginner Contest 243

今下書きしてるブログでも言う予定なんですが、御託はそろそろなしにしようかなと。 今回のF問題、さくっとできたので書いておきます。

自分なりの雑な言い方しただけで、解説とほぼ変わりません。

atcoder.jp

問題

くじでN種類の景品がもらえる。iが手に入る確率は\displaystyle{\frac{W_i}{\sum_{i=1}^{n}W_i}}
くじをK回引いた時にM種類の景品が求める確率はいくらか。mod998244353で答えよ。

解説

離散的な確率は基礎的な部分に立ち返ると\displaystyle{\frac{条件を満たす通り数}{全ての通り数}}である。

この問題のそれぞれの言う確率は言い換えると、一回のくじは  \sum_{i=1}^{n} W_i 枚のくじのうちW_i枚のくじを引く確率である。 この分母の「全ての通り数」を求めるのは簡単で(\sum_{i=1}^{n} W_i)^{K}である。
つまりこの分子「条件を満たすもの」を求めるればいいので、確率ではなく下記で言い換えられる。

それぞれの景品が当たるくじがW_i枚ずつ入っている。くじは毎回戻すものとしてM種類のくじが引けるのは何通りか。

ここで具体例を考えて行く。例えば下記としよう

N=3
M=2
K=3
W=\{1,2,3\}

具体例の1つとして、まず1を2枚、2を1枚取るパターンを考えよう。 まず1を2枚取るのは3つのうち2か所取るので_3C_2、あとその場所にW_1が二乗分いる。 次に2を1枚取るのは残り1か所の_1C_1と、その場所にW_2通りのくじの選びがある。 と、コンビネーションとW_iの累乗をかける形になる。。

よって1を2枚、2を1枚取るパターンは以上により _3C_2×(W_1×W_1)× _1C_1×W_2

1を2枚、3を1枚取るパターンならこうだ
_3C_2×(W_1×W_1)× _1C_1×W_3

2を2枚、3を1枚取るパターンでもこう
 _3C_2×(W_2×W_2)×_1C_1×W_3

M=2に反するが、全部1枚取るならこう
_3C_1×W_1×_2C_1× W_2×_1C_1×W_3

こんな形で、後から乗算されてゆくものは「これから引く枚数の累乗」×「残った枚数から引く枚数に対するコンビネーション」である。
この計算は「残り何枚のくじが引けるか(=これまで何枚引いたか)」が決まっていれば勝手に計算できるので、dpでひとまとめにできるという算段。

その上で引いた種類数も必要であることを留意した上で下記のdpを考える。

dp[i種類目まで決定 ][j種類を引いている][k個引いている]

  1. i種類目でk枚から新たに何枚引くかを選択する
  2. 一枚以上選択する場合はjが+1された遷移先にプラスする
  3. 0枚の場合は選択しないのでjが加算されない遷移先にプラスする

これはO(NMK^{2})で済む。(提出コードはO(N^{2}K^{2}logK)

あとは最初で言及した分母で割ろうドットコム。

https://atcoder.jp/contests/abc243/submissions/30088294

int main() {
  ll N, M, K;
  cin >> N >> M >> K;
  VL W = read(N);

  mint2 sumW = 0;
  REP(i, N)
    sumW += W[i];
  vector dp(N + 1, vector(N + 1, vector<mint2>(K + 1, 0)));
  vector<mint2> f;
  Factorical(51, f);

  dp[0][0][0] = 1;
  REP(i, N) {
    REP(j, i + 1) {
      REP(k, K + 1) {
        FOR(k2, k, K + 1) {
          // k個からk2個になるまでWiを引く
          mint2 add = mint2(W[i]).pow(k2 - k) * Combi(K - k, k2 - k, f);
          if (k == k2)
            dp[i + 1][j][k2] += dp[i][j][k] * add;
          else
            dp[i + 1][j + 1][k2] += dp[i][j][k] * add;
        }
      }
    }
  }
  cout << (dp[N][M][K] / sumW.pow(K)).val();
  return 0;
}
追記

なんか割り算減らしたり細かい事O(NMK^{2})になるまでやったら219ms->17msまで削れた 具体的には

  • どうせK!は全部共通で書けてるので、そこの部分は最後にまとめる
  • W_iの累乗は前の奴からかければいいので、別にいらない
  • M種類より上を枝刈り

ギリギリな解法だと意外と馬鹿にならない感じの変化になるんですね……。

https://atcoder.jp/contests/abc243/submissions/30104426

int main() {
  ll N, M, K;
  cin >> N >> M >> K;
  VL W = read(N);

  vector<mint2> f(51, 1);
  REP(i, 50) 
    f[i + 1] = f[i] / (i + 1);

  vector dp(N + 1, vector(M + 1, vector<mint2>(K + 1, 0)));
  dp[0][0][0] = 1;
  REP(i, N) {
    REP(j, min(i, M) + 1) {
      REP(k, K + 1) {
        dp[i + 1][j][k] += dp[i][j][k];
        if (j == M)
          continue;
        mint2 p = dp[i][j][k];
        FOR(k2, k + 1, K + 1) {
          p *= W[i];
          dp[i + 1][j + 1][k2] += p * f[k2 - k];
        }
      }
    }
  };
  cout << (dp[N][M][K] / (f[K] * mint2(accumulate(all(W), 0)).pow(K))).val();
  return 0;
}