nakashiiiの自由帳

自由に書きます

蟻本 p103 Conscription(最小全域木 プリム法とクラスカル法)

蟻本でプリム法とクラスカル法を勉強したのでメモ。 クラスカル法の方がかなりスッキリ書ける。

(注)2ケースくらいでしかテストしていないので、バグらせていたらごめんなさい

問題(ざっくり)

N人の女とM人の男を全員徴兵する際の最小コストを求めよ。

  • 徴兵コストは一人あたり10000
  • 女と男には親密度が設定されており、徴兵する際、徴兵済みの人と親密度がある場合はコストをその分差し引く
  • 同じ男女が複数通りの親密度を持つ場合もある

原文はこちら 3723 -- Conscription

解法

連結成分ごとに最小重み森を計算して、最大の徴兵コストから差し引く。

最大重み森とは、最小全域木の逆で、無向グラフからコストが最大になるように頂点を選んだ全域木を指す。 要は最小全域木の逆の意味。
今回は差し引くコストを最大にしたいので、最大重み森を求める。

また、頂点N+Mは連結でない頂点もあるので、プリム法を使う場合は連結成分ごとに計算しなければいけないことに注意。Union Findを使って連結判定しながら最大重み森を求める。

プリム法

プリム法でやると閉路判定のためにUnion Findを使う必要がある。

#include <bits/stdc++.h>
using namespace std;
using pint = pair<int, int>;
struct Edge {
  int to, cost;
  Edge(int t, int c) : to(t), cost(c) {}
};
using Graph = vector<vector<Edge>>;

// プリム法
// prim(無向グラフ, 利用したかどうか, 木のサイズ, 開始位置)
int prim(Graph &g, int used[], int tree_size, int start) {
  // コスト順に取り出したいので, Edgeではなくpairで優先キューに入れる
  priority_queue<pint, vector<pint>, greater<pint>> que;

  que.push(make_pair(0, start));  // 任意の点からスタートする(距離, 頂点)
  int ans = 0;                    // 最小全域木の合計コスト

  while (tree_size > 0) {
    pint cur = que.top();  // 現在の木に隣接する頂点で, 一番コストが小さいもの
    que.pop();
    int v = cur.second;
    int cost = cur.first;

    // 使用済みならスキップ
    if (used[v]) {
      continue;
    }

    // コストをプラスして, 使用済みにする
    ans += cost;
    used[v] = 1;  // 使用済みにする
    tree_size--;  // 使用済み頂点数を更新

    // 未探索の頂点をすべて優先キューに入れる
    for (auto next : g[v]) {
      if (!used[next.to]) {
        que.push(make_pair(next.cost, next.to));
      }
    }
  }

  return ans;
}

// union find
class UnionFind {
 public:
  // 親の番号を格納する。親だった場合は-1
  vector<int> Parent;
  UnionFind(int N) { Parent = vector<int>(N, -1); }

  // Aがどのグループに属しているか調べる
  int root(int A) {
    if (Parent[A] < 0) return A;
    return Parent[A] = root(Parent[A]);
  }

  // 自分のいるグループの頂点数を調べる
  int size(int A) {
    return -Parent[root(A)];  //親をとってきたい
  }

  // AとBをくっ付ける
  bool unite(int A, int B) {
    // AとBを直接つなぐのではなく、root(A)にroot(B)をくっつける
    A = root(A);
    B = root(B);
    if (A == B) {
      //すでにくっついてるからくっ付けない
      return false;
    }
    // 大きい方(A)に小さいほう(B)をくっ付ける
    // 大小が逆だったらひっくり返す
    if (size(A) < size(B)) {
      swap(A, B);
    }
    // Aのサイズを更新する
    Parent[A] += Parent[B];
    // Bの親をAに変更する
    Parent[B] = A;
    return true;
  }
};

int main() {
  int n, m, r;
  cin >> n >> m >> r;
  int k = n + m;

  Graph g(k);
  UnionFind uni(k);

  // 入力
  for (int i = 0; i < r; i++) {
    int x, y, d;
    cin >> x >> y >> d;
    g[x].push_back(Edge(y + n, -d));  // -d:最大重み森のため、コスト符号を反転
    g[y + n].push_back(Edge(x, -d));
    uni.unite(x, y + n);
  }

  int reduced_cost = 0;  //関係を使って削減できるコスト

  // 連結成分ごとにプリム法でコストを求める
  int used[k]{};
  for (int i = 0; i < k; i++) {
    if (!used[i]) {
      int tree_size = uni.size(i);
      reduced_cost += prim(g, used, tree_size, i);
    }
  }

  cout << (10000 * k) + reduced_cost << endl;
}

クラスカル

#include <bits/stdc++.h>
using namespace std;
using pint = pair<int, int>;
struct Edge {
  int from, to, cost;
  Edge(int f, int t, int c) : from(f), to(t), cost(c) {}
};
struct EdgeLess {  // 大小比較用の関数オブジェクトを定義することもできる
  bool operator()(const Edge& a, const Edge& b) const noexcept {
    // キーとして比較したい要素を列挙する
    return std::tie(a.cost) < std::tie(b.cost);
  }
};
using Graph = vector<vector<Edge>>;

class UnionFind {
 public:
  // 親の番号を格納する。親だった場合は-1
  vector<int> Parent;
  UnionFind(int N) { Parent = vector<int>(N, -1); }

  // Aがどのグループに属しているか調べる
  int root(int A) {
    if (Parent[A] < 0) return A;
    return Parent[A] = root(Parent[A]);
  }

  // 自分のいるグループの頂点数を調べる
  int size(int A) {
    return -Parent[root(A)];  //親をとってきたい
  }

  // AとBをくっ付ける
  bool unite(int A, int B) {
    // AとBを直接つなぐのではなく、root(A)にroot(B)をくっつける
    A = root(A);
    B = root(B);
    if (A == B) {
      //すでにくっついてるからくっ付けない
      return false;
    }
    // 大きい方(A)に小さいほう(B)をくっ付ける
    // 大小が逆だったらひっくり返す
    if (size(A) < size(B)) {
      swap(A, B);
    }
    // Aのサイズを更新する
    Parent[A] += Parent[B];
    // Bの親をAに変更する
    Parent[B] = A;
    return true;
  }
};

// クラスカル法
// (辺の集合, 頂点の数)
int krs(vector<Edge>& es, int n) {
  UnionFind uni(n);
  int ans = 0;

  sort(es.begin(), es.end(), EdgeLess{});  // コストが小さい順にソート

  for (auto e : es) {
    // 閉路になる場合はスキップ
    if (uni.root(e.from) == uni.root(e.to)) {
      continue;
    }
    ans += e.cost;
    uni.unite(e.from, e.to);
  }

  return ans;
}

int main() {
  int n, m, r;
  cin >> n >> m >> r;
  int k = n + m;

  vector<Edge> es(r, Edge(0, 0, 0));

  // 入力
  for (int i = 0; i < r; i++) {
    int x, y, d;
    cin >> x >> y >> d;
    es[i] = Edge(x, y + n, -d);
  }

  cout << (10000 * k) + krs(es, k) << endl;
}