yoongrammer

세그먼트 트리(Segment Tree) 개념 및 구현 본문

자료구조 (Data structure)

세그먼트 트리(Segment Tree) 개념 및 구현

yoongrammer 2022. 10. 9. 16:32
728x90

목차

    세그먼트 트리(Segment Tree)


    세그먼트 트리(Segment Tree)는 배열 간격에 대한 정보를 이진 트리에 저장하는 자료구조입니다.

     

    다음 예를 보겠습니다.

    A = {1, 2, 3, 4, 5 … ,N} 라는 배열에 아래 연산을 M번 수행한다고 생각해봅시다.

    1. 배열의 범위 합을 구하는 Query 연산
      • A[0] + A[1] + A[2] + … + A[N-1]
    2. i번째 배열 값을 v로 변경하는 Update 연산
      • A[i] = v

    단순한 방법으로는 각각 배열에 접근하여 연산을 한다면 시간 복잡도는 1번 연산 O(N), 2번 연산 O(1)이 됩니다.

    이 두 연산을 M번 수행한다면 총 시간 복잡도는 O(MN)+O(M) = O(MN)이 됩니다.

     

    세그먼트 트리를 사용하면 범위 최소/최대 및 합계 Query 및 범위 Update를 O(logN) 시간에 해결할 수 있습니다.

    위 두 연산에 세그먼트 트리를 사용한다면 O(MlogN) + O(MlogN) = O(MlogN) 시간으로 개선시킬 수 있습니다.

     

    세그먼트 트리 구성


    세그먼트 트리는 이진트리이므로 간단한 선형 배열을 사용하여 나타낼 수 있습니다.

    부모 노드의 인덱스가 i라면 2i은 왼쪽 자식 노드이고 2i+1은 오른쪽 자식 노드입니다.

     

    세그먼트 트리의 각 노드에는 구간 정보가 저장되어 있습니다.

    예) 구간 합, 구간 최대/최소 값

     

    다음은 배열 A = {1,2,3,4,5} 에 대한 세그먼트 트리를 시각화한 그림입니다.

    Segment Tree

    각 노드에는 배열의 구간 합이 저장되어 있습니다.

    리프 노드에는 주어진 배열(A) 값들이 저장되고 내부 노드에는 자식 노드의 합이 저장됩니다.

     

    세그먼트 트리 구현


    세그먼트 트리에는 세 가지(Build, Query, Update) 작업이 있습니다.

    여기서는 구간 합 정보를 저장하는 세그먼트 트리를 구현하는 방법을 알아보겠습니다.

     

    세그먼트 트리는 전 이진트리(Full Binary Tree)입니다.

    그렇기 때문에 크기가 n인 배열을 가지고 리프 노드가 n개인 세그먼트 트리를 만들 때 필요한 노드 수는 다음과 같습니다.

    $$ 1 + 2 + 4 + \dots + 2^{\lceil\log_2 n\rceil} \lt 2^{\lceil\log_2 n\rceil +1} \lt 4n  $$

     

    높이가 \(  h = \lceil\log_2 n\rceil  \) 라면

    필요한 배열 크기는 \( 2^{(h+1)} - 1 \) 이며 편의를 위해 \( 2^{(h+1)} \) 또는 \( 4n \)으로 크기를 정하기도 합니다.

    from math import ceil, log2
    
    height = ceil(log2(n))
    tree_size = 1 << (height+1)
    tree = [0] * tree_size
    
    # or
    tree_size = 4 * n
    tree = [0] * tree_size

     

    Build


    세그먼트 트리를 만드는 작업을 합니다.

    구현은 다음과 같습니다.

    • 세그먼트 트리는 재귀를 사용하여 구축할 수 있습니다.
    • 트리를 순회하며 작업을 진행합니다.
      • 리프 노드라면 배열(A)의 요소를 저장합니다.
      • 내부 노드라면 구간 정보를 저장합니다. (예를 들어 구간 정보가 구간 합이라면 두 자식 노드의 합을 저장합니다.)

    구현

    • 여기서 start, end는 배열 A에 대한 범위입니다.
    def build(node, start, end):
    	if start == end:
    		# 리프 노드라면 원소를 저장한다.
    		tree[node] = A[start]
    		return
    	
    	mid = (start + end) // 2
    	# 왼쪽 자식으로 재귀
    	build(node*2, start, mid)
    	# 오른쪽 자식으로 재귀
    	build(node*2+1, mid+1, end)
    	# 내부 노드라면 두 자식 노드의 합을 저장한다.
    	tree[node] = tree[node*2] + tree[node*2+1]
    	return

     

    Query


    세그먼트 트리에서 구간 정보를 가져옵니다.

     

    구현은 다음과 같습니다.

    • 트리를 순회합니다.
      • 노드가 나타내는 범위가 지정된 범위 밖에 있다면 0을 반환합니다.
      • 노드가 나타내는 범위가 지정된 범위 내에 있다면 값을 반환합니다.
      • 노드가 나타내는 범위가 지정된 범위 일부만 포함한다면 왼쪽 자식과 오른쪽 자식의 합을 반환합니다.

    예를 들어 배열 A = {1,2,3,4,5}에서 2부터 4까지 구간 합을 구한다면 아래 그림처럼 컬러 노드들을 방문하게 되고 그중 초록색 노드 값을 사용하여 결과를 반환하게 됩니다. (3+ 9 = 12)

    Segment Tree Query

    구현

    • 여기서 노드가 나타내는 범위는 start, end이고 지정된 범위는 left, right입니다.
    • 위 그림에서 지정된 범위는 left = 2, right = 4가 됩니다.
    def query(node, start, end, left, right):
    	if right < start or end < left:
    		# 노드가 지정된 범위 밖에 있는 경우
    		return 0
    
    	if left <= start and end <= right:
    		# 노드가 지정된 범위 안에 있는 경우
    		return tree[node]
    
    	# 노드가 지정된 범위 일부에 있는 경우
    	mid = (start + end)//2
    	left_child = query(node*2, start, mid, left, right)
    	right_child = query(node*2+1, mid+1, end, left, right)
    	
    	return left_child + right_child

     

    Update


    index번째 배열 A를 var값으로 변경하는 작업을 합니다. (A[index] = var)

    update는 index가 포함된 구간을 담당하는 노드들만 변경합니다.

     

    예를 들어 A = {1,2,3,4,5} 배열에서 A[2] = 5로 변경할 때, 변경해야 하는 구간은 다음과 같습니다. 색깔 노드는 방문하여 변경하는 노드입니다.

    Segment Tree Update

    구현은 다음과 같습니다.

    • 트리를 순회합니다.
    • index가 포함된 구간을 가지고 있는 자식 노드로 재귀합니다.
      • 리프 노드라면 배열 값을 변경합니다.
      • 내부 노드라면 구간 정보를 저장합니다. (예를 들어 구간 정보가 구간 합이라면 두 자식 노드의 합을 저장합니다.)
    def update(node, start, end, index, val):
    	if start == end:
    		# 리프 노드라면 배열 값을 변경한다.
    		tree[node] = val
    		return
    	
    	mid = (start + end) // 2
    	# index가 포함된 구간을 가진 자식 노드로 재귀한다.
    	if start <= index and index <= mid:
    		update(2*node, start, mid, index, val)
    	else:
    		updaet(2*node+1, mid+1, end, index, val)
    	# 내부 노드라면 두 자식 노드의 합을 저장한다.
    	tree[node] = tree[node*2] + tree[node*2+1]
    	return

     

    변화량을 사용한 구현

    • 변화량(diff)을 사용하여 업데이트하는 방법도 있습니다.
    • index 번째 배열 A를 val값으로 변경할 때 변화량은 diff = val - A[index]가 됩니다.
    • index가 포함된 구간을 가진 노드들만 diff만큼 증가시키는 방식으로 구현합니다.

    예를 들어 A = {1,2,3,4,5} 배열에서 A[2] = 5로 변경할 때, diff는 5 - 3 = 2가 되고 아래 색깔 노드만 diff만큼 증가시키면 됩니다.

    def update(node, start, end, index, diff):
    	if index < start or end < index:
    		#index가 노드 범위 밖이면 탐색을 중단한다.
    		return
    	
    	# 노드를 diff 만큼 증가시킨다.
    	tree[node] += diff
    
    	if start != end:
    		# 리프 노드가 아닌 경우 자식 노드를 update해준다.
    		mid = (start + end) // 2
    		update(node*2, start, mid, index, diff)
    		update(node*2+1, mid+1, end, index, diff)
    	return

     

     

    728x90
    Comments