Skip to content

Commit 1780c66

Browse files
kedarbcs16agirish
andcommitted
MD-3579: Compare double values with float precision (#493)
Co-authored-by: kedarbcs16 <[email protected]> Co-authored-by: agirish <[email protected]>
1 parent f0bb452 commit 1780c66

File tree

2 files changed

+63
-48
lines changed

2 files changed

+63
-48
lines changed

framework/src/main/java/org/apache/drill/test/framework/ColumnList.java

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,18 +30,19 @@
3030
public class ColumnList {
3131
private final List<Object> values;
3232
private final List<Integer> types;
33-
private final boolean Simba;
33+
private final boolean simba;
3434
public static final String SIMBA_JDBC = "sjdbc";
35+
private static final double MAX_DIFF_VALUE = 1.0E-6;
3536

3637
public ColumnList(List<Integer> types, List<Object> values) {
3738
this.values = values;
3839
this.types = types;
3940
if (TestDriver.cmdParam.driverExt != null &&
4041
TestDriver.cmdParam.driverExt.equals(ColumnList.SIMBA_JDBC)) {
41-
this.Simba = true;
42+
this.simba = true;
4243
}
4344
else {
44-
this.Simba = false;
45+
this.simba = false;
4546
}
4647
}
4748

@@ -54,9 +55,15 @@ public List<Object> getValues() {
5455
* "loosened" logic to handle float, double and decimal types. The algorithm
5556
* used for the comparison follows:
5657
*
57-
* Floats: f1 equals f2 iff |((f1-f2)/(average(f1,f2)))| < 0.000001
58+
* Floats: f1 and f2 are equal
59+
* if f1 == f2
60+
* if f1 > f2 && if |(f1-f2)/f1| < 0.000001
61+
* if f1 < f2 && if |(f1-f2)/f2| < 0.000001
5862
*
59-
* Doubles: d1 equals d2 iff |((d1-d2)/(average(d1,d2)))| < 0.000000000001
63+
* Doubles: d1 and d2 are equal
64+
* if d1 == d2
65+
* if d1 > d2 && if |(d1-d2)/d1| < 0.000001
66+
* if d1 < d2 && if |(d1-d2)/d2| < 0.000001
6067
*
6168
* Decimals: dec1 equals dec2 iff value(dec1) == value(dec2) and scale(dec1)
6269
* == scale(dec2)
@@ -76,7 +83,7 @@ public int hashCode() {
7683
if (values.get(i) == null) {
7784
continue;
7885
}
79-
int type = (Integer) (types.get(i));
86+
int type = types.get(i);
8087
switch (type) {
8188
case Types.FLOAT:
8289
case Types.DOUBLE:
@@ -98,8 +105,8 @@ public int hashCode() {
98105
public String toString() {
99106
StringBuilder sb = new StringBuilder();
100107
for (int i = 0; i < values.size() - 1; i++) {
101-
int type = (Integer) (types.get(i));
102-
if (Simba && (type == Types.VARCHAR)) {
108+
int type = types.get(i);
109+
if (simba && (type == Types.VARCHAR)) {
103110
String s1 = String.valueOf(values.get(i));
104111
// if the field has a JSON string or list, then remove newlines so that
105112
// each record fits on a single line in the actual output file.
@@ -114,8 +121,8 @@ public String toString() {
114121
}
115122
sb.append(values.get(i) + "\t");
116123
}
117-
int type = (Integer) (types.get(values.size()-1));
118-
if (Simba && (type == Types.VARCHAR)) {
124+
int type = types.get(values.size()-1);
125+
if (simba && (type == Types.VARCHAR)) {
119126
String s1 = String.valueOf(values.get(values.size()-1));
120127
// if the field has a JSON string or list, then remove newlines so that
121128
// each record fits on a single line in the actual output file.
@@ -147,17 +154,23 @@ private boolean compare(ColumnList o1, ColumnList o2) {
147154
if (oneNull(list1.get(i), list2.get(i))) {
148155
return false;
149156
}
150-
int type = (Integer) (types.get(i));
157+
int type = types.get(i);
151158
try {
152159
switch (type) {
153160
case Types.FLOAT:
154161
case Types.REAL:
155162
float f1 = (Float) list1.get(i);
156163
float f2 = (Float) list2.get(i);
157-
if ((f1 + f2) / 2 != 0) {
158-
if (!(Math.abs((f1 - f2) / ((f1 + f2) / 2)) < 1.0E-6)) return false;
159-
} else if (f1 != 0) {
160-
return false;
164+
if (f1 != f2) {
165+
double relativeError;
166+
if (f1 > f2) {
167+
relativeError = Math.abs((f1-f2)/f1);
168+
} else {
169+
relativeError = Math.abs((f1-f2)/f2);
170+
}
171+
if (relativeError > MAX_DIFF_VALUE) {
172+
return false;
173+
}
161174
}
162175
break;
163176
case Types.DOUBLE:
@@ -167,9 +180,13 @@ private boolean compare(ColumnList o1, ColumnList o2) {
167180
// especially for the cases when doubles are NaN / POSITIVE_INFINITY / NEGATIVE_INFINITY
168181
// otherwise proceed with "loosened" logic
169182
if (!d1.equals(d2)) {
170-
if ((d1 + d2) / 2 != 0) {
171-
if (!(Math.abs((d1 - d2) / ((d1 + d2) / 2)) < 1.0E-12)) return false;
172-
} else if (d1 != 0) {
183+
double relativeError;
184+
if (d1 > d2) {
185+
relativeError = Math.abs((d1-d2)/d1);
186+
} else {
187+
relativeError = Math.abs((d1-d2)/d2);
188+
}
189+
if (relativeError > MAX_DIFF_VALUE) {
173190
return false;
174191
}
175192
}

framework/src/main/java/org/apache/drill/test/framework/TestVerifier.java

Lines changed: 28 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import java.io.BufferedReader;
2121
import java.io.BufferedWriter;
2222
import java.io.File;
23-
import java.io.FileNotFoundException;
2423
import java.io.FileReader;
2524
import java.io.FileWriter;
2625
import java.io.IOException;
@@ -30,7 +29,6 @@
3029
import java.nio.file.Paths;
3130
import java.sql.Types;
3231
import java.util.ArrayList;
33-
import java.util.Arrays;
3432
import java.util.HashMap;
3533
import java.util.Iterator;
3634
import java.util.LinkedHashMap;
@@ -51,6 +49,7 @@
5149
public class TestVerifier {
5250
private static final Logger LOG = Logger.getLogger("DrillTestLogger");
5351
private static final int MAX_MISMATCH_SIZE = 10;
52+
private static final double MAX_DIFF_VALUE = 1.0E-6;
5453
public TestStatus testStatus = TestStatus.PENDING;
5554
private int mapSize = 0;
5655
private List<ColumnList> resultSet = null;
@@ -253,28 +252,28 @@ private Map<ColumnList, Integer> loadFromFileToMap(String filename, boolean orde
253252
typedFields.add(null);
254253
continue;
255254
}
256-
int type = (Integer) (types.get(i));
255+
int type = types.get(i);
257256
try {
258257
switch (type) {
259-
case Types.INTEGER:
260-
case Types.BIGINT:
261-
case Types.SMALLINT:
262-
case Types.TINYINT:
263-
typedFields.add(new BigInteger(fields[i]));
264-
break;
265-
case Types.FLOAT:
266-
case Types.REAL:
267-
typedFields.add(new Float(fields[i]));
268-
break;
269-
case Types.DOUBLE:
270-
typedFields.add(new Double(fields[i]));
271-
break;
272-
case Types.DECIMAL:
273-
typedFields.add(new BigDecimal(fields[i]));
274-
break;
275-
default:
276-
typedFields.add(fields[i]);
277-
break;
258+
case Types.INTEGER:
259+
case Types.BIGINT:
260+
case Types.SMALLINT:
261+
case Types.TINYINT:
262+
typedFields.add(new BigInteger(fields[i]));
263+
break;
264+
case Types.FLOAT:
265+
case Types.REAL:
266+
typedFields.add(new Float(fields[i]));
267+
break;
268+
case Types.DOUBLE:
269+
typedFields.add(new Double(fields[i]));
270+
break;
271+
case Types.DECIMAL:
272+
typedFields.add(new BigDecimal(fields[i]));
273+
break;
274+
default:
275+
typedFields.add(fields[i]);
276+
break;
278277
}
279278
} catch (Exception e) {
280279
typedFields.add(fields[i]);
@@ -399,7 +398,7 @@ private static int compareTo(ColumnList list1, ColumnList list2,
399398
return 0;
400399
}
401400
int idx = columnIndexAndOrder.get(start).index;
402-
int result = -1;
401+
int result;
403402
Object o1 = list1.getValues().get(idx);
404403
Object o2 = list2.getValues().get(idx);
405404
if (ColumnList.bothNull(o1, o2)) {
@@ -408,9 +407,9 @@ private static int compareTo(ColumnList list1, ColumnList list2,
408407
return 0; // TODO handle NULLS FIRST and NULLS LAST cases
409408
}
410409
if (o1 instanceof Number) {
411-
Number number1 = (Number) o1;
412-
Number number2 = (Number) o2;
413-
double diff = number1.doubleValue() - number2.doubleValue();
410+
double d1 = ((Number) o1).doubleValue();
411+
double d2 = ((Number) o2).doubleValue();
412+
double diff = d1 - d2;
414413
if (diff == 0) {
415414
return compareTo(list1, list2, columnIndexAndOrder, start + 1, orderByColumns);
416415
} else {
@@ -447,9 +446,9 @@ private static int compareTo(ColumnList list1, ColumnList list2,
447446
return 1;
448447
}
449448
if (idNode1.isNumber()) {
450-
Number number1 = (Number) idNode1.asInt();
451-
Number number2 = (Number) idNode2.asInt();
452-
double diff = number1.doubleValue() - number2.doubleValue();
449+
double d1 = idNode1.doubleValue();
450+
double d2 = idNode2.doubleValue();
451+
double diff = d1 - d2;
453452
if (diff == 0) {
454453
return compareTo(list1, list2, columnIndexAndOrder, start + 1, orderByColumns);
455454
} else {
@@ -872,7 +871,6 @@ public static Map<String, String> getOrderByColumns(String statement,
872871
column = column.trim();
873872
String[] columnOrder = column.split("\\s+");
874873
String columnName = columnOrder[0].trim();
875-
String fieldName;
876874
if (columnName.indexOf('.') >= 0) {
877875
// table name precedes column name. remove it
878876
columnName = columnName.substring(columnName.indexOf('.') + 1);

0 commit comments

Comments
 (0)