diff --git a/segment/analytics/client.py b/segment/analytics/client.py index 0f8015c..79262e9 100644 --- a/segment/analytics/client.py +++ b/segment/analytics/client.py @@ -30,7 +30,7 @@ class DefaultConfig(object): max_queue_size = 10000 gzip = False timeout = 15 - max_retries = 10 + max_retries = 1000 proxies = None thread = 1 upload_interval = 0.5 diff --git a/segment/analytics/consumer.py b/segment/analytics/consumer.py index 157e3c9..db0fe9d 100644 --- a/segment/analytics/consumer.py +++ b/segment/analytics/consumer.py @@ -1,10 +1,10 @@ import logging import time +import random from threading import Thread -import backoff import json -from segment.analytics.request import post, APIError, DatetimeSerializer +from segment.analytics.request import post, APIError, DatetimeSerializer, parse_retry_after from queue import Empty @@ -29,7 +29,7 @@ class Consumer(Thread): log = logging.getLogger('segment') def __init__(self, queue, write_key, upload_size=100, host=None, - on_error=None, upload_interval=0.5, gzip=False, retries=10, + on_error=None, upload_interval=0.5, gzip=False, retries=1000, timeout=15, proxies=None, oauth_manager=None): """Create a consumer thread.""" Thread.__init__(self) @@ -120,40 +120,108 @@ def next(self): return items def request(self, batch): - """Attempt to upload the batch and retry before raising an error """ - - def fatal_exception(exc): - if isinstance(exc, APIError): - # retry on server errors and client errors - # with 429 status code (rate limited), - # don't retry on other client errors - return (400 <= exc.status < 500) and exc.status != 429 - elif isinstance(exc, FatalError): - return True - else: - # retry on all other errors (eg. network) - return False - - attempt_count = 0 - - @backoff.on_exception( - backoff.expo, - Exception, - max_tries=self.retries + 1, - giveup=fatal_exception, - on_backoff=lambda details: self.log.debug( - f"Retry attempt {details['tries']}/{self.retries + 1} after {details['elapsed']:.2f}s" - )) - def send_request(): - nonlocal attempt_count - attempt_count += 1 + """Attempt to upload the batch and retry before raising an error""" + + def is_retryable_status(status): + """ + Determine if a status code is retryable. + Retryable 4xx: 408, 410, 429, 460 + Non-retryable 4xx: 400, 401, 403, 404, 413, 422, and all other 4xx + Retryable 5xx: All except 501, 505 + Non-retryable 5xx: 501, 505 + """ + if 400 <= status < 500: + return status in (408, 410, 429, 460) + elif 500 <= status < 600: + return status not in (501, 505) + return False + + def should_use_retry_after(status): + """Check if status code should respect Retry-After header""" + return status in (408, 429, 503) + + total_attempts = 0 + backoff_attempts = 0 + max_backoff_attempts = self.retries + 1 + + while True: try: - return post(self.write_key, self.host, gzip=self.gzip, - timeout=self.timeout, batch=batch, proxies=self.proxies, - oauth_manager=self.oauth_manager) - except Exception as e: - if attempt_count >= self.retries + 1: - self.log.error(f"All {self.retries} retries exhausted. Final error: {e}") + # Make the request with current retry count + response = post( + self.write_key, + self.host, + gzip=self.gzip, + timeout=self.timeout, + batch=batch, + proxies=self.proxies, + oauth_manager=self.oauth_manager, + retry_count=total_attempts + ) + # Success + return response + + except FatalError as e: + # Non-retryable error + self.log.error(f"Fatal error after {total_attempts} attempts: {e}") raise - send_request() + except APIError as e: + total_attempts += 1 + + # Check if we should use Retry-After header + if should_use_retry_after(e.status) and e.response: + retry_after = parse_retry_after(e.response) + if retry_after: + self.log.debug( + f"Retry-After header present: waiting {retry_after}s (attempt {total_attempts})" + ) + time.sleep(retry_after) + continue # Does not count against backoff budget + + # Check if status is retryable + if not is_retryable_status(e.status): + self.log.error( + f"Non-retryable error {e.status} after {total_attempts} attempts: {e}" + ) + raise + + # Count this against backoff attempts + backoff_attempts += 1 + if backoff_attempts >= max_backoff_attempts: + self.log.error( + f"All {self.retries} retries exhausted after {total_attempts} total attempts. Final error: {e}" + ) + raise + + # Calculate exponential backoff delay with jitter + base_delay = 0.5 * (2 ** (backoff_attempts - 1)) + jitter = random.uniform(0, 0.1 * base_delay) + delay = min(base_delay + jitter, 60) # Cap at 60 seconds + + self.log.debug( + f"Retry attempt {backoff_attempts}/{self.retries} (total attempts: {total_attempts}) " + f"after {delay:.2f}s for status {e.status}" + ) + time.sleep(delay) + + except Exception as e: + # Network errors or other exceptions - retry with backoff + total_attempts += 1 + backoff_attempts += 1 + + if backoff_attempts >= max_backoff_attempts: + self.log.error( + f"All {self.retries} retries exhausted after {total_attempts} total attempts. Final error: {e}" + ) + raise + + # Calculate exponential backoff delay with jitter + base_delay = 0.5 * (2 ** (backoff_attempts - 1)) + jitter = random.uniform(0, 0.1 * base_delay) + delay = min(base_delay + jitter, 60) # Cap at 60 seconds + + self.log.debug( + f"Network error retry {backoff_attempts}/{self.retries} (total attempts: {total_attempts}) " + f"after {delay:.2f}s: {e}" + ) + time.sleep(delay) diff --git a/segment/analytics/request.py b/segment/analytics/request.py index ab92b80..511a8a8 100644 --- a/segment/analytics/request.py +++ b/segment/analytics/request.py @@ -3,6 +3,7 @@ from gzip import GzipFile import logging import json +import base64 from dateutil.tz import tzutc from requests.auth import HTTPBasicAuth from requests import sessions @@ -12,8 +13,31 @@ _session = sessions.Session() +# Maximum Retry-After delay to respect (5 minutes) +MAX_RETRY_AFTER_SECONDS = 300 -def post(write_key, host=None, gzip=False, timeout=15, proxies=None, oauth_manager=None, **kwargs): + +def parse_retry_after(response): + """ + Parse Retry-After header from response. + Returns the delay in seconds, or None if header is not present or invalid. + Caps the value at MAX_RETRY_AFTER_SECONDS. + """ + retry_after = response.headers.get('Retry-After') + if not retry_after: + return None + + try: + # Try parsing as integer (delay in seconds) + delay = int(retry_after) + return min(delay, MAX_RETRY_AFTER_SECONDS) + except ValueError: + # Could be HTTP-date format, but for simplicity we'll skip that + # Most APIs use integer seconds + return None + + +def post(write_key, host=None, gzip=False, timeout=15, proxies=None, oauth_manager=None, retry_count=0, **kwargs): """Post the `kwargs` to the API""" log = logging.getLogger('segment') body = kwargs @@ -28,10 +52,18 @@ def post(write_key, host=None, gzip=False, timeout=15, proxies=None, oauth_manag log.debug('making request: %s', data) headers = { 'Content-Type': 'application/json', - 'User-Agent': 'analytics-python/' + VERSION + 'User-Agent': 'analytics-python/' + VERSION, + 'X-Retry-Count': str(retry_count) } + + # Add Authorization header - prefer OAuth Bearer token, fallback to Basic auth if auth: headers['Authorization'] = 'Bearer {}'.format(auth) + else: + # Basic auth with write key (format: "writeKey:" encoded in base64) + credentials = '{}:'.format(write_key) + encoded = base64.b64encode(credentials.encode('utf-8')).decode('utf-8') + headers['Authorization'] = 'Basic {}'.format(encoded) if gzip: headers['Content-Encoding'] = 'gzip' @@ -60,24 +92,25 @@ def post(write_key, host=None, gzip=False, timeout=15, proxies=None, oauth_manag log.debug('data uploaded successfully') return res - if oauth_manager and res.status_code in [400, 401, 403]: + if oauth_manager and res.status_code in [400, 401, 403, 511]: oauth_manager.clear_token() try: payload = res.json() log.debug('received response: %s', payload) - raise APIError(res.status_code, payload['code'], payload['message']) + raise APIError(res.status_code, payload['code'], payload['message'], res) except ValueError: log.error('Unknown error: [%s] %s', res.status_code, res.reason) - raise APIError(res.status_code, 'unknown', res.text) + raise APIError(res.status_code, 'unknown', res.text, res) class APIError(Exception): - def __init__(self, status, code, message): + def __init__(self, status, code, message, response=None): self.message = message self.status = status self.code = code + self.response = response def __str__(self): msg = "[Segment] {0}: {1} ({2})" diff --git a/segment/analytics/test/test_consumer.py b/segment/analytics/test/test_consumer.py index 8371726..1b9718b 100644 --- a/segment/analytics/test/test_consumer.py +++ b/segment/analytics/test/test_consumer.py @@ -8,7 +8,7 @@ except ImportError: from Queue import Queue -from segment.analytics.consumer import Consumer, MAX_MSG_SIZE +from segment.analytics.consumer import Consumer, MAX_MSG_SIZE, FatalError from segment.analytics.request import APIError @@ -220,3 +220,528 @@ def mock_post_fn(*args, **kwargs): args, kwargs = mock_post.call_args cls().assertIn('proxies', kwargs) cls().assertEqual(kwargs['proxies'], proxies) + + def test_retry_count_header_increments(self): + """Test that X-Retry-Count header increments on each retry""" + consumer = Consumer(None, 'testsecret', retries=3) + track = {'type': 'track', 'event': 'python event', 'userId': 'userId'} + + retry_counts = [] + + def mock_post_fn(*args, **kwargs): + retry_counts.append(kwargs.get('retry_count', 0)) + if len(retry_counts) < 3: + raise APIError(500, 'error', 'Server Error') + # Success on third attempt + return mock.Mock(status_code=200) + + with mock.patch('segment.analytics.consumer.post', side_effect=mock_post_fn): + consumer.request([track]) + + # Should have been called 3 times with retry counts 0, 1, 2 + self.assertEqual(retry_counts, [0, 1, 2]) + + def test_non_retryable_4xx_status_codes(self): + """Test that non-retryable 4xx errors are not retried""" + consumer = Consumer(None, 'testsecret', retries=3) + track = {'type': 'track', 'event': 'python event', 'userId': 'userId'} + + non_retryable_codes = [400, 401, 403, 404, 413, 422] + + for status_code in non_retryable_codes: + call_count = 0 + + def mock_post_fn(*args, **kwargs): + nonlocal call_count + call_count += 1 + raise APIError(status_code, 'error', f'Client Error {status_code}') + + with mock.patch('segment.analytics.consumer.post', side_effect=mock_post_fn): + try: + consumer.request([track]) + except APIError as e: + self.assertEqual(e.status, status_code) + + # Should only be called once (no retries) + self.assertEqual(call_count, 1, f'Status {status_code} should not be retried') + + def test_retryable_4xx_status_codes(self): + """Test that retryable 4xx errors are retried""" + consumer = Consumer(None, 'testsecret', retries=3) + track = {'type': 'track', 'event': 'python event', 'userId': 'userId'} + + retryable_codes = [408, 410, 429, 460] + + for status_code in retryable_codes: + call_count = 0 + + def mock_post_fn(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count < 3: + raise APIError(status_code, 'error', f'Retryable Error {status_code}') + # Success on third attempt + return mock.Mock(status_code=200) + + with mock.patch('segment.analytics.consumer.post', side_effect=mock_post_fn): + with mock.patch('time.sleep'): # Mock sleep to speed up test + consumer.request([track]) + + # Should have been called 3 times + self.assertEqual(call_count, 3, f'Status {status_code} should be retried') + + def test_non_retryable_5xx_status_codes(self): + """Test that non-retryable 5xx errors are not retried""" + consumer = Consumer(None, 'testsecret', retries=3) + track = {'type': 'track', 'event': 'python event', 'userId': 'userId'} + + non_retryable_codes = [501, 505] + + for status_code in non_retryable_codes: + call_count = 0 + + def mock_post_fn(*args, **kwargs): + nonlocal call_count + call_count += 1 + raise APIError(status_code, 'error', f'Server Error {status_code}') + + with mock.patch('segment.analytics.consumer.post', side_effect=mock_post_fn): + try: + consumer.request([track]) + except APIError as e: + self.assertEqual(e.status, status_code) + + # Should only be called once (no retries) + self.assertEqual(call_count, 1, f'Status {status_code} should not be retried') + + def test_retryable_5xx_status_codes(self): + """Test that retryable 5xx errors are retried""" + consumer = Consumer(None, 'testsecret', retries=3) + track = {'type': 'track', 'event': 'python event', 'userId': 'userId'} + + retryable_codes = [500, 502, 503, 504] + + for status_code in retryable_codes: + call_count = 0 + + def mock_post_fn(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count < 3: + raise APIError(status_code, 'error', f'Server Error {status_code}') + # Success on third attempt + return mock.Mock(status_code=200) + + with mock.patch('segment.analytics.consumer.post', side_effect=mock_post_fn): + with mock.patch('time.sleep'): # Mock sleep to speed up test + consumer.request([track]) + + # Should have been called 3 times + self.assertEqual(call_count, 3, f'Status {status_code} should be retried') + + def test_retry_after_header_support(self): + """Test that Retry-After header is respected and doesn't count against retry budget""" + consumer = Consumer(None, 'testsecret', retries=2) + track = {'type': 'track', 'event': 'python event', 'userId': 'userId'} + + call_count = 0 + sleep_durations = [] + + def mock_post_fn(*args, **kwargs): + nonlocal call_count + call_count += 1 + + if call_count <= 3: + # Return 429 with Retry-After for first 3 attempts + response = mock.Mock() + response.headers = {'Retry-After': '10'} + error = APIError(429, 'rate_limit', 'Too Many Requests') + error.response = response + raise error + + # Success on 4th attempt + return mock.Mock(status_code=200) + + def mock_sleep(duration): + sleep_durations.append(duration) + + with mock.patch('segment.analytics.consumer.post', side_effect=mock_post_fn): + with mock.patch('time.sleep', side_effect=mock_sleep): + consumer.request([track]) + + # Should succeed after 4 attempts (3 Retry-After, then success) + self.assertEqual(call_count, 4) + + # First 3 sleeps should be for Retry-After (10 seconds each) + self.assertEqual(sleep_durations[:3], [10, 10, 10]) + + def test_retry_after_capped_at_300_seconds(self): + """Test that Retry-After delay is capped at 300 seconds""" + consumer = Consumer(None, 'testsecret', retries=2) + track = {'type': 'track', 'event': 'python event', 'userId': 'userId'} + + call_count = 0 + sleep_duration = None + + def mock_post_fn(*args, **kwargs): + nonlocal call_count + call_count += 1 + + if call_count == 1: + # Return 429 with large Retry-After + response = mock.Mock() + response.headers = {'Retry-After': '600'} # 10 minutes + error = APIError(429, 'rate_limit', 'Too Many Requests') + error.response = response + raise error + + # Success on 2nd attempt + return mock.Mock(status_code=200) + + def mock_sleep(duration): + nonlocal sleep_duration + sleep_duration = duration + + with mock.patch('segment.analytics.consumer.post', side_effect=mock_post_fn): + with mock.patch('time.sleep', side_effect=mock_sleep): + consumer.request([track]) + + # Sleep should be capped at 300 seconds + self.assertEqual(sleep_duration, 300) + + def test_retry_after_for_408_and_503(self): + """Test that Retry-After is respected for 408 and 503 status codes""" + consumer = Consumer(None, 'testsecret', retries=2) + track = {'type': 'track', 'event': 'python event', 'userId': 'userId'} + + for status_code in [408, 503]: + call_count = 0 + sleep_duration = None + + def mock_post_fn(*args, **kwargs): + nonlocal call_count + call_count += 1 + + if call_count == 1: + response = mock.Mock() + response.headers = {'Retry-After': '5'} + error = APIError(status_code, 'error', 'Error') + error.response = response + raise error + + return mock.Mock(status_code=200) + + def mock_sleep(duration): + nonlocal sleep_duration + sleep_duration = duration + + with mock.patch('segment.analytics.consumer.post', side_effect=mock_post_fn): + with mock.patch('time.sleep', side_effect=mock_sleep): + consumer.request([track]) + + self.assertEqual(sleep_duration, 5, f'Retry-After should be respected for {status_code}') + + def test_exponential_backoff_with_jitter(self): + """Test that exponential backoff is used for retries without Retry-After""" + consumer = Consumer(None, 'testsecret', retries=4) + track = {'type': 'track', 'event': 'python event', 'userId': 'userId'} + + call_count = 0 + sleep_durations = [] + + def mock_post_fn(*args, **kwargs): + nonlocal call_count + call_count += 1 + + if call_count <= 3: + raise APIError(500, 'error', 'Server Error') + + return mock.Mock(status_code=200) + + def mock_sleep(duration): + sleep_durations.append(duration) + + with mock.patch('segment.analytics.consumer.post', side_effect=mock_post_fn): + with mock.patch('time.sleep', side_effect=mock_sleep): + consumer.request([track]) + + # Should have 3 backoff delays + self.assertEqual(len(sleep_durations), 3) + + # Delays should be increasing (exponential) + # First: ~0.5s, Second: ~1s, Third: ~2s (with jitter) + self.assertGreater(sleep_durations[0], 0.4) + self.assertLess(sleep_durations[0], 1.0) + self.assertGreater(sleep_durations[1], 0.9) + self.assertLess(sleep_durations[1], 2.0) + self.assertGreater(sleep_durations[2], 1.8) + self.assertLess(sleep_durations[2], 4.0) + + def test_fatal_error_not_retried(self): + """Test that FatalError is not retried""" + consumer = Consumer(None, 'testsecret', retries=3) + track = {'type': 'track', 'event': 'python event', 'userId': 'userId'} + + call_count = 0 + + def mock_post_fn(*args, **kwargs): + nonlocal call_count + call_count += 1 + raise FatalError('Fatal error occurred') + + with mock.patch('segment.analytics.consumer.post', side_effect=mock_post_fn): + try: + consumer.request([track]) + except FatalError: + pass + + # Should only be called once (no retries) + self.assertEqual(call_count, 1) + + def test_max_retries_exhausted(self): + """Test that request fails after max retries exhausted""" + consumer = Consumer(None, 'testsecret', retries=2) + track = {'type': 'track', 'event': 'python event', 'userId': 'userId'} + + call_count = 0 + + def mock_post_fn(*args, **kwargs): + nonlocal call_count + call_count += 1 + # Always fail with retryable error + raise APIError(500, 'error', 'Server Error') + + with mock.patch('segment.analytics.consumer.post', side_effect=mock_post_fn): + with mock.patch('time.sleep'): # Mock sleep to speed up test + try: + consumer.request([track]) + except APIError as e: + self.assertEqual(e.status, 500) + + # Should be called 3 times (initial + 2 retries) + self.assertEqual(call_count, 3) + + def test_first_request_has_retry_count_zero(self): + """T01: First successful request includes X-Retry-Count=0""" + consumer = Consumer(None, 'testsecret') + track = {'type': 'track', 'event': 'python event', 'userId': 'userId'} + + retry_count = None + + def mock_post_fn(*args, **kwargs): + nonlocal retry_count + retry_count = kwargs.get('retry_count') + return mock.Mock(status_code=200) + + with mock.patch('segment.analytics.consumer.post', side_effect=mock_post_fn): + consumer.request([track]) + + # First request should have retry_count=0 + self.assertEqual(retry_count, 0) + + def test_429_without_retry_after_uses_backoff(self): + """T09: 429 without Retry-After header uses backoff retry""" + consumer = Consumer(None, 'testsecret', retries=3) + track = {'type': 'track', 'event': 'python event', 'userId': 'userId'} + + call_count = 0 + retry_counts = [] + sleep_duration = None + + def mock_post_fn(*args, **kwargs): + nonlocal call_count + call_count += 1 + retry_counts.append(kwargs.get('retry_count', 0)) + + if call_count == 1: + # 429 without Retry-After header + error = APIError(429, 'rate_limit', 'Too Many Requests') + error.response = mock.Mock() + error.response.headers = {} # No Retry-After + raise error + + return mock.Mock(status_code=200) + + def mock_sleep(duration): + nonlocal sleep_duration + sleep_duration = duration + + with mock.patch('segment.analytics.consumer.post', side_effect=mock_post_fn): + with mock.patch('time.sleep', side_effect=mock_sleep): + consumer.request([track]) + + # Should have two attempts + self.assertEqual(call_count, 2) + self.assertEqual(retry_counts, [0, 1]) + + # Should use backoff delay (around 0.5s with jitter) + self.assertIsNotNone(sleep_duration) + if sleep_duration is not None: + self.assertGreater(sleep_duration, 0.4) + self.assertLess(sleep_duration, 1.0) + + def test_408_without_retry_after_uses_backoff(self): + """T10: 408 without Retry-After header uses backoff retry""" + consumer = Consumer(None, 'testsecret', retries=3) + track = {'type': 'track', 'event': 'python event', 'userId': 'userId'} + + call_count = 0 + retry_counts = [] + sleep_duration = None + + def mock_post_fn(*args, **kwargs): + nonlocal call_count + call_count += 1 + retry_counts.append(kwargs.get('retry_count', 0)) + + if call_count == 1: + # 408 without Retry-After header + error = APIError(408, 'timeout', 'Request Timeout') + error.response = mock.Mock() + error.response.headers = {} # No Retry-After + raise error + + return mock.Mock(status_code=200) + + def mock_sleep(duration): + nonlocal sleep_duration + sleep_duration = duration + + with mock.patch('segment.analytics.consumer.post', side_effect=mock_post_fn): + with mock.patch('time.sleep', side_effect=mock_sleep): + consumer.request([track]) + + # Should have two attempts + self.assertEqual(call_count, 2) + self.assertEqual(retry_counts, [0, 1]) + + # Should use backoff delay + self.assertIsNotNone(sleep_duration) + if sleep_duration is not None: + self.assertGreater(sleep_duration, 0.4) + self.assertLess(sleep_duration, 1.0) + + def test_network_error_retried_with_backoff(self): + """T15: Network/IO error is retried with backoff""" + consumer = Consumer(None, 'testsecret', retries=3) + track = {'type': 'track', 'event': 'python event', 'userId': 'userId'} + + call_count = 0 + retry_counts = [] + sleep_duration = None + + def mock_post_fn(*args, **kwargs): + nonlocal call_count + call_count += 1 + retry_counts.append(kwargs.get('retry_count', 0)) + + if call_count == 1: + # Network error + raise ConnectionError('Network connection failed') + + return mock.Mock(status_code=200) + + def mock_sleep(duration): + nonlocal sleep_duration + sleep_duration = duration + + with mock.patch('segment.analytics.consumer.post', side_effect=mock_post_fn): + with mock.patch('time.sleep', side_effect=mock_sleep): + consumer.request([track]) + + # Should have two attempts + self.assertEqual(call_count, 2) + self.assertEqual(retry_counts, [0, 1]) + + # Should use backoff delay + self.assertIsNotNone(sleep_duration) + if sleep_duration is not None: + self.assertGreater(sleep_duration, 0.4) + self.assertLess(sleep_duration, 1.0) + + def test_511_is_retryable(self): + """T05: 511 status code is retryable (part of 5xx family, not in non-retryable list)""" + consumer = Consumer(None, 'testsecret', retries=3) + track = {'type': 'track', 'event': 'python event', 'userId': 'userId'} + + call_count = 0 + retry_counts = [] + + def mock_post_fn(*args, **kwargs): + nonlocal call_count + call_count += 1 + retry_counts.append(kwargs.get('retry_count', 0)) + + if call_count < 3: + raise APIError(511, 'auth_required', 'Network Authentication Required') + + return mock.Mock(status_code=200) + + with mock.patch('segment.analytics.consumer.post', side_effect=mock_post_fn): + with mock.patch('time.sleep'): # Mock sleep to speed up test + consumer.request([track]) + + # Should have been called 3 times (511 is retryable) + self.assertEqual(call_count, 3) + self.assertEqual(retry_counts, [0, 1, 2]) + + def test_retry_after_not_counted_against_backoff_budget(self): + """T17: Retry-After attempts don't consume backoff retry budget""" + consumer = Consumer(None, 'testsecret', retries=1) + track = {'type': 'track', 'event': 'python event', 'userId': 'userId'} + + call_count = 0 + retry_counts = [] + + def mock_post_fn(*args, **kwargs): + nonlocal call_count + call_count += 1 + retry_counts.append(kwargs.get('retry_count', 0)) + + if call_count <= 2: + # First two: 429 with Retry-After (shouldn't count against budget) + response = mock.Mock() + response.headers = {'Retry-After': '1'} + error = APIError(429, 'rate_limit', 'Too Many Requests') + error.response = response + raise error + elif call_count == 3: + # Third: 500 without Retry-After (counts against budget) + raise APIError(500, 'error', 'Server Error') + + # Success on 4th attempt + return mock.Mock(status_code=200) + + with mock.patch('segment.analytics.consumer.post', side_effect=mock_post_fn): + with mock.patch('time.sleep'): # Mock sleep to speed up test + consumer.request([track]) + + # Should succeed after 4 attempts: + # - 2 Retry-After attempts (don't count against budget) + # - 1 backoff attempt (counts against budget = 1) + # - 1 final backoff attempt (counts against budget = 1, limit reached) + # Actually wait, with retries=1, we have max_backoff_attempts=2 + # So: 2 Retry-After + 2 backoff attempts = 4 total + self.assertEqual(call_count, 4) + self.assertEqual(retry_counts, [0, 1, 2, 3]) + + def test_413_payload_too_large_not_retried(self): + """T12: 413 Payload Too Large is non-retryable (won't succeed on retry)""" + consumer = Consumer(None, 'testsecret', retries=3) + track = {'type': 'track', 'event': 'python event', 'userId': 'userId'} + + call_count = 0 + + def mock_post_fn(*args, **kwargs): + nonlocal call_count + call_count += 1 + raise APIError(413, 'payload_too_large', 'Payload Too Large') + + with mock.patch('segment.analytics.consumer.post', side_effect=mock_post_fn): + try: + consumer.request([track]) + except APIError as e: + self.assertEqual(e.status, 413) + + # Should only be called once (no retries) + self.assertEqual(call_count, 1) diff --git a/segment/analytics/test/test_request.py b/segment/analytics/test/test_request.py index 5ffca00..54b48be 100644 --- a/segment/analytics/test/test_request.py +++ b/segment/analytics/test/test_request.py @@ -2,9 +2,10 @@ import unittest import json import requests +import base64 from unittest import mock -from segment.analytics.request import post, DatetimeSerializer +from segment.analytics.request import post, DatetimeSerializer, parse_retry_after, APIError class TestRequests(unittest.TestCase): @@ -72,3 +73,153 @@ def mock_post_fn(*args, **kwargs): args, kwargs = mock_post.call_args self.assertIn('proxies', kwargs) self.assertEqual(kwargs['proxies'], proxies) + + def test_authorization_header_basic_auth(self): + """Test that Basic Authorization header is added when no OAuth manager""" + def mock_post_fn(*args, **kwargs): + res = mock.Mock() + res.status_code = 200 + return res + + with mock.patch('segment.analytics.request._session.post', side_effect=mock_post_fn) as mock_post: + post('testsecret', batch=[{ + 'userId': 'userId', + 'event': 'python event', + 'type': 'track' + }]) + + args, kwargs = mock_post.call_args + headers = kwargs['headers'] + self.assertIn('Authorization', headers) + + # Verify it's Basic auth with correct encoding + expected_credentials = base64.b64encode(b'testsecret:').decode('utf-8') + expected_auth = f'Basic {expected_credentials}' + self.assertEqual(headers['Authorization'], expected_auth) + + def test_authorization_header_oauth(self): + """Test that Bearer Authorization header is used with OAuth manager""" + oauth_manager = mock.Mock() + oauth_manager.get_token.return_value = 'test_token_123' + + def mock_post_fn(*args, **kwargs): + res = mock.Mock() + res.status_code = 200 + return res + + with mock.patch('segment.analytics.request._session.post', side_effect=mock_post_fn) as mock_post: + post('testsecret', oauth_manager=oauth_manager, batch=[{ + 'userId': 'userId', + 'event': 'python event', + 'type': 'track' + }]) + + args, kwargs = mock_post.call_args + headers = kwargs['headers'] + self.assertIn('Authorization', headers) + self.assertEqual(headers['Authorization'], 'Bearer test_token_123') + + def test_x_retry_count_header(self): + """Test that X-Retry-Count header is included""" + def mock_post_fn(*args, **kwargs): + res = mock.Mock() + res.status_code = 200 + return res + + with mock.patch('segment.analytics.request._session.post', side_effect=mock_post_fn) as mock_post: + # Test with retry_count=0 (first attempt) + post('testsecret', retry_count=0, batch=[{ + 'userId': 'userId', + 'event': 'python event', + 'type': 'track' + }]) + + args, kwargs = mock_post.call_args + headers = kwargs['headers'] + self.assertIn('X-Retry-Count', headers) + self.assertEqual(headers['X-Retry-Count'], '0') + + with mock.patch('segment.analytics.request._session.post', side_effect=mock_post_fn) as mock_post: + # Test with retry_count=5 + post('testsecret', retry_count=5, batch=[{ + 'userId': 'userId', + 'event': 'python event', + 'type': 'track' + }]) + + args, kwargs = mock_post.call_args + headers = kwargs['headers'] + self.assertEqual(headers['X-Retry-Count'], '5') + + def test_parse_retry_after_integer(self): + """Test parsing Retry-After header with integer seconds""" + response = mock.Mock() + response.headers = {'Retry-After': '30'} + result = parse_retry_after(response) + self.assertEqual(result, 30) + + def test_parse_retry_after_capped(self): + """Test that Retry-After is capped at 300 seconds""" + response = mock.Mock() + response.headers = {'Retry-After': '600'} + result = parse_retry_after(response) + self.assertEqual(result, 300) + + def test_parse_retry_after_missing(self): + """Test parsing when Retry-After header is missing""" + response = mock.Mock() + response.headers = {} + result = parse_retry_after(response) + self.assertIsNone(result) + + def test_parse_retry_after_invalid(self): + """Test parsing with invalid Retry-After header""" + response = mock.Mock() + response.headers = {'Retry-After': 'invalid'} + result = parse_retry_after(response) + self.assertIsNone(result) + + def test_oauth_token_cleared_on_511(self): + """Test that OAuth token is cleared on 511 status""" + oauth_manager = mock.Mock() + oauth_manager.get_token.return_value = 'test_token' + + def mock_post_fn(*args, **kwargs): + res = mock.Mock() + res.status_code = 511 + res.json.return_value = {'code': 'error', 'message': 'Network Authentication Required'} + return res + + with mock.patch('segment.analytics.request._session.post', side_effect=mock_post_fn): + try: + post('testsecret', oauth_manager=oauth_manager, batch=[{ + 'userId': 'userId', + 'event': 'python event', + 'type': 'track' + }]) + except APIError: + pass + + # Verify clear_token was called + oauth_manager.clear_token.assert_called_once() + + def test_api_error_includes_response(self): + """Test that APIError includes the response object""" + def mock_post_fn(*args, **kwargs): + res = mock.Mock() + res.status_code = 429 + res.json.return_value = {'code': 'rate_limit', 'message': 'Too Many Requests'} + return res + + with mock.patch('segment.analytics.request._session.post', side_effect=mock_post_fn): + try: + post('testsecret', batch=[{ + 'userId': 'userId', + 'event': 'python event', + 'type': 'track' + }]) + except APIError as e: + self.assertEqual(e.status, 429) + self.assertIsNotNone(e.response) + else: + self.fail('Expected APIError to be raised')