diff --git a/pyproject.toml b/pyproject.toml index c567b4b..ae0ba8b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,9 @@ dependencies = [ ] [project.optional-dependencies] +mqdt = [ + "juliacall >= 0.9.24", +] tests = [ "pytest >= 8.0", "nbmake >= 1.3", diff --git a/src/rydstate/__init__.py b/src/rydstate/__init__.py index 739e3ad..429cffb 100644 --- a/src/rydstate/__init__.py +++ b/src/rydstate/__init__.py @@ -1,11 +1,13 @@ from rydstate import angular, basis, radial, rydberg, species from rydstate.basis import ( + BasisMQDT, BasisSQDTAlkali, BasisSQDTAlkalineFJ, BasisSQDTAlkalineJJ, BasisSQDTAlkalineLS, ) from rydstate.rydberg import ( + RydbergStateMQDT, RydbergStateSQDT, RydbergStateSQDTAlkali, RydbergStateSQDTAlkalineFJ, @@ -15,10 +17,12 @@ from rydstate.units import ureg __all__ = [ + "BasisMQDT", "BasisSQDTAlkali", "BasisSQDTAlkalineFJ", "BasisSQDTAlkalineJJ", "BasisSQDTAlkalineLS", + "RydbergStateMQDT", "RydbergStateSQDT", "RydbergStateSQDTAlkali", "RydbergStateSQDTAlkalineFJ", diff --git a/src/rydstate/angular/__init__.py b/src/rydstate/angular/__init__.py index 66f0142..4394d0e 100644 --- a/src/rydstate/angular/__init__.py +++ b/src/rydstate/angular/__init__.py @@ -1,8 +1,10 @@ from rydstate.angular import utils from rydstate.angular.angular_ket import AngularKetFJ, AngularKetJJ, AngularKetLS +from rydstate.angular.angular_ket_dummy import AngularKetDummy from rydstate.angular.angular_state import AngularState __all__ = [ + "AngularKetDummy", "AngularKetFJ", "AngularKetJJ", "AngularKetLS", diff --git a/src/rydstate/angular/angular_ket.py b/src/rydstate/angular/angular_ket.py index a21b2ce..7492c25 100644 --- a/src/rydstate/angular/angular_ket.py +++ b/src/rydstate/angular/angular_ket.py @@ -1,6 +1,5 @@ from __future__ import annotations -import contextlib import logging from abc import ABC from typing import TYPE_CHECKING, Any, ClassVar, Literal, overload @@ -14,6 +13,7 @@ is_angular_operator_type, ) from rydstate.angular.utils import ( + InvalidQuantumNumbersError, check_spin_addition_rule, get_possible_quantum_number_values, minus_one_pow, @@ -23,23 +23,14 @@ from rydstate.species import SpeciesObject if TYPE_CHECKING: - from typing_extensions import Self + from typing_extensions import Never, Self from rydstate.angular.angular_matrix_element import AngularMomentumQuantumNumbers, AngularOperatorType from rydstate.angular.angular_state import AngularState + from rydstate.angular.utils import CouplingScheme logger = logging.getLogger(__name__) -CouplingScheme = Literal["LS", "JJ", "FJ"] - - -class InvalidQuantumNumbersError(ValueError): - def __init__(self, ket: AngularKetBase, msg: str = "") -> None: - _msg = f"Invalid quantum numbers for {ket!r}" - if len(msg) > 0: - _msg += f"\n {msg}" - super().__init__(_msg) - class AngularKetBase(ABC): """Base class for a angular ket (i.e. a simple canonical spin ketstate).""" @@ -223,6 +214,9 @@ def to_state(self, coupling_scheme: Literal["JJ"]) -> AngularState[AngularKetJJ] @overload def to_state(self, coupling_scheme: Literal["FJ"]) -> AngularState[AngularKetFJ]: ... + @overload + def to_state(self, coupling_scheme: Literal["Dummy"]) -> Never: ... + @overload def to_state(self: Self) -> AngularState[Self]: ... @@ -736,55 +730,3 @@ def sanity_check(self, msgs: list[str] | None = None) -> None: msgs.append(f"{self.f_c=}, {self.j_r=}, {self.f_tot=} don't satisfy spin addition rule.") super().sanity_check(msgs) - - -def quantum_numbers_to_angular_ket( - species: str | SpeciesObject, - s_c: float | None = None, - l_c: int = 0, - j_c: float | None = None, - f_c: float | None = None, - s_r: float = 0.5, - l_r: int | None = None, - j_r: float | None = None, - s_tot: float | None = None, - l_tot: int | None = None, - j_tot: float | None = None, - f_tot: float | None = None, - m: float | None = None, -) -> AngularKetBase: - r"""Return an AngularKet object in the corresponding coupling scheme from the given quantum numbers. - - Args: - species: Atomic species. - s_c: Spin quantum number of the core electron (0 for Alkali, 0.5 for divalent atoms). - l_c: Orbital angular momentum quantum number of the core electron. - j_c: Total angular momentum quantum number of the core electron. - f_c: Total angular momentum quantum number of the core (core electron + nucleus). - s_r: Spin quantum number of the rydberg electron (always 0.5). - l_r: Orbital angular momentum quantum number of the rydberg electron. - j_r: Total angular momentum quantum number of the rydberg electron. - s_tot: Total spin quantum number of all electrons. - l_tot: Total orbital angular momentum quantum number of all electrons. - j_tot: Total angular momentum quantum number of all electrons. - f_tot: Total angular momentum quantum number of the atom (rydberg electron + core). - m: Total magnetic quantum number. - Optional, only needed for concrete angular matrix elements. - - """ - with contextlib.suppress(InvalidQuantumNumbersError, ValueError): - return AngularKetLS( - s_c=s_c, l_c=l_c, s_r=s_r, l_r=l_r, s_tot=s_tot, l_tot=l_tot, j_tot=j_tot, f_tot=f_tot, m=m, species=species - ) - - with contextlib.suppress(InvalidQuantumNumbersError, ValueError): - return AngularKetJJ( - s_c=s_c, l_c=l_c, j_c=j_c, s_r=s_r, l_r=l_r, j_r=j_r, j_tot=j_tot, f_tot=f_tot, m=m, species=species - ) - - with contextlib.suppress(InvalidQuantumNumbersError, ValueError): - return AngularKetFJ( - s_c=s_c, l_c=l_c, j_c=j_c, f_c=f_c, s_r=s_r, l_r=l_r, j_r=j_r, f_tot=f_tot, m=m, species=species - ) - - raise ValueError("Invalid combination of angular quantum numbers provided.") diff --git a/src/rydstate/angular/angular_ket_dummy.py b/src/rydstate/angular/angular_ket_dummy.py new file mode 100644 index 0000000..884cee9 --- /dev/null +++ b/src/rydstate/angular/angular_ket_dummy.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, ClassVar + +from rydstate.angular.angular_ket import AngularKetBase +from rydstate.angular.angular_matrix_element import is_angular_operator_type +from rydstate.angular.utils import InvalidQuantumNumbersError + +if TYPE_CHECKING: + from typing_extensions import Self + + from rydstate.angular.angular_matrix_element import AngularOperatorType + +logger = logging.getLogger(__name__) + + +class AngularKetDummy(AngularKetBase): + """Dummy spin ket for unknown quantum numbers.""" + + __slots__ = ("name",) + quantum_number_names: ClassVar = ("f_tot",) + coupled_quantum_numbers: ClassVar = {} + coupling_scheme = "Dummy" + + name: str + """Name of the dummy ket.""" + + def __init__( + self, + name: str, + f_tot: float, + m: float | None = None, + ) -> None: + """Initialize the Spin ket.""" + self.name = name + + self.f_tot = f_tot + self.m = m + + super()._post_init() + + def sanity_check(self, msgs: list[str] | None = None) -> None: + """Check that the quantum numbers are valid.""" + msgs = msgs if msgs is not None else [] + + if self.m is not None and not -self.f_tot <= self.m <= self.f_tot: + msgs.append(f"m must be between -f_tot and f_tot, but {self.f_tot=}, {self.m=}") + + if msgs: + msg = "\n ".join(msgs) + raise InvalidQuantumNumbersError(self, msg) + + def __repr__(self) -> str: + args = f"{self.name}, f_tot={self.f_tot}" + if self.m is not None: + args += f", m={self.m}" + return f"{self.__class__.__name__}({args})" + + def __str__(self) -> str: + return self.__repr__().replace("AngularKet", "") + + def __eq__(self, other: object) -> bool: + if not isinstance(other, AngularKetBase): + raise NotImplementedError(f"Cannot compare {self!r} with {other!r}.") + if not isinstance(other, AngularKetDummy): + return False + return self.name == other.name and self.f_tot == other.f_tot and self.m == other.m + + def __hash__(self) -> int: + return hash((self.name, self.f_tot, self.m)) + + def calc_reduced_overlap(self, other: AngularKetBase) -> float: + return int(self == other) + + def calc_reduced_matrix_element( + self: Self, + other: AngularKetBase, # noqa: ARG002 + operator: AngularOperatorType, + kappa: int, # noqa: ARG002 + ) -> float: + if not is_angular_operator_type(operator): + raise NotImplementedError(f"calc_reduced_matrix_element is not implemented for operator {operator}.") + + # ignore contributions from dummy kets + return 0 diff --git a/src/rydstate/angular/angular_state.py b/src/rydstate/angular/angular_state.py index e3978fb..0ae7542 100644 --- a/src/rydstate/angular/angular_state.py +++ b/src/rydstate/angular/angular_state.py @@ -12,15 +12,16 @@ AngularKetJJ, AngularKetLS, ) +from rydstate.angular.angular_ket_dummy import AngularKetDummy from rydstate.angular.angular_matrix_element import is_angular_momentum_quantum_number if TYPE_CHECKING: from collections.abc import Iterator, Sequence - from typing_extensions import Self + from typing_extensions import Never, Self - from rydstate.angular.angular_ket import CouplingScheme from rydstate.angular.angular_matrix_element import AngularMomentumQuantumNumbers, AngularOperatorType + from rydstate.angular.utils import CouplingScheme logger = logging.getLogger(__name__) @@ -32,21 +33,28 @@ class AngularState(Generic[_AngularKet]): def __init__( self, coefficients: Sequence[float], kets: Sequence[_AngularKet], *, warn_if_not_normalized: bool = True ) -> None: - self.coefficients = np.array(coefficients) - self.kets = kets + """Initialize an angular state as a linear combination of angular kets. + + All kets must be of the same type (coupling scheme), and no duplicate kets are allowed. + Dummy kets (AngularKetDummy) are ignored in the state representation, + however adding them is recommended for normalization purposes. + """ + self._coefficients = np.array(coefficients) + self._kets = kets + self._warn_if_not_normalized = warn_if_not_normalized if len(coefficients) != len(kets): raise ValueError("Length of coefficients and kets must be the same.") if len(kets) == 0: raise ValueError("At least one ket must be provided.") - if not all(type(ket) is type(kets[0]) for ket in kets): + if not all(type(ket) is type(self.kets[0]) for ket in self.kets): raise ValueError("All kets must be of the same type.") - if len(set(kets)) != len(kets): - raise ValueError("AngularState initialized with duplicate kets.") + if len(set(self.kets)) != len(self.kets): + raise ValueError("AngularState initialized with duplicate kets: %s", self.kets) if abs(self.norm - 1) > 1e-10 and warn_if_not_normalized: logger.warning("AngularState initialized with non-normalized coefficients: %s, %s", coefficients, kets) if self.norm > 1: - self.coefficients /= self.norm + self._coefficients /= self.norm def __iter__(self) -> Iterator[tuple[float, _AngularKet]]: return zip(self.coefficients, self.kets).__iter__() @@ -59,6 +67,16 @@ def __str__(self) -> str: terms = [f"{coeff}*{ket!s}" for coeff, ket in self] return f"{', '.join(terms)}" + @property + def kets(self) -> list[_AngularKet]: + return [ket for ket in self._kets if not isinstance(ket, AngularKetDummy)] + + @property + def coefficients(self) -> np.ndarray: + return np.array( + [coeff for coeff, ket in zip(self._coefficients, self._kets) if not isinstance(ket, AngularKetDummy)] + ) + @property def coupling_scheme(self) -> CouplingScheme: """Return the coupling scheme of the state.""" @@ -67,7 +85,7 @@ def coupling_scheme(self) -> CouplingScheme: @property def norm(self) -> float: """Return the norm of the state (should be 1).""" - return np.linalg.norm(self.coefficients) # type: ignore [return-value] + return np.linalg.norm(self._coefficients) # type: ignore [return-value] @overload def to(self, coupling_scheme: Literal["LS"]) -> AngularState[AngularKetLS]: ... @@ -78,6 +96,9 @@ def to(self, coupling_scheme: Literal["JJ"]) -> AngularState[AngularKetJJ]: ... @overload def to(self, coupling_scheme: Literal["FJ"]) -> AngularState[AngularKetFJ]: ... + @overload + def to(self, coupling_scheme: Literal["Dummy"]) -> Never: ... + def to(self, coupling_scheme: CouplingScheme) -> AngularState[Any]: """Convert to specified coupling scheme. @@ -98,7 +119,8 @@ def to(self, coupling_scheme: CouplingScheme) -> AngularState[Any]: else: kets.append(scheme_ket) coefficients.append(coeff * scheme_coeff) - return AngularState(coefficients, kets, warn_if_not_normalized=abs(self.norm - 1) < 1e-10) + warn_if_not_normalized = self._warn_if_not_normalized and (abs(self.norm - 1) < 1e-10) + return AngularState(coefficients, kets, warn_if_not_normalized=warn_if_not_normalized) def calc_exp_qn(self, q: AngularMomentumQuantumNumbers) -> float: """Calculate the expectation value of a quantum number q. diff --git a/src/rydstate/angular/utils.py b/src/rydstate/angular/utils.py index ecea008..0d73d07 100644 --- a/src/rydstate/angular/utils.py +++ b/src/rydstate/angular/utils.py @@ -1,7 +1,26 @@ from __future__ import annotations +import contextlib +from typing import TYPE_CHECKING, Literal + import numpy as np +if TYPE_CHECKING: + import juliacall + + from rydstate.angular.angular_ket import AngularKetBase + from rydstate.species.species_object import SpeciesObject + +CouplingScheme = Literal["LS", "JJ", "FJ", "Dummy"] + + +class InvalidQuantumNumbersError(ValueError): + def __init__(self, ket: AngularKetBase, msg: str = "") -> None: + _msg = f"Invalid quantum numbers for {ket!r}" + if len(msg) > 0: + _msg += f"\n {msg}" + super().__init__(_msg) + def minus_one_pow(n: float) -> int: """Calculate (-1)^n for an integer n and raise an error if n is not an integer.""" @@ -42,3 +61,68 @@ def get_possible_quantum_number_values(s_1: float, s_2: float, s_tot: float | No if s_tot is not None: return [s_tot] return [float(s) for s in np.arange(abs(s_1 - s_2), s_1 + s_2 + 1, 1)] + + +def julia_qn_to_dict(qn: juliacall.AnyValue) -> dict[str, float]: + """Convert MQDT Julia quantum numbers to dict object.""" + if "fjQuantumNumbers" in str(qn): + return dict(s_c=qn.sc, l_c=qn.lc, j_c=qn.Jc, f_c=qn.Fc, l_r=qn.lr, j_r=qn.Jr, f_tot=qn.F) # noqa: C408 + if "jjQuantumNumbers" in str(qn): + return dict(s_c=qn.sc, l_c=qn.lc, j_c=qn.Jc, l_r=qn.lr, j_r=qn.Jr, j_tot=qn.J, f_tot=qn.F) # noqa: C408 + if "lsQuantumNumbers" in str(qn): + return dict(s_c=qn.sc, s_tot=qn.S, l_c=qn.lc, l_r=qn.lr, l_tot=qn.L, j_tot=qn.J, f_tot=qn.F) # noqa: C408 + raise ValueError(f"Unknown MQDT Julia quantum numbers {qn!s}.") + + +def quantum_numbers_to_angular_ket( + species: str | SpeciesObject, + s_c: float | None = None, + l_c: int = 0, + j_c: float | None = None, + f_c: float | None = None, + s_r: float = 0.5, + l_r: int | None = None, + j_r: float | None = None, + s_tot: float | None = None, + l_tot: int | None = None, + j_tot: float | None = None, + f_tot: float | None = None, + m: float | None = None, +) -> AngularKetBase: + r"""Return an AngularKet object in the corresponding coupling scheme from the given quantum numbers. + + Args: + species: Atomic species. + s_c: Spin quantum number of the core electron (0 for Alkali, 0.5 for divalent atoms). + l_c: Orbital angular momentum quantum number of the core electron. + j_c: Total angular momentum quantum number of the core electron. + f_c: Total angular momentum quantum number of the core (core electron + nucleus). + s_r: Spin quantum number of the rydberg electron (always 0.5). + l_r: Orbital angular momentum quantum number of the rydberg electron. + j_r: Total angular momentum quantum number of the rydberg electron. + s_tot: Total spin quantum number of all electrons. + l_tot: Total orbital angular momentum quantum number of all electrons. + j_tot: Total angular momentum quantum number of all electrons. + f_tot: Total angular momentum quantum number of the atom (rydberg electron + core). + m: Total magnetic quantum number. + Optional, only needed for concrete angular matrix elements. + + """ + from rydstate.angular.angular_ket import AngularKetFJ, AngularKetJJ, AngularKetLS # noqa: PLC0415 + + with contextlib.suppress(InvalidQuantumNumbersError, ValueError): + return AngularKetLS( + s_c=s_c, l_c=l_c, s_r=s_r, l_r=l_r, s_tot=s_tot, l_tot=l_tot, j_tot=j_tot, f_tot=f_tot, m=m, species=species + ) + + with contextlib.suppress(InvalidQuantumNumbersError, ValueError): + return AngularKetJJ( + s_c=s_c, l_c=l_c, j_c=j_c, s_r=s_r, l_r=l_r, j_r=j_r, j_tot=j_tot, f_tot=f_tot, m=m, species=species + ) + + with contextlib.suppress(InvalidQuantumNumbersError, ValueError): + return AngularKetFJ( + s_c=s_c, l_c=l_c, j_c=j_c, f_c=f_c, s_r=s_r, l_r=l_r, j_r=j_r, f_tot=f_tot, m=m, species=species + ) + + raise ValueError("Invalid combination of angular quantum numbers provided.") diff --git a/src/rydstate/basis/__init__.py b/src/rydstate/basis/__init__.py index 7f15364..93ef96c 100644 --- a/src/rydstate/basis/__init__.py +++ b/src/rydstate/basis/__init__.py @@ -1,4 +1,12 @@ from rydstate.basis.basis_base import BasisBase +from rydstate.basis.basis_mqdt import BasisMQDT from rydstate.basis.basis_sqdt import BasisSQDTAlkali, BasisSQDTAlkalineFJ, BasisSQDTAlkalineJJ, BasisSQDTAlkalineLS -__all__ = ["BasisBase", "BasisSQDTAlkali", "BasisSQDTAlkalineFJ", "BasisSQDTAlkalineJJ", "BasisSQDTAlkalineLS"] +__all__ = [ + "BasisBase", + "BasisMQDT", + "BasisSQDTAlkali", + "BasisSQDTAlkalineFJ", + "BasisSQDTAlkalineJJ", + "BasisSQDTAlkalineLS", +] diff --git a/src/rydstate/basis/basis_mqdt.py b/src/rydstate/basis/basis_mqdt.py new file mode 100644 index 0000000..20994c3 --- /dev/null +++ b/src/rydstate/basis/basis_mqdt.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any + +import numpy as np + +from rydstate.angular.angular_ket_dummy import AngularKetDummy +from rydstate.angular.utils import julia_qn_to_dict, quantum_numbers_to_angular_ket +from rydstate.basis.basis_base import BasisBase +from rydstate.rydberg.rydberg_mqdt import RydbergStateMQDT +from rydstate.rydberg.rydberg_sqdt import RydbergStateSQDT + +if TYPE_CHECKING: + from juliacall import ( + JuliaError, + Main as jl, # noqa: N813 + convert, + ) + + from rydstate.angular.angular_ket import AngularKetBase + from rydstate.species import SpeciesObject + +logger = logging.getLogger(__name__) + +try: + USE_JULIACALL = True + from juliacall import ( + JuliaError, + Main as jl, # noqa: N813 + convert, + ) +except ImportError: + USE_JULIACALL = False + + +IS_MQDT_IMPORTED = False + + +def import_mqdt() -> bool: + """Load the MQDT Julia package. + + Since this might be time-consuming, we only do it if needed and ensure it is called only once. + """ + global IS_MQDT_IMPORTED # noqa: PLW0603 + if not IS_MQDT_IMPORTED: + try: + jl.seval("using MQDT") + IS_MQDT_IMPORTED = True + except JuliaError: + logger.exception("Failed to load Julia MQDT package") + return IS_MQDT_IMPORTED + + +class BasisMQDT(BasisBase[RydbergStateMQDT[Any]]): + def __init__( # noqa: PLR0915, C901, PLR0912 + self, + species: str | SpeciesObject, + n_min: int = 0, + n_max: int | None = None, + *, + skip_high_l: bool = True, + ) -> None: + if not USE_JULIACALL: + raise ImportError("JuliaCall is not available, try `pip install rydstate[mqdt]`.") + if not import_mqdt(): + raise ImportError("Failed to load the MQDT Julia package, try `pip install rydstate[mqdt]`.") + super().__init__(species) + + if n_max is None: + raise ValueError("n_max must be given") + + jl_species = jl.Symbol(self.species.name) + parameters = jl.MQDT.get_parameters(jl_species) + + self.models = [] + i_c = self.species.i_c if self.species.i_c is not None else 0 + for l in range(n_max): + jtot_min = min(l, abs(l - 1)) + jtot_max = l + 1 + for f_tot in np.arange(abs(jtot_min - i_c), jtot_max + i_c + 1): + models = jl.MQDT.get_fmodels(jl_species, l, float(f_tot)) + self.models.extend(models) + + n_min_high_l = 25 + + logger.debug("Calculating MQDT states...") + jl_states = [] + for model in self.models: + _n_min = n_min + if model.name.startswith("SQDT"): + if skip_high_l: + continue + _n_min = n_min_high_l + + logger.debug("model name: %s", model.name) + states = jl.MQDT.eigenstates(_n_min, n_max, model, parameters) + jl_states.append(states) + + if len(states.n) == 0: + logger.debug(" no states found") + else: + logger.debug(" nu_min=%s, nu_max=%s, total states=%d", min(states.n), max(states.n), len(states.n)) + + jl_basis = jl.basisarray(convert(jl.Vector, jl_states), convert(jl.Vector, self.models)) + + logger.debug("Generated state table with %d states", len(jl_basis.states)) + + self.states = [] + for jl_state in jl_basis.states: + nus = jl_state.nu_list + nu_energy = jl_state.energy + angular_kets: list[AngularKetBase] = [] + iqn = 0 + model = jl_state.model + for i, core in enumerate(model.core): + if not core: + name = model.name + model.terms[i] + angular_kets.append(AngularKetDummy(name, f_tot=model.f_tot)) + continue + + qn = julia_qn_to_dict(jl_state.channels.i[iqn]) + try: + angular_kets.append(quantum_numbers_to_angular_ket(species=self.species, **qn)) # type: ignore[arg-type] + except ValueError: + name = model.name + model.terms[i] + angular_kets.append(AngularKetDummy(name, f_tot=model.f_tot)) + + iqn += 1 + + sqdt_states = [ + RydbergStateSQDT.from_angular_ket(species, angular_ket, nu=nu) + for nu, angular_ket in zip(nus, angular_kets) + ] + # check angular and radial are created correctly + assert len([(s.angular, s.radial) for s in sqdt_states]) > 0 + + mqdt_state = RydbergStateMQDT(jl_state.coefficients, sqdt_states, nu_energy=nu_energy) + self.states.append(mqdt_state) diff --git a/src/rydstate/rydberg/__init__.py b/src/rydstate/rydberg/__init__.py index f7a895b..b13d7a7 100644 --- a/src/rydstate/rydberg/__init__.py +++ b/src/rydstate/rydberg/__init__.py @@ -1,4 +1,5 @@ from rydstate.rydberg.rydberg_base import RydbergStateBase +from rydstate.rydberg.rydberg_mqdt import RydbergStateMQDT from rydstate.rydberg.rydberg_sqdt import ( RydbergStateSQDT, RydbergStateSQDTAlkali, @@ -9,6 +10,7 @@ __all__ = [ "RydbergStateBase", + "RydbergStateMQDT", "RydbergStateSQDT", "RydbergStateSQDTAlkali", "RydbergStateSQDTAlkalineFJ", diff --git a/src/rydstate/rydberg/rydberg_mqdt.py b/src/rydstate/rydberg/rydberg_mqdt.py new file mode 100644 index 0000000..48b03e4 --- /dev/null +++ b/src/rydstate/rydberg/rydberg_mqdt.py @@ -0,0 +1,121 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload + +import numpy as np + +from rydstate.angular import AngularState +from rydstate.rydberg.rydberg_base import RydbergStateBase +from rydstate.rydberg.rydberg_sqdt import RydbergStateSQDT + +if TYPE_CHECKING: + from collections.abc import Iterator, Sequence + + from rydstate.units import MatrixElementOperator, PintFloat + + +logger = logging.getLogger(__name__) + + +_RydbergState = TypeVar("_RydbergState", bound=RydbergStateSQDT) + + +class RydbergStateMQDT(RydbergStateBase, Generic[_RydbergState]): + angular: AngularState[Any] + """Return the angular part of the MQDT state as an AngularState.""" + + def __init__( + self, + coefficients: Sequence[float], + sqdt_states: Sequence[_RydbergState], + *, + nu_energy: float | None = None, + warn_if_not_normalized: bool = True, + ) -> None: + self.coefficients = np.array(coefficients) + self.sqdt_states = sqdt_states + self.nu_energy = nu_energy + self.angular = AngularState(self.coefficients.tolist(), [ket.angular for ket in sqdt_states]) + + if len(coefficients) != len(sqdt_states): + raise ValueError("Length of coefficients and sqdt_states must be the same.") + if not all(type(sqdt_state) is type(sqdt_states[0]) for sqdt_state in sqdt_states): + raise ValueError("All sqdt_states must be of the same type.") + if len(set(sqdt_states)) != len(sqdt_states): + raise ValueError("RydbergStateMQDT initialized with duplicate sqdt_states.") + if abs(self.norm - 1) > 1e-10 and warn_if_not_normalized: + logger.warning( + "RydbergStateMQDT initialized with non-normalized coefficients " + "(norm=%s, coefficients=%s, sqdt_states=%s)", + self.norm, + coefficients, + sqdt_states, + ) + if self.norm > 1: + self.coefficients /= self.norm + + def __iter__(self) -> Iterator[tuple[float, _RydbergState]]: + return zip(self.coefficients, self.sqdt_states).__iter__() + + def __repr__(self) -> str: + terms = [f"{coeff}*{sqdt_state!r}" for coeff, sqdt_state in self] + return f"{self.__class__.__name__}({', '.join(terms)})" + + def __str__(self) -> str: + terms = [f"{coeff}*{sqdt_state!s}" for coeff, sqdt_state in self] + return f"{', '.join(terms)}" + + @property + def norm(self) -> float: + """Return the norm of the state (should be 1).""" + return np.linalg.norm(self.coefficients) # type: ignore [return-value] + + def calc_reduced_overlap(self, other: RydbergStateBase) -> float: + """Calculate the reduced overlap (ignoring the magnetic quantum number m).""" + if isinstance(other, RydbergStateSQDT): + other = other.to_mqdt() + + if isinstance(other, RydbergStateMQDT): + ov = 0 + for coeff1, sqdt1 in self: + for coeff2, sqdt2 in other: + ov += np.conjugate(coeff1) * coeff2 * sqdt1.calc_reduced_overlap(sqdt2) + return ov + + raise NotImplementedError(f"calc_reduced_overlap not implemented for {type(self)=}, {type(other)=}") + + @overload # type: ignore [override] + def calc_reduced_matrix_element( + self, other: RydbergStateBase, operator: MatrixElementOperator, unit: None = None + ) -> PintFloat: ... + + @overload + def calc_reduced_matrix_element( + self, other: RydbergStateBase, operator: MatrixElementOperator, unit: str + ) -> float: ... + + def calc_reduced_matrix_element( + self, other: RydbergStateBase, operator: MatrixElementOperator, unit: str | None = None + ) -> PintFloat | float: + r"""Calculate the reduced angular matrix element. + + This means, calculate the following matrix element: + + .. math:: + \left\langle self || \hat{O}^{(\kappa)} || other \right\rangle + + """ + if isinstance(other, RydbergStateSQDT): + other = other.to_mqdt() + + if isinstance(other, RydbergStateMQDT): + value = 0 + for coeff1, sqdt1 in self: + for coeff2, sqdt2 in other: + value += ( + np.conjugate(coeff1) * coeff2 * sqdt1.calc_reduced_matrix_element(sqdt2, operator, unit=unit) + ) + return value + + raise NotImplementedError(f"calc_reduced_overlap not implemented for {type(self)=}, {type(other)=}") diff --git a/src/rydstate/rydberg/rydberg_sqdt.py b/src/rydstate/rydberg/rydberg_sqdt.py index 744b582..9ea3b47 100644 --- a/src/rydstate/rydberg/rydberg_sqdt.py +++ b/src/rydstate/rydberg/rydberg_sqdt.py @@ -3,11 +3,11 @@ import logging import math from functools import cached_property -from typing import TYPE_CHECKING, overload +from typing import TYPE_CHECKING, Any, overload import numpy as np -from rydstate.angular.angular_ket import quantum_numbers_to_angular_ket +from rydstate.angular.utils import quantum_numbers_to_angular_ket from rydstate.radial import RadialKet from rydstate.rydberg.rydberg_base import RydbergStateBase from rydstate.species import SpeciesObject @@ -15,6 +15,7 @@ from rydstate.units import BaseQuantities, MatrixElementOperatorRanks, ureg if TYPE_CHECKING: + from rydstate import RydbergStateMQDT from rydstate.angular.angular_ket import AngularKetBase, AngularKetFJ, AngularKetJJ, AngularKetLS from rydstate.units import MatrixElementOperator, PintFloat @@ -182,10 +183,16 @@ def get_energy(self, unit: str | None = None) -> PintFloat | float: return energy return energy.to(unit, "spectroscopy").magnitude + def to_mqdt(self) -> RydbergStateMQDT[Any]: + """Convert to a trivial RydbergMQDT state with only one contribution with coefficient 1.""" + from rydstate import RydbergStateMQDT # noqa: PLC0415 + + return RydbergStateMQDT([1], [self]) + def calc_reduced_overlap(self, other: RydbergStateBase) -> float: """Calculate the reduced overlap (ignoring the magnetic quantum number m).""" if not isinstance(other, RydbergStateSQDT): - raise NotImplementedError("Reduced overlap only implemented between RydbergStateSQDT states.") + return self.to_mqdt().calc_reduced_overlap(other) radial_overlap = self.radial.calc_overlap(other.radial) angular_overlap = self.angular.calc_reduced_overlap(other.angular) @@ -226,7 +233,7 @@ def calc_reduced_matrix_element( """ if not isinstance(other, RydbergStateSQDT): - raise NotImplementedError("Reduced matrix element only implemented between RydbergStateSQDT states.") + return self.to_mqdt().calc_reduced_matrix_element(other, operator, unit=unit) if operator not in MatrixElementOperatorRanks: raise ValueError( diff --git a/tests/test_angular_matrix_elements.py b/tests/test_angular_matrix_elements.py index ea6e2e8..07d31ff 100644 --- a/tests/test_angular_matrix_elements.py +++ b/tests/test_angular_matrix_elements.py @@ -8,8 +8,9 @@ from rydstate.angular.angular_matrix_element import AngularMomentumQuantumNumbers if TYPE_CHECKING: - from rydstate.angular.angular_ket import AngularKetBase, CouplingScheme + from rydstate.angular.angular_ket import AngularKetBase from rydstate.angular.angular_matrix_element import AngularOperatorType + from rydstate.angular.utils import CouplingScheme TEST_KET_PAIRS = [ (