【Python3】AtCoder ABC 157 D – Friend Suggestions を Union-Find を理解しつつ解く

Union-Find の勉強をしながら解いたのでメモ。

とあるSNSに、人 $1$ 、人 $2$ 、……、人 $N$ が登録しています。

この $N$ 人の間には、 $M$ 組の「友達関係」と、 $K$ 組の「ブロック関係」が存在します。
$i = 1, 2, …, M$ について、人 $A_i$ と人 $B_i$ は友達関係にあります。この関係は双方向的です。
$i = 1, 2, …, K$ について、人 $C_i$ と人 $D_i$ はブロック関係にあります。この関係は双方向的です。

以下の $4$ つの条件が満たされるとき、人 $a$ は人 $b$ の「友達候補」であると定義します。

・ $a \neq b$ である。
・人 $a$ と人 $b$ はブロック関係に無い。
・人 $a$ と人 $b$ は友達関係に無い。
・$1$ 以上 $N$ 以下の整数から成るある数列 $c_0,c_1,c_2,⋯,c_L$ が存在し、$c_0=a$ であり、 $c_L=b$ であり、 $i=0,1,⋯,L−1$ について、人 $c_i$ と人 $c_{i+1}$ は友達関係にある。

$i=1,2,…N$ について、友達候補の数を答えてください。

D – Friend Suggestions

Union-Find (素集合データ構造)とは

このスライドの解説が何よりもわかりやすいです。

集合の分割をモデル化したものと捉えるとわかりやすそうな気がします。

Union-Find は基本的な操作が Union (まとめる)Find (どの集合に属するか判定)の 2 つになっています。また、グループをまとめていく操作はできても、分割する操作はできないことに注意が必要です。

サトゥー

Union-Find は要素をグループ分けする上で役に立つデータ構造。

Union-Find の Python での実装

とりあえず UnionFInd クラスとしてまとめた結論を。一通り調べてでてきた便利そうな機能をてんこ盛りにしてあります。GitHubのソースコード

class UnionFind:
    def __init__(self, n):
        self.n = n
        self.parent = [i for i in range(n)]  # 親
        self.rank = [1] * n  # 木の高さ
        self.size = [1] * n  # size[i] は i を根とするグループのサイズ

    def find(self, x):  # x の根を返す
        if self.parent[x] == x:
            return x
        else:
            self.parent[x] = self.find(self.parent[x])  # 経路圧縮
            return self.parent[x]

    def unite(self, x, y):  # x, y の属する集合を併合する
        x = self.find(x)
        y = self.find(y)
        if x != y:
            if self.rank[x] < self.rank[y]:
                self.parent[x] = y
                self.size[y] += self.size[x]
            else:
                self.parent[y] = x
                self.size[x] += self.size[y]
                if self.rank[x] == self.rank[y]:
                    self.rank[x] += 1

    def is_same(self, x, y):  # x, y が同じ集合に属するか判定する
        return self.find(x) == self.find(y)

    def group_size(self, x):  # x が属する集合の大きさを返す
        return self.size[self.find(x)]

    def group_members(self, x):  # x が属する集合の要素を返す
        root = self.find(x)
        return [i for i in range(self.n) if self.find(i) == root]

    def roots(self):  # すべての根をリストで返す
        return [i for i, x in enumerate(self.parent) if i == x]

    def group_count(self):  # 木の数を返す
        return len(self.roots())

    def all_group_members(self):  # すべての木の要素を辞書で返す
        return {r: self.group_members(r) for r in self.roots()}

    def __str__(self):  # print 表示用
        return '\n'.join('{}: {}'.format(r, self.group_members(r)) for r in self.roots())


これで多分正しく動くハズ(間違ってたら教えてほしいです 🙇‍♂️)

経路圧縮を用いた効率化

スライド 11 枚目を参照。具体的には find() 内で経路圧縮を実装しています。

    def find(self, x):  # x の根を返す
        if self.parent[x] == x:
            return x
        else:
            self.parent[x] = self.find(self.parent[x])  # 経路圧縮
            return self.parent[x]

