[백준] 파일 합치기 (11066번)/(Knuth optimization) - Python
✅문제 - 파일 합치기 (11066번)
✅필요 알고리즘 개념 - DP
🔵특정 경로를 지나는 DP
◼ 예제 입력 40 30 30 50의 경우
위의 경우에 3가지 케이스가 존재할 수 있다.
1) (1) + (2 + 3 + 4)
2) (1+2) + (3+4)
3) (1+2+3) + (4)
즉 중간 경로로 k를 선택해 지나가는 것이다.
그러므로 dp를 1차원 배열이 아닌 2차원 배열로 만들어 준다.
dp[start][end] = min(dp[start][mid] + dp[mid+1][end]) + cost[start][end]
를 만족하는 dp를 만든다.
✅필요 알고리즘 개념 - Knuth optimization
dp[start][end] = min(dp[start][mid] + dp[mid+1][end]) + cost[start][end]
이러한 점화식을 만족시키는 dp 문제에서는 특별히 Knuth's optimization을 사용할 수 있다.
이번 포스팅에서는 어떤 조건 하에서 어떻게 사용하는지만 알아보자
(우선 cost 배열은
(1) cost(b, c) ≤ cost(a, d)
(2) cost(a, c) + cost(b, d) ≤ cost(b, c) + cost(a, d) 을 만족시켜야 한다.)
이런 조건을 만족시키는 문제의 경우 dp와 함께 optimization table도 함께 만든다
opt[start][end] 에는 cost 값을 최소로 만들기 위한 mid 값을 저장한다.
( dp[start][end] = min(dp[start][mid] + dp[mid+1][end]) + cost[start][end] 을 만족시키는 mid 값을 opt[start][end]에 저장 )
opt의 초기값으로 opt[i][i] = i 을 설정한다.
여기까지가 준비단계이다.
❗❗Knuth Optimiztion의 사용법
위의 조건을 만족시키면 아래의 식을 만족시킨다.
opt[start][end-1] <= opt[start][end] <= opt[start+1][end]
이렇게 하면 기존에는 mid의 값을 start~end 까지 모두 탐색해야 하는 반면
Knuth optimization을 사용하면 탐색 범위를 줄일 수 있게 되어
시간 복잡도를 N^3 에서 N^2 까지 대폭 감소 시킬 수 있다.
✅코드 ( Knuth Optimization 사용 X )
import sys
input = sys.stdin.readline
test_num = int(input().rstrip())
for _ in range(test_num):
number = int(input().rstrip())
num_list = [0] + list(map(int,input().rstrip().split()))
psum = [0 for _ in range(number+1)]
for i in range(1,number+1):
psum[i] = psum[i-1] + num_list[i]
dp = [[0 for _ in range(number+1)] for _ in range(number+1)]
for i in range(1,number):
dp[i][i+1] = psum[i+1] - psum[i-1]
for distance in range(2,number):
for start in range(1,number+1-distance):
cost = psum[start+distance] - psum[start-1]
minimum = dp[start][start] + dp[start+1][start+distance]
for mid in range(start+1,start+distance):
minimum = min(minimum,dp[start][mid]+dp[mid+1][start+distance])
dp[start][start+distance] = cost + minimum
print(dp[1][number])
위의 방법은 Knuth optimization을 적용하지 않고
mid 값을 아래의 표와 같은 방법으로 채워서 dp[1][number]를 구하는 것이다
✅코드 ( Knuth Optimization 사용 O )
import sys
input = sys.stdin.readline
test_num = int(input().rstrip())
mx = 1e8
for _ in range(test_num):
number = int(input().rstrip())
num_list = [0] + list(map(int,input().rstrip().split()))
psum = [0 for _ in range(number+1)]
for i in range(1,number+1):
psum[i] = psum[i-1] + num_list[i]
dp = [[0 for _ in range(number+1)] for _ in range(number+1)]
opt = [[0 for _ in range(number+1)] for _ in range(number+1)]
for i in range(1,number):
dp[i][i+1] = psum[i+1]-psum[i-1]
# optimization array / 초기값 설정
for i in range(1,number+1):
opt[i][i] = i
for start in range(number-1,0,-1):
for end in range(start+1,number+1):
cost = psum[end] - psum[start-1]
minimum = mx
# Knuth Optimization 을 사용해 mid 의 범위를 축소 시킴 !!!!!
# mid의 값은 end 보다는 작아야 하므로 min()을 사용해줌
for mid in range(opt[start][end-1],min(end,opt[start+1][end]+1)):
tmp = dp[start][mid]+dp[mid+1][end]
# 최단 경로의 값과 그 때의 중간 경로 지점을 opt에 저장
if tmp <= minimum:
minimum = tmp
opt[start][end] = mid
dp[start][end] = cost + minimum
print(dp[1][number])
이 코드에서 start 와 end 의 값을 이해하기 힘들 수 있다.
우선 Knuth optimization 을 사용하는 방법은 mid의 범위를 구할 때 아래의 식을 적용하는 것이다.
opt[start][end-1] <= opt[start][end] <= opt[start+1][end]
❗❗❗ 여기서 우리의 타겟은 opt[start][end]의 범위를 구하는 것이다.
그러기 위해서는 opt[start][end-1]과 opt[start+1][end]를 알아야 한다. 아래의 표를 보자
opt[3][4]를 구하기 위해서는 opt[3][3] 과 opt[4][4]를 알아야한다.
그러므로 우리는 opt의 초기 상태에서는 opt[2][4]와 같은 값은 알 수 없다는 것이다.
그러므로 우리는 opt와 dp의 탐색 방향을 아래의 표와 같이 가기 원하는 것이다.
위의 표 방향대로 채워가면
opt[start][end-1] <= opt[start][end] <= opt[start+1][end]
위의 식을 사용함에 있어서 모든 opt 값을 채워가며 찾을 수 있기 때문에
for start in range(number-1,0,-1):
for end in range(start+1,number+1):
loop에서 start와 end의 범위가 위와 같이 설정되는 것이다.
##참고
https://www.geeksforgeeks.org/knuths-optimization-in-dynamic-programming/