Skip to content

Commit 747a320

Browse files
Improve Matter DNS-SD service parsing (#1144)
Co-authored-by: Martin Hjelmare <[email protected]>
1 parent 3756178 commit 747a320

File tree

4 files changed

+88
-24
lines changed

4 files changed

+88
-24
lines changed

.pre-commit-config.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,9 @@ repos:
8585
name: 🌟 Starring code with pylint
8686
language: system
8787
types: [python]
88-
entry: scripts/run-in-env.sh pylint
88+
entry: scripts/run-in-env.sh pylint matter_server/ tests/
89+
require_serial: true
90+
pass_filenames: false
8991
- id: trailing-whitespace
9092
name: ✄ Trim Trailing Whitespace
9193
language: system

matter_server/server/device_controller.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from datetime import datetime
1414
from functools import cached_property, lru_cache
1515
import logging
16+
import re
1617
import secrets
1718
import time
1819
from typing import TYPE_CHECKING, Any, cast
@@ -128,6 +129,10 @@
128129
0, Clusters.IcdManagement.Attributes.AttributeList
129130
)
130131

132+
RE_MDNS_SERVICE_NAME = re.compile(
133+
rf"^([0-9A-Fa-f]{{16}})-([0-9A-Fa-f]{{16}})\.{re.escape(MDNS_TYPE_OPERATIONAL_NODE)}$"
134+
)
135+
131136

132137
# pylint: disable=too-many-lines,too-many-instance-attributes,too-many-public-methods
133138

