diff --git a/README.md b/README.md index eb586aa..c6fecb5 100644 --- a/README.md +++ b/README.md @@ -205,6 +205,11 @@ predicate-authority revoke intent --host 127.0.0.1 --port 8787 --hash ` to enable optional local persistence and restart recovery. + ### Identity mode options (`predicate-authorityd`) - `--identity-mode local`: deterministic local bridge (default). diff --git a/docs/authorityd-operations.md b/docs/authorityd-operations.md index 3cbe919..3a8b64d 100644 --- a/docs/authorityd-operations.md +++ b/docs/authorityd-operations.md @@ -43,6 +43,23 @@ PYTHONPATH=. predicate-authorityd \ --credential-store-file ./.predicate-authorityd/credentials.json ``` +By design, mandate/revocation cache is in-memory (ephemeral) unless you explicitly +enable persistence with `--mandate-store-file`. + +### Optional: enable persisted mandate/revocation cache (parity extension) + +Use this only when restart-recovery for local revocations/mandate lineage is required. +If omitted, default behavior remains ephemeral. + +```bash +PYTHONPATH=. predicate-authorityd \ + --host 127.0.0.1 \ + --port 8787 \ + --mode local_only \ + --policy-file examples/authorityd/policy.json \ + --mandate-store-file ./.predicate-authorityd/mandates.json +``` + ### Optional: enable control-plane shipping To automatically ship proof events and usage records to @@ -506,6 +523,7 @@ Example response: { "mode": "local_only", "policy_hot_reload_enabled": true, + "mandate_store_persistence_enabled": false, "revoked_principal_count": 0, "revoked_intent_count": 0, "revoked_mandate_count": 0, diff --git a/predicate_authority/client.py b/predicate_authority/client.py index 7909b16..c443678 100644 --- a/predicate_authority/client.py +++ b/predicate_authority/client.py @@ -113,12 +113,15 @@ def authorize( reason=AuthorizationReason.INVALID_MANDATE, violated_rule="revocation_cache", ) + if decision.allowed and decision.mandate is not None: + self._revocation_cache.register_mandate(decision.mandate) return decision def verify_token(self, token: str) -> SignedMandate | None: mandate = self._mandate_signer.verify(token) if mandate is None: return None + self._revocation_cache.register_mandate(mandate) if self._revocation_cache.is_mandate_revoked(mandate): return None return mandate @@ -142,5 +145,5 @@ def verify_delegation_chain( def revoke_principal(self, principal_id: str) -> None: self._revocation_cache.revoke_principal(principal_id) - def revoke_mandate(self, mandate_id: str) -> None: - self._revocation_cache.revoke_mandate_id(mandate_id) + def revoke_mandate(self, mandate_id: str, cascade: bool = False) -> int: + return self._revocation_cache.revoke_mandate_id(mandate_id, cascade=cascade) diff --git a/predicate_authority/control_plane.py b/predicate_authority/control_plane.py index 4e89b93..752f65b 100644 --- a/predicate_authority/control_plane.py +++ b/predicate_authority/control_plane.py @@ -1,6 +1,7 @@ from __future__ import annotations import hashlib +import hmac import http.client import json import time @@ -27,6 +28,7 @@ class ControlPlaneClientConfig: sync_poll_interval_ms: int = 200 sync_project_id: str | None = None sync_environment: str | None = None + replay_signing_secret: str | None = None @dataclass(frozen=True) @@ -212,10 +214,11 @@ def poll_authority_updates( return AuthoritySyncSnapshot.from_payload(payload) def _post_json(self, path: str, payload: Mapping[str, object]) -> bool: + replay_headers = self._build_replay_headers(path) attempts = self.config.max_retries + 1 for attempt in range(attempts): try: - self._post_json_once(path, payload) + self._post_json_once(path, payload, replay_headers=replay_headers) return True except Exception as exc: is_last_attempt = attempt == attempts - 1 @@ -240,12 +243,19 @@ def _get_json(self, path: str) -> Mapping[str, object]: time.sleep(self.config.backoff_initial_s * (2**attempt)) return {} - def _post_json_once(self, path: str, payload: Mapping[str, object]) -> None: + def _post_json_once( + self, + path: str, + payload: Mapping[str, object], + *, + replay_headers: Mapping[str, str], + ) -> None: target_path = path if path.startswith("/") else f"/{path}" connection = self._new_connection() headers = {"Content-Type": "application/json"} if self.config.auth_token: headers["Authorization"] = f"Bearer {self.config.auth_token}" + headers.update(replay_headers) body = json.dumps(payload) try: connection.request("POST", target_path, body=body, headers=headers) @@ -280,6 +290,26 @@ def _new_connection(self) -> http.client.HTTPConnection: return http.client.HTTPSConnection(self._base.netloc, timeout=self.config.timeout_s) return http.client.HTTPConnection(self._base.netloc, timeout=self.config.timeout_s) + def _build_replay_headers(self, path: str) -> dict[str, str]: + timestamp = str(int(time.time())) + nonce = hashlib.sha256( + f"{self.config.tenant_id}|{path}|{time.time_ns()}".encode() + ).hexdigest()[:32] + headers = { + "X-PA-Nonce": nonce, + "X-PA-Timestamp": timestamp, + "X-PA-Idempotency-Token": hashlib.sha256( + f"{nonce}|{timestamp}|{path}".encode() + ).hexdigest()[:32], + } + if self.config.replay_signing_secret is not None: + message = f"{nonce}:{timestamp}:POST:{path}".encode() + signature = hmac.new( + self.config.replay_signing_secret.encode("utf-8"), message, hashlib.sha256 + ).hexdigest() + headers["X-PA-Signature"] = signature + return headers + @dataclass class ControlPlaneTraceEmitter: diff --git a/predicate_authority/daemon.py b/predicate_authority/daemon.py index cd46d15..01be190 100644 --- a/predicate_authority/daemon.py +++ b/predicate_authority/daemon.py @@ -370,11 +370,23 @@ def _handle_revoke_intent(self) -> None: def _handle_revoke_mandate(self) -> None: payload = self._read_json_body() mandate_id = payload.get("mandate_id") + cascade_raw = payload.get("cascade") + cascade = bool(cascade_raw) if isinstance(cascade_raw, bool) else False if not isinstance(mandate_id, str) or mandate_id.strip() == "": self._send_json(400, {"error": "mandate_id is required"}) return - self.server.daemon_ref.revoke_mandate(mandate_id.strip()) # type: ignore[attr-defined] - self._send_json(200, {"ok": True, "mandate_id": mandate_id.strip()}) + revoked_count = self.server.daemon_ref.revoke_mandate( # type: ignore[attr-defined] + mandate_id.strip(), cascade=cascade + ) + self._send_json( + 200, + { + "ok": True, + "mandate_id": mandate_id.strip(), + "cascade": cascade, + "revoked_count": int(revoked_count), + }, + ) def _handle_identity_task(self) -> None: payload = self._read_json_body() @@ -585,8 +597,8 @@ def revoke_principal(self, principal_id: str) -> None: def revoke_intent(self, intent_hash: str) -> None: self._sidecar.revoke_intent_hash(intent_hash) - def revoke_mandate(self, mandate_id: str) -> None: - self._sidecar.revoke_mandate_id(mandate_id) + def revoke_mandate(self, mandate_id: str, cascade: bool = False) -> int: + return self._sidecar.revoke_mandate_id(mandate_id, cascade=cascade) def max_request_body_bytes(self) -> int: return max(0, int(self._config.max_request_body_bytes)) @@ -826,15 +838,16 @@ def _apply_sync_snapshot(self, snapshot: AuthoritySyncSnapshot) -> None: elif item.type == "intent" and item.intent_hash is not None: self._sidecar.revoke_intent_hash(item.intent_hash) elif item.type == "tags": - # Tag revocation support is modeled in control-plane API but not yet represented in - # sidecar's revocation cache keys. - continue + tags = {tag.strip().lower() for tag in item.tags if tag.strip() != ""} + if "global_kill_switch" in tags: + self._sidecar.activate_global_kill_switch() def _build_default_sidecar( mode: AuthorityMode, policy_file: str | None, credential_store_file: str, + mandate_store_file: str | None = None, control_plane_config: ControlPlaneBootstrapConfig | None = None, local_identity_config: LocalIdentityBootstrapConfig | None = None, identity_bridge: ExchangeTokenBridge | None = None, @@ -912,7 +925,7 @@ def _build_default_sidecar( proof_ledger=proof_ledger, identity_bridge=identity_bridge or IdentityBridge(), credential_store=LocalCredentialStore(credential_store_file), - revocation_cache=LocalRevocationCache(), + revocation_cache=LocalRevocationCache(store_file_path=mandate_store_file), policy_engine=policy_engine, local_identity_registry=local_identity_registry, ) @@ -1070,6 +1083,14 @@ def main() -> None: "--credential-store-file", default=str(Path.home() / ".predicate-authorityd" / "credentials.json"), ) + parser.add_argument( + "--mandate-store-file", + default=None, + help=( + "Optional path for persisted local revocation/mandate cache. " + "If omitted, mandate cache remains in-memory (ephemeral default)." + ), + ) parser.add_argument( "--local-identity-enabled", action="store_true", @@ -1307,6 +1328,7 @@ def main() -> None: mode=mode, policy_file=args.policy_file, credential_store_file=args.credential_store_file, + mandate_store_file=args.mandate_store_file, control_plane_config=control_plane_bootstrap, local_identity_config=local_identity_bootstrap, identity_bridge=identity_bridge, diff --git a/predicate_authority/mandate.py b/predicate_authority/mandate.py index ca854b4..db1f1b2 100644 --- a/predicate_authority/mandate.py +++ b/predicate_authority/mandate.py @@ -5,17 +5,78 @@ import hmac import json import time -from dataclasses import asdict +from dataclasses import asdict, dataclass +from typing import Any, Literal, cast + +from cryptography.exceptions import InvalidSignature +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.asymmetric import ec, utils from predicate_contracts import ActionRequest, MandateClaims, SignedMandate +@dataclass(frozen=True) +class _SigningKeyMaterial: + secret_key: bytes + private_key: ec.EllipticCurvePrivateKey + public_key: ec.EllipticCurvePublicKey + + class LocalMandateSigner: - def __init__(self, secret_key: str, ttl_seconds: int = 300) -> None: + def __init__( + self, + secret_key: str, + ttl_seconds: int = 300, + signing_alg: Literal["ES256", "HS256"] = "ES256", + allow_legacy_hs256_verify: bool = True, + token_issuer: str | None = None, + token_audience: str | None = None, + ) -> None: if ttl_seconds <= 0: raise ValueError("ttl_seconds must be > 0") - self._secret_key = secret_key.encode("utf-8") + if signing_alg not in {"ES256", "HS256"}: + raise ValueError("signing_alg must be one of: ES256, HS256") self._ttl_seconds = ttl_seconds + self._signing_alg = signing_alg + self._allow_legacy_hs256_verify = allow_legacy_hs256_verify + self._token_issuer = token_issuer if token_issuer is not None else "predicate-authorityd" + self._token_audience = ( + token_audience if token_audience is not None else "predicate-authority" + ) + initial_kid = self._key_id_for_secret(secret_key) + initial_material = self._build_key_material(secret_key) + self._active_kid = initial_kid + self._next_kid: str | None = None + self._verification_keys: dict[str, _SigningKeyMaterial] = {initial_kid: initial_material} + + def key_lifecycle_status(self) -> dict[str, object]: + return { + "active_kid": self._active_kid, + "next_kid": self._next_kid, + "verification_kids": tuple(sorted(self._verification_keys.keys())), + "signing_alg": self._signing_alg, + } + + def stage_next_signing_key(self, secret_key: str) -> str: + next_kid = self._key_id_for_secret(secret_key) + self._verification_keys[next_kid] = self._build_key_material(secret_key) + self._next_kid = next_kid + return next_kid + + def activate_staged_signing_key(self) -> str: + if self._next_kid is None: + raise RuntimeError("No staged signing key to activate.") + self._active_kid = self._next_kid + self._next_kid = None + return self._active_kid + + def retire_verification_key(self, kid: str) -> bool: + if kid == self._active_kid or kid == self._next_kid: + return False + if kid not in self._verification_keys: + return False + del self._verification_keys[kid] + return True def issue( self, @@ -24,20 +85,26 @@ def issue( ) -> SignedMandate: issued_at = int(time.time()) expires_at = issued_at + self._ttl_seconds + issued_at_ns = time.time_ns() intent_hash = hashlib.sha256(request.action_spec.intent.encode("utf-8")).hexdigest() + delegated_by = parent_mandate.claims.principal_id if parent_mandate is not None else None + parent_mandate_id = parent_mandate.claims.mandate_id if parent_mandate is not None else None + delegation_depth = ( + parent_mandate.claims.delegation_depth + 1 if parent_mandate is not None else 0 + ) mandate_id_seed = ( f"{request.principal.principal_id}|" f"{request.action_spec.action}|" f"{request.action_spec.resource}|" f"{intent_hash}|" f"{request.state_evidence.state_hash}|" - f"{issued_at}" + f"{delegated_by or 'none'}|" + f"{parent_mandate_id or 'none'}|" + f"{delegation_depth}|" + f"{issued_at}|" + f"{issued_at_ns}" ) mandate_id = hashlib.sha256(mandate_id_seed.encode("utf-8")).hexdigest()[:24] - delegated_by = parent_mandate.claims.principal_id if parent_mandate is not None else None - delegation_depth = ( - parent_mandate.claims.delegation_depth + 1 if parent_mandate is not None else 0 - ) delegation_chain_hash = self._compute_delegation_chain_hash( request=request, mandate_id=mandate_id, @@ -57,8 +124,16 @@ def issue( issued_at_epoch_s=issued_at, expires_at_epoch_s=expires_at, delegated_by=delegated_by, + parent_mandate_id=parent_mandate_id, delegation_depth=delegation_depth, delegation_chain_hash=delegation_chain_hash, + iss=self._token_issuer, + aud=self._token_audience, + sub=request.principal.principal_id, + iat=issued_at, + exp=expires_at, + nbf=issued_at, + jti=mandate_id, ) token, signature = self._sign_claims(claims) return SignedMandate(token=token, claims=claims, signature=signature) @@ -69,29 +144,24 @@ def verify(self, token: str) -> SignedMandate | None: return None encoded_header, encoded_payload, encoded_signature = parts + alg, kid = self._parse_header_fields(encoded_header) + if alg is None: + return None signing_input = f"{encoded_header}.{encoded_payload}".encode() - expected_signature = self._hmac(signing_input) - expected_signature_encoded = self._base64url_encode(expected_signature) - if not hmac.compare_digest(expected_signature_encoded, encoded_signature): + if not self._verify_signature( + alg=alg, + signing_input=signing_input, + encoded_signature=encoded_signature, + kid=kid, + ): return None - try: - payload_json = self._base64url_decode(encoded_payload).decode("utf-8") - payload = json.loads(payload_json) - claims = MandateClaims(**payload) - except (ValueError, TypeError, json.JSONDecodeError): + claims = self._parse_claims(encoded_payload) + if claims is None: return None now_epoch = int(time.time()) - if claims.expires_at_epoch_s < now_epoch: - return None - if claims.delegation_depth < 0: - return None - if claims.delegation_depth == 0 and claims.delegated_by is not None: - return None - if claims.delegation_depth > 0 and claims.delegated_by is None: - return None - if claims.delegation_chain_hash is None: + if not self._claims_valid_for_epoch(claims, now_epoch): return None return SignedMandate(token=token, claims=claims, signature=encoded_signature) @@ -124,20 +194,165 @@ def verify_delegation( return hmac.compare_digest(expected_hash, claims.delegation_chain_hash) def _sign_claims(self, claims: MandateClaims) -> tuple[str, str]: + active_material = self._verification_keys[self._active_kid] header_json = json.dumps( - {"alg": "HS256", "typ": "JWT"}, separators=(",", ":"), sort_keys=True + {"alg": self._signing_alg, "typ": "JWT", "kid": self._active_kid}, + separators=(",", ":"), + sort_keys=True, ) payload_json = json.dumps(asdict(claims), separators=(",", ":"), sort_keys=True) encoded_header = self._base64url_encode(header_json.encode("utf-8")) encoded_payload = self._base64url_encode(payload_json.encode("utf-8")) signing_input = f"{encoded_header}.{encoded_payload}".encode() - signature = self._base64url_encode(self._hmac(signing_input)) + if self._signing_alg == "HS256": + signature_bytes = self._hmac(signing_input, secret_key=active_material.secret_key) + else: + der_signature = active_material.private_key.sign( + signing_input, ec.ECDSA(hashes.SHA256()) + ) + signature_bytes = self._der_signature_to_raw(der_signature) + signature = self._base64url_encode(signature_bytes) token = f"{encoded_header}.{encoded_payload}.{signature}" return token, signature - def _hmac(self, payload: bytes) -> bytes: - return hmac.new(self._secret_key, payload, hashlib.sha256).digest() + def _hmac(self, payload: bytes, secret_key: bytes) -> bytes: + return hmac.new(secret_key, payload, hashlib.sha256).digest() + + def _verify_signature( + self, + alg: str, + signing_input: bytes, + encoded_signature: str, + kid: str | None, + ) -> bool: + if alg == "HS256": + if not self._allow_legacy_hs256_verify and self._signing_alg != "HS256": + return False + candidate_kids = self._candidate_kids_for_verify(kid) + for candidate_kid in candidate_kids: + material = self._verification_keys.get(candidate_kid) + if material is None: + continue + expected_signature = self._hmac(signing_input, secret_key=material.secret_key) + expected_signature_encoded = self._base64url_encode(expected_signature) + if hmac.compare_digest(expected_signature_encoded, encoded_signature): + return True + return False + if alg == "ES256": + candidate_kids = self._candidate_kids_for_verify(kid) + try: + raw_signature = self._base64url_decode(encoded_signature) + der_signature = self._raw_signature_to_der(raw_signature) + except ValueError: + return False + for candidate_kid in candidate_kids: + material = self._verification_keys.get(candidate_kid) + if material is None: + continue + try: + material.public_key.verify( + der_signature, signing_input, ec.ECDSA(hashes.SHA256()) + ) + return True + except InvalidSignature: + continue + return False + return False + + def _candidate_kids_for_verify(self, kid: str | None) -> tuple[str, ...]: + if isinstance(kid, str) and kid.strip() != "": + normalized = kid.strip() + if normalized in self._verification_keys: + return (normalized,) + if self._active_kid in self._verification_keys: + return (self._active_kid, *tuple(self._verification_keys.keys())) + return tuple(self._verification_keys.keys()) + + def _parse_header_fields(self, encoded_header: str) -> tuple[str | None, str | None]: + try: + header_json = self._base64url_decode(encoded_header).decode("utf-8") + loaded_header = json.loads(header_json) + except (ValueError, TypeError, json.JSONDecodeError): + return None, None + if not isinstance(loaded_header, dict): + return None, None + header: dict[str, Any] = loaded_header + alg = header.get("alg") + if not isinstance(alg, str): + return None, None + kid_value = header.get("kid") + kid = kid_value if isinstance(kid_value, str) else None + return alg, kid + + def _parse_claims(self, encoded_payload: str) -> MandateClaims | None: + try: + payload_json = self._base64url_decode(encoded_payload).decode("utf-8") + loaded_payload = json.loads(payload_json) + except (ValueError, TypeError, json.JSONDecodeError): + return None + if not isinstance(loaded_payload, dict): + return None + try: + return MandateClaims(**loaded_payload) + except TypeError: + return None + + @staticmethod + def _claims_valid_for_epoch(claims: MandateClaims, now_epoch: int) -> bool: + effective_exp = claims.exp if claims.exp is not None else claims.expires_at_epoch_s + if effective_exp < now_epoch: + return False + if claims.iat is not None and claims.iat > now_epoch: + return False + if claims.nbf is not None and claims.nbf > now_epoch: + return False + if claims.delegation_depth < 0: + return False + if claims.delegation_depth == 0 and claims.delegated_by is not None: + return False + if claims.delegation_depth > 0 and claims.delegated_by is None: + return False + return claims.delegation_chain_hash is not None + + @classmethod + def _build_key_material(cls, secret_key: str) -> _SigningKeyMaterial: + secret = secret_key.encode("utf-8") + private_key = cls._derive_private_key(secret_key) + return _SigningKeyMaterial( + secret_key=secret, + private_key=private_key, + public_key=private_key.public_key(), + ) + + @staticmethod + def _key_id_for_secret(secret_key: str) -> str: + return hashlib.sha256(f"mandate-signing:{secret_key}".encode()).hexdigest()[:16] + + @staticmethod + def _derive_private_key(secret_key: str) -> ec.EllipticCurvePrivateKey: + digest = hashlib.sha256(secret_key.encode("utf-8")).digest() + private_value = int.from_bytes(digest, "big") + # Ensure private value is in valid range for the curve. + order = int( + "FFFFFFFF00000000FFFFFFFFFFFFFFFFBCE6FAADA7179E84F3B9CAC2FC632551", + 16, + ) + private_value = (private_value % (order - 1)) + 1 + return ec.derive_private_key(private_value, ec.SECP256R1()) + + @staticmethod + def _der_signature_to_raw(der_signature: bytes) -> bytes: + r_value, s_value = cast(tuple[int, int], utils.decode_dss_signature(der_signature)) + return r_value.to_bytes(32, "big") + s_value.to_bytes(32, "big") + + @staticmethod + def _raw_signature_to_der(raw_signature: bytes) -> bytes: + if len(raw_signature) != 64: + raise ValueError("ES256 signature must be 64 bytes") + r_value = int.from_bytes(raw_signature[:32], "big") + s_value = int.from_bytes(raw_signature[32:], "big") + return cast(bytes, utils.encode_dss_signature(r_value, s_value)) @staticmethod def _compute_delegation_chain_hash( diff --git a/predicate_authority/pyproject.toml b/predicate_authority/pyproject.toml index 62c30c4..83a4a18 100644 --- a/predicate_authority/pyproject.toml +++ b/predicate_authority/pyproject.toml @@ -18,6 +18,7 @@ classifiers = [ dependencies = [ "predicate-contracts>=0.1.0,<0.5.0", "pyyaml>=6.0", + "cryptography>=42.0.0", ] [project.scripts] diff --git a/predicate_authority/revocation.py b/predicate_authority/revocation.py index b4bda6b..062def9 100644 --- a/predicate_authority/revocation.py +++ b/predicate_authority/revocation.py @@ -1,17 +1,33 @@ from __future__ import annotations import hashlib +import json +import os +import uuid +from pathlib import Path from threading import Lock +from typing import Any from predicate_contracts import ActionRequest, SignedMandate class LocalRevocationCache: - def __init__(self) -> None: + def __init__(self, store_file_path: str | None = None) -> None: + self._store_file_path = ( + Path(store_file_path) + if isinstance(store_file_path, str) and store_file_path.strip() != "" + else None + ) self._revoked_principal_ids: set[str] = set() self._revoked_intent_hashes: set[str] = set() self._revoked_mandate_ids: set[str] = set() + self._global_kill_switch_enabled = False + self._mandate_parent_by_id: dict[str, str] = {} + self._mandate_children_by_id: dict[str, set[str]] = {} self._lock = Lock() + if self._store_file_path is not None: + self._ensure_store_path() + self._load_from_store() @property def revoked_principal_ids(self) -> set[str]: @@ -40,20 +56,53 @@ def revoked_mandate_count(self) -> int: with self._lock: return len(self._revoked_mandate_ids) + def persistence_enabled(self) -> bool: + return self._store_file_path is not None + + def global_kill_switch_enabled(self) -> bool: + with self._lock: + return self._global_kill_switch_enabled + def revoke_principal(self, principal_id: str) -> None: with self._lock: self._revoked_principal_ids.add(principal_id) + self._persist_unlocked() def revoke_intent_hash(self, intent_hash: str) -> None: with self._lock: self._revoked_intent_hashes.add(intent_hash) + self._persist_unlocked() - def revoke_mandate_id(self, mandate_id: str) -> None: + def revoke_mandate_id(self, mandate_id: str, cascade: bool = False) -> int: with self._lock: self._revoked_mandate_ids.add(mandate_id) + revoked_count = 1 + if cascade: + revoked_count += self._revoke_mandate_descendants_locked(mandate_id) + self._persist_unlocked() + return revoked_count + + def enable_global_kill_switch(self) -> None: + with self._lock: + self._global_kill_switch_enabled = True + self._persist_unlocked() + + def register_mandate(self, mandate: SignedMandate) -> None: + with self._lock: + mandate_id = mandate.claims.mandate_id + parent_mandate_id = mandate.claims.parent_mandate_id + if parent_mandate_id is None or parent_mandate_id.strip() == "": + return + parent_id = parent_mandate_id.strip() + self._mandate_parent_by_id[mandate_id] = parent_id + children = self._mandate_children_by_id.setdefault(parent_id, set()) + children.add(mandate_id) + self._persist_unlocked() def is_request_revoked(self, request: ActionRequest) -> bool: with self._lock: + if self._global_kill_switch_enabled: + return True if request.principal.principal_id in self._revoked_principal_ids: return True intent_hash = hashlib.sha256(request.action_spec.intent.encode("utf-8")).hexdigest() @@ -61,8 +110,143 @@ def is_request_revoked(self, request: ActionRequest) -> bool: def is_mandate_revoked(self, mandate: SignedMandate) -> bool: with self._lock: + if self._global_kill_switch_enabled: + return True if mandate.claims.principal_id in self._revoked_principal_ids: return True if mandate.claims.intent_hash in self._revoked_intent_hashes: return True return mandate.claims.mandate_id in self._revoked_mandate_ids + + def _revoke_mandate_descendants_locked(self, root_mandate_id: str) -> int: + revoked_count = 0 + queue = [root_mandate_id] + visited: set[str] = set() + while len(queue) > 0: + current = queue.pop(0) + if current in visited: + continue + visited.add(current) + children = self._mandate_children_by_id.get(current, set()) + for child_id in children: + if child_id not in self._revoked_mandate_ids: + self._revoked_mandate_ids.add(child_id) + revoked_count += 1 + queue.append(child_id) + return revoked_count + + def _ensure_store_path(self) -> None: + if self._store_file_path is None: + return + self._store_file_path.parent.mkdir(parents=True, exist_ok=True) + try: + os.chmod(self._store_file_path.parent, 0o700) + except OSError: + pass + if not self._store_file_path.exists(): + self._atomic_write_json(self._default_payload()) + self._chmod_file_safe() + + def _load_from_store(self) -> None: + if self._store_file_path is None: + return + loaded = self._read_store_payload() + self._revoked_principal_ids = self._parse_string_set(loaded.get("revoked_principal_ids")) + self._revoked_intent_hashes = self._parse_string_set(loaded.get("revoked_intent_hashes")) + self._revoked_mandate_ids = self._parse_string_set(loaded.get("revoked_mandate_ids")) + self._global_kill_switch_enabled = bool(loaded.get("global_kill_switch_enabled", False)) + self._mandate_parent_by_id = self._parse_string_map(loaded.get("mandate_parent_by_id")) + self._mandate_children_by_id = self._parse_children_map( + loaded.get("mandate_children_by_id") + ) + + def _persist_unlocked(self) -> None: + if self._store_file_path is None: + return + payload = { + "schema_version": 1, + "revoked_principal_ids": sorted(self._revoked_principal_ids), + "revoked_intent_hashes": sorted(self._revoked_intent_hashes), + "revoked_mandate_ids": sorted(self._revoked_mandate_ids), + "global_kill_switch_enabled": self._global_kill_switch_enabled, + "mandate_parent_by_id": dict(sorted(self._mandate_parent_by_id.items())), + "mandate_children_by_id": { + parent_id: sorted(children) + for parent_id, children in sorted(self._mandate_children_by_id.items()) + }, + } + self._atomic_write_json(payload) + self._chmod_file_safe() + + def _read_store_payload(self) -> dict[str, Any]: + if self._store_file_path is None or not self._store_file_path.exists(): + return self._default_payload() + content = self._store_file_path.read_text(encoding="utf-8").strip() + if content == "": + return self._default_payload() + try: + loaded = json.loads(content) + except json.JSONDecodeError: + return self._default_payload() + if not isinstance(loaded, dict): + return self._default_payload() + return loaded + + def _atomic_write_json(self, payload: dict[str, Any]) -> None: + if self._store_file_path is None: + return + tmp_path = self._store_file_path.with_name( + f"{self._store_file_path.name}.{uuid.uuid4().hex}.tmp" + ) + tmp_path.write_text(json.dumps(payload, indent=2), encoding="utf-8") + os.replace(tmp_path, self._store_file_path) + + def _chmod_file_safe(self) -> None: + if self._store_file_path is None: + return + try: + os.chmod(self._store_file_path, 0o600) + except OSError: + pass + + @staticmethod + def _default_payload() -> dict[str, Any]: + return { + "schema_version": 1, + "revoked_principal_ids": [], + "revoked_intent_hashes": [], + "revoked_mandate_ids": [], + "global_kill_switch_enabled": False, + "mandate_parent_by_id": {}, + "mandate_children_by_id": {}, + } + + @staticmethod + def _parse_string_set(raw: object) -> set[str]: + if not isinstance(raw, list): + return set() + return {str(item) for item in raw if isinstance(item, str) and item.strip() != ""} + + @staticmethod + def _parse_string_map(raw: object) -> dict[str, str]: + if not isinstance(raw, dict): + return {} + result: dict[str, str] = {} + for key, value in raw.items(): + if not isinstance(key, str) or key.strip() == "": + continue + if not isinstance(value, str) or value.strip() == "": + continue + result[key] = value + return result + + @classmethod + def _parse_children_map(cls, raw: object) -> dict[str, set[str]]: + if not isinstance(raw, dict): + return {} + result: dict[str, set[str]] = {} + for key, value in raw.items(): + if not isinstance(key, str) or key.strip() == "": + continue + result[key] = cls._parse_string_set(value) + return result diff --git a/predicate_authority/sidecar.py b/predicate_authority/sidecar.py index ae47145..d955302 100644 --- a/predicate_authority/sidecar.py +++ b/predicate_authority/sidecar.py @@ -43,6 +43,8 @@ class SidecarConfig: class SidecarStatus: mode: AuthorityMode policy_hot_reload_enabled: bool + mandate_store_persistence_enabled: bool + global_kill_switch_enabled: bool revoked_principal_count: int revoked_intent_count: int revoked_mandate_count: int @@ -140,6 +142,7 @@ def issue_mandate(self, request: ActionRequest) -> AuthorizationDecision: return decision decision = self._action_guard.authorize(request) if decision.allowed and decision.mandate is not None: + self._revocation_cache.register_mandate(decision.mandate) if self._revocation_cache.is_mandate_revoked(decision.mandate): revoked_decision = AuthorizationDecision( allowed=False, @@ -182,8 +185,11 @@ def revoke_by_invariant(self, principal_id: str) -> None: def revoke_intent_hash(self, intent_hash: str) -> None: self._revocation_cache.revoke_intent_hash(intent_hash) - def revoke_mandate_id(self, mandate_id: str) -> None: - self._revocation_cache.revoke_mandate_id(mandate_id) + def revoke_mandate_id(self, mandate_id: str, cascade: bool = False) -> int: + return self._revocation_cache.revoke_mandate_id(mandate_id, cascade=cascade) + + def activate_global_kill_switch(self) -> None: + self._revocation_cache.enable_global_kill_switch() def hot_reload_policy(self) -> bool: if self._policy_source is None: @@ -219,6 +225,8 @@ def status(self) -> SidecarStatus: return SidecarStatus( mode=self._config.mode, policy_hot_reload_enabled=self._policy_source is not None, + mandate_store_persistence_enabled=self._revocation_cache.persistence_enabled(), + global_kill_switch_enabled=self._revocation_cache.global_kill_switch_enabled(), revoked_principal_count=self._revocation_cache.revoked_principal_count(), revoked_intent_count=self._revocation_cache.revoked_intent_count(), revoked_mandate_count=self._revocation_cache.revoked_mandate_count(), diff --git a/predicate_contracts/models.py b/predicate_contracts/models.py index fb81689..a1dd6e0 100644 --- a/predicate_contracts/models.py +++ b/predicate_contracts/models.py @@ -96,8 +96,16 @@ class MandateClaims: issued_at_epoch_s: int expires_at_epoch_s: int delegated_by: str | None = None + parent_mandate_id: str | None = None delegation_depth: int = 0 delegation_chain_hash: str | None = None + iss: str | None = None + aud: str | None = None + sub: str | None = None + iat: int | None = None + exp: int | None = None + nbf: int | None = None + jti: str | None = None @dataclass(frozen=True) diff --git a/tests/test_authority_client_local_yaml.py b/tests/test_authority_client_local_yaml.py index 78bcc59..34e9013 100644 --- a/tests/test_authority_client_local_yaml.py +++ b/tests/test_authority_client_local_yaml.py @@ -16,6 +16,19 @@ ) +def _request() -> ActionRequest: + return ActionRequest( + principal=PrincipalRef(principal_id="agent:checkout"), + action_spec=ActionSpec( + action="http.post", + resource="https://api.vendor.com/orders", + intent="submit order", + ), + state_evidence=StateEvidence(source="unit-test", state_hash="sha256:test"), + verification_evidence=VerificationEvidence(), + ) + + def test_authority_client_mint_and_verify_with_local_yaml_policy(tmp_path: Path) -> None: policy = tmp_path / "policy.yaml" policy.write_text( @@ -37,16 +50,7 @@ def test_authority_client_mint_and_verify_with_local_yaml_policy(tmp_path: Path) context = AuthorityClient.from_policy_file(str(policy), secret_key="local-test-secret") client = context.client - request = ActionRequest( - principal=PrincipalRef(principal_id="agent:checkout"), - action_spec=ActionSpec( - action="http.post", - resource="https://api.vendor.com/orders", - intent="submit order", - ), - state_evidence=StateEvidence(source="unit-test", state_hash="sha256:test"), - verification_evidence=VerificationEvidence(), - ) + request = _request() decision = client.authorize(request) assert decision.allowed @@ -77,16 +81,7 @@ def test_authority_client_global_max_depth_from_yaml_is_enforced(tmp_path: Path) ) context = AuthorityClient.from_policy_file(str(policy), secret_key="local-test-secret") client = context.client - request = ActionRequest( - principal=PrincipalRef(principal_id="agent:checkout"), - action_spec=ActionSpec( - action="http.post", - resource="https://api.vendor.com/orders", - intent="submit order", - ), - state_evidence=StateEvidence(source="unit-test", state_hash="sha256:test"), - verification_evidence=VerificationEvidence(), - ) + request = _request() root = client.authorize(request) assert root.allowed is True assert root.mandate is not None @@ -119,3 +114,75 @@ def test_authority_client_from_env(tmp_path: Path, monkeypatch: MonkeyPatch) -> monkeypatch.setenv("PREDICATE_AUTHORITY_MANDATE_TTL_SECONDS", "120") context = AuthorityClient.from_env() assert context.policy_file == str(policy) + + +def test_revoke_mandate_without_cascade_keeps_child_active(tmp_path: Path) -> None: + policy = tmp_path / "policy.yaml" + policy.write_text( + "\n".join( + [ + "rules:", + " - name: allow-orders-create", + " effect: allow", + " principals:", + " - agent:checkout", + " actions:", + " - http.post", + " resources:", + " - https://api.vendor.com/orders", + ] + ), + encoding="utf-8", + ) + context = AuthorityClient.from_policy_file(str(policy), secret_key="local-test-secret") + client = context.client + request = _request() + + root = client.authorize(request) + assert root.allowed is True + assert root.mandate is not None + child = client.authorize(request, parent_mandate=root.mandate) + assert child.allowed is True + assert child.mandate is not None + + revoked_count = client.revoke_mandate(root.mandate.claims.mandate_id, cascade=False) + child_verified = client.verify_token(child.mandate.token) + + assert revoked_count == 1 + assert child_verified is not None + + +def test_revoke_mandate_with_cascade_revokes_descendants(tmp_path: Path) -> None: + policy = tmp_path / "policy.yaml" + policy.write_text( + "\n".join( + [ + "rules:", + " - name: allow-orders-create", + " effect: allow", + " principals:", + " - agent:checkout", + " actions:", + " - http.post", + " resources:", + " - https://api.vendor.com/orders", + ] + ), + encoding="utf-8", + ) + context = AuthorityClient.from_policy_file(str(policy), secret_key="local-test-secret") + client = context.client + request = _request() + + root = client.authorize(request) + assert root.allowed is True + assert root.mandate is not None + child = client.authorize(request, parent_mandate=root.mandate) + assert child.allowed is True + assert child.mandate is not None + + revoked_count = client.revoke_mandate(root.mandate.claims.mandate_id, cascade=True) + child_verified = client.verify_token(child.mandate.token) + + assert revoked_count >= 2 + assert child_verified is None diff --git a/tests/test_control_plane_integration.py b/tests/test_control_plane_integration.py index 44fe588..7051c33 100644 --- a/tests/test_control_plane_integration.py +++ b/tests/test_control_plane_integration.py @@ -102,6 +102,9 @@ def test_control_plane_client_posts_audit_and_usage() -> None: assert any( headers.get("Authorization") == "Bearer token-123" for headers in recorder.headers ) + assert all("X-PA-Nonce" in headers for headers in recorder.headers) + assert all("X-PA-Timestamp" in headers for headers in recorder.headers) + assert all("X-PA-Idempotency-Token" in headers for headers in recorder.headers) finally: server.shutdown() server.server_close() @@ -142,6 +145,42 @@ def test_control_plane_trace_emitter_sends_from_proof_event() -> None: server.server_close() +def test_control_plane_client_includes_replay_signature_when_configured() -> None: + recorder = Recorder() + server, _ = _start_server(recorder) + try: + base_url = f"http://127.0.0.1:{server.server_port}" + client = ControlPlaneClient( + ControlPlaneClientConfig( + base_url=base_url, + tenant_id="tenant-a", + project_id="project-a", + replay_signing_secret="test-replay-secret", + fail_open=False, + ) + ) + sent = client.send_audit_events( + ( + AuditEventEnvelope( + event_id="evt_1", + tenant_id="tenant-a", + principal_id="agent:orders-1", + action="http.post", + resource="https://api.vendor.com/orders", + allowed=True, + reason="allowed", + timestamp="2026-01-01T00:00:00+00:00", + ), + ) + ) + assert sent is True + assert len(recorder.headers) == 1 + assert "X-PA-Signature" in recorder.headers[0] + finally: + server.shutdown() + server.server_close() + + def test_control_plane_client_fail_open_returns_false() -> None: client = ControlPlaneClient( ControlPlaneClientConfig( diff --git a/tests/test_daemon_phase2.py b/tests/test_daemon_phase2.py index 57f747f..a35118f 100644 --- a/tests/test_daemon_phase2.py +++ b/tests/test_daemon_phase2.py @@ -168,6 +168,29 @@ def test_daemon_exposes_health_and_status_endpoints(tmp_path: Path) -> None: assert health["mode"] == "local_only" assert status["daemon_running"] is True assert status["policy_hot_reload_enabled"] is True + assert status["mandate_store_persistence_enabled"] is False + finally: + daemon.stop() + + +def test_daemon_status_exposes_mandate_store_persistence_mode(tmp_path: Path) -> None: + policy_file = tmp_path / "policy.json" + policy_file.write_text(json.dumps({"rules": []}), encoding="utf-8") + sidecar = _build_default_sidecar( + mode=AuthorityMode.LOCAL_ONLY, + policy_file=str(policy_file), + credential_store_file=str(tmp_path / "credentials.json"), + mandate_store_file=str(tmp_path / "mandates.json"), + ) + daemon = PredicateAuthorityDaemon( + sidecar=sidecar, + config=DaemonConfig(host="127.0.0.1", port=0, policy_poll_interval_s=0.05), + ) + daemon.start() + try: + base_url = f"http://127.0.0.1:{daemon.bound_port}" + status = _fetch_json(f"{base_url}/status") + assert status["mandate_store_persistence_enabled"] is True finally: daemon.stop() @@ -321,6 +344,8 @@ def test_daemon_supports_policy_reload_and_revoke_endpoints(tmp_path: Path) -> N assert revoke_principal["ok"] is True assert revoke_intent["ok"] is True assert revoke_mandate["ok"] is True + assert revoke_mandate["cascade"] is False + assert int(revoke_mandate["revoked_count"]) >= 1 assert int(status["revoked_principal_count"]) >= 1 assert int(status["revoked_intent_count"]) >= 1 assert int(status["revoked_mandate_count"]) >= 1 @@ -757,6 +782,128 @@ def test_daemon_long_poll_sync_applies_policy_and_revocations(tmp_path: Path) -> server.server_close() +def test_daemon_long_poll_sync_applies_global_kill_switch_tags(tmp_path: Path) -> None: + class SyncKillSwitchHandler(BaseHTTPRequestHandler): + requests: list[str] = [] + + def do_GET(self) -> None: # noqa: N802 + parsed = urlsplit(self.path) + if parsed.path != "/v1/sync/authority-updates": + self.send_response(404) + self.send_header("Content-Type", "application/json") + self.end_headers() + self.wfile.write(b'{"error":"not_found"}') + return + self.requests.append(self.path) + payload = { + "changed": True, + "sync_token": "sync-kill-1", + "tenant_id": "tenant-sync", + "project_id": "project-sync", + "environment": "prod", + "policy_id": "pol-sync-1", + "policy_revision": 1, + "policy_document": { + "rules": [ + { + "name": "allow-sync-http", + "effect": "allow", + "principals": ["agent:*"], + "actions": ["http.*"], + "resources": ["https://*/*"], + } + ] + }, + "revocations": [ + { + "revocation_id": "rev-kill-1", + "tenant_id": "tenant-sync", + "type": "tags", + "principal_id": None, + "intent_hash": None, + "tags": ["global_kill_switch"], + "reason": "incident", + "created_at": "2026-02-19T00:00:00+00:00", + } + ], + } + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.end_headers() + self.wfile.write(json.dumps(payload).encode("utf-8")) + + def do_POST(self) -> None: # noqa: N802 + raw_length = self.headers.get("Content-Length", "0") + content_length = int(raw_length) if raw_length.isdigit() else 0 + _ = self.rfile.read(content_length) if content_length > 0 else b"" + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.end_headers() + self.wfile.write(b"{}") + + def log_message(self, fmt: str, *args: Any) -> None: # noqa: A003 + _ = fmt + return + + server = ThreadingHTTPServer(("127.0.0.1", 0), SyncKillSwitchHandler) + thread = threading.Thread(target=server.serve_forever, daemon=True) + thread.start() + daemon: PredicateAuthorityDaemon | None = None + try: + sidecar = _build_default_sidecar( + mode=AuthorityMode.CLOUD_CONNECTED, + policy_file=None, + credential_store_file=str(tmp_path / "credentials.json"), + control_plane_config=ControlPlaneBootstrapConfig( + enabled=True, + base_url=f"http://127.0.0.1:{server.server_port}", + tenant_id="tenant-sync", + project_id="project-sync", + auth_token="token-sync", + fail_open=False, + sync_enabled=True, + sync_wait_timeout_s=0.2, + sync_poll_interval_ms=50, + sync_project_id="project-sync", + sync_environment="prod", + ), + ) + daemon = PredicateAuthorityDaemon( + sidecar=sidecar, + config=DaemonConfig(host="127.0.0.1", port=0, policy_poll_interval_s=1.0), + ) + daemon.start() + status_url = f"http://127.0.0.1:{daemon.bound_port}/status" + deadline = time.time() + 2.0 + while time.time() < deadline: + status = _fetch_json(status_url) + if status.get("global_kill_switch_enabled") is True: + break + time.sleep(0.05) + else: + raise AssertionError("global kill-switch tag was not applied in time") + + denied = sidecar.issue_mandate( + ActionRequest( + principal=PrincipalRef(principal_id="agent:any"), + action_spec=ActionSpec( + action="http.post", + resource="https://api.vendor.com/orders", + intent="sync path", + ), + state_evidence=StateEvidence(source="test", state_hash="sync-kill"), + verification_evidence=VerificationEvidence(), + ) + ) + assert denied.allowed is False + assert denied.reason.value == "invalid_mandate" + finally: + if daemon is not None: + daemon.stop() + server.shutdown() + server.server_close() + + def test_daemon_identity_mode_local_idp_builder() -> None: os.environ["LOCAL_IDP_SIGNING_KEY"] = "daemon-local-idp-key" args = Namespace( @@ -1420,3 +1567,157 @@ def test_daemon_restart_recovers_queue_after_partition(tmp_path: Path) -> None: daemon_after_restart.stop() healthy_server.shutdown() healthy_server.server_close() + + +def test_daemon_restart_recovers_revocations_when_mandate_store_enabled(tmp_path: Path) -> None: + policy_file = tmp_path / "policy.json" + policy_file.write_text( + json.dumps( + { + "rules": [ + { + "name": "allow-any-http", + "effect": "allow", + "principals": ["agent:*"], + "actions": ["http.*"], + "resources": ["https://*/*"], + } + ] + } + ), + encoding="utf-8", + ) + mandate_store_file = tmp_path / "mandates.json" + daemon: PredicateAuthorityDaemon | None = None + try: + sidecar = _build_default_sidecar( + mode=AuthorityMode.LOCAL_ONLY, + policy_file=str(policy_file), + credential_store_file=str(tmp_path / "credentials.json"), + mandate_store_file=str(mandate_store_file), + ) + daemon = PredicateAuthorityDaemon( + sidecar=sidecar, + config=DaemonConfig(host="127.0.0.1", port=0, policy_poll_interval_s=10.0), + ) + daemon.start() + base_url = f"http://127.0.0.1:{daemon.bound_port}" + pre_revoke_status, pre_revoke_payload = _post_json_with_status( + f"{base_url}/v1/authorize", + { + "principal": "agent:persisted-revoke", + "action": "http.post", + "resource": "https://api.vendor.com/orders", + "intent_hash": "persisted-revocation", + "state_evidence": {"source": "test", "state_hash": "persisted-state"}, + }, + ) + assert pre_revoke_status == 200 + assert pre_revoke_payload["allowed"] is True + revoke_response = _post_json( + f"{base_url}/revoke/principal", {"principal_id": "agent:persisted-revoke"} + ) + assert revoke_response["ok"] is True + finally: + if daemon is not None: + daemon.stop() + + daemon_after_restart: PredicateAuthorityDaemon | None = None + try: + restarted_sidecar = _build_default_sidecar( + mode=AuthorityMode.LOCAL_ONLY, + policy_file=str(policy_file), + credential_store_file=str(tmp_path / "credentials.json"), + mandate_store_file=str(mandate_store_file), + ) + daemon_after_restart = PredicateAuthorityDaemon( + sidecar=restarted_sidecar, + config=DaemonConfig(host="127.0.0.1", port=0, policy_poll_interval_s=10.0), + ) + daemon_after_restart.start() + base_url = f"http://127.0.0.1:{daemon_after_restart.bound_port}" + post_revoke_status, post_revoke_payload = _post_json_with_status( + f"{base_url}/v1/authorize", + { + "principal": "agent:persisted-revoke", + "action": "http.post", + "resource": "https://api.vendor.com/orders", + "intent_hash": "persisted-revocation", + "state_evidence": {"source": "test", "state_hash": "persisted-state"}, + }, + ) + assert post_revoke_status == 403 + assert post_revoke_payload["allowed"] is False + assert post_revoke_payload["reason"] == "invalid_mandate" + finally: + if daemon_after_restart is not None: + daemon_after_restart.stop() + + +def test_daemon_restart_drops_revocations_when_mandate_store_disabled(tmp_path: Path) -> None: + policy_file = tmp_path / "policy.json" + policy_file.write_text( + json.dumps( + { + "rules": [ + { + "name": "allow-any-http", + "effect": "allow", + "principals": ["agent:*"], + "actions": ["http.*"], + "resources": ["https://*/*"], + } + ] + } + ), + encoding="utf-8", + ) + daemon: PredicateAuthorityDaemon | None = None + try: + sidecar = _build_default_sidecar( + mode=AuthorityMode.LOCAL_ONLY, + policy_file=str(policy_file), + credential_store_file=str(tmp_path / "credentials.json"), + ) + daemon = PredicateAuthorityDaemon( + sidecar=sidecar, + config=DaemonConfig(host="127.0.0.1", port=0, policy_poll_interval_s=10.0), + ) + daemon.start() + base_url = f"http://127.0.0.1:{daemon.bound_port}" + revoke_response = _post_json( + f"{base_url}/revoke/principal", {"principal_id": "agent:ephemeral-revoke"} + ) + assert revoke_response["ok"] is True + finally: + if daemon is not None: + daemon.stop() + + daemon_after_restart: PredicateAuthorityDaemon | None = None + try: + restarted_sidecar = _build_default_sidecar( + mode=AuthorityMode.LOCAL_ONLY, + policy_file=str(policy_file), + credential_store_file=str(tmp_path / "credentials.json"), + ) + daemon_after_restart = PredicateAuthorityDaemon( + sidecar=restarted_sidecar, + config=DaemonConfig(host="127.0.0.1", port=0, policy_poll_interval_s=10.0), + ) + daemon_after_restart.start() + base_url = f"http://127.0.0.1:{daemon_after_restart.bound_port}" + post_restart_status, post_restart_payload = _post_json_with_status( + f"{base_url}/v1/authorize", + { + "principal": "agent:ephemeral-revoke", + "action": "http.post", + "resource": "https://api.vendor.com/orders", + "intent_hash": "ephemeral-revocation", + "state_evidence": {"source": "test", "state_hash": "ephemeral-state"}, + }, + ) + assert post_restart_status == 200 + assert post_restart_payload["allowed"] is True + finally: + if daemon_after_restart is not None: + daemon_after_restart.stop() diff --git a/tests/test_mandate_signer.py b/tests/test_mandate_signer.py index 1534a9b..11fdd68 100644 --- a/tests/test_mandate_signer.py +++ b/tests/test_mandate_signer.py @@ -1,5 +1,11 @@ from __future__ import annotations +import base64 +import hmac +import json +from dataclasses import asdict +from hashlib import sha256 + # pylint: disable=import-error from predicate_authority import LocalMandateSigner from predicate_contracts import ( @@ -11,6 +17,32 @@ ) +def _jwt_header(token: str) -> dict[str, object]: + encoded_header = token.split(".")[0] + padding = "=" * ((4 - len(encoded_header) % 4) % 4) + return json.loads(base64.urlsafe_b64decode(encoded_header + padding).decode("utf-8")) + + +def _jwt_payload(token: str) -> dict[str, object]: + encoded_payload = token.split(".")[1] + padding = "=" * ((4 - len(encoded_payload) % 4) % 4) + return json.loads(base64.urlsafe_b64decode(encoded_payload + padding).decode("utf-8")) + + +def _base64url(value: bytes) -> str: + return base64.urlsafe_b64encode(value).rstrip(b"=").decode("ascii") + + +def _build_legacy_hs256_token(secret_key: str, payload: dict[str, object]) -> str: + header_json = json.dumps({"alg": "HS256", "typ": "JWT"}, separators=(",", ":"), sort_keys=True) + payload_json = json.dumps(payload, separators=(",", ":"), sort_keys=True) + encoded_header = _base64url(header_json.encode("utf-8")) + encoded_payload = _base64url(payload_json.encode("utf-8")) + signing_input = f"{encoded_header}.{encoded_payload}".encode() + signature = hmac.new(secret_key.encode("utf-8"), signing_input, sha256).digest() + return f"{encoded_header}.{encoded_payload}.{_base64url(signature)}" + + def test_mandate_signature_verifies() -> None: signer = LocalMandateSigner(secret_key="test-key", ttl_seconds=60) request = ActionRequest( @@ -45,7 +77,9 @@ def test_mandate_tamper_is_rejected() -> None: ) signed = signer.issue(request) - tampered = signed.token[:-1] + ("A" if signed.token[-1] != "A" else "B") + token_parts = signed.token.split(".") + tampered_payload = token_parts[1][:-1] + ("A" if token_parts[1][-1] != "A" else "B") + tampered = f"{token_parts[0]}.{tampered_payload}.{token_parts[2]}" assert signer.verify(tampered) is None @@ -84,3 +118,153 @@ def test_multi_hop_delegation_claims_and_chain_verification() -> None: assert signer.verify_delegation(root_mandate, parent_mandate=None) is True assert signer.verify_delegation(child_mandate, parent_mandate=root_mandate) is True assert signer.verify_delegation(child_mandate, parent_mandate=None) is False + + +def test_mandate_signer_defaults_to_es256_issue_and_verify() -> None: + signer = LocalMandateSigner(secret_key="test-key", ttl_seconds=60) + request = ActionRequest( + principal=PrincipalRef(principal_id="agent:writer"), + action_spec=ActionSpec( + action="mcp.execute", resource="mcp://tools/write_file", intent="write report" + ), + state_evidence=StateEvidence(source="non-web", state_hash="state-xyz"), + verification_evidence=VerificationEvidence(), + ) + + signed = signer.issue(request) + header = _jwt_header(signed.token) + payload = _jwt_payload(signed.token) + verified = signer.verify(signed.token) + + assert header["alg"] == "ES256" + assert "kid" in header + assert payload["iss"] == "predicate-authorityd" + assert payload["aud"] == "predicate-authority" + assert payload["sub"] == "agent:writer" + assert payload["jti"] == signed.claims.mandate_id + assert payload["iat"] == signed.claims.issued_at_epoch_s + assert payload["exp"] == signed.claims.expires_at_epoch_s + assert payload["nbf"] == signed.claims.issued_at_epoch_s + assert verified is not None + + +def test_es256_signer_verifies_legacy_hs256_tokens_by_default() -> None: + legacy = LocalMandateSigner(secret_key="test-key", ttl_seconds=60, signing_alg="HS256") + current = LocalMandateSigner(secret_key="test-key", ttl_seconds=60) + request = ActionRequest( + principal=PrincipalRef(principal_id="agent:writer"), + action_spec=ActionSpec( + action="mcp.execute", resource="mcp://tools/write_file", intent="write report" + ), + state_evidence=StateEvidence(source="non-web", state_hash="state-xyz"), + verification_evidence=VerificationEvidence(), + ) + legacy_token = legacy.issue(request).token + header = _jwt_header(legacy_token) + + verified = current.verify(legacy_token) + + assert header["alg"] == "HS256" + assert verified is not None + + +def test_es256_signer_rejects_legacy_hs256_when_disabled() -> None: + legacy = LocalMandateSigner(secret_key="test-key", ttl_seconds=60, signing_alg="HS256") + strict_current = LocalMandateSigner( + secret_key="test-key", + ttl_seconds=60, + signing_alg="ES256", + allow_legacy_hs256_verify=False, + ) + request = ActionRequest( + principal=PrincipalRef(principal_id="agent:writer"), + action_spec=ActionSpec( + action="mcp.execute", resource="mcp://tools/write_file", intent="write report" + ), + state_evidence=StateEvidence(source="non-web", state_hash="state-xyz"), + verification_evidence=VerificationEvidence(), + ) + legacy_token = legacy.issue(request).token + + assert strict_current.verify(legacy_token) is None + + +def test_new_signer_parses_legacy_payload_without_standard_claims() -> None: + legacy_signer = LocalMandateSigner(secret_key="test-key", ttl_seconds=60, signing_alg="HS256") + current_signer = LocalMandateSigner(secret_key="test-key", ttl_seconds=60, signing_alg="ES256") + request = ActionRequest( + principal=PrincipalRef(principal_id="agent:writer"), + action_spec=ActionSpec( + action="mcp.execute", resource="mcp://tools/write_file", intent="write report" + ), + state_evidence=StateEvidence(source="non-web", state_hash="state-xyz"), + verification_evidence=VerificationEvidence(), + ) + signed = legacy_signer.issue(request) + legacy_payload = asdict(signed.claims) + for key in ("iss", "aud", "sub", "iat", "exp", "nbf", "jti"): + legacy_payload.pop(key, None) + legacy_token = _build_legacy_hs256_token("test-key", legacy_payload) + + verified = current_signer.verify(legacy_token) + + assert verified is not None + assert verified.claims.iss is None + assert verified.claims.aud is None + assert verified.claims.sub is None + assert verified.claims.iat is None + assert verified.claims.exp is None + assert verified.claims.nbf is None + assert verified.claims.jti is None + + +def test_key_rotation_activate_preserves_overlap_verify_window() -> None: + signer = LocalMandateSigner(secret_key="key-v1", ttl_seconds=60, signing_alg="ES256") + request = ActionRequest( + principal=PrincipalRef(principal_id="agent:writer"), + action_spec=ActionSpec( + action="mcp.execute", resource="mcp://tools/write_file", intent="write report" + ), + state_evidence=StateEvidence(source="non-web", state_hash="state-xyz"), + verification_evidence=VerificationEvidence(), + ) + old_signed = signer.issue(request) + old_kid = str(_jwt_header(old_signed.token)["kid"]) + + staged_kid = signer.stage_next_signing_key("key-v2") + activated_kid = signer.activate_staged_signing_key() + new_signed = signer.issue(request) + new_kid = str(_jwt_header(new_signed.token)["kid"]) + status = signer.key_lifecycle_status() + + assert staged_kid == activated_kid + assert new_kid == activated_kid + assert activated_kid != old_kid + assert signer.verify(old_signed.token) is not None + assert signer.verify(new_signed.token) is not None + assert status["active_kid"] == activated_kid + assert status["next_kid"] is None + assert old_kid in status["verification_kids"] + assert activated_kid in status["verification_kids"] + + +def test_key_rotation_retire_old_key_invalidates_old_token() -> None: + signer = LocalMandateSigner(secret_key="key-v1", ttl_seconds=60, signing_alg="ES256") + request = ActionRequest( + principal=PrincipalRef(principal_id="agent:writer"), + action_spec=ActionSpec( + action="mcp.execute", resource="mcp://tools/write_file", intent="write report" + ), + state_evidence=StateEvidence(source="non-web", state_hash="state-xyz"), + verification_evidence=VerificationEvidence(), + ) + old_signed = signer.issue(request) + old_kid = str(_jwt_header(old_signed.token)["kid"]) + signer.stage_next_signing_key("key-v2") + signer.activate_staged_signing_key() + _ = signer.issue(request) + + retired = signer.retire_verification_key(old_kid) + + assert retired is True + assert signer.verify(old_signed.token) is None