문제
N개의 정수로 이루어진 배열 A가 주어진다. 이때, 배열에 들어있는 정수의 순서를 적절히 바꿔서 다음 식의 최댓값을 구하는 프로그램을 작성하시오.
|A[0] - A[1]| + |A[1] - A[2]| + ... + |A[N-2] - A[N-1]|
- 입력
첫째 줄에 N (3 ≤ N ≤ 8)이 주어진다. 둘째 줄에는 배열 A에 들어있는 정수가 주어진다. 배열에 들어있는 정수는 -100보다 크거나 같고, 100보다 작거나 같다.
- 출력
첫째 줄에 배열에 들어있는 수의 순서를 적절히 바꿔서 얻을 수 있는 식의 최댓값을 출력한다.
과정
3 ≤ N ≤ 8 이라는 조건을 보고 순열을 전부 구해서 푸는 완전탐색 문제일 수도 있겠다 생각했지만 그래도 뭔가 규칙이 있지 않을까 해서 나름의 규칙을 고안했다.
아래 코드를 보면 알겠지만 index i는 (n-1)//2부터 시작해서 +2, -4, +6, -8, … 로 이동한다. index i에는 큰 수부터 배치한다.
index j는 i+1부터 시작해서 -2, +4, -6, +8, … 로 이동한다. index i와 동시에 동작하고 반대로 작은 수부터 배치한다.
예를 들어 [20, 1, 15, 8, 4, 10]이라는 순열이 있을 때 i는 (n-1)//2 = 2부터 시작하고 j는 3부터 시작한다.
- i = 2, j = 3 [-, -, 20, 1, -, -]
- i = 4, j = 1 [-, 4, 20, 1, 15, -]
- i = 0, j = 5 [10, 4, 20, 1, 15, 8]
위 과정을 거치면 62가 나오므로 이런 규칙이 아닐까 생각했다.
하지만 반례가 존재했다. 만약 배열이 [1, 15, 2, 8, 3]인 경우 내 방식대로라면 [3, 2, 15, 1, 8]의 순서로 배치되어 35가 출력되지만 정답은 [1, 15, 2, 8, 3]으로 최댓값은 38이다.
주저리 주저리 설명이 길었지만 그래서 규칙이 존재하는 것이 아니라 완전 탐색을 해야한다는 것이다. 아래는 틀린 코드이다.
import sys
n = int(sys.stdin.readline())
arr = sorted(list(map(int, sys.stdin.readline().split())))
newArr = [0]*n
i = (n-1)//2
j = i + 1
idx = 0
flag = 1
while True:
if 0 <= i < n and 0 <= j < n:
# index i는 큰 수부터 배치
newArr[i] = arr[-idx-1]
# index j는 작은 수부터 배치
newArr[j] = arr[idx]
i += 2 * (idx+1) * flag
j -= 2 * (idx+1) * flag
flag *= -1
idx += 1
else:
if (i < 0 or i >= n) and 0 <= j < n:
newArr[j] = arr[idx]
elif (j < 0 or j >= n) and 0 <= i < n:
newArr[i] = arr[-idx-1]
break
res = 0
for i in range(n-1):
res += abs(newArr[i]-newArr[i+1])
print(res)
최종코드
앞서 깨달았듯이 완전탐색을 통해 문제를 풀어야하기 때문에 주어진 배열의 순열을 모두 구해야 한다.
하지만 나는 순열 구하는 법을 모르기 때문에^^ 역시 찾아봤다. 라이브러리를 쓰면 간단하지만 과정이 궁금해서 라이브러리를 쓰지 않고도 구현해보았다.
아래는 라이브러리를 사용하지 않은 코드이다. 조금 어려운데 이렇게 순열을 구할 수 있다는 게 신기했다.
import sys
def next_permutation(a):
k = m = -1
length = len(a)
# 증가하는 마지막 부분을 가리키는 index k 찾기
for i in range(length-1):
if a[i] < a[i+1]:
k = i
# 전체 내림차순일 경우 반환한다.
if k == -1:
return False
# index k 이후 부분 중 값이 k보다 크면서 가장 멀리 있는 index m 찾기
for i in range(k, length):
if a[k] < a[i]:
m = i
# k와 m의 값을 바꾼다
a[k], a[m] = a[m], a[k]
# index k 이후 부분을 오름차순으로 정렬한다
a = a[:k+1] + sorted(a[k+1:])
return a
n = int(sys.stdin.readline())
arr = sorted(list(map(int, sys.stdin.readline().split())))
res = 0
# 원래 배열 계산
for i in range(n-1):
res += abs(arr[i] - arr[i+1])
# 순열마다 값을 계산
while True:
arr = next_permutation(arr)
if arr == False:
break
tmp = 0
for i in range(n-1):
tmp += abs(arr[i] - arr[i+1])
res = max(res, tmp)
print(res)
마지막으로 라이브러리를 사용한 코드이다. 훨씬 간단하기 때문에 실전에서는 라이브러리를 사용해야할 것 같다.
import sys
from itertools import permutations
n = int(sys.stdin.readline())
arr = sorted(list(map(int, sys.stdin.readline().split())))
# permutation 저장
per = permutations(arr)
res = 0
# 순열마다 값을 계산
for p in per:
tmp = 0
for i in range(n-1):
tmp += abs(p[i] - p[i+1])
res = max(res, tmp)
print(res)
참고 사이트
- [백준 알고리즘] 10819번 Python 풀이 sdesigner.tistory.com/51