@@ -152,7 +157,6 @@ def __init__(
152157
# we keep the last events in memory so we can include them in the diagnostics dump
153158
self.event_history: deque[Attribute.EventReadResult] = deque(maxlen=25)
154159
self._compressed_fabric_id: int | None = None
155-
self._fabric_id_hex: str | None = None
156160
self._wifi_credentials_set: bool = False
157161
self._thread_credentials_set: bool = False
158162
self._setup_node_tasks = dict[int, asyncio.Task]()
@@ -179,7 +183,6 @@ async def initialize(self) -> None:
179183
self._compressed_fabric_id = (
180184
await self._chip_device_controller.get_compressed_fabric_id()
181185
)
182-
self._fabric_id_hex = hex(self._compressed_fabric_id)[2:]
183186
await load_local_updates(self._ota_provider_dir)
184187

185188
async def start(self) -> None:
@@ -245,8 +248,10 @@ async def stop(self) -> None:
245248
LOGGER.debug("Stopped.")
246249

247250
@property
248-
def compressed_fabric_id(self) -> int | None:
251+
def compressed_fabric_id(self) -> int:
249252
"""Return the compressed fabric id."""
253+
if self._compressed_fabric_id is None:
254+
raise RuntimeError("Compressed Fabric ID not set")
250255
return self._compressed_fabric_id
251256

252257
@property
@@ -1524,25 +1529,36 @@ def _on_mdns_service_state_change(
15241529
)
15251530
return
15261531

1527-
if service_type == MDNS_TYPE_OPERATIONAL_NODE:
1528-
if self._fabric_id_hex is None or self._fabric_id_hex not in name.lower():
1529-
# filter out messages that are not for our fabric
1530-
return
1531-
# process the event with a debounce timer
1532+
if service_type != MDNS_TYPE_OPERATIONAL_NODE:
1533+
return
1534+
1535+
if not (match := RE_MDNS_SERVICE_NAME.match(name)):
1536+
LOGGER.getChild("mdns").warning(
1537+
"Service name doesn't match expected operational node pattern: %s", name
1538+
)
1539+
return
1540+
1541+
fabric_id_hex, node_id_hex = match.groups()
1542+
1543+
# Filter messages of other fabrics
1544+
if int(fabric_id_hex, 16) != self.compressed_fabric_id:
1545+
return
1546+
1547+
# Process the event with a debounce timer
15321548
self._mdns_event_timer[name] = self._loop.call_later(
1533-
0.5, self._on_mdns_operational_node_state, name, state_change
1549+
0.5,
1550+
self._on_mdns_operational_node_state,
1551+
name,
1552+
int(node_id_hex, 16),
1553+
state_change,
15341554
)
15351555

15361556
def _on_mdns_operational_node_state(
1537-
self, name: str, state_change: ServiceStateChange
1557+
self, name: str, node_id: int, state_change: ServiceStateChange
15381558
) -> None:
15391559
"""Handle a (operational) Matter node MDNS state change."""
15401560
self._mdns_event_timer.pop(name, None)
1541-
logger = LOGGER.getChild("mdns")
1542-
# the mdns name is constructed as [fabricid]-[nodeid]._matter._tcp.local.
1543-
# extract the node id from the name
1544-
node_id = int(name.split("-")[1].split(".")[0], 16)
1545-
node_logger = self.get_node_logger(logger, node_id)
1561+
node_logger = self.get_node_logger(LOGGER.getChild("mdns"), node_id)
15461562

15471563
if not (node := self._nodes.get(node_id)):
15481564
return # this should not happen, but guard just in case

matter_server/server/server.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import asyncio
66
from functools import cached_property, partial
7+
import inspect
78
import ipaddress
89
import logging
910
import os
@@ -315,18 +316,24 @@ def register_api_command(
315316

316317
def _register_api_commands(self) -> None:
317318
"""Register all methods decorated as api_command."""
318-
for cls in (self, self._device_controller, self.vendor_info):
319-
for attr_name in dir(cls):
319+
for obj in (self, self._device_controller, self.vendor_info):
320+
cls = obj.__class__
321+
for attr_name, attr in inspect.getmembers(cls):
320322
if attr_name.startswith("_"):
321323
continue
322-
val = getattr(cls, attr_name)
323-
if not hasattr(val, "api_cmd"):
324+
325+
if isinstance(attr, property):
326+
continue # skip properties
327+
328+
# attr is the (unbound) function, we can check for the decorator
329+
if not hasattr(attr, "api_cmd"):
324330
continue
325-
if hasattr(val, "mock_calls"):
326-
# filter out mocks
331+
if hasattr(attr, "mock_calls"):
327332
continue
328-
# method is decorated with our api decorator
329-
self.register_api_command(val.api_cmd, val)
333+
334+
# Get the instance method before registering
335+
bound_method = getattr(obj, attr_name)
336+
self.register_api_command(attr.api_cmd, bound_method)
330337

331338
async def _handle_info(self, request: web.Request) -> web.Response:
332339
"""Handle info endpoint to serve basic server (version) info."""

tests/test_device_controller.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
"""Device controller tests."""
2+
3+
import pytest
4+
5+
from matter_server.server.device_controller import RE_MDNS_SERVICE_NAME
6+
7+
8+
@pytest.mark.parametrize(
9+
("name", "expected"),
10+
[
11+
(
12+
"D22DC25523A78A89-0000000000000125._matter._tcp.local.",
13+
("D22DC25523A78A89", "0000000000000125"),
14+
),
15+
(
16+
"d22dc25523a78a89-0000000000000125._matter._tcp.local.",
17+
("d22dc25523a78a89", "0000000000000125"),
18+
),
19+
],
20+
)
21+
def test_valid_mdns_service_names(name, expected):
22+
"""Test valid mDNS service names."""
23+
match = RE_MDNS_SERVICE_NAME.match(name)
24+
assert match is not None
25+
assert match.groups() == expected
26+
27+
28+
@pytest.mark.parametrize(
29+
"name",
30+
[
31+
"D22DC25523A78A89-0000000000000125 (2)._matter._tcp.local.",
32+
"D22DC25523A78A89-0000000000000125.._matter._tcp.local.",
33+
"G22DC25523A78A89-0000000000000125._matter._tcp.local.", # invalid hex
34+
"D22DC25523A78A89-0000000000000125._matterc._udp.local.",
35+
],
36+
)
37+
def test_invalid_mdns_service_names(name):
38+
"""Test invalid mDNS service names."""
39+
assert RE_MDNS_SERVICE_NAME.match(name) is None

0 commit comments

Comments
 (0)