문제
풀이
MST를 구하고, MST 중 가중치가 가장 큰 간선을 절단해 마을을 2개로 분리하면 된다.
MST를 구하기 위해 Kruskal 알고리즘과 Prim 알고리즘을 사용할 수 있는데, 2가지 방법으로 모두 구현해 보았다.
1. Kruskal 알고리즘
N, M = map(int, input().split())
edges = [] #간선 리스트에 저장
for _ in range(M):
u, v, w = map(int, input().split())
edges.append((u, v, w))
edges 배열에 출발 노드, 도착 노드(무향 그래프지만 편의상 출발-도착으로 지칭), 가중치를 저장한다. Kruskal 알고리즘을 사용하기 위해 이와 같이 저장하였고, 추후 Prim 알고리즘으로 풀이할 때는 연결 리스트로 저장하므로 차이가 있다.
parent = [-1]*(N+1) #union-find 구현을 위한 parent 배열
def union(a, b):
root1 = find(a)
root2 = find(b)
if root1 != root2: #같은 그룹이 아니면
parent[root1] = root2 #a의 부모 노드를 b로 설정하여 같은 그룹에 편입
def find(a):
if parent[a] == -1: #루트 노드이면
return a
parent[a] = find(parent[a])
return parent[a]
이후 Kruskal 알고리즘 구현에 사용할 union-find 알고리즘을 작성한다.
def kruskal():
edges.sort(key=lambda x: x[2])
accepted, res_sum, res_max = 0, 0, 0
for u, v, w in edges:
if find(u) != find(v): #u와 v가 다른 집합이라면 MST에 포함시킬 수 있음
union(u, v)
accepted += 1
res_sum += w
res_max = max(res_max, w)
if accepted == N - 1:
break
return res_sum - res_max #완성된 MST에서 가중치가 가장 큰 간선을 제외
위에서 작성한 코드를 바탕으로 Kruskal 알고리즘 코드를 작성하면 된다.
전체 코드
import sys
sys.setrecursionlimit(10**6)
input = sys.stdin.readline
N, M = map(int, input().split())
edges = []
for _ in range(M):
u, v, w = map(int, input().split())
edges.append((u, v, w))
parent = [-1]*(N+1)
def union(a, b):
root1 = find(a)
root2 = find(b)
if root1 != root2:
parent[root1] = root2
def find(a):
if parent[a] == -1:
return a
parent[a] = find(parent[a])
return parent[a]
def kruskal():
edges.sort(key=lambda x: x[2])
accepted, res_sum, res_max = 0, 0, 0
for u, v, w in edges:
if find(u) != find(v):
union(u, v)
accepted += 1
res_sum += w
res_max = max(res_max, w)
if accepted == N - 1:
break
return res_sum - res_max
print(kruskal())
2. Prim 알고리즘
import sys
import heapq #minheap 사용
input = sys.stdin.readline
N, M = map(int, input().split())
graph = [[] for _ in range(N+1)]
for _ in range(M): #인접 리스트에 저장
u, v, w = map(int, input().split())
graph[u].append((w, v))
graph[v].append((w, u))
동일하게 u, v, w를 입력받지만 인접 리스트에 저장한다. 최소 가중치를 가진 간선을 추출하기 위해 최소 히프를 사용하므로 heapq를 import해 주어야 한다. 최소 히프를 사용하지 않을 경우 구현은 가능하지만 시간 초과로 표시된다.
def prim():
visited = [False] * (N+1)
min_heap = [(0, 1)]
accepted, res_sum, res_max = 0, 0, 0
while min_heap:
w, v = heapq.heappop(min_heap)
if visited[v]:
continue
visited[v] = True
res_sum += w
res_max = max(res_max, w)
accepted += 1
if accepted == N:
break
for next_w, next_n in graph[v]:
if not visited[next_n]:
heapq.heappush(min_heap, (next_w, next_n))
return res_sum - res_max
Prim 알고리즘을 구현한다.
전체 코드
import sys
import heapq
input = sys.stdin.readline
N, M = map(int, input().split())
graph = [[] for _ in range(N + 1)]
for _ in range(M):
u, v, w = map(int, input().split())
graph[u].append((w, v))
graph[v].append((w, u))
def prim():
visited = [False] * (N+1)
min_heap = [(0, 1)]
accepted, res_sum, res_max = 0, 0, 0
while min_heap:
w, v = heapq.heappop(min_heap)
if visited[v]:
continue
visited[v] = True
res_sum += w
res_max = max(res_max, w)
accepted += 1
if accepted == N:
break
for next_w, next_n in graph[v]:
if not visited[next_n]:
heapq.heappush(min_heap, (next_w, next_n))
return res_sum - res_max
print(prim())
+ 최소 히프를 사용하지 않은 Prim 알고리즘
import sys
input = sys.stdin.readline
N, M = map(int, input().split())
graph = [[] for _ in range(N + 1)]
for _ in range(M):
u, v, w = map(int, input().split())
graph[u].append((v, w))
graph[v].append((u, w))
def prim():
visited = [False] * (N + 1)
min_weight = [10001] * (N + 1)
min_weight[1] = 0
res_sum, res_max = 0, 0
for _ in range(N):
v = -1
curr_min = 10001
for i in range(1, N + 1):
if not visited[i] and min_weight[i] < curr_min:
curr_min = min_weight[i]
v = i
visited[v] = True
res_sum += curr_min
res_max = max(res_max, curr_min)
for v, weight in graph[v]:
if not visited[v] and weight < min_weight[v]:
min_weight[v] = weight
return res_sum - res_max
print(prim())
Kruskal 알고리즘으로 풀이한 방식(위)가 훨씬 빠른 것을 확인할 수 있다.
'백준' 카테고리의 다른 글
백준 16724: 피리 부는 사나이(python) (0) | 2025.02.24 |
---|---|
백준 1266: 다각형의 면적(python) (0) | 2025.01.08 |
백준 2343: 기타 레슨(python) (0) | 2024.11.18 |
백준 17298: 오큰수(java) (0) | 2024.10.13 |
백준 2473: 세 용액(java) (0) | 2024.10.11 |