Algorithm & Data Structure

[백준] 별자리 만들기 (4386번) - Python

후뿡이 2023. 5. 31. 01:30

✅ 문제 - [백준] 별자리 만들기 (4386번)


 

필요 알고리즘 개념 - 프림 알고리즘 ( Prim's Alghrithm )


🎯 프림 알고리즘 ( Prim's Alghrithm ) 이란?

프림 알고리즘이란 MST ( Minimum Spanning Tree , 최소 신장 트리 )를 찾는 방법 중 시작 노드부터 시작하여 방문한 노드 중 cost가 가장 적은 경로를 선택해 가며 최소 신장 트리를 만드는 일종의 그리디 알고리즘이다.

그리디 알고리즘의 경우 순간의 최선의 선택이 결과적으로 최선의 선택임을 증명해야 하는데 

MST를 푸는 알고리즘 중 프림(Prim's) 알고리즘과 크루스칼(Kruskal) 알고리즘은 증명이 된 그리디 알고리즘이므로 해결 가능함이 이미 증명이 된 알고리즘이다.

 

🎯 프림 알고리즘 ( Prim's Alghrithm ) 예시

왼쪽의 그림처럼 6개의 노드와 8개의 엣지가 있다고 하자.

 

이때 프림 알고리즘은 시작 노드를 정해야 하는데 결국 모든 노드를 지나야 하므로 어디로 설정하던 상관없다.

 

우리는 1번 노드에서 시작한다고 가정하자

 

이때 선택 가능한 경로는 2번 노드로 가는 코스트 4 인 경로 또는 3번 노드로 가는 코스트 6인 노드가 있다. 

이 노드들을 ( Cost , Node ) 순으로 힙큐에 넣어준다.

힙 큐의 상태는 [ (4,2), (6,3) ] 이 된다.

 

프림 알고리즘은 항상 최단의 경로만을 가정하므로 

2번 노드를 방문한다.

 

2번 노드를 방문하면 2번 노드와 연결된 엣지들을 힙큐에 추가해 준다.

 

힙큐의 상태는 [ (3,3), (6,3), (7,4) ] 가 된다.

 

그러므로 다음 상태에서는 코스트가 3인 3번 노드를 방문하게 된다. 

 

 

 

 

3번 노드를 방문하면 3번 노드와 연결된 엣지들을 힙큐에 추가해준다.

 

힙큐의 상태는 [ (2,4), (6,3), (7,4), (8,6) ] 가 된다.

 

그러므로 다음 상태에서는 코스트가 2인 4번 노드를 방문하게 된다. 

 

 

 

 

4번 노드를 방문하면 4번 노드와 연결된 엣지들을 힙큐에 추가해 준다.

 

힙큐의 상태는 [ (4,5), (6,3), (7,2), (7,4), (8,6) ] 가 된다.

 

그러므로 다음 상태에서는 코스트가 4인 5번 노드를 방문하게 된다. 

 

 

 

 

5번 노드를 방문하면 5번 노드와 연결된 엣지들을 힙큐에 추가해 준다.

 

힙큐의 상태는 [ (5,6), (6,3), (7,2), (7,4), (8,6) ] 가 된다.

 

그러므로 다음 상태에서는 코스트가 5인 6번 노드를 방문하게 된다. 

 

 

 

 

 

6번 노드를 방문하면 6번 노드와 연결된 엣지들을 힙큐에 추가해 준다.

 

힙큐의 상태는 [ (5,6), (6,3), (7,2), (7,4), (8,3), (8,6) ] 가 된다.

 

이 이후에 heappop을 통해 노드들을 꺼내면 모두 방문한 노드이므로

아무런 동작을 취하지 않고 힙큐가 비게 된다.

 

 

이러한 프림 알고리즘을 통해 MST를 구할 수 있다 !!

 

 

✅ 코드


import sys
import heapq

def getDist(node1,node2):
    dist = ((node1[0]-node2[0])**2 + (node1[1]-node2[1])**2)**0.5
    return round(dist,3)

N = int(sys.stdin.readline().rstrip())

nodes = []
visited = [ False for _ in range(N)]
dists = [ [ 0 for _ in range(N)] for _ in range(N)]


for _ in range(N):
    x,y = map(float,sys.stdin.readline().rstrip().split())
    nodes.append((x,y))

for i in range(N):
    for j in range(i+1,N):
        dists[i][j] = getDist(nodes[i],nodes[j])
        dists[j][i] = dists[i][j]

result = 0

hq = []

for i in range(1,N):
    heapq.heappush(hq,(dists[0][i],i))
visited[0] = True

while hq:
    dist,v = heapq.heappop(hq)
    if visited[v]:
        continue

    result += dist
    visited[v] = True

    for i in range(0,i):
        heapq.heappush(hq,(dists[v][i],i))
    
    for i in range(i+1,N):
        heapq.heappush(hq,(dists[v][i],i))
    
print(round(result,2))