Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion core/src/main/java/com/google/adk/agents/InvocationContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package com.google.adk.agents;

import static com.google.common.base.Strings.isNullOrEmpty;

import com.google.adk.apps.ResumabilityConfig;
import com.google.adk.artifacts.BaseArtifactService;
import com.google.adk.events.Event;
Expand Down Expand Up @@ -747,12 +749,30 @@ public Builder callbackContextData(Map<String, Object> callbackContextData) {
*
* @throws IllegalStateException if any required parameters are missing.
*/
// TODO: b/462183912 - Add validation for required parameters.
public InvocationContext build() {
validate(this);
return new InvocationContext(this);
}
}

/**
* Validates the required parameters fields: invocationId, agent, and session.
*
* @param builder the builder to validate.
* @throws IllegalStateException if any required parameters are missing.
*/
private static void validate(Builder builder) {
if (isNullOrEmpty(builder.invocationId)) {
throw new IllegalStateException("Invocation ID must be non-empty.");
}
if (builder.agent == null) {
throw new IllegalStateException("Agent must be set.");
}
if (builder.session == null) {
throw new IllegalStateException("Session must be set.");
}
}

@Override
public boolean equals(Object o) {
if (this == o) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.google.adk.agents;

import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertThrows;
import static org.mockito.Mockito.mock;

import com.google.adk.apps.ResumabilityConfig;
Expand Down Expand Up @@ -764,7 +765,7 @@ public void testSetEndInvocation() {
}

@Test
@SuppressWarnings("deprecation") // Testing deprecated methods.
// Testing deprecated methods.
public void testBranch() {
InvocationContext context =
InvocationContext.builder()
Expand All @@ -785,7 +786,7 @@ public void testBranch() {
}

@Test
@SuppressWarnings("deprecation") // Testing deprecated methods.
// Testing deprecated methods.
public void testDeprecatedCreateMethods() {
InvocationContext context1 =
InvocationContext.builder()
Expand Down Expand Up @@ -855,7 +856,7 @@ public void testEventsCompactionConfig() {
}

@Test
@SuppressWarnings("deprecation") // Testing deprecated methods.
// Testing deprecated methods.
public void testBuilderOptionalParameters() {
InvocationContext context =
InvocationContext.builder()
Expand All @@ -874,7 +875,7 @@ public void testBuilderOptionalParameters() {
}

@Test
@SuppressWarnings("deprecation") // Testing deprecated methods.
// Testing deprecated methods.
public void testDeprecatedConstructor() {
InvocationContext context =
new InvocationContext(
Expand Down Expand Up @@ -906,7 +907,7 @@ public void testDeprecatedConstructor() {
}

@Test
@SuppressWarnings("deprecation") // Testing deprecated methods.
// Testing deprecated methods.
public void testDeprecatedConstructor_11params() {
InvocationContext context =
new InvocationContext(
Expand Down Expand Up @@ -986,4 +987,60 @@ public void populateAgentStates_populatesAgentStatesAndEndOfAgents() {
assertThat(context.endOfAgents()).hasSize(1);
assertThat(context.endOfAgents()).containsEntry("agent1", true);
}

@Test
public void build_missingInvocationId_null_throwsException() {
InvocationContext.Builder builder =
InvocationContext.builder()
.sessionService(mockSessionService)
.artifactService(mockArtifactService)
.memoryService(mockMemoryService)
.agent(mockAgent)
.invocationId(null)
.session(session);

IllegalStateException exception = assertThrows(IllegalStateException.class, builder::build);
assertThat(exception).hasMessageThat().isEqualTo("Invocation ID must be non-empty.");
}

@Test
public void build_missingInvocationId_empty_throwsException() {
InvocationContext.Builder builder =
InvocationContext.builder()
.sessionService(mockSessionService)
.artifactService(mockArtifactService)
.memoryService(mockMemoryService)
.agent(mockAgent)
.invocationId("")
.session(session);

IllegalStateException exception = assertThrows(IllegalStateException.class, builder::build);
assertThat(exception).hasMessageThat().isEqualTo("Invocation ID must be non-empty.");
}

@Test
public void build_missingAgent_throwsException() {
InvocationContext.Builder builder =
InvocationContext.builder()
.sessionService(mockSessionService)
.artifactService(mockArtifactService)
.memoryService(mockMemoryService)
.session(session);

IllegalStateException exception = assertThrows(IllegalStateException.class, builder::build);
assertThat(exception).hasMessageThat().isEqualTo("Agent must be set.");
}

@Test
public void build_missingSession_throwsException() {
InvocationContext.Builder builder =
InvocationContext.builder()
.sessionService(mockSessionService)
.artifactService(mockArtifactService)
.memoryService(mockMemoryService)
.agent(mockAgent);

IllegalStateException exception = assertThrows(IllegalStateException.class, builder::build);
assertThat(exception).hasMessageThat().isEqualTo("Session must be set.");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import com.google.adk.examples.BaseExampleProvider;
import com.google.adk.examples.Example;
import com.google.adk.models.LlmRequest;
import com.google.adk.sessions.Session;
import com.google.common.collect.ImmutableList;
import com.google.genai.types.Content;
import com.google.genai.types.Part;
Expand Down Expand Up @@ -55,6 +56,7 @@ public void processRequest_withExampleProvider_addsExamplesToInstructions() {
InvocationContext context =
InvocationContext.builder()
.invocationId("invocation1")
.session(Session.builder("session1").build())
.agent(agent)
.userContent(Content.fromParts(Part.fromText("what is up?")))
.runConfig(RunConfig.builder().build())
Expand All @@ -76,6 +78,7 @@ public void processRequest_withoutExampleProvider_doesNotAddExamplesToInstructio
InvocationContext context =
InvocationContext.builder()
.invocationId("invocation1")
.session(Session.builder("session1").build())
.agent(agent)
.userContent(Content.fromParts(Part.fromText("what is up?")))
.runConfig(RunConfig.builder().build())
Expand Down
43 changes: 36 additions & 7 deletions core/src/test/java/com/google/adk/tools/FunctionToolTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import static org.junit.Assert.assertThrows;

import com.google.adk.agents.InvocationContext;
import com.google.adk.agents.LlmAgent;
import com.google.adk.events.ToolConfirmation;
import com.google.adk.sessions.Session;
import com.google.common.collect.ImmutableList;
Expand Down Expand Up @@ -232,10 +233,14 @@ public void create_withParameterizedList() {

@Test
public void call_withAllSupportedParameterTypes() throws Exception {
LlmAgent agent = LlmAgent.builder().name("test-agent").build();
FunctionTool tool = FunctionTool.create(Functions.class, "returnAllSupportedParametersAsMap");
ToolContext toolContext =
ToolContext.builder(
InvocationContext.builder().session(Session.builder("123").build()).build())
InvocationContext.builder()
.agent(agent)
.session(Session.builder("123").build())
.build())
.functionCallId("functionCallId")
.build();

Expand Down Expand Up @@ -576,12 +581,16 @@ public void call_withMaybePojoReturnType() throws Exception {
@Test
@SuppressWarnings("BooleanLiteral")
public void call_nonStaticWithAllSupportedParameterTypes() throws Exception {
LlmAgent agent = LlmAgent.builder().name("test-agent").build();
Functions functions = new Functions();
FunctionTool tool =
FunctionTool.create(functions, "nonStaticReturnAllSupportedParametersAsMap");
ToolContext toolContext =
ToolContext.builder(
InvocationContext.builder().session(Session.builder("123").build()).build())
InvocationContext.builder()
.agent(agent)
.session(Session.builder("123").build())
.build())
.functionCallId("functionCallId")
.build();

Expand Down Expand Up @@ -627,12 +636,16 @@ public void call_nonStaticWithAllSupportedParameterTypes() throws Exception {

@Test
public void runAsync_withRequireConfirmation() throws Exception {
LlmAgent agent = LlmAgent.builder().name("test-agent").build();
Method method = Functions.class.getMethod("returnsMap");
FunctionTool tool =
new FunctionTool(null, method, /* isLongRunning= */ false, /* requireConfirmation= */ true);
ToolContext toolContext =
ToolContext.builder(
InvocationContext.builder().session(Session.builder("123").build()).build())
InvocationContext.builder()
.agent(agent)
.session(Session.builder("123").build())
.build())
.functionCallId("functionCallId")
.build();

Expand Down Expand Up @@ -660,12 +673,16 @@ public void runAsync_withRequireConfirmation() throws Exception {

@Test
public void create_instanceMethodWithConfirmation_requestsConfirmation() throws Exception {
LlmAgent agent = LlmAgent.builder().name("test-agent").build();
Functions functions = new Functions();
Method method = Functions.class.getMethod("nonStaticVoidReturnWithoutSchema");
FunctionTool tool = FunctionTool.create(functions, method, /* requireConfirmation= */ true);
ToolContext toolContext =
ToolContext.builder(
InvocationContext.builder().session(Session.builder("123").build()).build())
InvocationContext.builder()
.agent(agent)
.session(Session.builder("123").build())
.build())
.functionCallId("functionCallId")
.build();

Expand All @@ -678,11 +695,15 @@ public void create_instanceMethodWithConfirmation_requestsConfirmation() throws

@Test
public void create_staticMethodWithConfirmation_requestsConfirmation() throws Exception {
LlmAgent agent = LlmAgent.builder().name("test-agent").build();
Method method = Functions.class.getMethod("voidReturnWithoutSchema");
FunctionTool tool = FunctionTool.create(method, /* requireConfirmation= */ true);
ToolContext toolContext =
ToolContext.builder(
InvocationContext.builder().session(Session.builder("123").build()).build())
InvocationContext.builder()
.agent(agent)
.session(Session.builder("123").build())
.build())
.functionCallId("functionCallId")
.build();

Expand All @@ -695,12 +716,16 @@ public void create_staticMethodWithConfirmation_requestsConfirmation() throws Ex

@Test
public void create_classMethodNameWithConfirmation_requestsConfirmation() throws Exception {
LlmAgent agent = LlmAgent.builder().name("test-agent").build();
FunctionTool tool =
FunctionTool.create(
Functions.class, "voidReturnWithoutSchema", /* requireConfirmation= */ true);
ToolContext toolContext =
ToolContext.builder(
InvocationContext.builder().session(Session.builder("123").build()).build())
InvocationContext.builder()
.agent(agent)
.session(Session.builder("123").build())
.build())
.functionCallId("functionCallId")
.build();

Expand All @@ -713,13 +738,17 @@ public void create_classMethodNameWithConfirmation_requestsConfirmation() throws

@Test
public void create_instanceMethodNameWithConfirmation_requestsConfirmation() throws Exception {
LlmAgent agent = LlmAgent.builder().name("test-agent").build();
Functions functions = new Functions();
FunctionTool tool =
FunctionTool.create(
functions, "nonStaticVoidReturnWithoutSchema", /* requireConfirmation= */ true);
ToolContext toolContext =
ToolContext.builder(
InvocationContext.builder().session(Session.builder("123").build()).build())
InvocationContext.builder()
.agent(agent)
.session(Session.builder("123").build())
.build())
.functionCallId("functionCallId")
.build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import static org.mockito.Mockito.when;

import com.google.adk.agents.InvocationContext;
import com.google.adk.agents.LlmAgent;
import com.google.adk.models.LlmRequest;
import com.google.adk.sessions.Session;
import com.google.adk.tools.ToolContext;
Expand Down Expand Up @@ -41,6 +42,7 @@ public final class VertexAiRagRetrievalTest {

@Test
public void runAsync_withResults_returnsContexts() throws Exception {
LlmAgent agent = LlmAgent.builder().name("test-agent").build();
ImmutableList<RagResource> ragResources =
ImmutableList.of(RagResource.newBuilder().setRagCorpus("corpus1").build());
Double vectorDistanceThreshold = 0.5;
Expand All @@ -55,7 +57,10 @@ public void runAsync_withResults_returnsContexts() throws Exception {
String query = "test query";
ToolContext toolContext =
ToolContext.builder(
InvocationContext.builder().session(Session.builder("123").build()).build())
InvocationContext.builder()
.agent(agent)
.session(Session.builder("123").build())
.build())
.functionCallId("functionCallId")
.build();
RetrieveContextsRequest expectedRequest =
Expand Down Expand Up @@ -85,6 +90,7 @@ public void runAsync_withResults_returnsContexts() throws Exception {

@Test
public void runAsync_noResults_returnsNoResultFoundMessage() throws Exception {
LlmAgent agent = LlmAgent.builder().name("test-agent").build();
ImmutableList<RagResource> ragResources =
ImmutableList.of(RagResource.newBuilder().setRagCorpus("corpus1").build());
Double vectorDistanceThreshold = 0.5;
Expand All @@ -99,7 +105,10 @@ public void runAsync_noResults_returnsNoResultFoundMessage() throws Exception {
String query = "test query";
ToolContext toolContext =
ToolContext.builder(
InvocationContext.builder().session(Session.builder("123").build()).build())
InvocationContext.builder()
.agent(agent)
.session(Session.builder("123").build())
.build())
.functionCallId("functionCallId")
.build();
RetrieveContextsRequest expectedRequest =
Expand Down Expand Up @@ -129,6 +138,7 @@ public void runAsync_noResults_returnsNoResultFoundMessage() throws Exception {

@Test
public void processLlmRequest_gemini2Model_addVertexRagStoreToConfig() {
LlmAgent agent = LlmAgent.builder().name("test-agent").build();
// This test's behavior depends on the GOOGLE_GENAI_USE_VERTEXAI environment variable
boolean useVertexAi = Boolean.parseBoolean(System.getenv("GOOGLE_GENAI_USE_VERTEXAI"));
ImmutableList<RagResource> ragResources =
Expand All @@ -145,7 +155,10 @@ public void processLlmRequest_gemini2Model_addVertexRagStoreToConfig() {
LlmRequest.Builder llmRequestBuilder = LlmRequest.builder().model("gemini-2-pro");
ToolContext toolContext =
ToolContext.builder(
InvocationContext.builder().session(Session.builder("123").build()).build())
InvocationContext.builder()
.agent(agent)
.session(Session.builder("123").build())
.build())
.functionCallId("functionCallId")
.build();

Expand Down Expand Up @@ -197,6 +210,7 @@ public void processLlmRequest_gemini2Model_addVertexRagStoreToConfig() {

@Test
public void processLlmRequest_otherModel_doNotAddVertexRagStoreToConfig() {
LlmAgent agent = LlmAgent.builder().name("test-agent").build();
ImmutableList<RagResource> ragResources =
ImmutableList.of(RagResource.newBuilder().setRagCorpus("corpus1").build());
Double vectorDistanceThreshold = 0.5;
Expand All @@ -211,7 +225,10 @@ public void processLlmRequest_otherModel_doNotAddVertexRagStoreToConfig() {
LlmRequest.Builder llmRequestBuilder = LlmRequest.builder().model("gemini-1-pro");
ToolContext toolContext =
ToolContext.builder(
InvocationContext.builder().session(Session.builder("123").build()).build())
InvocationContext.builder()
.agent(agent)
.session(Session.builder("123").build())
.build())
.functionCallId("functionCallId")
.build();
GenerateContentConfig initialConfig = GenerateContentConfig.builder().build();
Expand Down