Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.vectorized.execution.search.spi;

import java.util.List;
import java.util.Map;

/**
* Service Provider Interface for query execution results.
* Implementations provide access to columnar query results from different execution engines.
*
* @opensearch.experimental
*/
public interface QueryResult {

/**
* Returns the columnar result data where each entry maps a column name to its list of values.
*
* @return Map of column names to their corresponding value lists
*/
Map<String, List<Object>> getColumns();
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
import org.opensearch.search.fetch.subphase.FetchSourceContext;
import org.opensearch.search.internal.ReaderContext;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.datafusion.search.DfResult;
import org.opensearch.search.internal.ShardSearchRequest;
import org.opensearch.search.lookup.SourceLookup;

Expand Down Expand Up @@ -213,7 +214,7 @@ public void close() {

@Override
public void executeQueryPhase(DatafusionContext context) {
Map<String, Object[]> finalRes = new HashMap<>();
Map<String, List<Object>> finalRes = new HashMap<>();
List<Long> rowIdResult = new ArrayList<>();
RecordBatchStream stream = null;

Expand Down Expand Up @@ -244,7 +245,11 @@ public void executeQueryPhase(DatafusionContext context) {
fieldValues[i] = fieldVector.getObject(i);
}
}
finalRes.put(fieldName, fieldValues);
if(finalRes.containsKey(fieldName)) {
finalRes.get(fieldName).addAll(Arrays.asList(fieldValues));
} else {
finalRes.put(fieldName, new ArrayList<>(Arrays.asList(fieldValues)));
}
}
}
};
Expand Down Expand Up @@ -278,18 +283,18 @@ public void executeQueryPhase(DatafusionContext context) {
throw new RuntimeException(e);
}
}
context.setDFResults(finalRes);
context.setDFResults(new DfResult(finalRes));
context.queryResult().topDocs(new TopDocsAndMaxScore(new TopDocs(new TotalHits(rowIdResult.size(), TotalHits.Relation.EQUAL_TO), rowIdResult.stream().map(d-> new ScoreDoc(d.intValue(), Float.NaN, context.indexShard().shardId().getId())).toList().toArray(ScoreDoc[]::new)) , Float.NaN), new DocValueFormat[0]);
}

