diff --git a/core/src/main/java/com/google/adk/agents/InvocationContext.java b/core/src/main/java/com/google/adk/agents/InvocationContext.java index 3b460b073..8725a0142 100644 --- a/core/src/main/java/com/google/adk/agents/InvocationContext.java +++ b/core/src/main/java/com/google/adk/agents/InvocationContext.java @@ -16,6 +16,8 @@ package com.google.adk.agents; +import static com.google.common.base.Strings.isNullOrEmpty; + import com.google.adk.apps.ResumabilityConfig; import com.google.adk.artifacts.BaseArtifactService; import com.google.adk.events.Event; @@ -747,12 +749,30 @@ public Builder callbackContextData(Map callbackContextData) { * * @throws IllegalStateException if any required parameters are missing. */ - // TODO: b/462183912 - Add validation for required parameters. public InvocationContext build() { + validate(this); return new InvocationContext(this); } } + /** + * Validates the required parameters fields: invocationId, agent, and session. + * + * @param builder the builder to validate. + * @throws IllegalStateException if any required parameters are missing. + */ + private static void validate(Builder builder) { + if (isNullOrEmpty(builder.invocationId)) { + throw new IllegalStateException("Invocation ID must be non-empty."); + } + if (builder.agent == null) { + throw new IllegalStateException("Agent must be set."); + } + if (builder.session == null) { + throw new IllegalStateException("Session must be set."); + } + } + @Override public boolean equals(Object o) { if (this == o) { diff --git a/core/src/test/java/com/google/adk/agents/InvocationContextTest.java b/core/src/test/java/com/google/adk/agents/InvocationContextTest.java index c1cb30180..0cddb65f0 100644 --- a/core/src/test/java/com/google/adk/agents/InvocationContextTest.java +++ b/core/src/test/java/com/google/adk/agents/InvocationContextTest.java @@ -17,6 +17,7 @@ package com.google.adk.agents; import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; import static org.mockito.Mockito.mock; import com.google.adk.apps.ResumabilityConfig; @@ -764,7 +765,7 @@ public void testSetEndInvocation() { } @Test - @SuppressWarnings("deprecation") // Testing deprecated methods. + // Testing deprecated methods. public void testBranch() { InvocationContext context = InvocationContext.builder() @@ -785,7 +786,7 @@ public void testBranch() { } @Test - @SuppressWarnings("deprecation") // Testing deprecated methods. + // Testing deprecated methods. public void testDeprecatedCreateMethods() { InvocationContext context1 = InvocationContext.builder() @@ -855,7 +856,7 @@ public void testEventsCompactionConfig() { } @Test - @SuppressWarnings("deprecation") // Testing deprecated methods. + // Testing deprecated methods. public void testBuilderOptionalParameters() { InvocationContext context = InvocationContext.builder() @@ -874,7 +875,7 @@ public void testBuilderOptionalParameters() { } @Test - @SuppressWarnings("deprecation") // Testing deprecated methods. + // Testing deprecated methods. public void testDeprecatedConstructor() { InvocationContext context = new InvocationContext( @@ -906,7 +907,7 @@ public void testDeprecatedConstructor() { } @Test - @SuppressWarnings("deprecation") // Testing deprecated methods. + // Testing deprecated methods. public void testDeprecatedConstructor_11params() { InvocationContext context = new InvocationContext( @@ -986,4 +987,60 @@ public void populateAgentStates_populatesAgentStatesAndEndOfAgents() { assertThat(context.endOfAgents()).hasSize(1); assertThat(context.endOfAgents()).containsEntry("agent1", true); } + + @Test + public void build_missingInvocationId_null_throwsException() { + InvocationContext.Builder builder = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .memoryService(mockMemoryService) + .agent(mockAgent) + .invocationId(null) + .session(session); + + IllegalStateException exception = assertThrows(IllegalStateException.class, builder::build); + assertThat(exception).hasMessageThat().isEqualTo("Invocation ID must be non-empty."); + } + + @Test + public void build_missingInvocationId_empty_throwsException() { + InvocationContext.Builder builder = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .memoryService(mockMemoryService) + .agent(mockAgent) + .invocationId("") + .session(session); + + IllegalStateException exception = assertThrows(IllegalStateException.class, builder::build); + assertThat(exception).hasMessageThat().isEqualTo("Invocation ID must be non-empty."); + } + + @Test + public void build_missingAgent_throwsException() { + InvocationContext.Builder builder = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .memoryService(mockMemoryService) + .session(session); + + IllegalStateException exception = assertThrows(IllegalStateException.class, builder::build); + assertThat(exception).hasMessageThat().isEqualTo("Agent must be set."); + } + + @Test + public void build_missingSession_throwsException() { + InvocationContext.Builder builder = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .memoryService(mockMemoryService) + .agent(mockAgent); + + IllegalStateException exception = assertThrows(IllegalStateException.class, builder::build); + assertThat(exception).hasMessageThat().isEqualTo("Session must be set."); + } } diff --git a/core/src/test/java/com/google/adk/flows/llmflows/ExamplesTest.java b/core/src/test/java/com/google/adk/flows/llmflows/ExamplesTest.java index 4a6216029..a2188db3b 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/ExamplesTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/ExamplesTest.java @@ -24,6 +24,7 @@ import com.google.adk.examples.BaseExampleProvider; import com.google.adk.examples.Example; import com.google.adk.models.LlmRequest; +import com.google.adk.sessions.Session; import com.google.common.collect.ImmutableList; import com.google.genai.types.Content; import com.google.genai.types.Part; @@ -55,6 +56,7 @@ public void processRequest_withExampleProvider_addsExamplesToInstructions() { InvocationContext context = InvocationContext.builder() .invocationId("invocation1") + .session(Session.builder("session1").build()) .agent(agent) .userContent(Content.fromParts(Part.fromText("what is up?"))) .runConfig(RunConfig.builder().build()) @@ -76,6 +78,7 @@ public void processRequest_withoutExampleProvider_doesNotAddExamplesToInstructio InvocationContext context = InvocationContext.builder() .invocationId("invocation1") + .session(Session.builder("session1").build()) .agent(agent) .userContent(Content.fromParts(Part.fromText("what is up?"))) .runConfig(RunConfig.builder().build()) diff --git a/core/src/test/java/com/google/adk/tools/FunctionToolTest.java b/core/src/test/java/com/google/adk/tools/FunctionToolTest.java index 5816c427a..c5d194eb3 100644 --- a/core/src/test/java/com/google/adk/tools/FunctionToolTest.java +++ b/core/src/test/java/com/google/adk/tools/FunctionToolTest.java @@ -20,6 +20,7 @@ import static org.junit.Assert.assertThrows; import com.google.adk.agents.InvocationContext; +import com.google.adk.agents.LlmAgent; import com.google.adk.events.ToolConfirmation; import com.google.adk.sessions.Session; import com.google.common.collect.ImmutableList; @@ -232,10 +233,14 @@ public void create_withParameterizedList() { @Test public void call_withAllSupportedParameterTypes() throws Exception { + LlmAgent agent = LlmAgent.builder().name("test-agent").build(); FunctionTool tool = FunctionTool.create(Functions.class, "returnAllSupportedParametersAsMap"); ToolContext toolContext = ToolContext.builder( - InvocationContext.builder().session(Session.builder("123").build()).build()) + InvocationContext.builder() + .agent(agent) + .session(Session.builder("123").build()) + .build()) .functionCallId("functionCallId") .build(); @@ -576,12 +581,16 @@ public void call_withMaybePojoReturnType() throws Exception { @Test @SuppressWarnings("BooleanLiteral") public void call_nonStaticWithAllSupportedParameterTypes() throws Exception { + LlmAgent agent = LlmAgent.builder().name("test-agent").build(); Functions functions = new Functions(); FunctionTool tool = FunctionTool.create(functions, "nonStaticReturnAllSupportedParametersAsMap"); ToolContext toolContext = ToolContext.builder( - InvocationContext.builder().session(Session.builder("123").build()).build()) + InvocationContext.builder() + .agent(agent) + .session(Session.builder("123").build()) + .build()) .functionCallId("functionCallId") .build(); @@ -627,12 +636,16 @@ public void call_nonStaticWithAllSupportedParameterTypes() throws Exception { @Test public void runAsync_withRequireConfirmation() throws Exception { + LlmAgent agent = LlmAgent.builder().name("test-agent").build(); Method method = Functions.class.getMethod("returnsMap"); FunctionTool tool = new FunctionTool(null, method, /* isLongRunning= */ false, /* requireConfirmation= */ true); ToolContext toolContext = ToolContext.builder( - InvocationContext.builder().session(Session.builder("123").build()).build()) + InvocationContext.builder() + .agent(agent) + .session(Session.builder("123").build()) + .build()) .functionCallId("functionCallId") .build(); @@ -660,12 +673,16 @@ public void runAsync_withRequireConfirmation() throws Exception { @Test public void create_instanceMethodWithConfirmation_requestsConfirmation() throws Exception { + LlmAgent agent = LlmAgent.builder().name("test-agent").build(); Functions functions = new Functions(); Method method = Functions.class.getMethod("nonStaticVoidReturnWithoutSchema"); FunctionTool tool = FunctionTool.create(functions, method, /* requireConfirmation= */ true); ToolContext toolContext = ToolContext.builder( - InvocationContext.builder().session(Session.builder("123").build()).build()) + InvocationContext.builder() + .agent(agent) + .session(Session.builder("123").build()) + .build()) .functionCallId("functionCallId") .build(); @@ -678,11 +695,15 @@ public void create_instanceMethodWithConfirmation_requestsConfirmation() throws @Test public void create_staticMethodWithConfirmation_requestsConfirmation() throws Exception { + LlmAgent agent = LlmAgent.builder().name("test-agent").build(); Method method = Functions.class.getMethod("voidReturnWithoutSchema"); FunctionTool tool = FunctionTool.create(method, /* requireConfirmation= */ true); ToolContext toolContext = ToolContext.builder( - InvocationContext.builder().session(Session.builder("123").build()).build()) + InvocationContext.builder() + .agent(agent) + .session(Session.builder("123").build()) + .build()) .functionCallId("functionCallId") .build(); @@ -695,12 +716,16 @@ public void create_staticMethodWithConfirmation_requestsConfirmation() throws Ex @Test public void create_classMethodNameWithConfirmation_requestsConfirmation() throws Exception { + LlmAgent agent = LlmAgent.builder().name("test-agent").build(); FunctionTool tool = FunctionTool.create( Functions.class, "voidReturnWithoutSchema", /* requireConfirmation= */ true); ToolContext toolContext = ToolContext.builder( - InvocationContext.builder().session(Session.builder("123").build()).build()) + InvocationContext.builder() + .agent(agent) + .session(Session.builder("123").build()) + .build()) .functionCallId("functionCallId") .build(); @@ -713,13 +738,17 @@ public void create_classMethodNameWithConfirmation_requestsConfirmation() throws @Test public void create_instanceMethodNameWithConfirmation_requestsConfirmation() throws Exception { + LlmAgent agent = LlmAgent.builder().name("test-agent").build(); Functions functions = new Functions(); FunctionTool tool = FunctionTool.create( functions, "nonStaticVoidReturnWithoutSchema", /* requireConfirmation= */ true); ToolContext toolContext = ToolContext.builder( - InvocationContext.builder().session(Session.builder("123").build()).build()) + InvocationContext.builder() + .agent(agent) + .session(Session.builder("123").build()) + .build()) .functionCallId("functionCallId") .build(); diff --git a/core/src/test/java/com/google/adk/tools/retrieval/VertexAiRagRetrievalTest.java b/core/src/test/java/com/google/adk/tools/retrieval/VertexAiRagRetrievalTest.java index c501d6a43..3caee4915 100644 --- a/core/src/test/java/com/google/adk/tools/retrieval/VertexAiRagRetrievalTest.java +++ b/core/src/test/java/com/google/adk/tools/retrieval/VertexAiRagRetrievalTest.java @@ -6,6 +6,7 @@ import static org.mockito.Mockito.when; import com.google.adk.agents.InvocationContext; +import com.google.adk.agents.LlmAgent; import com.google.adk.models.LlmRequest; import com.google.adk.sessions.Session; import com.google.adk.tools.ToolContext; @@ -41,6 +42,7 @@ public final class VertexAiRagRetrievalTest { @Test public void runAsync_withResults_returnsContexts() throws Exception { + LlmAgent agent = LlmAgent.builder().name("test-agent").build(); ImmutableList ragResources = ImmutableList.of(RagResource.newBuilder().setRagCorpus("corpus1").build()); Double vectorDistanceThreshold = 0.5; @@ -55,7 +57,10 @@ public void runAsync_withResults_returnsContexts() throws Exception { String query = "test query"; ToolContext toolContext = ToolContext.builder( - InvocationContext.builder().session(Session.builder("123").build()).build()) + InvocationContext.builder() + .agent(agent) + .session(Session.builder("123").build()) + .build()) .functionCallId("functionCallId") .build(); RetrieveContextsRequest expectedRequest = @@ -85,6 +90,7 @@ public void runAsync_withResults_returnsContexts() throws Exception { @Test public void runAsync_noResults_returnsNoResultFoundMessage() throws Exception { + LlmAgent agent = LlmAgent.builder().name("test-agent").build(); ImmutableList ragResources = ImmutableList.of(RagResource.newBuilder().setRagCorpus("corpus1").build()); Double vectorDistanceThreshold = 0.5; @@ -99,7 +105,10 @@ public void runAsync_noResults_returnsNoResultFoundMessage() throws Exception { String query = "test query"; ToolContext toolContext = ToolContext.builder( - InvocationContext.builder().session(Session.builder("123").build()).build()) + InvocationContext.builder() + .agent(agent) + .session(Session.builder("123").build()) + .build()) .functionCallId("functionCallId") .build(); RetrieveContextsRequest expectedRequest = @@ -129,6 +138,7 @@ public void runAsync_noResults_returnsNoResultFoundMessage() throws Exception { @Test public void processLlmRequest_gemini2Model_addVertexRagStoreToConfig() { + LlmAgent agent = LlmAgent.builder().name("test-agent").build(); // This test's behavior depends on the GOOGLE_GENAI_USE_VERTEXAI environment variable boolean useVertexAi = Boolean.parseBoolean(System.getenv("GOOGLE_GENAI_USE_VERTEXAI")); ImmutableList ragResources = @@ -145,7 +155,10 @@ public void processLlmRequest_gemini2Model_addVertexRagStoreToConfig() { LlmRequest.Builder llmRequestBuilder = LlmRequest.builder().model("gemini-2-pro"); ToolContext toolContext = ToolContext.builder( - InvocationContext.builder().session(Session.builder("123").build()).build()) + InvocationContext.builder() + .agent(agent) + .session(Session.builder("123").build()) + .build()) .functionCallId("functionCallId") .build(); @@ -197,6 +210,7 @@ public void processLlmRequest_gemini2Model_addVertexRagStoreToConfig() { @Test public void processLlmRequest_otherModel_doNotAddVertexRagStoreToConfig() { + LlmAgent agent = LlmAgent.builder().name("test-agent").build(); ImmutableList ragResources = ImmutableList.of(RagResource.newBuilder().setRagCorpus("corpus1").build()); Double vectorDistanceThreshold = 0.5; @@ -211,7 +225,10 @@ public void processLlmRequest_otherModel_doNotAddVertexRagStoreToConfig() { LlmRequest.Builder llmRequestBuilder = LlmRequest.builder().model("gemini-1-pro"); ToolContext toolContext = ToolContext.builder( - InvocationContext.builder().session(Session.builder("123").build()).build()) + InvocationContext.builder() + .agent(agent) + .session(Session.builder("123").build()) + .build()) .functionCallId("functionCallId") .build(); GenerateContentConfig initialConfig = GenerateContentConfig.builder().build();