Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ dependencies = [
"fastapi>=0.104.0",
"uvicorn>=0.24.0",
"pydantic>=2.0.0",
"openai>=1.52.0",
"lance>=0.17.0",
]
description = "Python bindings for the lance-graph Cypher engine"
authors = [{ name = "Lance Devs", email = "[email protected]" }]
Expand Down Expand Up @@ -43,8 +45,6 @@ build-backend = "maturin"
[project.optional-dependencies]
tests = ["pytest", "pyarrow>=14", "pandas", "ruff"]
dev = ["ruff", "pyright"]
llm = ["openai>=1.52.0"]
lance-storage = ["lance>=0.17.0"]

[project.scripts]
knowledge_graph = "knowledge_graph.main:main"
Expand Down
9 changes: 9 additions & 0 deletions python/python/knowledge_graph/cli/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from .ingest import extract_and_add, preview_extraction
from .interactive import list_datasets, run_interactive

__all__ = [
"run_interactive",
"list_datasets",
"preview_extraction",
"extract_and_add",
]
19 changes: 19 additions & 0 deletions python/python/knowledge_graph/cli/embedding_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""Shared helpers for preparing rows and embeddings (internal)."""

from __future__ import annotations

from .ingest import (
_assign_embeddings,
_format_entity_embedding_input,
_format_relationship_embedding_input,
_prepare_entity_rows,
_prepare_relationship_rows,
)

__all__ = [
"_assign_embeddings",
"_format_entity_embedding_input",
"_format_relationship_embedding_input",
"_prepare_entity_rows",
"_prepare_relationship_rows",
]
220 changes: 220 additions & 0 deletions python/python/knowledge_graph/cli/ingest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
"""Extraction preview and ingest helpers for the knowledge graph CLI."""

from __future__ import annotations

import hashlib
import json
import logging
from dataclasses import asdict, is_dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Mapping

from .. import extraction as kg_extraction

if TYPE_CHECKING:
from ..embeddings import EmbeddingGenerator
from ..service import LanceKnowledgeGraph

LOGGER = logging.getLogger(__name__)


def preview_extraction(source: str, extractor: kg_extraction.BaseExtractor) -> None:
"""Preview extracted knowledge from a text source or inline text."""
text = _resolve_text_input(source)
result = kg_extraction.preview_extraction(text, extractor=extractor)
print(json.dumps(_result_to_dict(result), indent=2))


def extract_and_add(
source: str,
service: LanceKnowledgeGraph,
extractor: kg_extraction.BaseExtractor,
*,
embedding_generator: EmbeddingGenerator | None = None,
) -> None:
"""Extract knowledge and append it to the backing graph."""
import pyarrow as pa

text = _resolve_text_input(source)
result = kg_extraction.preview_extraction(text, extractor=extractor)
entity_rows, name_to_id = _prepare_entity_rows(
result.entities, embedding_generator=embedding_generator
)
relationships = result.relationships

if not entity_rows and not relationships:
print("No candidate entities or relationships detected.")
return

if entity_rows:
entity_table = pa.Table.from_pylist(entity_rows)
service.upsert_table("Entity", entity_table, merge=True)
message = f"Upserted {entity_table.num_rows} entity rows into dataset 'Entity'."
print(message)

relationship_rows = _prepare_relationship_rows(
relationships,
name_to_id,
embedding_generator=embedding_generator,
)
if relationship_rows:
rel_table = pa.Table.from_pylist(relationship_rows)
service.upsert_table("RELATIONSHIP", rel_table, merge=True)
message = (
"Upserted "
f"{rel_table.num_rows} relationship rows into dataset "
"'RELATIONSHIP'."
)
print(message)


def _resolve_text_input(raw: str) -> str:
"""Load text from a file if it exists, otherwise treat the string as content."""
candidate = Path(raw)
if candidate.exists():
if candidate.is_dir():
raise IsADirectoryError(f"Expected text file, got directory: {candidate}")
return candidate.read_text(encoding="utf-8")
return raw


def _ensure_dict(item: object) -> dict:
if is_dataclass(item):
return asdict(item) # type: ignore[arg-type]
if isinstance(item, dict):
return item
raise TypeError(f"Unsupported extraction item type: {type(item)!r}")


def _result_to_dict(result: "kg_extraction.ExtractionResult") -> dict[str, list[dict]]:
return {
"entities": [asdict(entity) for entity in result.entities],
"relationships": [asdict(rel) for rel in result.relationships],
}


def _prepare_entity_rows(
entities: list[Any],
*,
embedding_generator: EmbeddingGenerator | None = None,
) -> tuple[list[dict[str, Any]], dict[str, str]]:
rows: list[dict[str, Any]] = []
name_to_id: dict[str, str] = {}
for entity in entities:
payload = _ensure_dict(entity)
name = str(payload.get("name", "")).strip()
entity_type = str(
payload.get("entity_type") or payload.get("type") or ""
).strip()
if not name:
continue
base = f"{name}|{entity_type}".encode("utf-8")
entity_id = hashlib.md5(base).hexdigest()
payload["entity_id"] = entity_id
payload["entity_type"] = entity_type or "UNKNOWN"
payload["name_lower"] = name.lower()
rows.append(payload)
name_to_id.setdefault(name.lower(), entity_id)
if embedding_generator and rows:
_assign_embeddings(
rows,
embedding_generator,
_format_entity_embedding_input,
)
return rows, name_to_id


def _prepare_relationship_rows(
relationships: list[Any],
name_to_id: dict[str, str],
*,
embedding_generator: EmbeddingGenerator | None = None,
) -> list[dict[str, Any]]:
rows: list[dict[str, Any]] = []
for relation in relationships:
payload = _ensure_dict(relation)
source_name = str(
payload.get("source_entity_name") or payload.get("source") or ""
).strip()
target_name = str(
payload.get("target_entity_name") or payload.get("target") or ""
).strip()
source_id = name_to_id.get(source_name.lower())
target_id = name_to_id.get(target_name.lower())
if not (source_id and target_id):
continue
payload["source_entity_id"] = source_id
payload["target_entity_id"] = target_id
payload["relationship_type"] = (
payload.get("relationship_type") or payload.get("type") or "RELATED_TO"
)
payload.setdefault("source_entity_name", source_name)
payload.setdefault("target_entity_name", target_name)
rows.append(payload)
if embedding_generator and rows:
_assign_embeddings(
rows,
embedding_generator,
_format_relationship_embedding_input,
)
return rows


def _assign_embeddings(
rows: list[dict[str, Any]],
embedding_generator: EmbeddingGenerator,
formatter: Callable[[Mapping[str, Any]], str],
) -> None:
texts: list[str] = []
indices: list[int] = []
for idx, row in enumerate(rows):
text = formatter(row)
if text:
texts.append(text)
indices.append(idx)
if not texts:
return
try:
vectors = embedding_generator.embed(texts)
except Exception as exc: # pragma: no cover - defensive logging path
LOGGER.warning("Failed to generate embeddings: %s", exc)
return
if len(vectors) != len(indices):
LOGGER.warning(
"Mismatch between embedding count and row count: expected %s, got %s",
len(indices),
len(vectors),
)
return
for idx, vector in zip(indices, vectors):
rows[idx]["embedding"] = vector


def _format_entity_embedding_input(row: Mapping[str, Any]) -> str:
name = str(row.get("name", "")).strip()
entity_type = str(row.get("entity_type", "")).strip()
context = str(row.get("context", "")).strip()
pieces = []
if name:
pieces.append(name)
if entity_type:
pieces.append(f"Type: {entity_type}")
if context:
pieces.append(f"Context: {context}")
return " | ".join(pieces)


def _format_relationship_embedding_input(row: Mapping[str, Any]) -> str:
source = str(row.get("source_entity_name") or row.get("source") or "").strip()
target = str(row.get("target_entity_name") or row.get("target") or "").strip()
relationship_type = str(row.get("relationship_type", "")).strip()
description = str(row.get("description", "")).strip()
core: list[str] = []
if source or target:
if relationship_type:
core.append(f"{source} -[{relationship_type}]-> {target}".strip())
else:
core.append(f"{source} -> {target}".strip())
if description:
core.append(f"Description: {description}")
return " | ".join(part for part in core if part)
117 changes: 117 additions & 0 deletions python/python/knowledge_graph/cli/interactive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""Interactive shell and CLI display helpers for the knowledge graph."""

