From 67c6c4cb396cc4a27a21baeff35c03442c4bddf9 Mon Sep 17 00:00:00 2001 From: Alex Batisse Date: Wed, 31 Dec 2025 11:56:32 +0100 Subject: [PATCH 1/4] refactor: Refactor hover lookup to iterate over nodes --- pyproject.toml | 1 + src/craft_ls/core.py | 107 ++++++++++++++++++++++++++--------------- src/craft_ls/server.py | 5 +- src/craft_ls/types_.py | 19 +++++++- uv.lock | 11 +++++ 5 files changed, 99 insertions(+), 44 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6edf9bf..4f9cc95 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ dependencies = [ "referencing>=0.36.2", "jsonref>=1.1.0", "jsonpath-ng>=1.7.0", + "more-itertools>=10.8.0", ] classifiers = [ "Development Status :: 3 - Alpha", diff --git a/src/craft_ls/core.py b/src/craft_ls/core.py index 30305b2..84209fd 100644 --- a/src/craft_ls/core.py +++ b/src/craft_ls/core.py @@ -4,7 +4,7 @@ import re from collections import deque from importlib.resources import files -from itertools import chain, tee +from itertools import chain from typing import Iterable, cast import jsonref @@ -15,6 +15,7 @@ from jsonschema.protocols import Validator from jsonschema.validators import validator_for from lsprotocol import types as lsp +from more_itertools import peekable from referencing import Registry, Resource from referencing.jsonschema import DRAFT202012 from yaml.events import ( @@ -31,16 +32,18 @@ BlockEndToken, BlockMappingStartToken, BlockSequenceStartToken, + KeyToken, ScalarToken, Token, + ValueToken, ) from craft_ls.types_ import ( CompleteParsedResult, + DocumentNode, IncompleteParsedResult, MissingTypeCharmcraftValidator, MissingTypeSnapcraftValidator, - Node, ParsedResult, Schema, YamlDocument, @@ -202,9 +205,11 @@ def robust_load(instance_document: str) -> tuple[YamlDocument, list[Event]]: return cast(YamlDocument, yaml.safe_load(truncated_file)), truncated_file -def segmentize_nodes(root: yaml.CollectionNode) -> list[tuple[tuple[str, ...], Node]]: +def segmentize_nodes( + root: yaml.CollectionNode, +) -> list[tuple[tuple[str, ...], DocumentNode]]: """Flatten graph into path segments.""" - segments: list[tuple[tuple[str, ...], Node]] = [] + segments: list[tuple[tuple[str, ...], DocumentNode]] = [] nodes = list(root.value) for node_pair in nodes: @@ -217,7 +222,7 @@ def _do_segmentize_nodes( first: yaml.CollectionNode, second: yaml.CollectionNode, prefix: tuple[str, ...] | None = None, -) -> list[tuple[tuple[str, ...], Node]]: +) -> list[tuple[tuple[str, ...], DocumentNode]]: """Recursive node segmentation. Craft tools don't usually go over three levels, so we don't have to worry about recursion limits. @@ -227,7 +232,7 @@ def _do_segmentize_nodes( match second: case yaml.ScalarNode(end_mark=selection_end): - current_node = Node( + current_node = DocumentNode( value=first.value, start=first.start_mark, end=first.end_mark, @@ -236,7 +241,7 @@ def _do_segmentize_nodes( segments.append((prefix + (str(first.value),), current_node)) case yaml.MappingNode(end_mark=selection_end, value=children): - current_node = Node( + current_node = DocumentNode( value=first.value, start=first.start_mark, end=first.end_mark, @@ -256,8 +261,9 @@ def _do_segmentize_nodes( ) ) - case yaml.CollectionNode(end_mark=selection_end): - current_node = Node( + case yaml.SequenceNode(end_mark=selection_end): + print(selection_end.line) + current_node = DocumentNode( value=first.value, start=first.start_mark, end=first.end_mark, @@ -271,7 +277,9 @@ def _do_segmentize_nodes( def get_diagnostics( - validator: Validator, instance: YamlDocument, segments: dict[tuple[str, ...], Node] + validator: Validator, + instance: YamlDocument, + segments: dict[tuple[str, ...], DocumentNode], ) -> list[lsp.Diagnostic]: """Validate a document against its schema.""" diagnostics = [] @@ -351,14 +359,8 @@ def get_diagnostics( return diagnostics -def peek(tee_iterator: Iterable[Token]) -> Token | None: - """Return the next value without moving the input forward.""" - [forked_iterator] = tee(tee_iterator, 1) - return next(forked_iterator, None) - - def get_diagnostic_range( - document_segments: dict[tuple[str, ...], Node], diag_segments: Iterable[str] + document_segments: dict[tuple[str, ...], DocumentNode], diag_segments: Iterable[str] ) -> lsp.Range: """Link the validation error to the position in the original document.""" if ( @@ -409,43 +411,68 @@ def get_description_from_path(path: Iterable[str | int], schema: Schema) -> str: return MISSING_DESC -def get_schema_path_from_token_position( - position: lsp.Position, tokens: list[Token] -) -> deque[str] | None: - """Parse the document to find the path to the current position.""" +def get_exact_cursor_path(position: lsp.Position, tokens: list[Token]) -> deque[str]: # noqa: C901 + """Get the exact path to the cursor position.""" current_path: deque[str] = deque() + iterator = peekable(tokens) last_scalar_token: str = "" - start_mark: yaml.Mark - end_mark: yaml.Mark + previous: Token | None = None + token: Token | None = None + next_token: Token | None = None + + for token in iterator: + next_token = iterator.peek(None) + early_stop = ( + not next_token + or next_token.start_mark.line > position.line + or ( + next_token.start_mark.line >= position.line + and next_token.start_mark.column >= position.character + ) + ) + if early_stop: + break - for token in tokens: match token: case BlockMappingStartToken() | BlockSequenceStartToken(): - current_path.append(last_scalar_token) - + if last_scalar_token: + current_path.append(last_scalar_token) case BlockEndToken(): - current_path.pop() + if current_path: + current_path.pop() - case ScalarToken(value=value, start_mark=start_mark, end_mark=end_mark): - is_line_matching = start_mark.line == position.line - is_col_matching = ( - start_mark.column <= position.character <= end_mark.column - ) - if is_line_matching and is_col_matching: - current_path.append(value) - current_path.remove("") - return current_path + case KeyToken(): + last_scalar_token = "" - else: + case ScalarToken(value=value): + if isinstance(previous, yaml.KeyToken): last_scalar_token = value - case _: - continue + previous = token + + if ( + isinstance(token, (ValueToken, ScalarToken)) + and last_scalar_token + and next_token + ): + current_path.append(last_scalar_token) + + return current_path + + +def get_node_path_from_token_position( + position: lsp.Position, segments: dict[tuple[str, ...], DocumentNode] +) -> tuple[str, ...] | None: + """Find the innermost node path corresponding to the current position.""" + for segment, node in reversed(segments.items()): + if node.contains(position): + return segment + return None def list_symbols( - instance: YamlDocument, segments: dict[tuple[str, ...], Node] + instance: YamlDocument, segments: dict[tuple[str, ...], DocumentNode] ) -> list[lsp.DocumentSymbol]: """List document symbols. diff --git a/src/craft_ls/server.py b/src/craft_ls/server.py index 62413c3..83bba30 100644 --- a/src/craft_ls/server.py +++ b/src/craft_ls/server.py @@ -12,7 +12,7 @@ from craft_ls.core import ( get_description_from_path, get_diagnostics, - get_schema_path_from_token_position, + get_node_path_from_token_position, get_validator_and_parse, list_symbols, segmentize_nodes, @@ -141,12 +141,13 @@ def hover(ls: CraftLanguageServer, params: lsp.HoverParams) -> lsp.Hover | None: match ls.index.get(Path(uri)): case IndexEntry(validator_found, tokens=tokens): + case IndexEntry(validator_found, segments=segments): validator = validator_found case _: return None - if not (path := get_schema_path_from_token_position(position=pos, tokens=tokens)): + if not (path := get_node_path_from_token_position(position=pos, segments=segments)): return None description = get_description_from_path( diff --git a/src/craft_ls/types_.py b/src/craft_ls/types_.py index c7f50ff..ce20fac 100644 --- a/src/craft_ls/types_.py +++ b/src/craft_ls/types_.py @@ -5,6 +5,7 @@ from typing import Any, Generator, NewType from jsonschema import ValidationError, Validator +from lsprotocol import types as lsp from yaml import CollectionNode, Mark, Token # We can probably do better, but that will do for now @@ -36,7 +37,7 @@ class IncompleteParsedResult(ParsedResult): @dataclass -class Node: +class DocumentNode: """Document node.""" value: str @@ -44,6 +45,19 @@ class Node: end: Mark selection_end: Mark + def contains(self, position: lsp.Position) -> bool: + """Is position contained in node range?""" + range_start_before = self.start.line < position.line or ( + self.start.line == position.line and self.start.column <= position.character + ) + + range_end_after = self.selection_end.line > position.line or ( + self.selection_end.line == position.line + and self.selection_end.column >= position.character + ) + + return range_start_before and range_end_after + @dataclass class IndexEntry: @@ -52,7 +66,8 @@ class IndexEntry: validator: Validator tokens: list[Token] instance: YamlDocument - segments: dict[tuple[str, ...], Node] + segments: dict[tuple[str, ...], DocumentNode] + version: int | None class MissingTypeCharmcraftValidator: diff --git a/uv.lock b/uv.lock index c11e25c..9c0b9ee 100644 --- a/uv.lock +++ b/uv.lock @@ -41,6 +41,7 @@ dependencies = [ { name = "jsonref" }, { name = "jsonschema" }, { name = "lsprotocol" }, + { name = "more-itertools" }, { name = "pygls" }, { name = "pyyaml" }, { name = "referencing" }, @@ -68,6 +69,7 @@ requires-dist = [ { name = "jsonref", specifier = ">=1.1.0" }, { name = "jsonschema", specifier = ">=4.23.0" }, { name = "lsprotocol", specifier = ">=2025.0.0" }, + { name = "more-itertools", specifier = ">=10.8.0" }, { name = "pygls", specifier = ">=1.1.0" }, { name = "pyyaml", specifier = ">=6.0.2" }, { name = "referencing", specifier = ">=0.36.2" }, @@ -221,6 +223,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7b/f0/92f2d609d6642b5f30cb50a885d2bf1483301c69d5786286500d15651ef2/lsprotocol-2025.0.0-py3-none-any.whl", hash = "sha256:f9d78f25221f2a60eaa4a96d3b4ffae011b107537facee61d3da3313880995c7", size = 76250, upload-time = "2025-06-17T21:30:19.455Z" }, ] +[[package]] +name = "more-itertools" +version = "10.8.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ea/5d/38b681d3fce7a266dd9ab73c66959406d565b3e85f21d5e66e1181d93721/more_itertools-10.8.0.tar.gz", hash = "sha256:f638ddf8a1a0d134181275fb5d58b086ead7c6a72429ad725c67503f13ba30bd", size = 137431, upload-time = "2025-09-02T15:23:11.018Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a4/8e/469e5a4a2f5855992e425f3cb33804cc07bf18d48f2db061aec61ce50270/more_itertools-10.8.0-py3-none-any.whl", hash = "sha256:52d4362373dcf7c52546bc4af9a86ee7c4579df9a8dc268be0a2f949d376cc9b", size = 69667, upload-time = "2025-09-02T15:23:09.635Z" }, +] + [[package]] name = "mypy" version = "1.19.1" From b3e7491f3b25a75c585a6188af6ffd5c8da84465 Mon Sep 17 00:00:00 2001 From: Alex Batisse Date: Wed, 31 Dec 2025 11:57:16 +0100 Subject: [PATCH 2/4] fix: Re-trigger document parsing if needed --- src/craft_ls/server.py | 46 ++++++++++++++++++++++++++---------------- 1 file changed, 29 insertions(+), 17 deletions(-) diff --git a/src/craft_ls/server.py b/src/craft_ls/server.py index 83bba30..80a069e 100644 --- a/src/craft_ls/server.py +++ b/src/craft_ls/server.py @@ -41,25 +41,40 @@ def __init__( text_document_sync_kind, notebook_document_sync, ) - self.index: dict[Path, IndexEntry | None] = {} + self.index: dict[str, IndexEntry | None] = {} - def parse_file(self, file_uri: Path, source: str) -> IndexEntry | None: + def parse_file(self, file_uri: str) -> IndexEntry | None: """Parse a document into tokens, nodes and whatnot. The result is cached so we can access it in endpoints. """ - match get_validator_and_parse(file_uri.stem, source): + document = self.workspace.get_text_document(file_uri) + match get_validator_and_parse(Path(file_uri).stem, document.source): case None: self.index[file_uri] = None case validator, ParsedResult(tokens, instance, nodes): segments_nodes = segmentize_nodes(nodes) self.index[file_uri] = IndexEntry( - validator, tokens, instance, dict(segments_nodes) + validator, tokens, instance, dict(segments_nodes), document.version ) return self.index[file_uri] + def get_or_update_index(self, file_uri: str) -> IndexEntry | None: + """Re-parse document if needed.""" + current_version = self.workspace.get_text_document(file_uri).version + entry = self.index.get(file_uri) + match entry: + case IndexEntry(version=cached_version) as cached: + if not cached_version or cached_version != current_version: + return self.parse_file( + file_uri, + ) + return cached + case None: + return None + server = CraftLanguageServer( name="craft-ls", @@ -77,9 +92,6 @@ def shorten_messages(diagnostics: list[lsp.Diagnostic]) -> None: def on_opened(ls: CraftLanguageServer, params: lsp.DidOpenTextDocumentParams) -> None: """Parse each document when it is opened.""" uri = params.text_document.uri - version = params.text_document.version - doc = ls.workspace.get_text_document(params.text_document.uri) - source = doc.source diagnostics = ( [ lsp.Diagnostic( @@ -95,8 +107,10 @@ def on_opened(ls: CraftLanguageServer, params: lsp.DidOpenTextDocumentParams) -> else [] ) - match ls.parse_file(Path(uri), source): - case IndexEntry(validator, instance=instance, segments=segments): + match ls.parse_file(uri): + case IndexEntry( + validator, instance=instance, segments=segments, version=version + ): diagnostics.extend(get_diagnostics(validator, instance, segments)) case _: @@ -115,13 +129,12 @@ def on_opened(ls: CraftLanguageServer, params: lsp.DidOpenTextDocumentParams) -> def on_changed(ls: CraftLanguageServer, params: lsp.DidOpenTextDocumentParams) -> None: """Parse each document when it is changed.""" uri = params.text_document.uri - version = params.text_document.version - doc = ls.workspace.get_text_document(params.text_document.uri) - source = doc.source diagnostics = [] - match ls.parse_file(Path(uri), source): - case IndexEntry(validator, instance=instance, segments=segments): + match ls.parse_file(uri): + case IndexEntry( + validator, instance=instance, segments=segments, version=version + ): diagnostics.extend(get_diagnostics(validator, instance, segments)) case _: @@ -139,8 +152,7 @@ def hover(ls: CraftLanguageServer, params: lsp.HoverParams) -> lsp.Hover | None: pos = params.position uri = params.text_document.uri - match ls.index.get(Path(uri)): - case IndexEntry(validator_found, tokens=tokens): + match ls.get_or_update_index(uri): case IndexEntry(validator_found, segments=segments): validator = validator_found @@ -174,7 +186,7 @@ def document_symbol( uri = params.text_document.uri symbols_results: list[lsp.DocumentSymbol] = [] - match ls.index.get(Path(uri)): + match ls.get_or_update_index(uri): case IndexEntry(instance=instance, segments=segments): symbols_results = list_symbols(instance, segments) From 130fde88568887c6baa43a865f81cdd1bd0eb1e9 Mon Sep 17 00:00:00 2001 From: Alex Batisse Date: Wed, 31 Dec 2025 11:57:49 +0100 Subject: [PATCH 3/4] feat: Initial document completion implementation --- src/craft_ls/core.py | 31 ++++++++++++++++++++++ src/craft_ls/server.py | 28 ++++++++++++++++++++ tests/test_core.py | 59 +++++++++++++++++++++++++++++++++--------- 3 files changed, 106 insertions(+), 12 deletions(-) diff --git a/src/craft_ls/core.py b/src/craft_ls/core.py index 84209fd..536a0be 100644 --- a/src/craft_ls/core.py +++ b/src/craft_ls/core.py @@ -523,3 +523,34 @@ def list_symbols( symbols.append(symbol) return symbols + + +def get_completion_items_from_path( + segments: Iterable[str], schema: Schema, instance: YamlDocument +) -> list[lsp.CompletionItem]: + """Get possible values for children nodes or enum values.""" + sub_instance = instance + sub_schema = schema + for segment in segments: + sub_schema = sub_schema.get("properties", {}).get(segment, {}) + sub_instance = sub_instance.get(segment, {}) + + already_present = ( + set(sub_instance.keys()) if isinstance(sub_instance, dict) else set() + ) + items = [] + + if "cons" in sub_schema.keys(): + items = [lsp.CompletionItem(label=str(key)) for key in [sub_schema["cons"]]] + + elif "enum" in sub_schema.keys(): + items = [ + lsp.CompletionItem(label=str(key)) for key in sub_schema["enum"] if key + ] + + elif "properties" in sub_schema.keys(): + items = [ + lsp.CompletionItem(label=str(key)) + for key in set(sub_schema["properties"].keys()) - already_present + ] + return items diff --git a/src/craft_ls/server.py b/src/craft_ls/server.py index 80a069e..626e0ed 100644 --- a/src/craft_ls/server.py +++ b/src/craft_ls/server.py @@ -10,8 +10,10 @@ from craft_ls import __version__ from craft_ls.core import ( + get_completion_items_from_path, get_description_from_path, get_diagnostics, + get_exact_cursor_path, get_node_path_from_token_position, get_validator_and_parse, list_symbols, @@ -193,6 +195,32 @@ def document_symbol( return symbols_results +@server.feature( + lsp.TEXT_DOCUMENT_COMPLETION, lsp.CompletionOptions(trigger_characters=[" "]) +) +def completions( + ls: CraftLanguageServer, params: lsp.CompletionParams +) -> lsp.CompletionList | None: + """Suggest next element based on the document structure.""" + pos = params.position + uri = params.text_document.uri + items = [] + + match ls.get_or_update_index(uri): + case IndexEntry(validator_found, instance=instance, tokens=tokens): + validator = validator_found + + case _: + return None + + path = get_exact_cursor_path(pos, tokens) + items = get_completion_items_from_path( + segments=path, schema=cast(Schema, validator.schema), instance=instance + ) + + return lsp.CompletionList(is_incomplete=False, items=items) + + def start() -> None: """Start the server.""" server.start_io() diff --git a/tests/test_core.py b/tests/test_core.py index fe61b02..7be9029 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -14,7 +14,8 @@ get_description_from_path, get_diagnostic_range, get_diagnostics, - get_schema_path_from_token_position, + get_exact_cursor_path, + get_node_path_from_token_position, list_symbols, parse_tokens, segmentize_nodes, @@ -71,6 +72,7 @@ """ parsed_document = parse_tokens(document) +document_segments = segmentize_nodes(parsed_document.nodes) def test_get_description_first_level_ok() -> None: @@ -112,49 +114,82 @@ def test_get_description_unknown_path(key: str) -> None: assert description == MISSING_DESC -def test_get_path_from_position_first_level_ok() -> None: +def test_get_node_path_from_position_first_level_ok() -> None: # Given # line is 2 because of initial newline after """ + comment in the document position = lsp.Position(2, 5) # When - path = get_schema_path_from_token_position(position, parsed_document.tokens) + path = get_node_path_from_token_position(position, dict(document_segments)) # Then - assert path == deque(["productId"]) + assert path == ("productId",) -def test_get_path_from_position_nested_ok() -> None: +def test_get_node_path_from_position_nested_ok() -> None: # Given position = lsp.Position(5, 5) # When - path = get_schema_path_from_token_position(position, parsed_document.tokens) + path = get_node_path_from_token_position(position, dict(document_segments)) # Then - assert path == deque(["price", "amount"]) + assert path == ("price", "amount") -def test_get_path_from_comment_ko() -> None: +def test_get_node_path_from_outside_ko() -> None: # Given position = lsp.Position(1, 5) # comment line # When - path = get_schema_path_from_token_position(position, parsed_document.tokens) + path = get_node_path_from_token_position(position, dict(document_segments)) # Then assert not path -def test_get_path_from_empty_space_ko() -> None: +def test_get_node_path_from_value_ok() -> None: # Given position = lsp.Position(4, 10) # to the right of "price" # When - path = get_schema_path_from_token_position(position, parsed_document.tokens) + path = get_node_path_from_token_position(position, dict(document_segments)) # Then - assert not path + assert path == ("price",) + + +def test_get_cursor_path_from_token_ok() -> None: + # Given + position = lsp.Position(4, 3) # inside "price", current path should be root + + # When + path = get_exact_cursor_path(position, parsed_document.tokens) + + # Then + assert path == deque([]) + + +def test_get_cursor_path_from_nested_key_ok() -> None: + # Given + position = lsp.Position(5, 4) # inside "amount", current path should be "price" + + # When + path = get_exact_cursor_path(position, parsed_document.tokens) + + # Then + assert path == deque(["price"]) + + +def test_get_cursor_path_from_value_ok() -> None: + # Given + position = lsp.Position(3, 15) # inside "bar", current path should be "productName" + + # When + path = get_exact_cursor_path(position, parsed_document.tokens) + + # Then + assert path == deque(["productName"]) def test_values_are_not_flagged() -> None: From 0fc53236002d1703ed0c77ad32b8b54007a61df8 Mon Sep 17 00:00:00 2001 From: Alex Batisse Date: Wed, 31 Dec 2025 14:04:28 +0100 Subject: [PATCH 4/4] typo: Remove duplicated space --- snapcraft.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/snapcraft.yaml b/snapcraft.yaml index 09e1682..589180a 100644 --- a/snapcraft.yaml +++ b/snapcraft.yaml @@ -5,7 +5,7 @@ description: | base: core24 confinement: strict adopt-info: craft-ls -license: BSD-3-Clause +license: BSD-3-Clause parts: craft-ls: