diff --git a/pyproject.toml b/pyproject.toml index bc63ec9210..9d38e5fbda 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ dependencies = [ "pyjwt[crypto]", "tomlkit", "graypy>=2.1.0", + "jinja2>=3.1.6", ] dynamic = ["version"] license.file = "LICENSE" diff --git a/src/blueapi/cli/cli.py b/src/blueapi/cli/cli.py index c39ff13dff..523732565e 100644 --- a/src/blueapi/cli/cli.py +++ b/src/blueapi/cli/cli.py @@ -17,7 +17,6 @@ from click.exceptions import ClickException from observability_utils.tracing import setup_tracing from pydantic import ValidationError -from requests.exceptions import ConnectionError from blueapi import __version__, config from blueapi.cli.format import OutputFormat @@ -26,6 +25,7 @@ from blueapi.client.rest import ( BlueskyRemoteControlError, InvalidParametersError, + ServiceUnavailableError, UnauthorisedAccessError, UnknownPlanError, ) @@ -36,9 +36,10 @@ from blueapi.core import OTLP_EXPORT_ENABLED, DataEvent from blueapi.log import set_up_logging from blueapi.service.authentication import SessionCacheManager, SessionManager -from blueapi.service.model import SourceInfo, TaskRequest +from blueapi.service.model import DeviceResponse, PlanResponse, SourceInfo, TaskRequest from blueapi.worker import ProgressEvent, WorkerEvent +from . import stubgen from .scratch import setup_scratch from .updates import CliEventRenderer @@ -152,6 +153,23 @@ def start_application(obj: dict): start(config) +@main.command() +@click.pass_obj +@click.argument("target", type=click.Path(file_okay=False)) +def generate_stubs(obj: dict, target: Path): + """ + Generate a type-stubs project for blueapi for the currently running server. + This enables users using blueapi as a library to benefit from type checking + and linting when writing scripts against the BlueapiClient. + """ + click.echo(f"Writing stubs to {target}") + + config: ApplicationConfig = obj["config"] + bc = BlueapiClient.from_config(config) + + stubgen.generate_stubs(Path(target), list(bc.plans), list(bc.devices)) + + @main.group() @click.option( "-o", @@ -183,7 +201,7 @@ def check_connection(func: Callable[P, T]) -> Callable[P, T]: def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: try: return func(*args, **kwargs) - except ConnectionError as ce: + except ServiceUnavailableError as ce: raise ClickException( "Failed to establish connection to blueapi server." ) from ce @@ -204,7 +222,7 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: def get_plans(obj: dict) -> None: """Get a list of plans available for the worker to use""" client: BlueapiClient = obj["client"] - obj["fmt"].display(client.get_plans()) + obj["fmt"].display(PlanResponse(plans=[p.model for p in client.plans])) @controller.command(name="devices") @@ -213,7 +231,7 @@ def get_plans(obj: dict) -> None: def get_devices(obj: dict) -> None: """Get a list of devices available for the worker to use""" client: BlueapiClient = obj["client"] - obj["fmt"].display(client.get_devices()) + obj["fmt"].display(DeviceResponse(devices=[dev.model for dev in client.devices])) @controller.command(name="listen") @@ -345,7 +363,7 @@ def get_state(obj: dict) -> None: """Print the current state of the worker""" client: BlueapiClient = obj["client"] - print(client.get_state().name) + print(client.state.name) @controller.command(name="pause") @@ -428,7 +446,7 @@ def env( status = client.reload_environment(timeout=timeout) print("Environment is initialized") else: - status = client.get_environment() + status = client.environment print(status) @@ -470,14 +488,13 @@ def login(obj: dict) -> None: print("Logged in") except Exception: client = BlueapiClient.from_config(config) - oidc_config = client.get_oidc_config() - if oidc_config is None: + if oidc := client.oidc_config: + auth = SessionManager( + oidc, cache_manager=SessionCacheManager(config.auth_token_path) + ) + auth.start_device_flow() + else: print("Server is not configured to use authentication!") - return - auth = SessionManager( - oidc_config, cache_manager=SessionCacheManager(config.auth_token_path) - ) - auth.start_device_flow() @main.command(name="logout") diff --git a/src/blueapi/cli/format.py b/src/blueapi/cli/format.py index 490e57cfa5..7ce4fffbb0 100644 --- a/src/blueapi/cli/format.py +++ b/src/blueapi/cli/format.py @@ -12,7 +12,9 @@ from blueapi.core.bluesky_types import DataEvent from blueapi.service.model import ( + DeviceModel, DeviceResponse, + PlanModel, PlanResponse, PythonEnvironmentResponse, SourceInfo, @@ -54,17 +56,21 @@ def display_full(obj: Any, stream: Stream): match obj: case PlanResponse(plans=plans): for plan in plans: - print(plan.name) - if desc := plan.description: - print(indent(dedent(desc).strip(), " ")) - if schema := plan.parameter_schema: - print(" Schema") - print(indent(json.dumps(schema, indent=2), " ")) + display_full(plan, stream) + case PlanModel(name=name, description=desc, parameter_schema=schema): + print(name) + if desc: + print(indent(dedent(desc).strip(), " ")) + if schema: + print(" Schema") + print(indent(json.dumps(schema, indent=2), " ")) case DeviceResponse(devices=devices): for dev in devices: - print(dev.name) - for proto in dev.protocols: - print(f" {proto}") + display_full(dev, stream) + case DeviceModel(name=name, protocols=protocols): + print(name) + for proto in protocols: + print(f" {proto}") case DataEvent(name=name, doc=doc): print(f"{name.title()}:{fmt_dict(doc)}") case WorkerEvent(state=st, task_status=task): @@ -100,11 +106,13 @@ def display_json(obj: Any, stream: Stream): print = partial(builtins.print, file=stream) match obj: case PlanResponse(plans=plans): - print(json.dumps([p.model_dump() for p in plans], indent=2)) + display_json(plans, stream) case DeviceResponse(devices=devices): - print(json.dumps([d.model_dump() for d in devices], indent=2)) + display_json(devices, stream) case BaseModel(): print(json.dumps(obj.model_dump())) + case list(): + print(json.dumps([it.model_dump() for it in obj], indent=2)) case _: print(json.dumps(obj)) @@ -114,26 +122,30 @@ def display_compact(obj: Any, stream: Stream): match obj: case PlanResponse(plans=plans): for plan in plans: - print(plan.name) - if desc := plan.description: - print(indent(dedent(desc.split("\n\n")[0].strip("\n")), " ")) - if schema := plan.parameter_schema: - print(" Args") - for arg, spec in schema.get("properties", {}).items(): - req = arg in schema.get("required", {}) - print(f" {arg}={_describe_type(spec, req)}") + display_compact(plan, stream) + case PlanModel(name=name, description=desc, parameter_schema=schema): + print(name) + if desc: + print(indent(dedent(desc.split("\n\n")[0].strip("\n")), " ")) + if schema: + print(" Args") + for arg, spec in schema.get("properties", {}).items(): + req = arg in schema.get("required", {}) + print(f" {arg}={_describe_type(spec, req)}") case DeviceResponse(devices=devices): for dev in devices: - print(dev.name) - print( - indent( - textwrap.fill( - ", ".join(str(proto) for proto in dev.protocols), - 80, - ), - " ", - ) + display_compact(dev, stream) + case DeviceModel(name=name, protocols=protocols): + print(name) + print( + indent( + textwrap.fill( + ", ".join(str(proto) for proto in protocols), + 80, + ), + " ", ) + ) case DataEvent(name=name): print(f"Data Event: {name}") case WorkerEvent(state=state): diff --git a/src/blueapi/cli/stubgen.py b/src/blueapi/cli/stubgen.py new file mode 100644 index 0000000000..6f6fbf4bd6 --- /dev/null +++ b/src/blueapi/cli/stubgen.py @@ -0,0 +1,117 @@ +import logging +from dataclasses import dataclass +from inspect import cleandoc +from pathlib import Path +from textwrap import dedent +from typing import Self, TextIO + +from jinja2 import Environment, PackageLoader + +from blueapi.client.cache import DeviceRef, Plan +from blueapi.core import context +from blueapi.core.bluesky_types import BLUESKY_PROTOCOLS + +log = logging.getLogger(__name__) + + +@dataclass +class ArgSpec: + name: str + type: str + optional: bool + + +@dataclass +class PlanSpec: + name: str + docs: str + args: list[ArgSpec] + + @classmethod + def from_plan(cls, plan: Plan) -> Self: + req = set(plan.required) + args = [ + ArgSpec(arg, _type_string(spec), arg not in req) + for arg, spec in plan.model.parameter_schema.get("properties", {}).items() + ] + return cls(plan.name, plan.help_text, args) + + +BLUESKY_PROTOCOL_NAMES = {context.qualified_name(proto) for proto in BLUESKY_PROTOCOLS} + + +def _type_string(spec) -> str: + """Best effort attempt at making useful type hints for plans""" + match spec.get("type"): + case "array": + return f"list[{_type_string(spec.get('items'))}]" + case "integer": + return "int" + case "number": + return "float" + case proto if proto in BLUESKY_PROTOCOL_NAMES: + return "DeviceRef" + case "object": + return "dict[str, Any]" + case "string": + return "str" + case "boolean": + return "bool" + case None if opts := spec.get("anyOf"): + return " | ".join(_type_string(opt) for opt in opts) + case _: + return "Any" + + +def generate_stubs(target: Path, plans: list[Plan], devices: list[DeviceRef]): + log.info("Generating stubs for %d plans and %d devices", len(plans), len(devices)) + target.mkdir(parents=True, exist_ok=True) + client_dir = target / "src" / "blueapi-stubs" / "client" + + log.debug("Making project structure: %s", client_dir) + client_dir.mkdir(parents=True, exist_ok=True) + + stub_file = client_dir / "cache.pyi" + project_file = target / "pyproject.toml" + py_typed = target / "src" / "blueapi-stubs" / "py.typed" + + log.debug("Writing pyproject.toml to %s", project_file) + with open(project_file, "w") as out: + out.write( + dedent(""" + [project] + name = "blueapi-stubs" + version = "0.1.0" + description = "Generated client stubs for a running server" + readme = "README.md" + requires-python = ">=3.11" + + dependencies = [ + "blueapi" + ] + """) + ) + + log.debug("Writing py.typed file to %s", py_typed) + with open(py_typed, "w") as out: + out.write("partial\n") + + log.debug("Writing stub file to %s", stub_file) + with open(stub_file, "w") as out: + render_stub_file(out, plans, devices) + + +def _docstring(text: str) -> str: + # """Convert a docstring to a format that can be inserted into the template""" + return cleandoc(text).replace('"""', '\\"""') + + +def render_stub_file( + stub_file: TextIO, plan_models: list[Plan], devices: list[DeviceRef] +): + plans = [PlanSpec.from_plan(p) for p in plan_models] + + env = Environment(loader=PackageLoader("blueapi", package_path="stubs/templates")) + env.filters["docstring"] = _docstring + tmpl = env.get_template("cache_template.pyi.jinja") + stub_file.write(tmpl.render(plans=plans, devices=devices)) diff --git a/src/blueapi/client/cache.py b/src/blueapi/client/cache.py new file mode 100644 index 0000000000..0ec8c4c87c --- /dev/null +++ b/src/blueapi/client/cache.py @@ -0,0 +1,177 @@ +import logging +from collections.abc import Callable +from itertools import chain +from typing import Any + +from blueapi.client.rest import BlueapiRestClient +from blueapi.service.model import DeviceModel, PlanModel +from blueapi.worker.event import WorkerEvent + +log = logging.getLogger(__name__) + + +# This file should be kept in sync with the type stub template in stubs/templates + + +PlanRunner = Callable[[str, dict[str, Any]], WorkerEvent] + + +class PlanCache: + """ + Cache of plans available on the server + """ + + def __init__(self, runner: PlanRunner, plans: list[PlanModel]): + self._cache = {model.name: Plan(model=model, runner=runner) for model in plans} + for name, plan in self._cache.items(): + if name.startswith("_"): + continue + setattr(self, name, plan) + + def __getitem__(self, name: str) -> "Plan": + return self._cache[name] + + def __getattr__(self, name: str) -> "Plan": + raise AttributeError(f"No plan named '{name}' available") + + def __iter__(self): + return iter(self._cache.values()) + + def __repr__(self) -> str: + return f"PlanCache({len(self._cache)} plans)" + + +class Plan: + """ + An interface to a plan on the blueapi server + + This allows remote plans to be called (mostly) as if they were local + methods when writing user scripts. + + If you are seeing this help while using blueapi as a library, generating + type stubs may be helpful for type checking and plan discovery, eg + + blueapi generate-stubs /tmp/blueapi-stubs + uv add --editable /tmp/blueapi-stubs + + """ + + model: PlanModel + + def __init__(self, model: PlanModel, runner: PlanRunner): + self.model = model + self._runner = runner + self.__doc__ = model.description + + def __call__(self, *args, **kwargs) -> WorkerEvent: + """ + Run the plan on the server mapping the given args into the required parameters + """ + return self._runner(self.name, self._build_args(*args, **kwargs)) + + @property + def name(self) -> str: + return self.model.name + + @property + def help_text(self) -> str: + return self.model.description or f"Plan {self!r}" + + @property + def properties(self) -> set[str]: + return self.model.parameter_schema.get("properties", {}).keys() + + @property + def required(self) -> list[str]: + return self.model.parameter_schema.get("required", []) + + def _build_args(self, *args, **kwargs): + log.info( + "Building args for %s, using %s and %s", + "[" + ",".join(self.properties) + "]", + args, + kwargs, + ) + + if len(args) > len(self.properties): + raise TypeError(f"{self.name} got too many arguments") + if extra := {k for k in kwargs if k not in self.properties}: + raise TypeError(f"{self.name} got unexpected arguments: {extra}") + + params = {} + # Initially fill parameters using positional args assuming the order + # from the parameter_schema + for req, arg in zip(self.properties, args, strict=False): + params[req] = arg + + # Then append any values given via kwargs + for key, value in kwargs.items(): + # If we've already assumed a positional arg was this value, bail out + if key in params: + raise TypeError(f"{self.name} got multiple values for {key}") + params[key] = value + + if missing := {k for k in self.required if k not in params}: + raise TypeError(f"Missing argument(s) for {missing}") + return params + + def __repr__(self): + opts = [p for p in self.properties if p not in self.required] + params = ", ".join(chain(self.required, (f"{opt}=None" for opt in opts))) + return f"{self.name}({params})" + + +class DeviceCache: + def __init__(self, rest: BlueapiRestClient): + self._rest = rest + self._cache = { + model.name: DeviceRef(name=model.name, cache=self, model=model) + for model in rest.get_devices().devices + } + for name, device in self._cache.items(): + if name.startswith("_"): + continue + setattr(self, name, device) + + def __getitem__(self, name: str) -> "DeviceRef": + if dev := self._cache.get(name): + return dev + try: + model = self._rest.get_device(name) + device = DeviceRef(name=name, cache=self, model=model) + self._cache[name] = device + setattr(self, model.name, device) + return device + except KeyError: + pass + raise AttributeError(f"No device named '{name}' available") + + def __getattr__(self, name: str) -> "DeviceRef": + if name.startswith("_"): + return super().__getattribute__(name) + return self[name] + + def __iter__(self): + return iter(self._cache.values()) + + def __repr__(self) -> str: + return f"DeviceCache({len(self._cache)} devices)" + + +class DeviceRef(str): + model: DeviceModel + _cache: DeviceCache + + def __new__(cls, name: str, cache: DeviceCache, model: DeviceModel): + instance = super().__new__(cls, name) + instance.model = model + instance._cache = cache + return instance + + def __getattr__(self, name) -> "DeviceRef": + if name.startswith("_"): + raise AttributeError(f"No child device named {name}") + return self._cache[f"{self}.{name}"] + + def __repr__(self): + return f"Device({self})" diff --git a/src/blueapi/client/client.py b/src/blueapi/client/client.py index 0930e240a9..39fbf50540 100644 --- a/src/blueapi/client/client.py +++ b/src/blueapi/client/client.py @@ -1,5 +1,11 @@ +import itertools +import logging import time +from collections.abc import Iterable from concurrent.futures import Future +from functools import cached_property +from pathlib import Path +from typing import Any, Self from bluesky_stomp.messaging import MessageContext, StompClient from bluesky_stomp.models import Broker @@ -8,37 +14,47 @@ start_as_current_span, ) -from blueapi.config import ApplicationConfig, MissingStompConfigurationError +from blueapi.config import ( + ApplicationConfig, + ConfigLoader, + MissingStompConfigurationError, +) from blueapi.core.bluesky_types import DataEvent from blueapi.service.authentication import SessionManager from blueapi.service.model import ( - DeviceModel, - DeviceResponse, EnvironmentResponse, OIDCConfig, - PlanModel, - PlanResponse, PythonEnvironmentResponse, SourceInfo, TaskRequest, TaskResponse, - TasksListResponse, WorkerTask, ) -from blueapi.worker import TrackableTask, WorkerEvent, WorkerState +from blueapi.worker import WorkerEvent, WorkerState from blueapi.worker.event import ProgressEvent, TaskStatus +from .cache import DeviceCache, PlanCache from .event_bus import AnyEvent, BlueskyStreamingError, EventBusClient, OnAnyEvent from .rest import BlueapiRestClient, BlueskyRemoteControlError TRACER = get_tracer("client") +log = logging.getLogger(__name__) + + +class MissingInstrumentSessionError(Exception): + pass + + class BlueapiClient: """Unified client for controlling blueapi""" _rest: BlueapiRestClient _events: EventBusClient | None + _instrument_session: str | None = None + _callbacks: dict[int, OnAnyEvent] + _callback_id: itertools.count def __init__( self, @@ -47,9 +63,30 @@ def __init__( ): self._rest = rest self._events = events + self._callbacks = {} + self._callback_id = itertools.count() + + @cached_property + @start_as_current_span(TRACER) + def plans(self) -> PlanCache: + return PlanCache(self.run_plan, self._rest.get_plans().plans) + + @cached_property + @start_as_current_span(TRACER) + def devices(self) -> DeviceCache: + return DeviceCache(self._rest) @classmethod - def from_config(cls, config: ApplicationConfig) -> "BlueapiClient": + def from_config_file(cls, config_file: str) -> Self: + conf = ConfigLoader(ApplicationConfig) + conf.use_values_from_yaml(Path(config_file)) + return cls.from_config(conf.load()) + + @classmethod + def from_config( + cls, + config: ApplicationConfig, + ) -> Self: session_manager: SessionManager | None = None try: session_manager = SessionManager.from_cache(config.auth_token_path) @@ -71,56 +108,36 @@ def from_config(cls, config: ApplicationConfig) -> "BlueapiClient": else: return cls(rest) - @start_as_current_span(TRACER) - def get_plans(self) -> PlanResponse: - """ - List plans available + @property + def instrument_session(self) -> str: + if self._instrument_session is None: + raise MissingInstrumentSessionError() + return self._instrument_session - Returns: - PlanResponse: Plans that can be run - """ - return self._rest.get_plans() + @instrument_session.setter + def instrument_session(self, session: str): + log.debug("Setting instrument_session to %s", session) + self._instrument_session = session - @start_as_current_span(TRACER, "name") - def get_plan(self, name: str) -> PlanModel: - """ - Get details of a single plan + def with_instrument_session(self, session: str) -> Self: + self.instrument_session = session + return self - Args: - name: Plan name + def add_callback(self, callback: OnAnyEvent) -> int: + cb_id = next(self._callback_id) + self._callbacks[cb_id] = callback + return cb_id - Returns: - PlanModel: Details of the plan if found - """ - return self._rest.get_plan(name) + def remove_callback(self, id: int): + self._callbacks.pop(id) - @start_as_current_span(TRACER) - def get_devices(self) -> DeviceResponse: - """ - List devices available - - Returns: - DeviceResponse: Devices that can be used in plans - """ - - return self._rest.get_devices() - - @start_as_current_span(TRACER, "name") - def get_device(self, name: str) -> DeviceModel: - """ - Get details of a single device - - Args: - name: Device name - - Returns: - DeviceModel: Details of the device if found - """ - - return self._rest.get_device(name) + @property + def callbacks(self) -> Iterable[OnAnyEvent]: + return self._callbacks.values() + @property @start_as_current_span(TRACER) - def get_state(self) -> WorkerState: + def state(self) -> WorkerState: """ Get current state of the blueapi worker @@ -158,33 +175,9 @@ def resume(self) -> WorkerState: return self._rest.set_state(WorkerState.RUNNING, defer=False) - @start_as_current_span(TRACER, "task_id") - def get_task(self, task_id: str) -> TrackableTask: - """ - Get a task stored by the worker - - Args: - task_id: Unique ID for the task - - Returns: - TrackableTask: Task details - """ - assert task_id, "Task ID not provided!" - return self._rest.get_task(task_id) - - @start_as_current_span(TRACER) - def get_all_tasks(self) -> TasksListResponse: - """ - Get a list of all task stored by the worker - - Returns: - TasksListResponse: List of all Trackable Task - """ - - return self._rest.get_all_tasks() - + @property @start_as_current_span(TRACER) - def get_active_task(self) -> WorkerTask: + def active_task(self) -> WorkerTask: """ Get the currently active task, if any @@ -195,6 +188,15 @@ def get_active_task(self) -> WorkerTask: return self._rest.get_active_task() + @start_as_current_span(TRACER, "name", "params") + def run_plan(self, name: str, params: dict[str, Any]) -> WorkerEvent: + req = TaskRequest( + name=name, + params=params, + instrument_session=self.instrument_session, + ) + return self.run_task(req) + @start_as_current_span(TRACER, "task", "timeout") def run_task( self, @@ -221,7 +223,7 @@ def run_task( "Stomp configuration required to run plans is missing or disabled" ) - task_response = self.create_task(task) + task_response = self._rest.create_task(task) task_id = task_response.task_id complete: Future[WorkerEvent] = Future() @@ -239,6 +241,13 @@ def inner_on_event(event: AnyEvent, ctx: MessageContext) -> None: if relates_to_task: if on_event is not None: on_event(event) + for cb in self._callbacks.values(): + try: + cb(event) + except Exception as e: + log.error( + f"Callback ({cb}) failed for event: {event}", exc_info=e + ) if isinstance(event, WorkerEvent) and ( (event.is_complete()) and (ctx.correlation_id == task_id) ): @@ -255,7 +264,7 @@ def inner_on_event(event: AnyEvent, ctx: MessageContext) -> None: with self._events: self._events.subscribe_to_all_events(inner_on_event) - self.start_task(WorkerTask(task_id=task_id)) + self._rest.update_worker_task(WorkerTask(task_id=task_id)) return complete.result(timeout=timeout) @start_as_current_span(TRACER, "task") @@ -271,8 +280,10 @@ def create_and_start_task(self, task: TaskRequest) -> TaskResponse: TaskResponse: Acknowledgement of request """ - response = self.create_task(task) - worker_response = self.start_task(WorkerTask(task_id=response.task_id)) + response = self._rest.create_task(task) + worker_response = self._rest.update_worker_task( + WorkerTask(task_id=response.task_id) + ) if worker_response.task_id == response.task_id: return response else: @@ -281,48 +292,6 @@ def create_and_start_task(self, task: TaskRequest) -> TaskResponse: f"but {worker_response.task_id} was started instead" ) - @start_as_current_span(TRACER, "task") - def create_task(self, task: TaskRequest) -> TaskResponse: - """ - Create a new task, does not start execution - - Args: - task: Request object for task to create on the worker - - Returns: - TaskResponse: Acknowledgement of request - """ - - return self._rest.create_task(task) - - @start_as_current_span(TRACER) - def clear_task(self, task_id: str) -> TaskResponse: - """ - Delete a stored task on the worker - - Args: - task_id: ID for the task - - Returns: - TaskResponse: Acknowledgement of request - """ - - return self._rest.clear_task(task_id) - - @start_as_current_span(TRACER, "task") - def start_task(self, task: WorkerTask) -> WorkerTask: - """ - Instruct the worker to start a stored task immediately - - Args: - task: WorkerTask to start - - Returns: - WorkerTask: Acknowledgement of request - """ - - return self._rest.update_worker_task(task) - @start_as_current_span(TRACER, "reason") def abort(self, reason: str | None = None) -> WorkerState: """ @@ -358,15 +327,10 @@ def stop(self) -> WorkerState: return self._rest.cancel_current_task(WorkerState.STOPPING) + @property @start_as_current_span(TRACER) - def get_environment(self) -> EnvironmentResponse: - """ - Get details of the worker environment - - Returns: - EnvironmentResponse: Details of the worker - environment. - """ + def environment(self) -> EnvironmentResponse: + """Details of the worker environment""" return self._rest.get_environment() @@ -433,14 +397,10 @@ def _wait_for_reload( "seconds, a server restart is recommended" ) + @property @start_as_current_span(TRACER) - def get_oidc_config(self) -> OIDCConfig | None: - """ - Get oidc config from the server - - Returns: - OIDCConfig: Details of the oidc Config - """ + def oidc_config(self) -> OIDCConfig | None: + """OIDC config from the server""" return self._rest.get_oidc_config() diff --git a/src/blueapi/client/rest.py b/src/blueapi/client/rest.py index 3ff119449e..b6cebcd099 100644 --- a/src/blueapi/client/rest.py +++ b/src/blueapi/client/rest.py @@ -136,6 +136,7 @@ def _create_task_exceptions(response: requests.Response) -> Exception | None: class BlueapiRestClient: _config: RestConfig + _pool: requests.Session def __init__( self, @@ -144,6 +145,7 @@ def __init__( ) -> None: self._config = config or RestConfig() self._session_manager = session_manager + self._pool = requests.Session() def get_plans(self) -> PlanResponse: return self._request_and_deserialize("/plans", PlanResponse) @@ -252,14 +254,17 @@ def _request_and_deserialize( url = self._config.url.unicode_string().removesuffix("/") + suffix # Get the trace context to propagate to the REST API carr = get_context_propagator() - response = requests.request( - method, - url, - json=data, - params=params, - headers=carr, - auth=JWTAuth(self._session_manager), - ) + try: + response = self._pool.request( + method, + url, + json=data, + params=params, + headers=carr, + auth=JWTAuth(self._session_manager), + ) + except requests.exceptions.ConnectionError as ce: + raise ServiceUnavailableError() from ce exception = get_exception(response) if exception is not None: raise exception @@ -289,3 +294,7 @@ def __getattr__(name: str): ) return rename raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +class ServiceUnavailableError(Exception): + pass diff --git a/src/blueapi/service/interface.py b/src/blueapi/service/interface.py index 60b93250ea..9bc8bcef85 100644 --- a/src/blueapi/service/interface.py +++ b/src/blueapi/service/interface.py @@ -141,7 +141,10 @@ def get_devices() -> list[DeviceModel]: def get_device(name: str) -> DeviceModel: """Retrieve device by name from the BlueskyContext""" - return DeviceModel.from_device(context().devices[name]) + if not (device := context().find_device(name)): + raise KeyError(name) + + return DeviceModel.from_device(device) def submit_task(task_request: TaskRequest) -> str: diff --git a/src/blueapi/stubs/templates/cache_template.pyi.jinja b/src/blueapi/stubs/templates/cache_template.pyi.jinja new file mode 100644 index 0000000000..b06ef36385 --- /dev/null +++ b/src/blueapi/stubs/templates/cache_template.pyi.jinja @@ -0,0 +1,72 @@ +from collections.abc import Callable +from typing import Any +from blueapi.client.rest import BlueapiRestClient +from blueapi.service.model import DeviceModel, PlanModel +from blueapi.worker.event import WorkerEvent + +{#- + This file is based on the cache.py file in blueapi/client/cache.py and should + be kept in sync with changes there. +#} + +# This file is auto-generated for a live server and should not be modified directly + +PlanRunner = Callable[[str, dict[str, Any]], WorkerEvent] + +class PlanCache: + def __init__(self, runner: PlanRunner, plans: list[PlanModel]) -> None: ... + def __getitem__(self, name: str) -> Plan: ... + def __iter__(self): # -> Iterator[Plan]: + ... + def __repr__(self) -> str: ... + +### Generated plans +{%- for item in plans %} + def {{ item.name }}(self,{% for arg in item.args %} + {{ arg.name }}: {{ arg.type }}{% if arg.optional %} | None = None{% endif %}, + {%- endfor %} + ) -> WorkerEvent: + """ + {{ item.docs | docstring | indent(8) }} + """ + ... +{%- endfor %} +### End + + +class Plan: + model: PlanModel + def __init__(self, model: PlanModel, runner: PlanRunner) -> None: ... + def __call__(self, *args, **kwargs): # -> None: + ... + + @property + def name(self) -> str: ... + @property + def help_text(self) -> str: ... + @property + def properties(self) -> set[str]: ... + @property + def required(self) -> list[str]: ... + def __repr__(self) -> str: ... + + +class DeviceRef(str): + model: DeviceModel + _cache: DeviceCache + def __new__(cls, name: str, cache: DeviceCache, model: DeviceModel): ... + def __getattr__(self, name) -> DeviceRef: ... + def __repr__(self) -> str: ... + +class DeviceCache: + def __init__(self, rest: BlueapiRestClient) -> None: ... + def __getitem__(self, name: str) -> DeviceRef: ... + def __iter__(self): # -> Iterator[DeviceRef]: + ... + def __repr__(self) -> str: ... + +### Generated devices + {%- for item in devices %} + {{ item }}: DeviceRef + {%- endfor %} +### End diff --git a/tests/system_tests/test_blueapi_system.py b/tests/system_tests/test_blueapi_system.py index 0271331a6d..dc82c394ee 100644 --- a/tests/system_tests/test_blueapi_system.py +++ b/tests/system_tests/test_blueapi_system.py @@ -9,7 +9,6 @@ import requests from bluesky_stomp.models import BasicAuthentication from pydantic import TypeAdapter -from requests.exceptions import ConnectionError from scanspec.specs import Line from blueapi.client.client import ( @@ -17,7 +16,11 @@ BlueskyRemoteControlError, ) from blueapi.client.event_bus import AnyEvent, BlueskyStreamingError -from blueapi.client.rest import BlueskyRequestError +from blueapi.client.rest import ( + BlueapiRestClient, + BlueskyRequestError, + ServiceUnavailableError, +) from blueapi.config import ( ApplicationConfig, ConfigLoader, @@ -130,9 +133,9 @@ def client_with_stomp() -> Generator[BlueapiClient]: def wait_for_server(client: BlueapiClient): for _ in range(20): try: - client.get_environment() + _ = client.environment return - except ConnectionError: + except ServiceUnavailableError: ... time.sleep(0.5) raise TimeoutError("No connection to the blueapi server") @@ -149,6 +152,11 @@ def client() -> Generator[BlueapiClient]: yield BlueapiClient.from_config(config=ApplicationConfig()) +@pytest.fixture +def rest_client(client: BlueapiClient) -> BlueapiRestClient: + return client._rest + + @pytest.fixture def expected_plans() -> PlanResponse: return TypeAdapter(PlanResponse).validate_json( @@ -164,25 +172,22 @@ def expected_devices() -> DeviceResponse: @pytest.fixture -def blueapi_client_get_methods() -> list[str]: +def blueapi_rest_client_get_methods() -> list[str]: # Get a list of methods that take only one argument (self) - # This will currently return - # ['get_plans', 'get_devices', 'get_state', 'get_all_tasks', - # 'get_active_task','get_environment','resume', 'stop','get_oidc_config'] return [ - method - for method in BlueapiClient.__dict__ - if callable(getattr(BlueapiClient, method)) - and not method.startswith("__") - and len(inspect.signature(getattr(BlueapiClient, method)).parameters) == 1 - and "self" in inspect.signature(getattr(BlueapiClient, method)).parameters + name + for name, method in BlueapiRestClient.__dict__.items() + if not name.startswith("__") + and callable(method) + and len(params := inspect.signature(method).parameters) == 1 + and "self" in params ] @pytest.fixture(autouse=True) -def clean_existing_tasks(client: BlueapiClient): - for task in client.get_all_tasks().tasks: - client.clear_task(task.task_id) +def clean_existing_tasks(rest_client: BlueapiRestClient): + for task in rest_client.get_all_tasks().tasks: + rest_client.clear_task(task.task_id) yield @@ -214,26 +219,26 @@ def reset_numtracker(server_config: ApplicationConfig): def test_cannot_access_endpoints( - client_without_auth: BlueapiClient, blueapi_client_get_methods: list[str] + client_without_auth: BlueapiClient, blueapi_rest_client_get_methods: list[str] ): - blueapi_client_get_methods.remove( + blueapi_rest_client_get_methods.remove( "get_oidc_config" ) # get_oidc_config can be accessed without auth - for get_method in blueapi_client_get_methods: + for get_method in blueapi_rest_client_get_methods: with pytest.raises(BlueskyRemoteControlError, match=r""): - getattr(client_without_auth, get_method)() + getattr(client_without_auth._rest, get_method)() def test_can_get_oidc_config_without_auth(client_without_auth: BlueapiClient): - assert client_without_auth.get_oidc_config() == OIDCConfig( + assert client_without_auth.oidc_config == OIDCConfig( well_known_url=KEYCLOAK_BASE_URL + "realms/master/.well-known/openid-configuration", client_id="ixx-cli-blueapi", ) -def test_get_plans(client: BlueapiClient, expected_plans: PlanResponse): - retrieved_plans = client.get_plans() +def test_get_plans(rest_client: BlueapiRestClient, expected_plans: PlanResponse): + retrieved_plans = rest_client.get_plans() retrieved_plans.plans.sort(key=lambda x: x.name) expected_plans.plans.sort(key=lambda x: x.name) @@ -242,40 +247,52 @@ def test_get_plans(client: BlueapiClient, expected_plans: PlanResponse): def test_get_plans_by_name(client: BlueapiClient, expected_plans: PlanResponse): for plan in expected_plans.plans: - assert client.get_plan(plan.name) == plan + assert client.plans[plan.name].model == plan -def test_get_non_existent_plan(client: BlueapiClient): +def test_get_non_existent_plan(rest_client: BlueapiRestClient): with pytest.raises(KeyError, match="{'detail': 'Item not found'}"): - client.get_plan("Not exists") + rest_client.get_plan("Not exists") + +def test_client_non_existent_plan(client: BlueapiClient): + with pytest.raises(AttributeError, match="No plan named 'missing' available"): + _ = client.plans.missing -def test_get_devices(client: BlueapiClient, expected_devices: DeviceResponse): - retrieved_devices = client.get_devices() + +def test_get_devices(rest_client: BlueapiRestClient, expected_devices: DeviceResponse): + retrieved_devices = rest_client.get_devices() retrieved_devices.devices.sort(key=lambda x: x.name) expected_devices.devices.sort(key=lambda x: x.name) assert retrieved_devices == expected_devices -def test_get_device_by_name(client: BlueapiClient, expected_devices: DeviceResponse): +def test_get_device_by_name( + rest_client: BlueapiRestClient, expected_devices: DeviceResponse +): for device in expected_devices.devices: - assert client.get_device(device.name) == device + assert rest_client.get_device(device.name) == device -def test_get_non_existent_device(client: BlueapiClient): +def test_get_non_existent_device(rest_client: BlueapiRestClient): with pytest.raises(KeyError, match="{'detail': 'Item not found'}"): - client.get_device("Not exists") + rest_client.get_device("Not exists") + + +def test_client_non_existent_device(client: BlueapiClient): + with pytest.raises(AttributeError, match="No device named 'missing' available"): + _ = client.devices.missing -def test_create_task_and_delete_task_by_id(client: BlueapiClient): - create_task = client.create_task(_SIMPLE_TASK) - client.clear_task(create_task.task_id) +def test_create_task_and_delete_task_by_id(rest_client: BlueapiRestClient): + create_task = rest_client.create_task(_SIMPLE_TASK) + rest_client.clear_task(create_task.task_id) -def test_instrument_session_propagated(client: BlueapiClient): - response = client.create_task(_SIMPLE_TASK) - trackable_task = client.get_task(response.task_id) +def test_instrument_session_propagated(rest_client: BlueapiRestClient): + response = rest_client.create_task(_SIMPLE_TASK) + trackable_task = rest_client.get_task(response.task_id) assert trackable_task.task.metadata == { "instrument_session": AUTHORIZED_INSTRUMENT_SESSION, "tiled_access_tags": [ @@ -284,9 +301,9 @@ def test_instrument_session_propagated(client: BlueapiClient): } -def test_create_task_validation_error(client: BlueapiClient): +def test_create_task_validation_error(rest_client: BlueapiRestClient): with pytest.raises(BlueskyRequestError, match="Internal Server Error"): - client.create_task( + rest_client.create_task( TaskRequest( name="Not-exists", params={"Not-exists": 0.0}, @@ -295,26 +312,26 @@ def test_create_task_validation_error(client: BlueapiClient): ) -def test_get_all_tasks(client: BlueapiClient): +def test_get_all_tasks(rest_client: BlueapiRestClient): created_tasks: list[TaskResponse] = [] for task in [_SIMPLE_TASK, _LONG_TASK]: - created_task = client.create_task(task) + created_task = rest_client.create_task(task) created_tasks.append(created_task) task_ids = [task.task_id for task in created_tasks] - task_list = client.get_all_tasks() + task_list = rest_client.get_all_tasks() for trackable_task in task_list.tasks: assert trackable_task.task_id in task_ids assert trackable_task.is_complete is False and trackable_task.is_pending is True for task_id in task_ids: - client.clear_task(task_id) + rest_client.clear_task(task_id) -def test_get_task_by_id(client: BlueapiClient): - created_task = client.create_task(_SIMPLE_TASK) +def test_get_task_by_id(rest_client: BlueapiRestClient): + created_task = rest_client.create_task(_SIMPLE_TASK) - get_task = client.get_task(created_task.task_id) + get_task = rest_client.get_task(created_task.task_id) assert ( get_task.task_id == created_task.task_id and get_task.is_pending @@ -322,45 +339,45 @@ def test_get_task_by_id(client: BlueapiClient): and len(get_task.errors) == 0 ) - client.clear_task(created_task.task_id) + rest_client.clear_task(created_task.task_id) -def test_get_non_existent_task(client: BlueapiClient): +def test_get_non_existent_task(rest_client: BlueapiRestClient): with pytest.raises(KeyError, match="{'detail': 'Item not found'}"): - client.get_task("Not-exists") + rest_client.get_task("Not-exists") -def test_delete_non_existent_task(client: BlueapiClient): +def test_delete_non_existent_task(rest_client: BlueapiRestClient): with pytest.raises(KeyError, match="{'detail': 'Item not found'}"): - client.clear_task("Not-exists") + rest_client.clear_task("Not-exists") -def test_put_worker_task(client: BlueapiClient): - created_task = client.create_task(_SIMPLE_TASK) - client.start_task(WorkerTask(task_id=created_task.task_id)) - active_task = client.get_active_task() +def test_put_worker_task(rest_client: BlueapiRestClient): + created_task = rest_client.create_task(_SIMPLE_TASK) + rest_client.update_worker_task(WorkerTask(task_id=created_task.task_id)) + active_task = rest_client.get_active_task() assert active_task.task_id == created_task.task_id - client.clear_task(created_task.task_id) + rest_client.clear_task(created_task.task_id) -def test_put_worker_task_fails_if_not_idle(client: BlueapiClient): - small_task = client.create_task(_SIMPLE_TASK) - long_task = client.create_task(_LONG_TASK) +def test_put_worker_task_fails_if_not_idle(rest_client: BlueapiRestClient): + small_task = rest_client.create_task(_SIMPLE_TASK) + long_task = rest_client.create_task(_LONG_TASK) - client.start_task(WorkerTask(task_id=long_task.task_id)) - active_task = client.get_active_task() + rest_client.update_worker_task(WorkerTask(task_id=long_task.task_id)) + active_task = rest_client.get_active_task() assert active_task.task_id == long_task.task_id with pytest.raises(BlueskyRemoteControlError) as exception: - client.start_task(WorkerTask(task_id=small_task.task_id)) + rest_client.update_worker_task(WorkerTask(task_id=small_task.task_id)) assert "" in str(exception) - client.abort() - client.clear_task(small_task.task_id) - client.clear_task(long_task.task_id) + rest_client.cancel_current_task(WorkerState.ABORTING) + rest_client.clear_task(small_task.task_id) + rest_client.clear_task(long_task.task_id) def test_get_worker_state(client: BlueapiClient): - assert client.get_state() == WorkerState.IDLE + assert client.state == WorkerState.IDLE def test_set_state_transition_error(client: BlueapiClient): @@ -372,10 +389,10 @@ def test_set_state_transition_error(client: BlueapiClient): assert "" in str(exception) -def test_get_task_by_status(client: BlueapiClient): - task_1 = client.create_task(_SIMPLE_TASK) - task_2 = client.create_task(_SIMPLE_TASK) - task_by_pending = client.get_all_tasks() +def test_get_task_by_status(rest_client: BlueapiRestClient): + task_1 = rest_client.create_task(_SIMPLE_TASK) + task_2 = rest_client.create_task(_SIMPLE_TASK) + task_by_pending = rest_client.get_all_tasks() # https://github.com/DiamondLightSource/blueapi/issues/680 # task_by_pending = client.get_tasks_by_status(TaskStatusEnum.PENDING) assert len(task_by_pending.tasks) == 2 @@ -384,13 +401,13 @@ def test_get_task_by_status(client: BlueapiClient): trackable_task = TypeAdapter(TrackableTask).validate_python(task) assert trackable_task.is_complete is False and trackable_task.is_pending is True - client.start_task(WorkerTask(task_id=task_1.task_id)) - while not client.get_task(task_1.task_id).is_complete: + rest_client.update_worker_task(WorkerTask(task_id=task_1.task_id)) + while not rest_client.get_task(task_1.task_id).is_complete: time.sleep(0.1) - client.start_task(WorkerTask(task_id=task_2.task_id)) - while not client.get_task(task_2.task_id).is_complete: + rest_client.update_worker_task(WorkerTask(task_id=task_2.task_id)) + while not rest_client.get_task(task_2.task_id).is_complete: time.sleep(0.1) - task_by_completed = client.get_all_tasks() + task_by_completed = rest_client.get_all_tasks() # https://github.com/DiamondLightSource/blueapi/issues/680 # task_by_pending = client.get_tasks_by_status(TaskStatusEnum.COMPLETE) assert len(task_by_completed.tasks) == 2 @@ -399,8 +416,8 @@ def test_get_task_by_status(client: BlueapiClient): trackable_task = TypeAdapter(TrackableTask).validate_python(task) assert trackable_task.is_complete is True and trackable_task.is_pending is False - client.clear_task(task_id=task_1.task_id) - client.clear_task(task_id=task_2.task_id) + rest_client.clear_task(task_id=task_1.task_id) + rest_client.clear_task(task_id=task_2.task_id) def test_progress_with_stomp(client_with_stomp: BlueapiClient): @@ -441,13 +458,13 @@ def on_event(event: AnyEvent): def test_get_current_state_of_environment(client: BlueapiClient): - assert client.get_environment().initialized + assert client.environment.initialized def test_delete_current_environment(client: BlueapiClient): - old_env = client.get_environment() + old_env = client.environment client.reload_environment() - new_env = client.get_environment() + new_env = client.environment assert new_env.initialized assert new_env.environment_id != old_env.environment_id assert new_env.error_message is None diff --git a/tests/unit_tests/cli/test_stubgen.py b/tests/unit_tests/cli/test_stubgen.py new file mode 100644 index 0000000000..766f3e2f06 --- /dev/null +++ b/tests/unit_tests/cli/test_stubgen.py @@ -0,0 +1,214 @@ +from io import StringIO +from textwrap import dedent +from types import FunctionType +from unittest.mock import Mock + +import pytest + +from blueapi.cli.stubgen import ( + _docstring, + _type_string, + generate_stubs, + render_stub_file, +) +from blueapi.client.cache import DeviceRef, Plan +from blueapi.service.model import DeviceModel, PlanModel + + +def single_line(): + """Single line docstring""" + + +def single_line_new_line(): + """ + Single line docstring + """ + + +def multi_line_inline(): + """First line + Second line""" + + +def multi_line_new_line(): + """ + First line + Second line + """ + + +def indented_multi_line(): + """ + First line + indented + """ + + +@pytest.mark.parametrize( + "input,expected", + [ + (single_line, "Single line docstring"), + (single_line_new_line, "Single line docstring"), + (multi_line_inline, "First line\nSecond line"), + (multi_line_new_line, "First line\nSecond line"), + (indented_multi_line, "First line\n indented"), + ], +) +def test_docstring_filter(input: FunctionType, expected: str): + assert input.__doc__ + assert _docstring(input.__doc__) == expected + + +@pytest.mark.parametrize( + "typ,expected", + [ + ({"type": "string"}, "str"), + ({"type": "number"}, "float"), + ({"type": "integer"}, "int"), + ({"type": "object"}, "dict[str, Any]"), + ({"type": "boolean"}, "bool"), + ({"type": "array", "items": {"type": "integer"}}, "list[int]"), + ({"type": "array", "items": {"type": "object"}}, "list[dict[str, Any]]"), + ( + { + "type": "array", + "items": {"anyOf": [{"type": "integer"}, {"type": "boolean"}]}, + }, + "list[int | bool]", + ), + ({"anyOf": [{"type": "object"}, {"type": "string"}]}, "dict[str, Any] | str"), + ({"type": "unknown.other.Type"}, "Any"), + # Special case the bluesky protocols to require device references + ({"type": "bluesky.protocols.Readable"}, "DeviceRef"), + ({}, "Any"), + ], + ids=lambda param: param.get("type") if isinstance(param, dict) else param, +) +def test_type_string(typ: dict, expected: str): + assert _type_string(typ) == expected + + +def test_render_empty(): + output = StringIO() + + render_stub_file(output, [], []) + plan_text, device_text = _extract_rendered(output) + + assert plan_text == "" + assert device_text == "" + + +FOO = PlanModel(name="empty", description="Doc string for empty", schema={}) + +BAR = PlanModel( + name="two_args", + description="Doc string for two_args", + schema={ + "properties": { + "one": {"type": "integer"}, + "two": {"type": "string"}, + }, + "required": ["one"], + }, +) + + +def test_render_empty_plan_function(): + output = StringIO() + plans = [Plan(model=FOO, runner=Mock())] + render_stub_file(output, plans, []) + plan_text, device_text = _extract_rendered(output) + + assert device_text == "" + + assert ( + plan_text + == """\ + def empty(self, + ) -> WorkerEvent: + \""" + Doc string for empty + \""" + ...\n""" + ) + + +def test_render_multiple_plan_functions(): + output = StringIO() + runner = Mock() + plans = [Plan(FOO, runner), Plan(BAR, runner)] + render_stub_file(output, plans, []) + plan_text, device_text = _extract_rendered(output) + assert device_text == "" + + assert ( + plan_text + == """\ + def empty(self, + ) -> WorkerEvent: + \""" + Doc string for empty + \""" + ... + def two_args(self, + one: int, + two: str | None = None, + ) -> WorkerEvent: + \""" + Doc string for two_args + \""" + ...\n""" + ) + + +def test_device_fields(): + output = StringIO() + cache = Mock() + devices = [ + DeviceRef("one", cache, DeviceModel(name="one", protocols=[])), + DeviceRef("two", cache, DeviceModel(name="two", protocols=[])), + ] + render_stub_file(output, [], devices) + + plan_text, device_text = _extract_rendered(output) + assert plan_text == "" + assert device_text == " one: DeviceRef\n two: DeviceRef\n" + + +def test_package_creation(tmp_path): + generate_stubs(tmp_path / "blueapi-stubs", [], []) + with open(tmp_path / "blueapi-stubs" / "pyproject.toml") as pyproj: + assert pyproj.read().startswith( + dedent(""" + [project] + name = "blueapi-stubs" + version = "0.1.0" + """) + ) + with open( + tmp_path / "blueapi-stubs" / "src" / "blueapi-stubs" / "py.typed" + ) as typed: + assert typed.read() == "partial\n" + + assert ( + tmp_path / "blueapi-stubs" / "src" / "blueapi-stubs" / "client" / "cache.pyi" + ).exists() + + +def _extract_rendered(src: StringIO) -> tuple[str, str]: + src.seek(0) + _read_until_line(src, "### Generated plans") + plan_text = _read_until_line(src, "### End") + _read_until_line(src, "### Generated devices") + device_text = _read_until_line(src, "### End") + return plan_text, device_text + + +def _read_until_line(src: StringIO, match: str) -> str: + text = "" + for line in src: + if line.startswith(match): + break + text += line + + return text diff --git a/tests/unit_tests/client/test_client.py b/tests/unit_tests/client/test_client.py index d13ccce80d..8bbc54578d 100644 --- a/tests/unit_tests/client/test_client.py +++ b/tests/unit_tests/client/test_client.py @@ -8,8 +8,13 @@ JsonObjectSpanExporter, asserting_span_exporter, ) +from pydantic import HttpUrl -from blueapi.client.client import BlueapiClient +from blueapi.client.cache import DeviceCache, DeviceRef, Plan, PlanCache +from blueapi.client.client import ( + BlueapiClient, + MissingInstrumentSessionError, +) from blueapi.client.event_bus import AnyEvent, BlueskyStreamingError, EventBusClient from blueapi.client.rest import BlueapiRestClient, BlueskyRemoteControlError from blueapi.config import MissingStompConfigurationError @@ -20,6 +25,7 @@ EnvironmentResponse, PlanModel, PlanResponse, + ProtocolInfo, TaskRequest, TaskResponse, TasksListResponse, @@ -35,6 +41,19 @@ ] ) PLAN = PlanModel(name="foo") +FULL_PLAN = PlanModel( + name="foobar", + description="Description of plan foobar", + schema={ + "title": "foobar", + "description": "Model description of plan foobar", + "properties": { + "one": {}, + "two": {}, + }, + "required": ["one"], + }, +) DEVICES = DeviceResponse( devices=[ DeviceModel(name="foo", protocols=[]), @@ -72,9 +91,9 @@ def mock_rest() -> BlueapiRestClient: mock = Mock(spec=BlueapiRestClient) mock.get_plans.return_value = PLANS - mock.get_plan.return_value = PLAN + mock.get_plan.side_effect = lambda n: {p.name: p for p in PLANS.plans}[n] mock.get_devices.return_value = DEVICES - mock.get_device.return_value = DEVICE + mock.get_device.side_effect = lambda n: {d.name: d for d in DEVICES.devices}[n] mock.get_state.return_value = WorkerState.IDLE mock.get_task.return_value = TASK mock.get_all_tasks.return_value = TASKS @@ -105,114 +124,62 @@ def client_with_events(mock_rest: Mock, mock_events: MagicMock): return BlueapiClient(rest=mock_rest, events=mock_events) +def test_client_from_config(): + bc = BlueapiClient.from_config_file( + "tests/unit_tests/valid_example_config/client.yaml" + ) + assert bc._rest._config.url == HttpUrl("http://example.com:8082") + + def test_get_plans(client: BlueapiClient): - assert client.get_plans() == PLANS + assert PlanResponse(plans=[p.model for p in client.plans]) == PLANS def test_get_plan(client: BlueapiClient): - assert client.get_plan("foo") == PLAN + assert client.plans.foo.model == PLAN + assert client.plans["foo"].model == PLAN def test_get_nonexistant_plan( client: BlueapiClient, - mock_rest: Mock, ): - mock_rest.get_plan.side_effect = KeyError("Not found") - with pytest.raises(KeyError): - client.get_plan("baz") + with pytest.raises(AttributeError): + _ = client.plans.fizz_buzz.model def test_get_devices(client: BlueapiClient): - assert client.get_devices() == DEVICES + assert DeviceResponse(devices=[d.model for d in client.devices]) == DEVICES def test_get_device(client: BlueapiClient): - assert client.get_device("foo") == DEVICE - - -def test_get_nonexistant_device( - client: BlueapiClient, - mock_rest: Mock, -): - mock_rest.get_device.side_effect = KeyError("Not found") - with pytest.raises(KeyError): - client.get_device("baz") - - -def test_get_state(client: BlueapiClient): - assert client.get_state() == WorkerState.IDLE - - -def test_get_task(client: BlueapiClient): - assert client.get_task("foo") == TASK - - -def test_get_nonexistent_task( - client: BlueapiClient, - mock_rest: Mock, -): - mock_rest.get_task.side_effect = KeyError("Not found") - with pytest.raises(KeyError): - client.get_task("baz") + assert client.devices.foo.model == DEVICE -def test_get_task_with_empty_id(client: BlueapiClient): - with pytest.raises(AssertionError) as exc: - client.get_task("") - assert str(exc) == "Task ID not provided!" - - -def test_get_all_tasks( +def test_get_nonexistent_device( client: BlueapiClient, ): - assert client.get_all_tasks() == TASKS + with pytest.raises(AttributeError): + _ = client.devices.baz -def test_create_task( - client: BlueapiClient, - mock_rest: Mock, -): - client.create_task(task=TaskRequest(name="foo", instrument_session="cm12345-1")) - mock_rest.create_task.assert_called_once_with( - TaskRequest(name="foo", instrument_session="cm12345-1") +def test_get_child_device(mock_rest: Mock, client: BlueapiClient): + mock_rest.get_device.side_effect = ( + lambda name: DeviceModel(name="foo.x", protocols=[ProtocolInfo(name="One")]) + if name == "foo.x" + else None ) + foo = client.devices.foo + assert foo == "foo" + x = client.devices.foo.x + assert x == "foo.x" -def test_create_task_does_not_start_task( - client: BlueapiClient, - mock_rest: Mock, -): - client.create_task(task=TaskRequest(name="foo", instrument_session="cm12345-1")) - mock_rest.update_worker_task.assert_not_called() - - -def test_clear_task( - client: BlueapiClient, - mock_rest: Mock, -): - client.clear_task(task_id="foo") - mock_rest.clear_task.assert_called_once_with("foo") +def test_get_state(client: BlueapiClient): + assert client.state == WorkerState.IDLE def test_get_active_task(client: BlueapiClient): - assert client.get_active_task() == ACTIVE_TASK - - -def test_start_task( - client: BlueapiClient, - mock_rest: Mock, -): - client.start_task(task=WorkerTask(task_id="bar")) - mock_rest.update_worker_task.assert_called_once_with(WorkerTask(task_id="bar")) - - -def test_start_nonexistant_task( - client: BlueapiClient, - mock_rest: Mock, -): - mock_rest.update_worker_task.side_effect = KeyError("Not found") - with pytest.raises(KeyError): - client.start_task(task=WorkerTask(task_id="bar")) + assert client.active_task == ACTIVE_TASK def test_create_and_start_task_calls_both_creating_and_starting_endpoints( @@ -266,7 +233,7 @@ def test_create_and_start_task_fails_if_task_start_fails( def test_get_environment(client: BlueapiClient): - assert client.get_environment() == ENV + assert client.environment == ENV def test_reload_environment( @@ -439,6 +406,15 @@ def test_run_task_fails_on_failing_event( on_event.assert_called_with(FAILED_EVENT) +@patch("blueapi.client.client.BlueapiClient.run_task") +def test_run_plan(run_task, client, mock_rest): + client.instrument_session = "cm12345-2" + client.run_plan("foo", {"foo": "bar"}) + run_task.assert_called_once_with( + TaskRequest(name="foo", params={"foo": "bar"}, instrument_session="cm12345-2") + ) + + @pytest.mark.parametrize( "test_event", [ @@ -521,76 +497,40 @@ def callback(on_event: Callable[[AnyEvent, MessageContext], None]): mock_on_event.assert_called_once_with(COMPLETE_EVENT) +def test_get_oidc_config(client, mock_rest): + assert client.oidc_config == mock_rest.get_oidc_config() + + def test_get_plans_span_ok(exporter: JsonObjectSpanExporter, client: BlueapiClient): - with asserting_span_exporter(exporter, "get_plans"): - client.get_plans() + with asserting_span_exporter(exporter, "plans"): + _ = client.plans def test_get_plan_span_ok(exporter: JsonObjectSpanExporter, client: BlueapiClient): - with asserting_span_exporter(exporter, "get_plan", "name"): - client.get_plan("foo") + with asserting_span_exporter(exporter, "plans"): + _ = client.plans.foo def test_get_devices_span_ok(exporter: JsonObjectSpanExporter, client: BlueapiClient): - with asserting_span_exporter(exporter, "get_devices"): - client.get_devices() + with asserting_span_exporter(exporter, "devices"): + _ = client.devices def test_get_device_span_ok(exporter: JsonObjectSpanExporter, client: BlueapiClient): - with asserting_span_exporter(exporter, "get_device", "name"): - client.get_device("foo") - + with asserting_span_exporter(exporter, "devices"): + _ = client.devices.foo -def test_get_state_ok(exporter: JsonObjectSpanExporter, client: BlueapiClient): - with asserting_span_exporter(exporter, "get_state"): - client.get_state() - -def test_get_task_span_ok(exporter: JsonObjectSpanExporter, client: BlueapiClient): - with asserting_span_exporter(exporter, "get_task", "task_id"): - client.get_task("foo") - - -def test_get_all_tasks_span_ok( - exporter: JsonObjectSpanExporter, - client: BlueapiClient, -): - with asserting_span_exporter(exporter, "get_all_tasks"): - client.get_all_tasks() - - -def test_create_task_span_ok( - exporter: JsonObjectSpanExporter, - client: BlueapiClient, - mock_rest: Mock, -): - with asserting_span_exporter(exporter, "create_task", "task"): - client.create_task(task=TaskRequest(name="foo", instrument_session="cm12345-1")) - - -def test_clear_task_span_ok( - exporter: JsonObjectSpanExporter, - client: BlueapiClient, - mock_rest: Mock, -): - with asserting_span_exporter(exporter, "clear_task"): - client.clear_task(task_id="foo") +def test_get_state_span_ok(exporter: JsonObjectSpanExporter, client: BlueapiClient): + with asserting_span_exporter(exporter, "state"): + _ = client.state def test_get_active_task_span_ok( exporter: JsonObjectSpanExporter, client: BlueapiClient ): - with asserting_span_exporter(exporter, "get_active_task"): - client.get_active_task() - - -def test_start_task_span_ok( - exporter: JsonObjectSpanExporter, - client: BlueapiClient, - mock_rest: Mock, -): - with asserting_span_exporter(exporter, "start_task", "task"): - client.start_task(task=WorkerTask(task_id="bar")) + with asserting_span_exporter(exporter, "active_task"): + _ = client.active_task def test_create_and_start_task_span_ok( @@ -609,8 +549,8 @@ def test_create_and_start_task_span_ok( def test_get_environment_span_ok( exporter: JsonObjectSpanExporter, client: BlueapiClient ): - with asserting_span_exporter(exporter, "get_environment"): - client.get_environment() + with asserting_span_exporter(exporter, "environment"): + _ = client.environment def test_reload_environment_span_ok( @@ -668,3 +608,239 @@ def test_cannot_run_task_span_ok( ): with asserting_span_exporter(exporter, "grun_task"): client.run_task(TaskRequest(name="foo", instrument_session="cm12345-1")) + + +def test_instrument_session_required(client): + with pytest.raises(MissingInstrumentSessionError): + _ = client.instrument_session + + +def test_setting_instrument_session(client): + # This looks like a completely pointless test but instrument_session is a + # property with some logic so it's not purely to get coverage up + client.instrument_session = "cm12345-4" + assert client.instrument_session == "cm12345-4" + + +def test_fluent_instrument_session_setter(client): + client2 = client.with_instrument_session("cm12345-3") + assert client is client2 + assert client.instrument_session == "cm12345-3" + + +def test_plan_cache_ignores_underscores(client): + cache = PlanCache(client, [PlanModel(name="_ignored"), PlanModel(name="used")]) + with pytest.raises(AttributeError, match="_ignored"): + _ = cache._ignored + + +def test_plan_cache_repr(client): + assert repr(client.plans) == "PlanCache(2 plans)" + + +def test_device_cache_ignores_underscores(): + rest = Mock() + rest.get_devices.return_value = DeviceResponse( + devices=[ + DeviceModel(name="_ignored", protocols=[]), + ] + ) + cache = DeviceCache(rest) + with pytest.raises(AttributeError, match="_ignored"): + _ = cache._ignored + + rest.get_devices.reset_mock() + with pytest.raises(AttributeError, match="_anything"): + _ = cache._anything + rest.get_device.assert_not_called() + + +def test_devices_are_cached(mock_rest): + cache = DeviceCache(mock_rest) + _ = cache.foo + mock_rest.get_device.assert_not_called() + _ = cache["foo"] + mock_rest.get_device.assert_not_called() + + +def test_device_cache_repr(client): + assert repr(client.devices) == "DeviceCache(2 devices)" + + +def test_device_repr(): + cache = Mock() + model = Mock() + dev = DeviceRef(name="foo", cache=cache, model=model) + assert repr(dev) == "Device(foo)" + + +def test_device_ignores_underscores(): + cache = MagicMock() + model = Mock() + dev = DeviceRef(name="foo", cache=cache, model=model) + with pytest.raises(AttributeError, match="_underscore"): + _ = dev._underscore + cache.__getitem__.assert_not_called() + + +def test_plan_help_text(): + plan = Plan(PlanModel(name="foo", description="help for foo"), Mock()) + assert plan.help_text == "help for foo" + + +def test_plan_fallback_help_text(): + plan = Plan( + PlanModel( + name="foo", + schema={"properties": {"one": {}, "two": {}}, "required": ["one"]}, + ), + Mock(), + ) + assert plan.help_text == "Plan foo(one, two=None)" + + +def test_plan_properties(): + plan = Plan( + PlanModel( + name="foo", + schema={"properties": {"one": {}, "two": {}}, "required": ["one"]}, + ), + Mock(), + ) + + assert plan.properties == {"one", "two"} + assert plan.required == ["one"] + + +def test_plan_empty_fallback_help_text(): + plan = Plan( + PlanModel(name="foo", schema={"properties": {}, "required": []}), Mock() + ) + assert plan.help_text == "Plan foo()" + + +p = pytest.param + + +@pytest.mark.parametrize( + "args,kwargs,params", + [ + p((1,), {}, {"one": 1}, id="required_as_positional"), + p((), {"one": 7}, {"one": 7}, id="required_as_keyword"), + p((1,), {"two": 23}, {"one": 1, "two": 23}, id="all_as_mixed_args_kwargs"), + p((1, 2), {}, {"one": 1, "two": 2}, id="all_as_positional"), + p((), {"one": 21, "two": 42}, {"one": 21, "two": 42}, id="all_as_keyword"), + ], +) +def test_plan_param_mapping(args, kwargs, params): + runner = Mock() + plan = Plan(FULL_PLAN, runner) + + plan(*args, **kwargs) + runner.assert_called_once_with("foobar", params) + + +@pytest.mark.parametrize( + "args,kwargs,msg", + [ + p((), {}, r"Missing argument\(s\) for \{'one'\}", id="missing_required"), + p((1,), {"one": 7}, "multiple values for one", id="duplicate_required"), + p((1, 2), {"two": 23}, "multiple values for two", id="duplicate_optional"), + p((1, 2, 3), {}, "too many arguments", id="too_many_args"), + p( + (), + {"unknown_key": 42}, + r"got unexpected arguments: \{'unknown_key'\}", + id="unknown_arg", + ), + ], +) +def test_plan_invalid_param_mapping(args, kwargs, msg): + runner = Mock(spec=Callable) + plan = Plan( + FULL_PLAN, + runner, + ) + + with pytest.raises(TypeError, match=msg): + plan(*args, **kwargs) + runner.assert_not_called() + + +def test_adding_removing_callback(client): + def callback(*a, **kw): + pass + + cb_id = client.add_callback(callback) + assert len(client.callbacks) == 1 + client.remove_callback(cb_id) + assert len(client.callbacks) == 0 + + +@pytest.mark.parametrize( + "test_event", + [ + WorkerEvent( + state=WorkerState.RUNNING, + task_status=TaskStatus( + task_id="foo", + task_complete=False, + task_failed=False, + ), + ), + ProgressEvent(task_id="foo"), + DataEvent(name="start", doc={}, task_id="0000-1111"), + ], +) +def test_client_callbacks( + client_with_events: BlueapiClient, + mock_rest: Mock, + mock_events: MagicMock, + test_event: AnyEvent, +): + callback = Mock() + client_with_events.add_callback(callback) + mock_rest.create_task.return_value = TaskResponse(task_id="foo") + mock_rest.update_worker_task.return_value = TaskResponse(task_id="foo") + + ctx = Mock() + ctx.correlation_id = "foo" + + def subscribe(on_event: Callable[[AnyEvent, MessageContext], None]): + on_event(test_event, ctx) + on_event(COMPLETE_EVENT, ctx) + + mock_events.subscribe_to_all_events = subscribe # type: ignore + + client_with_events.run_task(TaskRequest(name="foo", instrument_session="cm12345-1")) + + assert callback.mock_calls == [call(test_event), call(COMPLETE_EVENT)] + + +def test_client_callback_failures( + client_with_events: BlueapiClient, + mock_rest: Mock, + mock_events: MagicMock, +): + failing_callback = Mock(side_effect=ValueError("Broken callback")) + callback = Mock() + client_with_events.add_callback(failing_callback) + client_with_events.add_callback(callback) + mock_rest.create_task.return_value = TaskResponse(task_id="foo") + mock_rest.update_worker_task.return_value = TaskResponse(task_id="foo") + + ctx = Mock() + ctx.correlation_id = "foo" + + evt = DataEvent(name="start", doc={}, task_id="foo") + + def subscribe(on_event: Callable[[AnyEvent, MessageContext], None]): + on_event(evt, ctx) + on_event(COMPLETE_EVENT, ctx) + + mock_events.subscribe_to_all_events = subscribe # type: ignore + + client_with_events.run_task(TaskRequest(name="foo", instrument_session="cm12345-1")) + + assert failing_callback.mock_calls == [call(evt), call(COMPLETE_EVENT)] + assert callback.mock_calls == [call(evt), call(COMPLETE_EVENT)] diff --git a/tests/unit_tests/client/test_rest.py b/tests/unit_tests/client/test_rest.py index c8fce9d101..2ddcdd3800 100644 --- a/tests/unit_tests/client/test_rest.py +++ b/tests/unit_tests/client/test_rest.py @@ -45,7 +45,7 @@ def rest_with_auth(oidc_config: OIDCConfig, tmp_path) -> BlueapiRestClient: (500, BlueskyRemoteControlError), ], ) -@patch("blueapi.client.rest.requests.request") +@patch("blueapi.client.rest.requests.Session.request") def test_rest_error_code( mock_request: Mock, rest: BlueapiRestClient, diff --git a/tests/unit_tests/test_cli.py b/tests/unit_tests/test_cli.py index 2a49a1fe80..a789d21eb5 100644 --- a/tests/unit_tests/test_cli.py +++ b/tests/unit_tests/test_cli.py @@ -110,7 +110,7 @@ def test_runs_with_umask_002( mock_umask.assert_called_once_with(0o002) -@patch("requests.request") +@patch("blueapi.client.rest.requests.Session.request") def test_connection_error_caught_by_wrapper_func( mock_requests: Mock, runner: CliRunner ): @@ -120,7 +120,7 @@ def test_connection_error_caught_by_wrapper_func( assert result.output == "Error: Failed to establish connection to blueapi server.\n" -@patch("requests.request") +@patch("blueapi.client.rest.requests.Session.request") def test_authentication_error_caught_by_wrapper_func( mock_requests: Mock, runner: CliRunner ): @@ -133,7 +133,7 @@ def test_authentication_error_caught_by_wrapper_func( ) -@patch("requests.request") +@patch("blueapi.client.rest.requests.Session.request") def test_remote_error_raised_by_wrapper_func(mock_requests: Mock, runner: CliRunner): mock_requests.side_effect = BlueskyRemoteControlError("Response [450]") @@ -198,15 +198,15 @@ def test_invalid_config_path_handling(runner: CliRunner): assert result.exit_code == 1 -@patch("blueapi.cli.cli.BlueapiClient.get_plans") +@patch("blueapi.cli.cli.BlueapiClient.plans") @patch("blueapi.cli.cli.OutputFormat.FULL.display") def test_options_via_env(mock_display, mock_plans, runner: CliRunner): result = runner.invoke( main, args=["controller", "plans"], env={"BLUEAPI_CONTROLLER_OUTPUT": "full"} ) - mock_plans.assert_called_once_with() - mock_display.assert_called_once_with(mock_plans.return_value) + mock_plans.__iter__.assert_called_once_with() + mock_display.assert_called_once_with(PlanResponse(plans=list(mock_plans))) assert result.exit_code == 0 @@ -493,9 +493,7 @@ def test_valid_stomp_config_for_listener( @responses.activate -def test_get_env( - runner: CliRunner, -): +def test_get_env(runner: CliRunner): environment_id = uuid.uuid4() responses.add( responses.GET, @@ -514,6 +512,17 @@ def test_get_env( ) +@responses.activate +def test_get_state(runner: CliRunner): + responses.add( + responses.GET, "http://localhost:8000/worker/state", json="IDLE", status=200 + ) + state = runner.invoke(main, ["controller", "state"]) + print(state.stderr) + assert state.exit_code == 0 + assert state.output == "IDLE\n" + + @responses.activate(assert_all_requests_are_fired=True) @patch("blueapi.client.client.time.sleep", return_value=None) def test_reset_env_client_behavior( @@ -1320,3 +1329,17 @@ def test_config_schema( stream.write.assert_called() else: assert json.loads(result.output) == expected + pass + + +@patch("blueapi.client.client.BlueapiClient.from_config") +@patch("blueapi.cli.cli.stubgen") +def test_genstubs( + stubgen, + client, + runner: CliRunner, +): + runner.invoke(main, ["generate-stubs", "/path/to/stub_dir"]) + stubgen.generate_stubs.assert_called_once_with( + Path("/path/to/stub_dir"), list(client().plans), list(client().devices) + ) diff --git a/uv.lock b/uv.lock index efbb556aa6..d44bac795a 100644 --- a/uv.lock +++ b/uv.lock @@ -438,6 +438,7 @@ dependencies = [ { name = "fastapi" }, { name = "gitpython" }, { name = "graypy" }, + { name = "jinja2" }, { name = "observability-utils" }, { name = "opentelemetry-distro" }, { name = "opentelemetry-instrumentation-fastapi" }, @@ -495,6 +496,7 @@ requires-dist = [ { name = "fastapi", specifier = ">=0.112.0" }, { name = "gitpython" }, { name = "graypy", specifier = ">=2.1.0" }, + { name = "jinja2", specifier = ">=3.1.6" }, { name = "observability-utils", specifier = ">=0.1.4" }, { name = "opentelemetry-distro", specifier = ">=0.48b0" }, { name = "opentelemetry-instrumentation-fastapi", specifier = ">=0.48b0" },