문제

알파벳 대문자가 한 칸에 한 개씩 적혀있는 N×M 크기의 문자판이 있다. 편의상 모든 문자는 대문자라 생각하자. 예를 들어 아래와 같은 문자판을 보자.

K A K T
X E A S
Y R W U
Z B Q P

이 문자판의 한 칸(아무 칸이나 상관없음)에서 시작하여 움직이면서, 그 칸에 적혀 있는 문자들을 차례대로 모으면 하나의 단어를 만들 수 있다. 움직일 때는 상하좌우로 K개의 칸까지만 이동할 수 있다. 예를 들어 K=2일 때 아래의 그림의 가운데에서는 ‘X’ 표시된 곳으로 이동할 수 있다.

  X      
  X      
X X   X X
  X      
  X      

반드시 한 칸 이상 이동을 해야 하고, 같은 자리에 머물러 있을 수 없다. 또, 같은 칸을 여러 번 방문할 수 있다.

이와 같은 문자판과 K, 그리고 하나의 영단어가 주어졌을 때, 이와 같은 영단어를 만들 수 있는 경로가 총 몇 개 존재하는지 알아내는 프로그램을 작성하시오.

위의 예에서 영단어가 BREAK인 경우에는 다음과 같이 3개의 경로가 존재한다. 앞의 수는 행 번호, 뒤의 수는 열 번호를 나타낸다.

  • (4, 2) (3, 2) (2, 2) (1, 2) (1, 1)
  • (4, 2) (3, 2) (2, 2) (1, 2) (1, 3)
  • (4, 2) (3, 2) (2, 2) (2, 3) (1, 3)

www.acmicpc.net/problem/2186

  • 입력

첫째 줄에 N(1 ≤ N ≤ 100), M(1 ≤ M ≤ 100), K(1 ≤ K ≤ 5)가 주어진다. 다음 N개의 줄에는 M개의 알파벳 대문자가 주어지는데, 이는 N×M 크기의 문자판을 나타낸다. 다음 줄에는 1자 이상 80자 이하의 영단어가 주어진다. 모든 문자들은 알파벳 대문자이며, 공백 없이 주어진다.

  • 출력

첫째 줄에 경로의 개수를 출력한다. 이 값은 int 범위이다.

과정

처음에는 다른 문제들처럼 bfs를 가지고 풀었는데 시간 초과가 떴다.

import sys
from collections import deque

def bfs(sx, sy):
    global cnt, arr, target, n, m, k
    q = deque()
    q.append([[sx, sy], 1])

    while q:
        now, idx = q.popleft()

        for dx, dy in zip([-1, 0, 1, 0], [0, 1, 0, -1]):
            for i in range(1, k+1):
                nx = now[0] + dx * i
                ny = now[1] + dy * i

                # 다음 문자가 target 문자와 같으면 계속 탐색
                if 0 <= nx < n and 0 <= ny < m and arr[nx][ny] == target[idx]:
                    if idx == len(target) - 1:
                        cnt += 1
                    else:
                        q.append([[nx, ny], idx+1])

n, m, k = map(int, sys.stdin.readline().split())
arr = [list(sys.stdin.readline().strip()) for _ in range(n)]
target = list(sys.stdin.readline().strip())
cnt = 0
for i in range(n):
    for j in range(m):
        # 첫 글자가 같은 곳만 bfs 탐색
        if arr[i][j] == target[0]:
            bfs(i, j)
print(cnt)

최종 코드

확인해보니 dfs와 dp를 이용해서 메모이제이션을 해줘야 시간 안에 완료할 수 있었다. 그런데 이 문제에서는 한 칸을 여러번 방문할 수 있어서 dp 값을 단순히 칸마다 계산하면 안 되고 index까지 고려해줘야 해서 더 복잡했다.

