diff --git a/src/bedrock_agentcore/memory/client.py b/src/bedrock_agentcore/memory/client.py index eb0d6d4..e03fcb0 100644 --- a/src/bedrock_agentcore/memory/client.py +++ b/src/bedrock_agentcore/memory/client.py @@ -860,6 +860,7 @@ def list_events( all_events.extend(events) next_token = response.get("nextToken") + # Break if: no more pages or reached max if not next_token or len(all_events) >= max_results: break diff --git a/src/bedrock_agentcore/memory/integrations/strands/session_manager.py b/src/bedrock_agentcore/memory/integrations/strands/session_manager.py index c3912a7..09e886d 100644 --- a/src/bedrock_agentcore/memory/integrations/strands/session_manager.py +++ b/src/bedrock_agentcore/memory/integrations/strands/session_manager.py @@ -3,9 +3,11 @@ import json import logging import threading +import time from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime, timedelta, timezone -from typing import TYPE_CHECKING, Any, Optional +from enum import Enum +from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar import boto3 from botocore.config import Config as BotocoreConfig @@ -19,6 +21,7 @@ from typing_extensions import override from bedrock_agentcore.memory.client import MemoryClient +from bedrock_agentcore.memory.models.filters import EventMetadataFilter, LeftExpression, OperatorType, RightExpression from .bedrock_converter import AgentCoreMemoryConverter from .config import AgentCoreMemoryConfig, RetrievalConfig @@ -28,11 +31,23 @@ logger = logging.getLogger(__name__) -SESSION_PREFIX = "session_" -AGENT_PREFIX = "agent_" -MESSAGE_PREFIX = "message_" MAX_FETCH_ALL_RESULTS = 10000 +# Legacy prefixes for backwards compatibility with old events +LEGACY_SESSION_PREFIX = "session_" +LEGACY_AGENT_PREFIX = "agent_" + +# Metadata keys for event identification +STATE_TYPE_KEY = "stateType" +AGENT_ID_KEY = "agentId" + + +class StateType(Enum): + """State type for distinguishing session and agent metadata in events.""" + + SESSION = "SESSION" + AGENT = "AGENT" + class AgentCoreMemorySessionManager(RepositorySessionManager, SessionRepository): """AgentCore Memory-based session manager for Bedrock AgentCore Memory integration. @@ -104,7 +119,6 @@ def __init__( session = boto_session or boto3.Session(region_name=region_name) self.has_existing_agent = False - # Override the clients if custom boto session or config is provided # Add strands-agents to the request user agent if boto_client_config: existing_user_agent = getattr(boto_client_config, "user_agent_extra", None) @@ -125,38 +139,6 @@ def __init__( ) super().__init__(session_id=self.config.session_id, session_repository=self) - def _get_full_session_id(self, session_id: str) -> str: - """Get the full session ID with the configured prefix. - - Args: - session_id (str): The session ID. - - Returns: - str: The full session ID with the prefix. - """ - full_session_id = f"{SESSION_PREFIX}{session_id}" - if full_session_id == self.config.actor_id: - raise SessionException( - f"Cannot have session [ {full_session_id} ] with the same ID as the actor ID: {self.config.actor_id}" - ) - return full_session_id - - def _get_full_agent_id(self, agent_id: str) -> str: - """Get the full agent ID with the configured prefix. - - Args: - agent_id (str): The agent ID. - - Returns: - str: The full agent ID with the prefix. - """ - full_agent_id = f"{AGENT_PREFIX}{agent_id}" - if full_agent_id == self.config.actor_id: - raise SessionException( - f"Cannot create agent [ {full_agent_id} ] with the same ID as the actor ID: {self.config.actor_id}" - ) - return full_agent_id - # region SessionRepository interface implementation def create_session(self, session: Session, **kwargs: Any) -> Session: """Create a new session in AgentCore Memory. @@ -179,12 +161,13 @@ def create_session(self, session: Session, **kwargs: Any) -> Session: event = self.memory_client.gmdp_client.create_event( memoryId=self.config.memory_id, - actorId=self._get_full_session_id(session.session_id), + actorId=self.config.actor_id, sessionId=self.session_id, payload=[ {"blob": json.dumps(session.to_dict())}, ], eventTimestamp=self._get_monotonic_timestamp(), + metadata={STATE_TYPE_KEY: {"stringValue": StateType.SESSION.value}}, ) logger.info("Created session: %s with event: %s", session.session_id, event.get("event", {}).get("eventId")) return session @@ -206,17 +189,50 @@ def read_session(self, session_id: str, **kwargs: Any) -> Optional[Session]: if session_id != self.config.session_id: return None + # 1. Try new approach (metadata filter) + event_metadata = [ + EventMetadataFilter.build_expression( + left_operand=LeftExpression.build(STATE_TYPE_KEY), + operator=OperatorType.EQUALS_TO, + right_operand=RightExpression.build(StateType.SESSION.value), + ) + ] + events = self.memory_client.list_events( memory_id=self.config.memory_id, - actor_id=self._get_full_session_id(session_id), + actor_id=self.config.actor_id, session_id=session_id, + event_metadata=event_metadata, max_results=1, ) - if not events: - return None + if events: + session_data = json.loads(events[0].get("payload", {})[0].get("blob")) + return Session.from_dict(session_data) - session_data = json.loads(events[0].get("payload", {})[0].get("blob")) - return Session.from_dict(session_data) + # 2. Fallback: check for legacy event and migrate + legacy_actor_id = f"{LEGACY_SESSION_PREFIX}{session_id}" + events = self.memory_client.list_events( + memory_id=self.config.memory_id, + actor_id=legacy_actor_id, + session_id=session_id, + max_results=1, + ) + if events: + old_event = events[0] + session_data = json.loads(old_event.get("payload", {})[0].get("blob")) + session = Session.from_dict(session_data) + # Migrate: create new event with metadata, delete old + self.create_session(session) + self.memory_client.gmdp_client.delete_event( + memoryId=self.config.memory_id, + actorId=legacy_actor_id, + sessionId=session_id, + eventId=old_event.get("eventId"), + ) + logger.info("Migrated legacy session event for session: %s", session_id) + return session + + return None def delete_session(self, session_id: str, **kwargs: Any) -> None: """Delete session and all associated data. @@ -250,12 +266,16 @@ def create_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: A event = self.memory_client.gmdp_client.create_event( memoryId=self.config.memory_id, - actorId=self._get_full_agent_id(session_agent.agent_id), + actorId=self.config.actor_id, sessionId=self.session_id, payload=[ {"blob": json.dumps(session_agent.to_dict())}, ], eventTimestamp=self._get_monotonic_timestamp(), + metadata={ + STATE_TYPE_KEY: {"stringValue": StateType.AGENT.value}, + AGENT_ID_KEY: {"stringValue": session_agent.agent_id}, + }, ) logger.info( "Created agent: %s in session: %s with event %s", @@ -280,18 +300,56 @@ def read_agent(self, session_id: str, agent_id: str, **kwargs: Any) -> Optional[ if session_id != self.config.session_id: return None try: + # 1. Try new approach (metadata filter) + event_metadata = [ + EventMetadataFilter.build_expression( + left_operand=LeftExpression.build(STATE_TYPE_KEY), + operator=OperatorType.EQUALS_TO, + right_operand=RightExpression.build(StateType.AGENT.value), + ), + EventMetadataFilter.build_expression( + left_operand=LeftExpression.build(AGENT_ID_KEY), + operator=OperatorType.EQUALS_TO, + right_operand=RightExpression.build(agent_id), + ), + ] + events = self.memory_client.list_events( memory_id=self.config.memory_id, - actor_id=self._get_full_agent_id(agent_id), + actor_id=self.config.actor_id, session_id=session_id, + event_metadata=event_metadata, max_results=1, ) - if not events: - return None + if events: + agent_data = json.loads(events[0].get("payload", {})[0].get("blob")) + return SessionAgent.from_dict(agent_data) - agent_data = json.loads(events[0].get("payload", {})[0].get("blob")) - return SessionAgent.from_dict(agent_data) + # 2. Fallback: check for legacy event and migrate + legacy_actor_id = f"{LEGACY_AGENT_PREFIX}{agent_id}" + events = self.memory_client.list_events( + memory_id=self.config.memory_id, + actor_id=legacy_actor_id, + session_id=session_id, + max_results=1, + ) + if events: + old_event = events[0] + agent_data = json.loads(old_event.get("payload", {})[0].get("blob")) + agent = SessionAgent.from_dict(agent_data) + # Migrate: create new event with metadata, delete old + self.create_agent(session_id, agent) + self.memory_client.gmdp_client.delete_event( + memoryId=self.config.memory_id, + actorId=legacy_actor_id, + sessionId=session_id, + eventId=old_event.get("eventId"), + ) + logger.info("Migrated legacy agent event for agent: %s", agent_id) + return agent + + return None except Exception as e: logger.error("Failed to read agent %s", e) return None @@ -311,8 +369,9 @@ def update_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: A previous_agent = self.read_agent(session_id=session_id, agent_id=agent_id) if previous_agent is None: raise SessionException(f"Agent {agent_id} in session {session_id} does not exist") + else: + session_agent.created_at = previous_agent.created_at - session_agent.created_at = previous_agent.created_at # Create a new agent as AgentCore Memory is immutable. We always get the latest one in `read_agent` self.create_agent(session_id, session_agent) diff --git a/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py b/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py index a01973c..5461da7 100644 --- a/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py +++ b/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py @@ -160,6 +160,37 @@ def test_read_session_invalid(self, session_manager): assert result is None + def test_read_session_legacy_migration(self, session_manager, mock_memory_client): + """Test reading a legacy session event triggers migration.""" + legacy_session_data = '{"session_id": "test-session-456", "session_type": "AGENT"}' + + # First call (new approach with metadata) returns empty + # Second call (legacy actor_id) returns the legacy event + mock_memory_client.list_events.side_effect = [ + [], # New approach returns nothing + [{"eventId": "legacy-event-1", "payload": [{"blob": legacy_session_data}]}], # Legacy approach + ] + mock_memory_client.gmdp_client.create_event.return_value = {"event": {"eventId": "new-event-1"}} + + result = session_manager.read_session("test-session-456") + + # Verify session was returned + assert result is not None + assert result.session_id == "test-session-456" + assert result.session_type == SessionType.AGENT + + # Verify migration: new event created with metadata + mock_memory_client.gmdp_client.create_event.assert_called_once() + create_call_kwargs = mock_memory_client.gmdp_client.create_event.call_args.kwargs + assert "metadata" in create_call_kwargs + assert create_call_kwargs["metadata"]["stateType"]["stringValue"] == "SESSION" + + # Verify migration: old event deleted + mock_memory_client.gmdp_client.delete_event.assert_called_once() + delete_call_kwargs = mock_memory_client.gmdp_client.delete_event.call_args.kwargs + assert delete_call_kwargs["actorId"] == "session_test-session-456" + assert delete_call_kwargs["eventId"] == "legacy-event-1" + def test_create_agent(self, session_manager): """Test creating an agent.""" session_agent = SessionAgent(agent_id="test-agent-123", state={}, conversation_manager_state={}) @@ -198,6 +229,40 @@ def test_read_agent_no_events(self, session_manager, mock_memory_client): assert result is None + @patch("bedrock_agentcore.memory.integrations.strands.session_manager.time.sleep") + def test_read_agent_legacy_migration(self, mock_sleep, session_manager, mock_memory_client): + """Test reading a legacy agent event triggers migration.""" + legacy_agent_data = '{"agent_id": "test-agent-123", "state": {}, "conversation_manager_state": {}}' + + # New approach with metadata is retried 3 times (all return empty) + # Then legacy actor_id approach returns the legacy event + mock_memory_client.list_events.side_effect = [ + [], # New approach - attempt 1 + [], # New approach - attempt 2 + [], # New approach - attempt 3 + [{"eventId": "legacy-agent-event-1", "payload": [{"blob": legacy_agent_data}]}], # Legacy approach + ] + mock_memory_client.gmdp_client.create_event.return_value = {"event": {"eventId": "new-agent-event-1"}} + + result = session_manager.read_agent("test-session-456", "test-agent-123") + + # Verify agent was returned + assert result is not None + assert result.agent_id == "test-agent-123" + + # Verify migration: new event created with metadata + mock_memory_client.gmdp_client.create_event.assert_called_once() + create_call_kwargs = mock_memory_client.gmdp_client.create_event.call_args.kwargs + assert "metadata" in create_call_kwargs + assert create_call_kwargs["metadata"]["stateType"]["stringValue"] == "AGENT" + assert create_call_kwargs["metadata"]["agentId"]["stringValue"] == "test-agent-123" + + # Verify migration: old event deleted + mock_memory_client.gmdp_client.delete_event.assert_called_once() + delete_call_kwargs = mock_memory_client.gmdp_client.delete_event.call_args.kwargs + assert delete_call_kwargs["actorId"] == "agent_test-agent-123" + assert delete_call_kwargs["eventId"] == "legacy-agent-event-1" + def test_create_message(self, session_manager, mock_memory_client): """Test creating a message.""" mock_memory_client.create_event.return_value = {"eventId": "event-123"} @@ -925,22 +990,6 @@ def test_init_with_boto_config(self, agentcore_config, mock_memory_client): manager = AgentCoreMemorySessionManager(agentcore_config, boto_client_config=boto_config) assert manager.memory_client is not None - def test_get_full_session_id_conflict(self, session_manager): - """Test session ID conflict with actor ID.""" - # Set up a scenario where session ID would conflict with actor ID - session_manager.config.actor_id = "session_test-session" - - with pytest.raises(SessionException, match="Cannot have session"): - session_manager._get_full_session_id("test-session") - - def test_get_full_agent_id_conflict(self, session_manager): - """Test agent ID conflict with actor ID.""" - # Set up a scenario where agent ID would conflict with actor ID - session_manager.config.actor_id = "agent_test-agent" - - with pytest.raises(SessionException, match="Cannot create agent"): - session_manager._get_full_agent_id("test-agent") - def test_retrieve_customer_context_no_messages(self, agentcore_config_with_retrieval, mock_memory_client): """Test retrieve_customer_context with no messages.""" with patch(