Skip to content

Commit 3ea09f2

Browse files
committed
Implement Wavelet Tree with rank and kthSmallest methods
1 parent 2616e09 commit 3ea09f2

2 files changed

Lines changed: 345 additions & 0 deletions

File tree

Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
package com.thealgorithms.datastructures.trees;
2+
3+
import java.util.ArrayList;
4+
import java.util.List;
5+
6+
/**
7+
* A Wavelet Tree is a highly efficient data structure used to store sequences
8+
* and answer queries like rank, select, and quantile in O(log(max_val - min_val)) time.
9+
* This structure is particularly useful in competitive programming and text compression.
10+
*/
11+
public class WaveletTree {
12+
13+
private class Node {
14+
int low, high;
15+
Node left, right;
16+
List<Integer> leftCount; // Prefix sums of elements going to the left child
17+
18+
/**
19+
* Recursively constructs the tree nodes by partitioning the array.
20+
*
21+
* @param arr the subarray for the current node
22+
* @param low the minimum possible value in the current node
23+
* @param high the maximum possible value in the current node
24+
*/
25+
Node(int[] arr, int low, int high) {
26+
this.low = low;
27+
this.high = high;
28+
29+
if (arr.length == 0 || low == high) {
30+
return;
31+
}
32+
33+
int mid = low + (high - low) / 2;
34+
leftCount = new ArrayList<>(arr.length + 1);
35+
leftCount.add(0);
36+
37+
List<Integer> leftArr = new ArrayList<>();
38+
List<Integer> rightArr = new ArrayList<>();
39+
40+
for (int x : arr) {
41+
if (x <= mid) {
42+
leftArr.add(x);
43+
leftCount.add(leftCount.get(leftCount.size() - 1) + 1);
44+
} else {
45+
rightArr.add(x);
46+
leftCount.add(leftCount.get(leftCount.size() - 1));
47+
}
48+
}
49+
50+
if (!leftArr.isEmpty()) {
51+
this.left = new Node(leftArr.stream().mapToInt(i -> i).toArray(), low, mid);
52+
}
53+
if (!rightArr.isEmpty()) {
54+
this.right = new Node(rightArr.stream().mapToInt(i -> i).toArray(), mid + 1, high);
55+
}
56+
}
57+
}
58+
59+
private Node root;
60+
private final int n;
61+
62+
/**
63+
* Constructs a Wavelet Tree from the given array.
64+
* The min and max values are determined dynamically from the array.
65+
*
66+
* @param arr the input array
67+
*/
68+
public WaveletTree(int[] arr) {
69+
if (arr == null || arr.length == 0) {
70+
this.n = 0;
71+
return;
72+
}
73+
this.n = arr.length;
74+
int min = arr[0];
75+
int max = arr[0];
76+
for (int x : arr) {
77+
if (x < min) {
78+
min = x;
79+
}
80+
if (x > max) {
81+
max = x;
82+
}
83+
}
84+
root = new Node(arr, min, max);
85+
}
86+
87+
/**
88+
* Constructs a Wavelet Tree from the given array with specific min and max values.
89+
*
90+
* @param arr the input array
91+
* @param minValue the minimum possible value
92+
* @param maxValue the maximum possible value
93+
*/
94+
public WaveletTree(int[] arr, int minValue, int maxValue) {
95+
if (arr == null || arr.length == 0) {
96+
this.n = 0;
97+
return;
98+
}
99+
this.n = arr.length;
100+
root = new Node(arr, minValue, maxValue);
101+
}
102+
103+
/**
104+
* How many times does the number x appear in the array from index 0 to i (inclusive)?
105+
*
106+
* @param x the number to search for
107+
* @param i the end index (0-based, inclusive)
108+
* @return the number of occurrences of x in arr[0...i]
109+
*/
110+
public int rank(int x, int i) {
111+
if (root == null || x < root.low || x > root.high || i < 0) {
112+
return 0;
113+
}
114+
// If i is out of bounds, cap it at n - 1
115+
int endIdx = Math.min(i, n - 1);
116+
return rank(root, x, endIdx + 1);
117+
}
118+
119+
private int rank(Node node, int x, int count) {
120+
if (node == null || count == 0) {
121+
return 0;
122+
}
123+
if (node.low == node.high) {
124+
return count;
125+
}
126+
int mid = node.low + (node.high - node.low) / 2;
127+
int leftC = node.leftCount.get(count);
128+
if (x <= mid) {
129+
return rank(node.left, x, leftC);
130+
} else {
131+
return rank(node.right, x, count - leftC);
132+
}
133+
}
134+
135+
/**
136+
* What is the 0-based index of the k-th occurrence of the number x in the array?
137+
*
138+
* @param x the number to search for
139+
* @param k the occurrence count (1-based)
140+
* @return the 0-based index in the original array, or -1 if x occurs less than k times
141+
*/
142+
public int select(int x, int k) {
143+
if (root == null || x < root.low || x > root.high || k <= 0) {
144+
return -1;
145+
}
146+
if (rank(x, n - 1) < k) {
147+
return -1;
148+
}
149+
return select(root, x, k);
150+
}
151+
152+
private int select(Node node, int x, int k) {
153+
if (node == null) {
154+
return -1;
155+
}
156+
if (node.low == node.high) {
157+
return k - 1; // 0-based index within the imaginary array at the leaf
158+
}
159+
int mid = node.low + (node.high - node.low) / 2;
160+
if (x <= mid) {
161+
int posInLeft = select(node.left, x, k);
162+
if (posInLeft == -1) {
163+
return -1;
164+
}
165+
return binarySearchLeft(node.leftCount, posInLeft + 1);
166+
} else {
167+
int posInRight = select(node.right, x, k);
168+
if (posInRight == -1) {
169+
return -1;
170+
}
171+
return binarySearchRight(node.leftCount, posInRight + 1);
172+
}
173+
}
174+
175+
private int binarySearchLeft(List<Integer> prefixSums, int k) {
176+
int l = 1, r = prefixSums.size() - 1;
177+
int ans = -1;
178+
while (l <= r) {
179+
int mid = l + (r - l) / 2;
180+
if (prefixSums.get(mid) >= k) {
181+
ans = mid;
182+
r = mid - 1;
183+
} else {
184+
l = mid + 1;
185+
}
186+
}
187+
return ans == -1 ? -1 : ans - 1; // Convert to 0-based index
188+
}
189+
190+
private int binarySearchRight(List<Integer> prefixSums, int k) {
191+
int l = 1, r = prefixSums.size() - 1;
192+
int ans = -1;
193+
while (l <= r) {
194+
int mid = l + (r - l) / 2;
195+
if (mid - prefixSums.get(mid) >= k) {
196+
ans = mid;
197+
r = mid - 1;
198+
} else {
199+
l = mid + 1;
200+
}
201+
}
202+
return ans == -1 ? -1 : ans - 1; // Convert to 0-based index
203+
}
204+
205+
/**
206+
* If you sort the subarray from index left to right, what would be the k-th smallest element?
207+
* This query is also commonly known as the quantile query.
208+
*
209+
* @param left the start index of the subarray (0-based, inclusive)
210+
* @param right the end index of the subarray (0-based, inclusive)
211+
* @param k the rank of the smallest element (1-based, e.g., k=1 is the minimum)
212+
* @return the k-th smallest element in the subarray, or -1 if invalid parameters
213+
*/
214+
public int kthSmallest(int left, int right, int k) {
215+
if (root == null || left > right || left < 0 || k < 1 || k > right - left + 1) {
216+
return -1;
217+
}
218+
return kthSmallest(root, left, right, k);
219+
}
220+
221+
private int kthSmallest(Node node, int left, int right, int k) {
222+
if (node == null) {
223+
return -1;
224+
}
225+
if (node.low == node.high) {
226+
return node.low;
227+
}
228+
229+
int countLeftInLMinus1 = (left == 0) ? 0 : node.leftCount.get(left);
230+
int countLeftInR = node.leftCount.get(right + 1);
231+
int elementsToLeft = countLeftInR - countLeftInLMinus1;
232+
233+
if (k <= elementsToLeft) {
234+
int newL = countLeftInLMinus1;
235+
int newR = countLeftInR - 1;
236+
return kthSmallest(node.left, newL, newR, k);
237+
} else {
238+
int newL = left - countLeftInLMinus1;
239+
int newR = right - countLeftInR;
240+
return kthSmallest(node.right, newL, newR, k - elementsToLeft);
241+
}
242+
}
243+
}
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
package com.thealgorithms.datastructures.trees;
2+
3+
import static org.junit.jupiter.api.Assertions.assertEquals;
4+
5+
import org.junit.jupiter.api.Test;
6+
7+
public class WaveletTreeTest {
8+
9+
@Test
10+
public void testRank() {
11+
int[] arr = {5, 1, 2, 5, 1};
12+
WaveletTree wt = new WaveletTree(arr);
13+
14+
// x = 1
15+
assertEquals(1, wt.rank(1, 1)); // In [5, 1], '1' appears 1 time
16+
assertEquals(2, wt.rank(1, 4)); // In [5, 1, 2, 5, 1], '1' appears 2 times
17+
assertEquals(0, wt.rank(1, 0)); // In [5], '1' appears 0 times
18+
19+
// x = 5
20+
assertEquals(1, wt.rank(5, 0)); // In [5], '5' appears 1 time
21+
assertEquals(1, wt.rank(5, 2)); // In [5, 1, 2], '5' appears 1 time
22+
assertEquals(2, wt.rank(5, 4)); // In [5, 1, 2, 5, 1], '5' appears 2 times
23+
24+
// Out of bounds / invalid value
25+
assertEquals(0, wt.rank(10, 4)); // '10' is not in the array
26+
assertEquals(0, wt.rank(5, -1)); // Invalid end index
27+
}
28+
29+
@Test
30+
public void testSelect() {
31+
int[] arr = {5, 1, 2, 5, 1};
32+
WaveletTree wt = new WaveletTree(arr);
33+
34+
assertEquals(1, wt.select(1, 1)); // 1st '1' is at index 1
35+
assertEquals(4, wt.select(1, 2)); // 2nd '1' is at index 4
36+
37+
assertEquals(0, wt.select(5, 1)); // 1st '5' is at index 0
38+
assertEquals(3, wt.select(5, 2)); // 2nd '5' is at index 3
39+
40+
assertEquals(2, wt.select(2, 1)); // 1st '2' is at index 2
41+
42+
assertEquals(-1, wt.select(5, 3)); // 3rd '5' doesn't exist
43+
assertEquals(-1, wt.select(10, 1)); // '10' doesn't exist
44+
assertEquals(-1, wt.select(5, 0)); // invalid k
45+
}
46+
47+
@Test
48+
public void testKthSmallest() {
49+
int[] arr = {5, 1, 2, 5, 1};
50+
WaveletTree wt = new WaveletTree(arr);
51+
52+
// Array: [5, 1, 2, 5, 1] -> Sorted: [1, 1, 2, 5, 5]
53+
assertEquals(1, wt.kthSmallest(0, 4, 1)); // 1st smallest in [5, 1, 2, 5, 1] is 1
54+
assertEquals(1, wt.kthSmallest(0, 4, 2)); // 2nd smallest in [5, 1, 2, 5, 1] is 1
55+
assertEquals(2, wt.kthSmallest(0, 4, 3)); // 3rd smallest in [5, 1, 2, 5, 1] is 2
56+
assertEquals(5, wt.kthSmallest(0, 4, 4)); // 4th smallest in [5, 1, 2, 5, 1] is 5
57+
assertEquals(5, wt.kthSmallest(0, 4, 5)); // 5th smallest in [5, 1, 2, 5, 1] is 5
58+
59+
// Subarray: arr[1..3] = [1, 2, 5] -> Sorted: [1, 2, 5]
60+
assertEquals(1, wt.kthSmallest(1, 3, 1)); // 1st smallest in [1, 2, 5] is 1
61+
assertEquals(2, wt.kthSmallest(1, 3, 2)); // 2nd smallest in [1, 2, 5] is 2
62+
assertEquals(5, wt.kthSmallest(1, 3, 3)); // 3rd smallest in [1, 2, 5] is 5
63+
64+
// Invalid ranges / arguments
65+
assertEquals(-1, wt.kthSmallest(4, 2, 1)); // Invalid range (left > right)
66+
assertEquals(-1, wt.kthSmallest(0, 4, 10)); // k > range length
67+
assertEquals(-1, wt.kthSmallest(0, 4, 0)); // k < 1
68+
}
69+
70+
@Test
71+
public void testEmptyAndSingleElementArray() {
72+
WaveletTree wtEmpty = new WaveletTree(new int[]{});
73+
assertEquals(0, wtEmpty.rank(1, 0));
74+
assertEquals(-1, wtEmpty.select(1, 1));
75+
assertEquals(-1, wtEmpty.kthSmallest(0, 0, 1));
76+
77+
WaveletTree wtSingle = new WaveletTree(new int[]{42});
78+
assertEquals(1, wtSingle.rank(42, 0));
79+
assertEquals(0, wtSingle.rank(42, -1));
80+
assertEquals(0, wtSingle.select(42, 1));
81+
assertEquals(-1, wtSingle.select(42, 2));
82+
assertEquals(42, wtSingle.kthSmallest(0, 0, 1));
83+
}
84+
85+
@Test
86+
public void testNegativeValues() {
87+
int[] arr = {-5, 10, -2, 0, -5};
88+
WaveletTree wt = new WaveletTree(arr);
89+
90+
assertEquals(2, wt.rank(-5, 4));
91+
assertEquals(1, wt.rank(0, 3));
92+
93+
assertEquals(0, wt.select(-5, 1));
94+
assertEquals(4, wt.select(-5, 2));
95+
assertEquals(3, wt.select(0, 1));
96+
97+
// Sorted: [-5, -5, -2, 0, 10]
98+
assertEquals(-5, wt.kthSmallest(0, 4, 1));
99+
assertEquals(-2, wt.kthSmallest(0, 4, 3));
100+
assertEquals(10, wt.kthSmallest(0, 4, 5));
101+
}
102+
}

0 commit comments

Comments
 (0)