Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/reference/openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,8 @@ components:
request_id:
title: Request Id
type: string
result:
title: Result
task:
$ref: '#/components/schemas/Task'
task_id:
Expand Down
2 changes: 1 addition & 1 deletion src/blueapi/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class BlueskyContext:
configuration: InitVar[ApplicationConfig | None] = None

run_engine: RunEngine = field(
default_factory=lambda: RunEngine(context_managers=[])
default_factory=lambda: RunEngine(context_managers=[], call_returns_result=True)
)
tiled_conf: TiledConfig | None = field(default=None, init=False, repr=False)
numtracker: NumtrackerClient | None = field(default=None, init=False, repr=False)
Expand Down
22 changes: 21 additions & 1 deletion src/blueapi/utils/base_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
from pydantic import BaseModel, ConfigDict
import logging
from typing import Annotated, Any

from pydantic import (
BaseModel,
ConfigDict,
PlainSerializer,
TypeAdapter,
)

logger = logging.getLogger(__name__)
# Pydantic config for blueapi API models with common config.
BlueapiModelConfig = ConfigDict(
extra="forbid",
Expand All @@ -16,6 +25,17 @@
)


def _safe_serialize(value: Any) -> Any:
"""Try serializing but skip any type that pydantic can't handle"""
try:
return TypeAdapter(type(value)).dump_python(value, mode="json")
except Exception:
logger.warning("Type '%s' not serializable: %s", type(value), value)


NoneFallback = Annotated[Any, PlainSerializer(_safe_serialize)]


class BlueapiBaseModel(BaseModel):
"""
Base class for blueapi API models.
Expand Down
5 changes: 5 additions & 0 deletions src/blueapi/worker/event.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from collections.abc import Mapping
from enum import Enum

Expand All @@ -6,11 +7,14 @@
from super_state_machine.extras import PropertyMachine, ProxyString

from blueapi.utils import BlueapiBaseModel
from blueapi.utils.base_model import NoneFallback

# The RunEngine can return any of these three types as its state
# RawRunEngineState = type[PropertyMachine | ProxyString | str]
RawRunEngineState = PropertyMachine | ProxyString | str

log = logging.getLogger(__name__)


# NOTE this is interim until refactor
class TaskStatusEnum(str, Enum):
Expand Down Expand Up @@ -109,6 +113,7 @@ class TaskStatus(BlueapiBaseModel):
task_id: str
task_complete: bool
task_failed: bool
result: NoneFallback = None


class WorkerEvent(BlueapiBaseModel):
Expand Down
5 changes: 4 additions & 1 deletion src/blueapi/worker/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ def do_task(self, ctx: BlueskyContext) -> None:
func = ctx.plan_functions[self.name]
prepared_params = self.prepare_params(ctx)
ctx.run_engine.md.update(self.metadata)
ctx.run_engine(func(**prepared_params))
result = ctx.run_engine(func(**prepared_params))
if isinstance(result, tuple):
return None
return result.plan_result


def _lookup_params(ctx: BlueskyContext, task: Task) -> BaseModel:
Expand Down
7 changes: 5 additions & 2 deletions src/blueapi/worker/task_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
)
from blueapi.core.bluesky_event_loop import configure_bluesky_event_loop
from blueapi.log import plan_tag_filter_context
from blueapi.utils.base_model import BlueapiBaseModel
from blueapi.utils.base_model import BlueapiBaseModel, NoneFallback
from blueapi.utils.thread_exception import handle_all_exceptions

from .event import (
Expand Down Expand Up @@ -69,6 +69,7 @@ class TrackableTask(BlueapiBaseModel):
is_complete: bool = False
is_pending: bool = True
errors: list[str] = Field(default_factory=list)
result: NoneFallback = None


class TaskWorker:
Expand Down Expand Up @@ -427,7 +428,8 @@ def process_task():
LOGGER.info(f"Got new task: {next_task}")
self._current = next_task
self._current.is_pending = False
self._current.task.do_task(self._ctx)
result = self._current.task.do_task(self._ctx)
self._current.result = result

with plan_tag_filter_context(next_task.task.name, LOGGER):
if self._current_task_otel_context is not None:
Expand Down Expand Up @@ -528,6 +530,7 @@ def _report_status(
task_id=self._current.task_id,
task_complete=self._current.is_complete,
task_failed=bool(self._current.errors),
result=self._current.result,
)
correlation_id = self._current.task_id
add_span_attributes(
Expand Down
5 changes: 5 additions & 0 deletions tests/unit_tests/service/test_rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ def test_get_tasks(mock_runner: Mock, client: TestClient) -> None:
"params": {"time": 0.0},
"metadata": {},
},
"result": None,
"task_id": "0",
},
{
Expand All @@ -348,6 +349,7 @@ def test_get_tasks(mock_runner: Mock, client: TestClient) -> None:
"params": {},
"metadata": {},
},
"result": None,
"task_id": "1",
},
]
Expand Down Expand Up @@ -379,6 +381,7 @@ def test_get_tasks_by_status(mock_runner: Mock, client: TestClient) -> None:
"params": {},
"metadata": {},
},
"result": None,
"task_id": "3",
}
]
Expand Down Expand Up @@ -472,6 +475,7 @@ def test_get_task(mock_runner: Mock, client: TestClient):
"foo": "bar",
},
},
"result": None,
"task_id": f"{task_id}",
}

Expand Down Expand Up @@ -500,6 +504,7 @@ def test_get_all_tasks(mock_runner: Mock, client: TestClient):
"is_complete": False,
"is_pending": True,
"request_id": None,
"result": None,
"errors": [],
}
]
Expand Down
9 changes: 6 additions & 3 deletions tests/unit_tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,9 +929,12 @@ def test_event_formatting():
OutputFormat.JSON,
worker,
(
"""{"state": "RUNNING", "task_status": """
"""{"task_id": "count", "task_complete": false, "task_failed": false}, """
""""errors": [], "warnings": []}\n"""
'{"state": "RUNNING", "task_status": {'
'"task_id": "count", '
'"task_complete": false, '
'"task_failed": false, '
'"result": null'
'}, "errors": [], "warnings": []}\n'
),
)
_assert_matching_formatting(OutputFormat.COMPACT, worker, "Worker Event: RUNNING\n")
Expand Down
Loading