본문 바로가기
백준

백준 1074: Z(python)

by unhyepnhj 2025. 7. 1.

문제


풀이

 

2x2 배열이 될 때까지 계속 분할하며 순서를 카운트하는 전형적인 분할 정복 문제이다. 아래와 같이 nxn 크기 배열 원소들의 방문 순서를 계산하는 visit() 함수를 정의한다.

def visit(n, x, y):
    global count
    
    if n == 2:  # 최소 단위
        if c not in (x, x + 1) or r not in (y, y + 1):
            count += 4
        else:
            if c == x and r == y: print(count)
            elif c == x + 1 and r == y: print(count + 1)
            elif c == x and r == y + 1: print(count + 2)
            else: print(count + 3)
            return

 

이때 n은 2차원 배열 한 변의 크기, x와 y는 배열 좌측 상단 칸의 인덱스이다.

현재 탐색하는 부분 배열이 2x2일 때 목표 칸 인덱스 r과 c가 현재 배열에 포함되지 않으면 배열 내부를 탐색을 건너뛰고 count를 4 증가하는 것으로 대신하며, r과 c가 현재 배열 내에 있으면 (x, y) -> (x+1, y) -> (x, y+1) -> (x+1, y+1) 순으로 count를 1씩 증가하며 r과 c의 방문 순서를 계산해 함수를 조기 종료한다.

    else:
        dist = n // 2
        if c < x + dist and r < y + dist: visit(dist, x, y) # 1사분면
        else: count += dist * dist  # skip

        if c >= x + dist and r < y + dist: visit(dist, x + dist, y) # 2사분면
        else: count += dist * dist  # skip

        if c < x + dist and r >= y + dist: visit(dist, x, y + dist) # 3사분면
        else: count += dist * dist  # skip

        if c >= x + dist and r >= y + dist: visit(dist, x + dist, y + dist) # 4사분면
        else: count += dist * dist  # skip

부분 배열을 더 분할할 수 있을 때는 위와 같이 재귀 호출로 처리한다. 이때 현재 배열의 중심 (x + dist, y + dist) 기준 r, c의 위치를 고려함으로써 visit()을 매번 호출하지 않도록 구현해야 하며 그렇지 않으면 시간이 초과된다. 해당되는 사분면에서만 visit()을 호출하고, 그렇지 않을 경우 실제 재귀 호출 없이 count를 dist*dist 크기만큼, 즉 탐색할 부분 배열 원소 수만큼 증가함으로써 탐색을 수행한 것과 동일하게 처리한다.

 

전체 코드

import sys
input = sys.stdin.readline
sys.setrecursionlimit(10 ** 6)

N, r, c = map(int, input().split())
count = 0   # 방문 순서

def visit(n, x, y):   # n = 탐색할 배열 한 변의 크기 / x, y = 배열 좌측 상단
    global count

    if n == 2:  # 최소 단위
        if c not in (x, x + 1) or r not in (y, y + 1):
            count += 4
        else:
            if c == x and r == y: print(count)
            elif c == x + 1 and r == y: print(count + 1)
            elif c == x and r == y + 1: print(count + 2)
            else: print(count + 3)
            return
    else:
        dist = n // 2
        if c < x + dist and r < y + dist: visit(dist, x, y) # 1사분면
        else: count += dist * dist  # skip

        if c >= x + dist and r < y + dist: visit(dist, x + dist, y) # 2사분면
        else: count += dist * dist  # skip

        if c < x + dist and r >= y + dist: visit(dist, x, y + dist) # 3사분면
        else: count += dist * dist  # skip

        if c >= x + dist and r >= y + dist: visit(dist, x + dist, y + dist) # 4사분면
        else: count += dist * dist  # skip

visit(2 ** N, 0, 0)