diff --git a/README.md b/README.md index ed66eabec..83363555f 100644 --- a/README.md +++ b/README.md @@ -271,7 +271,7 @@ $ java -cp classes com.williamfiset.algorithms.search.BinarySearch # Search algorithms - [Binary search (real numbers)](src/main/java/com/williamfiset/algorithms/search/BinarySearch.java) **- O(log(n))** -- [Interpolation search (discrete discrete)](src/main/java/com/williamfiset/algorithms/search/InterpolationSearch.java) **- O(n) or O(log(log(n))) with uniform input** +- [Interpolation search (discrete numbers)](src/main/java/com/williamfiset/algorithms/search/InterpolationSearch.java) **- O(n) or O(log(log(n))) with uniform input** - [Ternary search (real numbers)](src/main/java/com/williamfiset/algorithms/search/TernarySearch.java) **- O(log(n))** - [Ternary search (discrete numbers)](src/main/java/com/williamfiset/algorithms/search/TernarySearchDiscrete.java) **- O(log(n))** diff --git a/src/main/java/com/williamfiset/algorithms/datastructures/segmenttree/BUILD b/src/main/java/com/williamfiset/algorithms/datastructures/segmenttree/BUILD index 61e31a872..84a42bbc0 100644 --- a/src/main/java/com/williamfiset/algorithms/datastructures/segmenttree/BUILD +++ b/src/main/java/com/williamfiset/algorithms/datastructures/segmenttree/BUILD @@ -21,19 +21,6 @@ java_binary( runtime_deps = [":segmenttree"], ) -# bazel run //src/main/java/com/williamfiset/algorithms/datastructures/segmenttree:GenericSegmentTree2 -java_binary( - name = "GenericSegmentTree2", - main_class = "com.williamfiset.algorithms.datastructures.segmenttree.GenericSegmentTree2", - runtime_deps = [":segmenttree"], -) - -# bazel run //src/main/java/com/williamfiset/algorithms/datastructures/segmenttree:GenericSegmentTree3 -java_binary( - name = "GenericSegmentTree3", - main_class = "com.williamfiset.algorithms.datastructures.segmenttree.GenericSegmentTree3", - runtime_deps = [":segmenttree"], -) # bazel run //src/main/java/com/williamfiset/algorithms/datastructures/segmenttree:MaxQuerySumUpdateSegmentTree java_binary( diff --git a/src/main/java/com/williamfiset/algorithms/datastructures/segmenttree/GenericSegmentTree.java b/src/main/java/com/williamfiset/algorithms/datastructures/segmenttree/GenericSegmentTree.java index faba5b78e..8cb7bfb69 100644 --- a/src/main/java/com/williamfiset/algorithms/datastructures/segmenttree/GenericSegmentTree.java +++ b/src/main/java/com/williamfiset/algorithms/datastructures/segmenttree/GenericSegmentTree.java @@ -1,14 +1,32 @@ /** * A generic segment tree implementation that supports several range update and aggregation - * functions. This implementation of the segment tree differs from the `GenericSegmentTree2` impl in - * that it stores the segment tree information inside multiple arrays for node. + * functions. * - *

Run with: ./gradlew run -Palgorithm=datastructures.segmenttree.GenericSegmentTree + *

Run with: bazel run + * //src/main/java/com/williamfiset/algorithms/datastructures/segmenttree:GenericSegmentTree * *

Several thanks to cp-algorithms for their great article on segment trees: * https://cp-algorithms.com/data_structures/segment_tree.html * - *

NOTE: This file is still a WIP + *

Supported combinations of (SegmentCombinationFn, RangeUpdateFn): + * + *

+ * + *

Unsupported (will throw UnsupportedOperationException): + * + *

+ * + *

