문제

세준이는 N/M크기로 직사각형에 수를 N/M개 써놓았다.

세준이는 이 직사각형을 겹치지 않는 3개의 작은 직사각형으로 나누려고 한다. 각각의 칸은 단 하나의 작은 직사각형에 포함되어야 하고, 각각의 작은 직사각형은 적어도 하나의 숫자를 포함해야 한다.

어떤 작은 직사각형의 합은 그 속에 있는 수의 합이다. 입력으로 주어진 직사각형을 3개의 작은 직사각형으로 나누었을 때, 각각의 작은 직사각형의 합의 곱을 최대로 하는 프로그램을 작성하시오.

www.acmicpc.net/problem/1451

  • 입력

첫째 줄에 직사각형의 세로 크기 N과 가로 크기 M이 주어진다. 둘째 줄부터 직사각형에 들어가는 수가 가장 윗 줄부터 한 줄에 하나씩 M개의 수가 주어진다. N과 M은 100보다 작거나 같은 자연수이고, 직사각형엔 적어도 3개의 수가 있다. 또, 직사각형에 들어가는 수는 한 자리의 숫자이다.

  • 출력

세 개의 작은 직사각형의 합의 곱의 최댓값을 출력한다.

과정

사진 사진 출처

위 사진에서 나눠놓은 것처럼 직사각형을 3개의 직사각형으로 나누는 방법에는 6가지가 있었다. 따라서 경우를 6가지로 나눠서 나눌 수 있는 모든 방법을 탐색하여 최댓값을 구했다.

pypy3에선 통과했지만 python3에서는 시간 초과가 떴다.

import sys

n, m = map(int, sys.stdin.readline().split())
arr = [list(map(int, list(sys.stdin.readline().strip()))) for _ in range(n)]
res = 0

for i in range(m-1):
    # 세로, 세로 분할
    for j in range(i+1, m-1):
        tmp = [0, 0, 0]
        for k in range(n):
            for l in range(m):
                if l <= i:
                    tmp[0] += arr[k][l]
                elif l <= j:
                    tmp[1] += arr[k][l]
                else:
                    tmp[2] += arr[k][l]
        res = max(res, tmp[0]*tmp[1]*tmp[2])

    for j in range(n-1):
        # 세로, 좌측 가로 분할
        tmp = [0, 0, 0]
        for k in range(n):
            for l in range(m):
                if l <= i and k <= j:
                    tmp[0] += arr[k][l]
                elif l <= i and k > j:
                    tmp[1] += arr[k][l]
                else:
                    tmp[2] += arr[k][l]
        res = max(res, tmp[0] * tmp[1] * tmp[2])

        # 세로, 우측 가로 분할
        tmp = [0, 0, 0]
        for k in range(n):
            for l in range(m):
                if l > i and k <= j:
                    tmp[0] += arr[k][l]
                elif l > i and k > j:
                    tmp[1] += arr[k][l]
                else:
                    tmp[2] += arr[k][l]
        res = max(res, tmp[0] * tmp[1] * tmp[2])

for i in range(n-1):
    # 가로, 가로 분할
    for j in range(i+1, n-1):
        tmp = [0, 0, 0]
        for k in range(n):
            for l in range(m):
                if k <= i:
                    tmp[0] += arr[k][l]
                elif k <= j:
                    tmp[1] += arr[k][l]
                else:
                    tmp[2] += arr[k][l]
        res = max(res, tmp[0]*tmp[1]*tmp[2])

    for j in range(m-1):
        # 가로, 상단 세로 분할
        tmp = [0, 0, 0]
        for k in range(n):
            for l in range(m):
                if k <= i and l <= j:
                    tmp[0] += arr[k][l]
                elif k <= i and l > j:
                    tmp[1] += arr[k][l]
                else:
                    tmp[2] += arr[k][l]
        res = max(res, tmp[0] * tmp[1] * tmp[2])

        # 가로, 하단 세로 분할
        tmp = [0, 0, 0]
        for k in range(n):
            for l in range(m):
                if k > i and l <= j:
                    tmp[0] += arr[k][l]
                elif k > i and l > j:
                    tmp[1] += arr[k][l]
                else:
                    tmp[2] += arr[k][l]
        res = max(res, tmp[0] * tmp[1] * tmp[2])

