짜잘한 기록

[나만 몰랐던 알고리즘] 세그먼트 트리

슨민 2021. 8. 30. 22:30

트리 구조를 띄는 자료 구조인 세그먼트 트리는, 일렬로 나열되어 있는 자료 중에 어느 부분(세그먼트)의 연산(합, 최소, 최대, 곱 등)을 구하는데 많이 쓰인다.

 

크게 다음과 같은 문제 조건인지 확인해 보면 된다.

1. 구간 l, r (l <= r) 이 주어졌을때, A[l] + A[l+1] + ... + A[r-1] + A[r]을 구하기. (여기서 곱이나 최소, 최대의 연산을 할 수도 있음)
2. i 번째 수를 v로 바꾸기. A[i] = v

그냥 배열에 두고, 각 배열을 돌면서 l부터 r까지 연산을 해도 답은 나온다.

양쪽 방향으로 누적합을 구해두고 연산하는 투 포인터 알고리즘을 써도 괜찮지만 연산 횟수가 많아지면 많아질수록, 데이터의 수가 커지면 커질수록 이 방법은 한계가 온다.

 

그것을 극복하기 위해 세그먼트 트리를 사용한다.

세그먼트 트리의 구조를 한번 보자.

일단, 트리의 각 노드는 말단 노드(terminal node, outer node)와 내부 노드(internal node, branch node)로 나눌 수 있다.

세그먼트 트리에서 말단 노드는 각 데이터의 값을 가진다. 그림에서 첫번째 말단 노드부터 77, 24, 6 .. 의 값을 가지는 것을 볼 수 있다.

내부 노드는 각 자식들의 합을 값으로 가진다. 그림에서 77, 24를 자식으로 가지는 노드가 101의 값을 가지는 것을 볼 수 있다.

 

구조를 딱 보면, 노드의 위치와 범위를 잘 주무르다 보면 우리가 원하는 구간의 값을 빠르게 구할 수 있는 감이온다.

이제 코드를 보고 어떻게 돌아가는지 보자.

Init

long long init(int start, int end, int node, long long tree[], long long numbers[]) {
    if (start == end)
        return tree[node] = numbers[start];
    int mid = (start + end) / 2;
    tree[node] = init(start, mid, node * 2, tree, numbers) + init(mid+1, end, node*2+1, tree, numbers);
    return tree[node];
}

시작 범위, 종료 범위, 노드 인덱스를 받고, numbers[] 배열에 있는 데이터를 토대로 tree[]에 세그먼트 트리를 초기화한다.

말단 노드가 나올 때까지 init(start, mid) + init(mid + 1, end)를 구해서 해당 노드에 넣어준다. 리턴값은 초기화된 노드의 값.

말단 노드는 start == end 일 경우를 조건으로 잡았다. 이 경우, 해당 노드 인덱스 트리 값에 numbers[]에서 값을 가져와 저장한다.

getSum

long long getSum(int start, int end, int node, int left, int right, long long tree[]) {
    if (left > end || right < start)
        return 0;
    if (left <= start && end <= right)
        return tree[node];
    int mid = (start + end) / 2;
    return getSum(start, mid, node * 2, left, right, tree) + getSum(mid+1, end, node * 2 + 1, left, right, tree);
}

(탐색의) 시작 범위, 종료 범위, 노드 인덱스, (값을 얻어내고자 하는) 왼쪽 위치, 오른쪽 위치를 받아서 tree[] 에서 검색을 한다.

탐색의 범위는 값을 얻어내고자 하는 범위를 찾는 위치이다. 예를 들면 건초더미(탐색의 범위)에서 바늘(값을 얻어내고자 하는 범위)을 찾는다랄까? 다르게 보면, start와 end는 현재 노드가 가지고 있는 데이터 범위라고 생각할 수 있다.

이 범위가 값을 얻어내고자 하는 범위와 어떤 상대적 위치를 띄는지에 따라 if 분기로 3개의 행동을 한다.

그림에서 검색 범위가 left, right라고 생각하자. 색으로 표시되어 있는 범위는 start, end로 표시되는 내 노드가 가지고 있는 데이터 범위이다.

 

1. 의 경우는 내 범위가 검색 범위 안에 완전히 포함되는 경우이다. 첫번째의 경우 두 범위가 일치하는 경우이고, 두번째의 경우 검색 범위 안에 포함되어 있는 경우이다. 이 두 경우 모두, 지금 노드의 값을 리턴한다. -> 두번째 if 구문

2. 의 경우는 내 범위가 검색 범위에 포함은 되지만 일부 걸쳐있는 경우나 검색 범위가 내 범위에 속해있는 경우이다. 이 경우 getSum(start, mid) + getSum(mid + 1, end)로 나누어 밑 노드에서 검색을 계속하게 된다 -> if 가 걸리지 않고 쭉 실행되는 구문

3. 의 경우는 내 범위와 검색 범위가 아예 교집합을 가지지 않는 경우이다. 이 경우 0을 리턴한다 -> 첫 번째 if 구문.

 

범위가 어떤지에 따라 다르지만, 대부분의 흐름에선 두번째 경우의 함수로 시작해서, 각 노드나 검색 범위의 중심부 노드들을 첫번째 경우로가지는 함수에서 값을 뱉고, 가장자리나 범위 바깥은 세번째 경우로 0을 리턴해 올라온것을 두번째 경우로 시작했던 함수에서 최종 합을 연산하게 된다.

updateValue

void updateValue(int start, int end, int node, int index, long long diff, long long tree[]) {
    if (index < start || index > end)
        return;
    tree[node] += diff;
    if (start == end)
        return ;
    int mid = (start + end) / 2;
    updateValue(start, mid, node * 2, index, diff, tree);
    updateValue(mid + 1, end, node * 2 + 1, index, diff, tree);
}

트리 초기화, 합의 경우를 보았으니 이제 값을 변경하는 경우를 보자.

함수의 프로토타입은 유사하다. 시작 범위, 끝 범위, 노드 인덱스와 변경하고자 하는 인덱스를 받아 diff만큼 값을 변경한다.

첫 if 문에서는 해당 인덱스가 현재 노드에 속한지 확인한다. 만약 속하지 않는 노드라면 값을 수정할 필요가 없으니 바로 나온다.

이후 로직을 실행한다는 것은, 해당 인덱스가 현재 노드에 속한다는 것을 의미한다. 그래서 일단 현재 노드의 값을 수정해준다.

그리고 현재 노드가 말단 노드이면(start == end) 밑 자식들 값을 변경해줄 필요가 없으니 종료.

그렇지 않으면 updateValue(start, mid)와 updateValue(mid+1, end)를 실행해 자식들도 갱신시켜준다.


세그먼트 트리는 대체적으로 이런 로직으로 돌아간다. "Segment"라는 이름에 맞게 "범위"를 중심으로 연산이 이루어 지는 것을 볼 수 있었다. 트리는 진짜 별의 별 응용이 다 있는것 같다.

 

참고:

https://visualgo.net/en/segmenttree

 

VisuAlgo - Segment Tree

VisuAlgo is free of charge for Computer Science community on earth. If you like VisuAlgo, the only payment that we ask of you is for you to tell the existence of VisuAlgo to other Computer Science students/instructors that you know =) via Facebook, Twitter

visualgo.net

https://www.acmicpc.net/blog/view/9