Loading learning content...
In the previous module, we discovered why segment trees exist—their ability to answer range queries and handle point updates in O(log n) time, something neither simple arrays nor prefix sums can achieve when both query and update operations are required. Now comes the moment of truth: how do we actually build one?
The construction of a segment tree is where theory meets implementation. It's the process of transforming a flat, linear array into a hierarchical tree structure where each node captures aggregate information about a range of the original array. This transformation is not merely mechanical—it embodies a profound divide-and-conquer strategy that forms the backbone of segment tree efficiency.
By the end of this page, you will understand the complete construction algorithm for segment trees, including the recursive decomposition strategy, how to compute internal node values, the relationship between tree structure and array indices, and the initialization process that turns O(n) preprocessing into O(log n) queries.
Before diving into code, let's crystallize the core insight that makes segment tree construction possible:
Every range can be recursively decomposed into two halves.
This might seem trivial, but its implications are profound. Consider a range [0, 7] covering 8 elements:
[0, 7] = [0, 3] ∪ [4, 7]
[0, 3] = [0, 1] ∪ [2, 3]
[4, 7] = [4, 5] ∪ [6, 7]
[0, 1] = [0, 0] ∪ [1, 1]
... and so on
This binary decomposition creates a complete binary tree structure where:
Since we're dividing by 2 at each level, a segment tree for n elements has height ⌈log₂(n)⌉. This logarithmic height is the source of O(log n) query time—we never need to traverse more than this many levels.
What does each node store?
Each node in the segment tree stores the aggregate value for its range, determined by the specific operation:
| Query Type | Node Value | Combination Function |
|---|---|---|
| Range Sum | Sum of elements in range | node = left + right |
| Range Min | Minimum element in range | node = min(left, right) |
| Range Max | Maximum element in range | node = max(left, right) |
| Range GCD | GCD of elements in range | node = gcd(left, right) |
| Range XOR | XOR of elements in range | node = left ^ right |
The construction process is operation-agnostic in its structure—only the combination function changes. This makes segment trees remarkably versatile.
Let's trace through the construction of a segment tree for Range Sum queries. Consider the input array:
array = [1, 3, 5, 7, 9, 11]
indices: 0 1 2 3 4 5
We'll build a tree where each node stores the sum of its range.
| Step | Range | Computation | Node Value |
|---|---|---|---|
| Leaf | [0,0] | array[0] | 1 |
| Leaf | [1,1] | array[1] | 3 |
| Leaf | [2,2] | array[2] | 5 |
| Leaf | [3,3] | array[3] | 7 |
| Leaf | [4,4] | array[4] | 9 |
| Leaf | [5,5] | array[5] | 11 |
| Internal | [0,1] | 1 + 3 | 4 |
| Internal | [2,2] | 5 (single element) | 5 |
| Internal | [3,4] | 7 + 9 | 16 |
| Internal | [5,5] | 11 (single element) | 11 |
| Internal | [0,2] | 4 + 5 | 9 |
| Internal | [3,5] | 16 + 11 | 27 |
| Root | [0,5] | 9 + 27 | 36 |
The resulting tree structure (shown with ranges and values):
[0,5]=36
/ \
[0,2]=9 [3,5]=27
/ \ / \
[0,1]=4 [2,2]=5 [3,4]=16 [5,5]=11
/ \ / \
[0,0]=1 [1,1]=3 [3,3]=7 [4,4]=9
Notice the patterns:
When the array size isn't a power of 2, the tree becomes slightly unbalanced. Some leaf nodes appear at the second-to-last level while others are at the last level. The algorithms handle this naturally through their recursive structure—no special cases needed.
The construction algorithm follows directly from the divide-and-conquer insight. We define a recursive function build(node, start, end) that:
start == end (single element), the node is a leaf—store array[start]The algorithm naturally mirrors the tree structure it creates.
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
class SegmentTree: """ A segment tree for range sum queries. This implementation uses a recursive build approach. """ def __init__(self, arr): """ Initialize and build the segment tree from the input array. Args: arr: The input array (list of numbers) Time Complexity: O(n) where n = len(arr) Space Complexity: O(n) for the tree storage """ self.n = len(arr) self.arr = arr # Allocate space for the segment tree # A segment tree for n elements needs at most 4n space # (This is explained in detail in the next page) self.tree = [0] * (4 * self.n) # Build the tree starting from root (node index 1) # covering the full range [0, n-1] if self.n > 0: self._build(1, 0, self.n - 1) def _build(self, node, start, end): """ Recursively build the segment tree. Args: node: Current node index in the tree array start: Left boundary of the range (inclusive) end: Right boundary of the range (inclusive) This function fills tree[node] with the aggregate value for the range [start, end]. """ # Base case: leaf node (single element) if start == end: # A leaf node stores the array element itself self.tree[node] = self.arr[start] return # Recursive case: internal node # Find the midpoint to divide the range mid = (start + end) // 2 # Left child covers [start, mid] # Right child covers [mid+1, end] left_child = 2 * node right_child = 2 * node + 1 # Recursively build left and right subtrees self._build(left_child, start, mid) self._build(right_child, mid + 1, end) # Internal node value = combination of children # For range sum queries, this is addition self.tree[node] = self.tree[left_child] + self.tree[right_child] # Example usage demonstrating constructiondef demonstrate_construction(): """ Walk through the construction process step by step. """ arr = [1, 3, 5, 7, 9, 11] print(f"Input array: {arr}") print(f"Array length: {len(arr)}") # Build the segment tree st = SegmentTree(arr) # Display the tree contents print("\nSegment tree array (non-zero values):") for i, val in enumerate(st.tree): if val != 0: print(f" tree[{i}] = {val}") # The root (tree[1]) should contain the sum of all elements print(f"\nRoot value (tree[1]): {st.tree[1]}") print(f"Sum of array: {sum(arr)}") print(f"Match: {st.tree[1] == sum(arr)}") if __name__ == "__main__": demonstrate_construction()Understanding the recursion flow:
Let's trace _build(1, 0, 5) for our example array [1, 3, 5, 7, 9, 11]:
_build(1, 0, 5) → mid=2
├── _build(2, 0, 2) → mid=1
│ ├── _build(4, 0, 1) → mid=0
│ │ ├── _build(8, 0, 0) → leaf, tree[8]=1
│ │ └── _build(9, 1, 1) → leaf, tree[9]=3
│ │ → tree[4] = 1 + 3 = 4
│ └── _build(5, 2, 2) → leaf, tree[5]=5
│ → tree[2] = 4 + 5 = 9
└── _build(3, 3, 5) → mid=4
├── _build(6, 3, 4) → mid=3
│ ├── _build(12, 3, 3) → leaf, tree[12]=7
│ └── _build(13, 4, 4) → leaf, tree[13]=9
│ → tree[6] = 7 + 9 = 16
└── _build(7, 5, 5) → leaf, tree[7]=11
→ tree[3] = 16 + 11 = 27
→ tree[1] = 9 + 27 = 36
Each node is visited exactly once, and the work at each node (excluding recursive calls) is O(1). With n leaves and approximately n-1 internal nodes, the total time is O(n).
The line self.tree[node] = self.tree[left] + self.tree[right] is the soul of the segment tree. This combination function determines what queries the tree can answer.
A segment tree can support any operation that is associative—meaning (a ⊕ b) ⊕ c = a ⊕ (b ⊕ c). This property is crucial because it allows us to combine arbitrary contiguous ranges in any order and get the same result.
combine(a, b) = a + b — The most common use casecombine(a, b) = min(a, b) — Range minimum queries (RMQ)combine(a, b) = max(a, b) — Range maximum queriescombine(a, b) = gcd(a, b) — Greatest common divisor queriescombine(a, b) = lcm(a, b) — Least common multiple queriescombine(a, b) = a | b — Useful for flag combinationscombine(a, b) = a & b — Intersection of bit patternscombine(a, b) = a ^ b — Parity and toggle queriescombine(a, b) = a × b — Product queries (watch for overflow)123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
"""Segment Tree Variants: Different Combination Functions This module demonstrates how to build segment trees forvarious query types by simply changing the combination function.""" class SegmentTreeMin: """Segment tree for Range Minimum Queries (RMQ).""" def __init__(self, arr): self.n = len(arr) self.arr = arr self.tree = [float('inf')] * (4 * self.n) # Identity for min if self.n > 0: self._build(1, 0, self.n - 1) def _build(self, node, start, end): if start == end: self.tree[node] = self.arr[start] return mid = (start + end) // 2 left_child, right_child = 2 * node, 2 * node + 1 self._build(left_child, start, mid) self._build(right_child, mid + 1, end) # Combination function: minimum self.tree[node] = min(self.tree[left_child], self.tree[right_child]) class SegmentTreeMax: """Segment tree for Range Maximum Queries.""" def __init__(self, arr): self.n = len(arr) self.arr = arr self.tree = [float('-inf')] * (4 * self.n) # Identity for max if self.n > 0: self._build(1, 0, self.n - 1) def _build(self, node, start, end): if start == end: self.tree[node] = self.arr[start] return mid = (start + end) // 2 left_child, right_child = 2 * node, 2 * node + 1 self._build(left_child, start, mid) self._build(right_child, mid + 1, end) # Combination function: maximum self.tree[node] = max(self.tree[left_child], self.tree[right_child]) class SegmentTreeGCD: """Segment tree for Range GCD Queries.""" def __init__(self, arr): from math import gcd self.gcd = gcd self.n = len(arr) self.arr = arr self.tree = [0] * (4 * self.n) # Identity for gcd (gcd(x, 0) = x) if self.n > 0: self._build(1, 0, self.n - 1) def _build(self, node, start, end): if start == end: self.tree[node] = self.arr[start] return mid = (start + end) // 2 left_child, right_child = 2 * node, 2 * node + 1 self._build(left_child, start, mid) self._build(right_child, mid + 1, end) # Combination function: GCD self.tree[node] = self.gcd(self.tree[left_child], self.tree[right_child]) class SegmentTreeXOR: """Segment tree for Range XOR Queries.""" def __init__(self, arr): self.n = len(arr) self.arr = arr self.tree = [0] * (4 * self.n) # Identity for XOR if self.n > 0: self._build(1, 0, self.n - 1) def _build(self, node, start, end): if start == end: self.tree[node] = self.arr[start] return mid = (start + end) // 2 left_child, right_child = 2 * node, 2 * node + 1 self._build(left_child, start, mid) self._build(right_child, mid + 1, end) # Combination function: XOR self.tree[node] = self.tree[left_child] ^ self.tree[right_child] # Demonstrationdef demonstrate_variants(): arr = [12, 6, 18, 9, 15, 3] print(f"Input array: {arr}") st_min = SegmentTreeMin(arr) st_max = SegmentTreeMax(arr) st_gcd = SegmentTreeGCD(arr) st_xor = SegmentTreeXOR(arr) print(f"\nRange [0, 5] results:") print(f" Min: {st_min.tree[1]} (expected: {min(arr)})") print(f" Max: {st_max.tree[1]} (expected: {max(arr)})") from math import gcd from functools import reduce print(f" GCD: {st_gcd.tree[1]} (expected: {reduce(gcd, arr)})") print(f" XOR: {st_xor.tree[1]} (expected: {reduce(lambda x, y: x ^ y, arr)})") if __name__ == "__main__": demonstrate_variants()Operations like subtraction and division are NOT associative. (a - b) - c ≠ a - (b - c). Segment trees cannot directly support range subtraction or division queries. If you need these, you must reformulate the problem (e.g., subtraction can be done via sum queries of carefully designed ranges).
You'll encounter two common indexing conventions for segment trees:
1-Based Indexing (Root at index 1)
i is at 2*ii is at 2*i + 1i is at i // 20-Based Indexing (Root at index 0)
i is at 2*i + 1i is at 2*i + 2i is at (i - 1) // 212345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
"""Comparison of 1-based and 0-based indexing in segment trees.Both implementations produce functionally identical trees.""" class SegmentTree1Based: """Segment tree using 1-based indexing (root at index 1).""" def __init__(self, arr): self.n = len(arr) self.tree = [0] * (4 * self.n) if self.n > 0: self._build(arr, 1, 0, self.n - 1) def _build(self, arr, node, start, end): if start == end: self.tree[node] = arr[start] return mid = (start + end) // 2 # 1-based: left = 2*node, right = 2*node + 1 self._build(arr, 2 * node, start, mid) self._build(arr, 2 * node + 1, mid + 1, end) self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1] class SegmentTree0Based: """Segment tree using 0-based indexing (root at index 0).""" def __init__(self, arr): self.n = len(arr) self.tree = [0] * (4 * self.n) if self.n > 0: self._build(arr, 0, 0, self.n - 1) def _build(self, arr, node, start, end): if start == end: self.tree[node] = arr[start] return mid = (start + end) // 2 # 0-based: left = 2*node + 1, right = 2*node + 2 self._build(arr, 2 * node + 1, start, mid) self._build(arr, 2 * node + 2, mid + 1, end) self.tree[node] = self.tree[2 * node + 1] + self.tree[2 * node + 2] # Verify both produce the same resultsarr = [1, 3, 5, 7, 9, 11]st1 = SegmentTree1Based(arr)st0 = SegmentTree0Based(arr) print("1-Based (root at tree[1]):", st1.tree[1])print("0-Based (root at tree[0]):", st0.tree[0])print("Both equal sum:", sum(arr))Throughout this module, we use 1-based indexing for the tree array. This matches most competitive programming resources and produces cleaner code. The root is at tree[1], its children at tree[2] and tree[3], and so on.
Robust segment tree construction must handle several edge cases gracefully:
1. Empty Array
An empty array produces an empty tree. We should check for n == 0 before building.
2. Single Element
With n == 1, the tree has only one node (the root), which is also a leaf. The range is [0, 0].
3. Array Size Not a Power of 2 This is the common case. The tree is slightly unbalanced, but the algorithms handle it naturally.
4. Large Values and Overflow
For sum queries, internal nodes can hold values much larger than individual elements. Consider using long long (C++) or being aware of Python's arbitrary precision integers.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
"""A robust segment tree implementation handling all edge cases.""" class RobustSegmentTree: """ Production-quality segment tree with comprehensive edge case handling. """ def __init__(self, arr, combine=lambda a, b: a + b, identity=0): """ Initialize segment tree with custom operation. Args: arr: Input array (can be empty) combine: Binary function to combine two values identity: Identity element for the operation (0 for sum, inf for min, -inf for max) """ self.combine = combine self.identity = identity # Handle empty array if not arr: self.n = 0 self.tree = [] return self.n = len(arr) self.arr = arr[:] # Create a copy # Allocate tree with safety margin self.tree = [identity] * (4 * self.n) # Build the tree self._build(1, 0, self.n - 1) def _build(self, node, start, end): """Build segment tree recursively.""" # Base case: single element (leaf) if start == end: self.tree[node] = self.arr[start] return mid = (start + end) // 2 left_child = 2 * node right_child = 2 * node + 1 # Build children self._build(left_child, start, mid) self._build(right_child, mid + 1, end) # Combine children values self.tree[node] = self.combine( self.tree[left_child], self.tree[right_child] ) def is_empty(self): """Check if the tree is empty.""" return self.n == 0 def size(self): """Return the number of elements in the original array.""" return self.n def root_value(self): """Return the aggregate of the entire array.""" if self.is_empty(): return self.identity return self.tree[1] # Edge case demonstrationsdef test_edge_cases(): print("=== Edge Case Testing ===\n") # Empty array empty_tree = RobustSegmentTree([]) print(f"Empty array:") print(f" Size: {empty_tree.size()}") print(f" Is empty: {empty_tree.is_empty()}") print(f" Root value: {empty_tree.root_value()}") # Single element single = RobustSegmentTree([42]) print(f"\nSingle element [42]:") print(f" Size: {single.size()}") print(f" Root value: {single.root_value()}") # Non-power-of-2 size odd_size = RobustSegmentTree([1, 2, 3, 4, 5]) print(f"\nOdd size array [1,2,3,4,5]:") print(f" Size: {odd_size.size()}") print(f" Root value: {odd_size.root_value()} (sum = 15)") # Large values (Python handles this automatically) large_values = RobustSegmentTree([10**18, 10**18, 10**18]) print(f"\nLarge values [10^18, 10^18, 10^18]:") print(f" Root value: {large_values.root_value()}") print(f" Expected: {3 * 10**18}") # Custom operation: minimum min_tree = RobustSegmentTree( [5, 2, 8, 1, 9], combine=min, identity=float('inf') ) print(f"\nMin tree [5,2,8,1,9]:") print(f" Root value (min): {min_tree.root_value()}") if __name__ == "__main__": test_edge_cases()Understanding the construction complexity is essential for making informed decisions about when to use segment trees.
Time Complexity: O(n)
The build function visits each node exactly once. For a segment tree with n leaves:
At each node, we perform O(1) work (combine operation), so total time is O(n).
Space Complexity: O(n)
We allocate a tree array of size 4n. Why 4n? The detailed analysis is in the next page, but briefly:
| Aspect | Complexity | Notes |
|---|---|---|
| Time (Build) | O(n) | Each node visited once |
| Space (Tree) | O(n) | 4n allocation is common |
| Auxiliary Space | O(log n) | Recursion stack depth |
| Combine Operation | O(1) | Assumed constant-time |
The factor of 4 might seem wasteful, but it's a practical choice that avoids complex calculations. In the next page, we'll explore the array representation in depth and see exactly why 4n suffices and how to calculate tighter bounds if memory is critical.
Let's consolidate everything into a complete, well-documented implementation:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
"""Complete Segment Tree Implementation==================================== A production-ready segment tree supporting:- Range Sum Queries (default)- Custom combine operations- Construction, Query, and Update operations Author: DSA Learning Module""" from typing import List, Callable, TypeVar T = TypeVar('T') class SegmentTree: """ A flexible segment tree implementation. The segment tree is a binary tree where each node stores aggregate information about a range of the original array. Attributes: n: Size of the original array tree: Array storing the segment tree nodes combine: Binary function to combine two values identity: Identity element for the combine operation """ def __init__( self, arr: List[T], combine: Callable[[T, T], T] = lambda a, b: a + b, identity: T = 0 ): """ Build a segment tree from the input array. Args: arr: Input array combine: Binary associative operation identity: Identity element (e.g., 0 for +, inf for min) Time Complexity: O(n) Space Complexity: O(n) """ self.combine = combine self.identity = identity self.n = len(arr) if self.n == 0: self.tree = [] return # Allocate tree storage (4n is a safe upper bound) self.tree = [identity] * (4 * self.n) # Build the tree (root at index 1) self._build(arr, 1, 0, self.n - 1) def _build(self, arr: List[T], node: int, start: int, end: int) -> None: """ Recursively construct the segment tree. Args: arr: Original input array node: Current node index in tree start: Left boundary of current range end: Right boundary of current range """ # Base case: leaf node if start == end: self.tree[node] = arr[start] return # Recursive case: build children first mid = (start + end) // 2 left_child = 2 * node right_child = 2 * node + 1 self._build(arr, left_child, start, mid) self._build(arr, right_child, mid + 1, end) # Current node = combination of children self.tree[node] = self.combine( self.tree[left_child], self.tree[right_child] ) def __repr__(self) -> str: if self.n == 0: return "SegmentTree(empty)" return f"SegmentTree(n={self.n}, root={self.tree[1]})" # ============================================================# Test Suite for Segment Tree Construction# ============================================================ def test_construction(): """Comprehensive tests for segment tree construction.""" print("=" * 60) print("SEGMENT TREE CONSTRUCTION TESTS") print("=" * 60) # Test 1: Basic sum tree print("\n[Test 1] Basic Sum Tree") arr = [1, 3, 5, 7, 9, 11] st = SegmentTree(arr) expected_sum = sum(arr) assert st.tree[1] == expected_sum, f"Root should be {expected_sum}" print(f" Array: {arr}") print(f" Root (total sum): {st.tree[1]} ✓") # Test 2: Power of 2 size print("\n[Test 2] Power of 2 Size") arr_pow2 = [1, 2, 3, 4, 5, 6, 7, 8] st_pow2 = SegmentTree(arr_pow2) assert st_pow2.tree[1] == 36 print(f" Array (n=8): {arr_pow2}") print(f" Root: {st_pow2.tree[1]} ✓") # Test 3: Non-power of 2 size print("\n[Test 3] Non-Power of 2 Size") arr_npow2 = [1, 2, 3, 4, 5, 6, 7] st_npow2 = SegmentTree(arr_npow2) assert st_npow2.tree[1] == 28 print(f" Array (n=7): {arr_npow2}") print(f" Root: {st_npow2.tree[1]} ✓") # Test 4: Single element print("\n[Test 4] Single Element") st_single = SegmentTree([42]) assert st_single.tree[1] == 42 print(f" Array: [42]") print(f" Root: {st_single.tree[1]} ✓") # Test 5: Empty array print("\n[Test 5] Empty Array") st_empty = SegmentTree([]) assert st_empty.n == 0 print(f" Array: []") print(f" Tree is empty: {len(st_empty.tree) == 0} ✓") # Test 6: Min tree print("\n[Test 6] Minimum Query Tree") arr_min = [5, 2, 8, 1, 9, 3] st_min = SegmentTree(arr_min, combine=min, identity=float('inf')) assert st_min.tree[1] == 1 print(f" Array: {arr_min}") print(f" Root (min): {st_min.tree[1]} ✓") # Test 7: Max tree print("\n[Test 7] Maximum Query Tree") arr_max = [5, 2, 8, 1, 9, 3] st_max = SegmentTree(arr_max, combine=max, identity=float('-inf')) assert st_max.tree[1] == 9 print(f" Array: {arr_max}") print(f" Root (max): {st_max.tree[1]} ✓") # Test 8: GCD tree print("\n[Test 8] GCD Query Tree") from math import gcd arr_gcd = [12, 18, 24, 6] st_gcd = SegmentTree(arr_gcd, combine=gcd, identity=0) assert st_gcd.tree[1] == 6 print(f" Array: {arr_gcd}") print(f" Root (GCD): {st_gcd.tree[1]} ✓") # Test 9: Large values print("\n[Test 9] Large Values") arr_large = [10**15, 10**15, 10**15] st_large = SegmentTree(arr_large) assert st_large.tree[1] == 3 * 10**15 print(f" Array: [10^15, 10^15, 10^15]") print(f" Root: {st_large.tree[1]} ✓") # Test 10: Negative values print("\n[Test 10] Negative Values") arr_neg = [-5, 3, -2, 8, -1] st_neg = SegmentTree(arr_neg) assert st_neg.tree[1] == 3 print(f" Array: {arr_neg}") print(f" Root: {st_neg.tree[1]} ✓") print("\n" + "=" * 60) print("ALL CONSTRUCTION TESTS PASSED!") print("=" * 60) if __name__ == "__main__": test_construction()We've explored the construction of segment trees from first principles. Let's consolidate the key insights:
What's Next:
Now that we can build a segment tree, we need to understand the array representation in detail. The next page explores how the tree maps to a flat array, why we need 4n space, and how to navigate between parent and child nodes efficiently.
You now understand how to build a segment tree from an array. The recursive algorithm transforms an O(n) preprocessing step into the foundation for O(log n) queries and updates. Next, we'll explore the array-based representation that makes this efficiency possible.