1515from homeassistant .core import HomeAssistant
1616from homeassistant .helpers .entity_platform import AddEntitiesCallback
1717from homeassistant .helpers .entity import generate_entity_id
18+ from homeassistant .helpers .restore_state import RestoreEntity
19+
1820from .const import (
1921 CONF_API_KEY ,
2022 CONF_MODEL ,
@@ -73,7 +75,7 @@ async def async_setup_entry(
7375 async_add_entities ([OpenAITTSEntity (hass , config_entry , engine )])
7476
7577
76- class OpenAITTSEntity (TextToSpeechEntity ):
78+ class OpenAITTSEntity (TextToSpeechEntity , RestoreEntity ):
7779 _attr_has_entity_name = True
7880 _attr_should_poll = False
7981
@@ -93,6 +95,41 @@ def __init__(self, hass: HomeAssistant, config: ConfigEntry, engine: OpenAITTSEn
9395 self ._last_total_time = None
9496 self ._last_media_duration_ms = None # Store in milliseconds
9597
98+ async def async_added_to_hass (self ) -> None :
99+ """When entity is added to hass, restore previous state."""
100+ await super ().async_added_to_hass ()
101+
102+ # Restore previous state if it exists
103+ last_state = await self .async_get_last_state ()
104+
105+ if last_state is not None and last_state .attributes :
106+ # Restore from attributes
107+ self ._engine_active = last_state .attributes .get ("engine_active" , False )
108+
109+ # Restore time values
110+ api_time_str = last_state .attributes .get ("last_api_time" )
111+ if api_time_str and " msec" in api_time_str :
112+ self ._last_api_time = int (api_time_str .replace (" msec" , "" ))
113+
114+ ffmpeg_time_str = last_state .attributes .get ("last_ffmpeg_time" )
115+ if ffmpeg_time_str and " msec" in ffmpeg_time_str :
116+ self ._last_ffmpeg_time = int (ffmpeg_time_str .replace (" msec" , "" ))
117+
118+ total_time_str = last_state .attributes .get ("last_total_time" )
119+ if total_time_str and " msec" in total_time_str :
120+ self ._last_total_time = int (total_time_str .replace (" msec" , "" ))
121+
122+ # Restore media duration directly (stored as raw milliseconds)
123+ self ._last_media_duration_ms = last_state .attributes .get ("media_duration" )
124+
125+ _LOGGER .debug (
126+ "Restored OpenAI TTS entity state: api_time=%s, ffmpeg_time=%s, total_time=%s, media_duration=%s" ,
127+ self ._last_api_time ,
128+ self ._last_ffmpeg_time ,
129+ self ._last_total_time ,
130+ self ._last_media_duration_ms
131+ )
132+
96133 @property
97134 def default_language (self ) -> str :
98135 return "en"
@@ -266,6 +303,10 @@ def get_tts_audio(
266303 # Compute media duration in milliseconds before cleaning up.
267304 duration_seconds = get_media_duration (merged_output_path )
268305 self ._last_media_duration_ms = int (duration_seconds * 1000 )
306+
307+ # DO NOT call self.async_write_ha_state() here - thread safety issue
308+ # It will be called from the async_get_tts_audio method
309+
269310 # Cleanup temporary files.
270311 try :
271312 os .remove (tts_path )
@@ -306,6 +347,10 @@ def get_tts_audio(
306347 # Compute media duration in milliseconds for the normalized file.
307348 duration_seconds = get_media_duration (norm_output_path )
308349 self ._last_media_duration_ms = int (duration_seconds * 1000 )
350+
351+ # DO NOT call self.async_write_ha_state() here - thread safety issue
352+ # It will be called from the async_get_tts_audio method
353+
309354 try :
310355 os .remove (norm_input_path )
311356 os .remove (norm_output_path )
@@ -328,6 +373,10 @@ def get_tts_audio(
328373 _LOGGER .debug ("Overall TTS processing time: %.2f ms" , overall_duration )
329374 self ._last_total_time = overall_duration
330375 self ._last_ffmpeg_time = 0 # No ffmpeg processing used.
376+
377+ # DO NOT call self.async_write_ha_state() here - thread safety issue
378+ # It will be called from the async_get_tts_audio method
379+
331380 return "mp3" , audio_content
332381
333382 except CancelledError as ce :
@@ -347,11 +396,17 @@ async def async_get_tts_audio(
347396 try :
348397 self ._engine_active = True
349398 self .async_write_ha_state ()
350- return await asyncio .shield (
399+
400+ result = await asyncio .shield (
351401 self .hass .async_add_executor_job (
352402 partial (self .get_tts_audio , message , language , options = options )
353403 )
354404 )
405+
406+ # Update the entity state from within the event loop
407+ self .async_write_ha_state ()
408+
409+ return result
355410 except asyncio .CancelledError :
356411 _LOGGER .exception ("async_get_tts_audio cancelled" )
357412 raise
0 commit comments