문제

동호와 규완이는 212호에서 문자열에 대해 공부하고 있다. 김진영 조교는 동호와 규완이에게 특별 과제를 주었다. 특별 과제는 특별한 문자열로 이루어 진 사전을 만드는 것이다. 사전에 수록되어 있는 모든 문자열은 N개의 “a”와 M개의 “z”로 이루어져 있다. 그리고 다른 문자는 없다. 사전에는 알파벳 순서대로 수록되어 있다.

규완이는 사전을 완성했지만, 동호는 사전을 완성하지 못했다. 동호는 자신의 과제를 끝내기 위해서 규완이의 사전을 몰래 참조하기로 했다. 동호는 규완이가 자리를 비운 사이에 몰래 사전을 보려고 하기 때문에, 문자열 하나만 찾을 여유밖에 없다.

N과 M이 주어졌을 때, 규완이의 사전에서 K번째 문자열이 무엇인지 구하는 프로그램을 작성하시오.

www.acmicpc.net/problem/1256

  • 입력

첫째 줄에 N, M, K가 순서대로 주어진다. N과 M은 100보다 작거나 같은 자연수이고, K는 1,000,000,000보다 작거나 같은 자연수이다.

  • 출력

첫째 줄에 규완이의 사전에서 K번째 문자열을 출력한다. 만약 규완이의 사전에 수록되어 있는 문자열의 개수가 K보다 작으면 -1을 출력한다.

과정

문제를 보고 왠지 DP가 아닐까 생각했었는데 조합을 구할 때 DP를 사용하는 방법이 있긴 있었다.

. 0 1 2 3 4
0 1 1 1 1 1
1 1 2 3 4 5
2 1 3 6 10 15
3 1 4 10 20 35
4 1 5 15 35 70

n+mCm = dp[n][m] = dp[n-1][m] + dp[n][m-1]

위와 같이 dp를 이용해 표를 채워서 조합의 값을 얻을 수 있다.

import sys
#조합 계산
def cal(n, m):
    global dp
    if n == 0 or m == 0:
        return 1
    if dp[n][m]:
        return dp[n][m]
    else:
        dp[n][m] = cal(n-1, m) + cal(n, m-1)
        return dp[n][m]

n, m, k = map(int, input().split())

dp = [[0]*(m+1) for _ in range(n+1)]

# k가 범위 밖이면 -1
if cal(n, m) < k:
    print(-1)
    sys.exit(0)

res = ""
k -= 1
while True:
    if n == 0 or m == 0:
        res = res + 'z'*m + 'a'*n
        break

    if cal(n-1, m) > k:
        res = res + 'a'
        n -= 1
    else:
        res = res + 'z'
        # m 먼저 계산하면 안 됨!
        k -= cal(n-1, m)
        m -= 1

print(res)

이 코드도 처음엔 실패했었는데 마지막쯤에서 m -= 1 을 먼저 하고 k -= cal(n-1, m)을 하는 바람에 m이 이미 -1이 되어버려서 계산이 꼬인 것이었다. 세세한 부분까지 더 신경쓰자!

최종 코드

조합을 구할 때 dp를 사용해도 되지만 그냥 factorial로 계산해도 아무 문제 없었다. 시간 제약만 없다면 이게 더 간단해서 좋은 것 같다.

import sys
from math import factorial

def cal(n, m):
    return factorial(n)/(factorial(n-m)*factorial(m))

n, m, k = map(int, input().split())

if cal(n+m, m) < k:
    print(-1)
    sys.exit(0)

res = ""

while True:
    if n == 0 or m == 0:
        res = res + 'z'*m + 'a'*n
        break

    if cal(n + m - 1, m) >= k:
        res = res + 'a'
        n -= 1
    else:
        res = res + 'z'
        k -= cal(n + m - 1, m)
        m -= 1

print(res)

참고 사이트