diff --git a/src/blueapi/client/client.py b/src/blueapi/client/client.py index c5b41ff45..417d7df01 100644 --- a/src/blueapi/client/client.py +++ b/src/blueapi/client/client.py @@ -6,7 +6,7 @@ from functools import cached_property from itertools import chain from pathlib import Path -from typing import Self +from typing import Any, Self from bluesky_stomp.messaging import MessageContext, StompClient from bluesky_stomp.models import Broker @@ -38,7 +38,7 @@ ) from blueapi.utils import deprecated from blueapi.worker import WorkerEvent, WorkerState -from blueapi.worker.event import ProgressEvent, TaskStatus +from blueapi.worker.event import ProgressEvent, TaskError, TaskResult, TaskStatus from blueapi.worker.task_worker import TrackableTask from .event_bus import AnyEvent, EventBusClient, OnAnyEvent @@ -141,13 +141,17 @@ def __init__(self, name, model: PlanModel, client: "BlueapiClient"): self._client = client self.__doc__ = model.description - def __call__(self, *args, **kwargs): + def __call__(self, *args, **kwargs) -> Any: req = TaskRequest( name=self.name, params=self._build_args(*args, **kwargs), instrument_session=self._client.instrument_session, ) - self._client.run_task(req) + match self._client.run_task(req): + case TaskStatus(result=TaskResult(result=res)): + return res + case TaskStatus(result=TaskError(type=typ, message=msg)): + raise PlanFailedError(typ, msg) @property def help_text(self) -> str: @@ -744,3 +748,9 @@ def login(self, token_path: Path | None = None): auth.start_device_flow() else: print("Server is not configured to use authentication!") + + +class PlanFailedError(Exception): + def __init__(self, typ: str, message: str): + super().__init__(message) + self._type = typ diff --git a/tests/unit_tests/client/test_client.py b/tests/unit_tests/client/test_client.py index a96f428e8..1cff107c1 100644 --- a/tests/unit_tests/client/test_client.py +++ b/tests/unit_tests/client/test_client.py @@ -17,6 +17,7 @@ MissingInstrumentSessionError, Plan, PlanCache, + PlanFailedError, ) from blueapi.client.event_bus import AnyEvent, EventBusClient from blueapi.client.rest import BlueapiRestClient, BlueskyRemoteControlError @@ -512,6 +513,39 @@ def callback(on_event: Callable[[AnyEvent, MessageContext], None]): mock_on_event.assert_called_once_with(COMPLETE_EVENT) +def test_scripting_interface_returns_result(): + client = Mock(spec=BlueapiClient, instrument_session="cm12345-1") + client.run_task.return_value = TaskStatus( + task_id="foobar", + task_complete=True, + task_failed=False, + result=TaskResult(result=42, type="int"), + ) + demo_plan = Plan( + "demo", + client=client, + model=PlanModel(name="demo", description="Demo plan", schema={}), + ) + assert demo_plan() == 42 + + +def test_scripting_interface_raises_exceptions(): + client = Mock(spec=BlueapiClient, instrument_session="cm12345-1") + client.run_task.return_value = TaskStatus( + task_id="foobar", + task_complete=True, + task_failed=True, + result=TaskError(type="ValueError", message="Plan failed"), + ) + demo_plan = Plan( + "demo", + client=client, + model=PlanModel(name="demo", description="Demo plan", schema={}), + ) + with pytest.raises(PlanFailedError, match="Plan failed"): + demo_plan() + + def test_oidc_config_property(client, mock_rest): assert client.oidc_config == mock_rest.get_oidc_config()