Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,11 @@
import com.google.genai.types.Part;
import io.reactivex.rxjava3.observers.TestObserver;
import java.time.Instant;
import java.util.AbstractMap;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
Expand Down Expand Up @@ -530,44 +526,6 @@ void appendAndGet_withAllPartTypes_serializesAndDeserializesCorrectly() {
});
}

/**
* A wrapper class that implements ConcurrentMap but delegates to a HashMap. This is a workaround
* to allow putting null values, which ConcurrentHashMap forbids, for testing state removal logic.
*/
private static class HashMapAsConcurrentMap<K, V> extends AbstractMap<K, V>
implements ConcurrentMap<K, V> {
private final HashMap<K, V> map;

public HashMapAsConcurrentMap(Map<K, V> map) {
this.map = new HashMap<>(map);
}

@Override
public Set<Entry<K, V>> entrySet() {
return map.entrySet();
}

@Override
public V putIfAbsent(K key, V value) {
return map.putIfAbsent(key, value);
}

@Override
public boolean remove(Object key, Object value) {
return map.remove(key, value);
}

@Override
public boolean replace(K key, V oldValue, V newValue) {
return map.replace(key, oldValue, newValue);
}

@Override
public V replace(K key, V value) {
return map.replace(key, value);
}
}

/** Tests that appendEvent with only app state deltas updates the correct stores. */
@Test
void appendEvent_withAppOnlyStateDeltas_updatesCorrectStores() {
Expand Down Expand Up @@ -662,63 +620,6 @@ void appendEvent_withUserOnlyStateDeltas_updatesCorrectStores() {
verify(mockSessionDocRef, never()).update(eq(Constants.KEY_STATE), any());
}

/**
* Tests that appendEvent with all types of state deltas updates the correct stores and session
* state.
*/
@Test
void appendEvent_withAllStateDeltas_updatesCorrectStores() {
// Arrange
Session session =
Session.builder(SESSION_ID)
.appName(APP_NAME)
.userId(USER_ID)
.state(new ConcurrentHashMap<>()) // The session state itself must be concurrent
.build();
session.state().put("keyToRemove", "someValue");

Map<String, Object> stateDeltaMap = new HashMap<>();
stateDeltaMap.put("sessionKey", "sessionValue");
stateDeltaMap.put("_app_appKey", "appValue");
stateDeltaMap.put("_user_userKey", "userValue");
stateDeltaMap.put("keyToRemove", null);

// Use the wrapper to satisfy the ConcurrentMap interface for the builder
EventActions actions =
EventActions.builder().stateDelta(new HashMapAsConcurrentMap<>(stateDeltaMap)).build();

Event event =
Event.builder()
.author("model")
.content(Content.builder().parts(List.of(Part.fromText("..."))).build())
.actions(actions)
.build();

when(mockSessionsCollection.document(SESSION_ID)).thenReturn(mockSessionDocRef);
when(mockEventsCollection.document()).thenReturn(mockEventDocRef);
when(mockEventDocRef.getId()).thenReturn(EVENT_ID);
// THIS IS THE MISSING MOCK: Stub the call to get the document by its specific ID.
when(mockEventsCollection.document(EVENT_ID)).thenReturn(mockEventDocRef);
// Add the missing mock for the final session update call
when(mockSessionDocRef.update(anyMap()))
.thenReturn(ApiFutures.immediateFuture(mockWriteResult));

// Act
sessionService.appendEvent(session, event).test().assertComplete();

// Assert
assertThat(session.state()).containsEntry("sessionKey", "sessionValue");
assertThat(session.state()).doesNotContainKey("keyToRemove");

ArgumentCaptor<Map<String, Object>> appStateCaptor = ArgumentCaptor.forClass(Map.class);
verify(mockAppStateDocRef).set(appStateCaptor.capture(), any(SetOptions.class));
assertThat(appStateCaptor.getValue()).containsEntry("appKey", "appValue");

ArgumentCaptor<Map<String, Object>> userStateCaptor = ArgumentCaptor.forClass(Map.class);
verify(mockUserStateUserDocRef).set(userStateCaptor.capture(), any(SetOptions.class));
assertThat(userStateCaptor.getValue()).containsEntry("userKey", "userValue");
}

/** Tests that getSession skips malformed events and returns only the well-formed ones. */
@Test
@SuppressWarnings("unchecked")
Expand Down
17 changes: 6 additions & 11 deletions core/src/main/java/com/google/adk/events/EventActions.java
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,6 @@ public void setRequestedToolConfirmations(
Map<String, ToolConfirmation> requestedToolConfirmations) {
if (requestedToolConfirmations == null) {
this.requestedToolConfirmations = new ConcurrentHashMap<>();
} else if (requestedToolConfirmations instanceof ConcurrentMap) {
this.requestedToolConfirmations =
(ConcurrentMap<String, ToolConfirmation>) requestedToolConfirmations;
} else {
this.requestedToolConfirmations = new ConcurrentHashMap<>(requestedToolConfirmations);
}
Expand Down Expand Up @@ -290,8 +287,6 @@ public Builder skipSummarization(boolean skipSummarization) {
public Builder stateDelta(@Nullable Map<String, Object> value) {
if (value == null) {
this.stateDelta = new ConcurrentHashMap<>();
} else if (value instanceof ConcurrentMap) {
this.stateDelta = (ConcurrentMap<String, Object>) value;
} else {
this.stateDelta = new ConcurrentHashMap<>(value);
}
Expand All @@ -300,8 +295,12 @@ public Builder stateDelta(@Nullable Map<String, Object> value) {

@CanIgnoreReturnValue
@JsonProperty("artifactDelta")
public Builder artifactDelta(Map<String, Integer> value) {
this.artifactDelta = new ConcurrentHashMap<>(value);
public Builder artifactDelta(@Nullable Map<String, Integer> value) {
if (value == null) {
this.artifactDelta = new ConcurrentHashMap<>();
} else {
this.artifactDelta = new ConcurrentHashMap<>(value);
}
return this;
}

Expand Down Expand Up @@ -339,10 +338,6 @@ public Builder requestedAuthConfigs(
public Builder requestedToolConfirmations(@Nullable Map<String, ToolConfirmation> value) {
if (value == null) {
this.requestedToolConfirmations = new ConcurrentHashMap<>();
return this;
}
if (value instanceof ConcurrentMap) {
this.requestedToolConfirmations = (ConcurrentMap<String, ToolConfirmation>) value;
} else {
this.requestedToolConfirmations = new ConcurrentHashMap<>(value);
}
Expand Down
11 changes: 0 additions & 11 deletions core/src/test/java/com/google/adk/events/EventActionsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -177,17 +177,6 @@ public void merge_failsOnMismatchedKeyTypesNestedInStateDelta() {
IllegalArgumentException.class, () -> eventActions1.toBuilder().merge(eventActions2));
}

@Test
public void setRequestedToolConfirmations_withConcurrentMap_usesSameInstance() {
ConcurrentHashMap<String, ToolConfirmation> map = new ConcurrentHashMap<>();
map.put("tool", TOOL_CONFIRMATION);

EventActions actions = new EventActions();
actions.setRequestedToolConfirmations(map);

assertThat(actions.requestedToolConfirmations()).isSameInstanceAs(map);
}

@Test
public void setRequestedToolConfirmations_withRegularMap_createsConcurrentMap() {
ImmutableMap<String, ToolConfirmation> map = ImmutableMap.of("tool", TOOL_CONFIRMATION);
Expand Down
Loading