NOTE: MIN/MAX + MULTIPLICATION may produce incorrect results when multiplying by negative + * values, since the min can become the max and vice versa. * * @author William Fiset, william.alexandre.fiset@gmail.com */ @@ -18,7 +36,8 @@ public class GenericSegmentTree { - // The type of segment combination function to use + // The type of segment combination function to use. + // This determines how child segments are merged to form the parent segment value. public static enum SegmentCombinationFn { SUM, MIN, @@ -27,8 +46,9 @@ public static enum SegmentCombinationFn { PRODUCT } - // When updating the value of a specific index position, or a range of values, - // modify the affected values using the following function: + // When updating a range of values, modify the affected values using one of the following + // functions. The choice of range update function affects how lazy values are applied and + // propagated through the tree. public static enum RangeUpdateFn { // When a range update is issued, assign all the values in the range [l, r] to be `x` ASSIGN, @@ -45,29 +65,36 @@ public static enum RangeUpdateFn { // root node and the left and right children of node i are i*2+1 and i*2+2. private Long[] t; - // The delta values associates with each segment. Used for lazy propagation - // when doing range updates. + // The delta values associated with each segment. Used for lazy propagation. private Long[] lazy; - // The chosen range combination function + // The chosen range combination function used to merge two child segments into a parent. private BinaryOperator combinationFn; // The Range Update Function (RUF) interface. + // + // This functional interface defines how a delta (lazy) value is applied to transform a + // segment's aggregated value. The parameters are: + // base - the existing segment value + // tl,tr - the left/right endpoints of the segment range [tl, tr] + // delta - the pending lazy delta value + // + // The segment size (tr - tl + 1) is needed by some RUFs. For example, when applying + // a SUM + ADDITION update, adding `d` to each of `count` elements increases the sum + // by `count * d`. private interface Ruf { - // base = the existing value - // tl, tr = the index value of the left/right endpoints, i.e: [tl, tr] - // delta = the delta value - // TODO(william): reorder to be base, delta, tl, tr Long apply(Long base, long tl, long tr, Long delta); } - // The Range Update Function (RUF) that chooses how a lazy delta value is - // applied to a segment. + // The Range Update Function (RUF) that determines how a lazy delta value is applied to + // update a segment's value. private Ruf ruf; - // The Lazy Range Update Function (LRUF) associated with the RUF. How you - // propagate the lazy delta values is sometimes different than how you apply - // them to the current segment (but most of the time the RUF = LRUF). + // The Lazy Range Update Function (LRUF) associated with the RUF. This determines how + // lazy delta values are composed when propagated to children. For example: + // - ADDITION lazies compose by summing: new_lazy = old_lazy + delta + // - MULTIPLICATION lazies compose by multiplying: new_lazy = old_lazy * delta + // - ASSIGN lazies compose by overwriting: new_lazy = delta private Ruf lruf; private long safeSum(Long a, Long b) { @@ -94,75 +121,22 @@ private Long safeMax(Long a, Long b) { return Math.max(a, b); } - private BinaryOperator sumCombinationFn = (a, b) -> safeSum(a, b); - private BinaryOperator minCombinationFn = (a, b) -> safeMin(a, b); - private BinaryOperator maxCombinationFn = (a, b) -> safeMax(a, b); - private BinaryOperator productCombinationFn = (a, b) -> safeMul(a, b); - private BinaryOperator gcdCombinationFn = - (a, b) -> { - if (a == null) return b; - if (b == null) return a; - long gcd = a; - while (b != 0) { - gcd = b; - b = a % b; - a = gcd; - } - return Math.abs(gcd); - }; - - // TODO(william): Document the justification for each function below - - // Range update functions - private Ruf minQuerySumUpdate = (b, tl, tr, d) -> safeSum(b, d); - private Ruf lminQuerySumUpdate = (b, tl, tr, d) -> safeSum(b, d); - - // TODO(issue/208): Can negative multiplication updates be supported? - private Ruf minQueryMulUpdate = (b, tl, tr, d) -> safeMul(b, d); - private Ruf lminQueryMulUpdate = (b, tl, tr, d) -> safeMul(b, d); - - private Ruf minQueryAssignUpdate = (b, tl, tr, d) -> d; - private Ruf lminQueryAssignUpdate = (b, tl, tr, d) -> d; - - private Ruf maxQuerySumUpdate = (b, tl, tr, d) -> safeSum(b, d); - private Ruf lmaxQuerySumUpdate = (b, tl, tr, d) -> safeSum(b, d); - - // TODO(issue/208): Can negative multiplication updates be supported? - private Ruf maxQueryMulUpdate = (b, tl, tr, d) -> safeMul(b, d); - private Ruf lmaxQueryMulUpdate = (b, tl, tr, d) -> safeMul(b, d); - - private Ruf maxQueryAssignUpdate = (b, tl, tr, d) -> d; - private Ruf lmaxQueryAssignUpdate = (b, tl, tr, d) -> d; - - private Ruf sumQuerySumUpdate = (b, tl, tr, d) -> b + (tr - tl + 1) * d; - private Ruf lsumQuerySumUpdate = (b, tl, tr, d) -> safeSum(b, d); - - private Ruf sumQueryMulUpdate = (b, tl, tr, d) -> safeMul(b, d); - private Ruf lsumQueryMulUpdate = (b, tl, tr, d) -> safeMul(b, d); - - private Ruf sumQueryAssignUpdate = (b, tl, tr, d) -> (tr - tl + 1) * d; - private Ruf lsumQueryAssignUpdate = (b, tl, tr, d) -> d; - - // TODO(william): confirm this cannot be supported? Can we maintain additional - // information to make it possible? - private Ruf gcdQuerySumUpdate = (b, tl, tr, d) -> null; - private Ruf lgcdQuerySumUpdate = (b, tl, tr, d) -> null; - - private Ruf gcdQueryMulUpdate = (b, tl, tr, d) -> safeMul(b, d); - private Ruf lgcdQueryMulUpdate = (b, tl, tr, d) -> safeMul(b, d); - - private Ruf gcdQueryAssignUpdate = (b, tl, tr, d) -> d; - private Ruf lgcdQueryAssignUpdate = (b, tl, tr, d) -> d; - - private Ruf productQuerySumUpdate = (b, tl, tr, d) -> b + (long) (Math.pow(d, (tr - tl + 1))); - private Ruf lproductQuerySumUpdate = (b, tl, tr, d) -> safeSum(b, d); - - private Ruf productQueryMulUpdate = (b, tl, tr, d) -> b * (long) (Math.pow(d, (tr - tl + 1))); - private Ruf lproductQueryMulUpdate = - (b, tl, tr, d) -> safeMul(b, d); // safeMul(b, (long)(Math.pow(d, (tr - tl + 1)))); + private long gcd(long a, long b) { + while (b != 0) { + long tmp = b; + b = a % b; + a = tmp; + } + return Math.abs(a); + } - private Ruf productQueryAssignUpdate = (b, tl, tr, d) -> d; - private Ruf lproductQueryAssignUpdate = (b, tl, tr, d) -> d; + // Reusable RUF lambdas shared across multiple query/update combos: + // addDelta - composes two additive deltas, or applies an additive delta to a value + // mulDelta - composes two multiplicative deltas, or applies a multiplicative delta + // assignDelta - overwrites with the new delta (used for ASSIGN updates) + private final Ruf addDelta = (b, tl, tr, d) -> safeSum(b, d); + private final Ruf mulDelta = (b, tl, tr, d) -> safeMul(b, d); + private final Ruf assignDelta = (b, tl, tr, d) -> d; public GenericSegmentTree( long[] values, @@ -178,82 +152,137 @@ public GenericSegmentTree( throw new IllegalArgumentException("Please specify a valid range update function."); } n = values.length; + t = new Long[4 * n]; + lazy = new Long[4 * n]; + + // Select the combination function + switch (segmentCombinationFunction) { + case SUM: + combinationFn = (a, b) -> safeSum(a, b); + break; + case MIN: + combinationFn = (a, b) -> safeMin(a, b); + break; + case MAX: + combinationFn = (a, b) -> safeMax(a, b); + break; + case GCD: + combinationFn = + (a, b) -> { + if (a == null) return b; + if (b == null) return a; + return gcd(a, b); + }; + break; + case PRODUCT: + combinationFn = (a, b) -> safeMul(a, b); + break; + } - // The size of the segment tree `t` + // Select the range update function (ruf) and lazy propagation function (lruf). // - // TODO(william): Investigate to reduce this space. There are only 2n-1 segments, so we should - // be able to reduce the space, but may need to reorganize the tree/queries. One idea is to use - // the Eulerian tour structure of the tree to densely pack the segments. - int N = 4 * n; - - t = new Long[N]; - // TODO(william): Change this to be of size n to reduce memory from O(4n) to O(3n) - lazy = new Long[N]; - - // Select the specified combination function - if (segmentCombinationFunction == SegmentCombinationFn.SUM) { - combinationFn = sumCombinationFn; - if (rangeUpdateFunction == RangeUpdateFn.ADDITION) { - ruf = sumQuerySumUpdate; - lruf = lsumQuerySumUpdate; - } else if (rangeUpdateFunction == RangeUpdateFn.ASSIGN) { - ruf = sumQueryAssignUpdate; - lruf = lsumQueryAssignUpdate; - } else if (rangeUpdateFunction == RangeUpdateFn.MULTIPLICATION) { - ruf = sumQueryMulUpdate; - lruf = lsumQueryMulUpdate; - } - } else if (segmentCombinationFunction == SegmentCombinationFn.MIN) { - combinationFn = minCombinationFn; - if (rangeUpdateFunction == RangeUpdateFn.ADDITION) { - ruf = minQuerySumUpdate; - lruf = lminQuerySumUpdate; - } else if (rangeUpdateFunction == RangeUpdateFn.ASSIGN) { - ruf = minQueryAssignUpdate; - lruf = lminQueryAssignUpdate; - } else if (rangeUpdateFunction == RangeUpdateFn.MULTIPLICATION) { - ruf = minQueryMulUpdate; - lruf = lminQueryMulUpdate; - } - } else if (segmentCombinationFunction == SegmentCombinationFn.MAX) { - combinationFn = maxCombinationFn; - if (rangeUpdateFunction == RangeUpdateFn.ADDITION) { - ruf = maxQuerySumUpdate; - lruf = lmaxQuerySumUpdate; - } else if (rangeUpdateFunction == RangeUpdateFn.ASSIGN) { - ruf = maxQueryAssignUpdate; - lruf = lmaxQueryAssignUpdate; - } else if (rangeUpdateFunction == RangeUpdateFn.MULTIPLICATION) { - ruf = maxQueryMulUpdate; - lruf = lmaxQueryMulUpdate; - } - } else if (segmentCombinationFunction == SegmentCombinationFn.GCD) { - combinationFn = gcdCombinationFn; - if (rangeUpdateFunction == RangeUpdateFn.ADDITION) { - ruf = gcdQuerySumUpdate; - lruf = lgcdQuerySumUpdate; - } else if (rangeUpdateFunction == RangeUpdateFn.ASSIGN) { - ruf = gcdQueryAssignUpdate; - lruf = lgcdQueryAssignUpdate; - } else if (rangeUpdateFunction == RangeUpdateFn.MULTIPLICATION) { - ruf = gcdQueryMulUpdate; - lruf = lgcdQueryMulUpdate; - } - } else if (segmentCombinationFunction == SegmentCombinationFn.PRODUCT) { - combinationFn = productCombinationFn; - if (rangeUpdateFunction == RangeUpdateFn.ADDITION) { - ruf = productQuerySumUpdate; - lruf = lproductQuerySumUpdate; - } else if (rangeUpdateFunction == RangeUpdateFn.ASSIGN) { - ruf = productQueryAssignUpdate; - lruf = lproductQueryAssignUpdate; - } else if (rangeUpdateFunction == RangeUpdateFn.MULTIPLICATION) { - ruf = productQueryMulUpdate; - lruf = lproductQueryMulUpdate; - } - } else { - throw new UnsupportedOperationException( - "Combination function not supported: " + segmentCombinationFunction); + // For most combos, the ruf and lruf are one of the three shared lambdas above. + // Special cases arise when the segment size matters: + // + // SUM + ADDITION: ruf applies delta*count to the sum; lruf adds deltas + // SUM + ASSIGN: ruf sets sum to delta*count; lruf overwrites + // SUM + MUL: ruf multiplies the sum; lruf multiplies deltas + // + // MIN/MAX + ADDITION: ruf/lruf both add (min/max shift uniformly) + // MIN/MAX + ASSIGN: ruf/lruf both overwrite + // MIN/MAX + MUL: ruf/lruf both multiply (only correct for non-negative multipliers) + // + // GCD + ASSIGN: ruf/lruf both overwrite + // GCD + MUL: ruf/lruf both multiply (gcd(a*d,b*d) = |d|*gcd(a,b)) + // + // PRODUCT + ASSIGN: ruf sets product to d^count; lruf overwrites + // PRODUCT + MUL: ruf multiplies product by d^count; lruf multiplies deltas + // + switch (segmentCombinationFunction) { + case SUM: + switch (rangeUpdateFunction) { + case ADDITION: + ruf = (b, tl, tr, d) -> b + (tr - tl + 1) * d; + lruf = addDelta; + break; + case ASSIGN: + ruf = (b, tl, tr, d) -> (tr - tl + 1) * d; + lruf = assignDelta; + break; + case MULTIPLICATION: + ruf = mulDelta; + lruf = mulDelta; + break; + } + break; + + case MIN: + switch (rangeUpdateFunction) { + case ADDITION: + ruf = addDelta; + lruf = addDelta; + break; + case ASSIGN: + ruf = assignDelta; + lruf = assignDelta; + break; + case MULTIPLICATION: + ruf = mulDelta; + lruf = mulDelta; + break; + } + break; + + case MAX: + switch (rangeUpdateFunction) { + case ADDITION: + ruf = addDelta; + lruf = addDelta; + break; + case ASSIGN: + ruf = assignDelta; + lruf = assignDelta; + break; + case MULTIPLICATION: + ruf = mulDelta; + lruf = mulDelta; + break; + } + break; + + case GCD: + switch (rangeUpdateFunction) { + case ADDITION: + throw new UnsupportedOperationException( + "Can't use GCD with range addition updates; gcd(a+d, b+d) " + + "cannot be computed from gcd(a,b) and d alone."); + case ASSIGN: + ruf = assignDelta; + lruf = assignDelta; + break; + case MULTIPLICATION: + ruf = mulDelta; + lruf = mulDelta; + break; + } + break; + + case PRODUCT: + switch (rangeUpdateFunction) { + case ADDITION: + throw new UnsupportedOperationException( + "Can't use PRODUCT with range addition updates; product(a_i + d) " + + "cannot be computed from product(a_i) and d alone."); + case ASSIGN: + ruf = (b, tl, tr, d) -> (long) Math.pow(d, tr - tl + 1); + lruf = assignDelta; + break; + case MULTIPLICATION: + ruf = (b, tl, tr, d) -> b * (long) Math.pow(d, tr - tl + 1); + lruf = mulDelta; + break; + } + break; } buildSegmentTree(0, 0, n - 1, values); @@ -275,129 +304,79 @@ private void buildSegmentTree(int i, int tl, int tr, long[] values) { int tm = (tl + tr) / 2; buildSegmentTree(2 * i + 1, tl, tm, values); buildSegmentTree(2 * i + 2, tm + 1, tr, values); - t[i] = combinationFn.apply(t[2 * i + 1], t[2 * i + 2]); } /** - * Returns the query of the range [l, r] on the original `values` array (+ any updates made to it) + * Returns the result of the aggregation function over the range [l, r] on the original `values` + * array (including any updates made to it), O(log(n)). * * @param l the left endpoint of the range query (inclusive) * @param r the right endpoint of the range query (inclusive) */ - public Long rangeQuery1(int l, int r) { - return rangeQuery1(0, 0, n - 1, l, r); + public Long rangeQuery(int l, int r) { + return rangeQuery(0, 0, n - 1, l, r); } - /** - * Returns the range query value of the range [l, r] - * - * @param i the index of the current segment in the tree - * @param tl the left endpoint (inclusive) of the current segment - * @param tr the right endpoint (inclusive) of the current segment - * @param l the target left endpoint (inclusive) for the range query - * @param r the target right endpoint (inclusive) for the range query - */ - private Long rangeQuery1(int i, int tl, int tr, int l, int r) { - // Different segment tree types have different base cases + private Long rangeQuery(int i, int tl, int tr, int l, int r) { if (l > r) { return null; } - propagate1(i, tl, tr); + propagate(i, tl, tr); if (tl == l && tr == r) { return t[i]; } int tm = (tl + tr) / 2; - // Instead of checking if [tl, tm] overlaps [l, r] and [tm+1, tr] overlaps - // [l, r], simply recurse on both segments and let the base case return the - // default value for invalid intervals. return combinationFn.apply( - rangeQuery1(2 * i + 1, tl, tm, l, Math.min(tm, r)), - rangeQuery1(2 * i + 2, tm + 1, tr, Math.max(l, tm + 1), r)); + rangeQuery(2 * i + 1, tl, tm, l, Math.min(tm, r)), + rangeQuery(2 * i + 2, tm + 1, tr, Math.max(l, tm + 1), r)); } - // Apply the lazy delta value to the current node and push it to the child segments - private void propagate1(int i, int tl, int tr) { - // No lazy value to propagate + // Apply the lazy delta value to the current node and push it to the child segments. + private void propagate(int i, int tl, int tr) { if (lazy[i] == null) { return; } - // Apply the lazy delta to the current segment. t[i] = ruf.apply(t[i], tl, tr, lazy[i]); - // Push the lazy delta to left/right segments for non-leaf nodes - propagateLazy1(i, tl, tr, lazy[i]); + propagateLazy(i, tl, tr, lazy[i]); lazy[i] = null; } - private void propagateLazy1(int i, int tl, int tr, long delta) { - // Ignore leaf segments + // Push the lazy delta to the left and right child segments. Leaf nodes are skipped + // since they have no children. + private void propagateLazy(int i, int tl, int tr, long delta) { if (tl == tr) return; lazy[2 * i + 1] = lruf.apply(lazy[2 * i + 1], tl, tr, delta); lazy[2 * i + 2] = lruf.apply(lazy[2 * i + 2], tl, tr, delta); } - public void rangeUpdate1(int l, int r, long x) { - rangeUpdate1(0, 0, n - 1, l, r, x); + /** + * Updates all elements in the range [l, r] using the configured range update function, O(log(n)). + * + * @param l the left endpoint of the range update (inclusive) + * @param r the right endpoint of the range update (inclusive) + * @param x the value to apply (added, multiplied, or assigned depending on the RangeUpdateFn) + */ + public void rangeUpdate(int l, int r, long x) { + rangeUpdate(0, 0, n - 1, l, r, x); } - private void rangeUpdate1(int i, int tl, int tr, int l, int r, long x) { - propagate1(i, tl, tr); + private void rangeUpdate(int i, int tl, int tr, int l, int r, long x) { + propagate(i, tl, tr); if (l > r) { return; } - if (tl == l && tr == r) { t[i] = ruf.apply(t[i], tl, tr, x); - propagateLazy1(i, tl, tr, x); + propagateLazy(i, tl, tr, x); } else { int tm = (tl + tr) / 2; - // Instead of checking if [tl, tm] overlaps [l, r] and [tm+1, tr] overlaps - // [l, r], simply recurse on both segments and let the base case disregard - // invalid intervals. - rangeUpdate1(2 * i + 1, tl, tm, l, Math.min(tm, r), x); - rangeUpdate1(2 * i + 2, tm + 1, tr, Math.max(l, tm + 1), r, x); - + rangeUpdate(2 * i + 1, tl, tm, l, Math.min(tm, r), x); + rangeUpdate(2 * i + 2, tm + 1, tr, Math.max(l, tm + 1), r, x); t[i] = combinationFn.apply(t[2 * i + 1], t[2 * i + 2]); } } - // // Updates the value at index `i` in the original `values` array to be `newValue`. - // public void pointUpdate(int i, long newValue) { - // pointUpdate(0, i, 0, n - 1, newValue); - // } - - // /** - // * Update a point value to a new value and update all affected segments, O(log(n)) - // * - // *

Do this by performing a binary search to find the interval containing the point, then - // update - // * the leaf segment with the new value, and re-compute all affected segment values on the - // * callback. - // * - // * @param i the index of the current segment in the tree - // * @param pos the target position to update - // * @param tl the left segment endpoint (inclusive) - // * @param tr the right segment endpoint (inclusive) - // * @param newValue the new value to update - // */ - // private void pointUpdate(int i, int pos, int tl, int tr, long newValue) { - // if (tl == tr) { // `tl == pos && tr == pos` might be clearer - // t[i] = newValue; - // return; - // } - // int tm = (tl + tr) / 2; - // if (pos <= tm) { - // // The point index `pos` is contained within the left segment [tl, tm] - // pointUpdate(2 * i + 1, pos, tl, tm, newValue); - // } else { - // // The point index `pos` is contained within the right segment [tm+1, tr] - // pointUpdate(2 * i + 2, pos, tm + 1, tr, newValue); - // } - // // Re-compute the segment value of the current segment on the callback - // // t[i] = rangeUpdateFn.apply(t[2 * i + 1], t[2 * i + 2]); - // t[i] = combinationFn.apply(t[2 * i + 1], t[2 * i + 2]); - // } - public void printDebugInfo() { printDebugInfo(0, 0, n - 1); System.out.println(); @@ -418,141 +397,104 @@ private void printDebugInfo(int i, int tl, int tr) { //////////////////////////////////////////////////// public static void main(String[] args) { - t(); - // sumQuerySumUpdateExample(); - // minQueryAssignUpdateExample(); - // gcdQueryMulUpdateExample(); - // gcdQueryAssignUpdateExample(); - // productQueryMulUpdateExample(); + sumQuerySumUpdateExample(); + minQueryAssignUpdateExample(); + gcdQueryMulUpdateExample(); + gcdQueryAssignUpdateExample(); + productQueryMulUpdateExample(); + productQueryAssignUpdateExample(); } - private static void productQueryMulUpdateExample() { - // 0, 1, 2, 3 - long[] v = {3, 2, 2, 1}; + private static void sumQuerySumUpdateExample() { + long[] v = {2, 1, 3, 4, -1}; GenericSegmentTree st = - new GenericSegmentTree(v, SegmentCombinationFn.PRODUCT, RangeUpdateFn.MULTIPLICATION); + new GenericSegmentTree(v, SegmentCombinationFn.SUM, RangeUpdateFn.ADDITION); - int l = 0; - int r = 3; - long q = st.rangeQuery1(l, r); - if (q != 12) System.out.println("Error"); - System.out.printf("The product between indeces [%d, %d] is: %d\n", l, r, q); + int l = 1, r = 3; + long q = st.rangeQuery(l, r); + if (q != 8) System.out.println("Error"); + System.out.printf("The sum between [%d, %d] is: %d\n", l, r, q); + st.rangeUpdate(1, 3, 3); + q = st.rangeQuery(l, r); + if (q != 17) System.out.println("Error"); + System.out.printf("The sum between [%d, %d] is: %d\n", l, r, q); + } - // 3, 8, 8, 1 - // 3 * 8 * 8 * 1 = 192 - st.rangeUpdate1(1, 2, 4); - q = st.rangeQuery1(l, r); - if (q != 192) System.out.println("Error"); - System.out.printf("The product between indeces [%d, %d] is: %d\n", l, r, st.rangeQuery1(l, r)); + private static void minQueryAssignUpdateExample() { + long[] v = {2, 1, 3, 4, -1}; + GenericSegmentTree st = + new GenericSegmentTree(v, SegmentCombinationFn.MIN, RangeUpdateFn.ASSIGN); - // 3, 8, 16, 2 - // 3 * 8 * 16 * 2 = 768 - st.rangeUpdate1(2, 3, 2); - q = st.rangeQuery1(l, r); - if (q != 768) System.out.println("Error"); - System.out.printf("The product between indeces [%d, %d] is: %d\n", l, r, st.rangeQuery1(l, r)); - - // 12, 24, 24, 24, 48 - // st.rangeUpdate1(2, 3, 24); - // l = 0; - // r = 4; - // q = st.rangeQuery1(l, r); - // if (q != 12) System.out.println("Error"); - // System.out.printf("The product between indeces [%d, %d] is: %d\n", l, r, st.rangeQuery1(l, - // r)); + long q = st.rangeQuery(1, 3); + if (q != 1) System.out.println("Error"); + System.out.printf("The min between [1, 3] is: %d\n", q); + st.rangeUpdate(1, 3, 3); + q = st.rangeQuery(0, 1); + if (q != 2) System.out.println("Error"); + System.out.printf("The min between [0, 1] is: %d\n", q); } private static void gcdQueryMulUpdateExample() { - // 0, 1, 2, 3, 4 long[] v = {12, 24, 3, 4, -1}; GenericSegmentTree st = new GenericSegmentTree(v, SegmentCombinationFn.GCD, RangeUpdateFn.MULTIPLICATION); - int l = 0; - int r = 2; - long q = st.rangeQuery1(l, r); + long q = st.rangeQuery(0, 2); if (q != 3) System.out.println("Error"); - System.out.printf("The gcd between indeces [%d, %d] is: %d\n", l, r, q); - st.rangeUpdate1(2, 2, 2); - q = st.rangeQuery1(l, r); + System.out.printf("The gcd between [0, 2] is: %d\n", q); + st.rangeUpdate(2, 2, 2); + q = st.rangeQuery(0, 2); if (q != 6) System.out.println("Error"); - System.out.printf("The gcd between indeces [%d, %d] is: %d\n", l, r, st.rangeQuery1(l, r)); - - r = 1; // [l, r] = [0, 1] - q = st.rangeQuery1(l, r); - if (q != 12) System.out.println("Error"); - System.out.printf("The gcd between indeces [%d, %d] is: %d\n", l, r, st.rangeQuery1(l, r)); + System.out.printf("The gcd between [0, 2] is: %d\n", q); } private static void gcdQueryAssignUpdateExample() { - // 0, 1, 2, 3, 4 long[] v = {12, 24, 3, 12, 48}; GenericSegmentTree st = new GenericSegmentTree(v, SegmentCombinationFn.GCD, RangeUpdateFn.ASSIGN); - int l = 0; - int r = 2; - long q = st.rangeQuery1(l, r); + long q = st.rangeQuery(0, 2); if (q != 3) System.out.println("Error"); - System.out.printf("The gcd between indeces [%d, %d] is: %d\n", l, r, q); - // 12, 24, 48, 12, 48 - st.rangeUpdate1(2, 2, 48); - q = st.rangeQuery1(l, r); + st.rangeUpdate(2, 2, 48); + q = st.rangeQuery(0, 2); if (q != 12) System.out.println("Error"); - System.out.printf("The gcd between indeces [%d, %d] is: %d\n", l, r, st.rangeQuery1(l, r)); - // 12, 24, 24, 24, 48 - st.rangeUpdate1(2, 3, 24); - l = 0; - r = 4; - q = st.rangeQuery1(l, r); + st.rangeUpdate(2, 3, 24); + q = st.rangeQuery(0, 4); if (q != 12) System.out.println("Error"); - System.out.printf("The gcd between indeces [%d, %d] is: %d\n", l, r, st.rangeQuery1(l, r)); + System.out.printf("The gcd between [0, 4] is: %d\n", q); } - private static void sumQuerySumUpdateExample() { - // 0, 1, 2, 3, 4 - long[] v = {2, 1, 3, 4, -1}; + private static void productQueryMulUpdateExample() { + long[] v = {3, 2, 2, 1}; GenericSegmentTree st = - new GenericSegmentTree(v, SegmentCombinationFn.SUM, RangeUpdateFn.ADDITION); + new GenericSegmentTree(v, SegmentCombinationFn.PRODUCT, RangeUpdateFn.MULTIPLICATION); - int l = 1; - int r = 3; - long q = st.rangeQuery1(l, r); - if (q != 8) System.out.println("Error"); - System.out.printf("The sum between indeces [%d, %d] is: %d\n", l, r, q); - st.rangeUpdate1(1, 3, 3); - q = st.rangeQuery1(l, r); - if (q != 17) System.out.println("Error"); - System.out.printf("The sum between indeces [%d, %d] is: %d\n", l, r, st.rangeQuery1(l, r)); - } + long q = st.rangeQuery(0, 3); + if (q != 12) System.out.println("Error"); - private static void t() { - long[] v = {1, 4, 3, 0, 5, 8, -2, 7, 5, 2, 9}; - GenericSegmentTree st = - new GenericSegmentTree(v, SegmentCombinationFn.MIN, RangeUpdateFn.ASSIGN); - st.printDebugInfo(); + st.rangeUpdate(1, 2, 4); // {3, 8, 8, 1} -> 192 + q = st.rangeQuery(0, 3); + if (q != 192) System.out.println("Error"); + + st.rangeUpdate(2, 3, 2); // {3, 8, 16, 2} -> 768 + q = st.rangeQuery(0, 3); + if (q != 768) System.out.println("Error"); + System.out.printf("The product between [0, 3] is: %d\n", q); } - private static void minQueryAssignUpdateExample() { - // 0, 1, 2, 3, 4 - long[] v = {2, 1, 3, 4, -1}; + private static void productQueryAssignUpdateExample() { + long[] v = {2, 3, 1, 5}; GenericSegmentTree st = - new GenericSegmentTree(v, SegmentCombinationFn.MIN, RangeUpdateFn.ASSIGN); + new GenericSegmentTree(v, SegmentCombinationFn.PRODUCT, RangeUpdateFn.ASSIGN); - // System.out.println(java.util.Arrays.toString(st.t)); + long q = st.rangeQuery(0, 3); + if (q != 30) System.out.println("Error"); - int l = 1; - int r = 3; - long q = st.rangeQuery1(l, r); - if (q != 1) System.out.println("Error"); - System.out.printf("The min between indeces [%d, %d] is: %d\n", l, r, q); - st.rangeUpdate1(1, 3, 3); - l = 0; - r = 1; - q = st.rangeQuery1(l, r); - if (q != 2) System.out.println("Error"); - System.out.printf("The min between indeces [%d, %d] is: %d\n", l, r, st.rangeQuery1(l, r)); + st.rangeUpdate(1, 2, 4); // {2, 4, 4, 5} -> 160 + q = st.rangeQuery(0, 3); + if (q != 160) System.out.println("Error"); + System.out.printf("The product between [0, 3] is: %d\n", q); } } diff --git a/src/main/java/com/williamfiset/algorithms/datastructures/segmenttree/GenericSegmentTree2.java b/src/main/java/com/williamfiset/algorithms/datastructures/segmenttree/GenericSegmentTree2.java deleted file mode 100644 index 5870f4318..000000000 --- a/src/main/java/com/williamfiset/algorithms/datastructures/segmenttree/GenericSegmentTree2.java +++ /dev/null @@ -1,491 +0,0 @@ -/** - * A generic segment tree implementation that supports several range update and aggregation - * functions. - * - *

Run with: ./gradlew run -Palgorithm=datastructures.segmenttree.GenericSegmentTree2 - * - *

Several thanks to cp-algorithms for their great article on segment trees: - * https://cp-algorithms.com/data_structures/segment_tree.html - * - *

NOTE: This file is still a WIP - * - * @author William Fiset, william.alexandre.fiset@gmail.com - */ -package com.williamfiset.algorithms.datastructures.segmenttree; - -import java.util.Objects; -import java.util.function.BinaryOperator; - -public class GenericSegmentTree2 { - - // The type of segment combination function to use - public static enum SegmentCombinationFn { - SUM, - MIN, - MAX - } - - // When updating the value of a specific index position, or a range of values, - // modify the affected values using the following function: - public static enum RangeUpdateFn { - // When a range update is issued, assign all the values in the range [l, r] to be `x` - ASSIGN, - // When a range update is issued, add a value of `x` to all the elements in the range [l, r] - ADDITION, - // When a range update is issued, multiply all elements in the range [l, r] by a value of `x` - MULTIPLICATION - } - - private static class Segment { // implements PrintableNode - // TODO(william): investigate if we really need this, it's unlikely that we do since it should - // be able to implicitly determine the index. - int i; - - Long value; - Long lazy; - - // Used only for Min/Max mul queries. Used in an attempt to resolve: - // https://github.com/williamfiset/Algorithms/issues/208 - Long min; - Long max; - - // The range of the segment [tl, tr] - int tl; - int tr; - - public Segment(int i, Long value, Long min, Long max, int tl, int tr) { - this.i = i; - this.value = value; - this.min = min; - this.max = max; - this.tl = tl; - this.tr = tr; - } - - // @Override - // public PrintableNode getLeft() { - // return left; - // } - - // @Override - // public PrintableNode getRight() { - // return right; - // } - - // @Override - // public String getText() { - // return value.toString(); - // } - - @Override - public String toString() { - return String.format("[%d, %d], value = %d, lazy = %d", tl, tr, value, lazy); - } - } - - // The number of elements in the original input values array. - private int n; - - // The segment tree represented as a binary tree of ranges where st[0] is the - // root node and the left and right children of node i are i*2+1 and i*2+2. - private Segment[] st; - - // The chosen range combination function - private BinaryOperator combinationFn; - - private interface Ruf { - Long apply(Segment segment, Long delta); - } - - // The Range Update Function (RUF) that chooses how a lazy delta value is - // applied to a segment. - private Ruf ruf; - - // The Lazy Range Update Function (LRUF) associated with the RUF. How you - // propagate the lazy delta values is sometimes different than how you apply - // them to the current segment (but most of the time the RUF = LRUF). - private Ruf lruf; - - private long safeSum(Long a, Long b) { - if (a == null) a = 0L; - if (b == null) b = 0L; - return a + b; - } - - private Long safeMul(Long a, Long b) { - if (a == null) a = 1L; - if (b == null) b = 1L; - return a * b; - } - - private Long safeMin(Long a, Long b) { - if (a == null) return b; - if (b == null) return a; - return Math.min(a, b); - } - - private Long safeMax(Long a, Long b) { - if (a == null) return b; - if (b == null) return a; - return Math.max(a, b); - } - - private BinaryOperator sumCombinationFn = (a, b) -> safeSum(a, b); - private BinaryOperator minCombinationFn = (a, b) -> safeMin(a, b); - private BinaryOperator maxCombinationFn = (a, b) -> safeMax(a, b); - - // TODO(william): Document the justification for each function below - - // Range update functions - private Ruf minQuerySumUpdate = (s, x) -> safeSum(s.value, x); - private Ruf lminQuerySumUpdate = (s, x) -> safeSum(s.lazy, x); - - // // TODO(issue/208): support this multiplication update - private Ruf minQueryMulUpdate = - (s, x) -> { - if (x == 0) { - return 0L; - } else if (x < 0) { - // s.min was already calculated - if (Objects.equals(safeMul(s.value, x), s.min)) { - return s.max; - } else { - return s.min; - } - } else { - return safeMul(s.value, x); - } - }; - private Ruf lminQueryMulUpdate = (s, x) -> safeMul(s.lazy, x); - - private Ruf minQueryAssignUpdate = (s, x) -> x; - private Ruf lminQueryAssignUpdate = (s, x) -> x; - - private Ruf maxQuerySumUpdate = (s, x) -> safeSum(s.value, x); - private Ruf lmaxQuerySumUpdate = (s, x) -> safeSum(s.lazy, x); - - // TODO(issue/208): support this multiplication update - private Ruf maxQueryMulUpdate = - (s, x) -> { - if (x == 0) { - return 0L; - } else if (x < 0) { - if (Objects.equals(safeMul(s.value, x), s.min)) { - return s.max; - } else { - return s.min; - } - } else { - return safeMul(s.value, x); - } - }; - private Ruf lmaxQueryMulUpdate = (s, x) -> safeMul(s.lazy, x); - - private Ruf maxQueryAssignUpdate = (s, x) -> x; - private Ruf lmaxQueryAssignUpdate = (s, x) -> x; - - private Ruf sumQuerySumUpdate = (s, x) -> s.value + (s.tr - s.tl + 1) * x; - private Ruf lsumQuerySumUpdate = (s, x) -> safeSum(s.lazy, x); - - private Ruf sumQueryMulUpdate = (s, x) -> safeMul(s.value, x); - private Ruf lsumQueryMulUpdate = (s, x) -> safeMul(s.lazy, x); - - private Ruf sumQueryAssignUpdate = (s, x) -> (s.tr - s.tl + 1) * x; - private Ruf lsumQueryAssignUpdate = (s, x) -> x; - - public GenericSegmentTree2( - long[] values, - SegmentCombinationFn segmentCombinationFunction, - RangeUpdateFn rangeUpdateFunction) { - if (values == null) { - throw new IllegalArgumentException("Segment tree values cannot be null."); - } - if (segmentCombinationFunction == null) { - throw new IllegalArgumentException("Please specify a valid segment combination function."); - } - if (rangeUpdateFunction == null) { - throw new IllegalArgumentException("Please specify a valid range update function."); - } - n = values.length; - - // The size of the segment tree `t` - // - // TODO(william): Investigate to reduce this space. There are only 2n-1 segments, so we should - // be able to reduce the space, but may need to reorganize the tree/queries. One idea is to use - // the Eulerian tour structure of the tree to densely pack the segments. - int N = 4 * n; - - st = new Segment[N]; - - // Select the specified combination function - if (segmentCombinationFunction == SegmentCombinationFn.SUM) { - combinationFn = sumCombinationFn; - if (rangeUpdateFunction == RangeUpdateFn.ADDITION) { - ruf = sumQuerySumUpdate; - lruf = lsumQuerySumUpdate; - } else if (rangeUpdateFunction == RangeUpdateFn.ASSIGN) { - ruf = sumQueryAssignUpdate; - lruf = lsumQueryAssignUpdate; - } else if (rangeUpdateFunction == RangeUpdateFn.MULTIPLICATION) { - ruf = sumQueryMulUpdate; - lruf = lsumQueryMulUpdate; - } - } else if (segmentCombinationFunction == SegmentCombinationFn.MIN) { - combinationFn = minCombinationFn; - if (rangeUpdateFunction == RangeUpdateFn.ADDITION) { - ruf = minQuerySumUpdate; - lruf = lminQuerySumUpdate; - } else if (rangeUpdateFunction == RangeUpdateFn.ASSIGN) { - ruf = minQueryAssignUpdate; - lruf = lminQueryAssignUpdate; - } else if (rangeUpdateFunction == RangeUpdateFn.MULTIPLICATION) { - ruf = minQueryMulUpdate; - lruf = lminQueryMulUpdate; - } - } else if (segmentCombinationFunction == SegmentCombinationFn.MAX) { - combinationFn = maxCombinationFn; - if (rangeUpdateFunction == RangeUpdateFn.ADDITION) { - ruf = maxQuerySumUpdate; - lruf = lmaxQuerySumUpdate; - } else if (rangeUpdateFunction == RangeUpdateFn.ASSIGN) { - ruf = maxQueryAssignUpdate; - lruf = lmaxQueryAssignUpdate; - } else if (rangeUpdateFunction == RangeUpdateFn.MULTIPLICATION) { - ruf = maxQueryMulUpdate; - lruf = lmaxQueryMulUpdate; - } - } else { - throw new UnsupportedOperationException( - "Combination function not supported: " + segmentCombinationFunction); - } - - buildSegmentTree(0, 0, n - 1, values); - } - - /** - * Builds a segment tree by starting with the leaf nodes and combining segment values on callback. - * - * @param i the index of the segment in the segment tree - * @param tl the left index (inclusive) of the segment range - * @param tr the right index (inclusive) of the segment range - * @param values the initial values array - */ - private void buildSegmentTree(int i, int tl, int tr, long[] values) { - if (tl == tr) { - st[i] = new Segment(i, values[tl], values[tl], values[tl], tl, tr); - return; - } - int tm = (tl + tr) / 2; - buildSegmentTree(2 * i + 1, tl, tm, values); - buildSegmentTree(2 * i + 2, tm + 1, tr, values); - - Long segmentValue = combinationFn.apply(st[2 * i + 1].value, st[2 * i + 2].value); - Long minValue = Math.min(st[2 * i + 1].min, st[2 * i + 2].min); - Long maxValue = Math.max(st[2 * i + 1].max, st[2 * i + 2].max); - Segment segment = new Segment(i, segmentValue, minValue, maxValue, tl, tr); - - st[i] = segment; - } - - /** - * Returns the query of the range [l, r] on the original `values` array (+ any updates made to it) - * - * @param l the left endpoint of the range query (inclusive) - * @param r the right endpoint of the range query (inclusive) - */ - public Long rangeQuery1(int l, int r) { - return rangeQuery1(0, 0, n - 1, l, r); - } - - /** - * Returns the range query value of the range [l, r] - * - * @param i the index of the current segment in the tree - * @param tl the left endpoint (inclusive) of the current segment - * @param tr the right endpoint (inclusive) of the current segment - * @param l the target left endpoint (inclusive) for the range query - * @param r the target right endpoint (inclusive) for the range query - */ - private Long rangeQuery1(int i, int tl, int tr, int l, int r) { - // Different segment tree types have different base cases - if (l > r) { - return null; - } - propagate1(i, tl, tr); - if (tl == l && tr == r) { - return st[i].value; - } - int tm = (tl + tr) / 2; - // Instead of checking if [tl, tm] overlaps [l, r] and [tm+1, tr] overlaps - // [l, r], simply recurse on both segments and let the base case return the - // default value for invalid intervals. - return combinationFn.apply( - rangeQuery1(2 * i + 1, tl, tm, l, Math.min(tm, r)), - rangeQuery1(2 * i + 2, tm + 1, tr, Math.max(l, tm + 1), r)); - } - - // Apply the delta value to the current node and push it to the child segments - private void propagate1(int i, int tl, int tr) { - if (st[i].lazy != null) { - // Only used for min/max mul queries - st[i].min = st[i].min * st[i].lazy; - st[i].max = st[i].max * st[i].lazy; - - // Apply the delta to the current segment. - st[i].value = ruf.apply(st[i], st[i].lazy); - // Push the delta to left/right segments for non-leaf nodes - propagateLazy1(i, tl, tr, st[i].lazy); - st[i].lazy = null; - } - } - - private void propagateLazy1(int i, int tl, int tr, long delta) { - // Ignore leaf segments - if (tl == tr) return; - st[2 * i + 1].lazy = lruf.apply(st[2 * i + 1], delta); - st[2 * i + 2].lazy = lruf.apply(st[2 * i + 2], delta); - } - - public void rangeUpdate1(int l, int r, long x) { - rangeUpdate1(0, 0, n - 1, l, r, x); - } - - private void rangeUpdate1(int i, int tl, int tr, int l, int r, long x) { - propagate1(i, tl, tr); - if (l > r) { - return; - } - - if (tl == l && tr == r) { - // Only used for min/max mul queries - st[i].min = st[i].min * x; - st[i].max = st[i].max * x; - - st[i].value = ruf.apply(st[i], x); - propagateLazy1(i, tl, tr, x); - } else { - int tm = (tl + tr) / 2; - // Instead of checking if [tl, tm] overlaps [l, r] and [tm+1, tr] overlaps - // [l, r], simply recurse on both segments and let the base case disregard - // invalid intervals. - rangeUpdate1(2 * i + 1, tl, tm, l, Math.min(tm, r), x); - rangeUpdate1(2 * i + 2, tm + 1, tr, Math.max(l, tm + 1), r, x); - - st[i].value = combinationFn.apply(st[2 * i + 1].value, st[2 * i + 2].value); - st[i].max = Math.max(st[2 * i + 1].max, st[2 * i + 2].max); - st[i].min = Math.min(st[2 * i + 1].min, st[2 * i + 2].min); - } - } - - // // Updates the value at index `i` in the original `values` array to be `newValue`. - // public void pointUpdate(int i, long newValue) { - // pointUpdate(0, i, 0, n - 1, newValue); - // } - - // /** - // * Update a point value to a new value and update all affected segments, O(log(n)) - // * - // *

Do this by performing a binary search to find the interval containing the point, then - // update - // * the leaf segment with the new value, and re-compute all affected segment values on the - // * callback. - // * - // * @param i the index of the current segment in the tree - // * @param pos the target position to update - // * @param tl the left segment endpoint (inclusive) - // * @param tr the right segment endpoint (inclusive) - // * @param newValue the new value to update - // */ - // private void pointUpdate(int i, int pos, int tl, int tr, long newValue) { - // if (tl == tr) { // `tl == pos && tr == pos` might be clearer - // t[i] = newValue; - // return; - // } - // int tm = (tl + tr) / 2; - // if (pos <= tm) { - // // The point index `pos` is contained within the left segment [tl, tm] - // pointUpdate(2 * i + 1, pos, tl, tm, newValue); - // } else { - // // The point index `pos` is contained within the right segment [tm+1, tr] - // pointUpdate(2 * i + 2, pos, tm + 1, tr, newValue); - // } - // // Re-compute the segment value of the current segment on the callback - // // t[i] = rangeUpdateFn.apply(t[2 * i + 1], t[2 * i + 2]); - // t[i] = combinationFn.apply(t[2 * i + 1], t[2 * i + 2]); - // } - - public void printDebugInfo() { - printDebugInfo(0); - System.out.println(); - } - - private void printDebugInfo(int i) { - System.out.println(st[i]); - if (st[i].tl == st[i].tr) { - return; - } - printDebugInfo(2 * i + 1); - printDebugInfo(2 * i + 2); - } - - //////////////////////////////////////////////////// - // Example usage: // - //////////////////////////////////////////////////// - - public static void main(String[] args) { - minQuerySumUpdate(); - sumQuerySumUpdateExample(); - minQueryAssignUpdateExample(); - } - - private static void minQuerySumUpdate() { - // 0, 1, 2, 3, 4 - long[] v = {2, 1, 3, 4, -1}; - GenericSegmentTree2 st = - new GenericSegmentTree2(v, SegmentCombinationFn.MIN, RangeUpdateFn.ADDITION); - - int l = 1; - int r = 3; - long q = st.rangeQuery1(l, r); - if (q != 1) System.out.println("Error"); - System.out.printf("The min between indeces [%d, %d] is: %d\n", l, r, q); - - st.printDebugInfo(); - } - - private static void sumQuerySumUpdateExample() { - // 0, 1, 2, 3, 4 - long[] v = {2, 1, 3, 4, -1}; - GenericSegmentTree2 st = - new GenericSegmentTree2(v, SegmentCombinationFn.SUM, RangeUpdateFn.ADDITION); - - int l = 1; - int r = 3; - long q = st.rangeQuery1(l, r); - if (q != 8) System.out.println("Error"); - System.out.printf("The sum between indeces [%d, %d] is: %d\n", l, r, q); - st.rangeUpdate1(1, 3, 3); - q = st.rangeQuery1(l, r); - if (q != 17) System.out.println("Error"); - System.out.printf("The sum between indeces [%d, %d] is: %d\n", l, r, st.rangeQuery1(l, r)); - } - - private static void minQueryAssignUpdateExample() { - // 0, 1, 2, 3, 4 - long[] v = {2, 1, 3, 4, -1}; - GenericSegmentTree2 st = - new GenericSegmentTree2(v, SegmentCombinationFn.MIN, RangeUpdateFn.ASSIGN); - - int l = 1; - int r = 3; - long q = st.rangeQuery1(l, r); - if (q != 1) System.out.println("Error"); - System.out.printf("The min between indeces [%d, %d] is: %d\n", l, r, q); - st.rangeUpdate1(1, 3, 3); - l = 0; - r = 1; - q = st.rangeQuery1(l, r); - if (q != 2) System.out.println("Error"); - System.out.printf("The min between indeces [%d, %d] is: %d\n", l, r, st.rangeQuery1(l, r)); - } -} diff --git a/src/main/java/com/williamfiset/algorithms/datastructures/segmenttree/GenericSegmentTree3.java b/src/main/java/com/williamfiset/algorithms/datastructures/segmenttree/GenericSegmentTree3.java deleted file mode 100644 index db529d742..000000000 --- a/src/main/java/com/williamfiset/algorithms/datastructures/segmenttree/GenericSegmentTree3.java +++ /dev/null @@ -1,500 +0,0 @@ -/** - * A generic segment tree implementation that supports several range update and aggregation - * functions. - * - *

Run with: ./gradlew run -Palgorithm=datastructures.segmenttree.GenericSegmentTree3 - * - *

Several thanks to cp-algorithms for their great article on segment trees: - * https://cp-algorithms.com/data_structures/segment_tree.html - * - *

NOTE: This file is still a WIP - * - * @author William Fiset, william.alexandre.fiset@gmail.com - */ -package com.williamfiset.algorithms.datastructures.segmenttree; - -public class GenericSegmentTree3 { - - // // The type of segment combination function to use - // public static enum SegmentCombinationFn { - // SUM, - // MIN, - // MAX - // } - - // // When updating the value of a specific index position, or a range of values, - // // modify the affected values using the following function: - // public static enum RangeUpdateFn { - // // When a range update is issued, assign all the values in the range [l, r] to be `x` - // ASSIGN, - // // When a range update is issued, add a value of `x` to all the elements in the range [l, r] - // ADDITION, - // // When a range update is issued, multiply all elements in the range [l, r] by a value of `x` - // MULTIPLICATION - // } - - // // TODO(william): Make this class static if possible to avoid sharing members with parent ST - // class - // private class SegmentNode implements PrintableNode { - // // TODO(william): investigate if we really need this, it's unlikely that we do since it - // should - // // be able to implicitly determine the index. - // int i; - - // Long value; - // Long lazy; - - // // Used only for Min/Max mul queries. Used in an attempt to resolve: - // // https://github.com/williamfiset/Algorithms/issues/208 - // Long min; - // Long max; - - // // The range of the segment [l, r] - // int l; - // int r; - - // // The two child segments of this segment (null otherwise). - // // Left segment is [l, m) and right segment is [m, r] where m = (l+r)/2 - // SegmentNode left; - // SegmentNode right; - - // public SegmentNode(int i, Long value, Long min, Long max, int l, int r) { - // this.i = i; - // this.value = value; - // this.min = min; - // this.max = max; - // this.l = l; - // this.r = r; - // } - - // public Long rangeQuery1(int ll, int rr) { - // // Different segment tree types have different base cases - // if (ll > rr) { - // return null; - // } - // propagate1(); - // if (exactOverlap(l, r)) { - // return value; - // } - // int m = (l + r) / 2; - // // Instead of checking if [ll, m] overlaps [l, r] and [m+1, rr] overlaps - // // [l, r], simply recurse on both segments and let the base case return the - // // default value for invalid intervals. - // return combinationFn.apply( - // rangeQuery1(left, l, Math.min(m, rr)), - // rangeQuery1(right, Math.max(ll, m + 1), rr)); - // } - - // // Apply the delta value to the current node and push it to the child segments - // public void propagate1() { - // if (lazy != null) { - // // Only used for min/max mul queries - // min = min * lazy; - // max = max * lazy; - - // // Apply the delta to the current segment. - // value = ruf.apply(node, lazy); - // // Push the delta to left/right segments for non-leaf nodes - // propagateLazy1(lazy); - // lazy = null; - // } - // } - - // public void propagateLazy1(long delta) { - // // Ignore leaf segments since they don't have children. - // if (isLeaf()) { - // return; - // } - // left.lazy = lruf.apply(left, delta); - // right.lazy = lruf.apply(right, delta); - // } - - // public boolean exactOverlap(int ll, int rr) { - // return l == ll && r = rr; - // } - - // public boolean isLeaf() { - // return l == r; - // } - - // @Override - // public PrintableNode getLeft() { - // return left; - // } - - // @Override - // public PrintableNode getRight() { - // return right; - // } - - // @Override - // public String getText() { - // return value.toString(); - // } - - // @Override - // public String toString() { - // return String.format("[%d, %d], value = %d, lazy = %d", tl, tr, value, lazy); - // } - // } - - // // The number of elements in the original input values array. - // private int n; - - // private SegmentNode root; - - // // The chosen range combination function - // private BinaryOperator combinationFn; - - // private interface Ruf { - // Long apply(SegmentNode segment, Long delta); - // } - - // // The Range Update Function (RUF) that chooses how a lazy delta value is - // // applied to a segment. - // private Ruf ruf; - - // // The Lazy Range Update Function (LRUF) associated with the RUF. How you - // // propagate the lazy delta values is sometimes different than how you apply - // // them to the current segment (but most of the time the RUF = LRUF). - // private Ruf lruf; - - // private long safeSum(Long a, Long b) { - // if (a == null) a = 0L; - // if (b == null) b = 0L; - // return a + b; - // } - - // private Long safeMul(Long a, Long b) { - // if (a == null) a = 1L; - // if (b == null) b = 1L; - // return a * b; - // } - - // private Long safeMin(Long a, Long b) { - // if (a == null) return b; - // if (b == null) return a; - // return Math.min(a, b); - // } - - // private Long safeMax(Long a, Long b) { - // if (a == null) return b; - // if (b == null) return a; - // return Math.max(a, b); - // } - - // private BinaryOperator sumCombinationFn = (a, b) -> safeSum(a, b); - // private BinaryOperator minCombinationFn = (a, b) -> safeMin(a, b); - // private BinaryOperator maxCombinationFn = (a, b) -> safeMax(a, b); - - // // TODO(william): Document the justification for each function below - - // // Range update functions - // private Ruf minQuerySumUpdate = (s, x) -> safeSum(s.value, x); - // private Ruf lminQuerySumUpdate = (s, x) -> safeSum(s.lazy, x); - - // // // TODO(issue/208): support this multiplication update - // private Ruf minQueryMulUpdate = - // (s, x) -> { - // if (x == 0) { - // return 0L; - // } else if (x < 0) { - // // s.min was already calculated - // if (safeMul(s.value, x) == s.min) { - // return s.max; - // } else { - // return s.min; - // } - // } else { - // return safeMul(s.value, x); - // } - // }; - // private Ruf lminQueryMulUpdate = (s, x) -> safeMul(s.lazy, x); - - // private Ruf minQueryAssignUpdate = (s, x) -> x; - // private Ruf lminQueryAssignUpdate = (s, x) -> x; - - // private Ruf maxQuerySumUpdate = (s, x) -> safeSum(s.value, x); - // private Ruf lmaxQuerySumUpdate = (s, x) -> safeSum(s.lazy, x); - - // // TODO(issue/208): support this multiplication update - // private Ruf maxQueryMulUpdate = - // (s, x) -> { - // if (x == 0) { - // return 0L; - // } else if (x < 0) { - // if (safeMul(s.value, x) == s.min) { - // return s.max; - // } else { - // return s.min; - // } - // } else { - // return safeMul(s.value, x); - // } - // }; - // private Ruf lmaxQueryMulUpdate = (s, x) -> safeMul(s.lazy, x); - - // private Ruf maxQueryAssignUpdate = (s, x) -> x; - // private Ruf lmaxQueryAssignUpdate = (s, x) -> x; - - // private Ruf sumQuerySumUpdate = (s, x) -> s.value + (s.tr - s.tl + 1) * x; - // private Ruf lsumQuerySumUpdate = (s, x) -> safeSum(s.lazy, x); - - // private Ruf sumQueryMulUpdate = (s, x) -> safeMul(s.value, x); - // private Ruf lsumQueryMulUpdate = (s, x) -> safeMul(s.lazy, x); - - // private Ruf sumQueryAssignUpdate = (s, x) -> (s.tr - s.tl + 1) * x; - // private Ruf lsumQueryAssignUpdate = (s, x) -> x; - - // public GenericSegmentTree3( - // long[] values, - // SegmentCombinationFn segmentCombinationFunction, - // RangeUpdateFn rangeUpdateFunction) { - // if (values == null) { - // throw new IllegalArgumentException("Segment tree values cannot be null."); - // } - // if (segmentCombinationFunction == null) { - // throw new IllegalArgumentException("Please specify a valid segment combination function."); - // } - // if (rangeUpdateFunction == null) { - // throw new IllegalArgumentException("Please specify a valid range update function."); - // } - // n = values.length; - - // // The size of the segment tree `t` - // // - // // TODO(william): Investigate to reduce this space. There are only 2n-1 segments, so we - // should - // // be able to reduce the space, but may need to reorganize the tree/queries. One idea is to - // use - // // the Eulerian tour structure of the tree to densely pack the segments. - // int N = 4 * n; - - // // Select the specified combination function - // if (segmentCombinationFunction == SegmentCombinationFn.SUM) { - // combinationFn = sumCombinationFn; - // if (rangeUpdateFunction == RangeUpdateFn.ADDITION) { - // ruf = sumQuerySumUpdate; - // lruf = lsumQuerySumUpdate; - // } else if (rangeUpdateFunction == RangeUpdateFn.ASSIGN) { - // ruf = sumQueryAssignUpdate; - // lruf = lsumQueryAssignUpdate; - // } else if (rangeUpdateFunction == RangeUpdateFn.MULTIPLICATION) { - // ruf = sumQueryMulUpdate; - // lruf = lsumQueryMulUpdate; - // } - // } else if (segmentCombinationFunction == SegmentCombinationFn.MIN) { - // combinationFn = minCombinationFn; - // if (rangeUpdateFunction == RangeUpdateFn.ADDITION) { - // ruf = minQuerySumUpdate; - // lruf = lminQuerySumUpdate; - // } else if (rangeUpdateFunction == RangeUpdateFn.ASSIGN) { - // ruf = minQueryAssignUpdate; - // lruf = lminQueryAssignUpdate; - // } else if (rangeUpdateFunction == RangeUpdateFn.MULTIPLICATION) { - // ruf = minQueryMulUpdate; - // lruf = lminQueryMulUpdate; - // } - // } else if (segmentCombinationFunction == SegmentCombinationFn.MAX) { - // combinationFn = maxCombinationFn; - // if (rangeUpdateFunction == RangeUpdateFn.ADDITION) { - // ruf = maxQuerySumUpdate; - // lruf = lmaxQuerySumUpdate; - // } else if (rangeUpdateFunction == RangeUpdateFn.ASSIGN) { - // ruf = maxQueryAssignUpdate; - // lruf = lmaxQueryAssignUpdate; - // } else if (rangeUpdateFunction == RangeUpdateFn.MULTIPLICATION) { - // ruf = maxQueryMulUpdate; - // lruf = lmaxQueryMulUpdate; - // } - // } else { - // throw new UnsupportedOperationException( - // "Combination function not supported: " + segmentCombinationFunction); - // } - - // root = buildSegmentTree(0, 0, n - 1, values); - // } - - // /** - // * Builds a segment tree by starting with the leaf nodes and combining segment values on - // callback. - // * - // * @param i the index of the segment in the segment tree - // * @param l the left index (inclusive) of the segment range - // * @param r the right index (inclusive) of the segment range - // * @param values the initial values array - // */ - // private SegmentNode buildSegmentTree(int i, int l, int r, long[] values) { - // if (l == r) { - // return new SegmentNode(i, values[l], values[l], values[l], l, r); - // } - // int tm = (l + r) / 2; - // SegmentNode left = buildSegmentTree(2 * i + 1, l, tm, values); - // SegmentNode right = buildSegmentTree(2 * i + 2, tm + 1, r, values); - - // Long segmentValue = combinationFn.apply(left.value, right.value); - // Long minValue = Math.min(left.min, right.min); - // Long maxValue = Math.max(left.max, right.max); - - // // TODO(william): move assigning children to the constructor? - // Segment segment = new Segment(i, segmentValue, minValue, maxValue, l, r); - // segment.left = left; - // segment.right = right; - - // return segment; - // } - - // /** - // * Returns the query of the range [l, r] on the original `values` array (+ any updates made to - // it) - // * - // * @param l the left endpoint of the range query (inclusive) - // * @param r the right endpoint of the range query (inclusive) - // */ - // public Long rangeQuery1(int l, int r) { - // return root.rangeQuery1(l, r); - // } - - // public void rangeUpdate1(int l, int r, long x) { - // rangeUpdate1(node, l, r, x); - // } - - // private void rangeUpdate1(SegmentNode node, int l, int r, long x) { - // node.propagate1(); - // if (l > r) { - // return; - // } - - // if (node.exactOverlap(l, r)) { - // // Only used for min/max mul queries - // node.min = node.min * x; - // node.max = node.max * x; - - // node.value = ruf.apply(node, x); - // node.propagateLazy1(x); - // } else { - // int m = (l + r) / 2; - // // Instead of checking if [tl, m] overlaps [l, r] and [m+1, tr] overlaps - // // [l, r], simply recurse on both segments and let the base case disregard - // // invalid intervals. - // rangeUpdate1(node.left, l, Math.min(m, r), x); - // rangeUpdate1(node.right, Math.max(l, m + 1), r, x); - - // node.value = combinationFn.apply(node.left.value, node.right.value); - // node.max = Math.max(node.left.max, node.right.max); - // node.min = Math.min(node.left.min, node.right.min); - // } - // } - - // // // Updates the value at index `i` in the original `values` array to be `newValue`. - // // public void pointUpdate(int i, long newValue) { - // // pointUpdate(0, i, 0, n - 1, newValue); - // // } - - // // /** - // // * Update a point value to a new value and update all affected segments, O(log(n)) - // // * - // // *

Do this by performing a binary search to find the interval containing the point, then - // // update - // // * the leaf segment with the new value, and re-compute all affected segment values on the - // // * callback. - // // * - // // * @param i the index of the current segment in the tree - // // * @param pos the target position to update - // // * @param tl the left segment endpoint (inclusive) - // // * @param tr the right segment endpoint (inclusive) - // // * @param newValue the new value to update - // // */ - // // private void pointUpdate(int i, int pos, int tl, int tr, long newValue) { - // // if (tl == tr) { // `tl == pos && tr == pos` might be clearer - // // t[i] = newValue; - // // return; - // // } - // // int tm = (tl + tr) / 2; - // // if (pos <= tm) { - // // // The point index `pos` is contained within the left segment [tl, tm] - // // pointUpdate(2 * i + 1, pos, tl, tm, newValue); - // // } else { - // // // The point index `pos` is contained within the right segment [tm+1, tr] - // // pointUpdate(2 * i + 2, pos, tm + 1, tr, newValue); - // // } - // // // Re-compute the segment value of the current segment on the callback - // // // t[i] = rangeUpdateFn.apply(t[2 * i + 1], t[2 * i + 2]); - // // t[i] = combinationFn.apply(t[2 * i + 1], t[2 * i + 2]); - // // } - - // public void printDebugInfo() { - // printDebugInfo(0); - // System.out.println(); - // } - - // private void printDebugInfo(int i) { - // // System.out.println(st[i]); - // // if (st[i].tl == st[i].tr) { - // // return; - // // } - // // printDebugInfo(2 * i + 1); - // // printDebugInfo(2 * i + 2); - // } - - // //////////////////////////////////////////////////// - // // Example usage: // - // //////////////////////////////////////////////////// - - // public static void main(String[] args) { - // minQuerySumUpdate(); - // sumQuerySumUpdateExample(); - // minQueryAssignUpdateExample(); - // } - - // private static void minQuerySumUpdate() { - // // 0, 1, 2, 3, 4 - // long[] v = {2, 1, 3, 4, -1}; - // GenericSegmentTree3 st = - // new GenericSegmentTree3(v, SegmentCombinationFn.MIN, RangeUpdateFn.ADDITION); - - // int l = 1; - // int r = 3; - // long q = st.rangeQuery1(l, r); - // if (q != 1) System.out.println("Error"); - // System.out.printf("The min between indeces [%d, %d] is: %d\n", l, r, q); - - // st.printDebugInfo(); - // } - - // private static void sumQuerySumUpdateExample() { - // // 0, 1, 2, 3, 4 - // long[] v = {2, 1, 3, 4, -1}; - // GenericSegmentTree3 st = - // new GenericSegmentTree3(v, SegmentCombinationFn.SUM, RangeUpdateFn.ADDITION); - - // int l = 1; - // int r = 3; - // long q = st.rangeQuery1(l, r); - // if (q != 8) System.out.println("Error"); - // System.out.printf("The sum between indeces [%d, %d] is: %d\n", l, r, q); - // st.rangeUpdate1(1, 3, 3); - // q = st.rangeQuery1(l, r); - // if (q != 17) System.out.println("Error"); - // System.out.printf("The sum between indeces [%d, %d] is: %d\n", l, r, st.rangeQuery1(l, r)); - // } - - // private static void minQueryAssignUpdateExample() { - // // 0, 1, 2, 3, 4 - // long[] v = {2, 1, 3, 4, -1}; - // GenericSegmentTree3 st = - // new GenericSegmentTree3(v, SegmentCombinationFn.MIN, RangeUpdateFn.ASSIGN); - - // int l = 1; - // int r = 3; - // long q = st.rangeQuery1(l, r); - // if (q != 1) System.out.println("Error"); - // System.out.printf("The min between indeces [%d, %d] is: %d\n", l, r, q); - // st.rangeUpdate1(1, 3, 3); - // l = 0; - // r = 1; - // q = st.rangeQuery1(l, r); - // if (q != 2) System.out.println("Error"); - // System.out.printf("The min between indeces [%d, %d] is: %d\n", l, r, st.rangeQuery1(l, r)); - // } -} diff --git a/src/test/java/com/williamfiset/algorithms/datastructures/segmenttree/BUILD b/src/test/java/com/williamfiset/algorithms/datastructures/segmenttree/BUILD index caee6deb3..7bafacfb9 100644 --- a/src/test/java/com/williamfiset/algorithms/datastructures/segmenttree/BUILD +++ b/src/test/java/com/williamfiset/algorithms/datastructures/segmenttree/BUILD @@ -19,16 +19,6 @@ TEST_DEPS = [ "@maven//:junit_junit", ] + JUNIT5_DEPS -# bazel test //src/test/java/com/williamfiset/algorithms/datastructures/segmenttree:GenericSegmentTree2Test -java_test( - name = "GenericSegmentTree2Test", - srcs = ["GenericSegmentTree2Test.java"], - main_class = "org.junit.platform.console.ConsoleLauncher", - use_testrunner = False, - args = ["--select-class=com.williamfiset.algorithms.datastructures.segmenttree.GenericSegmentTree2Test"], - runtime_deps = JUNIT5_RUNTIME_DEPS, - deps = TEST_DEPS, -) # bazel test //src/test/java/com/williamfiset/algorithms/datastructures/segmenttree:GenericSegmentTreeTest java_test( diff --git a/src/test/java/com/williamfiset/algorithms/datastructures/segmenttree/GenericSegmentTree2Test.java b/src/test/java/com/williamfiset/algorithms/datastructures/segmenttree/GenericSegmentTree2Test.java deleted file mode 100644 index aa36954ab..000000000 --- a/src/test/java/com/williamfiset/algorithms/datastructures/segmenttree/GenericSegmentTree2Test.java +++ /dev/null @@ -1,498 +0,0 @@ -/** - * gradle test --info --tests - * "com.williamfiset.algorithms.datastructures.segmenttree.GenericSegmentTree2Test" - */ -package com.williamfiset.algorithms.datastructures.segmenttree; - -import static com.google.common.truth.Truth.assertThat; - -import com.williamfiset.algorithms.utils.TestUtils; -import org.junit.jupiter.api.*; - -public class GenericSegmentTree2Test { - - static int ITERATIONS = 100; - static int MAX_N = 28; - - @BeforeEach - public void setup() {} - - @Test - public void testSumQuerySumUpdate_Simple() { - long[] values = {1, 2, 3, 4, 5}; - GenericSegmentTree2 st = - new GenericSegmentTree2( - values, - GenericSegmentTree2.SegmentCombinationFn.SUM, - GenericSegmentTree2.RangeUpdateFn.ADDITION); - - assertThat(st.rangeQuery1(0, 1)).isEqualTo(3); - assertThat(st.rangeQuery1(2, 2)).isEqualTo(3); - assertThat(st.rangeQuery1(0, 4)).isEqualTo(15); - } - - @Test - public void testSumQuerySumUpdate_RangeUpdate() { - // 0, 1, 2, 3, 4 - long[] ar = {1, 2, 1, 2, 1}; - GenericSegmentTree2 st = - new GenericSegmentTree2( - ar, - GenericSegmentTree2.SegmentCombinationFn.SUM, - GenericSegmentTree2.RangeUpdateFn.ADDITION); - - // Do multiple range updates - st.rangeUpdate1(0, 1, 5); - st.rangeUpdate1(3, 4, 2); - st.rangeUpdate1(0, 4, 3); - - // Point queries - assertThat(st.rangeQuery1(0, 0)).isEqualTo(1 + 3 + 5); - assertThat(st.rangeQuery1(1, 1)).isEqualTo(2 + 3 + 5); - assertThat(st.rangeQuery1(2, 2)).isEqualTo(1 + 3); - assertThat(st.rangeQuery1(3, 3)).isEqualTo(2 + 3 + 2); - assertThat(st.rangeQuery1(4, 4)).isEqualTo(2 + 3 + 1); - - // Range queries - assertThat(st.rangeQuery1(0, 1)).isEqualTo(2 * 5 + 2 * 3 + 1 + 2); - assertThat(st.rangeQuery1(0, 2)).isEqualTo(2 * 5 + 3 * 3 + 1 + 2 + 1); - assertThat(st.rangeQuery1(3, 4)).isEqualTo(2 * 2 + 2 * 3 + 2 + 1); - assertThat(st.rangeQuery1(0, 4)).isEqualTo(2 * 5 + 2 * 2 + 3 * 5 + 1 + 1 + 1 + 2 + 2); - } - - @Test - public void testSumQueryAssignUpdate_simple() { - long[] ar = {2, 1, 3, 4, -1}; - GenericSegmentTree2 st = - new GenericSegmentTree2( - ar, - GenericSegmentTree2.SegmentCombinationFn.SUM, - GenericSegmentTree2.RangeUpdateFn.ASSIGN); - - st.rangeUpdate1(3, 4, 2); - - assertThat(st.rangeQuery1(0, 4)).isEqualTo(10); - assertThat(st.rangeQuery1(3, 4)).isEqualTo(4); - assertThat(st.rangeQuery1(3, 3)).isEqualTo(2); - assertThat(st.rangeQuery1(4, 4)).isEqualTo(2); - - st.rangeUpdate1(1, 3, 4); - - assertThat(st.rangeQuery1(0, 4)).isEqualTo(16); - assertThat(st.rangeQuery1(0, 1)).isEqualTo(6); - assertThat(st.rangeQuery1(3, 4)).isEqualTo(6); - assertThat(st.rangeQuery1(1, 1)).isEqualTo(4); - assertThat(st.rangeQuery1(2, 2)).isEqualTo(4); - assertThat(st.rangeQuery1(3, 3)).isEqualTo(4); - assertThat(st.rangeQuery1(1, 3)).isEqualTo(12); - assertThat(st.rangeQuery1(2, 3)).isEqualTo(8); - assertThat(st.rangeQuery1(1, 2)).isEqualTo(8); - - st.rangeUpdate1(2, 2, 5); - - assertThat(st.rangeQuery1(0, 4)).isEqualTo(17); - assertThat(st.rangeQuery1(0, 2)).isEqualTo(11); - assertThat(st.rangeQuery1(2, 4)).isEqualTo(11); - assertThat(st.rangeQuery1(1, 3)).isEqualTo(13); - assertThat(st.rangeQuery1(2, 2)).isEqualTo(5); - } - - @Test - public void testSumQueryMulUpdate_simple() { - long[] ar = {1, 4, 5, 3, 2}; - GenericSegmentTree2 st = - new GenericSegmentTree2( - ar, - GenericSegmentTree2.SegmentCombinationFn.SUM, - GenericSegmentTree2.RangeUpdateFn.MULTIPLICATION); - - st.rangeUpdate1(1, 3, 3); - - assertThat(st.rangeQuery1(1, 3)).isEqualTo(4 * 3 + 5 * 3 + 3 * 3); - assertThat(st.rangeQuery1(0, 4)).isEqualTo(1 + 4 * 3 + 5 * 3 + 3 * 3 + 2); - assertThat(st.rangeQuery1(0, 2)).isEqualTo(1 + 4 * 3 + 5 * 3); - assertThat(st.rangeQuery1(2, 4)).isEqualTo(5 * 3 + 3 * 3 + 2); - - st.rangeUpdate1(1, 3, 2); - assertThat(st.rangeQuery1(1, 3)).isEqualTo(4 * 3 * 2 + 5 * 3 * 2 + 3 * 3 * 2); - } - - @Test - public void minQuerySumUpdates_simple() { - long[] ar = {2, 1, 3, 4, -1}; - GenericSegmentTree2 st = - new GenericSegmentTree2( - ar, - GenericSegmentTree2.SegmentCombinationFn.MIN, - GenericSegmentTree2.RangeUpdateFn.ADDITION); - - st.rangeUpdate1(0, 4, 1); - - assertThat(st.rangeQuery1(0, 4)).isEqualTo(0); - assertThat(st.rangeQuery1(1, 3)).isEqualTo(2); - assertThat(st.rangeQuery1(2, 4)).isEqualTo(0); - assertThat(st.rangeQuery1(3, 3)).isEqualTo(5); - - st.rangeUpdate1(3, 4, 4); - - assertThat(st.rangeQuery1(0, 4)).isEqualTo(2); - assertThat(st.rangeQuery1(0, 1)).isEqualTo(2); - assertThat(st.rangeQuery1(3, 4)).isEqualTo(4); - assertThat(st.rangeQuery1(1, 1)).isEqualTo(2); - assertThat(st.rangeQuery1(2, 2)).isEqualTo(4); - assertThat(st.rangeQuery1(3, 3)).isEqualTo(9); - assertThat(st.rangeQuery1(1, 3)).isEqualTo(2); - assertThat(st.rangeQuery1(2, 3)).isEqualTo(4); - assertThat(st.rangeQuery1(1, 2)).isEqualTo(2); - - st.rangeUpdate1(1, 3, 3); - - assertThat(st.rangeQuery1(0, 4)).isEqualTo(3); - assertThat(st.rangeQuery1(0, 2)).isEqualTo(3); - assertThat(st.rangeQuery1(2, 4)).isEqualTo(4); - assertThat(st.rangeQuery1(1, 3)).isEqualTo(5); - assertThat(st.rangeQuery1(0, 0)).isEqualTo(3); - assertThat(st.rangeQuery1(1, 1)).isEqualTo(5); - assertThat(st.rangeQuery1(2, 2)).isEqualTo(7); - assertThat(st.rangeQuery1(3, 3)).isEqualTo(12); - assertThat(st.rangeQuery1(4, 4)).isEqualTo(4); - } - - @Test - public void maxQuerySumUpdate_simple() { - long[] ar = {2, 1, 3, 4, -1}; - GenericSegmentTree2 st = - new GenericSegmentTree2( - ar, - GenericSegmentTree2.SegmentCombinationFn.MAX, - GenericSegmentTree2.RangeUpdateFn.ADDITION); - - // st.printDebugInfo(); - st.rangeUpdate1(0, 4, 1); - // st.printDebugInfo(); - - assertThat(st.rangeQuery1(0, 4)).isEqualTo(5); - // st.printDebugInfo(); - assertThat(st.rangeQuery1(0, 1)).isEqualTo(3); - - assertThat(st.rangeQuery1(1, 2)).isEqualTo(4); - assertThat(st.rangeQuery1(1, 3)).isEqualTo(5); - - st.rangeUpdate1(3, 4, 4); - - assertThat(st.rangeQuery1(0, 4)).isEqualTo(9); - assertThat(st.rangeQuery1(0, 1)).isEqualTo(3); - assertThat(st.rangeQuery1(3, 4)).isEqualTo(9); - assertThat(st.rangeQuery1(1, 1)).isEqualTo(2); - assertThat(st.rangeQuery1(2, 2)).isEqualTo(4); - assertThat(st.rangeQuery1(3, 3)).isEqualTo(9); - assertThat(st.rangeQuery1(1, 3)).isEqualTo(9); - assertThat(st.rangeQuery1(2, 3)).isEqualTo(9); - assertThat(st.rangeQuery1(1, 2)).isEqualTo(4); - - st.rangeUpdate1(1, 3, 3); - - assertThat(st.rangeQuery1(0, 4)).isEqualTo(12); - assertThat(st.rangeQuery1(0, 2)).isEqualTo(7); - assertThat(st.rangeQuery1(2, 4)).isEqualTo(12); - assertThat(st.rangeQuery1(1, 3)).isEqualTo(12); - assertThat(st.rangeQuery1(0, 0)).isEqualTo(3); - assertThat(st.rangeQuery1(1, 1)).isEqualTo(5); - assertThat(st.rangeQuery1(2, 2)).isEqualTo(7); - assertThat(st.rangeQuery1(3, 3)).isEqualTo(12); - assertThat(st.rangeQuery1(4, 4)).isEqualTo(4); - } - - @Test - public void maxminQueryMulUpdate_simple() { - long[] ar = {2, 1, 3, 4, -1}; - GenericSegmentTree2 st1 = - new GenericSegmentTree2( - ar, - GenericSegmentTree2.SegmentCombinationFn.MAX, - GenericSegmentTree2.RangeUpdateFn.MULTIPLICATION); - GenericSegmentTree2 st2 = - new GenericSegmentTree2( - ar, - GenericSegmentTree2.SegmentCombinationFn.MIN, - GenericSegmentTree2.RangeUpdateFn.MULTIPLICATION); - - st1.rangeUpdate1(0, 4, 1); - st2.rangeUpdate1(0, 4, 1); - - assertThat(st1.rangeQuery1(0, 4)).isEqualTo(4); - assertThat(st2.rangeQuery1(0, 4)).isEqualTo(-1); - - // TODO(issue/208): Negative numbers are a known issue - st1.rangeUpdate1(0, 4, -2); - st2.rangeUpdate1(0, 4, -2); - - assertThat(st1.rangeQuery1(0, 4)).isEqualTo(2); - assertThat(st2.rangeQuery1(0, 4)).isEqualTo(-8); - - st1.rangeUpdate1(0, 4, -1); - st2.rangeUpdate1(0, 4, -1); - - assertThat(st1.rangeQuery1(0, 4)).isEqualTo(8); - assertThat(st2.rangeQuery1(0, 4)).isEqualTo(-2); - } - - @Test - public void maxQueryMulUpdate_simple() { - long[] ar = {2, 1, 3, 4, -1}; - GenericSegmentTree2 st1 = - new GenericSegmentTree2( - ar, - GenericSegmentTree2.SegmentCombinationFn.MAX, - GenericSegmentTree2.RangeUpdateFn.MULTIPLICATION); - - // [4, 2, 6, 8, -2] - st1.rangeUpdate1(0, 4, 2); - assertThat(st1.rangeQuery1(0, 4)).isEqualTo(8); - assertThat(st1.rangeQuery1(0, 0)).isEqualTo(4); - assertThat(st1.rangeQuery1(0, 1)).isEqualTo(4); - assertThat(st1.rangeQuery1(0, 2)).isEqualTo(6); - assertThat(st1.rangeQuery1(1, 3)).isEqualTo(8); - - // [4, 2, 6, -16, 4] - st1.rangeUpdate1(3, 4, -2); - assertThat(st1.rangeQuery1(0, 4)).isEqualTo(6); - assertThat(st1.rangeQuery1(0, 0)).isEqualTo(4); - assertThat(st1.rangeQuery1(0, 1)).isEqualTo(4); - assertThat(st1.rangeQuery1(0, 2)).isEqualTo(6); - assertThat(st1.rangeQuery1(1, 3)).isEqualTo(6); - assertThat(st1.rangeQuery1(3, 4)).isEqualTo(4); - } - - @Test - public void minQueryMulUpdate_simple() { - long[] ar = {2, 1, 3, 4, -1}; - GenericSegmentTree2 st1 = - new GenericSegmentTree2( - ar, - GenericSegmentTree2.SegmentCombinationFn.MIN, - GenericSegmentTree2.RangeUpdateFn.MULTIPLICATION); - - // [4, 2, 6, 8, -2] - st1.rangeUpdate1(0, 4, 2); - assertThat(st1.rangeQuery1(0, 4)).isEqualTo(-2); - assertThat(st1.rangeQuery1(0, 0)).isEqualTo(4); - assertThat(st1.rangeQuery1(0, 1)).isEqualTo(2); - assertThat(st1.rangeQuery1(0, 2)).isEqualTo(2); - assertThat(st1.rangeQuery1(1, 3)).isEqualTo(2); - - // [4, 2, 6, -16, 4] - st1.rangeUpdate1(3, 4, -2); - assertThat(st1.rangeQuery1(0, 4)).isEqualTo(-16); - assertThat(st1.rangeQuery1(0, 0)).isEqualTo(4); - assertThat(st1.rangeQuery1(0, 1)).isEqualTo(2); - assertThat(st1.rangeQuery1(0, 2)).isEqualTo(2); - assertThat(st1.rangeQuery1(1, 3)).isEqualTo(-16); - assertThat(st1.rangeQuery1(3, 4)).isEqualTo(-16); - } - - // Test segment tree min/max with mul range updates. These tests have smaller - // values to avoid overflow - // @Test - // public void testMinMax_mul() { - // GenericSegmentTree2.SegmentCombinationFn[] combinationFns = { - // GenericSegmentTree2.SegmentCombinationFn.MIN, GenericSegmentTree2.SegmentCombinationFn.MAX - // }; - - // GenericSegmentTree2.RangeUpdateFn[] rangeUpdateFns = { - // GenericSegmentTree2.RangeUpdateFn.MULTIPLICATION - // }; - - // for (GenericSegmentTree2.SegmentCombinationFn combinationFn : combinationFns) { - // for (GenericSegmentTree2.RangeUpdateFn rangeUpdateFn : rangeUpdateFns) { - - // for (int n = 5; n < 20; n++) { - // long[] ar = TestUtils.randomLongArray(n, -5, +5); - // GenericSegmentTree2 st = - // new GenericSegmentTree2( - // ar, GenericSegmentTree2.SegmentCombinationFn.MIN, rangeUpdateFn); - // GenericSegmentTree2 st2 = - // new GenericSegmentTree2( - // ar, GenericSegmentTree2.SegmentCombinationFn.MAX, rangeUpdateFn); - // System.out.println(); - - // for (int i = 0; i < n; i++) { - // int j = TestUtils.randValue(0, n - 1); - // int k = TestUtils.randValue(0, n - 1); - // int i1 = Math.min(j, k); - // int i2 = Math.max(j, k); - - // j = TestUtils.randValue(0, n - 1); - // k = TestUtils.randValue(0, n - 1); - // int i3 = Math.min(j, k); - // int i4 = Math.max(j, k); - - // // Range update - // long randValue = TestUtils.randValue(-10, 10); - // System.out.printf("UPDATE [%d, %d] with %d\n", i3, i4, randValue); - - // if (rangeUpdateFn == GenericSegmentTree2.RangeUpdateFn.ADDITION) { - // bruteForceSumRangeUpdate(ar, i3, i4, randValue); - // } else if (rangeUpdateFn == GenericSegmentTree2.RangeUpdateFn.ASSIGN) { - // bruteForceAssignRangeUpdate(ar, i3, i4, randValue); - // } else if (rangeUpdateFn == GenericSegmentTree2.RangeUpdateFn.MULTIPLICATION) { - // bruteForceMulRangeUpdate(ar, i3, i4, randValue); - // } - - // st.rangeUpdate1(i3, i4, randValue); - // st2.rangeUpdate1(i3, i4, randValue); - - // long bf = 0; - - // if (combinationFn == GenericSegmentTree2.SegmentCombinationFn.SUM) { - // bf = bruteForceSum(ar, i1, i2); - // } else if (combinationFn == GenericSegmentTree2.SegmentCombinationFn.MIN) { - // bf = bruteForceMin(ar, i1, i2); - // } else if (combinationFn == GenericSegmentTree2.SegmentCombinationFn.MAX) { - // bf = bruteForceMax(ar, i1, i2); - // } - - // long segTreeAnswer = st.rangeQuery1(i1, i2); - // long segTreeAnswer2 = st2.rangeQuery1(i1, i2); - // System.out.printf( - // "QUERY [%d, %d] want: %d, got: %d, got2: %d\n", - // i1, i2, bf, segTreeAnswer, segTreeAnswer2); - // // System.out.printf("QUERY [%d, %d] want: %d, got: %d\n", i1, i2, bf, - // segTreeAnswer2); - // if (bf != segTreeAnswer) { - // System.out.printf( - // "(%s query, %s range update) | [%d, %d], want = %d, got = %d, got2 = %d\n", - // combinationFn, rangeUpdateFn, i1, i2, bf, segTreeAnswer, segTreeAnswer2); - // } - // assertThat(bf).isEqualTo(segTreeAnswer); - // } - // } - // } - // } - // } - - @Test - public void testAllFunctionCombinations() { - GenericSegmentTree2.SegmentCombinationFn[] combinationFns = { - GenericSegmentTree2.SegmentCombinationFn.SUM, - GenericSegmentTree2.SegmentCombinationFn.MIN, - GenericSegmentTree2.SegmentCombinationFn.MAX, - }; - - GenericSegmentTree2.RangeUpdateFn[] rangeUpdateFns = { - GenericSegmentTree2.RangeUpdateFn.ADDITION, - GenericSegmentTree2.RangeUpdateFn.ASSIGN, - GenericSegmentTree2.RangeUpdateFn.MULTIPLICATION - }; - - for (GenericSegmentTree2.SegmentCombinationFn combinationFn : combinationFns) { - for (GenericSegmentTree2.RangeUpdateFn rangeUpdateFn : rangeUpdateFns) { - - // TODO(issue/208): The multiplication range update function seems to be suffering - // from overflow issues and not being able to handle negative numbers. - // - // One idea might be to also track the min value for the max query and vice versa - // and swap values when a negative number is found? - if (rangeUpdateFn == GenericSegmentTree2.RangeUpdateFn.MULTIPLICATION - && (combinationFn == GenericSegmentTree2.SegmentCombinationFn.MIN - || combinationFn == GenericSegmentTree2.SegmentCombinationFn.MAX)) { - continue; - } - - for (int n = 5; n < ITERATIONS; n++) { - long[] ar = TestUtils.randomLongArray(n, -100, +100); - GenericSegmentTree2 st = new GenericSegmentTree2(ar, combinationFn, rangeUpdateFn); - - for (int i = 0; i < n; i++) { - int j = TestUtils.randValue(0, n - 1); - int k = TestUtils.randValue(0, n - 1); - int i1 = Math.min(j, k); - int i2 = Math.max(j, k); - - j = TestUtils.randValue(0, n - 1); - k = TestUtils.randValue(0, n - 1); - int i3 = Math.min(j, k); - int i4 = Math.max(j, k); - - // Range update - long randValue = TestUtils.randValue(-10, 10); - // System.out.printf("UPDATE [%d, %d] with %d\n", i3, i4, randValue); - - if (rangeUpdateFn == GenericSegmentTree2.RangeUpdateFn.ADDITION) { - bruteForceSumRangeUpdate(ar, i3, i4, randValue); - } else if (rangeUpdateFn == GenericSegmentTree2.RangeUpdateFn.ASSIGN) { - bruteForceAssignRangeUpdate(ar, i3, i4, randValue); - } else if (rangeUpdateFn == GenericSegmentTree2.RangeUpdateFn.MULTIPLICATION) { - bruteForceMulRangeUpdate(ar, i3, i4, randValue); - } - - st.rangeUpdate1(i3, i4, randValue); - - long bf = 0; - - if (combinationFn == GenericSegmentTree2.SegmentCombinationFn.SUM) { - bf = bruteForceSum(ar, i1, i2); - } else if (combinationFn == GenericSegmentTree2.SegmentCombinationFn.MIN) { - bf = bruteForceMin(ar, i1, i2); - } else if (combinationFn == GenericSegmentTree2.SegmentCombinationFn.MAX) { - bf = bruteForceMax(ar, i1, i2); - } - - long segTreeAnswer = st.rangeQuery1(i1, i2); - if (bf != segTreeAnswer) { - System.out.printf( - "(%s query, %s range update) | [%d, %d], want = %d, got = %d\n", - combinationFn, rangeUpdateFn, i1, i2, bf, segTreeAnswer); - } - assertThat(bf).isEqualTo(segTreeAnswer); - } - } - } - } - } - - // Finds the sum in an array between [l, r] in the `values` array - private static long bruteForceSum(long[] values, int l, int r) { - long s = 0; - for (int i = l; i <= r; i++) { - s += values[i]; - } - return s; - } - - // Finds the min value in an array between [l, r] in the `values` array - private static long bruteForceMin(long[] values, int l, int r) { - long m = values[l]; - for (int i = l; i <= r; i++) { - m = Math.min(m, values[i]); - } - return m; - } - - // Finds the max value in an array between [l, r] in the `values` array - private static long bruteForceMax(long[] values, int l, int r) { - long m = values[l]; - for (int i = l; i <= r; i++) { - m = Math.max(m, values[i]); - } - return m; - } - - private static void bruteForceSumRangeUpdate(long[] values, int l, int r, long x) { - for (int i = l; i <= r; i++) { - values[i] += x; - } - } - - private static void bruteForceMulRangeUpdate(long[] values, int l, int r, long x) { - for (int i = l; i <= r; i++) { - values[i] *= x; - } - } - - private static void bruteForceAssignRangeUpdate(long[] values, int l, int r, long x) { - for (int i = l; i <= r; i++) { - values[i] = x; - } - } -} diff --git a/src/test/java/com/williamfiset/algorithms/datastructures/segmenttree/GenericSegmentTreeTest.java b/src/test/java/com/williamfiset/algorithms/datastructures/segmenttree/GenericSegmentTreeTest.java index 38d0b4bfe..a84add8a0 100644 --- a/src/test/java/com/williamfiset/algorithms/datastructures/segmenttree/GenericSegmentTreeTest.java +++ b/src/test/java/com/williamfiset/algorithms/datastructures/segmenttree/GenericSegmentTreeTest.java @@ -5,6 +5,7 @@ package com.williamfiset.algorithms.datastructures.segmenttree; import static com.google.common.truth.Truth.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; import com.williamfiset.algorithms.utils.TestUtils; import org.junit.jupiter.api.Test; @@ -14,209 +15,881 @@ public class GenericSegmentTreeTest { static int ITERATIONS = 250; static int MAX_N = 17; - // @Before - // public void setup() {} - - // @Test - // public void testSumQuerySumUpdate_Simple() { - // long[] values = {1, 2, 3, 4, 5}; - // GenericSegmentTree st = - // new GenericSegmentTree( - // values, - // GenericSegmentTree.SegmentCombinationFn.SUM, - // GenericSegmentTree.RangeUpdateFn.ADDITION); - - // assertThat(st.rangeQuery1(0, 1)).isEqualTo(3); - // assertThat(st.rangeQuery1(2, 2)).isEqualTo(3); - // assertThat(st.rangeQuery1(0, 4)).isEqualTo(15); - // } - - // @Test - // public void testSumQuerySumUpdate_RangeUpdate() { - // // 0, 1, 2, 3, 4 - // long[] ar = {1, 2, 1, 2, 1}; - // GenericSegmentTree st = - // new GenericSegmentTree( - // ar, - // GenericSegmentTree.SegmentCombinationFn.SUM, - // GenericSegmentTree.RangeUpdateFn.ADDITION); - - // // Do multiple range updates - // st.rangeUpdate1(0, 1, 5); - // st.rangeUpdate1(3, 4, 2); - // st.rangeUpdate1(0, 4, 3); - - // // Point queries - // assertThat(st.rangeQuery1(0, 0)).isEqualTo(1 + 3 + 5); - // assertThat(st.rangeQuery1(1, 1)).isEqualTo(2 + 3 + 5); - // assertThat(st.rangeQuery1(2, 2)).isEqualTo(1 + 3); - // assertThat(st.rangeQuery1(3, 3)).isEqualTo(2 + 3 + 2); - // assertThat(st.rangeQuery1(4, 4)).isEqualTo(2 + 3 + 1); - - // // Range queries - // assertThat(st.rangeQuery1(0, 1)).isEqualTo(2 * 5 + 2 * 3 + 1 + 2); - // assertThat(st.rangeQuery1(0, 2)).isEqualTo(2 * 5 + 3 * 3 + 1 + 2 + 1); - // assertThat(st.rangeQuery1(3, 4)).isEqualTo(2 * 2 + 2 * 3 + 2 + 1); - // assertThat(st.rangeQuery1(0, 4)).isEqualTo(2 * 5 + 2 * 2 + 3 * 5 + 1 + 1 + 1 + 2 + 2); - // } - - // @Test - // public void testSumQueryAssignUpdate_simple() { - // long[] ar = {2, 1, 3, 4, -1}; - // GenericSegmentTree st = - // new GenericSegmentTree( - // ar, - // GenericSegmentTree.SegmentCombinationFn.SUM, - // GenericSegmentTree.RangeUpdateFn.ASSIGN); - - // st.rangeUpdate1(3, 4, 2); - - // assertThat(st.rangeQuery1(0, 4)).isEqualTo(10); - // assertThat(st.rangeQuery1(3, 4)).isEqualTo(4); - // assertThat(st.rangeQuery1(3, 3)).isEqualTo(2); - // assertThat(st.rangeQuery1(4, 4)).isEqualTo(2); - - // st.rangeUpdate1(1, 3, 4); - - // assertThat(st.rangeQuery1(0, 4)).isEqualTo(16); - // assertThat(st.rangeQuery1(0, 1)).isEqualTo(6); - // assertThat(st.rangeQuery1(3, 4)).isEqualTo(6); - // assertThat(st.rangeQuery1(1, 1)).isEqualTo(4); - // assertThat(st.rangeQuery1(2, 2)).isEqualTo(4); - // assertThat(st.rangeQuery1(3, 3)).isEqualTo(4); - // assertThat(st.rangeQuery1(1, 3)).isEqualTo(12); - // assertThat(st.rangeQuery1(2, 3)).isEqualTo(8); - // assertThat(st.rangeQuery1(1, 2)).isEqualTo(8); - - // st.rangeUpdate1(2, 2, 5); - - // assertThat(st.rangeQuery1(0, 4)).isEqualTo(17); - // assertThat(st.rangeQuery1(0, 2)).isEqualTo(11); - // assertThat(st.rangeQuery1(2, 4)).isEqualTo(11); - // assertThat(st.rangeQuery1(1, 3)).isEqualTo(13); - // assertThat(st.rangeQuery1(2, 2)).isEqualTo(5); - // } - - // @Test - // public void testSumQueryMulUpdate_simple() { - // long[] ar = {1, 4, 5, 3, 2}; - // GenericSegmentTree st = - // new GenericSegmentTree( - // ar, - // GenericSegmentTree.SegmentCombinationFn.SUM, - // GenericSegmentTree.RangeUpdateFn.MULTIPLICATION); - - // st.rangeUpdate1(1, 3, 3); - - // assertThat(st.rangeQuery1(1, 3)).isEqualTo(4 * 3 + 5 * 3 + 3 * 3); - // assertThat(st.rangeQuery1(0, 4)).isEqualTo(1 + 4 * 3 + 5 * 3 + 3 * 3 + 2); - // assertThat(st.rangeQuery1(0, 2)).isEqualTo(1 + 4 * 3 + 5 * 3); - // assertThat(st.rangeQuery1(2, 4)).isEqualTo(5 * 3 + 3 * 3 + 2); - - // st.rangeUpdate1(1, 3, 2); - // assertThat(st.rangeQuery1(1, 3)).isEqualTo(4 * 3 * 2 + 5 * 3 * 2 + 3 * 3 * 2); - // } - - // @Test - // public void minQuerySumUpdates_simple() { - // long[] ar = {2, 1, 3, 4, -1}; - // GenericSegmentTree st = - // new GenericSegmentTree( - // ar, - // GenericSegmentTree.SegmentCombinationFn.MIN, - // GenericSegmentTree.RangeUpdateFn.ADDITION); - - // st.rangeUpdate1(0, 4, 1); - - // assertThat(st.rangeQuery1(0, 4)).isEqualTo(0); - // assertThat(st.rangeQuery1(1, 3)).isEqualTo(2); - // assertThat(st.rangeQuery1(2, 4)).isEqualTo(0); - // assertThat(st.rangeQuery1(3, 3)).isEqualTo(5); - - // st.rangeUpdate1(3, 4, 4); - - // assertThat(st.rangeQuery1(0, 4)).isEqualTo(2); - // assertThat(st.rangeQuery1(0, 1)).isEqualTo(2); - // assertThat(st.rangeQuery1(3, 4)).isEqualTo(4); - // assertThat(st.rangeQuery1(1, 1)).isEqualTo(2); - // assertThat(st.rangeQuery1(2, 2)).isEqualTo(4); - // assertThat(st.rangeQuery1(3, 3)).isEqualTo(9); - // assertThat(st.rangeQuery1(1, 3)).isEqualTo(2); - // assertThat(st.rangeQuery1(2, 3)).isEqualTo(4); - // assertThat(st.rangeQuery1(1, 2)).isEqualTo(2); - - // st.rangeUpdate1(1, 3, 3); - - // assertThat(st.rangeQuery1(0, 4)).isEqualTo(3); - // assertThat(st.rangeQuery1(0, 2)).isEqualTo(3); - // assertThat(st.rangeQuery1(2, 4)).isEqualTo(4); - // assertThat(st.rangeQuery1(1, 3)).isEqualTo(5); - // assertThat(st.rangeQuery1(0, 0)).isEqualTo(3); - // assertThat(st.rangeQuery1(1, 1)).isEqualTo(5); - // assertThat(st.rangeQuery1(2, 2)).isEqualTo(7); - // assertThat(st.rangeQuery1(3, 3)).isEqualTo(12); - // assertThat(st.rangeQuery1(4, 4)).isEqualTo(4); - // } - - // @Test - // public void maxQuerySumUpdate_simple() { - // long[] ar = {2, 1, 3, 4, -1}; - // GenericSegmentTree st = - // new GenericSegmentTree( - // ar, - // GenericSegmentTree.SegmentCombinationFn.MAX, - // GenericSegmentTree.RangeUpdateFn.ADDITION); - - // st.printDebugInfo(); - // st.rangeUpdate1(0, 4, 1); - // st.printDebugInfo(); - - // assertThat(st.rangeQuery1(0, 4)).isEqualTo(5); - // assertThat(st.rangeQuery1(0, 1)).isEqualTo(3); - // assertThat(st.rangeQuery1(1, 2)).isEqualTo(4); - // assertThat(st.rangeQuery1(1, 3)).isEqualTo(5); - - // st.rangeUpdate1(3, 4, 4); - - // assertThat(st.rangeQuery1(0, 4)).isEqualTo(9); - // assertThat(st.rangeQuery1(0, 1)).isEqualTo(3); - // assertThat(st.rangeQuery1(3, 4)).isEqualTo(9); - // assertThat(st.rangeQuery1(1, 1)).isEqualTo(2); - // assertThat(st.rangeQuery1(2, 2)).isEqualTo(4); - // assertThat(st.rangeQuery1(3, 3)).isEqualTo(9); - // assertThat(st.rangeQuery1(1, 3)).isEqualTo(9); - // assertThat(st.rangeQuery1(2, 3)).isEqualTo(9); - // assertThat(st.rangeQuery1(1, 2)).isEqualTo(4); - - // st.rangeUpdate1(1, 3, 3); - - // assertThat(st.rangeQuery1(0, 4)).isEqualTo(12); - // assertThat(st.rangeQuery1(0, 2)).isEqualTo(7); - // assertThat(st.rangeQuery1(2, 4)).isEqualTo(12); - // assertThat(st.rangeQuery1(1, 3)).isEqualTo(12); - // assertThat(st.rangeQuery1(0, 0)).isEqualTo(3); - // assertThat(st.rangeQuery1(1, 1)).isEqualTo(5); - // assertThat(st.rangeQuery1(2, 2)).isEqualTo(7); - // assertThat(st.rangeQuery1(3, 3)).isEqualTo(12); - // assertThat(st.rangeQuery1(4, 4)).isEqualTo(4); - // } - - // @Test - // public void maxQueryMulUpdate_simple() { - // long[] ar = {2, 1, 3, 4, -1}; - // GenericSegmentTree st = - // new GenericSegmentTree( - // ar, - // GenericSegmentTree.SegmentCombinationFn.MAX, - // GenericSegmentTree.RangeUpdateFn.MULTIPLICATION); - - // st.rangeUpdate1(0, 4, 1); - // assertThat(st.rangeQuery1(0, 4)).isEqualTo(4); - - // // TODO(issue/208): Negative numbers are a known issue - // // st.rangeUpdate1(0, 4, -2); - // // assertThat(st.rangeQuery1(0, 4)).isEqualTo(2); // Returns -8 as max but should be 2 - // } + @Test + public void testSumQuerySumUpdate_Simple() { + long[] values = {1, 2, 3, 4, 5}; + GenericSegmentTree st = + new GenericSegmentTree( + values, + GenericSegmentTree.SegmentCombinationFn.SUM, + GenericSegmentTree.RangeUpdateFn.ADDITION); + + assertThat(st.rangeQuery(0, 1)).isEqualTo(3); + assertThat(st.rangeQuery(2, 2)).isEqualTo(3); + assertThat(st.rangeQuery(0, 4)).isEqualTo(15); + } + + @Test + public void testSumQuerySumUpdate_RangeUpdate() { + long[] ar = {1, 2, 1, 2, 1}; + GenericSegmentTree st = + new GenericSegmentTree( + ar, + GenericSegmentTree.SegmentCombinationFn.SUM, + GenericSegmentTree.RangeUpdateFn.ADDITION); + + st.rangeUpdate(0, 1, 5); + st.rangeUpdate(3, 4, 2); + st.rangeUpdate(0, 4, 3); + + assertThat(st.rangeQuery(0, 0)).isEqualTo(1 + 3 + 5); + assertThat(st.rangeQuery(1, 1)).isEqualTo(2 + 3 + 5); + assertThat(st.rangeQuery(2, 2)).isEqualTo(1 + 3); + assertThat(st.rangeQuery(3, 3)).isEqualTo(2 + 3 + 2); + assertThat(st.rangeQuery(4, 4)).isEqualTo(2 + 3 + 1); + + assertThat(st.rangeQuery(0, 1)).isEqualTo(2 * 5 + 2 * 3 + 1 + 2); + assertThat(st.rangeQuery(0, 2)).isEqualTo(2 * 5 + 3 * 3 + 1 + 2 + 1); + assertThat(st.rangeQuery(3, 4)).isEqualTo(2 * 2 + 2 * 3 + 2 + 1); + assertThat(st.rangeQuery(0, 4)).isEqualTo(2 * 5 + 2 * 2 + 3 * 5 + 1 + 1 + 1 + 2 + 2); + } + + @Test + public void testSumQueryAssignUpdate_simple() { + long[] ar = {2, 1, 3, 4, -1}; + GenericSegmentTree st = + new GenericSegmentTree( + ar, + GenericSegmentTree.SegmentCombinationFn.SUM, + GenericSegmentTree.RangeUpdateFn.ASSIGN); + + st.rangeUpdate(3, 4, 2); + + assertThat(st.rangeQuery(0, 4)).isEqualTo(10); + assertThat(st.rangeQuery(3, 4)).isEqualTo(4); + assertThat(st.rangeQuery(3, 3)).isEqualTo(2); + assertThat(st.rangeQuery(4, 4)).isEqualTo(2); + + st.rangeUpdate(1, 3, 4); + + assertThat(st.rangeQuery(0, 4)).isEqualTo(16); + assertThat(st.rangeQuery(0, 1)).isEqualTo(6); + assertThat(st.rangeQuery(3, 4)).isEqualTo(6); + assertThat(st.rangeQuery(1, 1)).isEqualTo(4); + assertThat(st.rangeQuery(2, 2)).isEqualTo(4); + assertThat(st.rangeQuery(3, 3)).isEqualTo(4); + assertThat(st.rangeQuery(1, 3)).isEqualTo(12); + assertThat(st.rangeQuery(2, 3)).isEqualTo(8); + assertThat(st.rangeQuery(1, 2)).isEqualTo(8); + + st.rangeUpdate(2, 2, 5); + + assertThat(st.rangeQuery(0, 4)).isEqualTo(17); + assertThat(st.rangeQuery(0, 2)).isEqualTo(11); + assertThat(st.rangeQuery(2, 4)).isEqualTo(11); + assertThat(st.rangeQuery(1, 3)).isEqualTo(13); + assertThat(st.rangeQuery(2, 2)).isEqualTo(5); + } + + @Test + public void testSumQueryMulUpdate_simple() { + long[] ar = {1, 4, 5, 3, 2}; + GenericSegmentTree st = + new GenericSegmentTree( + ar, + GenericSegmentTree.SegmentCombinationFn.SUM, + GenericSegmentTree.RangeUpdateFn.MULTIPLICATION); + + st.rangeUpdate(1, 3, 3); + + assertThat(st.rangeQuery(1, 3)).isEqualTo(4 * 3 + 5 * 3 + 3 * 3); + assertThat(st.rangeQuery(0, 4)).isEqualTo(1 + 4 * 3 + 5 * 3 + 3 * 3 + 2); + assertThat(st.rangeQuery(0, 2)).isEqualTo(1 + 4 * 3 + 5 * 3); + assertThat(st.rangeQuery(2, 4)).isEqualTo(5 * 3 + 3 * 3 + 2); + + st.rangeUpdate(1, 3, 2); + assertThat(st.rangeQuery(1, 3)).isEqualTo(4 * 3 * 2 + 5 * 3 * 2 + 3 * 3 * 2); + } + + @Test + public void minQuerySumUpdates_simple() { + long[] ar = {2, 1, 3, 4, -1}; + GenericSegmentTree st = + new GenericSegmentTree( + ar, + GenericSegmentTree.SegmentCombinationFn.MIN, + GenericSegmentTree.RangeUpdateFn.ADDITION); + + st.rangeUpdate(0, 4, 1); + + assertThat(st.rangeQuery(0, 4)).isEqualTo(0); + assertThat(st.rangeQuery(1, 3)).isEqualTo(2); + assertThat(st.rangeQuery(2, 4)).isEqualTo(0); + assertThat(st.rangeQuery(3, 3)).isEqualTo(5); + + st.rangeUpdate(3, 4, 4); + + assertThat(st.rangeQuery(0, 4)).isEqualTo(2); + assertThat(st.rangeQuery(0, 1)).isEqualTo(2); + assertThat(st.rangeQuery(3, 4)).isEqualTo(4); + assertThat(st.rangeQuery(1, 1)).isEqualTo(2); + assertThat(st.rangeQuery(2, 2)).isEqualTo(4); + assertThat(st.rangeQuery(3, 3)).isEqualTo(9); + assertThat(st.rangeQuery(1, 3)).isEqualTo(2); + assertThat(st.rangeQuery(2, 3)).isEqualTo(4); + assertThat(st.rangeQuery(1, 2)).isEqualTo(2); + + st.rangeUpdate(1, 3, 3); + + assertThat(st.rangeQuery(0, 4)).isEqualTo(3); + assertThat(st.rangeQuery(0, 2)).isEqualTo(3); + assertThat(st.rangeQuery(2, 4)).isEqualTo(4); + assertThat(st.rangeQuery(1, 3)).isEqualTo(5); + assertThat(st.rangeQuery(0, 0)).isEqualTo(3); + assertThat(st.rangeQuery(1, 1)).isEqualTo(5); + assertThat(st.rangeQuery(2, 2)).isEqualTo(7); + assertThat(st.rangeQuery(3, 3)).isEqualTo(12); + assertThat(st.rangeQuery(4, 4)).isEqualTo(4); + } + + @Test + public void maxQuerySumUpdate_simple() { + long[] ar = {2, 1, 3, 4, -1}; + GenericSegmentTree st = + new GenericSegmentTree( + ar, + GenericSegmentTree.SegmentCombinationFn.MAX, + GenericSegmentTree.RangeUpdateFn.ADDITION); + + st.rangeUpdate(0, 4, 1); + + assertThat(st.rangeQuery(0, 4)).isEqualTo(5); + assertThat(st.rangeQuery(0, 1)).isEqualTo(3); + assertThat(st.rangeQuery(1, 2)).isEqualTo(4); + assertThat(st.rangeQuery(1, 3)).isEqualTo(5); + + st.rangeUpdate(3, 4, 4); + + assertThat(st.rangeQuery(0, 4)).isEqualTo(9); + assertThat(st.rangeQuery(0, 1)).isEqualTo(3); + assertThat(st.rangeQuery(3, 4)).isEqualTo(9); + assertThat(st.rangeQuery(1, 1)).isEqualTo(2); + assertThat(st.rangeQuery(2, 2)).isEqualTo(4); + assertThat(st.rangeQuery(3, 3)).isEqualTo(9); + assertThat(st.rangeQuery(1, 3)).isEqualTo(9); + assertThat(st.rangeQuery(2, 3)).isEqualTo(9); + assertThat(st.rangeQuery(1, 2)).isEqualTo(4); + + st.rangeUpdate(1, 3, 3); + + assertThat(st.rangeQuery(0, 4)).isEqualTo(12); + assertThat(st.rangeQuery(0, 2)).isEqualTo(7); + assertThat(st.rangeQuery(2, 4)).isEqualTo(12); + assertThat(st.rangeQuery(1, 3)).isEqualTo(12); + assertThat(st.rangeQuery(0, 0)).isEqualTo(3); + assertThat(st.rangeQuery(1, 1)).isEqualTo(5); + assertThat(st.rangeQuery(2, 2)).isEqualTo(7); + assertThat(st.rangeQuery(3, 3)).isEqualTo(12); + assertThat(st.rangeQuery(4, 4)).isEqualTo(4); + } + + @Test + public void maxQueryMulUpdate_simple() { + long[] ar = {2, 1, 3, 4, -1}; + GenericSegmentTree st = + new GenericSegmentTree( + ar, + GenericSegmentTree.SegmentCombinationFn.MAX, + GenericSegmentTree.RangeUpdateFn.MULTIPLICATION); + + st.rangeUpdate(0, 4, 1); + assertThat(st.rangeQuery(0, 4)).isEqualTo(4); + } + + @Test + public void testGcdQueryAdditionUpdate_throwsException() { + long[] v = {12, 24, 3}; + assertThrows( + UnsupportedOperationException.class, + () -> + new GenericSegmentTree( + v, + GenericSegmentTree.SegmentCombinationFn.GCD, + GenericSegmentTree.RangeUpdateFn.ADDITION)); + } + + @Test + public void testProductQueryAdditionUpdate_throwsException() { + long[] v = {2, 3, 4}; + assertThrows( + UnsupportedOperationException.class, + () -> + new GenericSegmentTree( + v, + GenericSegmentTree.SegmentCombinationFn.PRODUCT, + GenericSegmentTree.RangeUpdateFn.ADDITION)); + } + + @Test + public void testProductQueryMulUpdate_simple() { + long[] v = {3, 2, 2, 1}; + GenericSegmentTree st = + new GenericSegmentTree( + v, + GenericSegmentTree.SegmentCombinationFn.PRODUCT, + GenericSegmentTree.RangeUpdateFn.MULTIPLICATION); + + assertThat(st.rangeQuery(0, 3)).isEqualTo(12); + + st.rangeUpdate(1, 2, 4); // {3, 8, 8, 1} -> 192 + assertThat(st.rangeQuery(0, 3)).isEqualTo(192); + + st.rangeUpdate(2, 3, 2); // {3, 8, 16, 2} -> 768 + assertThat(st.rangeQuery(0, 3)).isEqualTo(768); + } + + @Test + public void testProductQueryAssignUpdate_simple() { + long[] v = {2, 3, 1, 5}; + GenericSegmentTree st = + new GenericSegmentTree( + v, + GenericSegmentTree.SegmentCombinationFn.PRODUCT, + GenericSegmentTree.RangeUpdateFn.ASSIGN); + + assertThat(st.rangeQuery(0, 3)).isEqualTo(30); + + st.rangeUpdate(1, 2, 4); // {2, 4, 4, 5} -> 160 + assertThat(st.rangeQuery(0, 3)).isEqualTo(160); + + st.rangeUpdate(0, 3, 3); // {3, 3, 3, 3} -> 81 + assertThat(st.rangeQuery(0, 3)).isEqualTo(81); + } + + @Test + public void testGcdQueryMulUpdate_simple() { + long[] v = {12, 24, 3, 4, -1}; + GenericSegmentTree st = + new GenericSegmentTree( + v, + GenericSegmentTree.SegmentCombinationFn.GCD, + GenericSegmentTree.RangeUpdateFn.MULTIPLICATION); + + assertThat(st.rangeQuery(0, 2)).isEqualTo(3); + st.rangeUpdate(2, 2, 2); + assertThat(st.rangeQuery(0, 2)).isEqualTo(6); + assertThat(st.rangeQuery(0, 1)).isEqualTo(12); + } + + @Test + public void testGcdQueryAssignUpdate_simple() { + long[] v = {12, 24, 3, 12, 48}; + GenericSegmentTree st = + new GenericSegmentTree( + v, + GenericSegmentTree.SegmentCombinationFn.GCD, + GenericSegmentTree.RangeUpdateFn.ASSIGN); + + assertThat(st.rangeQuery(0, 2)).isEqualTo(3); + + st.rangeUpdate(2, 2, 48); // {12, 24, 48, 12, 48} + assertThat(st.rangeQuery(0, 2)).isEqualTo(12); + + st.rangeUpdate(2, 3, 24); // {12, 24, 24, 24, 48} + assertThat(st.rangeQuery(0, 4)).isEqualTo(12); + } + + // ====================================================== + // Constructor validation tests + // ====================================================== + + @Test + public void testNullValues_throwsException() { + assertThrows( + IllegalArgumentException.class, + () -> + new GenericSegmentTree( + null, + GenericSegmentTree.SegmentCombinationFn.SUM, + GenericSegmentTree.RangeUpdateFn.ADDITION)); + } + + @Test + public void testNullCombinationFn_throwsException() { + long[] v = {1, 2, 3}; + assertThrows( + IllegalArgumentException.class, + () -> new GenericSegmentTree(v, null, GenericSegmentTree.RangeUpdateFn.ADDITION)); + } + + @Test + public void testNullRangeUpdateFn_throwsException() { + long[] v = {1, 2, 3}; + assertThrows( + IllegalArgumentException.class, + () -> new GenericSegmentTree(v, GenericSegmentTree.SegmentCombinationFn.SUM, null)); + } + + // ====================================================== + // Single element array tests + // ====================================================== + + @Test + public void testSingleElement_sumAddition() { + long[] v = {42}; + GenericSegmentTree st = + new GenericSegmentTree( + v, + GenericSegmentTree.SegmentCombinationFn.SUM, + GenericSegmentTree.RangeUpdateFn.ADDITION); + + assertThat(st.rangeQuery(0, 0)).isEqualTo(42); + + st.rangeUpdate(0, 0, 8); + assertThat(st.rangeQuery(0, 0)).isEqualTo(50); + } + + @Test + public void testSingleElement_minAssign() { + long[] v = {7}; + GenericSegmentTree st = + new GenericSegmentTree( + v, + GenericSegmentTree.SegmentCombinationFn.MIN, + GenericSegmentTree.RangeUpdateFn.ASSIGN); + + assertThat(st.rangeQuery(0, 0)).isEqualTo(7); + + st.rangeUpdate(0, 0, -3); + assertThat(st.rangeQuery(0, 0)).isEqualTo(-3); + } + + @Test + public void testSingleElement_productMul() { + long[] v = {5}; + GenericSegmentTree st = + new GenericSegmentTree( + v, + GenericSegmentTree.SegmentCombinationFn.PRODUCT, + GenericSegmentTree.RangeUpdateFn.MULTIPLICATION); + + assertThat(st.rangeQuery(0, 0)).isEqualTo(5); + + st.rangeUpdate(0, 0, 3); + assertThat(st.rangeQuery(0, 0)).isEqualTo(15); + } + + // ====================================================== + // MIN + ASSIGN simple test + // ====================================================== + + @Test + public void testMinQueryAssignUpdate_simple() { + long[] ar = {2, 1, 3, 4, -1}; + GenericSegmentTree st = + new GenericSegmentTree( + ar, + GenericSegmentTree.SegmentCombinationFn.MIN, + GenericSegmentTree.RangeUpdateFn.ASSIGN); + + assertThat(st.rangeQuery(0, 4)).isEqualTo(-1); + assertThat(st.rangeQuery(0, 2)).isEqualTo(1); + assertThat(st.rangeQuery(1, 3)).isEqualTo(1); + + // Assign [1,3] = 5 -> {2, 5, 5, 5, -1} + st.rangeUpdate(1, 3, 5); + assertThat(st.rangeQuery(0, 4)).isEqualTo(-1); + assertThat(st.rangeQuery(0, 3)).isEqualTo(2); + assertThat(st.rangeQuery(1, 3)).isEqualTo(5); + assertThat(st.rangeQuery(1, 1)).isEqualTo(5); + + // Assign [0,4] = 0 -> {0, 0, 0, 0, 0} + st.rangeUpdate(0, 4, 0); + assertThat(st.rangeQuery(0, 4)).isEqualTo(0); + assertThat(st.rangeQuery(2, 3)).isEqualTo(0); + } + + // ====================================================== + // MAX + ASSIGN simple test + // ====================================================== + + @Test + public void testMaxQueryAssignUpdate_simple() { + long[] ar = {2, 1, 3, 4, -1}; + GenericSegmentTree st = + new GenericSegmentTree( + ar, + GenericSegmentTree.SegmentCombinationFn.MAX, + GenericSegmentTree.RangeUpdateFn.ASSIGN); + + assertThat(st.rangeQuery(0, 4)).isEqualTo(4); + assertThat(st.rangeQuery(0, 2)).isEqualTo(3); + + // Assign [0,2] = 10 -> {10, 10, 10, 4, -1} + st.rangeUpdate(0, 2, 10); + assertThat(st.rangeQuery(0, 4)).isEqualTo(10); + assertThat(st.rangeQuery(3, 4)).isEqualTo(4); + assertThat(st.rangeQuery(0, 0)).isEqualTo(10); + + // Assign [2,4] = -5 -> {10, 10, -5, -5, -5} + st.rangeUpdate(2, 4, -5); + assertThat(st.rangeQuery(0, 4)).isEqualTo(10); + assertThat(st.rangeQuery(2, 4)).isEqualTo(-5); + assertThat(st.rangeQuery(0, 1)).isEqualTo(10); + } + + // ====================================================== + // MIN + MULTIPLICATION simple test + // ====================================================== + + @Test + public void testMinQueryMulUpdate_simple() { + long[] ar = {2, 1, 3, 4, 5}; + GenericSegmentTree st = + new GenericSegmentTree( + ar, + GenericSegmentTree.SegmentCombinationFn.MIN, + GenericSegmentTree.RangeUpdateFn.MULTIPLICATION); + + assertThat(st.rangeQuery(0, 4)).isEqualTo(1); + + // Multiply [0,4] by 2 -> {4, 2, 6, 8, 10} + st.rangeUpdate(0, 4, 2); + assertThat(st.rangeQuery(0, 4)).isEqualTo(2); + assertThat(st.rangeQuery(0, 0)).isEqualTo(4); + assertThat(st.rangeQuery(0, 2)).isEqualTo(2); + + // Multiply [2,4] by 3 -> {4, 2, 18, 24, 30} + st.rangeUpdate(2, 4, 3); + assertThat(st.rangeQuery(0, 4)).isEqualTo(2); + assertThat(st.rangeQuery(2, 4)).isEqualTo(18); + } + + // ====================================================== + // Edge case: multiply by zero + // ====================================================== + + @Test + public void testSumQueryMulUpdate_multiplyByZero() { + long[] ar = {5, 10, 15, 20}; + GenericSegmentTree st = + new GenericSegmentTree( + ar, + GenericSegmentTree.SegmentCombinationFn.SUM, + GenericSegmentTree.RangeUpdateFn.MULTIPLICATION); + + assertThat(st.rangeQuery(0, 3)).isEqualTo(50); + + // Multiply [1,2] by 0 -> {5, 0, 0, 20} + st.rangeUpdate(1, 2, 0); + assertThat(st.rangeQuery(0, 3)).isEqualTo(25); + assertThat(st.rangeQuery(1, 2)).isEqualTo(0); + assertThat(st.rangeQuery(0, 0)).isEqualTo(5); + assertThat(st.rangeQuery(3, 3)).isEqualTo(20); + } + + // ====================================================== + // Edge case: all same values + // ====================================================== + + @Test + public void testAllSameValues_gcdAssign() { + long[] ar = {6, 6, 6, 6, 6}; + GenericSegmentTree st = + new GenericSegmentTree( + ar, + GenericSegmentTree.SegmentCombinationFn.GCD, + GenericSegmentTree.RangeUpdateFn.ASSIGN); + + assertThat(st.rangeQuery(0, 4)).isEqualTo(6); + + // Assign [2,4] = 4 -> {6, 6, 4, 4, 4} + st.rangeUpdate(2, 4, 4); + assertThat(st.rangeQuery(0, 4)).isEqualTo(2); + assertThat(st.rangeQuery(0, 1)).isEqualTo(6); + assertThat(st.rangeQuery(2, 4)).isEqualTo(4); + } + + // ====================================================== + // Edge case: update entire array + // ====================================================== + + @Test + public void testUpdateEntireArray_sumAssign() { + long[] ar = {1, 2, 3, 4, 5}; + GenericSegmentTree st = + new GenericSegmentTree( + ar, + GenericSegmentTree.SegmentCombinationFn.SUM, + GenericSegmentTree.RangeUpdateFn.ASSIGN); + + // Assign all to 7 -> {7, 7, 7, 7, 7} + st.rangeUpdate(0, 4, 7); + assertThat(st.rangeQuery(0, 4)).isEqualTo(35); + assertThat(st.rangeQuery(0, 0)).isEqualTo(7); + assertThat(st.rangeQuery(2, 3)).isEqualTo(14); + } + + @Test + public void testUpdateEntireArray_maxAddition() { + long[] ar = {3, -1, 7, 2, 5}; + GenericSegmentTree st = + new GenericSegmentTree( + ar, + GenericSegmentTree.SegmentCombinationFn.MAX, + GenericSegmentTree.RangeUpdateFn.ADDITION); + + assertThat(st.rangeQuery(0, 4)).isEqualTo(7); + + // Add 10 to all -> {13, 9, 17, 12, 15} + st.rangeUpdate(0, 4, 10); + assertThat(st.rangeQuery(0, 4)).isEqualTo(17); + assertThat(st.rangeQuery(0, 1)).isEqualTo(13); + assertThat(st.rangeQuery(1, 1)).isEqualTo(9); + } + + // ====================================================== + // Edge case: point updates (single index range update) + // ====================================================== + + @Test + public void testPointUpdates_sumAddition() { + long[] ar = {0, 0, 0, 0, 0}; + GenericSegmentTree st = + new GenericSegmentTree( + ar, + GenericSegmentTree.SegmentCombinationFn.SUM, + GenericSegmentTree.RangeUpdateFn.ADDITION); + + // Set individual elements via point updates + st.rangeUpdate(0, 0, 1); + st.rangeUpdate(1, 1, 2); + st.rangeUpdate(2, 2, 3); + st.rangeUpdate(3, 3, 4); + st.rangeUpdate(4, 4, 5); + + assertThat(st.rangeQuery(0, 4)).isEqualTo(15); + assertThat(st.rangeQuery(0, 0)).isEqualTo(1); + assertThat(st.rangeQuery(2, 4)).isEqualTo(12); + } + + // ====================================================== + // Negative values tests + // ====================================================== + + @Test + public void testNegativeValues_sumAddition() { + long[] ar = {-5, -3, -1, -4, -2}; + GenericSegmentTree st = + new GenericSegmentTree( + ar, + GenericSegmentTree.SegmentCombinationFn.SUM, + GenericSegmentTree.RangeUpdateFn.ADDITION); + + assertThat(st.rangeQuery(0, 4)).isEqualTo(-15); + + // Add -5 to [1,3] -> {-5, -8, -6, -9, -2} + st.rangeUpdate(1, 3, -5); + assertThat(st.rangeQuery(0, 4)).isEqualTo(-30); + assertThat(st.rangeQuery(1, 3)).isEqualTo(-23); + } + + @Test + public void testNegativeValues_minAddition() { + long[] ar = {-5, -3, -1, -4, -2}; + GenericSegmentTree st = + new GenericSegmentTree( + ar, + GenericSegmentTree.SegmentCombinationFn.MIN, + GenericSegmentTree.RangeUpdateFn.ADDITION); + + assertThat(st.rangeQuery(0, 4)).isEqualTo(-5); + + // Add 10 to [0,0] -> {5, -3, -1, -4, -2} + st.rangeUpdate(0, 0, 10); + assertThat(st.rangeQuery(0, 4)).isEqualTo(-4); + assertThat(st.rangeQuery(0, 0)).isEqualTo(5); + } + + // ====================================================== + // Multiple overlapping range updates (non-SUM combos) + // ====================================================== + + @Test + public void testOverlappingUpdates_minAssign() { + long[] ar = {10, 20, 30, 40, 50}; + GenericSegmentTree st = + new GenericSegmentTree( + ar, + GenericSegmentTree.SegmentCombinationFn.MIN, + GenericSegmentTree.RangeUpdateFn.ASSIGN); + + // Assign [0,3] = 5 -> {5, 5, 5, 5, 50} + st.rangeUpdate(0, 3, 5); + assertThat(st.rangeQuery(0, 4)).isEqualTo(5); + + // Assign [2,4] = 3 -> {5, 5, 3, 3, 3} + st.rangeUpdate(2, 4, 3); + assertThat(st.rangeQuery(0, 4)).isEqualTo(3); + assertThat(st.rangeQuery(0, 1)).isEqualTo(5); + assertThat(st.rangeQuery(2, 4)).isEqualTo(3); + + // Assign [1,2] = 1 -> {5, 1, 1, 3, 3} + st.rangeUpdate(1, 2, 1); + assertThat(st.rangeQuery(0, 4)).isEqualTo(1); + assertThat(st.rangeQuery(0, 0)).isEqualTo(5); + assertThat(st.rangeQuery(1, 1)).isEqualTo(1); + assertThat(st.rangeQuery(3, 4)).isEqualTo(3); + } + + @Test + public void testOverlappingUpdates_productMul() { + long[] ar = {1, 2, 3, 4}; + GenericSegmentTree st = + new GenericSegmentTree( + ar, + GenericSegmentTree.SegmentCombinationFn.PRODUCT, + GenericSegmentTree.RangeUpdateFn.MULTIPLICATION); + + // product = 1*2*3*4 = 24 + assertThat(st.rangeQuery(0, 3)).isEqualTo(24); + + // Multiply [0,1] by 2 -> {2, 4, 3, 4} -> product = 96 + st.rangeUpdate(0, 1, 2); + assertThat(st.rangeQuery(0, 3)).isEqualTo(96); + assertThat(st.rangeQuery(0, 1)).isEqualTo(8); + + // Multiply [1,2] by 3 -> {2, 12, 9, 4} -> product = 864 + st.rangeUpdate(1, 2, 3); + assertThat(st.rangeQuery(0, 3)).isEqualTo(864); + assertThat(st.rangeQuery(1, 2)).isEqualTo(108); + assertThat(st.rangeQuery(0, 0)).isEqualTo(2); + assertThat(st.rangeQuery(3, 3)).isEqualTo(4); + } + + @Test + public void testOverlappingUpdates_gcdMul() { + long[] ar = {6, 12, 18}; + GenericSegmentTree st = + new GenericSegmentTree( + ar, + GenericSegmentTree.SegmentCombinationFn.GCD, + GenericSegmentTree.RangeUpdateFn.MULTIPLICATION); + + // gcd(6, 12, 18) = 6 + assertThat(st.rangeQuery(0, 2)).isEqualTo(6); + + // Multiply [0,1] by 3 -> {18, 36, 18}, gcd = 18 + st.rangeUpdate(0, 1, 3); + assertThat(st.rangeQuery(0, 2)).isEqualTo(18); + + // Multiply [1,2] by 2 -> {18, 72, 36}, gcd = 18 + st.rangeUpdate(1, 2, 2); + assertThat(st.rangeQuery(0, 2)).isEqualTo(18); + assertThat(st.rangeQuery(1, 2)).isEqualTo(36); + } + + // ====================================================== + // Query before any updates (initial state) + // ====================================================== + + @Test + public void testQueryWithoutUpdates_allCombos() { + long[] ar = {6, 3, 12, 9}; + + // SUM + GenericSegmentTree sumSt = + new GenericSegmentTree( + ar, + GenericSegmentTree.SegmentCombinationFn.SUM, + GenericSegmentTree.RangeUpdateFn.ADDITION); + assertThat(sumSt.rangeQuery(0, 3)).isEqualTo(30); + assertThat(sumSt.rangeQuery(1, 2)).isEqualTo(15); + assertThat(sumSt.rangeQuery(0, 0)).isEqualTo(6); + + // MIN + GenericSegmentTree minSt = + new GenericSegmentTree( + ar, + GenericSegmentTree.SegmentCombinationFn.MIN, + GenericSegmentTree.RangeUpdateFn.ADDITION); + assertThat(minSt.rangeQuery(0, 3)).isEqualTo(3); + assertThat(minSt.rangeQuery(2, 3)).isEqualTo(9); + + // MAX + GenericSegmentTree maxSt = + new GenericSegmentTree( + ar, + GenericSegmentTree.SegmentCombinationFn.MAX, + GenericSegmentTree.RangeUpdateFn.ADDITION); + assertThat(maxSt.rangeQuery(0, 3)).isEqualTo(12); + assertThat(maxSt.rangeQuery(0, 1)).isEqualTo(6); + + // GCD + GenericSegmentTree gcdSt = + new GenericSegmentTree( + ar, + GenericSegmentTree.SegmentCombinationFn.GCD, + GenericSegmentTree.RangeUpdateFn.ASSIGN); + assertThat(gcdSt.rangeQuery(0, 3)).isEqualTo(3); + assertThat(gcdSt.rangeQuery(0, 2)).isEqualTo(3); + assertThat(gcdSt.rangeQuery(2, 3)).isEqualTo(3); + + // PRODUCT + GenericSegmentTree prodSt = + new GenericSegmentTree( + ar, + GenericSegmentTree.SegmentCombinationFn.PRODUCT, + GenericSegmentTree.RangeUpdateFn.MULTIPLICATION); + assertThat(prodSt.rangeQuery(0, 3)).isEqualTo(6 * 3 * 12 * 9); + assertThat(prodSt.rangeQuery(1, 2)).isEqualTo(36); + } + + // ====================================================== + // Product + Assign with overlapping updates + // ====================================================== + + @Test + public void testProductQueryAssignUpdate_overlapping() { + long[] ar = {1, 2, 3, 4, 5}; + GenericSegmentTree st = + new GenericSegmentTree( + ar, + GenericSegmentTree.SegmentCombinationFn.PRODUCT, + GenericSegmentTree.RangeUpdateFn.ASSIGN); + + // Assign [0,4] = 2 -> {2, 2, 2, 2, 2}, product = 32 + st.rangeUpdate(0, 4, 2); + assertThat(st.rangeQuery(0, 4)).isEqualTo(32); + assertThat(st.rangeQuery(0, 0)).isEqualTo(2); + assertThat(st.rangeQuery(2, 3)).isEqualTo(4); + + // Assign [1,3] = 3 -> {2, 3, 3, 3, 2}, product = 108 + st.rangeUpdate(1, 3, 3); + assertThat(st.rangeQuery(0, 4)).isEqualTo(108); + assertThat(st.rangeQuery(1, 3)).isEqualTo(27); + assertThat(st.rangeQuery(0, 0)).isEqualTo(2); + assertThat(st.rangeQuery(4, 4)).isEqualTo(2); + } + + // ====================================================== + // MIN/MAX + MUL randomized test (positive-only values) + // ====================================================== + + @Test + public void testMinMaxMulUpdate_randomized_positiveOnly() { + GenericSegmentTree.SegmentCombinationFn[] combinationFns = { + GenericSegmentTree.SegmentCombinationFn.MIN, + GenericSegmentTree.SegmentCombinationFn.MAX, + }; + + for (GenericSegmentTree.SegmentCombinationFn combinationFn : combinationFns) { + for (int n = 5; n < 50; n++) { + long[] ar = TestUtils.randomLongArray(n, 1, 10); + GenericSegmentTree st = + new GenericSegmentTree( + ar, combinationFn, GenericSegmentTree.RangeUpdateFn.MULTIPLICATION); + + for (int i = 0; i < n; i++) { + int j = TestUtils.randValue(0, n - 1); + int k = TestUtils.randValue(0, n - 1); + int i1 = Math.min(j, k); + int i2 = Math.max(j, k); + + j = TestUtils.randValue(0, n - 1); + k = TestUtils.randValue(0, n - 1); + int i3 = Math.min(j, k); + int i4 = Math.max(j, k); + + // Only positive multipliers to avoid min/max sign-flip issues + long randValue = TestUtils.randValue(1, 3); + + bruteForceMulRangeUpdate(ar, i3, i4, randValue); + st.rangeUpdate(i3, i4, randValue); + + long bf; + if (combinationFn == GenericSegmentTree.SegmentCombinationFn.MIN) { + bf = bruteForceMin(ar, i1, i2); + } else { + bf = bruteForceMax(ar, i1, i2); + } + + long segTreeAnswer = st.rangeQuery(i1, i2); + assertThat(segTreeAnswer).isEqualTo(bf); + } + } + } + } + + // ====================================================== + // PRODUCT randomized tests (small values to avoid overflow) + // ====================================================== + + @Test + public void testProductMul_randomized() { + for (int n = 5; n < 20; n++) { + long[] ar = TestUtils.randomLongArray(n, 1, 3); + GenericSegmentTree st = + new GenericSegmentTree( + ar, + GenericSegmentTree.SegmentCombinationFn.PRODUCT, + GenericSegmentTree.RangeUpdateFn.MULTIPLICATION); + + for (int i = 0; i < n; i++) { + int j = TestUtils.randValue(0, n - 1); + int k = TestUtils.randValue(0, n - 1); + int i1 = Math.min(j, k); + int i2 = Math.max(j, k); + + j = TestUtils.randValue(0, n - 1); + k = TestUtils.randValue(0, n - 1); + int i3 = Math.min(j, k); + int i4 = Math.max(j, k); + + long randValue = TestUtils.randValue(1, 2); + bruteForceMulRangeUpdate(ar, i3, i4, randValue); + st.rangeUpdate(i3, i4, randValue); + + long bf = bruteForceMul(ar, i1, i2); + assertThat(st.rangeQuery(i1, i2)).isEqualTo(bf); + } + } + } + + @Test + public void testProductAssign_randomized() { + for (int n = 5; n < 20; n++) { + long[] ar = TestUtils.randomLongArray(n, 1, 3); + GenericSegmentTree st = + new GenericSegmentTree( + ar, + GenericSegmentTree.SegmentCombinationFn.PRODUCT, + GenericSegmentTree.RangeUpdateFn.ASSIGN); + + for (int i = 0; i < n; i++) { + int j = TestUtils.randValue(0, n - 1); + int k = TestUtils.randValue(0, n - 1); + int i1 = Math.min(j, k); + int i2 = Math.max(j, k); + + j = TestUtils.randValue(0, n - 1); + k = TestUtils.randValue(0, n - 1); + int i3 = Math.min(j, k); + int i4 = Math.max(j, k); + + long randValue = TestUtils.randValue(1, 3); + bruteForceAssignRangeUpdate(ar, i3, i4, randValue); + st.rangeUpdate(i3, i4, randValue); + + long bf = bruteForceMul(ar, i1, i2); + assertThat(st.rangeQuery(i1, i2)).isEqualTo(bf); + } + } + } + + // ====================================================== + // Randomized test for all valid function combinations + // ====================================================== @Test public void testAllFunctionCombinations() { @@ -225,7 +898,6 @@ public void testAllFunctionCombinations() { GenericSegmentTree.SegmentCombinationFn.MIN, GenericSegmentTree.SegmentCombinationFn.MAX, GenericSegmentTree.SegmentCombinationFn.GCD, - // GenericSegmentTree.SegmentCombinationFn.PRODUCT, }; GenericSegmentTree.RangeUpdateFn[] rangeUpdateFns = { @@ -237,20 +909,16 @@ public void testAllFunctionCombinations() { for (GenericSegmentTree.SegmentCombinationFn combinationFn : combinationFns) { for (GenericSegmentTree.RangeUpdateFn rangeUpdateFn : rangeUpdateFns) { - // TODO(issue/208): The multiplication range update function seems to be suffering - // from overflow issues and not being able to handle negative numbers. - // - // One idea might be to also track the min value for the max query and vice versa - // and swap values when a negative number is found? + // MIN/MAX + MULTIPLICATION may produce incorrect results with negative multipliers if (rangeUpdateFn == GenericSegmentTree.RangeUpdateFn.MULTIPLICATION && (combinationFn == GenericSegmentTree.SegmentCombinationFn.MIN || combinationFn == GenericSegmentTree.SegmentCombinationFn.MAX)) { continue; } + // GCD + ADDITION is not supported if (combinationFn == GenericSegmentTree.SegmentCombinationFn.GCD && rangeUpdateFn == GenericSegmentTree.RangeUpdateFn.ADDITION) { - // Not supported continue; } @@ -265,7 +933,6 @@ public void testAllFunctionCombinations() { GenericSegmentTree st = new GenericSegmentTree(ar, combinationFn, rangeUpdateFn); for (int i = 0; i < n; i++) { - // System.out.printf("i = %d\n", i); int j = TestUtils.randValue(0, n - 1); int k = TestUtils.randValue(0, n - 1); int i1 = Math.min(j, k); @@ -276,9 +943,7 @@ public void testAllFunctionCombinations() { int i3 = Math.min(j, k); int i4 = Math.max(j, k); - // Range update long randValue = getRandValueByTestType(combinationFn); - // System.out.printf("UPDATE [%d, %d] with %d\n", i3, i4, randValue); if (rangeUpdateFn == GenericSegmentTree.RangeUpdateFn.ADDITION) { bruteForceSumRangeUpdate(ar, i3, i4, randValue); @@ -288,9 +953,8 @@ public void testAllFunctionCombinations() { bruteForceMulRangeUpdate(ar, i3, i4, randValue); } - st.rangeUpdate1(i3, i4, randValue); + st.rangeUpdate(i3, i4, randValue); - // Range query long bf = 0; if (combinationFn == GenericSegmentTree.SegmentCombinationFn.SUM) { @@ -305,13 +969,11 @@ public void testAllFunctionCombinations() { bf = bruteForceMul(ar, i1, i2); } - long segTreeAnswer = st.rangeQuery1(i1, i2); + long segTreeAnswer = st.rangeQuery(i1, i2); if (bf != segTreeAnswer) { - System.out.printf( "Range query type: %s, range update type: %s, QUERY [%d, %d], want = %d, got = %d\n", combinationFn, rangeUpdateFn, i1, i2, bf, segTreeAnswer); - System.out.println(java.util.Arrays.toString(ar)); } assertThat(segTreeAnswer).isEqualTo(bf); @@ -331,14 +993,12 @@ private static long getRandValueByTestType( private static long[] generateRandomArrayByTestType( int n, GenericSegmentTree.SegmentCombinationFn combinationFn) { - // GCD doesn't play well with negative numbers if (combinationFn != GenericSegmentTree.SegmentCombinationFn.GCD) { return TestUtils.randomLongArray(n, -100, +100); } return TestUtils.randomLongArray(n, 1, +10); } - // Finds the sum in an array between [l, r] in the `values` array private static long bruteForceSum(long[] values, int l, int r) { long s = 0; for (int i = l; i <= r; i++) { @@ -347,7 +1007,6 @@ private static long bruteForceSum(long[] values, int l, int r) { return s; } - // Finds the min value in an array between [l, r] in the `values` array private static long bruteForceMin(long[] values, int l, int r) { long m = values[l]; for (int i = l; i <= r; i++) { @@ -356,7 +1015,6 @@ private static long bruteForceMin(long[] values, int l, int r) { return m; } - // Finds the max value in an array between [l, r] in the `values` array private static long bruteForceMax(long[] values, int l, int r) { long m = values[l]; for (int i = l; i <= r; i++) { @@ -374,16 +1032,14 @@ private static long bruteForceMul(long[] values, int l, int r) { } private static long gcd(long a, long b) { - long gcd = a; while (b != 0) { - gcd = b; + long tmp = b; b = a % b; - a = gcd; + a = tmp; } - return Math.abs(gcd); + return Math.abs(a); } - // Finds the sum in an array between [l, r] in the `values` array private static long bruteForceGcd(long[] values, int l, int r) { long s = values[l]; for (int i = l; i <= r; i++) {