from __future__ import annotations

import sys
from typing import TYPE_CHECKING

from ..store import LanceGraphStore

if TYPE_CHECKING:
from ..config import KnowledgeGraphConfig
from ..service import LanceKnowledgeGraph

if TYPE_CHECKING:
import pyarrow as pa


def list_datasets(config: "KnowledgeGraphConfig") -> None:
"""List the Lance datasets available under the configured root."""
store = LanceGraphStore(config)
store.ensure_layout()
datasets = store.list_datasets()
if not datasets:
print("No Lance datasets found. Load data or run extraction first.")
return
print("Available Lance datasets:")
for name, path in sorted(datasets.items()):
print(f" - {name}: {path}")


def run_interactive(service: "LanceKnowledgeGraph") -> None:
"""Enter an interactive shell for issuing Cypher queries."""
print("Lance Knowledge Graph interactive shell")
print("Type ':help' for commands, or 'quit' to exit.")

while True:
try:
text = input("kg> ").strip()
except EOFError:
print()
break

if not text:
continue
lowered = text.lower()
if lowered in {"quit", "exit", "q"}:
break
if text.startswith(":"):
_handle_command(text, service)
continue

_execute_query(service, text)


def _handle_command(command: str, service: "LanceKnowledgeGraph") -> None:
"""Handle meta-commands in the interactive shell."""
cmd = command.strip()
if cmd in {":help", ":h"}:
print("Commands:")
print(" :help Show this message")
print(" :datasets List persisted Lance datasets")
print(" :config Show the configured node/relationship mappings")
print(" quit/exit/q Leave the shell")
return
if cmd in {":datasets", ":ls"}:
list_datasets(service.store.config)
return
if cmd in {":config", ":schema"}:
_print_config_summary(service)
return
print(f"Unknown command: {command}")


