문제

어떤 자연수 N은 그보다 작거나 같은 제곱수들의 합으로 나타낼 수 있다. 예를 들어 \(11=3^2+1^2+1^2\)(3개 항)이다. 이런 표현방법은 여러 가지가 될 수 있는데, 11의 경우 \(11=2^2+2^2+1^2+1^2+1^2\)(5개 항)도 가능하다. 이 경우, 수학자 숌크라테스는 “11은 3개 항의 제곱수 합으로 표현할 수 있다.”라고 말한다. 또한 11은 그보다 적은 항의 제곱수 합으로 표현할 수 없으므로, 11을 그 합으로써 표현할 수 있는 제곱수 항의 최소 개수는 3이다.

주어진 자연수 N을 이렇게 제곱수들의 합으로 표현할 때에 그 항의 최소개수를 구하는 프로그램을 작성하시오.

www.acmicpc.net/problem/1699

  • 입력

첫째 줄에 자연수 N이 주어진다. (1 ≤ N ≤ 100,000)

  • 출력

주어진 자연수를 제곱수의 합으로 나타낼 때에 그 제곱수 항의 최소 개수를 출력한다.

과정

N 항의 최소개수
1 \(1^2\) 1
2 \(1^2+1^2\) 2
3 \(1^2+1^2+1^2\) 3
4 \(2^2\) 1
5 \(2^2+1^2\) 2
6 \(2^2+1^2+1^2\)4 3
7 \(2^2+1^2+1^2+1^2\) 4
8 \(2^2+2^2\) 2
9 \(3^2\) 1
10 \(3^2+1^2\) 2
11 \(3^2+1^2+1^2\) 3
12 \(2^2+2^2+2^2\) 3
13 \(3^2+2^2\) 2

위 표는 1부터 13까지의 답을 나타낸 표이다. 이렇게 쭉 답을 적어보니 어떤 규칙을 발견할 수 있었다.

11을 예로 들면 dp[1] + dp[10], dp[2] + dp[9], … , dp[5] + dp[6] 중에서 최소값인 3이 답이 되었다.

12의 경우는 이 중에 dp[4] + dp[9] 가 최소값이므로 2가 답이 되었다.

이를 점화식으로 나타내면 dp[n] = min(dp[n], dp[i] + dp[n-i]) (i는 1부터 int((n+1)/2)까지) 이다.

아래는 직접 구현한 코드이다. 제곱수를 판별하기 위해 int(math.sqrt(i+1)) == math.sqrt(i+1) 이러한 코드를 사용하였다.

import math

N = int(input())
dp = [100000]*N
dp[0] = 1

for i in range(1, N):
    if int(math.sqrt(i+1)) == math.sqrt(i+1):
        dp[i] = 1
    else:
        for j in range(i//2+1):
            dp[i] = min(dp[i], dp[j]+dp[i-1-j])

print(dp[-1])

하지만 이 코드는 시간복잡도가 \(O(N^2)\)으로 시간이 초과되었다 ㅠㅠ.

최종 코드

찾아보니 굳이 모든 dp[i] + dp[n-i]를 비교하여 최소값을 찾아낼 필요가 없었다.

12를 예로 들면 dp[1] + dp[11]부터 dp[6] + dp[6]까지 모든 값을 비교하는 것이 아니라, dp[12-1^2^] + dp[1^2^], dp[12-2^2^] + dp[2^2^], dp[12-3^2^] + dp[3^2^] 세 값만 비교해서 최소값을 구하는 것이다. 이렇게 하면 시간복잡도를 \(O(N\sqrt{N})\)으로 줄일 수 있다.

따라서 점화식은 dp[n] = min(dp[n], dp[n - i*i] + 1) (단, i >= 1이며, n은 1부터 제곱했을 때 i에 가장 가까운 정수 까지)가 되는 것이다. 여기서 dp[i^2^]은 무조건 1이므로 1로 표기했다.

import math

N = int(input())
dp = [x for x in range(N+1)]

for i in range(2, N+1):
    for j in range(1, int(math.sqrt(i)+1)):
        if dp[i]>dp[i-j*j]:
            dp[i] = dp[i-j*j]+1

print(dp[N])

여기서 주의할 점은 if문 대신 min을 쓰면 시간이 초과된다는 것이다. min도 함부로 쓰면 안 되구나하는 교훈을 얻었다.

참고 사이트