From 9c1cd53f16398f0cd0cbfebcae7a10063c7b30d5 Mon Sep 17 00:00:00 2001 From: Levente Hunyadi Date: Tue, 3 Feb 2026 12:36:43 +0100 Subject: [PATCH] Improve type safety --- src/pysafeguard/async_connection.py | 18 ++++-------------- src/pysafeguard/connection.py | 18 ++++-------------- src/pysafeguard/data_types.py | 3 --- src/pysafeguard/utility.py | 18 ++++++++++++++++++ 4 files changed, 26 insertions(+), 31 deletions(-) diff --git a/src/pysafeguard/async_connection.py b/src/pysafeguard/async_connection.py index 12a6af4..426e131 100644 --- a/src/pysafeguard/async_connection.py +++ b/src/pysafeguard/async_connection.py @@ -7,8 +7,8 @@ from multidict import CIMultiDict from truststore import SSLContext -from .data_types import A2ATypes, HttpMethods, JsonType, Services, SshKeyFormats -from .utility import LiteralString, assemble_path, assemble_url +from .data_types import A2ATypes, HttpMethods, Services, SshKeyFormats +from .utility import JsonType, LiteralString, assemble_path, assemble_url, get_access_token, get_user_token class AsyncWebRequestError(Exception): @@ -49,8 +49,6 @@ async def __execute_web_request( updated_headers = CIMultiDict(headers) if body and httpMethod in [HttpMethods.POST, HttpMethods.PUT] and not headers.get("content-type"): data_body = None - if not isinstance(body, dict): - raise TypeError("expected: body as a JSON object") json_body = body updated_headers["content-type"] = "application/json" else: @@ -136,20 +134,12 @@ async def get_provider_id(self, name: str) -> str: return typing.cast(str, matches[0]["RstsProviderId"]) async def __connect(self, body: JsonType, cert: tuple[str, str] | None = None) -> None: - data: JsonType resp = await self.invoke(HttpMethods.POST, Services.RSTS, "oauth2/token", body=body, cert=cert) if resp.status == 200 and "application/json" in resp.headers.get("content-type", ""): - data = await resp.json() - if not isinstance(data, dict): - raise TypeError("expected: JSON object with field `access_token`") - access_token = data.get("access_token") - + access_token = get_access_token(await resp.json()) resp = await self.invoke(HttpMethods.POST, Services.CORE, "Token/LoginResponse", body=dict(StsAccessToken=access_token)) if resp.status == 200 and "application/json" in resp.headers.get("content-type", ""): - data = await resp.json() - if not isinstance(data, dict): - raise TypeError("expected: JSON object with field `UserToken`") - user_token = typing.cast(str, data.get("UserToken")) + user_token = get_user_token(await resp.json()) self.connect_token(user_token) else: raise AsyncWebRequestError(resp) diff --git a/src/pysafeguard/connection.py b/src/pysafeguard/connection.py index 76976bf..e124c9f 100644 --- a/src/pysafeguard/connection.py +++ b/src/pysafeguard/connection.py @@ -5,8 +5,8 @@ from requests import Response, request from requests.structures import CaseInsensitiveDict -from .data_types import A2ATypes, HttpMethods, JsonType, Services, SshKeyFormats -from .utility import LiteralString, assemble_path, assemble_url +from .data_types import A2ATypes, HttpMethods, Services, SshKeyFormats +from .utility import JsonType, LiteralString, assemble_path, assemble_url, get_access_token, get_user_token class WebRequestError(Exception): @@ -47,8 +47,6 @@ def __execute_web_request( updated_headers = CaseInsensitiveDict(headers) if body and httpMethod in [HttpMethods.POST, HttpMethods.PUT] and not headers.get("content-type"): data_body = None - if not isinstance(body, dict): - raise TypeError("expected: body as a JSON object") json_body = body updated_headers["content-type"] = "application/json" else: @@ -125,20 +123,12 @@ def get_provider_id(self, name: str) -> str: return typing.cast(str, matches[0]["RstsProviderId"]) def __connect(self, body: JsonType, cert: tuple[str, str] | None = None) -> None: - data: JsonType resp = self.invoke(HttpMethods.POST, Services.RSTS, "oauth2/token", body=body, cert=cert) if resp.status_code == 200 and "application/json" in resp.headers.get("content-type", ""): - data = resp.json() - if not isinstance(data, dict): - raise TypeError("expected: JSON object with field `access_token`") - access_token = data.get("access_token") - + access_token = get_access_token(resp.json()) resp = self.invoke(HttpMethods.POST, Services.CORE, "Token/LoginResponse", body=dict(StsAccessToken=access_token)) if resp.status_code == 200 and "application/json" in resp.headers.get("content-type", ""): - data = resp.json() - if not isinstance(data, dict): - raise TypeError("expected: JSON object with field `UserToken`") - user_token = typing.cast(str, data.get("UserToken")) + user_token = get_user_token(resp.json()) self.connect_token(user_token) else: raise WebRequestError(resp) diff --git a/src/pysafeguard/data_types.py b/src/pysafeguard/data_types.py index 75e4ebb..55b12db 100644 --- a/src/pysafeguard/data_types.py +++ b/src/pysafeguard/data_types.py @@ -1,9 +1,6 @@ import enum import sys -JsonType = None | bool | int | float | str | dict[str, "JsonType"] | list["JsonType"] - - if sys.version_info < (3, 11): class StrEnum(str, enum.Enum): diff --git a/src/pysafeguard/utility.py b/src/pysafeguard/utility.py index 4a19009..482f070 100644 --- a/src/pysafeguard/utility.py +++ b/src/pysafeguard/utility.py @@ -7,6 +7,8 @@ else: from typing import LiteralString as LiteralString +JsonType = None | bool | int | float | str | dict[str, "JsonType"] | list["JsonType"] + def assemble_path(*args: str | None) -> str: return "/".join(arg.strip("/") for arg in args if arg is not None) @@ -14,3 +16,19 @@ def assemble_path(*args: str | None) -> str: def assemble_url(netloc: str = "", path: str = "", query: Mapping[str, str] = {}, fragment: str = "", scheme: LiteralString = "https") -> str: return urlunparse((scheme, netloc, path, "", urlencode(query, True), fragment)) + + +def get_access_token(data: JsonType) -> str: + if not isinstance(data, dict) or (access_token := data.get("access_token")) is None: + raise TypeError("expected: JSON object with field `access_token`") + if not isinstance(access_token, str): + raise TypeError("expected: `access_token` as a string") + return access_token + + +def get_user_token(data: JsonType) -> str: + if not isinstance(data, dict) or (user_token := data.get("UserToken")) is None: + raise TypeError("expected: JSON object with field `UserToken`") + if not isinstance(user_token, str): + raise TypeError("expected: `UserToken` as a string") + return user_token