Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 69 additions & 14 deletions python/python/knowledge_graph/llm/qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,23 +64,47 @@ def ask_question(

schema_summary = summarize_schema(service)
type_hints = service.store.config.type_hints()
allowed_relationship_types = tuple(
str(t) for t in (type_hints.get("relationship_types") or ())
)
if not allowed_relationship_types:
discovered = _discover_relationship_types(service)
allowed_relationship_types = tuple(discovered)
if discovered:
LOGGER.debug(
"Discovered relationship_type values from dataset: %s",
", ".join(discovered),
)
type_hint_lines = build_type_hint_lines(type_hints)

# Discover actual relationship types from data
discovered_rel_types = _discover_relationship_types(service)
if discovered_rel_types:
allowed_relationship_types = tuple(discovered_rel_types)
LOGGER.debug(
"Discovered relationship_type values from dataset: %s",
", ".join(discovered_rel_types),
)
else:
# Fall back to config types if discovery fails
allowed_relationship_types = tuple(
str(t) for t in (type_hints.get("relationship_types") or ())
)

# Discover actual entity types from data
discovered_entity_types = _discover_entity_types(service)
if discovered_entity_types:
LOGGER.debug(
"Discovered entity_type values from dataset: %s",
", ".join(discovered_entity_types),
)
else:
# Fall back to config types if discovery fails
discovered_entity_types = list(
str(t) for t in (type_hints.get("entity_types") or ())
)

# Use discovered types in the prompt instead of config types
actual_type_hints = dict(type_hints)
if discovered_rel_types:
actual_type_hints["relationship_types"] = tuple(discovered_rel_types)
if discovered_entity_types:
actual_type_hints["entity_types"] = tuple(discovered_entity_types)

type_hint_lines = build_type_hint_lines(actual_type_hints)
query_prompt = build_query_prompt(
question,
schema_summary,
type_hint_lines,
type_hints,
actual_type_hints,
seed_entities,
seed_neighbors,
)
Expand Down Expand Up @@ -291,13 +315,44 @@ def replace_in(match: re.Match[str]) -> str:


def _discover_relationship_types(service: LanceKnowledgeGraph) -> list[str]:
"""Discover distinct relationship_type values from the dataset as a fallback."""
"""Discover distinct relationship_type values from the dataset.

Results are cached on the service object to avoid repeated table loads.
"""
if hasattr(service, "_cached_rel_types"):
return service._cached_rel_types

try:
table = service.load_table("RELATIONSHIP")
if "relationship_type" in table.column_names:
values = table.column("relationship_type").to_pylist()
distinct = sorted({str(v) for v in values if v is not None and str(v)})
service._cached_rel_types = distinct
return distinct
except Exception as exc: # pragma: no cover - defensive
LOGGER.debug("Unable to discover relationship types: %s", exc)

service._cached_rel_types = []
return []


def _discover_entity_types(service: LanceKnowledgeGraph) -> list[str]:
"""Discover distinct entity_type values from the dataset.

Results are cached on the service object to avoid repeated table loads.
"""
if hasattr(service, "_cached_entity_types"):
return service._cached_entity_types

try:
table = service.load_table("Entity")
if "entity_type" in table.column_names:
values = table.column("entity_type").to_pylist()
distinct = sorted({str(v) for v in values if v is not None and str(v)})
service._cached_entity_types = distinct
return distinct
except Exception as exc: # pragma: no cover - defensive
LOGGER.debug("Unable to discover entity types: %s", exc)

service._cached_entity_types = []
return []
Loading