c++ 세그먼트 트리
[참조]
[ 세그먼트 트리(Segment Tree) ] 개념과 구현방법 (C++) :: 얍문's Coding World.. (tistory.com
(1)구간 합과 (2)요소 변경 후 구간 합을 요구하는 문제에서 사용하는 트리.
각 노드에는 "구간 (start, end)" 과 "합(sum)" 이 저장되어 있다.
Array = {1,2,3,4,5}
루트 노드에는 전체 구간과 그 합이 저장. (start , end)
왼쪽 노드에는 왼쪽 반 만큼의 구간과 그 합. (start, (start+end)/2)
오른쪽 노드에는 오른쪽 반 만큼의 구간과 그 합. ((start+end)/2+1 , end)
이렇게 내려가면서 리프 노드에는 단일 구간. 즉, 문제 배열의 요소의 인덱스와 그 값이 저장 된다.
( 그림은 구간이 배열 index 기준이 아니지만 실제로는 배열 index 기준으로 구간 값 할당. [시작 0] )
이진트리 구조이므로 배열로 만들어 인덱스 접근으로 구현 할 수 있으며, 총 트리 크기는 문제 요소가 N개일 때,
- 1<< Ceil(log2(N)) + 1 또는 N*4
- Ceil(log2(N) 은 요소가 N개 일때 트리의 높이, N*4 는 간단히 구현하고 싶을 때 러프하게 잡는 방법.
또한, 자식 인덱스 접근의 간편함을 위해 루트 노드는 인덱스가 1이며,
왼쪽 자식 인덱스 = cur*2
오른쪽 자식 인덱스 = cur*2 +1
로 접근.
만들어야 될 함수는 "세그먼트 트리 만들기", "구간 합 구하기", "요소 변경으로 세그먼트 트리 갱신" 3개 이다.
함수 모두 자식으로 점점 깊게 들어가는 "재귀 함수" 로 구현한다.
노드는 구조체로 정의해서 구현이 조금 더 직관적으로 되게 끔 해 보았다.
코드 구현
struct SegNode
{
int start, end;
int value;
explicit SegNode()
:start(0), end(0), value(0) {}
explicit SegNode(int _start, int _end, int _value)
:start(_start), end(_end), value(_value){}
};
// idx는 처음은 무조건 루트(1) 로 시작
int MakeSegTree(vector<SegNode>& SegTree, int idx, int start, int end, vector<int>& Array)
{
if (start == end)
{
//리프 노드 도달 시, 바로 값 적용
SegTree[idx] = SegNode(start, end, Array[start]);
}
else
{
int Sum = 0;
int Mid = (start + end) / 2;
int leftSum = MakeSegTree(SegTree, idx * 2, start, Mid, Array);
int rightSum = MakeSegTree(SegTree, idx * 2 + 1, Mid + 1, end, Array);
SegTree[idx] = SegNode(start, end, leftSum + rightSum);
}
return SegTree[idx].value;
}
void ChangeValueSegTree(vector<SegNode>& SegTree, int idx, int Array_idx, int Value)
{
//목표 요소를 포함하고 있는 노드들은 갱신
if (SegTree[idx].start <= Array_idx && Array_idx <= SegTree[idx].end)
{
SegTree[idx].value += Value;
//리프면 종료.
if (SegTree[idx].start == SegTree[idx].end)
return;
ChangeValueSegTree(SegTree, idx * 2, Array_idx, Value);
ChangeValueSegTree(SegTree, idx * 2+1, Array_idx, Value);
}
}
int GetSumSegTree(vector<SegNode>& SegTree, int idx, int start, int end)
{
if (start <= SegTree[idx].start && SegTree[idx].end <= end) //현재 노드가 목표 구간 내에 포함 시, 노드의 구간 합 리턴.
return SegTree[idx].value;
else if (end < SegTree[idx].start || SegTree[idx].end < start) //현재 노드가 목표 구간 밖에 있을 시, 0 리턴.
return 0;
else
{
//걸쳐져 있을 시, 자식들의 구간 합을 구하고 합하여 리턴
int left = GetSumSegTree(SegTree, idx * 2, start, end);
int right = GetSumSegTree(SegTree, idx * 2+1, start, end);
return left + right;
}
}
main.cpp
vector<int> arrays = { 1,2,3,4,5 };
vector<SegNode> segment(arrays.size() * 4, SegNode());
MakeSegTree(segment, 1, 0, 4, arrays);
int ques1 = GetSumSegTree(segment, 1, 2, 4); //2~4 합
int ques2 = GetSumSegTree(segment, 1, 1, 3); //1~3 합
ChangeValueSegTree(segment, 1, 2, 5 - arrays[2]); //3번째 수를 5로 교체
int ques3 = GetSumSegTree(segment, 1, 0, 2);
cout << ques1 << ", " << ques2 << ", " << ques3 << endl;