문제

방향 없는 그래프가 주어졌을 때, 연결 요소 (Connected Component)의 개수를 구하는 프로그램을 작성하시오.

www.acmicpc.net/problem/11724

  • 입력

첫째 줄에 정점의 개수 N과 간선의 개수 M이 주어진다. (1 ≤ N ≤ 1,000, 0 ≤ M ≤ N×(N-1)/2) 둘째 줄부터 M개의 줄에 간선의 양 끝점 u와 v가 주어진다. (1 ≤ u, v ≤ N, u ≠ v) 같은 간선은 한 번만 주어진다.

  • 출력

첫째 줄에 연결 요소의 개수를 출력한다.

과정

그래프 문제에서 dfs와 bfs를 둘 다 쓸 수 있는 경우에는 주로 bfs를 쓰게 된다. bfs는 보통 재귀를 사용하지 않고 큐를 사용하니까 이해가 더 쉽고, 구현도 더 쉬워서 손이 가는 것 같다.

일단 그래프 문제 경험이 많이 없어서 연결 요소를 구하는 간단한 문제인데 참 복잡하게 푼 것 같다. 다행히(?) ‘틀렸습니다’는 아니었지만 시간초과가 떴다.

from collections import deque
import sys

def bfs(start):
    global n, g
    q = deque()
    q.append(start)
    visit = [False] * (n+1)
    visit[start] = True
    cnt = 0

    while q:
        cnt += 1
        while q:
            v = q.popleft()
            for x in g[v]:
                if visit[x] == False:
                    visit[x] = True
                    q.append(x)
        # 아직 방문하지 않은 노드를 찾아 큐에 넣는다
        for i in range(1, n+1):
            if visit[i] == False:
                q.append(i)
                break
    return cnt

n, m = map(int, sys.stdin.readline().split())
g = [[] for _ in range(n+1)]
for _ in range(m):
    u, v = map(int, sys.stdin.readline().split())
    g[u].append(v)
    g[v].append(u)

print(bfs(1))

while문을 이중으로 돌리면서 visit 배열을 계속 처음부터 다시 탐색하니까 시간이 초과된 것 같다.

최종 코드

수정한 코드는 다음과 같다. visit 배열을 탐색하는 for 문을 제일 바깥으로 빼서 방문한 노드는 바로바로 넘어갈 수 있게끔 만들었다.

from collections import deque
import sys

def bfs():
    global n, g
    q = deque()
    visit = [False] * (n+1)
    cnt = 0

    # 1부터 확인해서 방문하지 않은 노드가 있으면 큐에 넣고 탐색 시작
    for i in range(1, n+1):
        if visit[i] == False:
            visit[i] = True
            q.append(i)
            cnt += 1
            while q:
                v = q.popleft()
                for x in g[v]:
                    if visit[x] == False:
                        visit[x] = True
                        q.append(x)

    return cnt

n, m = map(int, sys.stdin.readline().split())
g = [[] for _ in range(n+1)]
for _ in range(m):
    u, v = map(int, sys.stdin.readline().split())
    g[u].append(v)
    g[v].append(u)

print(bfs())

나는 bfs로 풀었는데 검색해보니 대부분 dfs로 풀었길래 dfs로도 구현해보았다. 코드 자체는 훨씬 간단해졌다.

dfs를 구현할 때 주의해야 할 점은 recursion limit을 설정해주어야 한다는 것이다. 백준에서는 제한이 낮아서 제한을 높여주지 않으면 에러가 뜨게 된다.

import sys
sys.setrecursionlimit(10000)

def dfs(start):
    visit[start] = True
    for x in g[start]:
        if visit[x] == False:
            dfs(x)

n, m = map(int, sys.stdin.readline().split())
g = [[] for _ in range(n+1)]
for _ in range(m):
    u, v = map(int, sys.stdin.readline().split())
    g[u].append(v)
    g[v].append(u)

visit = [False] * (n+1)
cnt = 0
for i in range(1, n+1):
    if visit[i] == False:
        dfs(i)
        cnt += 1

print(cnt)

참고 사이트