nakashiiiの自由帳

自由に書きます

ABC212 E - Safety Journey(配るDP → 貰うDPにして高速化する)

今日参加したあさかつ で解いた問題が貰うDPを利用した勉強になる問題だったので、そのメモ

(最後にお世話になった参考記事、動画のリンクを載せてあるので、そちらを先に見てもらった方が分かりやすいかもです。)

問題

atcoder.jp

考察(間に合わずTLEするバージョン)

まずは、愚直なDPの考察から。こんな感じで考えて、DPを組んでいった。

  • 頂点数5000、日数5000なので O(N * K) が間に合いそうだな
  • となると、i日目に、頂点jにいるような 組み合わせdp[i][j] として、dp[k][0] (0-indexed) が答えだ!
  • なので、とりあえず壊れた辺を除外してグラフ構築して、DPやろう!!

で、組んだDPが以下。めでたく O(N2 * K) となりTLEしました。
配るDPで組んでいて、配る度に壊れてない橋でつながった頂点分(= O(N))だけ計算量がかかるのでダメだった。

#include <bits/stdc++.h>

#include <atcoder/all>

using namespace atcoder;
using namespace std;
using ll = long long;
const int IINF = 0x3f3f3f3f;  // 2倍しても負にならない
const long long LINF = 0x3f3f3f3f3f3f3f3fLL;
long long MOD = 1000000007;
using plint = pair<ll, ll>;
using pint = pair<int, int>;
#define all(obj) (obj).begin(), (obj).end()
using Graph = vector<vector<int>>;

// 変数宣言------------------

// 関数定義------------------

// メイン------------------
int main() {
  // デバッグ用
  ifstream in("input.txt");
  if (in.is_open()) {
    cin.rdbuf(in.rdbuf());
  }

  using mint = modint998244353;

  int n, m, k;
  cin >> n >> m >> k;

  Graph g(n);

  set<pint> broken;  // 壊れた辺
  for (int i = 0; i < m; i++) {
    int u, v;
    cin >> u >> v;
    u--, v--;
    if (v < u) {
      swap(u, v);
    }
    broken.insert({u, v});
  }

  // 壊れてない辺だけ, 張る
  for (int u = 0; u < n - 1; u++) {
    for (int v = u + 1; v < n; v++) {
      if (broken.find({u, v}) == broken.end()) {
        g[u].push_back(v);
        g[v].push_back(u);
      }
    }
  }

  // i日目に, 都市j にいるような組み合わせ dp[i][j]
  vector dp(5100, vector<mint>(5100, 0));
  dp[0][0] = 1;
  for (int i = 0; i < k; i++) {
    for (int j = 0; j < n; j++) {
      for (auto v : g[j]) {
        dp[i + 1][v] += dp[i][j];
      }
    }
  }

  cout << dp[k][0].val() << endl;
}

考察(ACするバージョン)

(結局自力ではダメだったので、以下は解説読んだあとの考察。)
何かしら考察をする必要があるが、壊れた橋が少ない(max 5000)ことがポイント。
これをうまく使ってDPを構成する。自力ACに必要だと感じた考察は以下。

  • 壊れた橋が少ない(5000)のでうまく使えないか?
  • 橋が全てかかっている場合から、壊れた橋の分だけうまく引けないか?
  • 橋が全てかかっている場合は高速に計算できないか?

更に考察していくと、以下のようになる

  • i日目のdp[i][j]に関して、(i-1)日目の全ての頂点の組み合わせの総和sum(dp[i-1][0] + ... + dp[i-1][n-1]) はO(N) で求められる
  • dp[i][j] に足したくないのは、以下のふたつ
    • 壊れた橋からの移動 → i日目に引きたいものは、O(M)で計算できる (ここが一番理解が難しかった、、)
    • 自分自身(j)からの移動 → O(1) で分かる
  • したがって、各iに対してやることは以下。
    • 全ての橋が正常な状態の組み合わせの総和を求める O(N)
    • 各頂点について壊れた辺からの組み合わせを引く O(M)
    • 自分自身から移動してきた組み合わせを引く O(1)
  • よって、全体の計算量はO(K * (N+M)) で間に合う

補足

最初の愚直なDPは配るDPで書いていたけど、しれっと貰うDPに変更している。
これは、貰うほうじゃないとうまく高速化ができないから。(もしかして配る方でも高速化できるかもだけど、私は分かんないです)
「なんで貰うの方がいいの?」という部分は、自分もまだ理解が曖昧だが、現状はこんなイメージ。
伝わる、、かな、、伝わらないかも、、

配るDPより貰うDPの方が高速化しやすいイメージ

図にあるように、配る側はまとめるのが難しそうだが、貰う側は配る方の総和を計算しておけば、あとは各頂点に足すだけ。

というわけで実装。

#include <bits/stdc++.h>

#include <atcoder/all>
using namespace atcoder;
using namespace std;
using ll = long long;
const int IINF = 0x3f3f3f3f;  // 2倍しても負にならない
const long long LINF = 0x3f3f3f3f3f3f3f3fLL;
long long MOD = 1000000007;
using plint = pair<ll, ll>;
using pint = pair<int, int>;
#define all(obj) (obj).begin(), (obj).end()

// メイン------------------
int main() {
  // デバッグ用
  ifstream in("input.txt");
  if (in.is_open()) {
    cin.rdbuf(in.rdbuf());
  }

  using mint = modint998244353;

  int n, m, k;
  cin >> n >> m >> k;

  // 使えない辺
  vector<int> a(m);
  vector<int> b(m);
  for (int i = 0; i < m; i++) {
    cin >> a[i] >> b[i];
    a[i]--, b[i]--;
  }

  // i日目に 都市j にいるような組み合わせ dp[i][j]
  // 高速化したいので, 貰うDPで考える
  vector dp(5100, vector<mint>(5100, 0));
  dp[0][0] = 1;

  for (int i = 1; i <= k; i++) {
    // 一つ前のすべての頂点から経路
    mint sum_all = 0;
    for (int j = 0; j < n; j++) {
      sum_all += dp[i - 1][j];
    }

    // 壊れた橋の頂点からの経路を引いていく
    for (int j = 0; j < m; j++) {
      dp[i][a[j]] -= dp[i - 1][b[j]];
      dp[i][b[j]] -= dp[i - 1][a[j]];
    }

    // dp[i][j] = すべての頂点からの経路 - 壊れた橋の頂点からの経路 - 自分からの経路
    // すべての頂点を足して、自分自身からの経路を引く
    for (int j = 0; j < n; j++) {
      dp[i][j] += (sum_all - dp[i - 1][j]);
    }
  }

  cout << dp[k][0].val() << endl;
}

お世話になった記事、動画