c++ 자료구조, 알고리즘

c++ 세그먼트 트리

무한 나무 2023. 9. 23. 19:02

[참조]

[ 세그먼트 트리(Segment Tree) ] 개념과 구현방법 (C++) :: 얍문's Coding World.. (tistory.com


 

(1)구간 합(2)요소 변경 후 구간 합을 요구하는 문제에서 사용하는 트리.

 

노드에는 "구간 (start, end)" 과 "합(sum)" 이 저장되어 있다.

Array = {1,2,3,4,5}

그림은 구간이 배열 인덱스 기준이 아니지만 실제로는 배열의 index 기준으로 구간 값 할당. [시작 0]

 

루트 노드에는 전체 구간과 그 합이 저장.  (start , end)

왼쪽 노드에는 왼쪽 반 만큼의 구간과 그 합.  (start, (start+end)/2)

오른쪽 노드에는 오른쪽 반 만큼의 구간과 그 합.  ((start+end)/2+1 , end)

이렇게 내려가면서 리프 노드에는 단일 구간. 즉, 문제 배열의 요소의 인덱스와 그 값이 저장 된다.

( 그림은 구간이 배열 index 기준이 아니지만 실제로는 배열 index 기준으로 구간 값 할당. [시작 0] )

 

이진트리 구조이므로 배열로 만들어 인덱스 접근으로 구현 할 수 있으며, 총 트리 크기는 문제 요소가 N개일 때, 

  • 1<< Ceil(log2(N)) + 1  또는 N*4
  • Ceil(log2(N) 은 요소가 N개 일때 트리의 높이,   N*4 는 간단히 구현하고 싶을 때 러프하게 잡는 방법.

 

또한, 자식 인덱스 접근의 간편함을 위해 루트 노드인덱스가 1이며,

왼쪽 자식 인덱스 = cur*2

오른쪽 자식 인덱스 = cur*2 +1

로 접근.

 

만들어야 될 함수는 "세그먼트 트리 만들기", "구간 합 구하기", "요소 변경으로 세그먼트 트리 갱신" 3개 이다.

함수 모두 자식으로 점점 깊게 들어가는 "재귀 함수" 로 구현한다.

 

노드구조체로 정의해서 구현이 조금 더 직관적으로 되게 끔 해 보았다.

 


코드 구현

struct SegNode
{
    int start, end;
    int value;
    explicit SegNode()
        :start(0), end(0), value(0) {}
    explicit SegNode(int _start, int _end, int _value)
        :start(_start), end(_end), value(_value){}
};

// idx는 처음은 무조건 루트(1) 로 시작
int MakeSegTree(vector<SegNode>& SegTree, int idx, int start, int end, vector<int>& Array)
{
    if (start == end)
    {
        //리프 노드 도달 시, 바로 값 적용
        SegTree[idx] = SegNode(start, end, Array[start]);
    }
    else
    {
        int Sum = 0;
        int Mid = (start + end) / 2;
        int leftSum = MakeSegTree(SegTree, idx * 2, start, Mid, Array);
        int rightSum = MakeSegTree(SegTree, idx * 2 + 1, Mid + 1, end, Array);

        SegTree[idx] = SegNode(start, end, leftSum + rightSum);
    }
    
    return SegTree[idx].value;
}

void ChangeValueSegTree(vector<SegNode>& SegTree, int idx, int Array_idx, int Value)
{
    //목표 요소를 포함하고 있는 노드들은 갱신
    if (SegTree[idx].start <= Array_idx && Array_idx <= SegTree[idx].end)
    {
        SegTree[idx].value += Value;
        
        //리프면 종료.
        if (SegTree[idx].start == SegTree[idx].end)
            return;

        ChangeValueSegTree(SegTree, idx * 2, Array_idx, Value);
        ChangeValueSegTree(SegTree, idx * 2+1, Array_idx, Value);
    }
}

int GetSumSegTree(vector<SegNode>& SegTree, int idx, int start, int end)
{
    if (start <= SegTree[idx].start && SegTree[idx].end <= end)     //현재 노드가 목표 구간 내에 포함 시, 노드의 구간 합 리턴.
        return SegTree[idx].value;
    else if (end < SegTree[idx].start || SegTree[idx].end < start)  //현재 노드가 목표 구간 밖에 있을 시, 0 리턴.
        return 0;
    else
    {
        //걸쳐져 있을 시, 자식들의 구간 합을 구하고 합하여 리턴
        int left = GetSumSegTree(SegTree, idx * 2, start, end);
        int right = GetSumSegTree(SegTree, idx * 2+1, start, end);
        return left + right;
    }
}
main.cpp

    vector<int> arrays = { 1,2,3,4,5 };

    vector<SegNode> segment(arrays.size() * 4, SegNode());

    MakeSegTree(segment, 1, 0, 4, arrays);

    int ques1 = GetSumSegTree(segment, 1, 2, 4);   //2~4 합
    int ques2 = GetSumSegTree(segment, 1, 1, 3);   //1~3 합 
    ChangeValueSegTree(segment, 1, 2, 5 - arrays[2]);    //3번째 수를 5로 교체
    int ques3 = GetSumSegTree(segment, 1, 0, 2);
    cout << ques1 << ", " << ques2 << ", " << ques3 << endl;