index까지 고려하기 위해 dp를 3차원 배열로 선언하고 bfs에서 dfs로 알고리즘을 수정했다. dp 값이 존재하는 경우 dp 값을 활용하고, 존재하지 않으면 dfs 탐색을 계속하는 방식으로 구현했다.

하지만 아래 코드는 pypy3에서만 통과하고 python3에서는 시간초과가 떴다.

import sys
sys.setrecursionlimit(100000)

def dfs(sx, sy, idx):
    global arr, target, dp, n, m, k

    dp[sx][sy][idx] = 0 # 방문처리
    for dx, dy in zip([-1, 0, 1, 0], [0, 1, 0, -1]):
        for i in range(1, k+1):
            nx = sx + dx * i
            ny = sy + dy * i

            # 다음 문자가 target 문자와 같으면 계속 탐색
            if 0 <= nx < n and 0 <= ny < m and arr[nx][ny] == target[idx]:
                # 단어가 완성되었다면 +1
                if idx == len(target) - 1:
                    tmp = 1
                # dp 값이 없다면 계속 dfs 탐색 진행
                elif dp[nx][ny][idx+1] == -1:
                    tmp = dfs(nx, ny, idx+1)
                # dp 값이 있다면 탐색 중단하고 해당 값을 사용
                else:
                    tmp = dp[nx][ny][idx+1]
                dp[sx][sy][idx] += tmp
    return dp[sx][sy][idx]

n, m, k = map(int, sys.stdin.readline().split())
arr = [list(sys.stdin.readline().strip()) for _ in range(n)]
target = list(sys.stdin.readline().strip())
# dp는 3차원 배열로 -1로 초기화
dp = [[[-1]*len(target) for _ in range(m)] for _ in range(n)]
cnt = 0
for i in range(n):
    for j in range(m):
        # 첫 글자가 같은 곳만 bfs 탐색
        if arr[i][j] == target[0]:
            cnt += dfs(i, j, 1)
print(cnt)

Enhanced

python3에서도 통과하려면 어떻게 코드를 구현해야하나 찾아봤더니 두 가지가 달랐다. 솔직히 저렇게 한다고 왜 시간이 거의 두 배나 빨라지는지 잘 모르겠다ㅠㅠ. 그래도 python3에서 통과는 되더라.

  1. 이동할 방향을 구할 때 이중 for loop 대신 다른 방식 사용
  2. 가지치기를 dfs 호출 전에 하지 않고 dfs 호출 후 함수 맨 처음에서 진행
import sys
sys.setrecursionlimit(100000)

def dfs(sx, sy, idx):
    global arr, target, dp, n, m, k

    # 이미 탐색한 값이면 바로 반환
    if dp[sx][sy][idx] != -1:
        return dp[sx][sy][idx]
    # 문자가 일치하지 않으면 방문처리하고 반환
    if arr[sx][sy] != target[idx]:
        dp[sx][sy][idx] = 0
        return 0
    # 단어가 완성되었다면 1 반환
    if idx == len(target) - 1:
        return 1

    tmp = 0
    for i in range(-k, k+1):
        if i == 0:
            continue
        nx = sx + i
        ny = sy + i

        if 0 <= nx < n:
            tmp += dfs(nx, sy, idx + 1)
        if 0 <= ny < m:
            tmp += dfs(sx, ny, idx + 1)

    dp[sx][sy][idx] = tmp
    return dp[sx][sy][idx]

n, m, k = map(int, sys.stdin.readline().split())
arr = [list(sys.stdin.readline().strip()) for _ in range(n)]
target = list(sys.stdin.readline().strip())
# dp는 3차원 배열로 -1로 초기화
dp = [[[-1]*len(target) for _ in range(m)] for _ in range(n)]
cnt = 0
for i in range(n):
    for j in range(m):
        # 첫 글자가 같은 곳만 bfs 탐색
        if arr[i][j] == target[0]:
            cnt += dfs(i, j, 0)
print(cnt)

참고 사이트