Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 4 additions & 14 deletions src/pysafeguard/async_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from multidict import CIMultiDict
from truststore import SSLContext

from .data_types import A2ATypes, HttpMethods, JsonType, Services, SshKeyFormats
from .utility import LiteralString, assemble_path, assemble_url
from .data_types import A2ATypes, HttpMethods, Services, SshKeyFormats
from .utility import JsonType, LiteralString, assemble_path, assemble_url, get_access_token, get_user_token


class AsyncWebRequestError(Exception):
Expand Down Expand Up @@ -49,8 +49,6 @@ async def __execute_web_request(
updated_headers = CIMultiDict(headers)
if body and httpMethod in [HttpMethods.POST, HttpMethods.PUT] and not headers.get("content-type"):
data_body = None
if not isinstance(body, dict):
raise TypeError("expected: body as a JSON object")
json_body = body
updated_headers["content-type"] = "application/json"
else:
Expand Down Expand Up @@ -136,20 +134,12 @@ async def get_provider_id(self, name: str) -> str:
return typing.cast(str, matches[0]["RstsProviderId"])

async def __connect(self, body: JsonType, cert: tuple[str, str] | None = None) -> None:
data: JsonType
resp = await self.invoke(HttpMethods.POST, Services.RSTS, "oauth2/token", body=body, cert=cert)
if resp.status == 200 and "application/json" in resp.headers.get("content-type", ""):
data = await resp.json()
if not isinstance(data, dict):
raise TypeError("expected: JSON object with field `access_token`")
access_token = data.get("access_token")

access_token = get_access_token(await resp.json())
resp = await self.invoke(HttpMethods.POST, Services.CORE, "Token/LoginResponse", body=dict(StsAccessToken=access_token))
if resp.status == 200 and "application/json" in resp.headers.get("content-type", ""):
data = await resp.json()
if not isinstance(data, dict):
raise TypeError("expected: JSON object with field `UserToken`")
user_token = typing.cast(str, data.get("UserToken"))
user_token = get_user_token(await resp.json())
self.connect_token(user_token)
else:
raise AsyncWebRequestError(resp)
Expand Down
18 changes: 4 additions & 14 deletions src/pysafeguard/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from requests import Response, request
from requests.structures import CaseInsensitiveDict

from .data_types import A2ATypes, HttpMethods, JsonType, Services, SshKeyFormats
from .utility import LiteralString, assemble_path, assemble_url
from .data_types import A2ATypes, HttpMethods, Services, SshKeyFormats
from .utility import JsonType, LiteralString, assemble_path, assemble_url, get_access_token, get_user_token


class WebRequestError(Exception):
Expand Down Expand Up @@ -47,8 +47,6 @@ def __execute_web_request(
updated_headers = CaseInsensitiveDict(headers)
if body and httpMethod in [HttpMethods.POST, HttpMethods.PUT] and not headers.get("content-type"):
data_body = None
if not isinstance(body, dict):
raise TypeError("expected: body as a JSON object")
json_body = body
updated_headers["content-type"] = "application/json"
else:
Expand Down Expand Up @@ -125,20 +123,12 @@ def get_provider_id(self, name: str) -> str:
return typing.cast(str, matches[0]["RstsProviderId"])

def __connect(self, body: JsonType, cert: tuple[str, str] | None = None) -> None:
data: JsonType
resp = self.invoke(HttpMethods.POST, Services.RSTS, "oauth2/token", body=body, cert=cert)
if resp.status_code == 200 and "application/json" in resp.headers.get("content-type", ""):
data = resp.json()
if not isinstance(data, dict):
raise TypeError("expected: JSON object with field `access_token`")
access_token = data.get("access_token")

access_token = get_access_token(resp.json())
resp = self.invoke(HttpMethods.POST, Services.CORE, "Token/LoginResponse", body=dict(StsAccessToken=access_token))
if resp.status_code == 200 and "application/json" in resp.headers.get("content-type", ""):
data = resp.json()
if not isinstance(data, dict):
raise TypeError("expected: JSON object with field `UserToken`")
user_token = typing.cast(str, data.get("UserToken"))
user_token = get_user_token(resp.json())
self.connect_token(user_token)
else:
raise WebRequestError(resp)
Expand Down
3 changes: 0 additions & 3 deletions src/pysafeguard/data_types.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import enum
import sys

JsonType = None | bool | int | float | str | dict[str, "JsonType"] | list["JsonType"]


if sys.version_info < (3, 11):

class StrEnum(str, enum.Enum):
Expand Down
18 changes: 18 additions & 0 deletions src/pysafeguard/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,28 @@
else:
from typing import LiteralString as LiteralString

JsonType = None | bool | int | float | str | dict[str, "JsonType"] | list["JsonType"]


def assemble_path(*args: str | None) -> str:
return "/".join(arg.strip("/") for arg in args if arg is not None)


def assemble_url(netloc: str = "", path: str = "", query: Mapping[str, str] = {}, fragment: str = "", scheme: LiteralString = "https") -> str:
return urlunparse((scheme, netloc, path, "", urlencode(query, True), fragment))


def get_access_token(data: JsonType) -> str:
if not isinstance(data, dict) or (access_token := data.get("access_token")) is None:
raise TypeError("expected: JSON object with field `access_token`")
if not isinstance(access_token, str):
raise TypeError("expected: `access_token` as a string")
return access_token


def get_user_token(data: JsonType) -> str:
if not isinstance(data, dict) or (user_token := data.get("UserToken")) is None:
raise TypeError("expected: JSON object with field `UserToken`")
if not isinstance(user_token, str):
raise TypeError("expected: `UserToken` as a string")
return user_token