Skip to content

Commit 71c0351

Browse files
committed
Enhance VM entity ID generation to ensure uniqueness and handle collisions in Unraid switch integration
1 parent 4a0fc28 commit 71c0351

File tree

2 files changed

+125
-26
lines changed

2 files changed

+125
-26
lines changed

custom_components/unraid/api/vm_operations.py

Lines changed: 48 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -81,17 +81,18 @@ async def get_vms(self) -> List[Dict[str, Any]]:
8181

8282
# Collect VM information in a single command
8383
try:
84-
# Use a single command to collect all VM information
84+
# Use a more robust command that properly handles VM names with spaces and special characters
85+
# We use a unique delimiter (§§§) that's unlikely to appear in VM names or XML data
8586
cmd = (
8687
"if [ -x /etc/rc.d/rc.libvirt ] && /etc/rc.d/rc.libvirt status | grep -q 'is currently running'; then "
87-
" for vm in $(virsh list --all --name); do "
88-
" if [ -n \"$vm\" ]; then "
88+
" virsh list --all --name | while IFS= read -r vm; do "
89+
" if [ -n \"$vm\" ] && [ \"$vm\" != \" \" ]; then "
8990
" state=$(virsh domstate \"$vm\" 2>/dev/null || echo 'unknown'); "
9091
" info=$(virsh dominfo \"$vm\" 2>/dev/null); "
91-
" cpus=$(echo \"$info\" | grep 'CPU(s)' | awk '{print $2}'); "
92-
" mem=$(echo \"$info\" | grep 'Max memory' | sed 's/Max memory://g' | xargs); "
93-
" xml=$(virsh dumpxml \"$vm\" 2>/dev/null | grep -A5 \"<os>\"); "
94-
" echo \"$vm|$state|$cpus|$mem|$xml\"; "
92+
" cpus=$(echo \"$info\" | grep 'CPU(s)' | awk '{print $2}' | head -1); "
93+
" mem=$(echo \"$info\" | grep 'Max memory' | sed 's/Max memory://g' | xargs | head -1); "
94+
" xml=$(virsh dumpxml \"$vm\" 2>/dev/null | grep -A5 \"<os>\" | tr '\\n' ' '); "
95+
" echo \"$vm§§§$state§§§$cpus§§§$mem§§§$xml\"; "
9596
" fi; "
9697
" done; "
9798
"else "
@@ -107,18 +108,26 @@ async def get_vms(self) -> List[Dict[str, Any]]:
107108

108109
vms = []
109110
for line in result.stdout.splitlines():
110-
if not line.strip() or '|' not in line:
111+
if not line.strip() or '§§§' not in line:
111112
continue
112113

113114
try:
114-
parts = line.split('|')
115+
# Split on our unique delimiter
116+
parts = line.split('§§§')
115117
if len(parts) >= 5:
116118
vm_name = parts[0].strip()
117119
status = VMState.parse(parts[1].strip())
118-
cpus = parts[2].strip()
119-
memory = parts[3].strip()
120+
cpus = parts[2].strip() or '0'
121+
memory = parts[3].strip() or '0'
120122
xml_data = parts[4].strip()
121123

124+
# Skip empty VM names (shouldn't happen but be safe)
125+
if not vm_name:
126+
_LOGGER.debug("Skipping VM with empty name")
127+
continue
128+
129+
_LOGGER.debug("Processing VM: '%s' with status: '%s'", vm_name, status)
130+
122131
# Determine OS type from XML data
123132
os_type = 'unknown'
124133
xml_lower = xml_data.lower()
@@ -149,6 +158,7 @@ async def get_vms(self) -> List[Dict[str, Any]]:
149158
_LOGGER.debug("Error processing VM line '%s': %s", line, str(vm_err))
150159
continue
151160

161+
_LOGGER.debug("Successfully processed %d VMs", len(vms))
152162
return vms
153163

154164
except Exception as virsh_err:
@@ -170,23 +180,27 @@ async def _get_vms_original(self) -> List[Dict[str, Any]]:
170180

171181
vms = []
172182
for line in result.stdout.splitlines():
173-
if not line.strip():
183+
vm_name = line.strip()
184+
if not vm_name:
174185
continue
175186

176187
try:
177-
vm_name = line.strip()
188+
_LOGGER.debug("Processing VM (fallback): '%s'", vm_name)
178189
status = await self.get_vm_status(vm_name)
179190
os_type = await self.get_vm_os_info(vm_name)
180191

181192
vms.append({
182193
"name": vm_name,
183194
"status": status,
184-
"os_type": os_type
195+
"os_type": os_type,
196+
"cpus": "0", # Add default values for consistency
197+
"memory": "0"
185198
})
186199
except Exception as vm_err:
187-
_LOGGER.debug("Error processing VM '%s': %s", line.strip(), str(vm_err))
200+
_LOGGER.debug("Error processing VM '%s': %s", vm_name, str(vm_err))
188201
continue
189202

203+
_LOGGER.debug("Successfully processed %d VMs (fallback)", len(vms))
190204
return vms
191205
except Exception as err:
192206
_LOGGER.debug("Error in original VM method: %s", str(err))
@@ -236,7 +250,9 @@ async def get_vm_os_info(self, vm_name: str) -> str:
236250
async def get_vm_status(self, vm_name: str) -> str:
237251
"""Get detailed status of a specific virtual machine."""
238252
try:
239-
result = await self.execute_command(f'virsh domstate "{vm_name}"')
253+
# Use shlex.quote to properly escape VM names with special characters
254+
escaped_name = shlex.quote(vm_name)
255+
result = await self.execute_command(f'virsh domstate {escaped_name}')
240256
if result.exit_status != 0:
241257
_LOGGER.error("Failed to get VM status for '%s': %s", vm_name, result.stderr)
242258
return VMState.CRASHED.value
@@ -256,7 +272,9 @@ async def start_vm(self, vm_name: str) -> bool:
256272
_LOGGER.info("VM '%s' is already running", vm_name)
257273
return True
258274

259-
result = await self.execute_command(f'virsh start "{vm_name}"')
275+
# Use shlex.quote to properly escape VM names with special characters
276+
escaped_name = shlex.quote(vm_name)
277+
result = await self.execute_command(f'virsh start {escaped_name}')
260278
success = result.exit_status == 0
261279

262280
if not success:
@@ -289,7 +307,9 @@ async def stop_vm(self, vm_name: str) -> bool:
289307
_LOGGER.info("VM '%s' is already shut off", vm_name)
290308
return True
291309

292-
result = await self.execute_command(f'virsh shutdown "{vm_name}"')
310+
# Use shlex.quote to properly escape VM names with special characters
311+
escaped_name = shlex.quote(vm_name)
312+
result = await self.execute_command(f'virsh shutdown {escaped_name}')
293313
success = result.exit_status == 0
294314

295315
if not success:
@@ -326,7 +346,8 @@ async def pause_vm(self, vm_name: str) -> bool:
326346
_LOGGER.error("Cannot pause VM '%s' because it is not running (current state: %s)", vm_name, current_state)
327347
return False
328348

329-
result = await self.execute_command(f'virsh suspend "{vm_name}"')
349+
escaped_name = shlex.quote(vm_name)
350+
result = await self.execute_command(f'virsh suspend {escaped_name}')
330351
success = result.exit_status == 0
331352

332353
if not success:
@@ -363,7 +384,8 @@ async def resume_vm(self, vm_name: str) -> bool:
363384
_LOGGER.error("Cannot resume VM '%s' because it is not paused (current state: %s)", vm_name, current_state)
364385
return False
365386

366-
result = await self.execute_command(f'virsh resume "{vm_name}"')
387+
escaped_name = shlex.quote(vm_name)
388+
result = await self.execute_command(f'virsh resume {escaped_name}')
367389
success = result.exit_status == 0
368390

369391
if not success:
@@ -396,7 +418,8 @@ async def restart_vm(self, vm_name: str) -> bool:
396418
_LOGGER.error("Cannot restart VM '%s' because it is not running (current state: %s)", vm_name, current_state)
397419
return False
398420

399-
result = await self.execute_command(f'virsh reboot "{vm_name}"')
421+
escaped_name = shlex.quote(vm_name)
422+
result = await self.execute_command(f'virsh reboot {escaped_name}')
400423
success = result.exit_status == 0
401424

402425
if not success:
@@ -433,7 +456,8 @@ async def hibernate_vm(self, vm_name: str) -> bool:
433456
_LOGGER.error("Cannot hibernate VM '%s' because it is not running (current state: %s)", vm_name, current_state)
434457
return False
435458

436-
result = await self.execute_command(f'virsh dompmsuspend "{vm_name}" disk')
459+
escaped_name = shlex.quote(vm_name)
460+
result = await self.execute_command(f'virsh dompmsuspend {escaped_name} disk')
437461
success = result.exit_status == 0
438462

439463
if not success:
@@ -466,7 +490,8 @@ async def force_stop_vm(self, vm_name: str) -> bool:
466490
_LOGGER.info("VM '%s' is already shut off", vm_name)
467491
return True
468492

469-
result = await self.execute_command(f'virsh destroy "{vm_name}"')
493+
escaped_name = shlex.quote(vm_name)
494+
result = await self.execute_command(f'virsh destroy {escaped_name}')
470495
success = result.exit_status == 0
471496

472497
if not success:

custom_components/unraid/switch.py

Lines changed: 77 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,20 +150,94 @@ def __init__(
150150
self._vm_name = vm_name
151151
self._last_known_state = None
152152

153-
# Remove any leading numbers and spaces for the entity ID
154-
cleaned_name = ''.join(c for c in vm_name if not c.isdigit()).strip()
153+
# Create a safe entity ID from VM name with collision detection
154+
from .utils import normalize_name
155+
safe_name = normalize_name(vm_name)
156+
# Ensure it doesn't start with a number (Home Assistant requirement)
157+
if safe_name and safe_name[0].isdigit():
158+
safe_name = f"vm_{safe_name}"
159+
160+
# Handle entity ID collisions by checking existing VMs
161+
safe_name = self._ensure_unique_entity_id(safe_name, vm_name, coordinator)
155162

156163
super().__init__(
157164
coordinator,
158165
UnraidSwitchEntityDescription(
159-
key=f"vm_{cleaned_name}",
166+
key=f"vm_{safe_name}",
160167
name=f"{vm_name}",
161168
value_fn=self._get_vm_state,
162169
)
163170
)
164171
self._attr_entity_registry_enabled_default = True
165172

166173
# Get OS type for specific model info
174+
self._get_os_type_info(vm_name, coordinator)
175+
176+
def _ensure_unique_entity_id(self, base_name: str, vm_name: str, coordinator) -> str:
177+
"""Ensure the entity ID is unique by checking for collisions with existing VMs.
178+
179+
Args:
180+
base_name: The normalized base name for the entity ID
181+
vm_name: The original VM name
182+
coordinator: The coordinator containing VM data
183+
184+
Returns:
185+
A unique entity ID that won't collide with existing VMs
186+
"""
187+
from .utils import normalize_name
188+
189+
# Get all existing VM names from coordinator data
190+
existing_vms = coordinator.data.get("vms", [])
191+
existing_normalized_names = set()
192+
193+
for vm in existing_vms:
194+
existing_vm_name = vm.get("name", "")
195+
if existing_vm_name and existing_vm_name != vm_name: # Don't include current VM
196+
existing_normalized = normalize_name(existing_vm_name)
197+
if existing_normalized and existing_normalized[0].isdigit():
198+
existing_normalized = f"vm_{existing_normalized}"
199+
existing_normalized_names.add(existing_normalized)
200+
201+
# If no collision, return the base name
202+
if base_name not in existing_normalized_names:
203+
return base_name
204+
205+
# Handle collision by adding a suffix based on original VM name characteristics
206+
# Strategy: Use distinguishing characteristics from the original name
207+
208+
# Try to create a unique suffix based on the original name
209+
import re
210+
211+
# Extract unique characteristics from the original VM name
212+
# 1. Check for different separators (dash vs space vs underscore)
213+
if '-' in vm_name and ' ' not in vm_name:
214+
suffix = "dash"
215+
elif ' ' in vm_name and '-' not in vm_name:
216+
suffix = "space"
217+
elif '_' in vm_name:
218+
suffix = "underscore"
219+
else:
220+
# 2. Use position-based numbering or character-based differentiation
221+
# Extract any numbers or special patterns
222+
numbers = re.findall(r'\d+', vm_name)
223+
if numbers:
224+
suffix = f"n{numbers[-1]}" # Use last number found
225+
else:
226+
# 3. Use length or hash-based suffix as last resort
227+
suffix = f"len{len(vm_name)}"
228+
229+
candidate_name = f"{base_name}_{suffix}"
230+
231+
# If still collision, fall back to incremental numbering
232+
counter = 2
233+
while candidate_name in existing_normalized_names:
234+
candidate_name = f"{base_name}_{suffix}_{counter}"
235+
counter += 1
236+
237+
return candidate_name
238+
239+
def _get_os_type_info(self, vm_name: str, coordinator) -> None:
240+
"""Get OS type for specific model info."""
167241
for vm in coordinator.data.get("vms", []):
168242
if vm["name"] == vm_name and "os_type" in vm:
169243
self._attr_device_info["model"] = f"{vm['os_type'].capitalize()} Virtual Machine"

0 commit comments

Comments
 (0)