본문 바로가기
백준

백준 1647: 도시 분할 계획(python)

by unhyepnhj 2025. 1. 8.

문제


풀이

 

MST를 구하고, MST 중 가중치가 가장 큰 간선을 절단해 마을을 2개로 분리하면 된다.

MST를 구하기 위해 Kruskal 알고리즘과 Prim 알고리즘을 사용할 수 있는데, 2가지 방법으로 모두 구현해 보았다.


1. Kruskal 알고리즘

N, M = map(int, input().split())
edges = []  #간선 리스트에 저장
for _ in range(M):
    u, v, w = map(int, input().split())
    edges.append((u, v, w))

edges 배열에 출발 노드, 도착 노드(무향 그래프지만 편의상 출발-도착으로 지칭), 가중치를 저장한다. Kruskal 알고리즘을 사용하기 위해 이와 같이 저장하였고, 추후 Prim 알고리즘으로 풀이할 때는 연결 리스트로 저장하므로 차이가 있다.

parent = [-1]*(N+1)	#union-find 구현을 위한 parent 배열

def union(a, b):
    root1 = find(a)
    root2 = find(b)

    if root1 != root2:  #같은 그룹이 아니면
        parent[root1] = root2	#a의 부모 노드를 b로 설정하여 같은 그룹에 편입
        
def find(a):
    if parent[a] == -1: #루트 노드이면
        return a
    parent[a] = find(parent[a])
    return parent[a]

이후 Kruskal 알고리즘 구현에 사용할 union-find 알고리즘을 작성한다.

def kruskal():
    edges.sort(key=lambda x: x[2])
    accepted, res_sum, res_max = 0, 0, 0

    for u, v, w in edges:
        if find(u) != find(v):  #u와 v가 다른 집합이라면 MST에 포함시킬 수 있음
            union(u, v)
            accepted += 1
            res_sum += w
            res_max = max(res_max, w)

            if accepted == N - 1:
                break

    return res_sum - res_max	#완성된 MST에서 가중치가 가장 큰 간선을 제외

위에서 작성한 코드를 바탕으로 Kruskal 알고리즘 코드를 작성하면 된다.

 

전체 코드

import sys
sys.setrecursionlimit(10**6)
input = sys.stdin.readline

N, M = map(int, input().split())
edges = []
for _ in range(M):
    u, v, w = map(int, input().split())
    edges.append((u, v, w))

parent = [-1]*(N+1)

def union(a, b):
    root1 = find(a)
    root2 = find(b)

    if root1 != root2:
        parent[root1] = root2

def find(a):
    if parent[a] == -1:
        return a
    parent[a] = find(parent[a])
    return parent[a]

def kruskal():
    edges.sort(key=lambda x: x[2])
    accepted, res_sum, res_max = 0, 0, 0

    for u, v, w in edges:
        if find(u) != find(v):
            union(u, v)
            accepted += 1
            res_sum += w
            res_max = max(res_max, w)

            if accepted == N - 1:
                break

    return res_sum - res_max

print(kruskal())

2. Prim 알고리즘

import sys
import heapq	#minheap 사용
input = sys.stdin.readline

N, M = map(int, input().split())
graph = [[] for _ in range(N+1)]

for _ in range(M):	#인접 리스트에 저장
    u, v, w = map(int, input().split())
    graph[u].append((w, v))
    graph[v].append((w, u))

동일하게 u, v, w를 입력받지만 인접 리스트에 저장한다. 최소 가중치를 가진 간선을 추출하기 위해 최소 히프를 사용하므로 heapq를 import해 주어야 한다. 최소 히프를 사용하지 않을 경우 구현은 가능하지만 시간 초과로 표시된다.

def prim():
    visited = [False] * (N+1)
    min_heap = [(0, 1)]
    accepted, res_sum, res_max = 0, 0, 0

    while min_heap:
        w, v = heapq.heappop(min_heap)

        if visited[v]:
            continue

        visited[v] = True
        res_sum += w
        res_max = max(res_max, w)
        accepted += 1

        if accepted == N:
            break

        for next_w, next_n in graph[v]:
            if not visited[next_n]:
                heapq.heappush(min_heap, (next_w, next_n))

    return res_sum - res_max

Prim 알고리즘을 구현한다.

 

전체 코드

import sys
import heapq
input = sys.stdin.readline


N, M = map(int, input().split())
graph = [[] for _ in range(N + 1)]

for _ in range(M):
    u, v, w = map(int, input().split())
    graph[u].append((w, v))
    graph[v].append((w, u))

def prim():
    visited = [False] * (N+1)
    min_heap = [(0, 1)]
    accepted, res_sum, res_max = 0, 0, 0

    while min_heap:
        w, v = heapq.heappop(min_heap)

        if visited[v]:
            continue

        visited[v] = True
        res_sum += w
        res_max = max(res_max, w)
        accepted += 1

        if accepted == N:
            break

        for next_w, next_n in graph[v]:
            if not visited[next_n]:
                heapq.heappush(min_heap, (next_w, next_n))

    return res_sum - res_max

print(prim())

 

+ 최소 히프를 사용하지 않은 Prim 알고리즘

import sys
input = sys.stdin.readline

N, M = map(int, input().split())
graph = [[] for _ in range(N + 1)]

for _ in range(M):
    u, v, w = map(int, input().split())
    graph[u].append((v, w))
    graph[v].append((u, w))

def prim():
    visited = [False] * (N + 1)
    min_weight = [10001] * (N + 1)
    min_weight[1] = 0
    res_sum, res_max = 0, 0

    for _ in range(N):
        v = -1
        curr_min = 10001
        for i in range(1, N + 1):
            if not visited[i] and min_weight[i] < curr_min:
                curr_min = min_weight[i]
                v = i

        visited[v] = True
        res_sum += curr_min
        res_max = max(res_max, curr_min)

        for v, weight in graph[v]:
            if not visited[v] and weight < min_weight[v]:
                min_weight[v] = weight

    return res_sum - res_max

print(prim())

Kruskal 알고리즘으로 풀이한 방식(위)가 훨씬 빠른 것을 확인할 수 있다. 

'백준' 카테고리의 다른 글

백준 16724: 피리 부는 사나이(python)  (0) 2025.02.24
백준 1266: 다각형의 면적(python)  (0) 2025.01.08
백준 2343: 기타 레슨(python)  (0) 2024.11.18
백준 17298: 오큰수(java)  (0) 2024.10.13
백준 2473: 세 용액(java)  (0) 2024.10.11