diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index e4d8a66..a4887be 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -30,6 +30,20 @@ jobs: - name: Test run: | pytest tests -v - - name: Examples + - name: Examples (internal) run: | python examples/fodo.py + - name: Examples (external) + run: | + # Copy examples directory from the main PALS repository + cd examples + git clone --no-checkout https://github.com/pals-project/pals.git pals_temp + cd pals_temp + git sparse-checkout init + git sparse-checkout set examples/ + git checkout main + # Test all external examples + cd - + for file in pals_temp/examples/*.pals.yaml; do + python test_external_examples.py --path "${file}" + done \ No newline at end of file diff --git a/examples/test_external_examples.py b/examples/test_external_examples.py new file mode 100644 index 0000000..3f0a513 --- /dev/null +++ b/examples/test_external_examples.py @@ -0,0 +1,21 @@ +import argparse + +from pals import Lattice + + +def main(): + # Parse command-line arguments + parser = argparse.ArgumentParser() + parser.add_argument( + "--path", + required=True, + help="Path to the example file", + ) + args = parser.parse_args() + example_file = args.path + # Parse and validate YAML data from file + Lattice.from_file(example_file) + + +if __name__ == "__main__": + main() diff --git a/src/pals/kinds/BeamLine.py b/src/pals/kinds/BeamLine.py index f7dd881..2df4fca 100644 --- a/src/pals/kinds/BeamLine.py +++ b/src/pals/kinds/BeamLine.py @@ -3,7 +3,6 @@ from .all_elements import get_all_elements_as_annotation from .mixin import BaseElement -from ..functions import load_file_to_dict, store_dict_to_file class BeamLine(BaseElement): @@ -26,14 +25,3 @@ def model_dump(self, *args, **kwargs): from pals.kinds.mixin.all_element_mixin import dump_element_list return dump_element_list(self, "line", *args, **kwargs) - - @staticmethod - def from_file(filename: str) -> "BeamLine": - """Load a BeamLine from a text file""" - pals_dict = load_file_to_dict(filename) - return BeamLine(**pals_dict) - - def to_file(self, filename: str): - """Save a BeamLine to a text file""" - pals_dict = self.model_dump() - store_dict_to_file(filename, pals_dict) diff --git a/src/pals/kinds/Lattice.py b/src/pals/kinds/Lattice.py new file mode 100644 index 0000000..c2f204f --- /dev/null +++ b/src/pals/kinds/Lattice.py @@ -0,0 +1,39 @@ +from pydantic import model_validator +from typing import List, Literal + +from .all_elements import get_all_elements_as_annotation +from .mixin import BaseElement +from ..functions import load_file_to_dict, store_dict_to_file + + +class Lattice(BaseElement): + """A line of elements and/or other lines""" + + kind: Literal["Lattice"] = "Lattice" + + branches: List[Annotated[Union[BeamLine], Field(discriminator="kind")]] + + @model_validator(mode="before") + @classmethod + def unpack_json_structure(cls, data): + """Deserialize the JSON/YAML/...-like dict for Lattice elements""" + from pals.kinds.mixin.all_element_mixin import unpack_element_list_structure + + return unpack_element_list_structure(data, "line", "line") + + def model_dump(self, *args, **kwargs): + """Custom model dump for Lattice to handle element list formatting""" + from pals.kinds.mixin.all_element_mixin import dump_element_list + + return dump_element_list(self, "line", *args, **kwargs) + + @staticmethod + def from_file(filename: str) -> "Lattice": + """Load a Lattice from a text file""" + pals_dict = load_file_to_dict(filename) + return Lattice(**pals_dict) + + def to_file(self, filename: str): + """Save a Lattice to a text file""" + pals_dict = self.model_dump() + store_dict_to_file(filename, pals_dict) diff --git a/src/pals/kinds/__init__.py b/src/pals/kinds/__init__.py index d64c12c..36471e8 100644 --- a/src/pals/kinds/__init__.py +++ b/src/pals/kinds/__init__.py @@ -5,6 +5,7 @@ from .ACKicker import ACKicker # noqa: F401 from .BeamBeam import BeamBeam # noqa: F401 from .BeamLine import BeamLine # noqa: F401 +from .Lattice import Lattice # noqa: F401 from .BeginningEle import BeginningEle # noqa: F401 from .Converter import Converter # noqa: F401 from .CrabCavity import CrabCavity # noqa: F401 @@ -39,3 +40,4 @@ # Rebuild pydantic models that depend on other classes UnionEle.model_rebuild() BeamLine.model_rebuild() +Lattice.model_rebuild() diff --git a/src/pals/kinds/all_elements.py b/src/pals/kinds/all_elements.py index 580388e..f2d0d70 100644 --- a/src/pals/kinds/all_elements.py +++ b/src/pals/kinds/all_elements.py @@ -43,6 +43,7 @@ def get_all_element_types(extra_types: tuple = None): """Return a tuple of all element types that can be used in BeamLine or UnionEle.""" element_types = ( + "Lattice", # Forward reference to handle circular import "BeamLine", # Forward reference to handle circular import "UnionEle", # Forward reference to handle circular import ACKicker,