diff --git a/src/blueapi/core/context.py b/src/blueapi/core/context.py index 9bf29aec3..26f065b6d 100644 --- a/src/blueapi/core/context.py +++ b/src/blueapi/core/context.py @@ -44,6 +44,7 @@ from .bluesky_types import ( BLUESKY_PROTOCOLS, + AsyncDevice, Device, Plan, PlanGenerator, @@ -102,8 +103,12 @@ def qualified_generic_name(target: type) -> str: return f"{qualified_name(target)}{subscript}" -def is_bluesky_type(typ: type) -> bool: - return typ in BLUESKY_PROTOCOLS or isinstance(typ, BLUESKY_PROTOCOLS) +def is_bluesky_type(typ: Any) -> bool: + return ( + typ in BLUESKY_PROTOCOLS + or isinstance(typ, BLUESKY_PROTOCOLS) + or (isinstance(typ, type) and issubclass(typ, AsyncDevice)) + ) C = TypeVar("C", covariant=True) @@ -531,7 +536,7 @@ def _convert_type(self, typ: type | Any, no_default: bool = True) -> type: if typ is NoneType and not no_default: return SkipJsonSchema[NoneType] root = get_origin(typ) - if is_bluesky_type(typ) or (root is not None and is_bluesky_type(root)): + if is_bluesky_type(root or typ): return self._reference(typ) args = get_args(typ) if args: diff --git a/tests/unit_tests/core/test_context.py b/tests/unit_tests/core/test_context.py index 2f31784aa..9cca6fe46 100644 --- a/tests/unit_tests/core/test_context.py +++ b/tests/unit_tests/core/test_context.py @@ -18,8 +18,8 @@ from bluesky.run_engine import RunEngine from bluesky.utils import MsgGenerator from dodal.common import PlanGenerator, inject -from ophyd import Device from ophyd_async.core import ( + Device, PathProvider, StandardDetector, StaticPathProvider, @@ -550,6 +550,20 @@ def demo(named: ConcreteStoppable): ... assert spec["named"][1].default_factory is None +class PlainDevice(Device): + """Class that extends Device without any additional protocols""" + + +def test_device_without_protocol_annotation(empty_context: BlueskyContext): + dev_ref = empty_context._reference(PlainDevice) + + def demo_plan(dev: PlainDevice) -> MsgGenerator: + yield from [] + + spec = empty_context._type_spec_for_function(demo_plan) + assert spec["dev"][0] is dev_ref + + def test_str_default(empty_context: BlueskyContext, sim_motor: Motor, alt_motor: Motor): movable_ref = empty_context._reference(Movable) empty_context.register_device(sim_motor)