print(res)

최종코드

더 나은 풀이를 찾아보니 다행히 경우를 6가지로 나눠서 모든 경우의 수를 찾는 접근 방법은 맞았다. 하지만 모든 경우의 수를 구하는 과정에서 계속 새롭게 직사각형의 합을 구하다보니 시간이 초과된 것이었다.

이 문제를 해결하기 위해 이차원배열 s를 만들었다. 배열 s는 직사각형의 합을 저장해놓는 용도로 s[i][j]는 [1, 1] 부터 [i, j]까지 직사각형의 합을 의미한다. 점화식은 아래와 같다.

s[i][j] = s[i-1][j] + s[i][j-1] - s[i-1][j-1] + arr[i][j]

만약 [a, b] 부터 [c, d]까지 직사각형의 합을 구하려면 s[c][d] - s[c][b-1] - s[a-1][d] + s[a-1][b-1] 과 같은 식을 통해 반복문을 사용하지 않고 O(1)만에 합을 구할 수 있다.

그림도 있는 더 자세한 설명은 여기를 참고하자.

import sys

n, m = map(int, sys.stdin.readline().split())
arr = [[0]*(m+1)] + [[0] + list(map(int, list(sys.stdin.readline().strip()))) for _ in range(n)]
# 직사각형의 합을 계산하기 위한 s 배열, s[i][j] = [1, 1] 부터 [i, j]까지의 합
s = [[0]*(m+1) for _ in range(n+1)]
for i in range(1, n+1):
    for j in range(1, m+1):
        s[i][j] = s[i][j-1] + s[i-1][j] - s[i-1][j-1] + arr[i][j]
res = 0

# [x1+1, y1+1] 부터 [x2, y2] 까지의 합을 계산하는 함수
def sum(x1, y1, x2, y2):
    return s[x2][y2] - s[x2][y1] - s[x1][y2] + s[x1][y1]

# 가로, 가로 분할
for i in range(1, n):
    for j in range(i+1, n):
        tmp1 = sum(0, 0, i, m)
        tmp2 = sum(i, 0, j, m)
        tmp3 = sum(j, 0, n, m)
        res = max(res, tmp1 * tmp2 * tmp3)

for i in range(1, n):
    for j in range(1, m):
        # 가로, 상단 세로 분할
        tmp1 = sum(0, 0, i, j)
        tmp2 = sum(0, j, i, m)
        tmp3 = sum(i, 0, n, m)
        res = max(res, tmp1 * tmp2 * tmp3)
        # 가로, 하단 세로 분할
        tmp1 = sum(0, 0, i, m)
        tmp2 = sum(i, 0, n, j)
        tmp3 = sum(i, j, n, m)
        res = max(res, tmp1 * tmp2 * tmp3)

# 세로, 세로 분할
for i in range(1, m):
    for j in range(i+1, m):
        tmp1 = sum(0, 0, n, i)
        tmp2 = sum(0, i, n, j)
        tmp3 = sum(0, j, n, m)
        res = max(res, tmp1 * tmp2 * tmp3)

for i in range(1, m):
    for j in range(1, n):
        # 세로, 좌측 가로 분할
        tmp1 = sum(0, 0, j, i)
        tmp2 = sum(j, 0, n, i)
        tmp3 = sum(0, i, n, m)
        res = max(res, tmp1 * tmp2 * tmp3)
        # 세로, 우측 가로 분할
        tmp1 = sum(0, 0, n, i)
        tmp2 = sum(0, i, j, m)
        tmp3 = sum(j, i, n, m)
        res = max(res, tmp1 * tmp2 * tmp3)

print(res)

참고 사이트