77
88use crate :: ast:: DistanceMetric ;
99use crate :: error:: { GraphError , Result } ;
10- use arrow:: array:: { Array , ArrayRef , FixedSizeListArray , Float32Array } ;
10+ use arrow:: array:: { Array , ArrayRef , FixedSizeListArray , Float32Array , ListArray } ;
1111
12- /// Extract vectors from Arrow FixedSizeListArray
12+ /// Extract vectors from Arrow ListArray or FixedSizeListArray
13+ ///
14+ /// Accepts both types for user convenience:
15+ /// - FixedSizeListArray: from Lance datasets or explicit construction
16+ /// - ListArray: from natural table construction with nested lists
1317pub fn extract_vectors ( array : & ArrayRef ) -> Result < Vec < Vec < f32 > > > {
14- let list_array = array
15- . as_any ( )
16- . downcast_ref :: < FixedSizeListArray > ( )
17- . ok_or_else ( || GraphError :: ExecutionError {
18- message : "Expected FixedSizeListArray for vector column" . to_string ( ) ,
19- location : snafu:: Location :: new ( file ! ( ) , line ! ( ) , column ! ( ) ) ,
20- } ) ?;
18+ // Try FixedSizeListArray first (more common in Lance)
19+ if let Some ( list_array) = array. as_any ( ) . downcast_ref :: < FixedSizeListArray > ( ) {
20+ let mut vectors = Vec :: with_capacity ( list_array. len ( ) ) ;
21+ for i in 0 ..list_array. len ( ) {
22+ if list_array. is_null ( i) {
23+ return Err ( GraphError :: ExecutionError {
24+ message : "Null vector in FixedSizeListArray" . to_string ( ) ,
25+ location : snafu:: Location :: new ( file ! ( ) , line ! ( ) , column ! ( ) ) ,
26+ } ) ;
27+ }
28+ let value_array = list_array. value ( i) ;
29+ let float_array = value_array
30+ . as_any ( )
31+ . downcast_ref :: < Float32Array > ( )
32+ . ok_or_else ( || GraphError :: ExecutionError {
33+ message : "Expected Float32Array in vector" . to_string ( ) ,
34+ location : snafu:: Location :: new ( file ! ( ) , line ! ( ) , column ! ( ) ) ,
35+ } ) ?;
36+
37+ let vec: Vec < f32 > = ( 0 ..float_array. len ( ) )
38+ . map ( |j| float_array. value ( j) )
39+ . collect ( ) ;
40+ vectors. push ( vec) ;
41+ }
42+ return Ok ( vectors) ;
43+ }
2144
22- let mut vectors = Vec :: with_capacity ( list_array. len ( ) ) ;
23- for i in 0 ..list_array. len ( ) {
24- let value_array = list_array. value ( i) ;
25- let float_array = value_array
26- . as_any ( )
27- . downcast_ref :: < Float32Array > ( )
28- . ok_or_else ( || GraphError :: ExecutionError {
29- message : "Expected Float32Array in vector" . to_string ( ) ,
30- location : snafu:: Location :: new ( file ! ( ) , line ! ( ) , column ! ( ) ) ,
31- } ) ?;
32-
33- let vec: Vec < f32 > = ( 0 ..float_array. len ( ) )
34- . map ( |j| float_array. value ( j) )
35- . collect ( ) ;
36- vectors. push ( vec) ;
45+ // Try ListArray (from nested list construction)
46+ if let Some ( list_array) = array. as_any ( ) . downcast_ref :: < ListArray > ( ) {
47+ let mut vectors = Vec :: with_capacity ( list_array. len ( ) ) ;
48+ for i in 0 ..list_array. len ( ) {
49+ if list_array. is_null ( i) {
50+ return Err ( GraphError :: ExecutionError {
51+ message : "Null vector in ListArray" . to_string ( ) ,
52+ location : snafu:: Location :: new ( file ! ( ) , line ! ( ) , column ! ( ) ) ,
53+ } ) ;
54+ }
55+ let value_array = list_array. value ( i) ;
56+ let float_array = value_array
57+ . as_any ( )
58+ . downcast_ref :: < Float32Array > ( )
59+ . ok_or_else ( || GraphError :: ExecutionError {
60+ message : "Expected Float32Array in vector" . to_string ( ) ,
61+ location : snafu:: Location :: new ( file ! ( ) , line ! ( ) , column ! ( ) ) ,
62+ } ) ?;
63+
64+ let vec: Vec < f32 > = ( 0 ..float_array. len ( ) )
65+ . map ( |j| float_array. value ( j) )
66+ . collect ( ) ;
67+ vectors. push ( vec) ;
68+ }
69+ return Ok ( vectors) ;
3770 }
3871
39- Ok ( vectors)
72+ Err ( GraphError :: ExecutionError {
73+ message : "Expected ListArray or FixedSizeListArray for vector column" . to_string ( ) ,
74+ location : snafu:: Location :: new ( file ! ( ) , line ! ( ) , column ! ( ) ) ,
75+ } )
4076}
4177
4278/// Extract a single vector from a ScalarValue
@@ -136,8 +172,8 @@ pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
136172/// For similarity search, we return the negative (so lower is better for sorting)
137173pub fn dot_product_distance ( a : & [ f32 ] , b : & [ f32 ] ) -> f32 {
138174 if a. len ( ) != b. len ( ) {
139- // Dimension mismatch - return max distance
140- return f32:: MIN ;
175+ // Dimension mismatch - return worst distance to exclude from results
176+ return f32:: MAX ;
141177 }
142178
143179 -a. iter ( ) . zip ( b. iter ( ) ) . map ( |( x, y) | x * y) . sum :: < f32 > ( )
@@ -146,7 +182,7 @@ pub fn dot_product_distance(a: &[f32], b: &[f32]) -> f32 {
146182/// Compute dot product similarity (for vector_similarity function)
147183pub fn dot_product_similarity ( a : & [ f32 ] , b : & [ f32 ] ) -> f32 {
148184 if a. len ( ) != b. len ( ) {
149- // Dimension mismatch
185+ // Dimension mismatch - return worst similarity to exclude from results
150186 return f32:: MIN ;
151187 }
152188
@@ -256,6 +292,12 @@ mod tests {
256292
257293 let dist = cosine_distance ( & a, & b) ;
258294 assert_eq ! ( dist, 2.0 ) ;
295+
296+ let dist = dot_product_distance ( & a, & b) ;
297+ assert_eq ! ( dist, f32 :: MAX ) ;
298+
299+ let sim = dot_product_similarity ( & a, & b) ;
300+ assert_eq ! ( sim, f32 :: MIN ) ;
259301 }
260302
261303 #[ test]
@@ -340,4 +382,104 @@ mod tests {
340382 assert_eq ! ( similarities[ 1 ] , 0.0 ) ; // Orthogonal
341383 assert ! ( ( similarities[ 2 ] - 0.707 ) . abs( ) < 0.01 ) ; // cos(45°) ≈ 0.707
342384 }
385+
386+ #[ test]
387+ fn test_extract_vectors_from_fixed_size_list ( ) {
388+ use arrow:: datatypes:: { DataType , Field } ;
389+
390+ // Create FixedSizeListArray with 3D vectors
391+ let field = Arc :: new ( Field :: new ( "item" , DataType :: Float32 , true ) ) ;
392+ let values = Arc :: new ( Float32Array :: from ( vec ! [
393+ 1.0 , 0.0 , 0.0 , // Vector 1
394+ 0.0 , 1.0 , 0.0 , // Vector 2
395+ 0.0 , 0.0 , 1.0 , // Vector 3
396+ ] ) ) ;
397+ let list_array = FixedSizeListArray :: try_new ( field, 3 , values, None ) . unwrap ( ) ;
398+ let array_ref: ArrayRef = Arc :: new ( list_array) ;
399+
400+ let result = extract_vectors ( & array_ref) ;
401+ assert ! ( result. is_ok( ) ) ;
402+
403+ let vectors = result. unwrap ( ) ;
404+ assert_eq ! ( vectors. len( ) , 3 ) ;
405+ assert_eq ! ( vectors[ 0 ] , vec![ 1.0 , 0.0 , 0.0 ] ) ;
406+ assert_eq ! ( vectors[ 1 ] , vec![ 0.0 , 1.0 , 0.0 ] ) ;
407+ assert_eq ! ( vectors[ 2 ] , vec![ 0.0 , 0.0 , 1.0 ] ) ;
408+ }
409+
410+ #[ test]
411+ fn test_extract_vectors_from_list_array ( ) {
412+ use arrow:: array:: ListBuilder ;
413+
414+ // Create ListArray with variable-length vectors (though we use same length)
415+ let values_builder = Float32Array :: builder ( 9 ) ;
416+ let mut list_builder = ListBuilder :: new ( values_builder) ;
417+
418+ // Add first vector [1.0, 0.0, 0.0]
419+ list_builder. values ( ) . append_value ( 1.0 ) ;
420+ list_builder. values ( ) . append_value ( 0.0 ) ;
421+ list_builder. values ( ) . append_value ( 0.0 ) ;
422+ list_builder. append ( true ) ;
423+
424+ // Add second vector [0.0, 1.0, 0.0]
425+ list_builder. values ( ) . append_value ( 0.0 ) ;
426+ list_builder. values ( ) . append_value ( 1.0 ) ;
427+ list_builder. values ( ) . append_value ( 0.0 ) ;
428+ list_builder. append ( true ) ;
429+
430+ // Add third vector [0.5, 0.5, 0.0]
431+ list_builder. values ( ) . append_value ( 0.5 ) ;
432+ list_builder. values ( ) . append_value ( 0.5 ) ;
433+ list_builder. values ( ) . append_value ( 0.0 ) ;
434+ list_builder. append ( true ) ;
435+
436+ let list_array = list_builder. finish ( ) ;
437+ let array_ref: ArrayRef = Arc :: new ( list_array) ;
438+
439+ let result = extract_vectors ( & array_ref) ;
440+ assert ! ( result. is_ok( ) ) ;
441+
442+ let vectors = result. unwrap ( ) ;
443+ assert_eq ! ( vectors. len( ) , 3 ) ;
444+ assert_eq ! ( vectors[ 0 ] , vec![ 1.0 , 0.0 , 0.0 ] ) ;
445+ assert_eq ! ( vectors[ 1 ] , vec![ 0.0 , 1.0 , 0.0 ] ) ;
446+ assert_eq ! ( vectors[ 2 ] , vec![ 0.5 , 0.5 , 0.0 ] ) ;
447+ }
448+
449+ #[ test]
450+ fn test_extract_vectors_rejects_invalid_type ( ) {
451+ // Test that extract_vectors rejects non-list arrays
452+ let float_array = Float32Array :: from ( vec ! [ 1.0 , 2.0 , 3.0 ] ) ;
453+ let array_ref: ArrayRef = Arc :: new ( float_array) ;
454+
455+ let result = extract_vectors ( & array_ref) ;
456+ assert ! ( result. is_err( ) ) ;
457+ assert ! ( result
458+ . unwrap_err( )
459+ . to_string( )
460+ . contains( "Expected ListArray or FixedSizeListArray" ) ) ;
461+ }
462+
463+ #[ test]
464+ fn test_extract_vectors_rejects_null_in_fixed_size_list ( ) {
465+ use arrow:: datatypes:: { DataType , Field } ;
466+
467+ // Create FixedSizeListArray with a null vector
468+ let field = Arc :: new ( Field :: new ( "item" , DataType :: Float32 , true ) ) ;
469+ let values = Arc :: new ( Float32Array :: from ( vec ! [
470+ 1.0 , 0.0 , 0.0 , // Vector 1
471+ 0.0 , 1.0 , 0.0 , // Vector 2 (will be null)
472+ 0.0 , 0.0 , 1.0 , // Vector 3
473+ ] ) ) ;
474+ let null_buffer = arrow:: buffer:: NullBuffer :: from ( vec ! [ true , false , true ] ) ;
475+ let list_array = FixedSizeListArray :: try_new ( field, 3 , values, Some ( null_buffer) ) . unwrap ( ) ;
476+ let array_ref: ArrayRef = Arc :: new ( list_array) ;
477+
478+ let result = extract_vectors ( & array_ref) ;
479+ assert ! ( result. is_err( ) ) ;
480+ assert ! ( result
481+ . unwrap_err( )
482+ . to_string( )
483+ . contains( "Null vector in FixedSizeListArray" ) ) ;
484+ }
343485}
0 commit comments