From 29b736991074eab322ac14dd7c62ed1f019582ab Mon Sep 17 00:00:00 2001 From: johannes-moegerle Date: Fri, 9 Jan 2026 17:20:57 +0100 Subject: [PATCH 1/6] start adding mqdt states and basis --- pyproject.toml | 3 + src/rydstate/__init__.py | 4 + src/rydstate/angular/angular_ket.py | 12 +++ src/rydstate/angular/angular_state.py | 4 +- src/rydstate/basis/__init__.py | 10 ++- src/rydstate/basis/basis_mqdt.py | 110 +++++++++++++++++++++++ src/rydstate/rydberg/__init__.py | 2 + src/rydstate/rydberg/rydberg_mqdt.py | 121 ++++++++++++++++++++++++++ src/rydstate/rydberg/rydberg_sqdt.py | 13 ++- 9 files changed, 274 insertions(+), 5 deletions(-) create mode 100644 src/rydstate/basis/basis_mqdt.py create mode 100644 src/rydstate/rydberg/rydberg_mqdt.py diff --git a/pyproject.toml b/pyproject.toml index c567b4b..ae0ba8b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,9 @@ dependencies = [ ] [project.optional-dependencies] +mqdt = [ + "juliacall >= 0.9.24", +] tests = [ "pytest >= 8.0", "nbmake >= 1.3", diff --git a/src/rydstate/__init__.py b/src/rydstate/__init__.py index 739e3ad..429cffb 100644 --- a/src/rydstate/__init__.py +++ b/src/rydstate/__init__.py @@ -1,11 +1,13 @@ from rydstate import angular, basis, radial, rydberg, species from rydstate.basis import ( + BasisMQDT, BasisSQDTAlkali, BasisSQDTAlkalineFJ, BasisSQDTAlkalineJJ, BasisSQDTAlkalineLS, ) from rydstate.rydberg import ( + RydbergStateMQDT, RydbergStateSQDT, RydbergStateSQDTAlkali, RydbergStateSQDTAlkalineFJ, @@ -15,10 +17,12 @@ from rydstate.units import ureg __all__ = [ + "BasisMQDT", "BasisSQDTAlkali", "BasisSQDTAlkalineFJ", "BasisSQDTAlkalineJJ", "BasisSQDTAlkalineLS", + "RydbergStateMQDT", "RydbergStateSQDT", "RydbergStateSQDTAlkali", "RydbergStateSQDTAlkalineFJ", diff --git a/src/rydstate/angular/angular_ket.py b/src/rydstate/angular/angular_ket.py index a21b2ce..3ab7722 100644 --- a/src/rydstate/angular/angular_ket.py +++ b/src/rydstate/angular/angular_ket.py @@ -23,6 +23,7 @@ from rydstate.species import SpeciesObject if TYPE_CHECKING: + import juliacall from typing_extensions import Self from rydstate.angular.angular_matrix_element import AngularMomentumQuantumNumbers, AngularOperatorType @@ -738,6 +739,17 @@ def sanity_check(self, msgs: list[str] | None = None) -> None: super().sanity_check(msgs) +def julia_qn_to_dict(qn: juliacall.AnyValue) -> dict[str, float]: + """Convert MQDT Julia quantum numbers to dict object.""" + if "fjQuantumNumbers" in str(qn): + return dict(s_c=qn.sc, l_c=qn.lc, j_c=qn.Jc, f_c=qn.Fc, l_r=qn.lr, j_r=qn.Jr, f_tot=qn.F) # noqa: C408 + if "jjQuantumNumbers" in str(qn): + return dict(s_c=qn.sc, l_c=qn.lc, j_c=qn.Jc, l_r=qn.lr, j_r=qn.Jr, j_tot=qn.J, f_tot=qn.F) # noqa: C408 + if "lsQuantumNumbers" in str(qn): + return dict(s_c=qn.sc, s_tot=qn.S, l_c=qn.lc, l_r=qn.lr, l_tot=qn.L, j_tot=qn.J, f_tot=qn.F) # noqa: C408 + raise ValueError(f"Unknown MQDT Julia quantum numbers {qn!s}.") + + def quantum_numbers_to_angular_ket( species: str | SpeciesObject, s_c: float | None = None, diff --git a/src/rydstate/angular/angular_state.py b/src/rydstate/angular/angular_state.py index e3978fb..e315a94 100644 --- a/src/rydstate/angular/angular_state.py +++ b/src/rydstate/angular/angular_state.py @@ -34,6 +34,7 @@ def __init__( ) -> None: self.coefficients = np.array(coefficients) self.kets = kets + self._warn_if_not_normalized = warn_if_not_normalized if len(coefficients) != len(kets): raise ValueError("Length of coefficients and kets must be the same.") @@ -98,7 +99,8 @@ def to(self, coupling_scheme: CouplingScheme) -> AngularState[Any]: else: kets.append(scheme_ket) coefficients.append(coeff * scheme_coeff) - return AngularState(coefficients, kets, warn_if_not_normalized=abs(self.norm - 1) < 1e-10) + warn_if_not_normalized = self._warn_if_not_normalized and (abs(self.norm - 1) < 1e-10) + return AngularState(coefficients, kets, warn_if_not_normalized=warn_if_not_normalized) def calc_exp_qn(self, q: AngularMomentumQuantumNumbers) -> float: """Calculate the expectation value of a quantum number q. diff --git a/src/rydstate/basis/__init__.py b/src/rydstate/basis/__init__.py index 7f15364..93ef96c 100644 --- a/src/rydstate/basis/__init__.py +++ b/src/rydstate/basis/__init__.py @@ -1,4 +1,12 @@ from rydstate.basis.basis_base import BasisBase +from rydstate.basis.basis_mqdt import BasisMQDT from rydstate.basis.basis_sqdt import BasisSQDTAlkali, BasisSQDTAlkalineFJ, BasisSQDTAlkalineJJ, BasisSQDTAlkalineLS -__all__ = ["BasisBase", "BasisSQDTAlkali", "BasisSQDTAlkalineFJ", "BasisSQDTAlkalineJJ", "BasisSQDTAlkalineLS"] +__all__ = [ + "BasisBase", + "BasisMQDT", + "BasisSQDTAlkali", + "BasisSQDTAlkalineFJ", + "BasisSQDTAlkalineJJ", + "BasisSQDTAlkalineLS", +] diff --git a/src/rydstate/basis/basis_mqdt.py b/src/rydstate/basis/basis_mqdt.py new file mode 100644 index 0000000..b73b8db --- /dev/null +++ b/src/rydstate/basis/basis_mqdt.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any + +from rydstate.angular.angular_ket import julia_qn_to_dict +from rydstate.basis.basis_base import BasisBase +from rydstate.rydberg.rydberg_mqdt import RydbergStateMQDT +from rydstate.rydberg.rydberg_sqdt import RydbergStateSQDT + +if TYPE_CHECKING: + from rydstate.species import SpeciesObject + +logger = logging.getLogger(__name__) + +try: + USE_JULIACALL = True + from juliacall import ( + JuliaError, + Main as jl, # noqa: N813 + convert, + ) +except ImportError: + USE_JULIACALL = False + + +if USE_JULIACALL: + try: + jl.seval("using MQDT") + jl.seval("using CGcoefficient") + except JuliaError: + logger.exception("Failed to load Julia MQDT or CGcoefficient package") + USE_JULIACALL = False + +FMODEL_MAX_L = {"Sr87": 2, "Sr88": 2, "Yb171": 4, "Yb173": 1, "Yb174": 4} + + +class BasisMQDT(BasisBase[RydbergStateMQDT[Any]]): + def __init__( + self, + species: str | SpeciesObject, + n_min: int = 0, + n_max: int | None = None, + *, + skip_high_l: bool = True, + model_names: list[str] | None = None, + ) -> None: + super().__init__(species) + + if not USE_JULIACALL: + raise ImportError("JuliaCall or the MQDT Julia package is not available.") + + try: + self.jl_species = getattr(jl.MQDT, self.species.name) + parameters = self.jl_species.PARA + except AttributeError as e: + raise ValueError(f"Species '{species}' is not supported in the MQDT Julia package.") from e + + # TODO use n_min and n_max of the different models + + if n_max is None: + raise ValueError("n_max must be given") + + # initialize Wigner symbol calculation + if skip_high_l: + jl.CGcoefficient.wigner_init_float(5, "Jmax", 9) + else: + jl.CGcoefficient.wigner_init_float(n_max - 1, "Jmax", 9) + + logger.debug("Calculating low l MQDT states...") + + jl_species_attr_names = [str(name) for name in jl.seval(f"names(MQDT.{self.species.name}, all=true)")] + self.models = {name: getattr(self.jl_species, name) for name in jl_species_attr_names} + self.models = {k: v for k, v in self.models.items() if str(v).startswith("fModel")} + if model_names is not None: + self.models = {k: v for k, v in self.models.items() if k in model_names} + + if skip_high_l: + logger.debug("Skipping high l states.") + else: + logger.debug("Calculating high l SQDT states...") + l_start = FMODEL_MAX_L[self.species.name] + 1 + high_l_models = { + f"high_l_{l_ryd}": jl.single_channel_models(species, l_ryd, parameters) + for l_ryd in range(l_start, n_max) + } + self.models.update(high_l_models) + + model_names = list(self.models.keys()) + jl_states = {name: jl.eigenstates(n_min, n_max, model, parameters) for name, model in self.models.items()} + _models_vector = convert(jl.Vector, [self.models[name] for name in model_names]) + _jl_states_vector = convert(jl.Vector, [jl_states[name] for name in model_names]) + jl_basis = jl.basisarray(_jl_states_vector, _models_vector) + + logger.debug("Generated state table with %d states", len(jl_basis.states)) + + self.states = [] + for jl_state in jl_basis.states: + coeffs = jl_state.coeff + nus = jl_state.nu + nu_energy = jl_state.energy + qns = jl_state.channels.i + qns = [julia_qn_to_dict(qn) for qn in qns] + + sqdt_states = [RydbergStateSQDT(species, nu=nu, **qn) for nu, qn in zip(nus, qns)] + # check angular and radial are created correctly + [(s.angular, s.radial) for s in sqdt_states] + + mqdt_state = RydbergStateMQDT(coeffs, sqdt_states, nu_energy=nu_energy, warn_if_not_normalized=False) + self.states.append(mqdt_state) diff --git a/src/rydstate/rydberg/__init__.py b/src/rydstate/rydberg/__init__.py index f7a895b..b13d7a7 100644 --- a/src/rydstate/rydberg/__init__.py +++ b/src/rydstate/rydberg/__init__.py @@ -1,4 +1,5 @@ from rydstate.rydberg.rydberg_base import RydbergStateBase +from rydstate.rydberg.rydberg_mqdt import RydbergStateMQDT from rydstate.rydberg.rydberg_sqdt import ( RydbergStateSQDT, RydbergStateSQDTAlkali, @@ -9,6 +10,7 @@ __all__ = [ "RydbergStateBase", + "RydbergStateMQDT", "RydbergStateSQDT", "RydbergStateSQDTAlkali", "RydbergStateSQDTAlkalineFJ", diff --git a/src/rydstate/rydberg/rydberg_mqdt.py b/src/rydstate/rydberg/rydberg_mqdt.py new file mode 100644 index 0000000..48b03e4 --- /dev/null +++ b/src/rydstate/rydberg/rydberg_mqdt.py @@ -0,0 +1,121 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload + +import numpy as np + +from rydstate.angular import AngularState +from rydstate.rydberg.rydberg_base import RydbergStateBase +from rydstate.rydberg.rydberg_sqdt import RydbergStateSQDT + +if TYPE_CHECKING: + from collections.abc import Iterator, Sequence + + from rydstate.units import MatrixElementOperator, PintFloat + + +logger = logging.getLogger(__name__) + + +_RydbergState = TypeVar("_RydbergState", bound=RydbergStateSQDT) + + +class RydbergStateMQDT(RydbergStateBase, Generic[_RydbergState]): + angular: AngularState[Any] + """Return the angular part of the MQDT state as an AngularState.""" + + def __init__( + self, + coefficients: Sequence[float], + sqdt_states: Sequence[_RydbergState], + *, + nu_energy: float | None = None, + warn_if_not_normalized: bool = True, + ) -> None: + self.coefficients = np.array(coefficients) + self.sqdt_states = sqdt_states + self.nu_energy = nu_energy + self.angular = AngularState(self.coefficients.tolist(), [ket.angular for ket in sqdt_states]) + + if len(coefficients) != len(sqdt_states): + raise ValueError("Length of coefficients and sqdt_states must be the same.") + if not all(type(sqdt_state) is type(sqdt_states[0]) for sqdt_state in sqdt_states): + raise ValueError("All sqdt_states must be of the same type.") + if len(set(sqdt_states)) != len(sqdt_states): + raise ValueError("RydbergStateMQDT initialized with duplicate sqdt_states.") + if abs(self.norm - 1) > 1e-10 and warn_if_not_normalized: + logger.warning( + "RydbergStateMQDT initialized with non-normalized coefficients " + "(norm=%s, coefficients=%s, sqdt_states=%s)", + self.norm, + coefficients, + sqdt_states, + ) + if self.norm > 1: + self.coefficients /= self.norm + + def __iter__(self) -> Iterator[tuple[float, _RydbergState]]: + return zip(self.coefficients, self.sqdt_states).__iter__() + + def __repr__(self) -> str: + terms = [f"{coeff}*{sqdt_state!r}" for coeff, sqdt_state in self] + return f"{self.__class__.__name__}({', '.join(terms)})" + + def __str__(self) -> str: + terms = [f"{coeff}*{sqdt_state!s}" for coeff, sqdt_state in self] + return f"{', '.join(terms)}" + + @property + def norm(self) -> float: + """Return the norm of the state (should be 1).""" + return np.linalg.norm(self.coefficients) # type: ignore [return-value] + + def calc_reduced_overlap(self, other: RydbergStateBase) -> float: + """Calculate the reduced overlap (ignoring the magnetic quantum number m).""" + if isinstance(other, RydbergStateSQDT): + other = other.to_mqdt() + + if isinstance(other, RydbergStateMQDT): + ov = 0 + for coeff1, sqdt1 in self: + for coeff2, sqdt2 in other: + ov += np.conjugate(coeff1) * coeff2 * sqdt1.calc_reduced_overlap(sqdt2) + return ov + + raise NotImplementedError(f"calc_reduced_overlap not implemented for {type(self)=}, {type(other)=}") + + @overload # type: ignore [override] + def calc_reduced_matrix_element( + self, other: RydbergStateBase, operator: MatrixElementOperator, unit: None = None + ) -> PintFloat: ... + + @overload + def calc_reduced_matrix_element( + self, other: RydbergStateBase, operator: MatrixElementOperator, unit: str + ) -> float: ... + + def calc_reduced_matrix_element( + self, other: RydbergStateBase, operator: MatrixElementOperator, unit: str | None = None + ) -> PintFloat | float: + r"""Calculate the reduced angular matrix element. + + This means, calculate the following matrix element: + + .. math:: + \left\langle self || \hat{O}^{(\kappa)} || other \right\rangle + + """ + if isinstance(other, RydbergStateSQDT): + other = other.to_mqdt() + + if isinstance(other, RydbergStateMQDT): + value = 0 + for coeff1, sqdt1 in self: + for coeff2, sqdt2 in other: + value += ( + np.conjugate(coeff1) * coeff2 * sqdt1.calc_reduced_matrix_element(sqdt2, operator, unit=unit) + ) + return value + + raise NotImplementedError(f"calc_reduced_overlap not implemented for {type(self)=}, {type(other)=}") diff --git a/src/rydstate/rydberg/rydberg_sqdt.py b/src/rydstate/rydberg/rydberg_sqdt.py index 744b582..441f14c 100644 --- a/src/rydstate/rydberg/rydberg_sqdt.py +++ b/src/rydstate/rydberg/rydberg_sqdt.py @@ -3,7 +3,7 @@ import logging import math from functools import cached_property -from typing import TYPE_CHECKING, overload +from typing import TYPE_CHECKING, Any, overload import numpy as np @@ -15,6 +15,7 @@ from rydstate.units import BaseQuantities, MatrixElementOperatorRanks, ureg if TYPE_CHECKING: + from rydstate import RydbergStateMQDT from rydstate.angular.angular_ket import AngularKetBase, AngularKetFJ, AngularKetJJ, AngularKetLS from rydstate.units import MatrixElementOperator, PintFloat @@ -182,10 +183,16 @@ def get_energy(self, unit: str | None = None) -> PintFloat | float: return energy return energy.to(unit, "spectroscopy").magnitude + def to_mqdt(self) -> RydbergStateMQDT[Any]: + """Convert to a trivial RydbergMQDT state with only one contribution with coefficient 1.""" + from rydstate import RydbergStateMQDT # noqa: PLC0415 + + return RydbergStateMQDT([1], [self]) + def calc_reduced_overlap(self, other: RydbergStateBase) -> float: """Calculate the reduced overlap (ignoring the magnetic quantum number m).""" if not isinstance(other, RydbergStateSQDT): - raise NotImplementedError("Reduced overlap only implemented between RydbergStateSQDT states.") + return self.to_mqdt().calc_reduced_overlap(other) radial_overlap = self.radial.calc_overlap(other.radial) angular_overlap = self.angular.calc_reduced_overlap(other.angular) @@ -226,7 +233,7 @@ def calc_reduced_matrix_element( """ if not isinstance(other, RydbergStateSQDT): - raise NotImplementedError("Reduced matrix element only implemented between RydbergStateSQDT states.") + return self.to_mqdt().calc_reduced_matrix_element(other, operator, unit=unit) if operator not in MatrixElementOperatorRanks: raise ValueError( From aeb212572cf0b3d5f6064635ca06379644527c5c Mon Sep 17 00:00:00 2001 From: johannes-moegerle Date: Wed, 7 Jan 2026 21:05:29 +0100 Subject: [PATCH 2/6] update MQDT improved fModel --- src/rydstate/basis/basis_mqdt.py | 80 ++++++++++++++++---------------- 1 file changed, 41 insertions(+), 39 deletions(-) diff --git a/src/rydstate/basis/basis_mqdt.py b/src/rydstate/basis/basis_mqdt.py index b73b8db..f32b1d4 100644 --- a/src/rydstate/basis/basis_mqdt.py +++ b/src/rydstate/basis/basis_mqdt.py @@ -3,12 +3,20 @@ import logging from typing import TYPE_CHECKING, Any +import numpy as np + from rydstate.angular.angular_ket import julia_qn_to_dict from rydstate.basis.basis_base import BasisBase from rydstate.rydberg.rydberg_mqdt import RydbergStateMQDT from rydstate.rydberg.rydberg_sqdt import RydbergStateSQDT if TYPE_CHECKING: + from juliacall import ( + JuliaError, + Main as jl, # noqa: N813 + convert, + ) + from rydstate.species import SpeciesObject logger = logging.getLogger(__name__) @@ -32,8 +40,6 @@ logger.exception("Failed to load Julia MQDT or CGcoefficient package") USE_JULIACALL = False -FMODEL_MAX_L = {"Sr87": 2, "Sr88": 2, "Yb171": 4, "Yb173": 1, "Yb174": 4} - class BasisMQDT(BasisBase[RydbergStateMQDT[Any]]): def __init__( @@ -43,54 +49,50 @@ def __init__( n_max: int | None = None, *, skip_high_l: bool = True, - model_names: list[str] | None = None, ) -> None: super().__init__(species) - if not USE_JULIACALL: - raise ImportError("JuliaCall or the MQDT Julia package is not available.") - - try: - self.jl_species = getattr(jl.MQDT, self.species.name) - parameters = self.jl_species.PARA - except AttributeError as e: - raise ValueError(f"Species '{species}' is not supported in the MQDT Julia package.") from e - - # TODO use n_min and n_max of the different models - if n_max is None: raise ValueError("n_max must be given") + if not USE_JULIACALL: + raise ImportError("JuliaCall or the MQDT Julia package is not available.") + # initialize Wigner symbol calculation if skip_high_l: - jl.CGcoefficient.wigner_init_float(5, "Jmax", 9) + jl.CGcoefficient.wigner_init_float(10, "Jmax", 9) else: jl.CGcoefficient.wigner_init_float(n_max - 1, "Jmax", 9) - logger.debug("Calculating low l MQDT states...") - - jl_species_attr_names = [str(name) for name in jl.seval(f"names(MQDT.{self.species.name}, all=true)")] - self.models = {name: getattr(self.jl_species, name) for name in jl_species_attr_names} - self.models = {k: v for k, v in self.models.items() if str(v).startswith("fModel")} - if model_names is not None: - self.models = {k: v for k, v in self.models.items() if k in model_names} - - if skip_high_l: - logger.debug("Skipping high l states.") - else: - logger.debug("Calculating high l SQDT states...") - l_start = FMODEL_MAX_L[self.species.name] + 1 - high_l_models = { - f"high_l_{l_ryd}": jl.single_channel_models(species, l_ryd, parameters) - for l_ryd in range(l_start, n_max) - } - self.models.update(high_l_models) - - model_names = list(self.models.keys()) - jl_states = {name: jl.eigenstates(n_min, n_max, model, parameters) for name, model in self.models.items()} - _models_vector = convert(jl.Vector, [self.models[name] for name in model_names]) - _jl_states_vector = convert(jl.Vector, [jl_states[name] for name in model_names]) - jl_basis = jl.basisarray(_jl_states_vector, _models_vector) + jl_species = jl.Symbol(self.species.name) + parameters = jl.MQDT.get_parameters(jl_species) + + self.models = [] + i_c = self.species.i_c if self.species.i_c is not None else 0 + for l in range(n_max): + jtot_min = min(l, abs(l - 1)) + jtot_max = l + 1 + for f_tot in np.arange(abs(jtot_min - i_c), jtot_max + i_c + 1): + models = jl.MQDT.get_fmodels(jl_species, l, f_tot) + self.models.extend(models) + + n_min_high_l = 25 + + logger.debug("Calculating MQDT states...") + jl_states = [] + for model in self.models: + _n_min = n_min + if model.name.startswith("SQDT"): + if skip_high_l: + continue + _n_min = n_min_high_l + + logger.debug(f"{model.name}:") + states = jl.MQDT.eigenstates(_n_min, n_max, model, parameters) + jl_states.append(states) + logger.debug(f" found nu_min={min(states.n)}, nu_max={max(states.n)}, total states={len(states.n)}") + + jl_basis = jl.basisarray(convert(jl.Vector, jl_states), convert(jl.Vector, self.models)) logger.debug("Generated state table with %d states", len(jl_basis.states)) From e4317cbe4028dd23ea6daca1317a1b781c1a432a Mon Sep 17 00:00:00 2001 From: johannes-moegerle Date: Wed, 7 Jan 2026 22:31:49 +0100 Subject: [PATCH 3/6] start adding AngularKetDummy --- src/rydstate/angular/__init__.py | 3 +- src/rydstate/angular/angular_ket.py | 71 ++++++++++++++++++++++++++- src/rydstate/angular/angular_state.py | 36 +++++++++++--- 3 files changed, 99 insertions(+), 11 deletions(-) diff --git a/src/rydstate/angular/__init__.py b/src/rydstate/angular/__init__.py index 66f0142..3b41857 100644 --- a/src/rydstate/angular/__init__.py +++ b/src/rydstate/angular/__init__.py @@ -1,8 +1,9 @@ from rydstate.angular import utils -from rydstate.angular.angular_ket import AngularKetFJ, AngularKetJJ, AngularKetLS +from rydstate.angular.angular_ket import AngularKetDummy, AngularKetFJ, AngularKetJJ, AngularKetLS from rydstate.angular.angular_state import AngularState __all__ = [ + "AngularKetDummy", "AngularKetFJ", "AngularKetJJ", "AngularKetLS", diff --git a/src/rydstate/angular/angular_ket.py b/src/rydstate/angular/angular_ket.py index 3ab7722..a8525fa 100644 --- a/src/rydstate/angular/angular_ket.py +++ b/src/rydstate/angular/angular_ket.py @@ -24,14 +24,14 @@ if TYPE_CHECKING: import juliacall - from typing_extensions import Self + from typing_extensions import Never, Self from rydstate.angular.angular_matrix_element import AngularMomentumQuantumNumbers, AngularOperatorType from rydstate.angular.angular_state import AngularState logger = logging.getLogger(__name__) -CouplingScheme = Literal["LS", "JJ", "FJ"] +CouplingScheme = Literal["LS", "JJ", "FJ", "Dummy"] class InvalidQuantumNumbersError(ValueError): @@ -224,6 +224,9 @@ def to_state(self, coupling_scheme: Literal["JJ"]) -> AngularState[AngularKetJJ] @overload def to_state(self, coupling_scheme: Literal["FJ"]) -> AngularState[AngularKetFJ]: ... + @overload + def to_state(self, coupling_scheme: Literal["Dummy"]) -> Never: ... + @overload def to_state(self: Self) -> AngularState[Self]: ... @@ -372,6 +375,10 @@ def calc_reduced_overlap(self, other: AngularKetBase) -> float: kets = [self, other] + # Dummy overlaps + if any(isinstance(s, AngularKetDummy) for s in kets): + return int(self == other) + # JJ - FJ overlaps if any(isinstance(s, AngularKetJJ) for s in kets) and any(isinstance(s, AngularKetFJ) for s in kets): jj = next(s for s in kets if isinstance(s, AngularKetJJ)) @@ -414,6 +421,10 @@ def calc_reduced_matrix_element( # noqa: C901 if not is_angular_operator_type(operator): raise NotImplementedError(f"calc_reduced_matrix_element is not implemented for operator {operator}.") + # Dummy matrix elements + if any(isinstance(s, AngularKetDummy) for s in [self, other]): + return 0 + if type(self) is not type(other): return self.to_state().calc_reduced_matrix_element(other.to_state(), operator, kappa) if is_angular_momentum_quantum_number(operator) and operator not in self.quantum_number_names: @@ -739,6 +750,62 @@ def sanity_check(self, msgs: list[str] | None = None) -> None: super().sanity_check(msgs) +class AngularKetDummy(AngularKetBase): + """Dummy spin ket for unknown quantum numbers.""" + + __slots__ = ("name",) + quantum_number_names: ClassVar = ("f_tot",) + coupled_quantum_numbers: ClassVar = {} + coupling_scheme = "Dummy" + + name: str + """Name of the dummy ket.""" + + def __init__( + self, + name: str, + f_tot: float, + m: float | None = None, + ) -> None: + """Initialize the Spin ket.""" + self.name = name + + self.f_tot = f_tot + self.m = m + + super()._post_init() + + def sanity_check(self, msgs: list[str] | None = None) -> None: + """Check that the quantum numbers are valid.""" + msgs = msgs if msgs is not None else [] + + if self.m is not None and not -self.f_tot <= self.m <= self.f_tot: + msgs.append(f"m must be between -f_tot and f_tot, but {self.f_tot=}, {self.m=}") + + if msgs: + msg = "\n ".join(msgs) + raise InvalidQuantumNumbersError(self, msg) + + def __repr__(self) -> str: + args = f"{self.name}, f_tot={self.f_tot}" + if self.m is not None: + args += f", m={self.m}" + return f"{self.__class__.__name__}({args})" + + def __str__(self) -> str: + return self.__repr__().replace("AngularKet", "") + + def __eq__(self, other: object) -> bool: + if not isinstance(other, AngularKetBase): + raise NotImplementedError(f"Cannot compare {self!r} with {other!r}.") + if not isinstance(other, AngularKetDummy): + return False + return self.name == other.name and self.f_tot == other.f_tot and self.m == other.m + + def __hash__(self) -> int: + return hash((self.name, self.f_tot, self.m)) + + def julia_qn_to_dict(qn: juliacall.AnyValue) -> dict[str, float]: """Convert MQDT Julia quantum numbers to dict object.""" if "fjQuantumNumbers" in str(qn): diff --git a/src/rydstate/angular/angular_state.py b/src/rydstate/angular/angular_state.py index e315a94..01b5700 100644 --- a/src/rydstate/angular/angular_state.py +++ b/src/rydstate/angular/angular_state.py @@ -8,6 +8,7 @@ from rydstate.angular.angular_ket import ( AngularKetBase, + AngularKetDummy, AngularKetFJ, AngularKetJJ, AngularKetLS, @@ -17,7 +18,7 @@ if TYPE_CHECKING: from collections.abc import Iterator, Sequence - from typing_extensions import Self + from typing_extensions import Never, Self from rydstate.angular.angular_ket import CouplingScheme from rydstate.angular.angular_matrix_element import AngularMomentumQuantumNumbers, AngularOperatorType @@ -32,22 +33,28 @@ class AngularState(Generic[_AngularKet]): def __init__( self, coefficients: Sequence[float], kets: Sequence[_AngularKet], *, warn_if_not_normalized: bool = True ) -> None: - self.coefficients = np.array(coefficients) - self.kets = kets + """Initialize an angular state as a linear combination of angular kets. + + All kets must be of the same type (coupling scheme), and no duplicate kets are allowed. + Dummy kets (AngularKetDummy) are ignored in the state representation, + however adding them is recommended for normalization purposes. + """ + self._coefficients = np.array(coefficients) + self._kets = kets self._warn_if_not_normalized = warn_if_not_normalized if len(coefficients) != len(kets): raise ValueError("Length of coefficients and kets must be the same.") if len(kets) == 0: raise ValueError("At least one ket must be provided.") - if not all(type(ket) is type(kets[0]) for ket in kets): + if not all(type(ket) is type(self.kets[0]) for ket in self.kets): raise ValueError("All kets must be of the same type.") - if len(set(kets)) != len(kets): - raise ValueError("AngularState initialized with duplicate kets.") + if len(set(self.kets)) != len(self.kets): + raise ValueError("AngularState initialized with duplicate kets: %s", self.kets) if abs(self.norm - 1) > 1e-10 and warn_if_not_normalized: logger.warning("AngularState initialized with non-normalized coefficients: %s, %s", coefficients, kets) if self.norm > 1: - self.coefficients /= self.norm + self._coefficients /= self.norm def __iter__(self) -> Iterator[tuple[float, _AngularKet]]: return zip(self.coefficients, self.kets).__iter__() @@ -60,6 +67,16 @@ def __str__(self) -> str: terms = [f"{coeff}*{ket!s}" for coeff, ket in self] return f"{', '.join(terms)}" + @property + def kets(self) -> list[_AngularKet]: + return [ket for ket in self._kets if not isinstance(ket, AngularKetDummy)] + + @property + def coefficients(self) -> np.ndarray: + return np.array( + [coeff for coeff, ket in zip(self._coefficients, self._kets) if not isinstance(ket, AngularKetDummy)] + ) + @property def coupling_scheme(self) -> CouplingScheme: """Return the coupling scheme of the state.""" @@ -68,7 +85,7 @@ def coupling_scheme(self) -> CouplingScheme: @property def norm(self) -> float: """Return the norm of the state (should be 1).""" - return np.linalg.norm(self.coefficients) # type: ignore [return-value] + return np.linalg.norm(self._coefficients) # type: ignore [return-value] @overload def to(self, coupling_scheme: Literal["LS"]) -> AngularState[AngularKetLS]: ... @@ -79,6 +96,9 @@ def to(self, coupling_scheme: Literal["JJ"]) -> AngularState[AngularKetJJ]: ... @overload def to(self, coupling_scheme: Literal["FJ"]) -> AngularState[AngularKetFJ]: ... + @overload + def to(self, coupling_scheme: Literal["Dummy"]) -> Never: ... + def to(self, coupling_scheme: CouplingScheme) -> AngularState[Any]: """Convert to specified coupling scheme. From 75fae8b6274edcb041ba6250022bdea0945395e2 Mon Sep 17 00:00:00 2001 From: johannes-moegerle Date: Wed, 7 Jan 2026 22:33:57 +0100 Subject: [PATCH 4/6] basis mqdt updates/fixes --- src/rydstate/basis/basis_mqdt.py | 50 ++++++++++++++++++++++++-------- 1 file changed, 38 insertions(+), 12 deletions(-) diff --git a/src/rydstate/basis/basis_mqdt.py b/src/rydstate/basis/basis_mqdt.py index f32b1d4..097b9a4 100644 --- a/src/rydstate/basis/basis_mqdt.py +++ b/src/rydstate/basis/basis_mqdt.py @@ -5,7 +5,11 @@ import numpy as np -from rydstate.angular.angular_ket import julia_qn_to_dict +from rydstate.angular.angular_ket import ( + AngularKetDummy, + julia_qn_to_dict, + quantum_numbers_to_angular_ket, +) from rydstate.basis.basis_base import BasisBase from rydstate.rydberg.rydberg_mqdt import RydbergStateMQDT from rydstate.rydberg.rydberg_sqdt import RydbergStateSQDT @@ -17,6 +21,7 @@ convert, ) + from rydstate.angular.angular_ket import AngularKetBase from rydstate.species import SpeciesObject logger = logging.getLogger(__name__) @@ -42,7 +47,7 @@ class BasisMQDT(BasisBase[RydbergStateMQDT[Any]]): - def __init__( + def __init__( # noqa: PLR0915, C901, PLR0912 self, species: str | SpeciesObject, n_min: int = 0, @@ -73,7 +78,7 @@ def __init__( jtot_min = min(l, abs(l - 1)) jtot_max = l + 1 for f_tot in np.arange(abs(jtot_min - i_c), jtot_max + i_c + 1): - models = jl.MQDT.get_fmodels(jl_species, l, f_tot) + models = jl.MQDT.get_fmodels(jl_species, l, float(f_tot)) self.models.extend(models) n_min_high_l = 25 @@ -87,10 +92,14 @@ def __init__( continue _n_min = n_min_high_l - logger.debug(f"{model.name}:") + logger.debug("model name: %s", model.name) states = jl.MQDT.eigenstates(_n_min, n_max, model, parameters) jl_states.append(states) - logger.debug(f" found nu_min={min(states.n)}, nu_max={max(states.n)}, total states={len(states.n)}") + + if len(states.n) == 0: + logger.debug(" no states found") + else: + logger.debug(" nu_min=%s, nu_max=%s, total states=%d", min(states.n), max(states.n), len(states.n)) jl_basis = jl.basisarray(convert(jl.Vector, jl_states), convert(jl.Vector, self.models)) @@ -98,15 +107,32 @@ def __init__( self.states = [] for jl_state in jl_basis.states: - coeffs = jl_state.coeff - nus = jl_state.nu + nus = jl_state.nu_list nu_energy = jl_state.energy - qns = jl_state.channels.i - qns = [julia_qn_to_dict(qn) for qn in qns] + angular_kets: list[AngularKetBase] = [] + iqn = 0 + model = jl_state.model + for i, core in enumerate(model.core): + if not core: + name = model.name + model.terms[i] + angular_kets.append(AngularKetDummy(name, f_tot=model.f_tot)) + continue + + qn = julia_qn_to_dict(jl_state.channels.i[iqn]) + try: + angular_kets.append(quantum_numbers_to_angular_ket(species=self.species, **qn)) # type: ignore[arg-type] + except ValueError: + name = model.name + model.terms[i] + angular_kets.append(AngularKetDummy(name, f_tot=model.f_tot)) + + iqn += 1 - sqdt_states = [RydbergStateSQDT(species, nu=nu, **qn) for nu, qn in zip(nus, qns)] + sqdt_states = [ + RydbergStateSQDT.from_angular_ket(species, angular_ket, nu=nu) + for nu, angular_ket in zip(nus, angular_kets) + ] # check angular and radial are created correctly - [(s.angular, s.radial) for s in sqdt_states] + assert len([(s.angular, s.radial) for s in sqdt_states]) > 0 - mqdt_state = RydbergStateMQDT(coeffs, sqdt_states, nu_energy=nu_energy, warn_if_not_normalized=False) + mqdt_state = RydbergStateMQDT(jl_state.coefficients, sqdt_states, nu_energy=nu_energy) self.states.append(mqdt_state) From f146a92228062118d5088cd066e0767cf7ff624a Mon Sep 17 00:00:00 2001 From: johannes-moegerle Date: Mon, 12 Jan 2026 09:58:10 +0100 Subject: [PATCH 5/6] split up angular ket --- src/rydstate/angular/__init__.py | 3 +- src/rydstate/angular/angular_ket.py | 141 +--------------------- src/rydstate/angular/angular_ket_dummy.py | 86 +++++++++++++ src/rydstate/angular/angular_state.py | 4 +- src/rydstate/angular/utils.py | 84 +++++++++++++ src/rydstate/basis/basis_mqdt.py | 7 +- src/rydstate/rydberg/rydberg_sqdt.py | 2 +- tests/test_angular_matrix_elements.py | 3 +- 8 files changed, 181 insertions(+), 149 deletions(-) create mode 100644 src/rydstate/angular/angular_ket_dummy.py diff --git a/src/rydstate/angular/__init__.py b/src/rydstate/angular/__init__.py index 3b41857..4394d0e 100644 --- a/src/rydstate/angular/__init__.py +++ b/src/rydstate/angular/__init__.py @@ -1,5 +1,6 @@ from rydstate.angular import utils -from rydstate.angular.angular_ket import AngularKetDummy, AngularKetFJ, AngularKetJJ, AngularKetLS +from rydstate.angular.angular_ket import AngularKetFJ, AngularKetJJ, AngularKetLS +from rydstate.angular.angular_ket_dummy import AngularKetDummy from rydstate.angular.angular_state import AngularState __all__ = [ diff --git a/src/rydstate/angular/angular_ket.py b/src/rydstate/angular/angular_ket.py index a8525fa..7492c25 100644 --- a/src/rydstate/angular/angular_ket.py +++ b/src/rydstate/angular/angular_ket.py @@ -1,6 +1,5 @@ from __future__ import annotations -import contextlib import logging from abc import ABC from typing import TYPE_CHECKING, Any, ClassVar, Literal, overload @@ -14,6 +13,7 @@ is_angular_operator_type, ) from rydstate.angular.utils import ( + InvalidQuantumNumbersError, check_spin_addition_rule, get_possible_quantum_number_values, minus_one_pow, @@ -23,24 +23,14 @@ from rydstate.species import SpeciesObject if TYPE_CHECKING: - import juliacall from typing_extensions import Never, Self from rydstate.angular.angular_matrix_element import AngularMomentumQuantumNumbers, AngularOperatorType from rydstate.angular.angular_state import AngularState + from rydstate.angular.utils import CouplingScheme logger = logging.getLogger(__name__) -CouplingScheme = Literal["LS", "JJ", "FJ", "Dummy"] - - -class InvalidQuantumNumbersError(ValueError): - def __init__(self, ket: AngularKetBase, msg: str = "") -> None: - _msg = f"Invalid quantum numbers for {ket!r}" - if len(msg) > 0: - _msg += f"\n {msg}" - super().__init__(_msg) - class AngularKetBase(ABC): """Base class for a angular ket (i.e. a simple canonical spin ketstate).""" @@ -375,10 +365,6 @@ def calc_reduced_overlap(self, other: AngularKetBase) -> float: kets = [self, other] - # Dummy overlaps - if any(isinstance(s, AngularKetDummy) for s in kets): - return int(self == other) - # JJ - FJ overlaps if any(isinstance(s, AngularKetJJ) for s in kets) and any(isinstance(s, AngularKetFJ) for s in kets): jj = next(s for s in kets if isinstance(s, AngularKetJJ)) @@ -421,10 +407,6 @@ def calc_reduced_matrix_element( # noqa: C901 if not is_angular_operator_type(operator): raise NotImplementedError(f"calc_reduced_matrix_element is not implemented for operator {operator}.") - # Dummy matrix elements - if any(isinstance(s, AngularKetDummy) for s in [self, other]): - return 0 - if type(self) is not type(other): return self.to_state().calc_reduced_matrix_element(other.to_state(), operator, kappa) if is_angular_momentum_quantum_number(operator) and operator not in self.quantum_number_names: @@ -748,122 +730,3 @@ def sanity_check(self, msgs: list[str] | None = None) -> None: msgs.append(f"{self.f_c=}, {self.j_r=}, {self.f_tot=} don't satisfy spin addition rule.") super().sanity_check(msgs) - - -class AngularKetDummy(AngularKetBase): - """Dummy spin ket for unknown quantum numbers.""" - - __slots__ = ("name",) - quantum_number_names: ClassVar = ("f_tot",) - coupled_quantum_numbers: ClassVar = {} - coupling_scheme = "Dummy" - - name: str - """Name of the dummy ket.""" - - def __init__( - self, - name: str, - f_tot: float, - m: float | None = None, - ) -> None: - """Initialize the Spin ket.""" - self.name = name - - self.f_tot = f_tot - self.m = m - - super()._post_init() - - def sanity_check(self, msgs: list[str] | None = None) -> None: - """Check that the quantum numbers are valid.""" - msgs = msgs if msgs is not None else [] - - if self.m is not None and not -self.f_tot <= self.m <= self.f_tot: - msgs.append(f"m must be between -f_tot and f_tot, but {self.f_tot=}, {self.m=}") - - if msgs: - msg = "\n ".join(msgs) - raise InvalidQuantumNumbersError(self, msg) - - def __repr__(self) -> str: - args = f"{self.name}, f_tot={self.f_tot}" - if self.m is not None: - args += f", m={self.m}" - return f"{self.__class__.__name__}({args})" - - def __str__(self) -> str: - return self.__repr__().replace("AngularKet", "") - - def __eq__(self, other: object) -> bool: - if not isinstance(other, AngularKetBase): - raise NotImplementedError(f"Cannot compare {self!r} with {other!r}.") - if not isinstance(other, AngularKetDummy): - return False - return self.name == other.name and self.f_tot == other.f_tot and self.m == other.m - - def __hash__(self) -> int: - return hash((self.name, self.f_tot, self.m)) - - -def julia_qn_to_dict(qn: juliacall.AnyValue) -> dict[str, float]: - """Convert MQDT Julia quantum numbers to dict object.""" - if "fjQuantumNumbers" in str(qn): - return dict(s_c=qn.sc, l_c=qn.lc, j_c=qn.Jc, f_c=qn.Fc, l_r=qn.lr, j_r=qn.Jr, f_tot=qn.F) # noqa: C408 - if "jjQuantumNumbers" in str(qn): - return dict(s_c=qn.sc, l_c=qn.lc, j_c=qn.Jc, l_r=qn.lr, j_r=qn.Jr, j_tot=qn.J, f_tot=qn.F) # noqa: C408 - if "lsQuantumNumbers" in str(qn): - return dict(s_c=qn.sc, s_tot=qn.S, l_c=qn.lc, l_r=qn.lr, l_tot=qn.L, j_tot=qn.J, f_tot=qn.F) # noqa: C408 - raise ValueError(f"Unknown MQDT Julia quantum numbers {qn!s}.") - - -def quantum_numbers_to_angular_ket( - species: str | SpeciesObject, - s_c: float | None = None, - l_c: int = 0, - j_c: float | None = None, - f_c: float | None = None, - s_r: float = 0.5, - l_r: int | None = None, - j_r: float | None = None, - s_tot: float | None = None, - l_tot: int | None = None, - j_tot: float | None = None, - f_tot: float | None = None, - m: float | None = None, -) -> AngularKetBase: - r"""Return an AngularKet object in the corresponding coupling scheme from the given quantum numbers. - - Args: - species: Atomic species. - s_c: Spin quantum number of the core electron (0 for Alkali, 0.5 for divalent atoms). - l_c: Orbital angular momentum quantum number of the core electron. - j_c: Total angular momentum quantum number of the core electron. - f_c: Total angular momentum quantum number of the core (core electron + nucleus). - s_r: Spin quantum number of the rydberg electron (always 0.5). - l_r: Orbital angular momentum quantum number of the rydberg electron. - j_r: Total angular momentum quantum number of the rydberg electron. - s_tot: Total spin quantum number of all electrons. - l_tot: Total orbital angular momentum quantum number of all electrons. - j_tot: Total angular momentum quantum number of all electrons. - f_tot: Total angular momentum quantum number of the atom (rydberg electron + core). - m: Total magnetic quantum number. - Optional, only needed for concrete angular matrix elements. - - """ - with contextlib.suppress(InvalidQuantumNumbersError, ValueError): - return AngularKetLS( - s_c=s_c, l_c=l_c, s_r=s_r, l_r=l_r, s_tot=s_tot, l_tot=l_tot, j_tot=j_tot, f_tot=f_tot, m=m, species=species - ) - - with contextlib.suppress(InvalidQuantumNumbersError, ValueError): - return AngularKetJJ( - s_c=s_c, l_c=l_c, j_c=j_c, s_r=s_r, l_r=l_r, j_r=j_r, j_tot=j_tot, f_tot=f_tot, m=m, species=species - ) - - with contextlib.suppress(InvalidQuantumNumbersError, ValueError): - return AngularKetFJ( - s_c=s_c, l_c=l_c, j_c=j_c, f_c=f_c, s_r=s_r, l_r=l_r, j_r=j_r, f_tot=f_tot, m=m, species=species - ) - - raise ValueError("Invalid combination of angular quantum numbers provided.") diff --git a/src/rydstate/angular/angular_ket_dummy.py b/src/rydstate/angular/angular_ket_dummy.py new file mode 100644 index 0000000..884cee9 --- /dev/null +++ b/src/rydstate/angular/angular_ket_dummy.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, ClassVar + +from rydstate.angular.angular_ket import AngularKetBase +from rydstate.angular.angular_matrix_element import is_angular_operator_type +from rydstate.angular.utils import InvalidQuantumNumbersError + +if TYPE_CHECKING: + from typing_extensions import Self + + from rydstate.angular.angular_matrix_element import AngularOperatorType + +logger = logging.getLogger(__name__) + + +class AngularKetDummy(AngularKetBase): + """Dummy spin ket for unknown quantum numbers.""" + + __slots__ = ("name",) + quantum_number_names: ClassVar = ("f_tot",) + coupled_quantum_numbers: ClassVar = {} + coupling_scheme = "Dummy" + + name: str + """Name of the dummy ket.""" + + def __init__( + self, + name: str, + f_tot: float, + m: float | None = None, + ) -> None: + """Initialize the Spin ket.""" + self.name = name + + self.f_tot = f_tot + self.m = m + + super()._post_init() + + def sanity_check(self, msgs: list[str] | None = None) -> None: + """Check that the quantum numbers are valid.""" + msgs = msgs if msgs is not None else [] + + if self.m is not None and not -self.f_tot <= self.m <= self.f_tot: + msgs.append(f"m must be between -f_tot and f_tot, but {self.f_tot=}, {self.m=}") + + if msgs: + msg = "\n ".join(msgs) + raise InvalidQuantumNumbersError(self, msg) + + def __repr__(self) -> str: + args = f"{self.name}, f_tot={self.f_tot}" + if self.m is not None: + args += f", m={self.m}" + return f"{self.__class__.__name__}({args})" + + def __str__(self) -> str: + return self.__repr__().replace("AngularKet", "") + + def __eq__(self, other: object) -> bool: + if not isinstance(other, AngularKetBase): + raise NotImplementedError(f"Cannot compare {self!r} with {other!r}.") + if not isinstance(other, AngularKetDummy): + return False + return self.name == other.name and self.f_tot == other.f_tot and self.m == other.m + + def __hash__(self) -> int: + return hash((self.name, self.f_tot, self.m)) + + def calc_reduced_overlap(self, other: AngularKetBase) -> float: + return int(self == other) + + def calc_reduced_matrix_element( + self: Self, + other: AngularKetBase, # noqa: ARG002 + operator: AngularOperatorType, + kappa: int, # noqa: ARG002 + ) -> float: + if not is_angular_operator_type(operator): + raise NotImplementedError(f"calc_reduced_matrix_element is not implemented for operator {operator}.") + + # ignore contributions from dummy kets + return 0 diff --git a/src/rydstate/angular/angular_state.py b/src/rydstate/angular/angular_state.py index 01b5700..0ae7542 100644 --- a/src/rydstate/angular/angular_state.py +++ b/src/rydstate/angular/angular_state.py @@ -8,11 +8,11 @@ from rydstate.angular.angular_ket import ( AngularKetBase, - AngularKetDummy, AngularKetFJ, AngularKetJJ, AngularKetLS, ) +from rydstate.angular.angular_ket_dummy import AngularKetDummy from rydstate.angular.angular_matrix_element import is_angular_momentum_quantum_number if TYPE_CHECKING: @@ -20,8 +20,8 @@ from typing_extensions import Never, Self - from rydstate.angular.angular_ket import CouplingScheme from rydstate.angular.angular_matrix_element import AngularMomentumQuantumNumbers, AngularOperatorType + from rydstate.angular.utils import CouplingScheme logger = logging.getLogger(__name__) diff --git a/src/rydstate/angular/utils.py b/src/rydstate/angular/utils.py index ecea008..0d73d07 100644 --- a/src/rydstate/angular/utils.py +++ b/src/rydstate/angular/utils.py @@ -1,7 +1,26 @@ from __future__ import annotations +import contextlib +from typing import TYPE_CHECKING, Literal + import numpy as np +if TYPE_CHECKING: + import juliacall + + from rydstate.angular.angular_ket import AngularKetBase + from rydstate.species.species_object import SpeciesObject + +CouplingScheme = Literal["LS", "JJ", "FJ", "Dummy"] + + +class InvalidQuantumNumbersError(ValueError): + def __init__(self, ket: AngularKetBase, msg: str = "") -> None: + _msg = f"Invalid quantum numbers for {ket!r}" + if len(msg) > 0: + _msg += f"\n {msg}" + super().__init__(_msg) + def minus_one_pow(n: float) -> int: """Calculate (-1)^n for an integer n and raise an error if n is not an integer.""" @@ -42,3 +61,68 @@ def get_possible_quantum_number_values(s_1: float, s_2: float, s_tot: float | No if s_tot is not None: return [s_tot] return [float(s) for s in np.arange(abs(s_1 - s_2), s_1 + s_2 + 1, 1)] + + +def julia_qn_to_dict(qn: juliacall.AnyValue) -> dict[str, float]: + """Convert MQDT Julia quantum numbers to dict object.""" + if "fjQuantumNumbers" in str(qn): + return dict(s_c=qn.sc, l_c=qn.lc, j_c=qn.Jc, f_c=qn.Fc, l_r=qn.lr, j_r=qn.Jr, f_tot=qn.F) # noqa: C408 + if "jjQuantumNumbers" in str(qn): + return dict(s_c=qn.sc, l_c=qn.lc, j_c=qn.Jc, l_r=qn.lr, j_r=qn.Jr, j_tot=qn.J, f_tot=qn.F) # noqa: C408 + if "lsQuantumNumbers" in str(qn): + return dict(s_c=qn.sc, s_tot=qn.S, l_c=qn.lc, l_r=qn.lr, l_tot=qn.L, j_tot=qn.J, f_tot=qn.F) # noqa: C408 + raise ValueError(f"Unknown MQDT Julia quantum numbers {qn!s}.") + + +def quantum_numbers_to_angular_ket( + species: str | SpeciesObject, + s_c: float | None = None, + l_c: int = 0, + j_c: float | None = None, + f_c: float | None = None, + s_r: float = 0.5, + l_r: int | None = None, + j_r: float | None = None, + s_tot: float | None = None, + l_tot: int | None = None, + j_tot: float | None = None, + f_tot: float | None = None, + m: float | None = None, +) -> AngularKetBase: + r"""Return an AngularKet object in the corresponding coupling scheme from the given quantum numbers. + + Args: + species: Atomic species. + s_c: Spin quantum number of the core electron (0 for Alkali, 0.5 for divalent atoms). + l_c: Orbital angular momentum quantum number of the core electron. + j_c: Total angular momentum quantum number of the core electron. + f_c: Total angular momentum quantum number of the core (core electron + nucleus). + s_r: Spin quantum number of the rydberg electron (always 0.5). + l_r: Orbital angular momentum quantum number of the rydberg electron. + j_r: Total angular momentum quantum number of the rydberg electron. + s_tot: Total spin quantum number of all electrons. + l_tot: Total orbital angular momentum quantum number of all electrons. + j_tot: Total angular momentum quantum number of all electrons. + f_tot: Total angular momentum quantum number of the atom (rydberg electron + core). + m: Total magnetic quantum number. + Optional, only needed for concrete angular matrix elements. + + """ + from rydstate.angular.angular_ket import AngularKetFJ, AngularKetJJ, AngularKetLS # noqa: PLC0415 + + with contextlib.suppress(InvalidQuantumNumbersError, ValueError): + return AngularKetLS( + s_c=s_c, l_c=l_c, s_r=s_r, l_r=l_r, s_tot=s_tot, l_tot=l_tot, j_tot=j_tot, f_tot=f_tot, m=m, species=species + ) + + with contextlib.suppress(InvalidQuantumNumbersError, ValueError): + return AngularKetJJ( + s_c=s_c, l_c=l_c, j_c=j_c, s_r=s_r, l_r=l_r, j_r=j_r, j_tot=j_tot, f_tot=f_tot, m=m, species=species + ) + + with contextlib.suppress(InvalidQuantumNumbersError, ValueError): + return AngularKetFJ( + s_c=s_c, l_c=l_c, j_c=j_c, f_c=f_c, s_r=s_r, l_r=l_r, j_r=j_r, f_tot=f_tot, m=m, species=species + ) + + raise ValueError("Invalid combination of angular quantum numbers provided.") diff --git a/src/rydstate/basis/basis_mqdt.py b/src/rydstate/basis/basis_mqdt.py index 097b9a4..7d068ae 100644 --- a/src/rydstate/basis/basis_mqdt.py +++ b/src/rydstate/basis/basis_mqdt.py @@ -5,11 +5,8 @@ import numpy as np -from rydstate.angular.angular_ket import ( - AngularKetDummy, - julia_qn_to_dict, - quantum_numbers_to_angular_ket, -) +from rydstate.angular.angular_ket_dummy import AngularKetDummy +from rydstate.angular.utils import julia_qn_to_dict, quantum_numbers_to_angular_ket from rydstate.basis.basis_base import BasisBase from rydstate.rydberg.rydberg_mqdt import RydbergStateMQDT from rydstate.rydberg.rydberg_sqdt import RydbergStateSQDT diff --git a/src/rydstate/rydberg/rydberg_sqdt.py b/src/rydstate/rydberg/rydberg_sqdt.py index 441f14c..9ea3b47 100644 --- a/src/rydstate/rydberg/rydberg_sqdt.py +++ b/src/rydstate/rydberg/rydberg_sqdt.py @@ -7,7 +7,7 @@ import numpy as np -from rydstate.angular.angular_ket import quantum_numbers_to_angular_ket +from rydstate.angular.utils import quantum_numbers_to_angular_ket from rydstate.radial import RadialKet from rydstate.rydberg.rydberg_base import RydbergStateBase from rydstate.species import SpeciesObject diff --git a/tests/test_angular_matrix_elements.py b/tests/test_angular_matrix_elements.py index ea6e2e8..07d31ff 100644 --- a/tests/test_angular_matrix_elements.py +++ b/tests/test_angular_matrix_elements.py @@ -8,8 +8,9 @@ from rydstate.angular.angular_matrix_element import AngularMomentumQuantumNumbers if TYPE_CHECKING: - from rydstate.angular.angular_ket import AngularKetBase, CouplingScheme + from rydstate.angular.angular_ket import AngularKetBase from rydstate.angular.angular_matrix_element import AngularOperatorType + from rydstate.angular.utils import CouplingScheme TEST_KET_PAIRS = [ ( From f6330a29e034a5ce3e8b769e963fc4af5b4302ad Mon Sep 17 00:00:00 2001 From: johannes-moegerle Date: Mon, 12 Jan 2026 09:23:59 +0100 Subject: [PATCH 6/6] improve basis mqdt --- src/rydstate/basis/basis_mqdt.py | 36 ++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/src/rydstate/basis/basis_mqdt.py b/src/rydstate/basis/basis_mqdt.py index 7d068ae..20994c3 100644 --- a/src/rydstate/basis/basis_mqdt.py +++ b/src/rydstate/basis/basis_mqdt.py @@ -34,13 +34,22 @@ USE_JULIACALL = False -if USE_JULIACALL: - try: - jl.seval("using MQDT") - jl.seval("using CGcoefficient") - except JuliaError: - logger.exception("Failed to load Julia MQDT or CGcoefficient package") - USE_JULIACALL = False +IS_MQDT_IMPORTED = False + + +def import_mqdt() -> bool: + """Load the MQDT Julia package. + + Since this might be time-consuming, we only do it if needed and ensure it is called only once. + """ + global IS_MQDT_IMPORTED # noqa: PLW0603 + if not IS_MQDT_IMPORTED: + try: + jl.seval("using MQDT") + IS_MQDT_IMPORTED = True + except JuliaError: + logger.exception("Failed to load Julia MQDT package") + return IS_MQDT_IMPORTED class BasisMQDT(BasisBase[RydbergStateMQDT[Any]]): @@ -52,20 +61,15 @@ def __init__( # noqa: PLR0915, C901, PLR0912 *, skip_high_l: bool = True, ) -> None: + if not USE_JULIACALL: + raise ImportError("JuliaCall is not available, try `pip install rydstate[mqdt]`.") + if not import_mqdt(): + raise ImportError("Failed to load the MQDT Julia package, try `pip install rydstate[mqdt]`.") super().__init__(species) if n_max is None: raise ValueError("n_max must be given") - if not USE_JULIACALL: - raise ImportError("JuliaCall or the MQDT Julia package is not available.") - - # initialize Wigner symbol calculation - if skip_high_l: - jl.CGcoefficient.wigner_init_float(10, "Jmax", 9) - else: - jl.CGcoefficient.wigner_init_float(n_max - 1, "Jmax", 9) - jl_species = jl.Symbol(self.species.name) parameters = jl.MQDT.get_parameters(jl_species)