diff --git a/infisical_sdk/infisical_requests.py b/infisical_sdk/infisical_requests.py index 89ef9f4..0b91516 100644 --- a/infisical_sdk/infisical_requests.py +++ b/infisical_sdk/infisical_requests.py @@ -1,9 +1,27 @@ -from typing import Any, Dict, Generic, Optional, TypeVar, Type +from typing import Any, Dict, Generic, Optional, TypeVar, Type, Callable, List +import socket import requests +import functools from dataclasses import dataclass +import time +import random T = TypeVar("T") +# List of network-related exceptions that should trigger retries +NETWORK_ERRORS = [ + requests.exceptions.ConnectionError, + requests.exceptions.ChunkedEncodingError, + requests.exceptions.ReadTimeout, + requests.exceptions.ConnectTimeout, + socket.gaierror, + socket.timeout, + ConnectionResetError, + ConnectionRefusedError, + ConnectionError, + ConnectionAbortedError, +] + def join_url(base: str, path: str) -> str: """ Join base URL and path properly, handling slashes appropriately. @@ -49,6 +67,42 @@ def from_dict(cls, data: Dict) -> 'APIResponse[T]': headers=data['headers'] ) +def with_retry( + max_retries: int = 3, + base_delay: float = 1.0, + network_errors: Optional[List[Type[Exception]]] = None +) -> Callable: + """ + Decorator to add retry logic with exponential backoff to requests methods. + """ + if network_errors is None: + network_errors = NETWORK_ERRORS + + def decorator(func: Callable) -> Callable: + @functools.wraps(func) + def wrapper(*args, **kwargs): + retry_count = 0 + + while True: + try: + return func(*args, **kwargs) + except tuple(network_errors) as error: + retry_count += 1 + if retry_count > max_retries: + raise + + base_delay_with_backoff = base_delay * (2 ** (retry_count - 1)) + + # +/-20% jitter + jitter = random.uniform(-0.2, 0.2) * base_delay_with_backoff + delay = base_delay_with_backoff + jitter + + time.sleep(delay) + + return wrapper + + return decorator + class InfisicalRequests: def __init__(self, host: str, token: Optional[str] = None): @@ -93,6 +147,7 @@ def _handle_response(self, response: requests.Response) -> Dict[str, Any]: except ValueError: raise InfisicalError("Invalid JSON response") + @with_retry(max_retries=4, base_delay=1.0) def get( self, path: str, @@ -119,6 +174,7 @@ def get( headers=dict(response.headers) ) + @with_retry(max_retries=4, base_delay=1.0) def post( self, path: str, @@ -143,6 +199,7 @@ def post( headers=dict(response.headers) ) + @with_retry(max_retries=4, base_delay=1.0) def patch( self, path: str, @@ -167,6 +224,7 @@ def patch( headers=dict(response.headers) ) + @with_retry(max_retries=4, base_delay=1.0) def delete( self, path: str,