Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions src/blueapi/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@

from .bluesky_types import (
BLUESKY_PROTOCOLS,
AsyncDevice,
Device,
Plan,
PlanGenerator,
Expand Down Expand Up @@ -102,8 +103,12 @@ def qualified_generic_name(target: type) -> str:
return f"{qualified_name(target)}{subscript}"


def is_bluesky_type(typ: type) -> bool:
return typ in BLUESKY_PROTOCOLS or isinstance(typ, BLUESKY_PROTOCOLS)
def is_bluesky_type(typ: Any) -> bool:
return (
typ in BLUESKY_PROTOCOLS
or isinstance(typ, BLUESKY_PROTOCOLS)
or (isinstance(typ, type) and issubclass(typ, AsyncDevice))
)


C = TypeVar("C", covariant=True)
Expand Down Expand Up @@ -531,7 +536,7 @@ def _convert_type(self, typ: type | Any, no_default: bool = True) -> type:
if typ is NoneType and not no_default:
return SkipJsonSchema[NoneType]
root = get_origin(typ)
if is_bluesky_type(typ) or (root is not None and is_bluesky_type(root)):
if is_bluesky_type(root or typ):
return self._reference(typ)
args = get_args(typ)
if args:
Expand Down
16 changes: 15 additions & 1 deletion tests/unit_tests/core/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from bluesky.run_engine import RunEngine
from bluesky.utils import MsgGenerator
from dodal.common import PlanGenerator, inject
from ophyd import Device
from ophyd_async.core import (
Device,
PathProvider,
StandardDetector,
StaticPathProvider,
Expand Down Expand Up @@ -550,6 +550,20 @@ def demo(named: ConcreteStoppable): ...
assert spec["named"][1].default_factory is None


class PlainDevice(Device):
"""Class that extends Device without any additional protocols"""


def test_device_without_protocol_annotation(empty_context: BlueskyContext):
dev_ref = empty_context._reference(PlainDevice)

def demo_plan(dev: PlainDevice) -> MsgGenerator:
yield from []

spec = empty_context._type_spec_for_function(demo_plan)
assert spec["dev"][0] is dev_ref


def test_str_default(empty_context: BlueskyContext, sim_motor: Motor, alt_motor: Motor):
movable_ref = empty_context._reference(Movable)
empty_context.register_device(sim_motor)
Expand Down