Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/bedrock_agentcore/memory/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,6 +860,7 @@ def list_events(
all_events.extend(events)

next_token = response.get("nextToken")
# Break if: no more pages or reached max
if not next_token or len(all_events) >= max_results:
break

Expand Down
159 changes: 109 additions & 50 deletions src/bedrock_agentcore/memory/integrations/strands/session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import json
import logging
import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, Optional
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar

import boto3
from botocore.config import Config as BotocoreConfig
Expand All @@ -19,6 +21,7 @@
from typing_extensions import override

from bedrock_agentcore.memory.client import MemoryClient
from bedrock_agentcore.memory.models.filters import EventMetadataFilter, LeftExpression, OperatorType, RightExpression

from .bedrock_converter import AgentCoreMemoryConverter
from .config import AgentCoreMemoryConfig, RetrievalConfig
Expand All @@ -28,11 +31,23 @@

logger = logging.getLogger(__name__)

SESSION_PREFIX = "session_"
AGENT_PREFIX = "agent_"
MESSAGE_PREFIX = "message_"
MAX_FETCH_ALL_RESULTS = 10000

# Legacy prefixes for backwards compatibility with old events
LEGACY_SESSION_PREFIX = "session_"
LEGACY_AGENT_PREFIX = "agent_"

# Metadata keys for event identification
STATE_TYPE_KEY = "stateType"
AGENT_ID_KEY = "agentId"


class StateType(Enum):
"""State type for distinguishing session and agent metadata in events."""

SESSION = "SESSION"
AGENT = "AGENT"


class AgentCoreMemorySessionManager(RepositorySessionManager, SessionRepository):
"""AgentCore Memory-based session manager for Bedrock AgentCore Memory integration.
Expand Down Expand Up @@ -104,7 +119,6 @@ def __init__(
session = boto_session or boto3.Session(region_name=region_name)
self.has_existing_agent = False

# Override the clients if custom boto session or config is provided
# Add strands-agents to the request user agent
if boto_client_config:
existing_user_agent = getattr(boto_client_config, "user_agent_extra", None)
Expand All @@ -125,38 +139,6 @@ def __init__(
)
super().__init__(session_id=self.config.session_id, session_repository=self)

def _get_full_session_id(self, session_id: str) -> str:
"""Get the full session ID with the configured prefix.

Args:
session_id (str): The session ID.

Returns:
str: The full session ID with the prefix.
"""
full_session_id = f"{SESSION_PREFIX}{session_id}"
if full_session_id == self.config.actor_id:
raise SessionException(
f"Cannot have session [ {full_session_id} ] with the same ID as the actor ID: {self.config.actor_id}"
)
return full_session_id

def _get_full_agent_id(self, agent_id: str) -> str:
"""Get the full agent ID with the configured prefix.

Args:
agent_id (str): The agent ID.

Returns:
str: The full agent ID with the prefix.
"""
full_agent_id = f"{AGENT_PREFIX}{agent_id}"
if full_agent_id == self.config.actor_id:
raise SessionException(
f"Cannot create agent [ {full_agent_id} ] with the same ID as the actor ID: {self.config.actor_id}"
)
return full_agent_id

# region SessionRepository interface implementation
def create_session(self, session: Session, **kwargs: Any) -> Session:
"""Create a new session in AgentCore Memory.
Expand All @@ -179,12 +161,13 @@ def create_session(self, session: Session, **kwargs: Any) -> Session:

event = self.memory_client.gmdp_client.create_event(
memoryId=self.config.memory_id,
actorId=self._get_full_session_id(session.session_id),
actorId=self.config.actor_id,
sessionId=self.session_id,
payload=[
{"blob": json.dumps(session.to_dict())},
],
eventTimestamp=self._get_monotonic_timestamp(),
metadata={STATE_TYPE_KEY: {"stringValue": StateType.SESSION.value}},
)
logger.info("Created session: %s with event: %s", session.session_id, event.get("event", {}).get("eventId"))
return session
Expand All @@ -206,17 +189,50 @@ def read_session(self, session_id: str, **kwargs: Any) -> Optional[Session]:
if session_id != self.config.session_id:
return None

# 1. Try new approach (metadata filter)
event_metadata = [
EventMetadataFilter.build_expression(
left_operand=LeftExpression.build(STATE_TYPE_KEY),
operator=OperatorType.EQUALS_TO,
right_operand=RightExpression.build(StateType.SESSION.value),
)
]

events = self.memory_client.list_events(
memory_id=self.config.memory_id,
actor_id=self._get_full_session_id(session_id),
actor_id=self.config.actor_id,
session_id=session_id,
event_metadata=event_metadata,
max_results=1,
)
if not events:
return None
if events:
session_data = json.loads(events[0].get("payload", {})[0].get("blob"))
return Session.from_dict(session_data)

session_data = json.loads(events[0].get("payload", {})[0].get("blob"))
return Session.from_dict(session_data)
# 2. Fallback: check for legacy event and migrate
legacy_actor_id = f"{LEGACY_SESSION_PREFIX}{session_id}"
events = self.memory_client.list_events(
memory_id=self.config.memory_id,
actor_id=legacy_actor_id,
session_id=session_id,
max_results=1,
)
if events:
old_event = events[0]
session_data = json.loads(old_event.get("payload", {})[0].get("blob"))
session = Session.from_dict(session_data)
# Migrate: create new event with metadata, delete old
self.create_session(session)
self.memory_client.gmdp_client.delete_event(
memoryId=self.config.memory_id,
actorId=legacy_actor_id,
sessionId=session_id,
eventId=old_event.get("eventId"),
)
logger.info("Migrated legacy session event for session: %s", session_id)
return session

return None

def delete_session(self, session_id: str, **kwargs: Any) -> None:
"""Delete session and all associated data.
Expand Down Expand Up @@ -250,12 +266,16 @@ def create_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: A

event = self.memory_client.gmdp_client.create_event(
memoryId=self.config.memory_id,
actorId=self._get_full_agent_id(session_agent.agent_id),
actorId=self.config.actor_id,
sessionId=self.session_id,
payload=[
{"blob": json.dumps(session_agent.to_dict())},
],
eventTimestamp=self._get_monotonic_timestamp(),
metadata={
STATE_TYPE_KEY: {"stringValue": StateType.AGENT.value},
AGENT_ID_KEY: {"stringValue": session_agent.agent_id},
},
)
logger.info(
"Created agent: %s in session: %s with event %s",
Expand All @@ -280,18 +300,56 @@ def read_agent(self, session_id: str, agent_id: str, **kwargs: Any) -> Optional[
if session_id != self.config.session_id:
return None
try:
# 1. Try new approach (metadata filter)
event_metadata = [
EventMetadataFilter.build_expression(
left_operand=LeftExpression.build(STATE_TYPE_KEY),
operator=OperatorType.EQUALS_TO,
right_operand=RightExpression.build(StateType.AGENT.value),
),
EventMetadataFilter.build_expression(
left_operand=LeftExpression.build(AGENT_ID_KEY),
operator=OperatorType.EQUALS_TO,
right_operand=RightExpression.build(agent_id),
),
]

events = self.memory_client.list_events(
memory_id=self.config.memory_id,
actor_id=self._get_full_agent_id(agent_id),
actor_id=self.config.actor_id,
session_id=session_id,
event_metadata=event_metadata,
max_results=1,
)

if not events:
return None
if events:
agent_data = json.loads(events[0].get("payload", {})[0].get("blob"))
return SessionAgent.from_dict(agent_data)

agent_data = json.loads(events[0].get("payload", {})[0].get("blob"))
return SessionAgent.from_dict(agent_data)
# 2. Fallback: check for legacy event and migrate
legacy_actor_id = f"{LEGACY_AGENT_PREFIX}{agent_id}"
events = self.memory_client.list_events(
memory_id=self.config.memory_id,
actor_id=legacy_actor_id,
session_id=session_id,
max_results=1,
)
if events:
old_event = events[0]
agent_data = json.loads(old_event.get("payload", {})[0].get("blob"))
agent = SessionAgent.from_dict(agent_data)
# Migrate: create new event with metadata, delete old
self.create_agent(session_id, agent)
self.memory_client.gmdp_client.delete_event(
memoryId=self.config.memory_id,
actorId=legacy_actor_id,
sessionId=session_id,
eventId=old_event.get("eventId"),
)
logger.info("Migrated legacy agent event for agent: %s", agent_id)
return agent

return None
except Exception as e:
logger.error("Failed to read agent %s", e)
return None
Expand All @@ -311,8 +369,9 @@ def update_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: A
previous_agent = self.read_agent(session_id=session_id, agent_id=agent_id)
if previous_agent is None:
raise SessionException(f"Agent {agent_id} in session {session_id} does not exist")
else:
session_agent.created_at = previous_agent.created_at

session_agent.created_at = previous_agent.created_at
# Create a new agent as AgentCore Memory is immutable. We always get the latest one in `read_agent`
self.create_agent(session_id, session_agent)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,37 @@ def test_read_session_invalid(self, session_manager):

assert result is None

def test_read_session_legacy_migration(self, session_manager, mock_memory_client):
"""Test reading a legacy session event triggers migration."""
legacy_session_data = '{"session_id": "test-session-456", "session_type": "AGENT"}'

# First call (new approach with metadata) returns empty
# Second call (legacy actor_id) returns the legacy event
mock_memory_client.list_events.side_effect = [
[], # New approach returns nothing
[{"eventId": "legacy-event-1", "payload": [{"blob": legacy_session_data}]}], # Legacy approach
]
mock_memory_client.gmdp_client.create_event.return_value = {"event": {"eventId": "new-event-1"}}

result = session_manager.read_session("test-session-456")

# Verify session was returned
assert result is not None
assert result.session_id == "test-session-456"
assert result.session_type == SessionType.AGENT

# Verify migration: new event created with metadata
mock_memory_client.gmdp_client.create_event.assert_called_once()
create_call_kwargs = mock_memory_client.gmdp_client.create_event.call_args.kwargs
assert "metadata" in create_call_kwargs
assert create_call_kwargs["metadata"]["stateType"]["stringValue"] == "SESSION"

# Verify migration: old event deleted
mock_memory_client.gmdp_client.delete_event.assert_called_once()
delete_call_kwargs = mock_memory_client.gmdp_client.delete_event.call_args.kwargs
assert delete_call_kwargs["actorId"] == "session_test-session-456"
assert delete_call_kwargs["eventId"] == "legacy-event-1"

def test_create_agent(self, session_manager):
"""Test creating an agent."""
session_agent = SessionAgent(agent_id="test-agent-123", state={}, conversation_manager_state={})
Expand Down Expand Up @@ -198,6 +229,40 @@ def test_read_agent_no_events(self, session_manager, mock_memory_client):

assert result is None

@patch("bedrock_agentcore.memory.integrations.strands.session_manager.time.sleep")
def test_read_agent_legacy_migration(self, mock_sleep, session_manager, mock_memory_client):
"""Test reading a legacy agent event triggers migration."""
legacy_agent_data = '{"agent_id": "test-agent-123", "state": {}, "conversation_manager_state": {}}'

# New approach with metadata is retried 3 times (all return empty)
# Then legacy actor_id approach returns the legacy event
mock_memory_client.list_events.side_effect = [
[], # New approach - attempt 1
[], # New approach - attempt 2
[], # New approach - attempt 3
[{"eventId": "legacy-agent-event-1", "payload": [{"blob": legacy_agent_data}]}], # Legacy approach
]
mock_memory_client.gmdp_client.create_event.return_value = {"event": {"eventId": "new-agent-event-1"}}

result = session_manager.read_agent("test-session-456", "test-agent-123")

# Verify agent was returned
assert result is not None
assert result.agent_id == "test-agent-123"

# Verify migration: new event created with metadata
mock_memory_client.gmdp_client.create_event.assert_called_once()
create_call_kwargs = mock_memory_client.gmdp_client.create_event.call_args.kwargs
assert "metadata" in create_call_kwargs
assert create_call_kwargs["metadata"]["stateType"]["stringValue"] == "AGENT"
assert create_call_kwargs["metadata"]["agentId"]["stringValue"] == "test-agent-123"

# Verify migration: old event deleted
mock_memory_client.gmdp_client.delete_event.assert_called_once()
delete_call_kwargs = mock_memory_client.gmdp_client.delete_event.call_args.kwargs
assert delete_call_kwargs["actorId"] == "agent_test-agent-123"
assert delete_call_kwargs["eventId"] == "legacy-agent-event-1"

def test_create_message(self, session_manager, mock_memory_client):
"""Test creating a message."""
mock_memory_client.create_event.return_value = {"eventId": "event-123"}
Expand Down Expand Up @@ -925,22 +990,6 @@ def test_init_with_boto_config(self, agentcore_config, mock_memory_client):
manager = AgentCoreMemorySessionManager(agentcore_config, boto_client_config=boto_config)
assert manager.memory_client is not None

def test_get_full_session_id_conflict(self, session_manager):
"""Test session ID conflict with actor ID."""
# Set up a scenario where session ID would conflict with actor ID
session_manager.config.actor_id = "session_test-session"

with pytest.raises(SessionException, match="Cannot have session"):
session_manager._get_full_session_id("test-session")

def test_get_full_agent_id_conflict(self, session_manager):
"""Test agent ID conflict with actor ID."""
# Set up a scenario where agent ID would conflict with actor ID
session_manager.config.actor_id = "agent_test-agent"

with pytest.raises(SessionException, match="Cannot create agent"):
session_manager._get_full_agent_id("test-agent")

def test_retrieve_customer_context_no_messages(self, agentcore_config_with_retrieval, mock_memory_client):
"""Test retrieve_customer_context with no messages."""
with patch(
Expand Down
Loading