Skip to content

Commit 73e0e4f

Browse files
authored
export stateful/aggregation/custom function registry (#74)
Signed-off-by: Song Gao <[email protected]>
1 parent 0517d80 commit 73e0e4f

21 files changed

+681
-218
lines changed

src/flow/src/expr/custom_func/mod.rs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,11 @@
1+
pub mod registry;
12
pub mod string_func;
23

34
use crate::expr::func::EvalError;
45
use datatypes::Value;
6+
pub use registry::{CustomFuncRegistry, CustomFuncRegistryError};
57
pub use string_func::ConcatFunc;
68

7-
/// List of functions that can be called through CallFunc (custom functions)
8-
pub const CUSTOM_FUNCTIONS: &[&str] = &[
9-
"concat",
10-
// Add more custom functions here as they are implemented
11-
];
12-
139
/// Custom function that can be implemented by users
1410
/// This trait allows users to define their own functions for evaluation
1511
pub trait CustomFunc: Send + Sync + std::fmt::Debug {
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
use super::{ConcatFunc, CustomFunc};
2+
use std::collections::HashMap;
3+
use std::sync::{Arc, RwLock};
4+
5+
#[derive(Debug, Clone, PartialEq)]
6+
pub enum CustomFuncRegistryError {
7+
AlreadyRegistered(String),
8+
}
9+
10+
/// Registry for scalar custom functions referenced in SQL (e.g. `concat(a, b)`).
11+
pub struct CustomFuncRegistry {
12+
functions: RwLock<HashMap<String, Arc<dyn CustomFunc>>>,
13+
}
14+
15+
impl CustomFuncRegistry {
16+
pub fn new() -> Self {
17+
Self {
18+
functions: RwLock::new(HashMap::new()),
19+
}
20+
}
21+
22+
pub fn with_builtins() -> Arc<Self> {
23+
let registry = Arc::new(Self::new());
24+
registry.register_builtin_functions();
25+
registry
26+
}
27+
28+
pub fn register_function(
29+
&self,
30+
function: Arc<dyn CustomFunc>,
31+
) -> Result<(), CustomFuncRegistryError> {
32+
let mut write = self
33+
.functions
34+
.write()
35+
.expect("custom func registry poisoned");
36+
let key = function.name().to_lowercase();
37+
if write.contains_key(&key) {
38+
return Err(CustomFuncRegistryError::AlreadyRegistered(key));
39+
}
40+
write.insert(key, function);
41+
Ok(())
42+
}
43+
44+
pub fn get(&self, name: &str) -> Option<Arc<dyn CustomFunc>> {
45+
self.functions
46+
.read()
47+
.expect("custom func registry poisoned")
48+
.get(&name.to_lowercase())
49+
.cloned()
50+
}
51+
52+
pub fn is_registered(&self, name: &str) -> bool {
53+
self.functions
54+
.read()
55+
.expect("custom func registry poisoned")
56+
.contains_key(&name.to_lowercase())
57+
}
58+
59+
pub fn list_names(&self) -> Vec<String> {
60+
let mut names: Vec<_> = self
61+
.functions
62+
.read()
63+
.expect("custom func registry poisoned")
64+
.keys()
65+
.cloned()
66+
.collect();
67+
names.sort();
68+
names
69+
}
70+
71+
fn register_builtin_functions(&self) {
72+
let _ = self.register_function(Arc::new(ConcatFunc));
73+
}
74+
}
75+
76+
impl Default for CustomFuncRegistry {
77+
fn default() -> Self {
78+
Self::new()
79+
}
80+
}
81+
82+
#[cfg(test)]
83+
mod tests {
84+
use super::*;
85+
use crate::expr::func::EvalError;
86+
use datatypes::Value;
87+
88+
#[derive(Debug)]
89+
struct DummyFn;
90+
91+
impl CustomFunc for DummyFn {
92+
fn validate_row(&self, _args: &[Value]) -> Result<(), EvalError> {
93+
Ok(())
94+
}
95+
96+
fn eval_row(&self, _args: &[Value]) -> Result<Value, EvalError> {
97+
Ok(Value::Null)
98+
}
99+
100+
fn name(&self) -> &str {
101+
"dummy"
102+
}
103+
}
104+
105+
#[test]
106+
fn register_and_resolve_custom_function() {
107+
let registry = CustomFuncRegistry::new();
108+
assert!(!registry.is_registered("dummy"));
109+
registry
110+
.register_function(Arc::new(DummyFn))
111+
.expect("register");
112+
assert!(registry.is_registered("dummy"));
113+
assert!(registry.get("dummy").is_some());
114+
assert!(registry.get("DuMmY").is_some());
115+
assert!(registry.get("missing").is_none());
116+
}
117+
118+
#[test]
119+
fn reject_duplicate_registration() {
120+
let registry = CustomFuncRegistry::new();
121+
registry
122+
.register_function(Arc::new(DummyFn))
123+
.expect("register");
124+
let err = registry
125+
.register_function(Arc::new(DummyFn))
126+
.expect_err("duplicate register should fail");
127+
assert_eq!(
128+
err,
129+
CustomFuncRegistryError::AlreadyRegistered("dummy".to_string())
130+
);
131+
}
132+
}

0 commit comments

Comments
 (0)