diff --git a/docs/architecture.md b/docs/architecture.md new file mode 100644 index 0000000..6c3c0ac --- /dev/null +++ b/docs/architecture.md @@ -0,0 +1,548 @@ +# GrandCypher Architecture + +> High-level architecture overview of the GrandCypher query engine + +## Table of Contents +- [Query Lifecycle](#query-lifecycle) +- [AST Structure](#ast-structure) +- [Expression Types](#expression-types) +- [Execution Model](#execution-model) +- [Nesting & Composition](#nesting--composition) + +--- + +## Query Lifecycle + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ QUERY LIFECYCLE │ +└─────────────────────────────────────────────────────────────────┘ + +Input Query String + │ + ▼ +┌─────────────────┐ +│ LARK PARSER │ Parse grammar → Parse Tree +└────────┬────────┘ + │ + ▼ +┌─────────────────┐ +│ TRANSFORMER │ Parse Tree → AST (Abstract Syntax Tree) +└────────┬────────┘ + │ + ▼ +┌─────────────────┐ +│ AST OBJECTS │ • MatchClause +│ (Query Plan) │ • WhereCondition +│ │ • ReturnItems +│ │ • OrderByClause +└────────┬────────┘ + │ + ▼ +┌─────────────────┐ +│ GRAPH MATCHING │ Find subgraph isomorphisms → Match objects +└────────┬────────┘ + │ + ▼ +┌─────────────────┐ +│ WHERE FILTER │ Evaluate conditions → Filter matches +└────────┬────────┘ + │ + ▼ +┌─────────────────┐ +│ RETURN BUILD │ • Evaluate scalar functions (per match) +│ │ • Collect values for aggregations +└────────┬────────┘ + │ + ▼ +┌─────────────────┐ +│ AGGREGATION │ Group and aggregate → Final values +└────────┬────────┘ + │ + ▼ +┌─────────────────┐ +│ ORDER BY │ Sort results +└────────┬────────┘ + │ + ▼ + Result Dict +``` + +--- + +## AST Structure + +The AST is composed of **Condition** objects that form a tree structure. + +### Condition Hierarchy + +``` + Condition (base) + │ + ┌────────────────┼────────────────┐ + │ │ │ + ScalarFunction AggregationFunc Comparison + (per-match) (grouped) (boolean) + │ │ │ + ┌────┴────┐ ┌────┴────┐ ┌────┴────┐ + │ │ │ │ │ │ + ID toLower COUNT SUM Compound AND/OR + Type toUpper AVG MAX │ + Size trim MIN COLLECT │ + coalesce EntityOp +``` + +### AST Node Types + +| Node Type | Purpose | Evaluation Time | Returns | +|-----------|---------|----------------|---------| +| **ScalarFunction** | Transform single values | Per match | Single value | +| **AggregationFunction** | Aggregate multiple values | After all matches | Grouped dict | +| **CompoundCondition** | Entity comparison | During WHERE | Boolean | +| **BoolCondition** | Logical operators | During WHERE | Boolean | +| **EntityAttributeGetter** | Entity resolution | Runtime | Entity value | + +--- + +## Expression Types + +### 1. Scalar Expressions +**Used in:** WHERE, RETURN, ORDER BY +**Evaluated:** Once per match +**Examples:** +```cypher +ID(n) +toLower(n.name) +trim(n.description) +coalesce(n.nickname, n.name) +``` + +**AST:** +``` +ScalarFunction + ├─ function_name: "toLower" + └─ argument: EntityAttributeGetter("n.name") +``` + +### 2. Aggregation Expressions +**Used in:** RETURN, ORDER BY (not WHERE!) +**Evaluated:** After all matches collected +**Examples:** +```cypher +COUNT(n) +SUM(n.value) +AVG(n.score) +MAX(n.timestamp) +COLLECT(n.tags) +``` + +**AST:** +``` +AggregationFunction + ├─ function_name: "COUNT" + ├─ entity: "n" + └─ attribute: "value" +``` + +### 3. Comparison Expressions +**Used in:** WHERE clause only +**Evaluated:** During match filtering +**Examples:** +```cypher +n.age > 20 +n.name = "Alice" +n.score >= 100 +``` + +**AST:** +``` +CompoundCondition + ├─ entity_id: EntityAttributeGetter("n.age") + ├─ operator: LambdaCompareCondition(">") + └─ value: 20 +``` + +### 4. Boolean Expressions +**Used in:** WHERE clause +**Evaluated:** During match filtering +**Examples:** +```cypher +n.age > 20 AND n.age < 65 +n.active = true OR n.verified = true +``` + +**AST:** +``` +AND + ├─ condition_a: CompoundCondition(n.age > 20) + └─ condition_b: CompoundCondition(n.age < 65) +``` + +--- + +## Execution Model + +### Phase 1: Parse Time (Static) +**What happens:** Build AST from query string + +``` +Query: MATCH (n) WHERE n.age > 20 RETURN toLower(n.name), COUNT(n) + +Parse Tree → Transformer → AST: + ├─ MatchClause(nodes=["n"]) + ├─ WhereCondition( + │ CompoundCondition( + │ entity="n.age", + │ operator=">", + │ value=20 + │ ) + │ ) + └─ ReturnItems([ + ScalarFunction(toLower, "n.name"), + AggregationFunction(COUNT, "n") + ]) +``` + +**Key:** All function objects created at parse time, stored as AST nodes. + +### Phase 2: Match Time (Graph Traversal) +**What happens:** Find subgraph patterns + +``` +Motif: (n) +Host Graph: A, B, C, D, E + +grandiso.find_motifs() → + Match(node_mappings={"n": "A"}) + Match(node_mappings={"n": "B"}) + Match(node_mappings={"n": "C"}) + ... +``` + +### Phase 3: Filter Time (WHERE Evaluation) +**What happens:** Evaluate conditions per match + +``` +For each match: + ┌──────────────────────────────────┐ + │ WHERE CONDITION EVALUATION │ + ├──────────────────────────────────┤ + │ CompoundCondition.__call__(): │ + │ 1. Resolve entity_id │ + │ n.age → host.nodes["A"]["age"] = 25 + │ 2. Apply operator │ + │ 25 > 20 → True │ + │ 3. Return (bool, results) │ + └──────────────────────────────────┘ + + Keep match if True, discard if False +``` + +**Timing:** **Scalar functions in WHERE evaluated here** (per match) + +### Phase 4: Return Time (Data Collection) +**What happens:** Evaluate RETURN expressions + +``` +For each passing match: + ┌──────────────────────────────────────┐ + │ SCALAR FUNCTION EVALUATION │ + ├──────────────────────────────────────┤ + │ toLower(n.name).__call__(): │ + │ 1. Get entity value │ + │ n.name → "Alice" │ + │ 2. Apply function │ + │ toLower("Alice") → "alice" │ + │ 3. Store result │ + │ results["toLower(n.name)"] = "alice" + └──────────────────────────────────────┘ + + Also collect values for aggregations: + scope["n"] = ["A", "B", "C", ...] +``` + +**Timing:** **Scalar functions in RETURN evaluated here** (per match) + +### Phase 5: Aggregation Time (After All Matches) +**What happens:** Compute aggregated values + +``` +After ALL matches collected: + ┌──────────────────────────────────────┐ + │ AGGREGATION FUNCTION EVALUATION │ + ├──────────────────────────────────────┤ + │ COUNT(n).evaluate(): │ + │ 1. Get scope values │ + │ scope["n"] = ["A", "B", "C"] │ + │ 2. Group by group_keys │ + │ No grouping → {(): ["A","B","C"]} + │ 3. Compute aggregation │ + │ Count non-null → 3 │ + │ 4. Return {(): 3} │ + └──────────────────────────────────────┘ +``` + +**Timing:** **Aggregations evaluated here** (after all matches) + +### Phase 6: Order & Format Time +**What happens:** Sort and finalize results + +``` +ORDER BY applied +Results formatted as dict +Return to user +``` + +--- + +## Nesting & Composition + +### Scalar Function Nesting + +Scalar functions can be nested arbitrarily deep: + +```cypher +RETURN toLower(trim(n.name)) +``` + +**AST:** +``` +ToLower + └─ argument: Trim + └─ argument: EntityAttributeGetter("n.name") +``` + +**Execution (inside-out):** +``` +1. EntityAttributeGetter("n.name").evaluate() → " Alice " +2. Trim.__call__(" Alice ") → "Alice" +3. ToLower.__call__("Alice") → "alice" +``` + +**Call Stack:** +``` +toLower.__call__(): + │ + ├─ Evaluate argument (Trim object) + │ └─ trim.__call__(): + │ │ + │ ├─ Evaluate argument (EntityAttributeGetter) + │ │ └─ get(" Alice ") + │ │ + │ └─ return "Alice" + │ + └─ return "alice" +``` + +### Composition in WHERE + +```cypher +WHERE toLower(n.name) = "alice" AND n.age > 20 +``` + +**AST:** +``` +AND + ├─ CompoundCondition + │ ├─ entity_id: ToLower(EntityAttributeGetter("n.name")) + │ ├─ operator: "=" + │ └─ value: "alice" + └─ CompoundCondition + ├─ entity_id: EntityAttributeGetter("n.age") + ├─ operator: ">" + └─ value: 20 +``` + +**Execution:** +``` +AND.__call__(): + ├─ Evaluate condition_a: + │ └─ ToLower(n.name) → "alice" + │ └─ "alice" = "alice" → True + │ + ├─ Evaluate condition_b: + │ └─ n.age → 25 + │ └─ 25 > 20 → True + │ + └─ True AND True → True +``` + +### Mixed Scalar + Aggregation + +```cypher +RETURN toLower(n.name), COUNT(n), SUM(n.value) +``` + +**AST:** +``` +ReturnItems: [ + ScalarFunction(toLower), ← Evaluated per match + AggregationFunction(COUNT), ← Evaluated after all matches + AggregationFunction(SUM) ← Evaluated after all matches +] +``` + +**Execution Timeline:** +``` +Time → +├─ Match phase: Find all matches +│ +├─ RETURN phase (per match): +│ └─ Evaluate toLower(n.name) for each match +│ Results: ["alice", "bob", "charlie"] +│ +└─ Aggregation phase (all matches): + ├─ COUNT(n): 3 + └─ SUM(n.value): 150 +``` + +--- + +## Key Architectural Patterns + +### 1. Wrap at Parse, Evaluate at Runtime +``` +Parse Time: Create ScalarFunction objects (AST nodes) +Run Time: Call __call__() to get actual values +``` + +### 2. Two-Phase Return Processing +``` +Phase 1 (per match): Evaluate scalar expressions +Phase 2 (after all): Evaluate aggregation expressions +``` + +### 3. Scope-Based Communication +``` +Scalar phase → Collect values → Store in scope dict + ↓ +Aggregation phase ← Read from scope ← Group & compute +``` + +### 4. AST Optimization +``` +WHERE conditions → to_indexer_ast() → IndexerAST + ↓ + Binary search + on indexed attrs +``` + +### 5. Entity vs Literal Distinction +``` +EntityAttributeGetter("n.name") ← Entity reference (resolve at runtime) +"Unknown" ← Literal value (use as-is) +``` + +--- + +## Summary: When Things Are Called + +| Expression Type | Created (Parse) | Called (Runtime) | Context | +|----------------|----------------|------------------|---------| +| **Scalar in WHERE** | ✓ AST node | During filtering | Per match | +| **Scalar in RETURN** | ✓ AST node | During collection | Per match | +| **Aggregation in RETURN** | ✓ AST node | After all matches | Grouped | +| **Comparison in WHERE** | ✓ AST node | During filtering | Per match | +| **Boolean (AND/OR)** | ✓ AST node | During filtering | Per match | + +**Rule of Thumb:** +- **Scalar:** Called for each match individually +- **Aggregation:** Called once with all matches +- **WHERE:** Evaluated during filtering (short-circuits) +- **RETURN:** Evaluated during collection & aggregation + +--- + +## Architecture Decisions + +### Why Two Function Types? + +**Scalars:** +- Operate on single values +- Can be used anywhere (WHERE, RETURN, ORDER BY) +- Example: `toLower("Alice")` → `"alice"` + +**Aggregations:** +- Require full dataset to compute +- Only in RETURN/ORDER BY (not WHERE) +- Example: `COUNT(["A", "B", "C"])` → `3` + +### Why EntityAttributeGetter? + +Without wrapping, parser can't distinguish: +```cypher +coalesce(n.name, "Unknown") + ^^^^^^^ ^^^^^^^^^ + entity ref literal +``` + +With wrapping: +```python +coalesce([ + EntityAttributeGetter("n.name"), # Resolve at runtime + "Unknown" # Use as-is +]) +``` + +### Why Scope Dictionary? + +Aggregations need pre-computed values: +```python +# Can't evaluate COUNT during match iteration +# Need all values first, then count + +matches = [match1, match2, match3] + +# Phase 1: Collect +scope = {"n": ["A", "B", "C"]} + +# Phase 2: Aggregate +COUNT.evaluate(scope) → 3 +``` + +--- + +## Example: Full Query Trace + +**Query:** +```cypher +MATCH (n:Person) +WHERE n.age > 20 AND toLower(n.name) STARTS WITH "a" +RETURN n.name, COUNT(n) AS total +ORDER BY total DESC +``` + +**AST:** +``` +Match: (n:Person) +Where: AND( + CompoundCondition(n.age > 20), + CompoundCondition(toLower(n.name) STARTS WITH "a") +) +Return: [ + EntityAttributeGetter("n.name"), + COUNT("n") +] +OrderBy: "total" DESC +``` + +**Execution:** +``` +1. PARSE: Query → AST (functions created, not called) +2. MATCH: Find all (n:Person) → 100 matches +3. WHERE: Filter each match: + - n.age > 20? (50 pass) + - toLower(n.name) starts with "a"? (10 pass) + Final: 10 matches +4. RETURN: For each 10 matches: + - Evaluate n.name → ["Alice", "Aaron", ...] + - Store scope["n"] = ["A", "B", ...] +5. AGGREGATE: COUNT(scope["n"]) → {(): 10} +6. ORDER: Sort by total DESC +7. RESULT: {"n.name": [...], "total": [10]} +``` + +--- + +**Last Updated:** 2025-12-06 +**Version:** 1.0 diff --git a/grandcypher/__init__.py b/grandcypher/__init__.py index 81cc93f..0fc9024 100644 --- a/grandcypher/__init__.py +++ b/grandcypher/__init__.py @@ -49,8 +49,9 @@ | "(" compound_condition boolean_arithmetic compound_condition ")" | compound_condition boolean_arithmetic compound_condition -condition : (entity_id | scalar_function) op entity_id_or_value - | (entity_id | scalar_function) op_list value_list +condition : (entity_id | scalar_function | list_predicate_function) op entity_id_or_value + | (entity_id | scalar_function | list_predicate_function) op_list value_list + | list_predicate_function | sub_query | "not"i condition -> condition_not @@ -87,11 +88,34 @@ return_item : (entity_id | aggregation_function | scalar_function | entity_id "." attribute_id) ( "AS"i alias )? alias : CNAME -aggregation_function : AGGREGATE_FUNC "(" entity_id ( "." attribute_id )? ")" -AGGREGATE_FUNC : "COUNT" | "SUM" | "AVG" | "MAX" | "MIN" +aggregation_function : AGGREGATE_FUNC "(" agg_argument ")" +agg_argument : scalar_function + | entity_id ( "." attribute_id )? +AGGREGATE_FUNC : "COUNT" | "SUM" | "AVG" | "MAX" | "MIN" | "COLLECT" attribute_id : CNAME scalar_function : "id"i "(" entity_id ")" -> id_function + | "size"i "(" list_expression ")" -> size_function + | "tolower"i "(" scalar_func_arg ")" -> tolower_function + | "toupper"i "(" scalar_func_arg ")" -> toupper_function + | "trim"i "(" scalar_func_arg ")" -> trim_function + | "type"i "(" entity_id ")" -> type_function + | "coalesce"i "(" coalesce_args ")" -> coalesce_function + +scalar_func_arg : scalar_function + | entity_id ("." attribute_id)? + +coalesce_args : coalesce_arg ("," coalesce_arg)* +coalesce_arg : value + | entity_id ("." attribute_id)? + +list_predicate_function : "all"i "(" CNAME "in"i list_expression "where"i compound_condition ")" -> all_function + | "any"i "(" CNAME "in"i list_expression "where"i compound_condition ")" -> any_function + | "none"i "(" CNAME "in"i list_expression "where"i compound_condition ")" -> none_function + | "single"i "(" CNAME "in"i list_expression "where"i compound_condition ")" -> single_function + +list_expression : "relationships"i "(" entity_id ")" -> relationships_function + | entity_id -> entity_list distinct_return : "DISTINCT"i limit_clause : "limit"i NUMBER @@ -578,6 +602,329 @@ def generate_multiedge_edge_hop_key( class Condition: ... + +class ScalarFunction(Condition): + """ + Base class for scalar functions that return a single value per row. + + Characteristics: + - Return single value (not aggregate) + - Can be used in WHERE clauses (comparison) + - Can be used in RETURN clauses (output) + - Evaluated once per match + + Examples: ID(), SIZE(), type(), timestamp() + """ + + def __call__(self, match: Match, host: nx.DiGraph, return_edges: list, + scope: Optional[dict] = None): + """Evaluate scalar function for a single match.""" + raise NotImplementedError + + def __str__(self) -> str: + """Return string representation for result keys.""" + raise NotImplementedError + + +class EntityAttributeGetter: + """ + Wrapper to distinguish entity references from literal values. + + This class is used to represent references to graph entities and their attributes + (e.g., n.name, n) in expressions, allowing the runtime to distinguish them from + literal string values (e.g., "Unknown"). + + Examples: + EntityAttributeGetter("n.name") represents n.name + EntityAttributeGetter("n") represents n + """ + + def __init__(self, expression: str): + """ + Initialize entity attribute getter from expression string. + + Args: + expression: The entity reference (e.g., "n", "n.name", "r.weight") + """ + # Parse expression: "n.name" -> entity="n", attribute="name" + if "." in expression: + self.entity, self.attribute = expression.split(".", 1) + else: + self.entity = expression + self.attribute = None + + def evaluate(self, match: Match, host: nx.DiGraph, + return_edges: dict = None, scope: dict = None): + """ + Evaluate this entity reference against a match. + + Priority order for resolution: + 1. Scope variables (highest priority - for list predicates) + 2. Node mappings (standard case) + 3. Edge mappings (for edge references) + 4. None (not found) + + Args: + match: The current match containing node mappings + host: The graph to query + return_edges: Optional edge mappings for edge references + scope: Optional scope dictionary for list predicate variables + + Returns: + The attribute value if found, None otherwise + """ + # 1. Check scope first (highest priority for list predicates) + if scope and self.entity in scope: + element = scope[self.entity] + if self.attribute: + # Scope variable with attribute access: e.related + return element.get(self.attribute) if isinstance(element, dict) else None + # Simple scope variable: e + return element + + # 2. Check node mappings (standard case) + if self.entity in match.node_mappings: + node_id = match.node_mappings[self.entity] + if self.attribute: + # Node with attribute: n.name + return host.nodes[node_id].get(self.attribute) + # Simple node reference: n - return full node dictionary + return dict(host.nodes[node_id]) + + # 3. Check edge mappings (for edge references) + if return_edges and self.entity in return_edges: + edge_mapping = return_edges[self.entity] + host_edges = match.mth.edge(*edge_mapping).edges + return get_edge_from_host(host, host_edges, self.attribute) + + return None + + def __str__(self) -> str: + """String representation for debugging.""" + if self.attribute: + return f"{self.entity}.{self.attribute}" + return self.entity + + def __repr__(self) -> str: + """Detailed representation for debugging.""" + if self.attribute: + return f"EntityAttributeGetter({self.entity!r}.{self.attribute!r})" + return f"EntityAttributeGetter({self.entity!r})" + + +class ID(ScalarFunction): + """ + Implements id() scalar function. + Returns the node ID from the host graph. + + Usage: + - WHERE: WHERE ID(n) = 1 + - RETURN: RETURN ID(n) AS nodeId + """ + + def __init__(self, entity_name: str): + self._entity_name = entity_name + + def __call__(self, match: Match, host: nx.DiGraph, return_edges: list, + scope: Optional[dict] = None): + """Return the node ID from the match.""" + if self._entity_name in match.node_mappings: + return match.node_mappings[self._entity_name] + else: + raise IndexError(f"Entity {self._entity_name} not in match.") + + def __str__(self) -> str: + return f"ID({self._entity_name})" + + +class AggregationFunction(Condition): + """ + Base class for aggregation functions that compute over multiple matches. + + Unlike ScalarFunction (per-match evaluation), AggregationFunction requires + the full result set to compute values. + + Characteristics: + - Requires multiple matches to evaluate + - Needs grouping context + - Used in RETURN and ORDER BY clauses + - Architecture supports future WITH statements + + Examples: COUNT, SUM, AVG, MAX, MIN + """ + + def __init__(self, entity, entity_attribute: Optional[str] = None): + """ + Initialize aggregation function. + + Args: + entity: Entity name (e.g., 'r' in COUNT(r)) OR ScalarFunction (e.g., size(r) in AVG(size(r))) + entity_attribute: Optional attribute (e.g., 'value' in SUM(r.value)) + """ + self._entity = entity + self._entity_attribute = entity_attribute + self._is_scalar_function = isinstance(entity, ScalarFunction) + + def evaluate(self, matches: List[Match], host: nx.DiGraph, + return_edges: dict, group_keys: List[str], scope: dict) -> Dict[tuple, Any]: + """ + Evaluate aggregation over all matches with grouping. + + Args: + matches: All matches to aggregate over + host: Target graph + return_edges: Edge mappings + group_keys: Entity names to group by (e.g., ["n.name"]) + scope: Pre-computed values from outer context (e.g., from _lookup) + + Returns: + Dict mapping group_tuple -> aggregated_value + """ + raise NotImplementedError(f"{self.__class__.__name__}.evaluate() must be implemented") + + def __str__(self) -> str: + """String representation for result keys: COUNT(r) or SUM(r.value) or AVG(size(r))""" + if self._is_scalar_function: + # Scalar function: use its string representation + entity_str = str(self._entity) + else: + # Regular entity reference + entity_str = self._entity + if self._entity_attribute: + entity_str += f".{self._entity_attribute}" + return f"{self.__class__.__name__}({entity_str})" + + def _group_matches(self, scope: dict, group_keys: List[str], + matches: List[Match] = None, host: nx.DiGraph = None, + return_edges: dict = None) -> Dict[tuple, List[Any]]: + """ + Group values and extract data for aggregation using scope. + + Uses pre-computed values from scope when available (RETURN case), + OR evaluates scalar functions per-match if entity is a ScalarFunction. + + Args: + scope: Pre-computed values from outer context (e.g., {"n.name": [...], "r.value": [...]}) + group_keys: Keys to group by (e.g., ["n.name"]) + matches: List of matches (needed for scalar function evaluation) + host: Target graph (needed for scalar function evaluation) + return_edges: Edge mappings (needed for scalar function evaluation) + + Returns: + Dict mapping group_tuple -> list of values to aggregate + """ + # If entity is a scalar function, evaluate it for each match + if self._is_scalar_function: + entity_values = [] + for match in matches: + value = self._entity(match, host, return_edges, scope=None) + entity_values.append(value) + else: + # Build entity path: "r" or "r.value" + entity_path = self._entity + ('.' + self._entity_attribute if self._entity_attribute else '') + + # Get entity values from scope (already extracted by _lookup) + entity_values = scope.get(entity_path, []) + + # Group by group_keys (adapted from old aggregate() method) + grouped_data = {} + for i in range(len(entity_values)): + # Build group tuple from scope values + group_tuple = tuple(scope.get(key, [])[i] if i < len(scope.get(key, [])) else None + for key in group_keys) + + if group_tuple not in grouped_data: + grouped_data[group_tuple] = [] + grouped_data[group_tuple].append(entity_values[i]) + + return grouped_data + + +class COUNT(AggregationFunction): + """ + COUNT aggregation function. + Returns the number of non-null values. + """ + + def evaluate(self, matches: List[Match], host: nx.DiGraph, + return_edges: dict, group_keys: List[str], scope: dict) -> Dict[tuple, int]: + grouped = self._group_matches(scope, group_keys, matches, host, return_edges) + # COUNT only counts non-null values (filter out None) + return {group: sum(1 for v in values if v is not None) for group, values in grouped.items()} + + +class SUM(AggregationFunction): + """ + SUM aggregation function. + Returns the sum of numeric values (None treated as 0). + """ + + def evaluate(self, matches: List[Match], host: nx.DiGraph, + return_edges: dict, group_keys: List[str], scope: dict) -> Dict[tuple, float]: + grouped = self._group_matches(scope, group_keys, matches, host, return_edges) + # SUM treats None as 0 + return {group: sum(v or 0 for v in values) + for group, values in grouped.items()} + + +class AVG(AggregationFunction): + """ + AVG aggregation function. + Returns the average of numeric values (None treated as 0). + """ + + def evaluate(self, matches: List[Match], host: nx.DiGraph, + return_edges: dict, group_keys: List[str], scope: dict) -> Dict[tuple, float]: + grouped = self._group_matches(scope, group_keys, matches, host, return_edges) + # AVG treats None as 0 + result = {} + for group, values in grouped.items(): + collated = [v or 0 for v in values] + result[group] = sum(collated) / len(collated) if collated else 0 + return result + + +class MAX(AggregationFunction): + """ + MAX aggregation function. + Returns the maximum value (None treated as negative infinity). + """ + + def evaluate(self, matches: List[Match], host: nx.DiGraph, + return_edges: dict, group_keys: List[str], scope: dict) -> Dict[tuple, Any]: + grouped = self._group_matches(scope, group_keys, matches, host, return_edges) + # MAX treats None as -infinity + return {group: max((d if d is not None else -float("inf")) for d in values) + for group, values in grouped.items()} + + +class MIN(AggregationFunction): + """ + MIN aggregation function. + Returns the minimum value (None treated as positive infinity). + """ + + def evaluate(self, matches: List[Match], host: nx.DiGraph, + return_edges: dict, group_keys: List[str], scope: dict) -> Dict[tuple, Any]: + grouped = self._group_matches(scope, group_keys, matches, host, return_edges) + # MIN treats None as +infinity + return {group: min((d if d is not None else float("inf")) for d in values) + for group, values in grouped.items()} + + +class COLLECT(AggregationFunction): + """ + COLLECT aggregation function. + Collects all values into a list (like SQL's array_agg). + """ + + def evaluate(self, matches: List[Match], host: nx.DiGraph, + return_edges: dict, group_keys: List[str], scope: dict) -> Dict[tuple, list]: + grouped = self._group_matches(scope, group_keys, matches, host, return_edges) + # Collect all values (including None) into lists + return {group: list(values) for group, values in grouped.items()} + + class BoolCondition(Condition): ... @@ -588,9 +935,9 @@ def __init__(self, condition_a: CONDITION, condition_b: CONDITION): self._condition_b = condition_b self._operator = "and" - def __call__(self, match: dict, host: nx.DiGraph, return_edges: list) -> bool: - condition_a, where_a = self._condition_a(match, host, return_edges) - condition_b, where_b = self._condition_b(match, host, return_edges) + def __call__(self, match: dict, host: nx.DiGraph, return_edges: list, scope: dict = None) -> bool: + condition_a, where_a = self._condition_a(match, host, return_edges, scope) + condition_b, where_b = self._condition_b(match, host, return_edges, scope) where_result = [a and b for a, b in zip(where_a, where_b)] return (condition_a and condition_b), where_result @@ -601,9 +948,9 @@ def __init__(self, condition_a: CONDITION, condition_b: CONDITION): self._condition_b = condition_b self._operator = "or" - def __call__(self, match: dict, host: nx.DiGraph, return_edges: list) -> tuple[bool, dict]: - condition_a, where_a = self._condition_a(match, host, return_edges) - condition_b, where_b = self._condition_b(match, host, return_edges) + def __call__(self, match: dict, host: nx.DiGraph, return_edges: list, scope: dict = None) -> tuple[bool, dict]: + condition_a, where_a = self._condition_a(match, host, return_edges, scope) + condition_b, where_b = self._condition_b(match, host, return_edges, scope) where_result = [a or b for a, b in zip(where_a, where_b)] return (condition_a or condition_b), where_result @@ -626,57 +973,540 @@ def __str__(self) -> str: class CompoundCondition(Condition): """compound condition""" - def __init__(self, should_be: bool, entity_id: str, operator, value): + def __init__(self, should_be: bool, entity_id, operator, value): + """ + Initialize CompoundCondition. + + Args: + should_be: Boolean expectation for the condition + entity_id: Either a ScalarFunction or string entity reference (e.g., "n.name") + operator: Comparison operator + value: Value to compare against + """ self._should_be = should_be - self._entity_id = entity_id self._operator = operator self._value = value - def __str__(self): - return f"compound of {self._operator} for key {self._entity_id}: value {self._value}" - - # def __call__(self, match: dict, host: nx.DiGraph, return_edges: list, edge_hop_map = None, edge_hop_key = None) -> bool: - def __call__(self, match: Match, host: nx.DiGraph, return_edges: list) -> bool: - # Check if this is an ID function call - if self._entity_id.startswith("ID(") and self._entity_id.endswith(")"): - # Extract the entity name from ID(entity_name) - actual_entity_name = self._entity_id[3:-1] # Remove "ID(" and ")" - if actual_entity_name in match.node_mappings: - # Return the node ID directly - node_id = match.node_mappings[actual_entity_name] - try: - val = self._operator(node_id, self._value) - except: - val = False - operator_results = [val] - else: - raise IndexError(f"Entity {actual_entity_name} not in match.") + # Wrap entity references in EntityAttributeGetter at init time + if isinstance(entity_id, ScalarFunction): + self._entity_id = entity_id else: - host_entity_id = self._entity_id.split(".") - if isinstance(self._operator, SUBOP): - # SUBOP operator doesn't need a entity id. - val = self._operator(match.node_mappings, self._value) - elif host_entity_id[0] in match.node_mappings: - # Regular entity attribute access - host_entity_id[0] = match.node_mappings[host_entity_id[0]] - val = self._operator(get_node_from_host(host, host_entity_id[0], host_entity_id[1]), self._value) - elif host_entity_id[0] in return_edges: - # looking for edge... - entity_name, entity_attribute = _data_path_to_entity_name_attribute(self._entity_id) - edge_mapping = return_edges[entity_name] - - host_edges = match.mth.edge(*edge_mapping).edges - val = self._operator(get_edge_from_host(host, host_edges, entity_attribute), self._value) - else: - raise IndexError(f"Entity {host_entity_id} not in graph.") + # Store both the original string and the getter + self._entity_id_str = entity_id + self._entity_id = EntityAttributeGetter(entity_id) + + def __str__(self): + entity_repr = self._entity_id_str if hasattr(self, '_entity_id_str') else str(self._entity_id) + return f"compound of {self._operator} for key {entity_repr}: value {self._value}" + + def __call__(self, match: Match, host: nx.DiGraph, return_edges: list, scope: dict = None) -> bool: + # Handle scalar functions (ID, SIZE, etc.) + if isinstance(self._entity_id, ScalarFunction): + # Evaluate scalar function to get value + scalar_value = self._entity_id(match, host, return_edges, scope) + + # Apply comparison operator + val = self._operator(scalar_value, self._value) + operator_results = [val] + + if val is None: + val = False + if val != self._should_be: + return False, operator_results + return True, operator_results + + # Handle SUBOP operators (special case - don't need entity resolution) + if isinstance(self._operator, SUBOP): + val = self._operator(match.node_mappings, self._value) + operator_results = [val] + if val is None: + val = False + if val != self._should_be: + return False, operator_results + return True, operator_results + + # Use EntityAttributeGetter for all entity references (handles scope, nodes, edges) + entity_value = self._entity_id.evaluate(match, host, return_edges, scope) + + # Apply comparison operator + val = self._operator(entity_value, self._value) operator_results = [val] + if val is None: - val is False + val = False if val != self._should_be: return False, operator_results return True, operator_results +# ==================== List Expression Classes for all()/any() ==================== + +class ListExpression: + """Base class for list expressions that can reference scope.""" + + def evaluate(self, match: Match, host: nx.DiGraph, return_edges: list, scope: Optional[dict] = None) -> list: + """Evaluate to get list, possibly using scope variables.""" + raise NotImplementedError + + +class ScopedListExpression(ListExpression): + """ + List expression that can reference scope variables. + + Examples: + - 'r' → get edge list from return_edges + - 'e.related' → get 'related' attr from scope variable 'e' + """ + def __init__(self, expr: str): + """ + Initialize ScopedListExpression. + + Args: + expr: Expression string (e.g., "r", "e.related") + """ + self._expr = expr + # Wrap expression in EntityAttributeGetter at init time + self._getter = EntityAttributeGetter(expr) + + def evaluate(self, match: Match, host: nx.DiGraph, return_edges: list, scope: Optional[dict] = None): + # Use EntityAttributeGetter to resolve the value + value = self._getter.evaluate(match, host, return_edges, scope) + + # Normalize to list + if value is None: + return [] + elif isinstance(value, list): + return value + elif isinstance(value, dict): + return [value] + else: + # Scalar value - wrap in list + return [value] + + +class RelationshipsFunction(ListExpression): + """Implements relationships() function.""" + def __init__(self, path_variable: str): + self._path_variable = path_variable + + def evaluate(self, match: Match, host: nx.DiGraph, return_edges: list, scope: Optional[dict] = None): + # Use ScopedListExpression to handle scope references + return ScopedListExpression(self._path_variable).evaluate(match, host, return_edges, scope) + + +# ==================== ALL and ANY Condition Classes ==================== + +class ALL(Condition): + """ + Implements all() predicate function. + + REUSES existing Condition classes (CompoundCondition, AND, OR)! + """ + def __init__(self, name: str, list_expr: str, pred): + self._name = name # Loop variable name + self._list_expr = list_expr # String expr or ListExpression object + self._pred = pred # Regular Condition (CompoundCondition, AND, OR, etc.) + + def __call__(self, match, host: nx.DiGraph, return_edges: list, + scope: dict = None) -> tuple[bool, list]: + # 1. Evaluate list expression (may reference scope) + if isinstance(self._list_expr, str): + list_obj = ScopedListExpression(self._list_expr) + else: + list_obj = self._list_expr + + elements = list_obj.evaluate(match, host, return_edges, scope) + + # 2. Handle empty/null lists (Neo4j semantics) + if elements is None: + return None, [None] + if not elements: + return True, [True] # Vacuously true + + # 3. Iterate and evaluate predicate for each element + for element in elements: + # Create new scope with current element + new_scope = {**scope} if scope else {} + new_scope[self._name] = element + + # Evaluate predicate with scope (pred is a regular Condition!) + result, _ = self._pred(match, host, return_edges, new_scope) + + # Short-circuit on False + if result is False: + return False, [False] + elif result is None: + # Track null for Neo4j semantics + pass + + return True, [True] + + +class ANY(Condition): + """Similar structure to ALL but returns True if any element satisfies.""" + def __init__(self, name: str, list_expr: str, pred): + self._name = name + self._list_expr = list_expr + self._pred = pred # Regular Condition! + + def __call__(self, match, host, return_edges, scope=None): + # Evaluate list + if isinstance(self._list_expr, str): + list_obj = ScopedListExpression(self._list_expr) + else: + list_obj = self._list_expr + + elements = list_obj.evaluate(match, host, return_edges, scope) + + # Handle empty/null (Neo4j semantics) + if elements is None: + return None, [None] + if not elements: + return False, [False] # False for empty list + + # Iterate + for element in elements: + new_scope = {**scope} if scope else {} + new_scope[self._name] = element + + # Evaluate predicate with scope + result, _ = self._pred(match, host, return_edges, new_scope) + + # Short-circuit on True + if result is True: + return True, [True] + + return False, [False] + + +class NONE(Condition): + """ + Implements none() predicate function. + Returns true when NO elements satisfy the predicate. + + Neo4j semantics: + - Empty list [] → True (no elements violate) + - Null list → None + - Short-circuits on first True (element satisfies) + """ + def __init__(self, name: str, list_expr, pred: Condition): + self._name = name # Loop variable name + self._list_expr = list_expr # String expr or ListExpression + self._pred = pred # Regular Condition + + def __call__(self, match: Match, host: nx.DiGraph, return_edges: list, + scope: Optional[dict] = None): + # Evaluate list expression + if isinstance(self._list_expr, str): + list_obj = ScopedListExpression(self._list_expr) + else: + list_obj = self._list_expr + + elements = list_obj.evaluate(match, host, return_edges, scope) + + # Handle empty/null (Neo4j semantics) + if elements is None: + return None, [None] + if not elements: + return True, [True] # Vacuously true (no elements violate) + + # Iterate and check for violations + for element in elements: + # Create new scope with current element + new_scope = {**scope} if scope else {} + new_scope[self._name] = element + + # Evaluate predicate with scope + result, _ = self._pred(match, host, return_edges, new_scope) + + # Short-circuit on True (found element that satisfies) + if result is True: + return False, [False] # none() fails if any element satisfies + + return True, [True] # No elements satisfied the predicate + + +class SINGLE(Condition): + """ + Implements single() predicate function. + Returns true when EXACTLY ONE element satisfies the predicate. + + Neo4j semantics: + - Empty list [] → False (not exactly one) + - Null list → None + - Cannot short-circuit (must check all elements to count) + """ + def __init__(self, name: str, list_expr, pred: Condition): + self._name = name # Loop variable name + self._list_expr = list_expr # String expr or ListExpression + self._pred = pred # Regular Condition + + def __call__(self, match: Match, host: nx.DiGraph, return_edges: list, + scope: Optional[dict] = None): + # Evaluate list expression + if isinstance(self._list_expr, str): + list_obj = ScopedListExpression(self._list_expr) + else: + list_obj = self._list_expr + + elements = list_obj.evaluate(match, host, return_edges, scope) + + # Handle empty/null (Neo4j semantics) + if elements is None: + return None, [None] + if not elements: + return False, [False] # Empty doesn't have exactly one + + # Count satisfying elements + count = 0 + has_null = False + + for element in elements: + # Create new scope with current element + new_scope = {**scope} if scope else {} + new_scope[self._name] = element + + # Evaluate predicate with scope + result, _ = self._pred(match, host, return_edges, new_scope) + + if result is True: + count += 1 + # Early exit if count > 1 + if count > 1: + return False, [False] + elif result is None: + has_null = True + + # Return based on count + if count == 1: + return True, [True] + elif count == 0 and has_null: + return None, [None] # Uncertain due to nulls + else: + return False, [False] + + +class SIZE(ScalarFunction): + """ + Implements size() scalar function. + Returns the length of a list as an integer. + + Can be used in: + - WHERE clauses: WHERE size(r) > 2 + - RETURN clauses: RETURN size(r) AS pathLength + + Neo4j semantics: + - Null list → None + - Empty list [] → 0 + """ + def __init__(self, list_expr): + self._list_expr = list_expr # String expr or ListExpression + + def __call__(self, match: Match, host: nx.DiGraph, return_edges: list, + scope: Optional[dict] = None): + # Evaluate list expression + if isinstance(self._list_expr, str): + list_obj = ScopedListExpression(self._list_expr) + else: + list_obj = self._list_expr + + elements = list_obj.evaluate(match, host, return_edges, scope) + + # Handle null + if elements is None: + return None + + # Return length + return len(elements) + + def __str__(self) -> str: + """Return string representation for result keys.""" + expr_str = self._list_expr if isinstance(self._list_expr, str) else "..." + return f"size({expr_str})" + + +class ToLower(ScalarFunction): + """ + Implements toLower() scalar function. + Converts a string to lowercase. + """ + + def __init__(self, expression): + """ + Initialize toLower with an expression. + + Args: + expression: Either a ScalarFunction or string entity reference (e.g., "n.name") + """ + if isinstance(expression, ScalarFunction): + self._expression = expression + else: + # Wrap entity reference in EntityAttributeGetter at init time + self._expression = EntityAttributeGetter(expression) + + def __call__(self, match: Match, host: nx.DiGraph, return_edges: list, + scope: Optional[dict] = None): + # Evaluate expression (either ScalarFunction or EntityAttributeGetter) + if isinstance(self._expression, ScalarFunction): + value = self._expression(match, host, return_edges, scope) + else: + # It's an EntityAttributeGetter + value = self._expression.evaluate(match, host, return_edges, scope) + + return value.lower() if isinstance(value, str) else value + + def __str__(self) -> str: + return f"toLower({self._expression})" + + +class ToUpper(ScalarFunction): + """ + Implements toUpper() scalar function. + Converts a string to uppercase. + """ + + def __init__(self, expression): + """ + Initialize toUpper with an expression. + + Args: + expression: Either a ScalarFunction or string entity reference (e.g., "n.name") + """ + if isinstance(expression, ScalarFunction): + self._expression = expression + else: + # Wrap entity reference in EntityAttributeGetter at init time + self._expression = EntityAttributeGetter(expression) + + def __call__(self, match: Match, host: nx.DiGraph, return_edges: list, + scope: Optional[dict] = None): + # Evaluate expression (either ScalarFunction or EntityAttributeGetter) + if isinstance(self._expression, ScalarFunction): + value = self._expression(match, host, return_edges, scope) + else: + # It's an EntityAttributeGetter + value = self._expression.evaluate(match, host, return_edges, scope) + + return value.upper() if isinstance(value, str) else value + + def __str__(self) -> str: + return f"toUpper({self._expression})" + + +class Trim(ScalarFunction): + """ + Implements trim() scalar function. + Trims whitespace from a string. + """ + + def __init__(self, expression): + """ + Initialize trim with an expression. + + Args: + expression: Either a ScalarFunction or string entity reference (e.g., "n.name") + """ + if isinstance(expression, ScalarFunction): + self._expression = expression + else: + # Wrap entity reference in EntityAttributeGetter at init time + self._expression = EntityAttributeGetter(expression) + + def __call__(self, match: Match, host: nx.DiGraph, return_edges: list, + scope: Optional[dict] = None): + # Evaluate expression (either ScalarFunction or EntityAttributeGetter) + if isinstance(self._expression, ScalarFunction): + value = self._expression(match, host, return_edges, scope) + else: + # It's an EntityAttributeGetter + value = self._expression.evaluate(match, host, return_edges, scope) + + return value.strip() if isinstance(value, str) else value + + def __str__(self) -> str: + return f"trim({self._expression})" + + +class Type(ScalarFunction): + """ + Implements type() scalar function. + Returns the type/label of a relationship. + """ + + def __init__(self, expression: str): + self._expression = expression # e.g., "r" + + def __call__(self, match: Match, host: nx.DiGraph, return_edges: list, + scope: Optional[dict] = None): + # Type works on relationships only + # Use return_edges to find the edge + if self._expression in return_edges: + edge_mapping = return_edges[self._expression] + host_edges = match.mth.edge(*edge_mapping).edges + edge_data = get_edge_from_host(host, host_edges, None) + + # edge_data might be a dict or list + if isinstance(edge_data, dict): + # Single edge + labels = edge_data.get('__labels__', set()) + if labels: + return list(labels)[0] if isinstance(labels, set) else labels + elif isinstance(edge_data, list) and len(edge_data) > 0: + # Multiple edges - return first label + labels = edge_data[0].get('__labels__', set()) + if labels: + return list(labels)[0] if isinstance(labels, set) else labels + + return None + + def __str__(self) -> str: + return f"type({self._expression})" + + +class Coalesce(ScalarFunction): + """ + Implements coalesce() scalar function. + Returns the first non-null value from the argument list. + """ + + def __init__(self, expressions: list): + self._expressions = expressions # List of expressions + + def __call__(self, match: Match, host: nx.DiGraph, return_edges: list, + scope: Optional[dict] = None): + # Evaluate each expression and return first non-null + for expr in self._expressions: + # Check if it's an EntityAttributeGetter (entity reference) + if isinstance(expr, EntityAttributeGetter): + # It's an entity reference like n.name or n + value = expr.evaluate(match, host) + if value is not None: + return value + else: + # It's a literal value (string, number, bool, None, etc.) + if expr is not None: + return expr + return None + + def __str__(self) -> str: + # Format expressions for display + expr_strs = [] + for expr in self._expressions: + if isinstance(expr, EntityAttributeGetter): + # Entity reference: n.name or n + expr_strs.append(str(expr)) + elif isinstance(expr, str): + # String literal: show with quotes + expr_strs.append(f'"{expr}"') + else: + # Other literals: numbers, booleans, None + expr_strs.append(repr(expr)) + return f"coalesce({', '.join(expr_strs)})" + + +# ==================== End of List Predicate Classes ==================== + + def none_wrapper(func) -> Callable[[Any, Any], Union[bool, None]]: def inner(x, y) -> Union[bool, None]: try: @@ -710,15 +1540,21 @@ def inner(x, y) -> Union[bool, None]: def _data_path_to_entity_name_attribute(data_path): + """ + Parse data path into entity name and attribute using EntityAttributeGetter. + + Args: + data_path: String path (e.g., "n.name", "n") or Token containing the path + + Returns: + tuple: (entity_name, entity_attribute) + """ if isinstance(data_path, Token): data_path = data_path.value - if "." in data_path: - entity_name, entity_attribute = data_path.split(".") - else: - entity_name = data_path - entity_attribute = None - return entity_name, entity_attribute + # Use EntityAttributeGetter to parse the path + getter = EntityAttributeGetter(data_path) + return getter.entity, getter.attribute # this is to convert WHERE OPERATOR to INDEXER OPERATOR @@ -761,8 +1597,19 @@ def to_indexer_ast(condition: Condition, entity_id = None, value = None, should_ should_be=condition._should_be) if (isinstance(condition, LambdaCompareCondition) and condition._operator in WHERE_OPERATORS_TO_INDEXER_OPERATORS): - if entity_id.startswith("ID()"): - entity_id = entity_id[3:-1] + # Handle scalar functions + if isinstance(entity_id, ID): + # ID() can be optimized - extract entity name + entity_id = entity_id._entity_name + elif isinstance(entity_id, EntityAttributeGetter): + # EntityAttributeGetter can be optimized - extract full expression + if entity_id.attribute: + entity_id = f"{entity_id.entity}.{entity_id.attribute}" + else: + entity_id = entity_id.entity + elif not isinstance(entity_id, str): + # Other scalar functions can't be optimized by indexer + return IndexerUnsupportedOp(condition, entity_id, value) operator = condition._operator if should_be is True: operator = WHERE_OPERATORS_TO_INDEXER_OPERATORS[operator] @@ -864,25 +1711,32 @@ def _lookup(self, data_paths: List[str], offset_limit) -> Dict[str, List]: result = {} processed_paths = set() # Keep track of processed paths - # handling RETURN ID(A) + # Handle all scalar functions (ID, SIZE, future functions) - UNIFIED! + for data_path in data_paths: + if isinstance(data_path, ScalarFunction): + # Evaluate scalar function for each match + ret = [] + for match in true_matches: + result_value = data_path( + match, + self._target_graph, + self._return_edges, + scope=None + ) + ret.append(result_value) + + # Use str(data_path) as key: "ID(A)", "size(r)", etc. + result[str(data_path)] = ret[offset_limit] + processed_paths.add(data_path) + processed_paths.add(str(data_path)) + continue + + # Validate entity names for non-scalar-function data paths for data_path in data_paths: + if isinstance(data_path, ScalarFunction): + continue # Skip scalar functions, already processed + entity_name, _ = _data_path_to_entity_name_attribute(data_path) - # Special handling for ID function - if entity_name.upper().startswith("ID(") and entity_name.endswith(")"): - # Extract the original entity name - original_entity = entity_name[3:-1] - if original_entity in motif_nodes: - # Return the node ID directly instead of the node attributes - ret = [match.mth.node(original_entity) for match in true_matches] - result[data_path] = ret[offset_limit] - result[original_entity] = ret[ - offset_limit - ] # Also store under original entity name - processed_paths.add(data_path) # Mark as processed - processed_paths.add( - original_entity - ) # Mark original also as processed - continue if ( entity_name not in motif_nodes and entity_name not in self._return_edges @@ -900,22 +1754,14 @@ def _lookup(self, data_paths: List[str], offset_limit) -> Dict[str, List]: if entity_name in motif_nodes: # We are looking for a node mapping in the target graph: - - if entity_attribute: - # Get the correct entity from the target host graph, - # and then return the attribute: - ret = ( - self._target_graph.nodes[match.mth.node(entity_name)].get( - entity_attribute, None - ) - for match in true_matches - ) - else: - # Return the full node dictionary with all attributes - ret = ( - self._target_graph.nodes[match.mth.node(entity_name)] - for match in true_matches - ) + # Use EntityAttributeGetter for consistent entity access + # If entity_attribute exists: returns specific attribute value + # If entity_attribute is None: returns full node dictionary + getter = EntityAttributeGetter(data_path) + ret = ( + getter.evaluate(match, self._target_graph, self._return_edges) + for match in true_matches + ) elif entity_name in self._paths: ret = [] @@ -938,18 +1784,13 @@ def _lookup(self, data_paths: List[str], offset_limit) -> Dict[str, List]: ret.append(path) else: - edge_mapping = self._return_edges[entity_name] # We are looking for an edge mapping in the target graph: - ret = [] - for match in true_matches: - host_edges = match.mth.edge(*edge_mapping).edges - ret.append( - get_edge_from_host( - self._target_graph, - host_edges, - entity_attribute, - ) - ) + # Use EntityAttributeGetter for consistent entity access + getter = EntityAttributeGetter(data_path) + ret = ( + getter.evaluate(match, self._target_graph, self._return_edges) + for match in true_matches + ) result[data_path] = list(ret)[offset_limit] return result @@ -1014,27 +1855,40 @@ def returns(self, ignore_limit=False): offset_limit=slice(0, None), ) if len(self._aggregate_functions) > 0: + # Determine group keys: exclude keys that end with aggregated entity path + # (matches old aggregate() logic) group_keys = [ key for key in results.keys() - if not any(key.endswith(func[1]) for func in self._aggregate_functions) + if not any( + not agg_func._is_scalar_function and key.endswith( + agg_func._entity + ('.' + agg_func._entity_attribute if agg_func._entity_attribute else '') + ) + for agg_func in self._aggregate_functions + ) ] aggregated_results = {} - for func, entity in self._aggregate_functions: - aggregated_data = self.aggregate(func, results, entity, group_keys) + # Evaluate each aggregation function using scope-based architecture + for agg_func in self._aggregate_functions: + # Call evaluate() with scope (results dict from _lookup) + aggregated_data = agg_func.evaluate( + self._get_true_matches(), + self._target_graph, + self._return_edges, + group_keys, + scope=results # Pass results as scope + ) aggregated_values = list(aggregated_data.values()) aggregated_keys = list(aggregated_data.keys()) - func_key = self._format_aggregation_key(func, entity) + # Use str(agg_func) for result key: "COUNT(r)", "SUM(r.value)", etc. + func_key = str(agg_func) aggregated_results[func_key] = aggregated_values self._return_requests.append(func_key) - # TODO: the group_keys is the same for all func - # let's have aggregated keys 1st - # then have aggregated values - # so we don't have to repeat the groups key population here - # for i in range(len(gro up_keys)): - # results[group_keys[i]] = [k[i] for k in aggregated_keys] + + # Merge aggregated results with regular results results.update(aggregated_results) + # Reconstruct grouped results for i in range(len(group_keys)): results[group_keys[i]] = [k[i] for k in aggregated_keys] @@ -1050,16 +1904,19 @@ def returns(self, ignore_limit=False): # Only after all other transformations, apply pagination results = self._apply_pagination(results, ignore_limit) - self._return_requests = list(map(str, self._return_requests)) + # Convert all return_requests to strings (including scalar functions) for key matching + # Use a local variable to avoid modifying self._return_requests (breaks reusability) + return_requests_str = [str(item) for item in self._return_requests] # Only include keys that were asked for in `RETURN` in the final results results = { self._entity2alias.get(key, key): values for key, values in results.items() - if key in self._return_requests - or self._alias2entity.get(key, key) in self._return_requests + if key in return_requests_str + or self._alias2entity.get(key, key) in return_requests_str } + # TODO: remove this hack # HACK: convert to [None] if edge is None for key, values in results.items(): parsed_values = [] @@ -1478,36 +2335,85 @@ def return_clause(self, clause): alias = self._extract_alias(item) item = item.children[0] if isinstance(item, Tree) else item if isinstance(item, Tree) and item.data == "aggregation_function": - func, entity = self._parse_aggregation_token(item) + # Parse to AggregationFunction object (not tuple) + agg_func = self._parse_aggregation_token(item) if alias: - self._executors[-1]._entity2alias[ - self._executors[-1]._format_aggregation_key(func, entity) - ] = alias - self._executors[-1]._aggregation_attributes.add(entity) - self._executors[-1]._aggregate_functions.append((func, entity)) + # Use str(agg_func) for alias key: "COUNT(r)", "SUM(r.value)", etc. + self._executors[-1]._entity2alias[str(agg_func)] = alias + # Add full entity path to aggregation_attributes for _lookup (only for entity references) + if not agg_func._is_scalar_function: + entity_path = agg_func._entity + ('.' + agg_func._entity_attribute if agg_func._entity_attribute else '') + self._executors[-1]._aggregation_attributes.add(entity_path) + # Store AggregationFunction object, not tuple + self._executors[-1]._aggregate_functions.append(agg_func) else: - if not isinstance(item, str): + # Handle scalar functions (ID, SIZE, etc.) - keep object for evaluation + if isinstance(item, ScalarFunction): + # Keep scalar function object + self._executors[-1]._original_return_requests.add(item) + if alias: + # Use str(item) for alias key: "ID(A)", "size(r)", etc. + self._executors[-1]._entity2alias[str(item)] = alias + self._executors[-1]._return_requests.append(item) + elif not isinstance(item, str): + # Convert non-string, non-scalar-function items to string item = str(item.value) - self._executors[-1]._original_return_requests.add(item) - - if alias: - self._executors[-1]._entity2alias[item] = alias - self._executors[-1]._return_requests.append(item) + self._executors[-1]._original_return_requests.add(item) + if alias: + self._executors[-1]._entity2alias[item] = alias + self._executors[-1]._return_requests.append(item) + else: + # Already a string, use as-is + self._executors[-1]._original_return_requests.add(item) + if alias: + self._executors[-1]._entity2alias[item] = alias + self._executors[-1]._return_requests.append(item) self._executors[-1]._alias2entity.update({v: k for k, v in self._executors[-1]._entity2alias.items()}) - def _parse_aggregation_token(self, item: Tree): + def _parse_aggregation_token(self, item: Tree) -> AggregationFunction: """ - Parse the aggregation function token and return the function and entity - input: Tree('aggregation_function', [Token('AGGREGATE_FUNC', 'SUM'), Token('CNAME', 'r'), Tree('attribute_id', [Token('CNAME', 'value')])]) - output: ('SUM', 'r.value') + Parse the aggregation function token and return an AggregationFunction object. + input: Tree('aggregation_function', [Token('AGGREGATE_FUNC', 'SUM'), Tree('agg_argument', [...])]) + output: SUM('r', 'value') object or SUM(scalar_function) object """ - func = str(item.children[0].value) # AGGREGATE_FUNC - entity = str(item.children[1].value) - if len(item.children) > 2: - entity += "." + str(item.children[2].children[0].value) + func_name = str(item.children[0].value).upper() # COUNT, SUM, AVG, MAX, MIN + agg_arg = item.children[1] # Tree('agg_argument', [...]) + + # Check if argument is a scalar function or entity reference + if isinstance(agg_arg, Tree) and agg_arg.data == "agg_argument": + arg_child = agg_arg.children[0] - return func, entity + # Case 1: Scalar function argument (e.g., size(relationships(r))) + if isinstance(arg_child, ScalarFunction): + # Scalar function was already parsed by transformer + entity = arg_child + entity_attribute = None + + # Case 2: Entity reference (e.g., r.value or r) + else: + entity = str(arg_child.value) if hasattr(arg_child, 'value') else str(arg_child) + entity_attribute = None + + # Check for attribute (e.g., r.value) + if len(agg_arg.children) > 1 and isinstance(agg_arg.children[1], Tree): + entity_attribute = str(agg_arg.children[1].children[0].value) + else: + # Fallback for old format (backward compatibility) + entity = str(agg_arg.value) if hasattr(agg_arg, 'value') else str(agg_arg) + entity_attribute = None + + # Create appropriate AggregationFunction class instance + func_class = { + "COUNT": COUNT, + "SUM": SUM, + "AVG": AVG, + "MAX": MAX, + "MIN": MIN, + "COLLECT": COLLECT, + }[func_name] + + return func_class(entity, entity_attribute) def _extract_alias(self, item: Tree): """ @@ -1540,9 +2446,13 @@ def order_clause(self, order_clause): isinstance(item.children[0], Tree) and item.children[0].data == "aggregation_function" ): - func, entity = self._parse_aggregation_token(item.children[0]) - field = self._executors[-1]._format_aggregation_key(func, entity) - self._executors[-1]._order_by_attributes.add(entity) + # Parse to AggregationFunction object + agg_func = self._parse_aggregation_token(item.children[0]) + # Use str(agg_func) for field name: "COUNT(r)", "SUM(r.value)", etc. + field = str(agg_func) + # Add full entity path to order_by_attributes for _lookup + entity_path = agg_func._entity + ('.' + agg_func._entity_attribute if agg_func._entity_attribute else '') + self._executors[-1]._order_by_attributes.add(entity_path) else: field = str( item.children[0] @@ -1700,7 +2610,11 @@ def where_clause(self, where_clause: tuple): def compound_condition(self, val): if len(val) == 1: - val = CompoundCondition(*val[0]) + item = val[0] + # Check if already a Condition object (ALL, ANY, or other) + if isinstance(item, Condition): + return item + val = CompoundCondition(*item) else: # len == 3 compound_a, operator, compound_b = val val = operator(compound_a, compound_b) @@ -1713,8 +2627,12 @@ def where_or(self, val): return _BOOL_ARI["or"] def condition(self, condition): - if len(condition) == 1: # sub query - condition = condition[0] + if len(condition) == 1: # sub query or list predicate or scalar function + item = condition[0] + # Check if it's already a Condition object (ALL, ANY, NONE, SINGLE, ScalarFunction) + if isinstance(item, (ALL, ANY, NONE, SINGLE, ScalarFunction)): + return item + condition = item if len(condition) == 3: (entity_id, operator, value) = condition @@ -1731,11 +2649,252 @@ def condition_not(self, processed_condition): def id_function(self, entity_id): entity_name = entity_id[0].value - # Add the raw entity ID to the return requests as well - # This ensures tests like test_id can still access res["A"] - # self._return_requests.append(entity_name) - # Return a special identifier that will be processed in _lookup method - return f"ID({entity_name})" + # Return ID object (class-based, not string-based) + return ID(entity_name) + + def tolower_function(self, items): + """ + Parse: toLower(n) or toLower(n.name) or toLower(trim(n.name)) + + items: [scalar_func_arg] which is a Tree containing either a ScalarFunction or entity_id + """ + arg = items[0] + + # arg is a Tree with 'scalar_func_arg' + if hasattr(arg, 'children') and len(arg.children) > 0: + first_child = arg.children[0] + + # Check if the first child is a ScalarFunction (nested) + if isinstance(first_child, ScalarFunction): + return ToLower(first_child) + + # Check if first child is a Token (entity name) + if hasattr(first_child, 'value'): + if len(arg.children) == 1: + # Just entity: n + expression = first_child.value + else: + # Has attribute: n.name + entity_name = first_child.value + attribute_name = arg.children[1].children[0].value + expression = f"{entity_name}.{attribute_name}" + return ToLower(expression) + + # Fallback + return ToLower(str(arg)) + + def toupper_function(self, items): + """ + Parse: toUpper(n) or toUpper(n.name) or toUpper(trim(n.name)) + + items: [scalar_func_arg] which is a Tree containing either a ScalarFunction or entity_id + """ + arg = items[0] + + # arg is a Tree with 'scalar_func_arg' + if hasattr(arg, 'children') and len(arg.children) > 0: + first_child = arg.children[0] + + # Check if the first child is a ScalarFunction (nested) + if isinstance(first_child, ScalarFunction): + return ToUpper(first_child) + + # Check if first child is a Token (entity name) + if hasattr(first_child, 'value'): + if len(arg.children) == 1: + # Just entity: n + expression = first_child.value + else: + # Has attribute: n.name + entity_name = first_child.value + attribute_name = arg.children[1].children[0].value + expression = f"{entity_name}.{attribute_name}" + return ToUpper(expression) + + # Fallback + return ToUpper(str(arg)) + + def trim_function(self, items): + """ + Parse: trim(n) or trim(n.name) or trim(toLower(n.name)) + + items: [scalar_func_arg] which is a Tree containing either a ScalarFunction or entity_id + """ + arg = items[0] + + # arg is a Tree with 'scalar_func_arg' + if hasattr(arg, 'children') and len(arg.children) > 0: + first_child = arg.children[0] + + # Check if the first child is a ScalarFunction (nested) + if isinstance(first_child, ScalarFunction): + return Trim(first_child) + + # Check if first child is a Token (entity name) + if hasattr(first_child, 'value'): + if len(arg.children) == 1: + # Just entity: n + expression = first_child.value + else: + # Has attribute: n.name + entity_name = first_child.value + attribute_name = arg.children[1].children[0].value + expression = f"{entity_name}.{attribute_name}" + return Trim(expression) + + # Fallback + return Trim(str(arg)) + + def type_function(self, items): + """ + Parse: type(r) + + items: [entity_id] + """ + entity_name = items[0].value + return Type(entity_name) + + def coalesce_function(self, items): + """ + Parse: coalesce(n.name, n.id, 'default') + + items: [coalesce_args Tree] + """ + # items[0] is the coalesce_args tree + args_tree = items[0] + expressions = [] + + # Process each coalesce_arg + for arg in args_tree.children: + # arg is a Tree with data='coalesce_arg' + if hasattr(arg, 'data') and arg.data == 'coalesce_arg': + # It's a Tree containing the argument + if len(arg.children) == 1: + child = arg.children[0] + # Check if it's a value or entity_id (both are Tokens or Trees) + if hasattr(child, 'value'): + # It's a Token - need to determine if it's a literal or entity reference + if hasattr(child, 'type'): + # Check token type to distinguish literals from entity references + if child.type == 'ESTRING': + # String literal - parse it (remove quotes and handle escapes) + expressions.append(child.value.strip('"').encode().decode('unicode_escape')) + elif child.type == 'NUMBER': + # Number literal - parse it + try: + # Try int first, then float + expressions.append(int(child.value)) + except ValueError: + expressions.append(float(child.value)) + elif child.type == 'CNAME': + # Entity reference without attribute: n + expressions.append(EntityAttributeGetter(child.value)) + else: + # Other token types (shouldn't happen in coalesce) + expressions.append(child.value) + else: + # No type attribute - fallback + expressions.append(child.value) + elif hasattr(child, 'data'): + # It's a Tree (like entity_id or null/true/false) + if child.data == 'entity_id': + # Just entity name: n + entity_name = child.children[0].value + expressions.append(EntityAttributeGetter(entity_name)) + elif child.data == 'null': + # NULL literal + expressions.append(None) + elif child.data == 'true': + # TRUE literal + expressions.append(True) + elif child.data == 'false': + # FALSE literal + expressions.append(False) + else: + # Some other tree - shouldn't happen + expressions.append(child) + else: + expressions.append(child) + elif len(arg.children) >= 2: + # entity_id with attribute_id: n.name + # arg.children[0] is Token CNAME for entity + # arg.children[1] is Tree attribute_id + entity_name = arg.children[0].value + attribute_name = arg.children[1].children[0].value + expressions.append(EntityAttributeGetter(f"{entity_name}.{attribute_name}")) + else: + # Direct value (shouldn't happen with current grammar) + expressions.append(arg) + + return Coalesce(expressions) + + def all_function(self, items): + """ + Parse: all(edge IN r WHERE edge.weight > 5) + + items structure: + [0]: CNAME (loop variable, e.g., "edge") + [1]: list_expression (string or ListExpression) + [2]: compound_condition (already transformed into CompoundCondition/AND/OR!) + """ + loop_variable = items[0].value + list_expression = items[1] # Already a string or ListExpression + inner_condition = items[2] # Already a Condition - perfect! + + # Just pass it through! No conversion needed. + return ALL(name=loop_variable, list_expr=list_expression, pred=inner_condition) + + def any_function(self, items): + """Similar to all_function - trivial!""" + loop_variable = items[0].value + list_expression = items[1] + inner_condition = items[2] # Already a Condition + + return ANY(name=loop_variable, list_expr=list_expression, pred=inner_condition) + + def none_function(self, items): + """ + Parse: none(edge IN r WHERE edge.weight > 5) + + items: [loop_var, list_expr, condition] + """ + loop_variable = items[0].value + list_expression = items[1] + inner_condition = items[2] + + return NONE(name=loop_variable, list_expr=list_expression, pred=inner_condition) + + def single_function(self, items): + """ + Parse: single(edge IN r WHERE edge.weight > 5) + + items: [loop_var, list_expr, condition] + """ + loop_variable = items[0].value + list_expression = items[1] + inner_condition = items[2] + + return SINGLE(name=loop_variable, list_expr=list_expression, pred=inner_condition) + + def size_function(self, items): + """ + Parse: size(r) or size(relationships(r)) + + items: [list_expression] + """ + list_expression = items[0] + + return SIZE(list_expr=list_expression) + + def relationships_function(self, items): + """Parse: relationships(path_variable)""" + path_variable = items[0].value if isinstance(items[0], Token) else items[0] + return path_variable # Return as string, will be wrapped in ScopedListExpression + + def entity_list(self, items): + """Parse: direct entity reference as list""" + entity_id = items[0] + return entity_id.value if isinstance(entity_id, Token) else entity_id def value_list(self, items): return list(items) diff --git a/grandcypher/test_aggregation_functions.py b/grandcypher/test_aggregation_functions.py new file mode 100644 index 0000000..8a0f1b9 --- /dev/null +++ b/grandcypher/test_aggregation_functions.py @@ -0,0 +1,338 @@ +""" +Unit tests for AggregationFunction classes. + +Tests the class-based aggregation architecture (COUNT, SUM, AVG, MAX, MIN). +""" + +import pytest +import networkx as nx +from . import COUNT, SUM, AVG, MAX, MIN +from .struct import Match + + +@pytest.fixture +def simple_graph(): + """Simple graph for testing aggregations.""" + host = nx.DiGraph() + host.add_node("A", value=10) + host.add_node("B", value=20) + host.add_node("C", value=30) + host.add_edge("A", "B", weight=5) + host.add_edge("B", "C", weight=15) + return host + + +@pytest.fixture +def simple_matches(simple_graph): + """Create simple matches for testing.""" + matches = [] + + # Match 1: A->B + match1 = Match( + node_mappings={"n": "A", "m": "B"}, + where_results=None, + edge_mapping=None + ) + matches.append(match1) + + # Match 2: B->C + match2 = Match( + node_mappings={"n": "B", "m": "C"}, + where_results=None, + edge_mapping=None + ) + matches.append(match2) + + return matches + + +class TestCOUNT: + """Tests for COUNT aggregation.""" + + def test_count_basic(self, simple_graph, simple_matches): + """Test basic COUNT functionality.""" + agg = COUNT("n", None) + # Build scope with values extracted from matches + scope = { + "n": [m.node_mappings["n"] for m in simple_matches], # ["A", "B"] + "m": [m.node_mappings["m"] for m in simple_matches], # ["B", "C"] + } + result = agg.evaluate(simple_matches, simple_graph, {}, ["m"], scope=scope) + + # Should group by 'm' and count 'n' values + assert len(result) == 2 + assert all(count == 1 for count in result.values()) + + def test_count_str(self): + """Test __str__ representation.""" + agg = COUNT("r", None) + assert str(agg) == "COUNT(r)" + + agg_with_attr = COUNT("r", "value") + assert str(agg_with_attr) == "COUNT(r.value)" + + def test_count_with_none(self, simple_graph): + """Test COUNT excludes None values.""" + # Create matches with None values + matches = [] + match = Match( + node_mappings={"n": "A"}, + where_results=None, + edge_mapping=None + ) + matches.append(match) + + agg = COUNT("n", "nonexistent") # This attribute doesn't exist + # Build scope with None values (attribute doesn't exist) + scope = { + "n.nonexistent": [simple_graph.nodes[m.node_mappings["n"]].get("nonexistent") for m in matches], + } + result = agg.evaluate(matches, simple_graph, {}, [], scope=scope) + + # COUNT excludes None values + assert result[()] == 0 + + +class TestSUM: + """Tests for SUM aggregation.""" + + def test_sum_basic(self, simple_graph, simple_matches): + """Test basic SUM functionality.""" + agg = SUM("n", "value") + # Build scope with attribute values from graph nodes + scope = { + "n.value": [simple_graph.nodes[m.node_mappings["n"]].get("value") for m in simple_matches], + } + result = agg.evaluate(simple_matches, simple_graph, {}, [], scope=scope) + + # Should sum all values: 10 + 20 = 30 + assert result[()] == 30 + + def test_sum_str(self): + """Test __str__ representation.""" + agg = SUM("r", "weight") + assert str(agg) == "SUM(r.weight)" + + def test_sum_with_none(self, simple_graph): + """Test SUM treats None as 0.""" + matches = [] + match = Match( + node_mappings={"n": "A"}, + where_results=None, + edge_mapping=None + ) + matches.append(match) + + agg = SUM("n", "nonexistent") # This attribute doesn't exist (None) + # Build scope with None values + scope = { + "n.nonexistent": [simple_graph.nodes[m.node_mappings["n"]].get("nonexistent") for m in matches], + } + result = agg.evaluate(matches, simple_graph, {}, [], scope=scope) + + # Should sum to 0 (None treated as 0) + assert result[()] == 0 + + +class TestAVG: + """Tests for AVG aggregation.""" + + def test_avg_basic(self, simple_graph, simple_matches): + """Test basic AVG functionality.""" + agg = AVG("n", "value") + # Build scope with attribute values + scope = { + "n.value": [simple_graph.nodes[m.node_mappings["n"]].get("value") for m in simple_matches], + } + result = agg.evaluate(simple_matches, simple_graph, {}, [], scope=scope) + + # Should average: (10 + 20) / 2 = 15 + assert result[()] == 15 + + def test_avg_str(self): + """Test __str__ representation.""" + agg = AVG("r", "weight") + assert str(agg) == "AVG(r.weight)" + + def test_avg_with_none(self, simple_graph): + """Test AVG treats None as 0.""" + matches = [] + match = Match( + node_mappings={"n": "A"}, + where_results=None, + edge_mapping=None + ) + matches.append(match) + + agg = AVG("n", "nonexistent") + # Build scope with None values + scope = { + "n.nonexistent": [simple_graph.nodes[m.node_mappings["n"]].get("nonexistent") for m in matches], + } + result = agg.evaluate(matches, simple_graph, {}, [], scope=scope) + + # Should average to 0 + assert result[()] == 0 + + +class TestMAX: + """Tests for MAX aggregation.""" + + def test_max_basic(self, simple_graph, simple_matches): + """Test basic MAX functionality.""" + agg = MAX("n", "value") + # Build scope with attribute values + scope = { + "n.value": [simple_graph.nodes[m.node_mappings["n"]].get("value") for m in simple_matches], + } + result = agg.evaluate(simple_matches, simple_graph, {}, [], scope=scope) + + # Should find max: max(10, 20) = 20 + assert result[()] == 20 + + def test_max_str(self): + """Test __str__ representation.""" + agg = MAX("r", "weight") + assert str(agg) == "MAX(r.weight)" + + def test_max_with_none(self, simple_graph): + """Test MAX treats None as negative infinity.""" + matches = [] + for node_id in ["A", "B"]: + match = Match( + node_mappings={"n": node_id}, + where_results=None, + edge_mapping=None + ) + matches.append(match) + + agg = MAX("n", "nonexistent") + # Build scope with None values + scope = { + "n.nonexistent": [simple_graph.nodes[m.node_mappings["n"]].get("nonexistent") for m in matches], + } + result = agg.evaluate(matches, simple_graph, {}, [], scope=scope) + + # Should return -inf (all None values) + assert result[()] == -float("inf") + + +class TestMIN: + """Tests for MIN aggregation.""" + + def test_min_basic(self, simple_graph, simple_matches): + """Test basic MIN functionality.""" + agg = MIN("n", "value") + # Build scope with attribute values + scope = { + "n.value": [simple_graph.nodes[m.node_mappings["n"]].get("value") for m in simple_matches], + } + result = agg.evaluate(simple_matches, simple_graph, {}, [], scope=scope) + + # Should find min: min(10, 20) = 10 + assert result[()] == 10 + + def test_min_str(self): + """Test __str__ representation.""" + agg = MIN("r", "weight") + assert str(agg) == "MIN(r.weight)" + + def test_min_with_none(self, simple_graph): + """Test MIN treats None as positive infinity.""" + matches = [] + for node_id in ["A", "B"]: + match = Match( + node_mappings={"n": node_id}, + where_results=None, + edge_mapping=None + ) + matches.append(match) + + agg = MIN("n", "nonexistent") + # Build scope with None values + scope = { + "n.nonexistent": [simple_graph.nodes[m.node_mappings["n"]].get("nonexistent") for m in matches], + } + result = agg.evaluate(matches, simple_graph, {}, [], scope=scope) + + # Should return +inf (all None values) + assert result[()] == float("inf") + + +class TestGrouping: + """Tests for grouping behavior.""" + + def test_grouping_by_one_key(self, simple_graph, simple_matches): + """Test grouping by a single key.""" + agg = COUNT("n", None) + # Build scope with node mappings + scope = { + "n": [m.node_mappings["n"] for m in simple_matches], + "m": [m.node_mappings["m"] for m in simple_matches], + } + result = agg.evaluate(simple_matches, simple_graph, {}, ["m"], scope=scope) + + # Should have 2 groups (one for each 'm' value) + assert len(result) == 2 + assert ("B",) in result + assert ("C",) in result + + def test_grouping_by_multiple_keys(self, simple_graph): + """Test grouping by multiple keys.""" + matches = [] + for i, (n_val, m_val) in enumerate([("A", "B"), ("A", "C"), ("B", "C")]): + match = Match( + node_mappings={"n": n_val, "m": m_val, "o": f"O{i}"}, + where_results=None, + edge_mapping=None + ) + matches.append(match) + + agg = COUNT("o", None) + # Build scope with node mappings + scope = { + "o": [m.node_mappings["o"] for m in matches], + "n": [m.node_mappings["n"] for m in matches], + "m": [m.node_mappings["m"] for m in matches], + } + result = agg.evaluate(matches, simple_graph, {}, ["n", "m"], scope=scope) + + # Should have 3 groups (one for each (n, m) combination) + assert len(result) == 3 + assert ("A", "B") in result + assert ("A", "C") in result + assert ("B", "C") in result + + def test_no_grouping(self, simple_graph, simple_matches): + """Test aggregation without grouping (single group).""" + agg = COUNT("n", None) + # Build scope with node mappings + scope = { + "n": [m.node_mappings["n"] for m in simple_matches], + } + result = agg.evaluate(simple_matches, simple_graph, {}, [], scope=scope) + + # Should have 1 group (empty tuple) + assert len(result) == 1 + assert () in result + assert result[()] == 2 # 2 matches total + + +class TestEdgeCases: + """Tests for edge cases and error conditions.""" + + def test_empty_matches(self, simple_graph): + """Test aggregation with no matches.""" + agg = COUNT("n", None) + result = agg.evaluate([], simple_graph, {}, [], scope={}) + + # Should return empty dict + assert result == {} + + def test_evaluate_not_implemented(self, simple_graph): + """Test that base class evaluate() raises NotImplementedError.""" + from grandcypher import AggregationFunction + + agg = AggregationFunction("n", None) + with pytest.raises(NotImplementedError): + agg.evaluate([], simple_graph, {}, [], scope={}) diff --git a/grandcypher/test_entity_attribute_getter.py b/grandcypher/test_entity_attribute_getter.py new file mode 100644 index 0000000..940b192 --- /dev/null +++ b/grandcypher/test_entity_attribute_getter.py @@ -0,0 +1,176 @@ +""" +Unit tests for EntityAttributeGetter class. + +Tests the enhanced EntityAttributeGetter that supports string parsing, +scope variables, and edge references. +""" + +import pytest +import networkx as nx +from . import EntityAttributeGetter +from .struct import Match + + +@pytest.fixture +def simple_graph(): + """Simple graph for testing.""" + host = nx.DiGraph() + host.add_node("A", name="Alice", value=10) + host.add_node("B", name="Bob", value=20) + host.add_node("C", name="Charlie", value=30) + host.add_edge("A", "B", weight=5) + host.add_edge("B", "C", weight=15) + return host + + +@pytest.fixture +def simple_match(simple_graph): + """Create a simple match for testing.""" + return Match( + node_mappings={"n": "A", "m": "B"}, + where_results=None, + edge_mapping=None + ) + + +class TestEntityAttributeGetterInit: + """Tests for EntityAttributeGetter constructor.""" + + def test_init_simple_entity(self): + """Test initializing with simple entity reference: 'n'""" + getter = EntityAttributeGetter("n") + assert getter.entity == "n" + assert getter.attribute is None + + def test_init_entity_with_attribute(self): + """Test initializing with entity.attribute: 'n.name'""" + getter = EntityAttributeGetter("n.name") + assert getter.entity == "n" + assert getter.attribute == "name" + + def test_init_nested_attribute(self): + """Test parsing with nested dots (only first dot splits)""" + getter = EntityAttributeGetter("n.data.value") + assert getter.entity == "n" + assert getter.attribute == "data.value" # Everything after first dot + + +class TestEntityAttributeGetterEvaluate: + """Tests for EntityAttributeGetter.evaluate() method.""" + + def test_evaluate_node_attribute(self, simple_graph, simple_match): + """Test getting node attribute from match""" + getter = EntityAttributeGetter("n.name") + result = getter.evaluate(simple_match, simple_graph) + assert result == "Alice" # Node A has name="Alice" + + def test_evaluate_node_id(self, simple_graph, simple_match): + """Test getting full node dictionary (no attribute)""" + getter = EntityAttributeGetter("n") + result = getter.evaluate(simple_match, simple_graph) + assert result == {"name": "Alice", "value": 10} # n returns full node dictionary + + def test_evaluate_nonexistent_attribute(self, simple_graph, simple_match): + """Test getting nonexistent attribute returns None""" + getter = EntityAttributeGetter("n.nonexistent") + result = getter.evaluate(simple_match, simple_graph) + assert result is None + + def test_evaluate_nonexistent_entity(self, simple_graph, simple_match): + """Test getting nonexistent entity returns None""" + getter = EntityAttributeGetter("z") + result = getter.evaluate(simple_match, simple_graph) + assert result is None + + +class TestEntityAttributeGetterScope: + """Tests for scope variable handling.""" + + def test_evaluate_with_scope_simple(self, simple_graph, simple_match): + """Test scope variable takes priority over node mappings""" + scope = {"n": {"name": "ScopeName", "value": 100}} + getter = EntityAttributeGetter("n.name") + result = getter.evaluate(simple_match, simple_graph, scope=scope) + # Should get from scope, not from node mapping + assert result == "ScopeName" + + def test_evaluate_with_scope_entity_only(self, simple_graph, simple_match): + """Test getting simple scope variable (no attribute)""" + scope = {"e": {"weight": 10}} + getter = EntityAttributeGetter("e") + result = getter.evaluate(simple_match, simple_graph, scope=scope) + assert result == {"weight": 10} + + def test_evaluate_scope_non_dict(self, simple_graph, simple_match): + """Test scope variable that is not a dict""" + scope = {"n": "simple_value"} + getter = EntityAttributeGetter("n.name") + result = getter.evaluate(simple_match, simple_graph, scope=scope) + # Trying to access attribute on non-dict should return None + assert result is None + + def test_evaluate_scope_priority_order(self, simple_graph, simple_match): + """Test that scope takes priority over node_mappings""" + # simple_match has n -> "A" mapping + # scope has n with different value + scope = {"n": {"name": "ScopePriority"}} + getter = EntityAttributeGetter("n.name") + result = getter.evaluate(simple_match, simple_graph, scope=scope) + # Should get from scope (priority 1), not node (priority 2) + assert result == "ScopePriority" + + +class TestEntityAttributeGetterStrRepresentation: + """Tests for __str__() and __repr__() methods.""" + + def test_str_simple_entity(self): + """Test __str__() with simple entity""" + getter = EntityAttributeGetter("n") + assert str(getter) == "n" + + def test_str_entity_with_attribute(self): + """Test __str__() with entity.attribute""" + getter = EntityAttributeGetter("n.name") + assert str(getter) == "n.name" + + def test_repr_simple_entity(self): + """Test __repr__() with simple entity""" + getter = EntityAttributeGetter("m") + assert repr(getter) == "EntityAttributeGetter('m')" + + def test_repr_entity_with_attribute(self): + """Test __repr__() with entity.attribute""" + getter = EntityAttributeGetter("r.weight") + assert repr(getter) == "EntityAttributeGetter('r'.'weight')" + + +class TestEntityAttributeGetterIntegration: + """Integration tests with real Cypher-like scenarios.""" + + def test_multiple_attributes_on_different_nodes(self, simple_graph): + """Test accessing attributes on different nodes in a match""" + match = Match( + node_mappings={"start": "A", "end": "C"}, + where_results=None, + edge_mapping=None + ) + + getter_start = EntityAttributeGetter("start.name") + getter_end = EntityAttributeGetter("end.value") + + assert getter_start.evaluate(match, simple_graph) == "Alice" + assert getter_end.evaluate(match, simple_graph) == 30 + + def test_fallback_to_none_pattern(self, simple_graph, simple_match): + """Test pattern: try attribute, fallback if None""" + # Node "A" has name="Alice" + getter = EntityAttributeGetter("n.nickname") + result = getter.evaluate(simple_match, simple_graph) + + # nickname doesn't exist, should return None + assert result is None + + # Can then use actual attribute as fallback + getter2 = EntityAttributeGetter("n.name") + result2 = getter2.evaluate(simple_match, simple_graph) + assert result2 == "Alice" diff --git a/grandcypher/test_queries.py b/grandcypher/test_queries.py index 6957ce4..4aa4236 100644 --- a/grandcypher/test_queries.py +++ b/grandcypher/test_queries.py @@ -842,7 +842,7 @@ def test_order_by_aggregation_function(self): assert res["n.name"] == ["Alice", "Bob"] assert res["MIN(r.value)"] == [9, 14] assert res["MAX(r.value)"] == [40, 14] - assert res["COUNT(r.value)"] == [3, 1] + assert res["COUNT(r.value)"] == [2, 1] # COUNT excludes None values @pytest.mark.parametrize("graph_type", ACCEPTED_GRAPH_TYPES) def test_order_by_aggregation_fails_if_not_requested_in_return(self, graph_type): @@ -1384,6 +1384,267 @@ def test_multigraph_multiple_aggregation_functions(self): assert res["COUNT(r.amount)"] == [2, 1] assert res["SUM(r.amount)"] == [52, 6] + def test_multigraph_aggregation_function_collect(self): + """Test COLLECT aggregation function""" + host = nx.MultiDiGraph() + host.add_node("a", name="Alice", age=25) + host.add_node("b", name="Bob", age=30) + host.add_node("c", name="Christine", age=35) + host.add_edge("a", "b", __labels__={"paid"}, amount=40) + host.add_edge("a", "b", __labels__={"paid"}, amount=12) + host.add_edge("a", "c", __labels__={"owes"}, amount=39) + host.add_edge("b", "a", __labels__={"paid"}, amount=6) + + # Test COLLECT with edge attributes + qry = """ + MATCH (n)-[r:paid]->(m) + RETURN n.name, COLLECT(r.amount) + """ + res = GrandCypher(host).run(qry) + assert res["n.name"] == ["Alice", "Bob"] + assert res["COLLECT(r.amount)"] == [[40, 12], [6]] + + # Test COLLECT with node attributes + qry = """ + MATCH (n) + RETURN COLLECT(n.name) + """ + res = GrandCypher(host).run(qry) + assert res["COLLECT(n.name)"] == [["Alice", "Bob", "Christine"]] + + def test_multigraph_collect_with_grouping(self): + """Test COLLECT with grouping by multiple keys""" + host = nx.MultiDiGraph() + host.add_node("a", name="Alice") + host.add_node("b", name="Bob") + host.add_node("c", name="Charlie") + host.add_edge("a", "b", value=1) + host.add_edge("a", "b", value=2) + host.add_edge("b", "c", value=3) + + qry = """ + MATCH (n)-[r]->(m) + RETURN n.name, m.name, COLLECT(r.value) + """ + res = GrandCypher(host).run(qry) + assert res["n.name"] == ["Alice", "Bob"] + assert res["m.name"] == ["Bob", "Charlie"] + assert res["COLLECT(r.value)"] == [[1, 2], [3]] + + +class TestStringScalarFunctions: + """Tests for string scalar functions: toLower, toUpper, trim""" + + def test_tolower_basic(self): + """Test toLower with node attribute""" + host = nx.DiGraph() + host.add_node("a", name="ALICE") + host.add_node("b", name="BOB") + + qry = """ + MATCH (n) + RETURN n.name, toLower(n.name) + """ + res = GrandCypher(host).run(qry) + assert res["n.name"] == ["ALICE", "BOB"] + assert res["toLower(n.name)"] == ["alice", "bob"] + + def test_tolower_mixed_case(self): + """Test toLower with mixed case strings""" + host = nx.DiGraph() + host.add_node("a", name="AlIcE") + + qry = """ + MATCH (n) + RETURN toLower(n.name) + """ + res = GrandCypher(host).run(qry) + assert res["toLower(n.name)"] == ["alice"] + + def test_toupper_basic(self): + """Test toUpper with node attribute""" + host = nx.DiGraph() + host.add_node("a", name="alice") + host.add_node("b", name="bob") + + qry = """ + MATCH (n) + RETURN n.name, toUpper(n.name) + """ + res = GrandCypher(host).run(qry) + assert res["n.name"] == ["alice", "bob"] + assert res["toUpper(n.name)"] == ["ALICE", "BOB"] + + def test_toupper_mixed_case(self): + """Test toUpper with mixed case strings""" + host = nx.DiGraph() + host.add_node("a", name="AlIcE") + + qry = """ + MATCH (n) + RETURN toUpper(n.name) + """ + res = GrandCypher(host).run(qry) + assert res["toUpper(n.name)"] == ["ALICE"] + + def test_trim_whitespace(self): + """Test trim with leading/trailing whitespace""" + host = nx.DiGraph() + host.add_node("a", name=" alice ") + host.add_node("b", name="bob ") + host.add_node("c", name=" charlie") + + qry = """ + MATCH (n) + RETURN n.name, trim(n.name) + """ + res = GrandCypher(host).run(qry) + assert res["n.name"] == [" alice ", "bob ", " charlie"] + assert res["trim(n.name)"] == ["alice", "bob", "charlie"] + + def test_trim_no_whitespace(self): + """Test trim with no whitespace""" + host = nx.DiGraph() + host.add_node("a", name="alice") + + qry = """ + MATCH (n) + RETURN trim(n.name) + """ + res = GrandCypher(host).run(qry) + assert res["trim(n.name)"] == ["alice"] + + # NOTE: Scalar functions work on LEFT side of WHERE conditions but not RIGHT side yet + # Currently supported: WHERE ID(A) == 1, WHERE toLower(n.name) = 'value' + # Not yet supported: WHERE ID(A) == ID(B), WHERE toLower(n.name) = toLower(m.name) + def test_string_functions_with_where(self): + """Test string functions in WHERE clause""" + host = nx.DiGraph() + host.add_node("a", name="ALICE") + host.add_node("b", name="BOB") + + qry = """ + MATCH (n) + WHERE toLower(n.name) = "alice" + RETURN n.name + """ + res = GrandCypher(host).run(qry) + assert set(res["n.name"]) == {"ALICE"} + + def test_string_functions_combined(self): + """Test combining multiple string functions (nested)""" + host = nx.DiGraph() + host.add_node("a", name=" ALICE ") + host.add_node("b", name=" bob ") + + qry = """ + MATCH (n) + RETURN toLower(trim(n.name)) + """ + res = GrandCypher(host).run(qry) + assert set(res["toLower(trim(n.name))"]) == {"alice", "bob"} + + def test_nested_functions_multiple_levels(self): + """Test deeply nested functions""" + host = nx.DiGraph() + host.add_node("a", name=" HELLO ") + + qry = """ + MATCH (n) + RETURN toUpper(trim(toLower(n.name))) + """ + res = GrandCypher(host).run(qry) + assert res["toUpper(trim(toLower(n.name)))"] == ["HELLO"] + + +class TestTypeAndCoalesceScalarFunctions: + """Tests for type() and coalesce() scalar functions""" + + def test_type_basic(self): + """Test type() with relationship labels""" + host = nx.MultiDiGraph() + host.add_node("a", name="Alice") + host.add_node("b", name="Bob") + host.add_edge("a", "b", __labels__={"paid"}, amount=50) + host.add_edge("a", "b", __labels__={"owes"}, amount=20) + + qry = """ + MATCH (n)-[r]->(m) + RETURN n.name, type(r), r.amount + """ + res = GrandCypher(host).run(qry) + assert res["n.name"] == ["Alice", "Alice"] + assert set(res["type(r)"]) == {"paid", "owes"} + + def test_type_with_specific_label(self): + """Test type() with specific relationship label filter""" + host = nx.MultiDiGraph() + host.add_node("a", name="Alice") + host.add_node("b", name="Bob") + host.add_edge("a", "b", __labels__={"paid"}, amount=50) + host.add_edge("a", "b", __labels__={"owes"}, amount=20) + + qry = """ + MATCH (n)-[r:paid]->(m) + RETURN type(r) + """ + res = GrandCypher(host).run(qry) + assert res["type(r)"] == ["paid"] + + def test_coalesce_basic(self): + """Test coalesce() with multiple attributes""" + host = nx.DiGraph() + host.add_node("a", name="Alice", nickname=None) + host.add_node("b", nickname="Bobby") + host.add_node("c", name="Charlie") + + qry = """ + MATCH (n) + RETURN coalesce(n.nickname, n.name) + """ + res = GrandCypher(host).run(qry) + # a: nickname is None, so use name "Alice" + # b: nickname is "Bobby" + # c: nickname missing, so use name "Charlie" + assert set(res["coalesce(n.nickname, n.name)"]) == {"Alice", "Bobby", "Charlie"} + + def test_coalesce_first_non_null(self): + """Test that coalesce returns first non-null value""" + host = nx.DiGraph() + host.add_node("a", first=None, second="Second", third="Third") + + qry = """ + MATCH (n) + RETURN coalesce(n.first, n.second, n.third) + """ + res = GrandCypher(host).run(qry) + assert res["coalesce(n.first, n.second, n.third)"] == ["Second"] + + def test_coalesce_all_null(self): + """Test coalesce when all values are null""" + host = nx.DiGraph() + host.add_node("a") + + qry = """ + MATCH (n) + RETURN coalesce(n.missing1, n.missing2) + """ + res = GrandCypher(host).run(qry) + assert res["coalesce(n.missing1, n.missing2)"] == [None] + + def test_coalesce_with_fallback(self): + """Test coalesce with a fallback attribute""" + host = nx.DiGraph() + host.add_node("a", id="id_a") + host.add_node("b", name="Bob", id="id_b") + + qry = """ + MATCH (n) + RETURN coalesce(n.name, n.id) + """ + res = GrandCypher(host).run(qry) + assert set(res["coalesce(n.name, n.id)"]) == {"Bob", "id_a"} + class TestAlias: @pytest.mark.benchmark @@ -2731,3 +2992,628 @@ def test_equijoin2(): assert GrandCypher(G).run(qry) == { "ID(n)": ["x"] } + + +# ============================================================================== +# CONSOLIDATED SCALAR AND AGGREGATION FUNCTION TESTS +# ============================================================================== +# Tests moved from test_aggregation_functions.py and test_list_predicates.py +# to centralize all Cypher query tests in one place. +# ============================================================================== + + +class TestCoalesceScalarFunction: + """Tests for coalesce() scalar function with literal values. + Moved from test_aggregation_functions.py""" + + def test_coalesce_with_double_quote_literal(self): + """Test coalesce with double-quoted string literal as fallback""" + host = nx.DiGraph() + host.add_node("a", name="Alice") + host.add_node("b") # No name attribute + host.add_node("c", name="Charlie") + + qry = """ + MATCH (n) + RETURN coalesce(n.name, "Unknown") + """ + res = GrandCypher(host).run(qry) + + # Should return "Unknown" for node b which has no name + assert set(res[list(res.keys())[0]]) == {"Alice", "Unknown", "Charlie"} + + def test_coalesce_distinguishes_literal_from_entity(self): + """Test that coalesce distinguishes between string literals and entity references""" + host = nx.DiGraph() + host.add_node("a", name="Alice", backup="BackupA") + host.add_node("b", backup="BackupB") # No name, but has backup + host.add_node("Unknown", value="I am a node") # Node with ID "Unknown" + + # Test 1: String literal "Unknown" should return the literal string + qry1 = """ + MATCH (n) + RETURN coalesce(n.name, "Unknown") + """ + res1 = GrandCypher(host).run(qry1) + # Should return "Unknown" as literal string for nodes b and Unknown (which have no name) + # Node a has name="Alice", nodes b and Unknown have no name so get "Unknown" + results1 = set(res1[list(res1.keys())[0]]) + assert "Alice" in results1 # Node a + assert "Unknown" in results1 # Fallback for nodes b and Unknown + assert "I am a node" not in results1 # Should NOT look up node ID "Unknown" + + # Test 2: Entity reference backup (no quotes) should look up the backup attribute + qry2 = """ + MATCH (n) + RETURN coalesce(n.name, n.backup) + """ + res2 = GrandCypher(host).run(qry2) + # Should return actual backup values + results2 = set(res2["coalesce(n.name, n.backup)"]) + assert "Alice" in results2 # Node a has name + assert "BackupB" in results2 # Node b has no name, uses backup + # Node "Unknown" has neither name nor backup, so returns None + assert None in results2 + + def test_coalesce_multiple_literals(self): + """Test coalesce with multiple string literal fallbacks""" + host = nx.DiGraph() + host.add_node("a", name="Alice") + host.add_node("b") + + qry = """ + MATCH (n) + RETURN coalesce(n.name, n.nickname, "Default", "Final") + """ + res = GrandCypher(host).run(qry) + + # Node b has neither name nor nickname, should get first literal "Default" + assert set(res[list(res.keys())[0]]) == {"Alice", "Default"} + + def test_coalesce_number_literal(self): + """Test coalesce with number literal as fallback""" + host = nx.DiGraph() + host.add_node("a", score=100) + host.add_node("b") # No score + + qry = """ + MATCH (n) + RETURN coalesce(n.score, 0) + """ + res = GrandCypher(host).run(qry) + + assert set(res["coalesce(n.score, 0)"]) == {100, 0} + + def test_coalesce_null_literal(self): + """Test coalesce explicitly with NULL""" + host = nx.DiGraph() + host.add_node("a", name="Alice") + host.add_node("b") + + qry = """ + MATCH (n) + RETURN coalesce(n.name, NULL, "Fallback") + """ + res = GrandCypher(host).run(qry) + + # NULL should be skipped, should use "Fallback" for node b + assert set(res[list(res.keys())[0]]) == {"Alice", "Fallback"} + + +# ============================================================================== +# LIST PREDICATE TESTS +# ============================================================================== +# Tests moved from test_list_predicates.py +# Includes: ALL, ANY, NONE, SINGLE, SIZE, and combined tests +# ============================================================================== + + +class TestListPredicatesALL: + """Tests for ALL() list predicate. + Moved from test_list_predicates.py""" + + def test_all_predicate_on_path_weights_true(self): + """Test ALL predicate returns true when all edges meet condition""" + G = nx.DiGraph() + G.add_node("a", name="Alice") + G.add_node("b", name="Bob") + G.add_node("c", name="Charlie") + G.add_node("d", name="David") + G.add_edge("a", "b", weight=10) + G.add_edge("b", "c", weight=20) + G.add_edge("c", "d", weight=5) + + qry = """ + MATCH (a)-[r*2]->(c) + WHERE ALL(edge IN r WHERE edge.weight > 5) + RETURN a.name, c.name + """ + res = GrandCypher(G).run(qry) + + # Path a->b->c has weights [10, 20], both > 5 + assert res == {"a.name": ["Alice"], "c.name": ["Charlie"]} + + def test_all_predicate_on_path_weights_false(self): + """Test ALL predicate returns false when not all edges meet condition""" + G = nx.DiGraph() + G.add_node("a", name="Alice") + G.add_node("b", name="Bob") + G.add_node("c", name="Charlie") + G.add_node("d", name="David") + G.add_edge("a", "b", weight=10) + G.add_edge("b", "c", weight=20) + G.add_edge("c", "d", weight=5) + + qry = """ + MATCH (a)-[r*2]->(c) + WHERE ALL(edge IN r WHERE edge.weight > 15) + RETURN a.name, c.name + """ + res = GrandCypher(G).run(qry) + + # Path a->b->c has weights [10, 20], not all > 15 + # Path b->c->d has weights [20, 5], not all > 15 + assert res == {"a.name": [], "c.name": []} + + def test_all_predicate_variable_length_path(self): + """Test ALL with variable-length path [r*1..3]""" + G = nx.DiGraph() + G.add_node("a", name="Alice") + G.add_node("b", name="Bob") + G.add_node("c", name="Charlie") + G.add_node("d", name="David") + G.add_edge("a", "b", weight=10) + G.add_edge("b", "c", weight=20) + G.add_edge("c", "d", weight=5) + + qry = """ + MATCH (a)-[r*1..3]->(b) + WHERE ALL(edge IN r WHERE edge.weight >= 10) + RETURN a.name, b.name, size(relationships(r)) AS path_length + """ + res = GrandCypher(G).run(qry) + + # Should find paths where all edges have weight >= 10 + assert "Alice" in res["a.name"] # a->b (weight=10) + assert "Bob" in res["a.name"] # b->c (weight=20) + + def test_all_combined_with_any(self): + """Test ALL combined with ANY in same query""" + G = nx.DiGraph() + G.add_node("a", name="Alice") + G.add_node("b", name="Bob") + G.add_node("c", name="Charlie") + G.add_node("d", name="David") + G.add_edge("a", "b", weight=10) + G.add_edge("b", "c", weight=20) + G.add_edge("c", "d", weight=5) + + qry = """ + MATCH (a)-[r*2]->(c) + WHERE ALL(edge IN r WHERE edge.weight > 0) + AND ANY(edge IN r WHERE edge.weight > 15) + RETURN a.name, c.name + """ + res = GrandCypher(G).run(qry) + + # Two paths match: a->b->c [10,20] and b->c->d [20,5] + # Both have all weights > 0 AND at least one > 15 + assert set(res["a.name"]) == {"Alice", "Bob"} + assert set(res["c.name"]) == {"Charlie", "David"} + + def test_all_with_size_function(self): + """Test ALL combined with size() function""" + G = nx.DiGraph() + G.add_node("a", name="Alice") + G.add_node("b", name="Bob") + G.add_node("c", name="Charlie") + G.add_node("d", name="David") + G.add_edge("a", "b", weight=10) + G.add_edge("b", "c", weight=20) + G.add_edge("c", "d", weight=5) + + qry = """ + MATCH (a)-[r*1..3]->(b) + WHERE size(relationships(r)) = 2 + AND ALL(edge IN r WHERE edge.weight > 5) + RETURN a.name, b.name + """ + res = GrandCypher(G).run(qry) + + # Only 2-hop paths where all weights > 5 + assert "Alice" in res["a.name"] # a->b->c + + +class TestListPredicatesANY: + """Tests for ANY() list predicate. + Moved from test_list_predicates.py""" + + def test_any_predicate_true(self): + """Test ANY predicate returns true when at least one edge meets condition""" + G = nx.DiGraph() + G.add_node("a", name="Alice") + G.add_node("b", name="Bob") + G.add_node("c", name="Charlie") + G.add_node("d", name="David") + G.add_edge("a", "b", weight=10) + G.add_edge("b", "c", weight=20) + G.add_edge("c", "d", weight=5) + + qry = """ + MATCH (a)-[r*2]->(c) + WHERE ANY(edge IN r WHERE edge.weight > 15) + RETURN a.name, c.name + """ + res = GrandCypher(G).run(qry) + + # Path a->b->c has weight 20 > 15 + assert "Alice" in res["a.name"] + assert "Charlie" in res["c.name"] + + def test_any_predicate_false(self): + """Test ANY predicate returns false when no edges meet condition""" + G = nx.DiGraph() + G.add_node("a", name="Alice") + G.add_node("b", name="Bob") + G.add_node("c", name="Charlie") + G.add_node("d", name="David") + G.add_edge("a", "b", weight=10) + G.add_edge("b", "c", weight=20) + G.add_edge("c", "d", weight=5) + + qry = """ + MATCH (a)-[r*2]->(c) + WHERE ANY(edge IN r WHERE edge.weight > 100) + RETURN a.name, c.name + """ + res = GrandCypher(G).run(qry) + + # No edges have weight > 100 + assert res == {"a.name": [], "c.name": []} + + +class TestListPredicatesNONE: + """Tests for NONE() list predicate. + Moved from test_list_predicates.py""" + + def test_none_predicate_true(self): + """Test NONE predicate returns true when no edges meet condition""" + G = nx.DiGraph() + G.add_node("a", name="Alice") + G.add_node("b", name="Bob") + G.add_node("c", name="Charlie") + G.add_node("d", name="David") + G.add_edge("a", "b", weight=10) + G.add_edge("b", "c", weight=20) + G.add_edge("c", "d", weight=5) + + qry = """ + MATCH (a)-[r*2]->(c) + WHERE NONE(edge IN r WHERE edge.weight > 100) + RETURN a.name, c.name + """ + res = GrandCypher(G).run(qry) + + # No edges have weight > 100, so NONE returns true for all paths + assert "Alice" in res["a.name"] + assert "Bob" in res["a.name"] + + def test_none_predicate_false(self): + """Test NONE predicate returns false when at least one edge meets condition""" + G = nx.DiGraph() + G.add_node("a", name="Alice") + G.add_node("b", name="Bob") + G.add_node("c", name="Charlie") + G.add_node("d", name="David") + G.add_edge("a", "b", weight=10) + G.add_edge("b", "c", weight=20) + G.add_edge("c", "d", weight=5) + + qry = """ + MATCH (a)-[r*2]->(c) + WHERE NONE(edge IN r WHERE edge.weight > 15) + RETURN a.name, c.name + """ + res = GrandCypher(G).run(qry) + + # Path a->b->c has weight 20 > 15, so excluded + # Path b->c->d has weight 20 > 15, so excluded + assert res == {"a.name": [], "c.name": []} + + def test_none_combined_with_all(self): + """Test NONE combined with ALL""" + G = nx.DiGraph() + G.add_node("a", name="Alice") + G.add_node("b", name="Bob") + G.add_node("c", name="Charlie") + G.add_node("d", name="David") + G.add_edge("a", "b", weight=10) + G.add_edge("b", "c", weight=20) + G.add_edge("c", "d", weight=5) + + qry = """ + MATCH (a)-[r*2]->(c) + WHERE NONE(edge IN r WHERE edge.weight < 5) + AND ALL(edge IN r WHERE edge.weight > 0) + RETURN a.name, c.name + """ + res = GrandCypher(G).run(qry) + + # Paths where no edges < 5 and all edges > 0 + assert "Alice" in res["a.name"] # a->b->c: [10, 20] + + +class TestListPredicatesSINGLE: + """Tests for SINGLE() list predicate. + Moved from test_list_predicates.py""" + + def test_single_predicate_true(self): + """Test SINGLE predicate returns true when exactly one edge meets condition""" + G = nx.DiGraph() + G.add_node("a", name="Alice") + G.add_node("b", name="Bob") + G.add_node("c", name="Charlie") + G.add_node("d", name="David") + G.add_edge("a", "b", weight=10) + G.add_edge("b", "c", weight=20) + G.add_edge("c", "d", weight=5) + + qry = """ + MATCH (a)-[r*2]->(c) + WHERE SINGLE(edge IN r WHERE edge.weight > 15) + RETURN a.name, c.name + """ + res = GrandCypher(G).run(qry) + + # Path a->b->c has exactly one edge with weight > 15 (20) + assert "Alice" in res["a.name"] + assert "Charlie" in res["c.name"] + + def test_single_predicate_false_zero_matches(self): + """Test SINGLE predicate returns false when zero edges meet condition""" + G = nx.DiGraph() + G.add_node("a", name="Alice") + G.add_node("b", name="Bob") + G.add_node("c", name="Charlie") + G.add_node("d", name="David") + G.add_edge("a", "b", weight=10) + G.add_edge("b", "c", weight=20) + G.add_edge("c", "d", weight=5) + + qry = """ + MATCH (a)-[r*2]->(c) + WHERE SINGLE(edge IN r WHERE edge.weight > 100) + RETURN a.name, c.name + """ + res = GrandCypher(G).run(qry) + + # No edges have weight > 100 + assert res == {"a.name": [], "c.name": []} + + def test_single_predicate_false_multiple_matches(self): + """Test SINGLE predicate returns false when multiple edges meet condition""" + G = nx.DiGraph() + G.add_node("a", name="Alice") + G.add_node("b", name="Bob") + G.add_node("c", name="Charlie") + G.add_node("d", name="David") + G.add_edge("a", "b", weight=10) + G.add_edge("b", "c", weight=20) + G.add_edge("c", "d", weight=5) + + qry = """ + MATCH (a)-[r*2]->(c) + WHERE SINGLE(edge IN r WHERE edge.weight > 5) + RETURN a.name, c.name + """ + res = GrandCypher(G).run(qry) + + # Path a->b->c has TWO edges with weight > 5, so SINGLE is false + assert "Alice" not in res["a.name"] + + +class TestSIZEFunction: + """Tests for size() function with lists. + Moved from test_list_predicates.py""" + + def test_size_of_relationships(self): + """Test size() function on relationships list""" + G = nx.DiGraph() + G.add_node("a", name="Alice") + G.add_node("b", name="Bob") + G.add_node("c", name="Charlie") + G.add_node("d", name="David") + G.add_edge("a", "b", weight=10) + G.add_edge("b", "c", weight=20) + G.add_edge("c", "d", weight=5) + + qry = """ + MATCH (a)-[r*1..3]->(b) + WHERE size(relationships(r)) = 2 + RETURN a.name, b.name + """ + res = GrandCypher(G).run(qry) + + # Should find all 2-hop paths + assert "Alice" in res["a.name"] + assert "Bob" in res["a.name"] + + def test_size_in_return_clause(self): + """Test size() in RETURN clause""" + G = nx.DiGraph() + G.add_node("a", name="Alice") + G.add_node("b", name="Bob") + G.add_node("c", name="Charlie") + G.add_node("d", name="David") + G.add_edge("a", "b", weight=10) + G.add_edge("b", "c", weight=20) + G.add_edge("c", "d", weight=5) + + qry = """ + MATCH (a)-[r*1..2]->(b) + RETURN a.name, b.name, size(relationships(r)) AS path_length + """ + res = GrandCypher(G).run(qry) + + assert "Alice" in res["a.name"] + assert 1 in res["path_length"] # 1-hop paths + assert 2 in res["path_length"] # 2-hop paths + + def test_size_with_where_comparison(self): + """Test size() with comparison operators""" + G = nx.DiGraph() + G.add_node("a", name="Alice") + G.add_node("b", name="Bob") + G.add_node("c", name="Charlie") + G.add_node("d", name="David") + G.add_edge("a", "b", weight=10) + G.add_edge("b", "c", weight=20) + G.add_edge("c", "d", weight=5) + + qry = """ + MATCH (a)-[r*1..3]->(b) + WHERE size(relationships(r)) >= 2 + RETURN a.name, b.name, size(relationships(r)) AS hops + """ + res = GrandCypher(G).run(qry) + + # Should only get paths with 2 or more hops + assert all(h >= 2 for h in res["hops"]) + + def test_size_combined_with_predicates(self): + """Test size() combined with list predicates""" + G = nx.DiGraph() + G.add_node("a", name="Alice") + G.add_node("b", name="Bob") + G.add_node("c", name="Charlie") + G.add_node("d", name="David") + G.add_edge("a", "b", weight=10) + G.add_edge("b", "c", weight=20) + G.add_edge("c", "d", weight=5) + + qry = """ + MATCH (a)-[r*1..3]->(b) + WHERE size(relationships(r)) > 1 + AND ALL(edge IN r WHERE edge.weight > 0) + RETURN a.name, b.name, size(relationships(r)) AS path_len + """ + res = GrandCypher(G).run(qry) + + # Paths with >1 hop and all positive weights + assert all(pl > 1 for pl in res["path_len"]) + + def test_size_in_aggregation_context(self): + """Test size() with aggregation functions""" + G = nx.DiGraph() + G.add_node("a", name="Alice") + G.add_node("b", name="Bob") + G.add_node("c", name="Charlie") + G.add_node("d", name="David") + G.add_edge("a", "b", weight=10) + G.add_edge("b", "c", weight=20) + G.add_edge("c", "d", weight=5) + + qry = """ + MATCH (a)-[r*1..2]->(b) + RETURN a.name, COUNT(r) AS path_count, AVG(size(relationships(r))) AS avg_length + """ + res = GrandCypher(G).run(qry) + + assert "Alice" in res["a.name"] + assert all(isinstance(c, int) for c in res["path_count"]) + # Check that avg_length is calculated for nested scalar function + assert all(isinstance(al, (int, float)) for al in res["avg_length"]) + + +class TestListPredicatesCombined: + """Tests combining multiple list predicates. + Moved from test_list_predicates.py""" + + def test_all_and_any_combined(self): + """Test ALL and ANY in same WHERE clause""" + G = nx.DiGraph() + G.add_node("a", name="Alice") + G.add_node("b", name="Bob") + G.add_node("c", name="Charlie") + G.add_node("d", name="David") + G.add_edge("a", "b", weight=10) + G.add_edge("b", "c", weight=20) + G.add_edge("c", "d", weight=5) + + qry = """ + MATCH (a)-[r*2]->(c) + WHERE ALL(edge IN r WHERE edge.weight > 0) + AND ANY(edge IN r WHERE edge.weight > 15) + RETURN a.name, c.name + """ + res = GrandCypher(G).run(qry) + + # Path must have all edges positive AND at least one > 15 + assert "Alice" in res["a.name"] + + def test_single_and_none_combined(self): + """Test SINGLE and NONE in same WHERE clause""" + G = nx.DiGraph() + G.add_node("a", name="Alice") + G.add_node("b", name="Bob") + G.add_node("c", name="Charlie") + G.add_node("d", name="David") + G.add_edge("a", "b", weight=10) + G.add_edge("b", "c", weight=20) + G.add_edge("c", "d", weight=5) + + qry = """ + MATCH (a)-[r*2]->(c) + WHERE SINGLE(edge IN r WHERE edge.weight > 15) + AND NONE(edge IN r WHERE edge.weight < 5) + RETURN a.name, c.name + """ + res = GrandCypher(G).run(qry) + + # Exactly one edge > 15 and no edges < 5 + assert "Alice" in res["a.name"] # a->b->c: [10, 20] + + def test_nested_list_expressions(self): + """Test predicates with size() and other functions""" + G = nx.DiGraph() + G.add_node("a", name="Alice") + G.add_node("b", name="Bob") + G.add_node("c", name="Charlie") + G.add_node("d", name="David") + G.add_edge("a", "b", weight=10) + G.add_edge("b", "c", weight=20) + G.add_edge("c", "d", weight=5) + + qry = """ + MATCH (a)-[r*1..3]->(b) + WHERE size(relationships(r)) = 2 + AND ALL(edge IN r WHERE edge.weight >= 10) + AND ANY(edge IN r WHERE edge.weight = 20) + RETURN a.name, b.name + """ + res = GrandCypher(G).run(qry) + + # 2-hop paths where all weights >= 10 and at least one = 20 + assert "Alice" in res["a.name"] + + def test_complex_predicate_logic(self): + """Test complex boolean logic with predicates""" + G = nx.DiGraph() + G.add_node("a", name="Alice") + G.add_node("b", name="Bob") + G.add_node("c", name="Charlie") + G.add_node("d", name="David") + G.add_edge("a", "b", weight=10) + G.add_edge("b", "c", weight=20) + G.add_edge("c", "d", weight=5) + + qry = """ + MATCH (a)-[r*1..3]->(b) + WHERE (size(relationships(r)) = 1 AND ANY(edge IN r WHERE edge.weight = 10)) + OR (size(relationships(r)) = 2 AND ALL(edge IN r WHERE edge.weight >= 10)) + RETURN a.name, b.name, size(relationships(r)) AS hops + """ + res = GrandCypher(G).run(qry) + + # Should match: 1-hop with weight=10 OR 2-hop with all weights>=10 + assert "Alice" in res["a.name"]