def _print_config_summary(service: "LanceKnowledgeGraph") -> None:
"""Print a brief summary of the graph configuration."""
config = service.config
# GraphConfig does not currently expose direct iterators; rely on repr.
print("Graph configuration:")
print(f" {config!r}")


def _execute_query(service: "LanceKnowledgeGraph", statement: str) -> None:
"""Execute a single Cypher statement and print results."""
try:
result = service.run(statement)
except Exception as exc: # pragma: no cover - CLI feedback path
print(f"Query failed: {exc}", file=sys.stderr)
return

_print_table(result)


def _print_table(table: "pa.Table") -> None:
"""Render a PyArrow table in a simple textual format."""
if table.num_rows == 0:
print("(no rows)")
return

column_names = table.column_names
columns = [table.column(i).to_pylist() for i in range(len(column_names))]
widths = []
for name, values in zip(column_names, columns):
str_values = ["" if value is None else str(value) for value in values]
if str_values:
width = max(len(name), *(len(value) for value in str_values))
else:
width = len(name)
widths.append(width)

header = " | ".join(name.ljust(width) for name, width in zip(column_names, widths))
separator = "-+-".join("-" * width for width in widths)
print(header)
print(separator)
for row_values in zip(*columns):
str_row = ["" if value is None else str(value) for value in row_values]
line = " | ".join(value.ljust(width) for value, width in zip(str_row, widths))
print(line)
Loading
Loading