@Override
public void executeQueryPhaseAsync(DatafusionContext context, Executor executor, ActionListener<Map<String, Object[]>> listener) {
public void executeQueryPhaseAsync(DatafusionContext context, Executor executor, ActionListener<DfResult> listener) {
try {
DatafusionSearcher datafusionSearcher = context.getEngineSearcher();
context.getDatafusionQuery().setQueryPlanExplainEnabled(context.evaluateSearchQueryExplainMode());

datafusionSearcher.searchAsync(context.getDatafusionQuery(), datafusionService.getRuntimePointer()).whenCompleteAsync((streamPointer, error)-> {
Map<String, Object[]> finalRes = new HashMap<>();
Map<String, List<Object>> finalResColumns = new HashMap<>();
List<Long> rowIdResult = new ArrayList<>();
if(streamPointer == null) {
throw new RuntimeException(error);
Expand All @@ -303,24 +308,27 @@ public void collect(RecordBatchStream value) {
for (Field field : root.getSchema().getFields()) {
String fieldName = field.getName();
FieldVector fieldVector = root.getVector(fieldName);
Object[] fieldValues = new Object[fieldVector.getValueCount()];
List<Object> fieldValues = new ArrayList<>(fieldVector.getValueCount());
if (fieldName.equals(CompositeDataFormatWriter.ROW_ID)) {
FieldVector rowIdVector = root.getVector(fieldName);
for(int i=0; i<fieldVector.getValueCount(); i++) {
rowIdResult.add((long) rowIdVector.getObject(i));
fieldValues[i] = fieldVector.getObject(i);
fieldValues.add(fieldVector.getObject(i));
}
}
else {
} else {
for (int i = 0; i < fieldVector.getValueCount(); i++) {
fieldValues[i] = fieldVector.getObject(i);
fieldValues.add(fieldVector.getObject(i));
}
}
finalRes.put(fieldName, fieldValues);
if(finalResColumns.containsKey(fieldName)) {
finalResColumns.get(fieldName).addAll(fieldValues);
} else {
finalResColumns.put(fieldName, fieldValues);
}
}
}
};
loadNextBatch(stream, executor, collector, finalRes, allocator, listener, context, rowIdResult);
loadNextBatch(stream, executor, collector, finalResColumns, allocator, listener, context, rowIdResult);
});

// logger.info("Memory Pool Allocation Post Query ShardID:{}", context.getQueryShardContext().getShardId());
Expand All @@ -343,9 +351,9 @@ private void loadNextBatch(
RecordBatchStream stream,
Executor executor,
SearchResultsCollector<RecordBatchStream> collector,
Map<String, Object[]> finalRes,
Map<String, List<Object>> finalRes,
RootAllocator allocator,
ActionListener<Map<String, Object[]>> listener,
ActionListener<DfResult> listener,
DatafusionContext context,
List<Long> rowIdResult
) {
Expand All @@ -365,7 +373,8 @@ private void loadNextBatch(
context.queryResult().topDocs(new TopDocsAndMaxScore(new TopDocs(new TotalHits(rowIdResult.size(),
TotalHits.Relation.EQUAL_TO), rowIdResult.stream().map(d-> new ScoreDoc(d.intValue(),
Float.NaN, context.indexShard().shardId().getId())).toList().toArray(ScoreDoc[]::new)) , Float.NaN), new DocValueFormat[0]);
listener.onResponse(finalRes);
// ArrayList<> --> Object[]
listener.onResponse(new DfResult(finalRes));
}
}, error -> {
cleanup(stream, allocator);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ public class DatafusionContext extends SearchContext {
private final IndexService indexService;
private final QueryShardContext queryShardContext;
private DatafusionQuery datafusionQuery;
private Map<String, Object[]> dfResults;
private QueryResult dfResults;
private SearchContextAggregations aggregations;
private final BigArrays bigArrays;
private final Map<Class<?>, CollectorManager<? extends Collector, ReduceableSearchResult>> queryCollectorManagers = new HashMap<>();
Expand Down Expand Up @@ -825,11 +825,11 @@ public ContextEngineSearcher<DatafusionQuery, RecordBatchStream> contextEngineSe
return new ContextEngineSearcher<>(this.engineSearcher, this);
}

public void setDFResults(Map<String, Object[]> dfResults) {
public void setDFResults(QueryResult dfResults) {
this.dfResults = dfResults;
}

public Map<String, Object[]> getDFResults() {
public QueryResult getDFResults() {
return dfResults;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.datafusion.search;

import org.opensearch.vectorized.execution.search.spi.QueryResult;

import java.util.List;
import java.util.Map;

/**
* Wraps the columnar result from a DataFusion query execution.
* Each entry maps a column name to its list of values.
* Implements the QueryResult SPI to allow usage in core without creating a dependency.
*/
public class DfResult implements QueryResult {

private final Map<String, List<Object>> columns;

public DfResult(Map<String, List<Object>> columns) {
this.columns = columns;
}

@Override
public Map<String, List<Object>> getColumns() {
return columns;
}
}
1 change: 1 addition & 0 deletions server/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ dependencies {
api project(":libs:opensearch-geo")
api project(":libs:opensearch-telemetry")
api project(":libs:opensearch-task-commons")
api project(":libs:opensearch-vectorized-exec-spi")

compileOnly project(":libs:agent-sm:bootstrap")
compileOnly project(':libs:opensearch-plugin-classloader')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.opensearch.search.internal.ReaderContext;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.internal.ShardSearchRequest;
import org.opensearch.vectorized.execution.search.spi.QueryResult;

import java.io.IOException;
import java.util.Map;
Expand All @@ -44,7 +45,7 @@ public abstract class SearchExecEngine<C extends SearchContext, S extends Engine
*/
public abstract void executeQueryPhase(C context) throws IOException;

public abstract void executeQueryPhaseAsync(C context, Executor executor, ActionListener<Map<String, Object[]>> listener);
public abstract void executeQueryPhaseAsync(C context, Executor executor, ActionListener<QueryResult> listener);

/**
* execute Fetch Phase
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@
import org.opensearch.search.profile.ProfileShardResult;
import org.opensearch.search.profile.Profilers;
import org.opensearch.search.profile.SearchProfileShardResults;
import org.opensearch.vectorized.execution.search.spi.QueryResult;
import org.opensearch.search.query.*;
import org.opensearch.search.rescore.RescorerBuilder;
import org.opensearch.search.searchafter.SearchAfterBuilder;
Expand Down Expand Up @@ -958,9 +959,9 @@ private void executeNativeQueryPhaseAsync(
SearchExecEngine searchExecEngine = indexer instanceof CompositeEngine ? ((CompositeEngine) indexer).getPrimaryReadEngine() : null;

// Execute native query async
searchExecEngine.executeQueryPhaseAsync(finalContext, executor, new ActionListener<Map<String, Object[]>>() {
searchExecEngine.executeQueryPhaseAsync(finalContext, executor, new ActionListener<QueryResult>() {
@Override
public void onResponse(Map<String, Object[]> result) {
public void onResponse(QueryResult result) {
try {
finalContext.setDFResults(result);
// Continue with rest of query phase
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,25 @@
package org.opensearch.search.aggregations;

import org.opensearch.search.internal.SearchContext;
import org.opensearch.vectorized.execution.search.spi.QueryResult;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;

public interface ShardResultConvertor {

default List<InternalAggregation> convert(Map<String, Object[]> shardResult, SearchContext searchContext) {
int rows = shardResult.entrySet().stream().findFirst().get().getValue().length;
default List<InternalAggregation> convert(QueryResult queryResult, SearchContext searchContext) {
Map<String, List<Object>> shardResult = queryResult.getColumns();
int rows = shardResult.entrySet().stream().findFirst().get().getValue().size();
List<InternalAggregation> internalAggregations = new ArrayList<>();
for (int i = 0; i < rows; i++) {
internalAggregations.add(convertRow(shardResult, i, searchContext));
}
return internalAggregations;
}

default InternalAggregation convertRow(Map<String, Object[]> shardResult, int row, SearchContext searchContext) {
default InternalAggregation convertRow(Map<String, List<Object>> shardResult, int row, SearchContext searchContext) {
throw new UnsupportedOperationException("Row conversion not supported");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
import org.opensearch.common.Rounding;
import org.opensearch.common.collect.Tuple;
import org.opensearch.common.lease.Releasables;
import org.opensearch.vectorized.execution.search.spi.QueryResult;
import org.opensearch.index.IndexSortConfig;
import org.opensearch.lucene.queries.SearchAfterSortedDocQuery;
import org.opensearch.search.DocValueFormat;
Expand Down Expand Up @@ -734,25 +735,26 @@ public void collect(int doc, long zeroBucket) throws IOException {
}

@Override
public List<InternalAggregation> convert(Map<String, Object[]> shardResult, SearchContext searchContext) {
public List<InternalAggregation> convert(QueryResult dfResult, SearchContext searchContext) {
Map<String, List<Object>> shardResult = dfResult.getColumns();
if(shardResult.isEmpty()) {
return Collections.singletonList(buildEmptyAggregation());
}
// Generate the composite keys
List<Comparable<?>> currentCompositeKey = new ArrayList<>(sourceConfigs.length);
List<CompositeKey> compositeKeys = new ArrayList<>(shardResult.size());
if (shardResult.isEmpty() == false) {
for (int i = 0; i < shardResult.get(shardResult.keySet().stream().findFirst().get()).length; i++) {
for (int i = 0; i < shardResult.get(shardResult.keySet().stream().findFirst().get()).size(); i++) {
for (CompositeValuesSourceConfig sourceConfig : sourceConfigs) {
// if (sourceConfig.fieldType() == null) {
// throw new UnsupportedOperationException("Composite aggregation does not support script field types");
// }
// source=hits | eval m = extract(minute from EventTime) | stats count() by UserID, m, SearchPhrase | sort - \`count()\` | head 10
// for above query without this change it will fail above
// We can get the name directly from sourceConfig
Object[] values = shardResult.get(sourceConfig.name());
List<Object> values = shardResult.get(sourceConfig.name());
// TODO : Would require conversion for certain types,
currentCompositeKey.add(searchContext.convertToComparable(values[i]));
currentCompositeKey.add(searchContext.convertToComparable(values.get(i)));
}
compositeKeys.add(new CompositeKey(currentCompositeKey.toArray(new Comparable[0])));
currentCompositeKey.clear();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.util.BytesRef;
import org.opensearch.common.collect.Tuple;
import org.opensearch.vectorized.execution.search.spi.QueryResult;
import org.opensearch.search.DocValueFormat;
import org.opensearch.search.aggregations.Aggregator;
import org.opensearch.search.aggregations.AggregatorFactories;
Expand Down Expand Up @@ -120,14 +121,15 @@ protected SignificantStringTerms buildEmptySignificantTermsAggregation(long subs
}

@Override
public List<InternalAggregation> convert(Map<String, Object[]> shardResult, SearchContext searchContext) {
public List<InternalAggregation> convert(QueryResult dfResult, SearchContext searchContext) {
Map<String, List<Object>> shardResult = dfResult.getColumns();
if(shardResult.isEmpty()) {
return Collections.singletonList(buildEmptyTermsAggregation());
}
int rowCount = shardResult.get(shardResult.keySet().stream().findFirst().get()).length;
int rowCount = shardResult.get(shardResult.keySet().stream().findFirst().get()).size();
List<StringTerms.Bucket> buckets = new ArrayList<>(rowCount);
for (int row = 0; row < rowCount; row++) {
String termKey = (String) searchContext.convertToComparable(shardResult.get(name)[row]);
String termKey = (String) searchContext.convertToComparable(shardResult.get(name).get(row));
Tuple<List<InternalAggregation>, Long> subAggsAndDocCount = SearchEngineResultConversionUtils.extractSubAggsAndDocCount(subAggregators, searchContext, shardResult, row);
buckets.add(new StringTerms.Bucket(
new BytesRef(termKey),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.lease.Releasable;
import org.opensearch.common.lease.Releasables;
import org.opensearch.vectorized.execution.search.spi.QueryResult;
import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
Expand Down Expand Up @@ -709,12 +710,13 @@ static InternalValuesSource doubleValueSource(ValuesSource.Numeric valuesSource,
}

@Override
public List<InternalAggregation> convert(Map<String, Object[]> shardResult, SearchContext searchContext) {
int rowCount = shardResult.isEmpty() ? 0 : shardResult.get(fields.getFirst()).length ;
public List<InternalAggregation> convert(QueryResult dfResult, SearchContext searchContext) {
Map<String, List<Object>> shardResult = dfResult.getColumns();
int rowCount = shardResult.isEmpty() ? 0 : shardResult.get(fields.getFirst()).size() ;
List<InternalMultiTerms.Bucket> buckets = new ArrayList<>(rowCount);
for (int i = 0; i < rowCount; i++) {
final int j = i;
List<Object> key = fields.stream().map(fieldName -> (Object) searchContext.convertToComparable(shardResult.get(fieldName)[j])).toList();
List<Object> key = fields.stream().map(fieldName -> (Object) searchContext.convertToComparable(shardResult.get(fieldName).get(j))).toList();
Tuple<List<InternalAggregation>, Long> subAggsAndDocCount = SearchEngineResultConversionUtils.extractSubAggsAndDocCount(subAggregators, searchContext, shardResult, i);
buckets.add(new InternalMultiTerms.Bucket(key, subAggsAndDocCount.v2(), InternalAggregations.from(subAggsAndDocCount.v1()), showTermDocCountError, 0, formats));
}
Expand Down
Loading
Loading