Loading content...
We've built a segment tree. We've understood how it's stored in memory. But the construction was just preparation—the real magic happens in queries.
A query asks: "What is the aggregate (sum, min, max, etc.) of elements in the range [L, R]?" The naive approach would scan every element from L to R, taking O(n) time. With a segment tree, we answer in O(log n) time by cleverly combining precomputed ranges.
How? We decompose the query range [L, R] into a minimal set of precomputed ranges stored in tree nodes, then combine their values. The segment tree structure ensures this decomposition uses at most O(log n) nodes.
By the end of this page, you will understand the query algorithm in complete detail, how query ranges are decomposed into tree nodes, why this decomposition uses at most O(log n) nodes, and how to implement queries for various operations (sum, min, max, etc.).
Let's formalize what we're trying to solve:
Given:
Find:
For a sum query, we want A[L] + A[L+1] + ... + A[R]. For a min query, we want min(A[L], A[L+1], ..., A[R]). And so on.
The key insight: Each node in our segment tree already stores the aggregate for some range. If we can identify which nodes together cover exactly [L, R] without overlap, we can combine their values to get our answer.
Example Setup:
Segment tree for array [1, 3, 5, 7, 9, 11]:
[0,5]=36
/ \
[0,2]=9 [3,5]=27
/ \ / \
[0,1]=4 [2,2]=5 [3,4]=16 [5,5]=11
/ \ / \
[0,0]=1 [1,1]=3 [3,3]=7 [4,4]=9
Queries we might ask:
The query algorithm is recursive and follows a simple decision tree at each node:
At each node covering range [start, end], one of three cases applies:
No Overlap: The query range [L, R] doesn't intersect [start, end] at all. → Return the identity element (0 for sum, ∞ for min, -∞ for max)
Complete Overlap: The node's range [start, end] is entirely within [L, R]. → Return this node's value directly (it's fully useful)
Partial Overlap: The ranges intersect but neither contains the other. → Query both children and combine their results
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
"""Segment Tree Query Implementation================================= The query function decomposes [L, R] into precomputed rangesand combines their values in O(log n) time.""" class SegmentTree: """ Segment tree with query support. """ def __init__(self, arr): self.n = len(arr) self.tree = [0] * (4 * self.n) if self.n > 0 else [] self.arr = arr if self.n > 0: self._build(1, 0, self.n - 1) def _build(self, node, start, end): if start == end: self.tree[node] = self.arr[start] return mid = (start + end) // 2 self._build(2 * node, start, mid) self._build(2 * node + 1, mid + 1, end) self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1] def query(self, L: int, R: int) -> int: """ Query the sum of elements in range [L, R]. Args: L: Left boundary (inclusive), 0-indexed R: Right boundary (inclusive), 0-indexed Returns: Sum of elements A[L] + A[L+1] + ... + A[R] Time Complexity: O(log n) """ if self.n == 0 or L > R or L < 0 or R >= self.n: return 0 # Invalid query return self._query(1, 0, self.n - 1, L, R) def _query(self, node: int, start: int, end: int, L: int, R: int) -> int: """ Recursive query helper. Args: node: Current node index in tree start, end: Range this node covers L, R: Query range Returns: Sum of elements in [L, R] that fall within [start, end] """ # Case 1: No overlap # The query range [L, R] is completely outside [start, end] if R < start or L > end: return 0 # Identity element for sum # Case 2: Complete overlap # The node's range [start, end] is entirely within query range [L, R] if L <= start and end <= R: return self.tree[node] # Use precomputed value # Case 3: Partial overlap # Need to check both children mid = (start + end) // 2 left_sum = self._query(2 * node, start, mid, L, R) right_sum = self._query(2 * node + 1, mid + 1, end, L, R) return left_sum + right_sum # Demonstration with step-by-step tracingdef trace_query(): """Trace a query step by step.""" arr = [1, 3, 5, 7, 9, 11] st = SegmentTree(arr) print("Array:", arr) print("\n" + "=" * 60) print("QUERY TRACING: Sum of [1, 4]") print("=" * 60) # Manual trace print(""" Query: sum([1, 4]) = 3 + 5 + 7 + 9 = 24 Starting at root: node=1, [0,5], query=[1,4] ├── Partial overlap (0 < 1, 5 > 4) │ ├── Left child: node=2, [0,2], query=[1,4] │ ├── Partial overlap (0 < 1) │ │ │ ├── Left child: node=4, [0,1], query=[1,4] │ │ ├── Partial overlap (0 < 1) │ │ │ │ │ ├── Left child: node=8, [0,0], query=[1,4] │ │ │ └── NO OVERLAP (0 < 1) → return 0 │ │ │ │ │ └── Right child: node=9, [1,1], query=[1,4] │ │ └── COMPLETE OVERLAP (1 ≤ 1 ≤ 1 ≤ 4) → return 3 │ │ │ │ → return 0 + 3 = 3 │ │ │ └── Right child: node=5, [2,2], query=[1,4] │ └── COMPLETE OVERLAP (1 ≤ 2 ≤ 2 ≤ 4) → return 5 │ │ → return 3 + 5 = 8 │ └── Right child: node=3, [3,5], query=[1,4] ├── Partial overlap (5 > 4) │ ├── Left child: node=6, [3,4], query=[1,4] │ └── COMPLETE OVERLAP (1 ≤ 3 ≤ 4 ≤ 4) → return 16 │ └── Right child: node=7, [5,5], query=[1,4] └── NO OVERLAP (5 > 4) → return 0 → return 16 + 0 = 16 Final: 8 + 16 = 24 ✓ """) result = st.query(1, 4) print(f"Computed result: {result}") print(f"Expected: {sum(arr[1:5])}") print(f"Match: {result == sum(arr[1:5])}") if __name__ == "__main__": trace_query()The magic of the query algorithm lies in how it decomposes any query range [L, R] into tree nodes. Let's understand this deeply.
What makes a "useful" node?
A node with range [start, end] is useful for query [L, R] if:
These maximal, fully-contained nodes form the canonical decomposition of [L, R].
Example: Query [1, 4] on array [1, 3, 5, 7, 9, 11]
[0,5]
/ \
[0,2] [3,5]
/ \ / \
[0,1] [2,2]* [3,4]* [5,5]
/ \
[0,0] [1,1]*
Nodes marked with * are the canonical decomposition:
Total: 3 + 5 + 16 = 24 ✓
Notice:
The segment tree decomposes [L, R] into at most O(log n) nodes because at each level of the tree, we use at most 2 nodes from the canonical decomposition. This is because L and R each "block" at most one node per level from being fully contained.
Why is the query O(log n)? Let's prove this rigorously.
Claim: A query visits at most O(log n) nodes.
Proof Strategy: At each level of the tree, we visit at most 4 nodes. Since the tree has O(log n) levels, total nodes visited is O(log n).
Detailed Proof:
Consider what happens at any level of the tree. The query range [L, R] intersects some nodes at this level. These nodes can be categorized as:
Key Observation: At any level, there are at most 2 nodes with partial overlap.
Why? The nodes that partially overlap [L, R] are exactly:
All other nodes at that level are either:
Counting nodes visited:
Total nodes visited: O(log n). Total time: O(log n).
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
"""Demonstration of O(log n) query complexity.Count the nodes visited during queries.""" class SegmentTreeWithStats: """Segment tree that tracks query statistics.""" def __init__(self, arr): self.n = len(arr) self.tree = [0] * (4 * self.n) if self.n > 0 else [] if self.n > 0: self._build(arr, 1, 0, self.n - 1) self.nodes_visited = 0 self.visit_log = [] def _build(self, arr, node, start, end): if start == end: self.tree[node] = arr[start] return mid = (start + end) // 2 self._build(arr, 2 * node, start, mid) self._build(arr, 2 * node + 1, mid + 1, end) self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1] def query(self, L: int, R: int) -> int: """Query with statistics tracking.""" self.nodes_visited = 0 self.visit_log = [] return self._query(1, 0, self.n - 1, L, R) def _query(self, node, start, end, L, R) -> int: self.nodes_visited += 1 # No overlap if R < start or L > end: self.visit_log.append( f"Node {node} [{start},{end}]: No overlap → 0" ) return 0 # Complete overlap if L <= start and end <= R: self.visit_log.append( f"Node {node} [{start},{end}]: Complete overlap → {self.tree[node]}" ) return self.tree[node] # Partial overlap self.visit_log.append( f"Node {node} [{start},{end}]: Partial overlap → descend" ) mid = (start + end) // 2 left = self._query(2 * node, start, mid, L, R) right = self._query(2 * node + 1, mid + 1, end, L, R) return left + right def print_stats(self, L, R, result): """Print query statistics.""" import math expected_max = 4 * (math.ceil(math.log2(self.n)) + 1) if self.n > 1 else 1 print(f"\nQuery [{L}, {R}] = {result}") print(f" Nodes visited: {self.nodes_visited}") print(f" Array size n: {self.n}") print(f" log₂(n): {math.log2(self.n):.2f}") print(f" Theoretical max: ~{expected_max} nodes") print(f"\n Visit log:") for log in self.visit_log: print(f" {log}") def demonstrate_query_complexity(): """Show that queries are O(log n).""" import math import random print("=" * 70) print("QUERY COMPLEXITY DEMONSTRATION") print("=" * 70) # Small example with detailed trace arr = [1, 3, 5, 7, 9, 11] st = SegmentTreeWithStats(arr) print(f"\nArray: {arr}") print(f"Size n = {len(arr)}, log₂(n) = {math.log2(len(arr)):.2f}") # Trace a few queries for (L, R) in [(0, 5), (1, 4), (2, 3), (0, 0)]: result = st.query(L, R) st.print_stats(L, R, result) # Statistics for larger arrays print("\n" + "=" * 70) print("SCALING BEHAVIOR") print("=" * 70) print(f"\n{'n':>10} {'log₂(n)':>10} {'Max nodes visited':>20}") print("-" * 42) for size in [10, 100, 1000, 10000, 100000]: arr = list(range(size)) st = SegmentTreeWithStats(arr) max_visited = 0 for _ in range(100): # 100 random queries L = random.randint(0, size - 1) R = random.randint(L, size - 1) st.query(L, R) max_visited = max(max_visited, st.nodes_visited) log_n = math.log2(size) print(f"{size:>10} {log_n:>10.2f} {max_visited:>20}") print("\nObservation: Nodes visited grows as O(log n) ✓") if __name__ == "__main__": demonstrate_query_complexity()The query algorithm adapts to different operations by changing:
The structure remains identical—only these two elements change.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
"""Query implementations for various segment tree operations.""" from typing import Callable, TypeVarfrom math import gcd T = TypeVar('T') class GenericSegmentTree: """ A generic segment tree supporting any associative operation. """ def __init__( self, arr: list, combine: Callable[[T, T], T], identity: T ): """ Initialize segment tree. Args: arr: Input array combine: Binary function to combine two values identity: Identity element for the operation """ self.n = len(arr) self.combine = combine self.identity = identity self.tree = [identity] * (4 * self.n) if self.n > 0 else [] if self.n > 0: self._build(arr, 1, 0, self.n - 1) def _build(self, arr, node, start, end): if start == end: self.tree[node] = arr[start] return mid = (start + end) // 2 self._build(arr, 2 * node, start, mid) self._build(arr, 2 * node + 1, mid + 1, end) self.tree[node] = self.combine( self.tree[2 * node], self.tree[2 * node + 1] ) def query(self, L: int, R: int) -> T: """Query the aggregate of [L, R].""" if self.n == 0 or L > R or L < 0 or R >= self.n: return self.identity return self._query(1, 0, self.n - 1, L, R) def _query(self, node, start, end, L, R) -> T: # No overlap if R < start or L > end: return self.identity # Complete overlap if L <= start and end <= R: return self.tree[node] # Partial overlap mid = (start + end) // 2 left = self._query(2 * node, start, mid, L, R) right = self._query(2 * node + 1, mid + 1, end, L, R) return self.combine(left, right) # Specialized implementations for common operations class SumSegmentTree(GenericSegmentTree): """Range Sum Queries.""" def __init__(self, arr): super().__init__(arr, lambda a, b: a + b, 0) class MinSegmentTree(GenericSegmentTree): """Range Minimum Queries.""" def __init__(self, arr): super().__init__(arr, min, float('inf')) class MaxSegmentTree(GenericSegmentTree): """Range Maximum Queries.""" def __init__(self, arr): super().__init__(arr, max, float('-inf')) class GCDSegmentTree(GenericSegmentTree): """Range GCD Queries.""" def __init__(self, arr): super().__init__(arr, gcd, 0) class XORSegmentTree(GenericSegmentTree): """Range XOR Queries.""" def __init__(self, arr): super().__init__(arr, lambda a, b: a ^ b, 0) class ProductSegmentTree(GenericSegmentTree): """Range Product Queries (watch for overflow!).""" def __init__(self, arr): super().__init__(arr, lambda a, b: a * b, 1) # Demonstrationdef demonstrate_query_variants(): """Show queries with different operations.""" arr = [12, 6, 18, 9, 3, 15, 24, 8] print("Array:", arr) print() queries = [(0, 7), (2, 5), (0, 3), (4, 7)] # Sum queries st_sum = SumSegmentTree(arr) print("=== SUM QUERIES ===") for L, R in queries: result = st_sum.query(L, R) expected = sum(arr[L:R+1]) print(f" Sum[{L},{R}] = {result} (expected: {expected}) ✓") # Min queries print("\n=== MIN QUERIES ===") st_min = MinSegmentTree(arr) for L, R in queries: result = st_min.query(L, R) expected = min(arr[L:R+1]) print(f" Min[{L},{R}] = {result} (expected: {expected}) ✓") # Max queries print("\n=== MAX QUERIES ===") st_max = MaxSegmentTree(arr) for L, R in queries: result = st_max.query(L, R) expected = max(arr[L:R+1]) print(f" Max[{L},{R}] = {result} (expected: {expected}) ✓") # GCD queries print("\n=== GCD QUERIES ===") st_gcd = GCDSegmentTree(arr) from functools import reduce for L, R in queries: result = st_gcd.query(L, R) expected = reduce(gcd, arr[L:R+1]) print(f" GCD[{L},{R}] = {result} (expected: {expected}) ✓") # XOR queries print("\n=== XOR QUERIES ===") st_xor = XORSegmentTree(arr) for L, R in queries: result = st_xor.query(L, R) expected = reduce(lambda x, y: x ^ y, arr[L:R+1]) print(f" XOR[{L},{R}] = {result} (expected: {expected}) ✓") if __name__ == "__main__": demonstrate_query_variants()| Operation | Identity | Combine Function | Return Type |
|---|---|---|---|
| Sum | 0 | a + b | Same as input |
| Min | +∞ | min(a, b) | Same as input |
| Max | -∞ | max(a, b) | Same as input |
| GCD | 0 | gcd(a, b) | Same as input |
| LCM | 1 | lcm(a, b) | Can overflow! |
| XOR | 0 | a ^ b | Same as input |
| AND | ~0 (all 1s) | a & b | Same as input |
| OR | 0 | a | b | Same as input |
| Product | 1 | a × b | Can overflow! |
Robust query implementations must handle edge cases:
1. Empty Array (n = 0)
2. Single Element Query (L = R)
3. Full Range Query (L = 0, R = n-1)
4. Invalid Ranges
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
"""Testing edge cases in segment tree queries.""" class RobustSegmentTree: """Segment tree with robust edge case handling.""" def __init__(self, arr): self.n = len(arr) self.tree = [0] * (4 * self.n) if self.n > 0 else [] if self.n > 0: self._build(arr, 1, 0, self.n - 1) def _build(self, arr, node, start, end): if start == end: self.tree[node] = arr[start] return mid = (start + end) // 2 self._build(arr, 2 * node, start, mid) self._build(arr, 2 * node + 1, mid + 1, end) self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1] def query(self, L: int, R: int) -> int: """ Query with comprehensive edge case handling. Returns: - Sum of [L, R] if valid - 0 for invalid/empty queries """ # Edge case: Empty array if self.n == 0: return 0 # Edge case: Invalid range (L > R) if L > R: return 0 # Edge case: Out of bounds L = max(0, L) # Clamp L to valid range R = min(self.n - 1, R) # Clamp R to valid range if L > R: # Check again after clamping return 0 return self._query(1, 0, self.n - 1, L, R) def _query(self, node, start, end, L, R): if R < start or L > end: return 0 if L <= start and end <= R: return self.tree[node] mid = (start + end) // 2 return (self._query(2 * node, start, mid, L, R) + self._query(2 * node + 1, mid + 1, end, L, R)) def test_edge_cases(): """Test all edge cases.""" print("=" * 50) print("EDGE CASE TESTS") print("=" * 50) # Test 1: Empty array print("\n[Test 1] Empty array") empty_st = RobustSegmentTree([]) result = empty_st.query(0, 0) print(f" Query on empty: {result}") assert result == 0, "Empty array query should return 0" print(" ✓ Passed") # Test 2: Single element print("\n[Test 2] Single element array") single_st = RobustSegmentTree([42]) result = single_st.query(0, 0) print(f" Query [0,0]: {result}") assert result == 42, "Single element query should return that element" print(" ✓ Passed") # Test 3: Full range print("\n[Test 3] Full range query") arr = [1, 2, 3, 4, 5] st = RobustSegmentTree(arr) result = st.query(0, 4) print(f" Query [0,4] on {arr}: {result}") assert result == 15, "Full range should equal sum of all elements" print(" ✓ Passed") # Test 4: L > R (invalid range) print("\n[Test 4] Invalid range (L > R)") result = st.query(3, 1) print(f" Query [3,1]: {result}") assert result == 0, "Invalid range should return 0" print(" ✓ Passed") # Test 5: Out of bounds (negative L) print("\n[Test 5] Out of bounds (L < 0)") result = st.query(-5, 2) print(f" Query [-5,2]: {result} (clamped to [0,2])") assert result == 6, "Should clamp and return sum of [0,2]" print(" ✓ Passed") # Test 6: Out of bounds (R >= n) print("\n[Test 6] Out of bounds (R >= n)") result = st.query(3, 100) print(f" Query [3,100]: {result} (clamped to [3,4])") assert result == 9, "Should clamp and return sum of [3,4]" print(" ✓ Passed") # Test 7: Both bounds out of range, but valid after clamping print("\n[Test 7] Both bounds out of range") result = st.query(-10, 100) print(f" Query [-10,100]: {result} (clamped to [0,4])") assert result == 15, "Should clamp and return full sum" print(" ✓ Passed") # Test 8: Completely out of bounds (no overlap after clamping) print("\n[Test 8] Completely out of bounds") result = st.query(-10, -5) print(f" Query [-10,-5]: {result}") assert result == 0, "Completely out of bounds should return 0" print(" ✓ Passed") print("\n" + "=" * 50) print("ALL EDGE CASE TESTS PASSED!") print("=" * 50) if __name__ == "__main__": test_edge_cases()While the recursive implementation is clear and correct, there are optimizations worth knowing:
1. Iterative Query (for specific cases) Some segment tree variants support bottom-up iterative queries, which avoid function call overhead.
2. Early Termination When querying the full range, return the root immediately.
3. Tail Recursion Optimization In some cases, we can avoid one recursive call if we know it returns identity.
4. Short-Circuit for No Overlap Check no-overlap condition first (most common case during descent).
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
"""Optimized segment tree query implementations.""" class OptimizedSegmentTree: """ Segment tree with optimized query implementation. """ def __init__(self, arr): self.n = len(arr) self.tree = [0] * (4 * self.n) if self.n > 0 else [] if self.n > 0: self._build(arr, 1, 0, self.n - 1) def _build(self, arr, node, start, end): if start == end: self.tree[node] = arr[start] return mid = (start + end) // 2 self._build(arr, 2 * node, start, mid) self._build(arr, 2 * node + 1, mid + 1, end) self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1] def query(self, L: int, R: int) -> int: """Optimized query with early termination.""" # Early termination: full range query if L == 0 and R == self.n - 1: return self.tree[1] # Early termination: single element (could optimize further) if L == R: return self._query(1, 0, self.n - 1, L, R) return self._query(1, 0, self.n - 1, L, R) def _query(self, node: int, start: int, end: int, L: int, R: int) -> int: """ Optimized recursive query. Optimizations: 1. No overlap check comes first (most likely to trigger) 2. Complete overlap doesn't recurse further 3. Bit operations for child calculation (compiler usually optimizes anyway) """ # Optimization: Check no overlap first (most common early termination) if end < L or start > R: return 0 # Complete overlap - use precomputed value if L <= start and end <= R: return self.tree[node] # Partial overlap - must check children mid = (start + end) >> 1 # Bit shift instead of /2 left_child = node << 1 # 2 * node right_child = (node << 1) | 1 # 2 * node + 1 # Only recurse into children that have possible overlap left_result = 0 right_result = 0 if L <= mid: # Query overlaps with left child left_result = self._query(left_child, start, mid, L, R) if R > mid: # Query overlaps with right child right_result = self._query(right_child, mid + 1, end, L, R) return left_result + right_result class MinimalRecursionSegmentTree: """ Segment tree that minimizes unnecessary recursive calls. """ def __init__(self, arr): self.n = len(arr) self.tree = [0] * (4 * self.n) if self.n > 0 else [] if self.n > 0: self._build(arr, 1, 0, self.n - 1) def _build(self, arr, node, start, end): if start == end: self.tree[node] = arr[start] return mid = (start + end) // 2 self._build(arr, 2 * node, start, mid) self._build(arr, 2 * node + 1, mid + 1, end) self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1] def query(self, L: int, R: int) -> int: """Query with smarter recursion.""" return self._query(1, 0, self.n - 1, L, R) def _query(self, node, start, end, L, R): # Complete overlap if L <= start and end <= R: return self.tree[node] mid = (start + end) // 2 # Determine which children to visit left = 2 * node right = 2 * node + 1 # If query is entirely in left half if R <= mid: return self._query(left, start, mid, L, R) # If query is entirely in right half if L > mid: return self._query(right, mid + 1, end, L, R) # Query spans both halves return (self._query(left, start, mid, L, mid) + self._query(right, mid + 1, end, mid + 1, R)) # Performance comparison (basic benchmark)def compare_performance(): """Compare original vs optimized implementations.""" import time import random print("=" * 60) print("QUERY PERFORMANCE COMPARISON") print("=" * 60) sizes = [1000, 10000, 100000] for n in sizes: arr = [random.randint(1, 100) for _ in range(n)] st_basic = OptimizedSegmentTree(arr) st_minimal = MinimalRecursionSegmentTree(arr) # Generate random queries queries = [(random.randint(0, n-1), random.randint(0, n-1)) for _ in range(10000)] queries = [(min(L,R), max(L,R)) for L, R in queries] # Benchmark basic start_time = time.perf_counter() for L, R in queries: st_basic.query(L, R) basic_time = time.perf_counter() - start_time # Benchmark minimal recursion start_time = time.perf_counter() for L, R in queries: st_minimal.query(L, R) minimal_time = time.perf_counter() - start_time print(f"\nn = {n:,}, 10,000 queries:") print(f" Optimized: {basic_time:.4f}s") print(f" Minimal recursion: {minimal_time:.4f}s") if __name__ == "__main__": compare_performance()The standard recursive implementation is usually fast enough. Micro-optimizations like bit shifts and early termination provide marginal gains. Focus on correctness first; optimize only if profiling shows the segment tree is a bottleneck.
Let's put everything together with a complete, tested implementation:
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
"""Complete Segment Tree with Query Operations============================================ A production-ready implementation supporting:- Construction in O(n)- Range queries in O(log n)- Multiple operation types Usage: st = SegmentTree([1, 3, 5, 7, 9, 11]) print(st.query(1, 4)) # Output: 24""" from typing import List, Callable, TypeVar T = TypeVar('T') class SegmentTree: """ A flexible segment tree supporting various range queries. Time Complexity: - Build: O(n) - Query: O(log n) Space Complexity: O(n) """ def __init__( self, arr: List[T], combine: Callable[[T, T], T] = lambda a, b: a + b, identity: T = 0 ): """ Build a segment tree from the input array. Args: arr: Input array combine: Binary associative function to combine values identity: Identity element for the combine function Example: # Sum tree (default) st = SegmentTree([1, 2, 3, 4, 5]) # Min tree st = SegmentTree([1, 2, 3], min, float('inf')) # Max tree st = SegmentTree([1, 2, 3], max, float('-inf')) """ self.n = len(arr) self.combine = combine self.identity = identity if self.n == 0: self.tree = [] return self.tree = [identity] * (4 * self.n) self._build(arr, 1, 0, self.n - 1) def _build(self, arr: List[T], node: int, start: int, end: int) -> None: """Recursively build the segment tree.""" if start == end: self.tree[node] = arr[start] return mid = (start + end) // 2 left, right = 2 * node, 2 * node + 1 self._build(arr, left, start, mid) self._build(arr, right, mid + 1, end) self.tree[node] = self.combine(self.tree[left], self.tree[right]) def query(self, L: int, R: int) -> T: """ Query the aggregate of elements in range [L, R]. Args: L: Left index (inclusive), 0-indexed R: Right index (inclusive), 0-indexed Returns: Aggregate of elements from index L to R Example: st = SegmentTree([1, 3, 5, 7, 9, 11]) print(st.query(1, 4)) # 3 + 5 + 7 + 9 = 24 """ if self.n == 0: return self.identity if L < 0: L = 0 if R >= self.n: R = self.n - 1 if L > R: return self.identity return self._query(1, 0, self.n - 1, L, R) def _query(self, node: int, start: int, end: int, L: int, R: int) -> T: """Recursive query helper.""" # No overlap if R < start or L > end: return self.identity # Complete overlap if L <= start and end <= R: return self.tree[node] # Partial overlap mid = (start + end) // 2 left_result = self._query(2 * node, start, mid, L, R) right_result = self._query(2 * node + 1, mid + 1, end, L, R) return self.combine(left_result, right_result) def __repr__(self) -> str: """String representation of the segment tree.""" if self.n == 0: return "SegmentTree(empty)" return f"SegmentTree(n={self.n}, root={self.tree[1]})" # ============================================================# Comprehensive Test Suite# ============================================================ def run_all_tests(): """Run comprehensive tests on segment tree queries.""" print("=" * 60) print("SEGMENT TREE QUERY - COMPREHENSIVE TESTS") print("=" * 60) # Test 1: Basic Sum Queries print("\n[1] Basic Sum Queries") arr = [1, 3, 5, 7, 9, 11] st = SegmentTree(arr) test_cases = [ ((0, 5), 36), # Full range ((0, 0), 1), # Single element (first) ((5, 5), 11), # Single element (last) ((1, 4), 24), # Middle range ((0, 2), 9), # Prefix ((3, 5), 27), # Suffix ] for (L, R), expected in test_cases: result = st.query(L, R) status = "✓" if result == expected else "✗" print(f" Sum[{L},{R}] = {result} (expected: {expected}) {status}") assert result == expected # Test 2: Min Queries print("\n[2] Min Queries") st_min = SegmentTree(arr, min, float('inf')) min_cases = [ ((0, 5), 1), ((2, 4), 5), ((4, 5), 9), ] for (L, R), expected in min_cases: result = st_min.query(L, R) status = "✓" if result == expected else "✗" print(f" Min[{L},{R}] = {result} (expected: {expected}) {status}") assert result == expected # Test 3: Max Queries print("\n[3] Max Queries") st_max = SegmentTree(arr, max, float('-inf')) max_cases = [ ((0, 5), 11), ((0, 3), 7), ((1, 2), 5), ] for (L, R), expected in max_cases: result = st_max.query(L, R) status = "✓" if result == expected else "✗" print(f" Max[{L},{R}] = {result} (expected: {expected}) {status}") assert result == expected # Test 4: Edge Cases print("\n[4] Edge Cases") assert SegmentTree([]).query(0, 0) == 0, "Empty array" assert SegmentTree([42]).query(0, 0) == 42, "Single element" assert st.query(-1, 2) == 9, "Negative L (clamped)" assert st.query(4, 100) == 20, "R out of bounds (clamped)" print(" All edge cases passed ✓") # Test 5: Large Array Performance print("\n[5] Large Array Performance") import time import random large_arr = [random.randint(1, 1000) for _ in range(100000)] build_start = time.perf_counter() large_st = SegmentTree(large_arr) build_time = time.perf_counter() - build_start queries = [(random.randint(0, 99999), random.randint(0, 99999)) for _ in range(10000)] queries = [(min(a,b), max(a,b)) for a, b in queries] query_start = time.perf_counter() for L, R in queries: large_st.query(L, R) query_time = time.perf_counter() - query_start print(f" Build time (n=100,000): {build_time:.4f}s") print(f" Query time (10,000 queries): {query_time:.4f}s") print(f" Avg query time: {query_time/10000*1000000:.2f} µs") print("\n" + "=" * 60) print("ALL TESTS PASSED!") print("=" * 60) if __name__ == "__main__": run_all_tests()We've explored the query operation in complete depth. Here are the essential insights:
What's Next:
Now that we can query efficiently, we need to handle changes. The next page covers the update operation—how to modify a single element in O(log n) time while maintaining the segment tree invariant.
You now understand how segment tree queries work from first principles. The ability to decompose any range into O(log n) precomputed ranges is what makes segment trees powerful. Next, we'll see how updates maintain this structure efficiently.