Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions sentry_sdk/ai/monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def record_token_usage(
span: "Span",
input_tokens: "Optional[int]" = None,
input_tokens_cached: "Optional[int]" = None,
input_tokens_cache_write: "Optional[int]" = None,
output_tokens: "Optional[int]" = None,
output_tokens_reasoning: "Optional[int]" = None,
total_tokens: "Optional[int]" = None,
Expand All @@ -113,6 +114,12 @@ def record_token_usage(
input_tokens_cached,
)

if input_tokens_cache_write is not None:
span.set_data(
SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE,
input_tokens_cache_write,
)

if output_tokens is not None:
span.set_data(SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS, output_tokens)

Expand Down
6 changes: 6 additions & 0 deletions sentry_sdk/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,12 @@ class SPANDATA:
Example: 50
"""

GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE = "gen_ai.usage.input_tokens.cache_write"
"""
The number of tokens written to the cache when processing the AI input (prompt).
Example: 100
"""

GEN_AI_USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens"
"""
The number of tokens in the output.
Expand Down
102 changes: 88 additions & 14 deletions sentry_sdk/integrations/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,29 +72,47 @@ def _capture_exception(exc: "Any") -> None:
sentry_sdk.capture_event(event, hint=hint)


def _get_token_usage(result: "Messages") -> "tuple[int, int]":
def _get_token_usage(result: "Messages") -> "tuple[int, int, int, int]":
"""
Get token usage from the Anthropic response.
Returns: (input_tokens, output_tokens, cache_read_input_tokens, cache_write_input_tokens)
"""
input_tokens = 0
output_tokens = 0
cache_read_input_tokens = 0
cache_write_input_tokens = 0
if hasattr(result, "usage"):
usage = result.usage
if hasattr(usage, "input_tokens") and isinstance(usage.input_tokens, int):
input_tokens = usage.input_tokens
if hasattr(usage, "output_tokens") and isinstance(usage.output_tokens, int):
output_tokens = usage.output_tokens

return input_tokens, output_tokens
if hasattr(usage, "cache_read_input_tokens") and isinstance(
usage.cache_read_input_tokens, int
):
cache_read_input_tokens = usage.cache_read_input_tokens
if hasattr(usage, "cache_creation_input_tokens") and isinstance(
usage.cache_creation_input_tokens, int
):
cache_write_input_tokens = usage.cache_creation_input_tokens

return (
input_tokens,
output_tokens,
cache_read_input_tokens,
cache_write_input_tokens,
)


def _collect_ai_data(
event: "MessageStreamEvent",
model: "str | None",
input_tokens: int,
output_tokens: int,
cache_read_input_tokens: int,
cache_write_input_tokens: int,
content_blocks: "list[str]",
) -> "tuple[str | None, int, int, list[str]]":
) -> "tuple[str | None, int, int, int, int, list[str]]":
"""
Collect model information, token usage, and collect content blocks from the AI streaming response.
"""
Expand All @@ -104,6 +122,14 @@ def _collect_ai_data(
usage = event.message.usage
input_tokens += usage.input_tokens
output_tokens += usage.output_tokens
if hasattr(usage, "cache_read_input_tokens") and isinstance(
usage.cache_read_input_tokens, int
):
cache_read_input_tokens += usage.cache_read_input_tokens
if hasattr(usage, "cache_creation_input_tokens") and isinstance(
usage.cache_creation_input_tokens, int
):
cache_write_input_tokens += usage.cache_creation_input_tokens
model = event.message.model or model
elif event.type == "content_block_start":
pass
Expand All @@ -117,7 +143,14 @@ def _collect_ai_data(
elif event.type == "message_delta":
output_tokens += event.usage.output_tokens

return model, input_tokens, output_tokens, content_blocks
return (
model,
input_tokens,
output_tokens,
cache_read_input_tokens,
cache_write_input_tokens,
content_blocks,
)


def _set_input_data(
Expand Down Expand Up @@ -219,6 +252,8 @@ def _set_output_data(
model: "str | None",
input_tokens: "int | None",
output_tokens: "int | None",
cache_read_input_tokens: "int | None",
cache_write_input_tokens: "int | None",
content_blocks: "list[Any]",
finish_span: bool = False,
) -> None:
Expand Down Expand Up @@ -254,6 +289,8 @@ def _set_output_data(
span,
input_tokens=input_tokens,
output_tokens=output_tokens,
input_tokens_cached=cache_read_input_tokens,
input_tokens_cache_write=cache_write_input_tokens,
)

if finish_span:
Expand Down Expand Up @@ -288,7 +325,12 @@ def _sentry_patched_create_common(f: "Any", *args: "Any", **kwargs: "Any") -> "A

with capture_internal_exceptions():
if hasattr(result, "content"):
input_tokens, output_tokens = _get_token_usage(result)
(
input_tokens,
output_tokens,
cache_read_input_tokens,
cache_write_input_tokens,
) = _get_token_usage(result)

content_blocks = []
for content_block in result.content:
Expand All @@ -305,6 +347,8 @@ def _sentry_patched_create_common(f: "Any", *args: "Any", **kwargs: "Any") -> "A
model=getattr(result, "model", None),
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_read_input_tokens=cache_read_input_tokens,
cache_write_input_tokens=cache_write_input_tokens,
content_blocks=content_blocks,
finish_span=True,
)
Expand All @@ -317,13 +361,26 @@ def new_iterator() -> "Iterator[MessageStreamEvent]":
model = None
input_tokens = 0
output_tokens = 0
cache_read_input_tokens = 0
cache_write_input_tokens = 0
content_blocks: "list[str]" = []

for event in old_iterator:
model, input_tokens, output_tokens, content_blocks = (
_collect_ai_data(
event, model, input_tokens, output_tokens, content_blocks
)
(
model,
input_tokens,
output_tokens,
cache_read_input_tokens,
cache_write_input_tokens,
content_blocks,
) = _collect_ai_data(
event,
model,
input_tokens,
output_tokens,
cache_read_input_tokens,
cache_write_input_tokens,
content_blocks,
)
yield event

Expand All @@ -333,6 +390,8 @@ def new_iterator() -> "Iterator[MessageStreamEvent]":
model=model,
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_read_input_tokens=cache_read_input_tokens,
cache_write_input_tokens=cache_write_input_tokens,
content_blocks=[{"text": "".join(content_blocks), "type": "text"}],
finish_span=True,
)
Expand All @@ -341,13 +400,26 @@ async def new_iterator_async() -> "AsyncIterator[MessageStreamEvent]":
model = None
input_tokens = 0
output_tokens = 0
cache_read_input_tokens = 0
cache_write_input_tokens = 0
content_blocks: "list[str]" = []

async for event in old_iterator:
model, input_tokens, output_tokens, content_blocks = (
_collect_ai_data(
event, model, input_tokens, output_tokens, content_blocks
)
(
model,
input_tokens,
output_tokens,
cache_read_input_tokens,
cache_write_input_tokens,
content_blocks,
) = _collect_ai_data(
event,
model,
input_tokens,
output_tokens,
cache_read_input_tokens,
cache_write_input_tokens,
content_blocks,
)
yield event

Expand All @@ -357,6 +429,8 @@ async def new_iterator_async() -> "AsyncIterator[MessageStreamEvent]":
model=model,
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_read_input_tokens=cache_read_input_tokens,
cache_write_input_tokens=cache_write_input_tokens,
content_blocks=[{"text": "".join(content_blocks), "type": "text"}],
finish_span=True,
)
Expand Down
22 changes: 22 additions & 0 deletions sentry_sdk/integrations/pydantic_ai/spans/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,30 @@ def _set_usage_data(
if hasattr(usage, "input_tokens") and usage.input_tokens is not None:
span.set_data(SPANDATA.GEN_AI_USAGE_INPUT_TOKENS, usage.input_tokens)

# Pydantic AI uses cache_read_tokens (not input_tokens_cached)
if hasattr(usage, "cache_read_tokens") and usage.cache_read_tokens is not None:
span.set_data(
SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED, usage.cache_read_tokens
)

# Pydantic AI uses cache_write_tokens (not input_tokens_cache_write)
if hasattr(usage, "cache_write_tokens") and usage.cache_write_tokens is not None:
span.set_data(
SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE,
usage.cache_write_tokens,
)

if hasattr(usage, "output_tokens") and usage.output_tokens is not None:
span.set_data(SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS, usage.output_tokens)

if (
hasattr(usage, "output_tokens_reasoning")
and usage.output_tokens_reasoning is not None
):
span.set_data(
SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS_REASONING,
usage.output_tokens_reasoning,
)

if hasattr(usage, "total_tokens") and usage.total_tokens is not None:
span.set_data(SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS, usage.total_tokens)
88 changes: 86 additions & 2 deletions tests/integrations/anthropic/test_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,8 +850,10 @@ def test_collect_ai_data_with_input_json_delta():
output_tokens = 20
content_blocks = []

model, new_input_tokens, new_output_tokens, new_content_blocks = _collect_ai_data(
event, model, input_tokens, output_tokens, content_blocks
model, new_input_tokens, new_output_tokens, _, _, new_content_blocks = (
_collect_ai_data(
event, model, input_tokens, output_tokens, 0, 0, content_blocks
)
)

assert model is None
Expand Down Expand Up @@ -881,6 +883,8 @@ def test_set_output_data_with_input_json_delta(sentry_init):
model="",
input_tokens=10,
output_tokens=20,
cache_read_input_tokens=0,
cache_write_input_tokens=0,
content_blocks=[{"text": "".join(json_deltas), "type": "text"}],
)

Expand Down Expand Up @@ -1446,3 +1450,83 @@ def test_system_prompt_with_complex_structure(sentry_init, capture_events):
assert stored_messages[0]["content"][1]["text"] == "Be concise and clear."
assert stored_messages[1]["role"] == "user"
assert stored_messages[1]["content"] == "Hello"


def test_cache_tokens_nonstreaming(sentry_init, capture_events):
"""Test cache read/write tokens are tracked for non-streaming responses."""
sentry_init(integrations=[AnthropicIntegration()], traces_sample_rate=1.0)
events = capture_events()
client = Anthropic(api_key="z")

client.messages._post = mock.Mock(
return_value=Message(
id="id",
model="claude-3-5-sonnet-20241022",
role="assistant",
content=[TextBlock(type="text", text="Response")],
type="message",
usage=Usage(
input_tokens=100,
output_tokens=50,
cache_read_input_tokens=80,
cache_creation_input_tokens=20,
),
)
)

with start_transaction(name="anthropic"):
client.messages.create(
max_tokens=1024,
messages=[{"role": "user", "content": "Hello"}],
model="claude-3-5-sonnet-20241022",
)

(span,) = events[0]["spans"]
assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED] == 80
assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE] == 20


def test_cache_tokens_streaming(sentry_init, capture_events):
"""Test cache tokens are tracked for streaming responses."""
client = Anthropic(api_key="z")
returned_stream = Stream(cast_to=None, response=None, client=client)
returned_stream._iterator = [
MessageStartEvent(
type="message_start",
message=Message(
id="id",
model="claude-3-5-sonnet-20241022",
role="assistant",
content=[],
type="message",
usage=Usage(
input_tokens=100,
output_tokens=0,
cache_read_input_tokens=80,
cache_creation_input_tokens=20,
),
),
),
MessageDeltaEvent(
type="message_delta",
delta=Delta(stop_reason="end_turn"),
usage=MessageDeltaUsage(output_tokens=10),
),
]

sentry_init(integrations=[AnthropicIntegration()], traces_sample_rate=1.0)
events = capture_events()
client.messages._post = mock.Mock(return_value=returned_stream)

with start_transaction(name="anthropic"):
for _ in client.messages.create(
max_tokens=1024,
messages=[{"role": "user", "content": "Hello"}],
model="claude-3-5-sonnet-20241022",
stream=True,
):
pass

(span,) = events[0]["spans"]
assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED] == 80
assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE] == 20
Loading