Algorithm & Data Structure

[백준] 파일 합치기 (11066번)/(Knuth optimization) - Python

후뿡이 2023. 1. 30. 16:15

✅문제 - 파일 합치기 (11066번)


 

필요 알고리즘 개념 - DP


🔵특정 경로를 지나는 DP

◼ 예제 입력 40 30 30 50의 경우

위의 경우에 3가지 케이스가 존재할 수 있다.

1) (1) + (2 + 3 + 4)

2) (1+2) + (3+4)

3) (1+2+3) + (4)

 

즉 중간 경로로 k를 선택해 지나가는 것이다.

 

그러므로 dp를 1차원 배열이 아닌 2차원 배열로 만들어 준다.

dp[start][end] = min(dp[start][mid] + dp[mid+1][end]) + cost[start][end]

를 만족하는 dp를 만든다.

 

필요 알고리즘 개념 - Knuth optimization


dp[start][end] = min(dp[start][mid] + dp[mid+1][end]) + cost[start][end]

이러한 점화식을 만족시키는 dp 문제에서는 특별히 Knuth's optimization을 사용할 수 있다.

 

이번 포스팅에서는 어떤 조건 하에서 어떻게 사용하는지만 알아보자

 

(우선 cost 배열은

(1) cost(b, c) ≤ cost(a, d)

(2) cost(a, c) + cost(b, d) ≤ cost(b, c) + cost(a, d) 을 만족시켜야 한다.)

 

이런 조건을 만족시키는 문제의 경우 dp와 함께 optimization table도 함께 만든다

opt[start][end] 에는 cost 값을 최소로 만들기 위한 mid 값을 저장한다.

( dp[start][end] = min(dp[start][mid] + dp[mid+1][end]) + cost[start][end] 을 만족시키는 mid 값을 opt[start][end]에 저장 )

 

opt의 초기값으로 opt[i][i] = i 을 설정한다. 

 

여기까지가 준비단계이다.

 

❗❗Knuth Optimiztion의 사용법

위의 조건을 만족시키면 아래의 식을 만족시킨다.

opt[start][end-1] <= opt[start][end] <= opt[start+1][end]

 

이렇게 하면 기존에는 mid의 값을 start~end 까지 모두 탐색해야 하는 반면

Knuth optimization을 사용하면 탐색 범위를 줄일 수 있게 되어

 

시간 복잡도를 N^3 에서 N^2 까지 대폭 감소 시킬 수 있다.

 

 

✅코드 ( Knuth Optimization 사용 X )


import sys

input = sys.stdin.readline
test_num = int(input().rstrip())
for _ in range(test_num):
    number = int(input().rstrip())
    
    num_list = [0] + list(map(int,input().rstrip().split()))
    psum = [0 for _ in range(number+1)]
    
    for i in range(1,number+1):
        psum[i] = psum[i-1] + num_list[i]

    dp = [[0 for _ in range(number+1)] for _ in range(number+1)]
    
    for i in range(1,number):
        dp[i][i+1] = psum[i+1] - psum[i-1]
    
    for distance in range(2,number):
        for start in range(1,number+1-distance):
            cost = psum[start+distance] - psum[start-1]
            minimum = dp[start][start] + dp[start+1][start+distance]
            
            for mid in range(start+1,start+distance):
                minimum = min(minimum,dp[start][mid]+dp[mid+1][start+distance])
            
            dp[start][start+distance] = cost + minimum
    
    print(dp[1][number])

 

위의 방법은 Knuth optimization을 적용하지 않고

mid 값을 아래의 표와 같은 방법으로 채워서 dp[1][number]를 구하는 것이다

 

코드 ( Knuth Optimization 사용 O )


import sys

input = sys.stdin.readline
test_num = int(input().rstrip())
mx = 1e8

for _ in range(test_num):
    number = int(input().rstrip())
    
    num_list = [0] + list(map(int,input().rstrip().split()))
    psum = [0 for _ in range(number+1)]
    
    for i in range(1,number+1):
        psum[i] = psum[i-1] + num_list[i]

    dp = [[0 for _ in range(number+1)] for _ in range(number+1)]
    opt = [[0 for _ in range(number+1)] for _ in range(number+1)]
    
    for i in range(1,number):
        dp[i][i+1] = psum[i+1]-psum[i-1]

    
    # optimization array / 초기값 설정
    for i in range(1,number+1):
        opt[i][i] = i
    
    for start in range(number-1,0,-1):
        for end in range(start+1,number+1):
            cost = psum[end] - psum[start-1]
            minimum = mx
            
            # Knuth Optimization 을 사용해 mid 의 범위를 축소 시킴 !!!!!
            # mid의 값은 end 보다는 작아야 하므로 min()을 사용해줌
            for mid in range(opt[start][end-1],min(end,opt[start+1][end]+1)):
                tmp = dp[start][mid]+dp[mid+1][end]

                # 최단 경로의 값과 그 때의 중간 경로 지점을 opt에 저장
                if tmp <= minimum:
                    minimum = tmp
                    opt[start][end] = mid
            dp[start][end] = cost + minimum

    print(dp[1][number])

 

이 코드에서 start 와 end 의 값을 이해하기 힘들 수 있다. 

우선 Knuth optimization 을 사용하는 방법은 mid의 범위를 구할 때 아래의 식을 적용하는 것이다.

opt[start][end-1] <= opt[start][end] <= opt[start+1][end]

❗❗❗ 여기서 우리의 타겟은 opt[start][end]의 범위를 구하는 것이다.

그러기 위해서는 opt[start][end-1]과 opt[start+1][end]를 알아야 한다. 아래의 표를 보자

opt[3][4]를 구하기 위해서는 opt[3][3] 과 opt[4][4]를 알아야한다.

그러므로 우리는 opt의 초기 상태에서는 opt[2][4]와 같은 값은 알 수 없다는 것이다.

그러므로 우리는 opt와 dp의 탐색 방향을 아래의 표와 같이 가기 원하는 것이다.

위의 표 방향대로 채워가면

opt[start][end-1] <= opt[start][end] <= opt[start+1][end]

위의 식을 사용함에 있어서 모든 opt 값을 채워가며 찾을 수 있기 때문에

 

    for start in range(number-1,0,-1):
        for end in range(start+1,number+1):

loop에서 start와 end의 범위가 위와 같이 설정되는 것이다.

 

 

##참고

https://www.geeksforgeeks.org/knuths-optimization-in-dynamic-programming/