ランクを用いた効率化

スライド 12 枚目を参照。 unite() でマージするときにランクを比較し、低い方を高い方に繋げることで木が深くなることを避けています。

    def unite(self, x, y):  # x, y の属する集合を併合する
        x = self.find(x)
        y = self.find(y)
        if x != y:
            if self.rank[x] < self.rank[y]:
                self.parent[x] = y
                self.size[y] += self.size[x]
            else:
                self.parent[y] = x
                self.size[x] += self.size[y]
                if self.rank[x] == self.rank[y]:
                    self.rank[x] += 1

実装する上で参考にした記事

以下に実装する上で参考にさせていただいた記事を載せておきます。Python ですと note.mkmk.me さんの記事はめちゃくちゃ参考になりますし、他にもこの界隈は偉大な先人の方々がわかりやすい記事を書いてくださっているので色々見てみると勉強になると思います。

問題の回答

さて、これを用いて本問を解いてみます。

本問の「友達候補」とは、「友達の友達の友達の友達」のように、友達でつながっている人のことを指します。
ということは、各々の友達候補のネットワークを 1 つにまとめて、ここから「友達」と「ブロック」を引けば答えが得られそう。

具体的には、

  • $X$ : 人 $i$ が属する集合のサイズ
  • $Y$ : 人 $i$ と同じ集合に属する人のうち、人 $i$ と友達関係もしくはブロック関係にある人数

としたとき、 $X – Y – 1$ が人 $i$ の答えになります。

これをふまえて、先程の UnionFind クラスを活用して回答を書いてみます。(余計な関数は省いた)

class UnionFind:
    def __init__(self, n):
        self.n = n
        self.parent = [i for i in range(n)]  # 親
        self.rank = [1] * n  # 木の高さ
        self.size = [1] * n  # size[i] は i を根とするグループのサイズ

    def find(self, x):  # x の根を返す
        if self.parent[x] == x:
            return x
        else:
            self.parent[x] = self.find(self.parent[x])  # 経路圧縮
            return self.parent[x]

    def unite(self, x, y):  # x, y の属する集合を併合する
        x = self.find(x)
        y = self.find(y)
        if x != y:
            if self.rank[x] < self.rank[y]:
                self.parent[x] = y
                self.size[y] += self.size[x]
            else:
                self.parent[y] = x
                self.size[x] += self.size[y]
                if self.rank[x] == self.rank[y]:
                    self.rank[x] += 1

    def is_same(self, x, y):  # x, y が同じ集合に属するか判定する
        return self.find(x) == self.find(y)

    def group_size(self, x):  # x が属する集合の大きさを返す
        return self.size[self.find(x)]


N, M, K = map(int, input().split())
A, B = [0] * M, [0] * M  # 入力時に 0-index に合わせる
C, D = [0] * K, [0] * K  # 入力時に 0-index に合わせる

direct =[[] for _ in range(N)]  # 同じ集合の友達もしくはブロック関係の人
uf = UnionFind(N)  # 友達関係をまとめた UnionFind 木

for i in range(M):
    A[i], B[i] = map(int, input().split())
    A[i] -= 1
    B[i] -= 1
    direct[A[i]].append(B[i])
    direct[B[i]].append(A[i])
    uf.unite(A[i], B[i])

for i in range(K):
    C[i], D[i] = map(int, input().split())
    C[i] -= 1
    D[i] -= 1
    if uf.is_same(C[i], D[i]):
        direct[C[i]].append(D[i])
        direct[D[i]].append(C[i])

ans = [0] * N
for i in range(N):
    ans[i] = uf.group_size(i) - len(direct[i]) - 1
print(*ans)

こちらのコードで 1511 ms で AC でした。

この記事を書いた人

サトゥー

東大学際情報学府M1。情報科学と教養の海に溺れています。面白いことをやるのがすきです。