diff --git a/azurefunctions/src/main/java/com/microsoft/durabletask/azurefunctions/DurableClientContext.java b/azurefunctions/src/main/java/com/microsoft/durabletask/azurefunctions/DurableClientContext.java index bbf5e198..34d7d549 100644 --- a/azurefunctions/src/main/java/com/microsoft/durabletask/azurefunctions/DurableClientContext.java +++ b/azurefunctions/src/main/java/com/microsoft/durabletask/azurefunctions/DurableClientContext.java @@ -9,6 +9,13 @@ import com.microsoft.azure.functions.HttpStatus; import com.microsoft.durabletask.DurableTaskClient; import com.microsoft.durabletask.DurableTaskGrpcClientFactory; +import com.microsoft.durabletask.DurableEntityClient; +import com.microsoft.durabletask.EntityInstanceId; +import com.microsoft.durabletask.EntityMetadata; +import com.microsoft.durabletask.EntityQuery; +import com.microsoft.durabletask.EntityQueryResult; +import com.microsoft.durabletask.CleanEntityStorageRequest; +import com.microsoft.durabletask.CleanEntityStorageResult; import com.microsoft.durabletask.OrchestrationMetadata; import com.microsoft.durabletask.OrchestrationRuntimeStatus; @@ -133,6 +140,79 @@ public HttpManagementPayload createHttpManagementPayload(HttpRequestMessage r return this.getClientResponseLinks(request, instanceId); } + /** + * Gets the entity client for interacting with durable entities. + *

+ * This mirrors the .NET SDK's {@code DurableTaskClient.Entities} property. + * + * @return the {@link DurableEntityClient} for this client + */ + public DurableEntityClient getEntities() { + return getClient().getEntities(); + } + + /** + * Sends a fire-and-forget signal to a durable entity. + * + * @param entityId the target entity's instance ID + * @param operationName the name of the operation to invoke on the entity + * @param input the input to pass to the operation (may be {@code null}) + */ + public void signalEntity(EntityInstanceId entityId, String operationName, Object input) { + getClient().getEntities().signalEntity(entityId, operationName, input); + } + + /** + * Sends a fire-and-forget signal to a durable entity with no input. + * + * @param entityId the target entity's instance ID + * @param operationName the name of the operation to invoke on the entity + */ + public void signalEntity(EntityInstanceId entityId, String operationName) { + getClient().getEntities().signalEntity(entityId, operationName); + } + + /** + * Gets the metadata for a durable entity, including optionally its serialized state. + * + * @param entityId the entity's instance ID + * @param includeState whether to include the entity's serialized state in the result + * @return the entity metadata, or {@code null} if the entity does not exist + */ + public EntityMetadata getEntityMetadata(EntityInstanceId entityId, boolean includeState) { + return getClient().getEntities().getEntityMetadata(entityId, includeState); + } + + /** + * Gets the metadata for a durable entity without including its serialized state. + * + * @param entityId the entity's instance ID + * @return the entity metadata, or {@code null} if the entity does not exist + */ + public EntityMetadata getEntityMetadata(EntityInstanceId entityId) { + return getClient().getEntities().getEntityMetadata(entityId); + } + + /** + * Queries the durable store for entity instances matching the specified filter criteria. + * + * @param query the query filter criteria + * @return the query result containing matching entities and an optional continuation token + */ + public EntityQueryResult queryEntities(EntityQuery query) { + return getClient().getEntities().queryEntities(query); + } + + /** + * Cleans up entity storage by removing empty entities and/or releasing orphaned locks. + * + * @param request the clean storage request specifying what to clean + * @return the result of the clean operation, including counts of removed entities and released locks + */ + public CleanEntityStorageResult cleanEntityStorage(CleanEntityStorageRequest request) { + return getClient().getEntities().cleanEntityStorage(request); + } + private HttpManagementPayload getClientResponseLinks(HttpRequestMessage request, String instanceId) { String instanceStatusURL = this.getInstanceStatusURL(request, instanceId); return new HttpManagementPayload(instanceId, instanceStatusURL, this.requiredQueryStringParameters); diff --git a/azurefunctions/src/main/java/com/microsoft/durabletask/azurefunctions/DurableEntityTrigger.java b/azurefunctions/src/main/java/com/microsoft/durabletask/azurefunctions/DurableEntityTrigger.java new file mode 100644 index 00000000..46ef4857 --- /dev/null +++ b/azurefunctions/src/main/java/com/microsoft/durabletask/azurefunctions/DurableEntityTrigger.java @@ -0,0 +1,68 @@ +/** + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See License.txt in the project root for + * license information. + */ + +package com.microsoft.durabletask.azurefunctions; + +import com.microsoft.azure.functions.annotation.CustomBinding; +import com.microsoft.azure.functions.annotation.HasImplicitOutput; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + *

+ * Azure Functions attribute for binding a function parameter to a Durable Task entity request. + *

+ * The following is an example of an entity function that uses this trigger binding to implement + * a counter entity backed by a {@code TaskEntity} subclass. + *

+ *
+ * {@literal @}FunctionName("Counter")
+ * public String counterEntity(
+ *         {@literal @}DurableEntityTrigger(name = "req") String req) {
+ *     return EntityRunner.loadAndRun(req, () -> new CounterEntity());
+ * }
+ * 
+ * + * @since 2.0.0 + */ +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.PARAMETER) +@CustomBinding(direction = "in", name = "", type = "entityTrigger") +@HasImplicitOutput +public @interface DurableEntityTrigger { + /** + *

The name of the entity function.

+ *

If not specified, the function name is used as the name of the entity.

+ *

This property supports binding parameters.

+ * + * @return The name of the entity function. + */ + String entityName() default ""; + + /** + * The variable name used in function.json. + * + * @return The variable name used in function.json. + */ + String name(); + + /** + *

+ * Defines how Functions runtime should treat the parameter value. Possible values are: + *

+ * + * + * @return The dataType which will be used by the Functions runtime. + */ + String dataType() default "string"; +} diff --git a/azurefunctions/src/main/java/com/microsoft/durabletask/azurefunctions/internal/middleware/EntityMiddleware.java b/azurefunctions/src/main/java/com/microsoft/durabletask/azurefunctions/internal/middleware/EntityMiddleware.java new file mode 100644 index 00000000..e9f2ea63 --- /dev/null +++ b/azurefunctions/src/main/java/com/microsoft/durabletask/azurefunctions/internal/middleware/EntityMiddleware.java @@ -0,0 +1,49 @@ +/** + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See License.txt in the project root for + * license information. + */ + +package com.microsoft.durabletask.azurefunctions.internal.middleware; + +import com.microsoft.azure.functions.internal.spi.middleware.Middleware; +import com.microsoft.azure.functions.internal.spi.middleware.MiddlewareChain; +import com.microsoft.azure.functions.internal.spi.middleware.MiddlewareContext; +import com.microsoft.durabletask.DataConverter; + +/** + * Durable Function Entity Middleware + * + *

This class is internal and is hence not for public use. Its APIs are unstable and can change + * at any time. + */ +public class EntityMiddleware implements Middleware { + + private static final String ENTITY_TRIGGER = "DurableEntityTrigger"; + + @Override + public void invoke(MiddlewareContext context, MiddlewareChain chain) throws Exception { + String parameterName = context.getParameterName(ENTITY_TRIGGER); + if (parameterName == null) { + chain.doNext(context); + return; + } + + // The entity function receives the raw base64-encoded EntityBatchRequest as a String. + // The user function is expected to call EntityRunner.loadAndRun() with a TaskEntityFactory + // and return the base64-encoded EntityBatchResult. + // + // Unlike orchestrations, entity operations are simple request/response calls with no + // replay-based blocking (no OrchestratorBlockedException equivalent), so the middleware + // delegates directly to the user function. + try { + chain.doNext(context); + } catch (Exception e) { + Throwable cause = e.getCause(); + if (cause instanceof DataConverter.DataConverterException) { + throw (DataConverter.DataConverterException) cause; + } + throw new RuntimeException("Unexpected failure in entity function execution", e); + } + } +} diff --git a/azurefunctions/src/main/resources/META-INF/services/com.microsoft.azure.functions.internal.spi.middleware.Middleware b/azurefunctions/src/main/resources/META-INF/services/com.microsoft.azure.functions.internal.spi.middleware.Middleware index 26168496..0ba98d04 100644 --- a/azurefunctions/src/main/resources/META-INF/services/com.microsoft.azure.functions.internal.spi.middleware.Middleware +++ b/azurefunctions/src/main/resources/META-INF/services/com.microsoft.azure.functions.internal.spi.middleware.Middleware @@ -1 +1,2 @@ -com.microsoft.durabletask.azurefunctions.internal.middleware.OrchestrationMiddleware \ No newline at end of file +com.microsoft.durabletask.azurefunctions.internal.middleware.OrchestrationMiddleware +com.microsoft.durabletask.azurefunctions.internal.middleware.EntityMiddleware \ No newline at end of file diff --git a/azuremanaged/build.gradle b/azuremanaged/build.gradle index bdd640be..293d86e7 100644 --- a/azuremanaged/build.gradle +++ b/azuremanaged/build.gradle @@ -58,7 +58,8 @@ compileTestJava { sourceCompatibility = JavaVersion.VERSION_11 targetCompatibility = JavaVersion.VERSION_11 options.fork = true - options.forkOptions.executable = "${PATH_TO_TEST_JAVA_RUNTIME}/bin/javac" + def javacExe = org.gradle.internal.os.OperatingSystem.current().isWindows() ? 'javac.exe' : 'javac' + options.forkOptions.executable = "${PATH_TO_TEST_JAVA_RUNTIME}/bin/${javacExe}" } test { diff --git a/client/build.gradle b/client/build.gradle index 9ba98dce..1829ca2c 100644 --- a/client/build.gradle +++ b/client/build.gradle @@ -57,7 +57,8 @@ compileTestJava { sourceCompatibility = JavaVersion.VERSION_11 targetCompatibility = JavaVersion.VERSION_11 options.fork = true - options.forkOptions.executable = "${PATH_TO_TEST_JAVA_RUNTIME}/bin/javac" + def javacExe = org.gradle.internal.os.OperatingSystem.current().isWindows() ? 'javac.exe' : 'javac' + options.forkOptions.executable = "${PATH_TO_TEST_JAVA_RUNTIME}/bin/${javacExe}" } task downloadProtoFiles { @@ -107,7 +108,8 @@ sourceSets { } tasks.withType(Test) { - executable = new File("${PATH_TO_TEST_JAVA_RUNTIME}", 'bin/java') + def javaExe = org.gradle.internal.os.OperatingSystem.current().isWindows() ? 'java.exe' : 'java' + executable = new File("${PATH_TO_TEST_JAVA_RUNTIME}", "bin/${javaExe}") } test { diff --git a/client/src/main/java/com/microsoft/durabletask/CallEntityOptions.java b/client/src/main/java/com/microsoft/durabletask/CallEntityOptions.java new file mode 100644 index 00000000..3b92ff32 --- /dev/null +++ b/client/src/main/java/com/microsoft/durabletask/CallEntityOptions.java @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.microsoft.durabletask; + +import javax.annotation.Nullable; +import java.time.Duration; + +/** + * Options for calling a durable entity and waiting for a response. + */ +public final class CallEntityOptions { + private Duration timeout; + + /** + * Creates a new {@code CallEntityOptions} with default settings. + */ + public CallEntityOptions() { + } + + /** + * Sets the timeout for the entity call. If the entity does not respond within this duration, + * the call will fail with a timeout exception. + * + * @param timeout the maximum duration to wait for a response + * @return this {@code CallEntityOptions} object for chaining + */ + public CallEntityOptions setTimeout(@Nullable Duration timeout) { + this.timeout = timeout; + return this; + } + + /** + * Gets the timeout for the entity call, or {@code null} if no timeout is configured. + * + * @return the timeout duration, or {@code null} + */ + @Nullable + public Duration getTimeout() { + return this.timeout; + } +} diff --git a/client/src/main/java/com/microsoft/durabletask/CleanEntityStorageRequest.java b/client/src/main/java/com/microsoft/durabletask/CleanEntityStorageRequest.java new file mode 100644 index 00000000..349dea3a --- /dev/null +++ b/client/src/main/java/com/microsoft/durabletask/CleanEntityStorageRequest.java @@ -0,0 +1,109 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.microsoft.durabletask; + +import javax.annotation.Nullable; + +/** + * Represents a request to clean up entity storage by removing empty entities and/or releasing orphaned locks. + *

+ * Use the builder-style setters to configure the request, then pass it to + * {@link DurableEntityClient#cleanEntityStorage(CleanEntityStorageRequest)}. + */ +public final class CleanEntityStorageRequest { + private String continuationToken; + private boolean removeEmptyEntities = true; + private boolean releaseOrphanedLocks = true; + + /** + * Creates a new {@code CleanEntityStorageRequest} with default settings. + * By default, both {@code removeEmptyEntities} and {@code releaseOrphanedLocks} are {@code true}. + */ + public CleanEntityStorageRequest() { + } + + /** + * Sets the continuation token for resuming a previous clean operation. + * + * @param continuationToken the continuation token, or {@code null} to start from the beginning + * @return this {@code CleanEntityStorageRequest} for chaining + */ + public CleanEntityStorageRequest setContinuationToken(@Nullable String continuationToken) { + this.continuationToken = continuationToken; + return this; + } + + /** + * Gets the continuation token for resuming a previous clean operation. + * + * @return the continuation token, or {@code null} + */ + @Nullable + public String getContinuationToken() { + return this.continuationToken; + } + + /** + * Sets whether to remove entities that have no state and no pending operations. + * + * @param removeEmptyEntities {@code true} to remove empty entities + * @return this {@code CleanEntityStorageRequest} for chaining + */ + public CleanEntityStorageRequest setRemoveEmptyEntities(boolean removeEmptyEntities) { + this.removeEmptyEntities = removeEmptyEntities; + return this; + } + + /** + * Gets whether empty entities should be removed. + * + * @return {@code true} if empty entities will be removed + */ + public boolean isRemoveEmptyEntities() { + return this.removeEmptyEntities; + } + + /** + * Sets whether to release locks held by orchestrations that no longer exist. + * + * @param releaseOrphanedLocks {@code true} to release orphaned locks + * @return this {@code CleanEntityStorageRequest} for chaining + */ + public CleanEntityStorageRequest setReleaseOrphanedLocks(boolean releaseOrphanedLocks) { + this.releaseOrphanedLocks = releaseOrphanedLocks; + return this; + } + + /** + * Gets whether orphaned locks should be released. + * + * @return {@code true} if orphaned locks will be released + */ + public boolean isReleaseOrphanedLocks() { + return this.releaseOrphanedLocks; + } + + /** + * Sets whether the client should automatically continue cleaning with continuation tokens + * until all entities have been processed. When {@code true}, the client will loop internally, + * accumulating results across multiple pages. + * + * @param continueUntilComplete {@code true} to automatically continue until complete + * @return this {@code CleanEntityStorageRequest} for chaining + */ + public CleanEntityStorageRequest setContinueUntilComplete(boolean continueUntilComplete) { + this.continueUntilComplete = continueUntilComplete; + return this; + } + + /** + * Gets whether the client should automatically continue cleaning until all entities are processed. + * + * @return {@code true} if the client will automatically continue until complete + */ + public boolean isContinueUntilComplete() { + return this.continueUntilComplete; + } + + private boolean continueUntilComplete; +} diff --git a/client/src/main/java/com/microsoft/durabletask/CleanEntityStorageResult.java b/client/src/main/java/com/microsoft/durabletask/CleanEntityStorageResult.java new file mode 100644 index 00000000..f5ad588d --- /dev/null +++ b/client/src/main/java/com/microsoft/durabletask/CleanEntityStorageResult.java @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.microsoft.durabletask; + +import javax.annotation.Nullable; + +/** + * Represents the result of a {@link DurableEntityClient#cleanEntityStorage(CleanEntityStorageRequest)} operation. + */ +public final class CleanEntityStorageResult { + private final String continuationToken; + private final int emptyEntitiesRemoved; + private final int orphanedLocksReleased; + + /** + * Creates a new {@code CleanEntityStorageResult}. + * + * @param continuationToken the continuation token for resuming the clean operation, or {@code null} if complete + * @param emptyEntitiesRemoved the number of empty entities removed in this batch + * @param orphanedLocksReleased the number of orphaned locks released in this batch + */ + CleanEntityStorageResult( + @Nullable String continuationToken, + int emptyEntitiesRemoved, + int orphanedLocksReleased) { + this.continuationToken = continuationToken; + this.emptyEntitiesRemoved = emptyEntitiesRemoved; + this.orphanedLocksReleased = orphanedLocksReleased; + } + + /** + * Gets the continuation token for resuming the clean operation. + * If {@code null}, the clean operation has processed all entities. + * + * @return the continuation token, or {@code null} if the operation is complete + */ + @Nullable + public String getContinuationToken() { + return this.continuationToken; + } + + /** + * Gets the number of empty entities that were removed in this batch. + * + * @return the count of empty entities removed + */ + public int getEmptyEntitiesRemoved() { + return this.emptyEntitiesRemoved; + } + + /** + * Gets the number of orphaned locks that were released in this batch. + * + * @return the count of orphaned locks released + */ + public int getOrphanedLocksReleased() { + return this.orphanedLocksReleased; + } +} diff --git a/client/src/main/java/com/microsoft/durabletask/DurableEntityClient.java b/client/src/main/java/com/microsoft/durabletask/DurableEntityClient.java new file mode 100644 index 00000000..45ed5765 --- /dev/null +++ b/client/src/main/java/com/microsoft/durabletask/DurableEntityClient.java @@ -0,0 +1,232 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.microsoft.durabletask; + +import javax.annotation.Nullable; + +/** + * Client for interacting with durable entities. + *

+ * This class provides operations for signaling entities, querying entity metadata, + * and performing entity storage maintenance. Instances are obtained from + * {@link DurableTaskClient#getEntities()}. + *

+ * This design mirrors the .NET SDK's {@code DurableEntityClient} which is accessed + * via the {@code DurableTaskClient.Entities} property. + */ +public abstract class DurableEntityClient { + + private final String name; + + /** + * Creates a new {@code DurableEntityClient} instance. + * + * @param name the name of the client + */ + protected DurableEntityClient(String name) { + this.name = name; + } + + /** + * Gets the name of this client. + * + * @return the client name + */ + public String getName() { + return this.name; + } + + /** + * Sends a signal to a durable entity instance without waiting for a response. + *

+ * If the target entity does not exist, it will be created automatically when it receives the signal. + * + * @param entityId the target entity's instance ID + * @param operationName the name of the operation to invoke on the entity + */ + public void signalEntity(EntityInstanceId entityId, String operationName) { + this.signalEntity(entityId, operationName, null, null); + } + + /** + * Sends a signal with input to a durable entity instance without waiting for a response. + *

+ * If the target entity does not exist, it will be created automatically when it receives the signal. + * + * @param entityId the target entity's instance ID + * @param operationName the name of the operation to invoke on the entity + * @param input the serializable input for the operation, or {@code null} + */ + public void signalEntity(EntityInstanceId entityId, String operationName, @Nullable Object input) { + this.signalEntity(entityId, operationName, input, null); + } + + /** + * Sends a signal with input and options to a durable entity instance without waiting for a response. + *

+ * If the target entity does not exist, it will be created automatically when it receives the signal. + * Use {@link SignalEntityOptions#setScheduledTime(java.time.Instant)} to schedule the signal for + * delivery at a future time. + * + * @param entityId the target entity's instance ID + * @param operationName the name of the operation to invoke on the entity + * @param input the serializable input for the operation, or {@code null} + * @param options additional options for the signal, or {@code null} + */ + public abstract void signalEntity( + EntityInstanceId entityId, + String operationName, + @Nullable Object input, + @Nullable SignalEntityOptions options); + + /** + * Fetches the metadata for a durable entity instance, including its state by default. + *

+ * This matches the .NET SDK behavior where {@code includeState} defaults to {@code true}. + * + * @param entityId the entity instance ID to query + * @return the entity metadata, or {@code null} if the entity does not exist + */ + @Nullable + public EntityMetadata getEntityMetadata(EntityInstanceId entityId) { + return this.getEntityMetadata(entityId, true); + } + + /** + * Fetches the metadata for a durable entity instance, optionally including its state. + * + * @param entityId the entity instance ID to query + * @param includeState {@code true} to include the entity's serialized state in the result + * @return the entity metadata, or {@code null} if the entity does not exist + */ + @Nullable + public abstract EntityMetadata getEntityMetadata(EntityInstanceId entityId, boolean includeState); + + /** + * Fetches the metadata for a durable entity instance with typed state access. + *

+ * This always includes state in the result, matching the .NET SDK's + * {@code GetEntityAsync()} pattern. The returned {@link TypedEntityMetadata} provides + * a {@link TypedEntityMetadata#getState()} method for direct typed state access. + * + *

{@code
+     * TypedEntityMetadata metadata = client.getEntities()
+     *     .getEntityMetadata(entityId, Integer.class);
+     * if (metadata != null) {
+     *     Integer state = metadata.getState();
+     *     System.out.println("Counter value: " + state);
+     * }
+     * }
+ * + * @param entityId the entity instance ID to query + * @param stateType the class to deserialize the entity's state into + * @param the entity state type + * @return the typed entity metadata with state, or {@code null} if the entity does not exist + */ + @Nullable + public TypedEntityMetadata getEntityMetadata(EntityInstanceId entityId, Class stateType) { + EntityMetadata metadata = this.getEntityMetadata(entityId, true); + if (metadata == null) { + return null; + } + return new TypedEntityMetadata<>(metadata, stateType); + } + + /** + * Queries the durable store for entity instances matching the specified filter criteria. + * + * @param query the query filter criteria + * @return the query result containing matching entities and an optional continuation token + */ + public abstract EntityQueryResult queryEntities(EntityQuery query); + + /** + * Returns an auto-paginating iterable over entity instances matching the specified filter criteria. + *

+ * This method automatically handles pagination when iterating over results. It fetches pages + * from the store on demand, making it convenient when you want to process all matching entities + * without manually managing continuation tokens. + *

+ * You can iterate over individual items: + *

{@code
+     * for (EntityMetadata entity : client.getEntities().getAllEntities(query)) {
+     *     System.out.println(entity.getEntityInstanceId());
+     * }
+     * }
+ *

+ * Or iterate page by page for more control: + *

{@code
+     * for (EntityQueryResult page : client.getEntities().getAllEntities(query).byPage()) {
+     *     for (EntityMetadata entity : page.getEntities()) {
+     *         System.out.println(entity.getEntityInstanceId());
+     *     }
+     * }
+     * }
+ * + * @param query the query filter criteria + * @return a pageable iterable over all matching entities + */ + public EntityQueryPageable getAllEntities(EntityQuery query) { + return new EntityQueryPageable(query, this::queryEntities); + } + + /** + * Returns an auto-paginating iterable over all entity instances. + *

+ * This is a convenience overload equivalent to {@code getAllEntities(new EntityQuery())}. + * + * @return a pageable iterable over all entities + */ + public EntityQueryPageable getAllEntities() { + return getAllEntities(new EntityQuery()); + } + + /** + * Returns an auto-paginating iterable over entity instances matching the specified filter criteria, + * with typed state access. + *

+ * This mirrors the .NET SDK's {@code GetAllEntitiesAsync()} pattern. Entity state is always + * included in the results and eagerly deserialized into the specified type. Each item is a + * {@link TypedEntityMetadata} with a {@link TypedEntityMetadata#getState()} accessor. + *

+ * Note: A copy of the query is made with {@code includeState} set to {@code true} so the + * original query is not modified. + * + *

{@code
+     * EntityQuery query = new EntityQuery().setInstanceIdStartsWith("counter");
+     * for (TypedEntityMetadata entity : client.getEntities().getAllEntities(query, Integer.class)) {
+     *     Integer state = entity.getState();
+     *     System.out.println("Counter value: " + state);
+     * }
+     * }
+ * + * @param query the query filter criteria + * @param stateType the class to deserialize each entity's state into + * @param the entity state type + * @return a pageable iterable over all matching entities with typed state + */ + public TypedEntityQueryPageable getAllEntities(EntityQuery query, Class stateType) { + // Create a copy with includeState=true so we don't mutate the caller's query + EntityQuery typedQuery = new EntityQuery() + .setInstanceIdStartsWith(query.getInstanceIdStartsWith()) + .setLastModifiedFrom(query.getLastModifiedFrom()) + .setLastModifiedTo(query.getLastModifiedTo()) + .setIncludeState(true) + .setIncludeTransient(query.isIncludeTransient()) + .setPageSize(query.getPageSize()) + .setContinuationToken(query.getContinuationToken()); + EntityQueryPageable inner = new EntityQueryPageable(typedQuery, this::queryEntities); + return new TypedEntityQueryPageable<>(inner, stateType); + } + + /** + * Cleans up entity storage by removing empty entities and/or releasing orphaned locks. + *

+ * This is an administrative operation that can be used to reclaim storage space and fix + * entity state inconsistencies. + * + * @param request the clean storage request specifying what to clean + * @return the result of the clean operation, including counts of removed entities and released locks + */ + public abstract CleanEntityStorageResult cleanEntityStorage(CleanEntityStorageRequest request); +} diff --git a/client/src/main/java/com/microsoft/durabletask/DurableTaskClient.java b/client/src/main/java/com/microsoft/durabletask/DurableTaskClient.java index 1e1b3cb0..adb33673 100644 --- a/client/src/main/java/com/microsoft/durabletask/DurableTaskClient.java +++ b/client/src/main/java/com/microsoft/durabletask/DurableTaskClient.java @@ -356,4 +356,142 @@ public void resumeInstance(String instanceId) { * @param reason the reason for resuming the orchestration instance */ public abstract void resumeInstance(String instanceId, @Nullable String reason); + + // region Entity APIs + + /** + * Gets the entity client for interacting with durable entities. + *

+ * This mirrors the .NET SDK's {@code DurableTaskClient.Entities} property, providing a + * dedicated client for entity operations such as signaling, querying, and storage management. + * + * @return the {@link DurableEntityClient} for this client + * @throws UnsupportedOperationException if the current client implementation does not support entities + */ + public DurableEntityClient getEntities() { + throw new UnsupportedOperationException("Entity operations are not supported by this client implementation."); + } + + /** + * Sends a signal to a durable entity instance without waiting for a response. + *

+ * If the target entity does not exist, it will be created automatically when it receives the signal. + * + * @param entityId the target entity's instance ID + * @param operationName the name of the operation to invoke on the entity + * @deprecated Use {@code getEntities().signalEntity(entityId, operationName)} instead. + */ + @Deprecated + public void signalEntity(EntityInstanceId entityId, String operationName) { + this.getEntities().signalEntity(entityId, operationName); + } + + /** + * Sends a signal with input to a durable entity instance without waiting for a response. + *

+ * If the target entity does not exist, it will be created automatically when it receives the signal. + * + * @param entityId the target entity's instance ID + * @param operationName the name of the operation to invoke on the entity + * @param input the serializable input for the operation, or {@code null} + * @deprecated Use {@code getEntities().signalEntity(entityId, operationName, input)} instead. + */ + @Deprecated + public void signalEntity(EntityInstanceId entityId, String operationName, @Nullable Object input) { + this.getEntities().signalEntity(entityId, operationName, input); + } + + /** + * Sends a signal with input and options to a durable entity instance without waiting for a response. + * + * @param entityId the target entity's instance ID + * @param operationName the name of the operation to invoke on the entity + * @param input the serializable input for the operation, or {@code null} + * @param options additional options for the signal, or {@code null} + * @deprecated Use {@code getEntities().signalEntity(entityId, operationName, input, options)} instead. + */ + @Deprecated + public void signalEntity( + EntityInstanceId entityId, + String operationName, + @Nullable Object input, + @Nullable SignalEntityOptions options) { + this.getEntities().signalEntity(entityId, operationName, input, options); + } + + /** + * Fetches the metadata for a durable entity instance, excluding its state. + * + * @param entityId the entity instance ID to query + * @return the entity metadata, or {@code null} if the entity does not exist + * @deprecated Use {@code getEntities().getEntityMetadata(entityId)} instead. + */ + @Deprecated + @Nullable + public EntityMetadata getEntityMetadata(EntityInstanceId entityId) { + return this.getEntities().getEntityMetadata(entityId); + } + + /** + * Fetches the metadata for a durable entity instance, optionally including its state. + * + * @param entityId the entity instance ID to query + * @param includeState {@code true} to include the entity's serialized state in the result + * @return the entity metadata, or {@code null} if the entity does not exist + * @deprecated Use {@code getEntities().getEntityMetadata(entityId, includeState)} instead. + */ + @Deprecated + @Nullable + public EntityMetadata getEntityMetadata(EntityInstanceId entityId, boolean includeState) { + return this.getEntities().getEntityMetadata(entityId, includeState); + } + + /** + * Queries the durable store for entity instances matching the specified filter criteria. + * + * @param query the query filter criteria + * @return the query result containing matching entities and an optional continuation token + * @deprecated Use {@code getEntities().queryEntities(query)} instead. + */ + @Deprecated + public EntityQueryResult queryEntities(EntityQuery query) { + return this.getEntities().queryEntities(query); + } + + /** + * Returns an auto-paginating iterable over entity instances matching the specified filter criteria. + * + * @param query the query filter criteria + * @return a pageable iterable over all matching entities + * @deprecated Use {@code getEntities().getAllEntities(query)} instead. + */ + @Deprecated + public EntityQueryPageable getAllEntities(EntityQuery query) { + return this.getEntities().getAllEntities(query); + } + + /** + * Returns an auto-paginating iterable over all entity instances. + * + * @return a pageable iterable over all entities + * @deprecated Use {@code getEntities().getAllEntities()} instead. + */ + @Deprecated + public EntityQueryPageable getAllEntities() { + return this.getEntities().getAllEntities(); + } + + /** + * Cleans up entity storage by removing empty entities and/or releasing orphaned locks. + * + * @param request the clean storage request specifying what to clean + * @return the result of the clean operation, including counts of removed entities and released locks + * @deprecated Use {@code getEntities().cleanEntityStorage(request)} instead. + */ + @Deprecated + public CleanEntityStorageResult cleanEntityStorage(CleanEntityStorageRequest request) { + return this.getEntities().cleanEntityStorage(request); + } + + // endregion } \ No newline at end of file diff --git a/client/src/main/java/com/microsoft/durabletask/DurableTaskGrpcClient.java b/client/src/main/java/com/microsoft/durabletask/DurableTaskGrpcClient.java index 897c93f0..0801c0ac 100644 --- a/client/src/main/java/com/microsoft/durabletask/DurableTaskGrpcClient.java +++ b/client/src/main/java/com/microsoft/durabletask/DurableTaskGrpcClient.java @@ -33,6 +33,7 @@ public final class DurableTaskGrpcClient extends DurableTaskClient { private final ManagedChannel managedSidecarChannel; private final TaskHubSidecarServiceBlockingStub sidecarClient; private final String defaultVersion; + private final GrpcDurableEntityClient entityClient; DurableTaskGrpcClient(DurableTaskGrpcClientBuilder builder) { this.dataConverter = builder.dataConverter != null ? builder.dataConverter : new JacksonDataConverter(); @@ -59,6 +60,7 @@ public final class DurableTaskGrpcClient extends DurableTaskClient { } this.sidecarClient = TaskHubSidecarServiceGrpc.newBlockingStub(sidecarGrpcChannel); + this.entityClient = new GrpcDurableEntityClient("GrpcDurableEntityClient", this.sidecarClient, this.dataConverter); } DurableTaskGrpcClient(int port, String defaultVersion) { @@ -71,6 +73,7 @@ public final class DurableTaskGrpcClient extends DurableTaskClient { .usePlaintext() .build(); this.sidecarClient = TaskHubSidecarServiceGrpc.newBlockingStub(this.managedSidecarChannel); + this.entityClient = new GrpcDurableEntityClient("GrpcDurableEntityClient", this.sidecarClient, this.dataConverter); } /** @@ -407,4 +410,13 @@ public String restartInstance(String instanceId, boolean restartWithNewInstanceI private PurgeResult toPurgeResult(PurgeInstancesResponse response){ return new PurgeResult(response.getDeletedInstanceCount()); } + + // region Entity APIs + + @Override + public DurableEntityClient getEntities() { + return this.entityClient; + } + + // endregion } diff --git a/client/src/main/java/com/microsoft/durabletask/DurableTaskGrpcWorker.java b/client/src/main/java/com/microsoft/durabletask/DurableTaskGrpcWorker.java index 552cf579..9456227a 100644 --- a/client/src/main/java/com/microsoft/durabletask/DurableTaskGrpcWorker.java +++ b/client/src/main/java/com/microsoft/durabletask/DurableTaskGrpcWorker.java @@ -14,6 +14,8 @@ import java.time.Duration; import java.util.*; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.logging.Level; import java.util.logging.Logger; @@ -29,17 +31,22 @@ public final class DurableTaskGrpcWorker implements AutoCloseable { private final HashMap orchestrationFactories = new HashMap<>(); private final HashMap activityFactories = new HashMap<>(); + private final HashMap entityFactories = new HashMap<>(); private final ManagedChannel managedSidecarChannel; private final DataConverter dataConverter; private final Duration maximumTimerInterval; private final DurableTaskGrpcWorkerVersioningOptions versioningOptions; + private final int maxConcurrentEntityWorkItems; + private final ExecutorService workItemExecutor; private final TaskHubSidecarServiceBlockingStub sidecarClient; DurableTaskGrpcWorker(DurableTaskGrpcWorkerBuilder builder) { this.orchestrationFactories.putAll(builder.orchestrationFactories); this.activityFactories.putAll(builder.activityFactories); + this.entityFactories.putAll(builder.entityFactories); + this.maxConcurrentEntityWorkItems = builder.maxConcurrentEntityWorkItems; Channel sidecarGrpcChannel; if (builder.channel != null) { @@ -65,6 +72,11 @@ public final class DurableTaskGrpcWorker implements AutoCloseable { this.dataConverter = builder.dataConverter != null ? builder.dataConverter : new JacksonDataConverter(); this.maximumTimerInterval = builder.maximumTimerInterval != null ? builder.maximumTimerInterval : DEFAULT_MAXIMUM_TIMER_INTERVAL; this.versioningOptions = builder.versioningOptions; + this.workItemExecutor = Executors.newCachedThreadPool(r -> { + Thread t = new Thread(r, "durabletask-worker"); + t.setDaemon(true); + return t; + }); } /** @@ -85,6 +97,15 @@ public void start() { * configured. */ public void close() { + this.workItemExecutor.shutdown(); + try { + if (!this.workItemExecutor.awaitTermination(10, TimeUnit.SECONDS)) { + this.workItemExecutor.shutdownNow(); + } + } catch (InterruptedException e) { + this.workItemExecutor.shutdownNow(); + } + if (this.managedSidecarChannel != null) { try { this.managedSidecarChannel.shutdownNow().awaitTermination(5, TimeUnit.SECONDS); @@ -123,11 +144,20 @@ public void startAndBlock() { this.activityFactories, this.dataConverter, logger); + TaskEntityExecutor taskEntityExecutor = new TaskEntityExecutor( + this.entityFactories, + this.dataConverter, + logger); // TODO: How do we interrupt manually? while (true) { try { - GetWorkItemsRequest getWorkItemsRequest = GetWorkItemsRequest.newBuilder().build(); + GetWorkItemsRequest.Builder requestBuilder = GetWorkItemsRequest.newBuilder(); + if (!this.entityFactories.isEmpty()) { + // Signal to the sidecar that this worker can handle entity work items + requestBuilder.setMaxConcurrentEntityWorkItems(this.maxConcurrentEntityWorkItems); + } + GetWorkItemsRequest getWorkItemsRequest = requestBuilder.build(); Iterator workItemStream = this.sidecarClient.getWorkItems(getWorkItemsRequest); while (workItemStream.hasNext()) { WorkItem workItem = workItemStream.next(); @@ -219,37 +249,137 @@ public void startAndBlock() { } else if (requestType == RequestCase.ACTIVITYREQUEST) { ActivityRequest activityRequest = workItem.getActivityRequest(); - // TODO: Run this on a worker pool thread: https://www.baeldung.com/thread-pool-java-and-guava - String output = null; - TaskFailureDetails failureDetails = null; - try { - output = taskActivityExecutor.execute( - activityRequest.getName(), - activityRequest.getInput().getValue(), - activityRequest.getTaskId()); - } catch (Throwable e) { - failureDetails = TaskFailureDetails.newBuilder() - .setErrorType(e.getClass().getName()) - .setErrorMessage(e.getMessage()) - .setStackTrace(StringValue.of(FailureDetails.getFullStackTrace(e))) - .build(); - } + this.workItemExecutor.submit(() -> { + String output = null; + TaskFailureDetails failureDetails = null; + try { + output = taskActivityExecutor.execute( + activityRequest.getName(), + activityRequest.getInput().getValue(), + activityRequest.getTaskId()); + } catch (Throwable e) { + failureDetails = TaskFailureDetails.newBuilder() + .setErrorType(e.getClass().getName()) + .setErrorMessage(e.getMessage()) + .setStackTrace(StringValue.of(FailureDetails.getFullStackTrace(e))) + .build(); + } - ActivityResponse.Builder responseBuilder = ActivityResponse.newBuilder() - .setInstanceId(activityRequest.getOrchestrationInstance().getInstanceId()) - .setTaskId(activityRequest.getTaskId()) - .setCompletionToken(workItem.getCompletionToken()); + ActivityResponse.Builder responseBuilder = ActivityResponse.newBuilder() + .setInstanceId(activityRequest.getOrchestrationInstance().getInstanceId()) + .setTaskId(activityRequest.getTaskId()) + .setCompletionToken(workItem.getCompletionToken()); - if (output != null) { - responseBuilder.setResult(StringValue.of(output)); - } + if (output != null) { + responseBuilder.setResult(StringValue.of(output)); + } - if (failureDetails != null) { - responseBuilder.setFailureDetails(failureDetails); - } + if (failureDetails != null) { + responseBuilder.setFailureDetails(failureDetails); + } - this.sidecarClient.completeActivityTask(responseBuilder.build()); - } + this.sidecarClient.completeActivityTask(responseBuilder.build()); + }); + } else if (requestType == RequestCase.ENTITYREQUEST) { + EntityBatchRequest entityRequest = workItem.getEntityRequest(); + this.workItemExecutor.submit(() -> { + try { + EntityBatchResult result = taskEntityExecutor.execute(entityRequest); + EntityBatchResult responseWithToken = result.toBuilder() + .setCompletionToken(workItem.getCompletionToken()) + .build(); + this.sidecarClient.completeEntityTask(responseWithToken); + } catch (Exception e) { + logger.log(Level.WARNING, + String.format("Failed to execute entity batch for '%s'. Abandoning work item.", + entityRequest.getInstanceId()), + e); + this.sidecarClient.abandonTaskEntityWorkItem( + AbandonEntityTaskRequest.newBuilder() + .setCompletionToken(workItem.getCompletionToken()) + .build()); + } + }); + } else if (requestType == RequestCase.ENTITYREQUESTV2) { + EntityRequest entityRequestV2 = workItem.getEntityRequestV2(); + this.workItemExecutor.submit(() -> { + try { + // Convert V2 (history-based) format to V1 (flat) format + EntityBatchRequest.Builder batchBuilder = EntityBatchRequest.newBuilder() + .setInstanceId(entityRequestV2.getInstanceId()); + if (entityRequestV2.hasEntityState()) { + batchBuilder.setEntityState(entityRequestV2.getEntityState()); + } + + List operationInfos = new ArrayList<>(); + for (HistoryEvent event : entityRequestV2.getOperationRequestsList()) { + if (event.hasEntityOperationSignaled()) { + EntityOperationSignaledEvent signaled = event.getEntityOperationSignaled(); + OperationRequest.Builder opBuilder = OperationRequest.newBuilder() + .setRequestId(signaled.getRequestId()) + .setOperation(signaled.getOperation()); + if (signaled.hasInput()) { + opBuilder.setInput(signaled.getInput()); + } + batchBuilder.addOperations(opBuilder.build()); + // Fire-and-forget: no response destination + operationInfos.add(OperationInfo.newBuilder() + .setRequestId(signaled.getRequestId()) + .build()); + } else if (event.hasEntityOperationCalled()) { + EntityOperationCalledEvent called = event.getEntityOperationCalled(); + OperationRequest.Builder opBuilder = OperationRequest.newBuilder() + .setRequestId(called.getRequestId()) + .setOperation(called.getOperation()); + if (called.hasInput()) { + opBuilder.setInput(called.getInput()); + } + batchBuilder.addOperations(opBuilder.build()); + // Two-way call: include response destination + OperationInfo.Builder infoBuilder = OperationInfo.newBuilder() + .setRequestId(called.getRequestId()); + if (called.hasParentInstanceId()) { + OrchestrationInstance.Builder destBuilder = OrchestrationInstance.newBuilder() + .setInstanceId(called.getParentInstanceId().getValue()); + if (called.hasParentExecutionId()) { + destBuilder.setExecutionId(StringValue.of(called.getParentExecutionId().getValue())); + } + infoBuilder.setResponseDestination(destBuilder.build()); + } + operationInfos.add(infoBuilder.build()); + } else { + logger.log(Level.WARNING, + "Skipping unsupported history event type in ENTITYREQUESTV2: {0}", + event.getEventTypeCase()); + } + } + + EntityBatchRequest batchRequest = batchBuilder.build(); + EntityBatchResult result = taskEntityExecutor.execute(batchRequest); + + // Attach completion token and operation infos for response routing + EntityBatchResult.Builder responseBuilder = result.toBuilder() + .setCompletionToken(workItem.getCompletionToken()); + // Trim operationInfos to match actual result count + int resultCount = result.getResultsCount(); + if (operationInfos.size() > resultCount) { + responseBuilder.addAllOperationInfos(operationInfos.subList(0, resultCount)); + } else { + responseBuilder.addAllOperationInfos(operationInfos); + } + this.sidecarClient.completeEntityTask(responseBuilder.build()); + } catch (Exception e) { + logger.log(Level.WARNING, + String.format("Failed to execute V2 entity batch for '%s'. Abandoning work item.", + entityRequestV2.getInstanceId()), + e); + this.sidecarClient.abandonTaskEntityWorkItem( + AbandonEntityTaskRequest.newBuilder() + .setCompletionToken(workItem.getCompletionToken()) + .build()); + } + }); + } else if (requestType == RequestCase.HEALTHPING) { // No-op diff --git a/client/src/main/java/com/microsoft/durabletask/DurableTaskGrpcWorkerBuilder.java b/client/src/main/java/com/microsoft/durabletask/DurableTaskGrpcWorkerBuilder.java index ec39fee2..fc6d879a 100644 --- a/client/src/main/java/com/microsoft/durabletask/DurableTaskGrpcWorkerBuilder.java +++ b/client/src/main/java/com/microsoft/durabletask/DurableTaskGrpcWorkerBuilder.java @@ -4,8 +4,10 @@ import io.grpc.Channel; +import java.lang.reflect.InvocationTargetException; import java.time.Duration; import java.util.HashMap; +import java.util.Locale; /** * Builder object for constructing customized {@link DurableTaskGrpcWorker} instances. @@ -13,11 +15,13 @@ public final class DurableTaskGrpcWorkerBuilder { final HashMap orchestrationFactories = new HashMap<>(); final HashMap activityFactories = new HashMap<>(); + final HashMap entityFactories = new HashMap<>(); int port; Channel channel; DataConverter dataConverter; Duration maximumTimerInterval; DurableTaskGrpcWorkerVersioningOptions versioningOptions; + int maxConcurrentEntityWorkItems = 1; /** * Adds an orchestration factory to be used by the constructed {@link DurableTaskGrpcWorker}. @@ -62,6 +66,115 @@ public DurableTaskGrpcWorkerBuilder addActivity(TaskActivityFactory factory) { return this; } + /** + * Adds an entity factory to be used by the constructed {@link DurableTaskGrpcWorker}. + * + * @param name the name of the entity type + * @param factory the factory that creates instances of the entity + * @return this builder object + */ + public DurableTaskGrpcWorkerBuilder addEntity(String name, TaskEntityFactory factory) { + if (name == null || name.isEmpty()) { + throw new IllegalArgumentException("A non-empty entity name is required."); + } + if (factory == null) { + throw new IllegalArgumentException("An entity factory is required."); + } + + String key = name.toLowerCase(Locale.ROOT); + if (this.entityFactories.containsKey(key)) { + throw new IllegalArgumentException( + String.format("An entity factory named %s is already registered.", name)); + } + + this.entityFactories.put(key, factory); + return this; + } + + /** + * Registers an entity type for the constructed {@link DurableTaskGrpcWorker}. + *

+ * The entity class must implement {@link ITaskEntity} and have a public no-argument constructor. + * A new instance of the entity is created for each operation batch using reflection. + *

+ * The entity name is derived from the simple class name of the provided type. + * + * @param entityClass the entity class to register; must implement {@link ITaskEntity} + * @return this builder object + * @throws IllegalArgumentException if the class does not implement {@link ITaskEntity} + */ + public DurableTaskGrpcWorkerBuilder addEntity(Class entityClass) { + if (entityClass == null) { + throw new IllegalArgumentException("entityClass must not be null."); + } + String name = entityClass.getSimpleName(); + return this.addEntity(name, entityClass); + } + + /** + * Registers an entity type with a specific name for the constructed {@link DurableTaskGrpcWorker}. + *

+ * The entity class must implement {@link ITaskEntity} and have a public no-argument constructor. + * A new instance of the entity is created for each operation batch using reflection. + * + * @param name the name of the entity type + * @param entityClass the entity class to register; must implement {@link ITaskEntity} + * @return this builder object + * @throws IllegalArgumentException if the class does not implement {@link ITaskEntity} + */ + public DurableTaskGrpcWorkerBuilder addEntity(String name, Class entityClass) { + if (entityClass == null) { + throw new IllegalArgumentException("entityClass must not be null."); + } + if (!ITaskEntity.class.isAssignableFrom(entityClass)) { + throw new IllegalArgumentException( + String.format("Type %s does not implement ITaskEntity.", entityClass.getName())); + } + return this.addEntity(name, () -> { + try { + return entityClass.getDeclaredConstructor().newInstance(); + } catch (InstantiationException | IllegalAccessException | InvocationTargetException | NoSuchMethodException e) { + throw new RuntimeException( + String.format("Failed to create instance of entity type %s. Ensure it has a public no-argument constructor.", entityClass.getName()), e); + } + }); + } + + /** + * Registers an entity singleton for the constructed {@link DurableTaskGrpcWorker}. + *

+ * The same entity instance is reused for every operation batch. This is useful for stateless entities + * or entities that manage their own lifecycle. + *

+ * The entity name is derived from the simple class name of the provided entity instance. + * + * @param entity the entity instance to register + * @return this builder object + */ + public DurableTaskGrpcWorkerBuilder addEntity(ITaskEntity entity) { + if (entity == null) { + throw new IllegalArgumentException("entity must not be null."); + } + String name = entity.getClass().getSimpleName(); + return this.addEntity(name, () -> entity); + } + + /** + * Registers an entity singleton with a specific name for the constructed {@link DurableTaskGrpcWorker}. + *

+ * The same entity instance is reused for every operation batch. + * + * @param name the name of the entity type + * @param entity the entity instance to register + * @return this builder object + */ + public DurableTaskGrpcWorkerBuilder addEntity(String name, ITaskEntity entity) { + if (entity == null) { + throw new IllegalArgumentException("entity must not be null."); + } + return this.addEntity(name, () -> entity); + } + /** * Sets the gRPC channel to use for communicating with the sidecar process. *

@@ -114,6 +227,24 @@ public DurableTaskGrpcWorkerBuilder maximumTimerInterval(Duration maximumTimerIn return this; } + /** + * Sets the maximum number of entity work items that can be processed concurrently by this worker. + *

+ * Each entity instance is always single-threaded (serial execution), but this setting controls + * how many different entity instances can process work items in parallel. The default value is 1. + * + * @param maxConcurrentEntityWorkItems the maximum number of concurrent entity work items (must be at least 1) + * @return this builder object + * @throws IllegalArgumentException if the value is less than 1 + */ + public DurableTaskGrpcWorkerBuilder maxConcurrentEntityWorkItems(int maxConcurrentEntityWorkItems) { + if (maxConcurrentEntityWorkItems < 1) { + throw new IllegalArgumentException("maxConcurrentEntityWorkItems must be at least 1."); + } + this.maxConcurrentEntityWorkItems = maxConcurrentEntityWorkItems; + return this; + } + /** * Sets the versioning options for this worker. * diff --git a/client/src/main/java/com/microsoft/durabletask/EntityInstanceId.java b/client/src/main/java/com/microsoft/durabletask/EntityInstanceId.java new file mode 100644 index 00000000..ad8153d8 --- /dev/null +++ b/client/src/main/java/com/microsoft/durabletask/EntityInstanceId.java @@ -0,0 +1,119 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.microsoft.durabletask; + +import javax.annotation.Nonnull; +import java.util.Locale; +import java.util.Objects; + +/** + * Immutable identifier for a durable entity instance, consisting of a name and key pair. + *

+ * The name typically corresponds to the entity class/type name, and the key identifies the specific + * entity instance (e.g., a user ID or account number). + */ +public final class EntityInstanceId implements Comparable { + private final String name; + private final String key; + + /** + * Creates a new {@code EntityInstanceId} with the specified name and key. + * + * @param name the entity name (type), must not be null + * @param key the entity key (instance identifier), must not be null + * @throws IllegalArgumentException if name or key is null or empty + */ + public EntityInstanceId(@Nonnull String name, @Nonnull String key) { + if (name == null || name.isEmpty()) { + throw new IllegalArgumentException("Entity name must not be null or empty."); + } + if (name.contains("@")) { + throw new IllegalArgumentException("Entity name must not contain '@'."); + } + if (key == null || key.isEmpty()) { + throw new IllegalArgumentException("Entity key must not be null or empty."); + } + this.name = name.toLowerCase(Locale.ROOT); + this.key = key; + } + + /** + * Gets the entity name (type). + * + * @return the entity name + */ + @Nonnull + public String getName() { + return this.name; + } + + /** + * Gets the entity key (instance identifier). + * + * @return the entity key + */ + @Nonnull + public String getKey() { + return this.key; + } + + /** + * Parses an {@code EntityInstanceId} from its string representation. + *

+ * The expected format is {@code @{name}@{key}}, matching the .NET {@code EntityId.ToString()} format. + * + * @param value the string to parse + * @return the parsed {@code EntityInstanceId} + * @throws IllegalArgumentException if the string is not in the expected format + */ + @Nonnull + public static EntityInstanceId fromString(@Nonnull String value) { + if (value == null || value.isEmpty()) { + throw new IllegalArgumentException("Value must not be null or empty."); + } + if (!value.startsWith("@")) { + throw new IllegalArgumentException( + "Invalid EntityInstanceId format. Expected '@{name}@{key}', got: " + value); + } + int secondAt = value.indexOf('@', 1); + if (secondAt < 0) { + throw new IllegalArgumentException( + "Invalid EntityInstanceId format. Expected '@{name}@{key}', got: " + value); + } + String name = value.substring(1, secondAt); + String key = value.substring(secondAt + 1); + return new EntityInstanceId(name, key); + } + + /** + * Returns the string representation in the format {@code @{name}@{key}}. + * + * @return the string representation of this entity instance ID + */ + @Override + public String toString() { + return "@" + this.name + "@" + this.key; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + EntityInstanceId that = (EntityInstanceId) o; + return this.name.equals(that.name) && this.key.equals(that.key); + } + + @Override + public int hashCode() { + return Objects.hash(this.name, this.key); + } + + @Override + public int compareTo(@Nonnull EntityInstanceId other) { + int nameCompare = this.name.compareTo(other.name); + if (nameCompare != 0) { + return nameCompare; + } + return this.key.compareTo(other.key); + } +} diff --git a/client/src/main/java/com/microsoft/durabletask/EntityMetadata.java b/client/src/main/java/com/microsoft/durabletask/EntityMetadata.java new file mode 100644 index 00000000..d32155a7 --- /dev/null +++ b/client/src/main/java/com/microsoft/durabletask/EntityMetadata.java @@ -0,0 +1,148 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.microsoft.durabletask; + +import javax.annotation.Nullable; +import java.time.Instant; + +/** + * Represents metadata about a durable entity instance, including its identity, state, and lock status. + *

+ * For typed state access, see {@link TypedEntityMetadata} which provides a {@code getState()} method + * that returns the deserialized state as a specific type. + * + * @see TypedEntityMetadata + */ +public class EntityMetadata { + private final String instanceId; + private final Instant lastModifiedTime; + private final int backlogQueueSize; + private final String lockedBy; + private final String serializedState; + private final boolean includesState; + private final DataConverter dataConverter; + + /** + * Creates a new {@code EntityMetadata} instance. + * + * @param instanceId the entity instance ID string (in {@code @name@key} format) + * @param lastModifiedTime the time the entity was last modified + * @param backlogQueueSize the number of operations waiting in the entity's backlog queue + * @param lockedBy the orchestration instance ID that currently holds a lock on this entity, or {@code null} + * @param serializedState the serialized entity state, or {@code null} if state was not fetched + * @param includesState {@code true} if the state was requested and is included in this metadata + * @param dataConverter the data converter used to deserialize state + */ + EntityMetadata( + String instanceId, + Instant lastModifiedTime, + int backlogQueueSize, + @Nullable String lockedBy, + @Nullable String serializedState, + boolean includesState, + DataConverter dataConverter) { + this.instanceId = instanceId; + this.lastModifiedTime = lastModifiedTime; + this.backlogQueueSize = backlogQueueSize; + this.lockedBy = lockedBy; + this.serializedState = serializedState; + this.includesState = includesState; + this.dataConverter = dataConverter; + } + + /** + * Gets the entity instance ID string. + * + * @return the instance ID + */ + public String getInstanceId() { + return this.instanceId; + } + + /** + * Gets the parsed {@link EntityInstanceId} from the instance ID string. + * + * @return the parsed entity instance ID + */ + public EntityInstanceId getEntityInstanceId() { + return EntityInstanceId.fromString(this.instanceId); + } + + /** + * Gets the time the entity was last modified. + * + * @return the last modified time + */ + public Instant getLastModifiedTime() { + return this.lastModifiedTime; + } + + /** + * Gets the number of operations waiting in the entity's backlog queue. + * + * @return the backlog queue size + */ + public int getBacklogQueueSize() { + return this.backlogQueueSize; + } + + /** + * Gets the orchestration instance ID that currently holds a lock on this entity, + * or {@code null} if the entity is not locked. + * + * @return the locking orchestration instance ID, or {@code null} + */ + @Nullable + public String getLockedBy() { + return this.lockedBy; + } + + /** + * Gets the raw serialized entity state, or {@code null} if state was not fetched. + * + * @return the serialized state string, or {@code null} + */ + @Nullable + public String getSerializedState() { + return this.serializedState; + } + + /** + * Gets whether this metadata response includes the entity state. + *

+ * Queries can exclude the state of the entity from the metadata that is retrieved. + * When this returns {@code false}, {@link #getSerializedState()} and {@link #readStateAs(Class)} + * will return {@code null}. + * + * @return {@code true} if state was requested and included in this metadata + */ + public boolean isIncludesState() { + return this.includesState; + } + + /** + * Gets the data converter used for state deserialization. + *

+ * This is package-private to allow {@link TypedEntityMetadata} to pass it to the superclass constructor. + * + * @return the data converter + */ + DataConverter getDataConverter() { + return this.dataConverter; + } + + /** + * Deserializes the entity state into an object of the specified type. + * + * @param stateType the class to deserialize the state into + * @param the target type + * @return the deserialized state, or {@code null} if no state is available + */ + @Nullable + public T readStateAs(Class stateType) { + if (this.serializedState == null) { + return null; + } + return this.dataConverter.deserialize(this.serializedState, stateType); + } +} diff --git a/client/src/main/java/com/microsoft/durabletask/EntityOperationFailedException.java b/client/src/main/java/com/microsoft/durabletask/EntityOperationFailedException.java new file mode 100644 index 00000000..d34071af --- /dev/null +++ b/client/src/main/java/com/microsoft/durabletask/EntityOperationFailedException.java @@ -0,0 +1,70 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.microsoft.durabletask; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +/** + * Exception thrown when a two-way entity call fails because the entity operation threw an exception. + *

+ * This is analogous to {@link TaskFailedException} but specific to entity operations invoked + * via {@code callEntity}. The {@link #getFailureDetails()} method provides detailed information + * about the failure. + */ +public class EntityOperationFailedException extends RuntimeException { + private final EntityInstanceId entityId; + private final String operationName; + private final FailureDetails failureDetails; + + /** + * Creates a new {@code EntityOperationFailedException}. + * + * @param entityId the ID of the entity that failed + * @param operationName the name of the operation that failed + * @param failureDetails the details of the failure + */ + public EntityOperationFailedException( + @Nonnull EntityInstanceId entityId, + @Nonnull String operationName, + @Nonnull FailureDetails failureDetails) { + super(String.format( + "Entity operation '%s' on entity '%s' failed: %s", + operationName, + entityId.toString(), + failureDetails.getErrorMessage())); + this.entityId = entityId; + this.operationName = operationName; + this.failureDetails = failureDetails; + } + + /** + * Gets the ID of the entity that failed. + * + * @return the entity instance ID + */ + @Nonnull + public EntityInstanceId getEntityId() { + return this.entityId; + } + + /** + * Gets the name of the operation that failed. + * + * @return the operation name + */ + @Nonnull + public String getOperationName() { + return this.operationName; + } + + /** + * Gets the failure details, including the error type, message, and stack trace. + * + * @return the failure details + */ + @Nonnull + public FailureDetails getFailureDetails() { + return this.failureDetails; + } +} diff --git a/client/src/main/java/com/microsoft/durabletask/EntityQuery.java b/client/src/main/java/com/microsoft/durabletask/EntityQuery.java new file mode 100644 index 00000000..61b7179d --- /dev/null +++ b/client/src/main/java/com/microsoft/durabletask/EntityQuery.java @@ -0,0 +1,201 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.microsoft.durabletask; + +import javax.annotation.Nullable; +import java.time.Instant; +import java.util.Locale; + +/** + * Represents a query filter for fetching durable entity metadata from the store. + *

+ * Use the builder-style setters to configure the query parameters, then pass this object to + * {@link DurableEntityClient#queryEntities(EntityQuery)}. + */ +public final class EntityQuery { + + /** + * The default page size for entity queries ({@value}). + * This matches the .NET SDK's {@code EntityQuery.DefaultPageSize}. + */ + public static final int DEFAULT_PAGE_SIZE = 100; + + private String instanceIdStartsWith; + private Instant lastModifiedFrom; + private Instant lastModifiedTo; + private boolean includeState = true; + private boolean includeTransient; + private Integer pageSize; + private String continuationToken; + + /** + * Creates a new {@code EntityQuery} with default settings. + */ + public EntityQuery() { + } + + /** + * Sets a prefix filter on entity instance IDs. + *

+ * If the value does not start with {@code @}, it is treated as a raw entity name and will be + * normalized to the internal format {@code @entityname} (lowercased) to match how entity + * instance IDs are stored. + * + * @param instanceIdStartsWith the instance ID prefix to filter by, or {@code null} for no filter + * @return this {@code EntityQuery} for chaining + */ + public EntityQuery setInstanceIdStartsWith(@Nullable String instanceIdStartsWith) { + if (instanceIdStartsWith != null && !instanceIdStartsWith.startsWith("@")) { + // Normalize: treat as raw entity name, prepend '@' and lowercase + instanceIdStartsWith = "@" + instanceIdStartsWith.toLowerCase(Locale.ROOT); + } else if (instanceIdStartsWith != null) { + // Already in @name format, but ensure the entity name portion is lowercased + // Format is @entityName or @entityName@key + int secondAt = instanceIdStartsWith.indexOf('@', 1); + if (secondAt > 0) { + // Has key part: lowercase only the name between first and second @ + instanceIdStartsWith = "@" + instanceIdStartsWith.substring(1, secondAt).toLowerCase(Locale.ROOT) + + instanceIdStartsWith.substring(secondAt); + } else { + // Only @name, lowercase the name portion + instanceIdStartsWith = "@" + instanceIdStartsWith.substring(1).toLowerCase(Locale.ROOT); + } + } + this.instanceIdStartsWith = instanceIdStartsWith; + return this; + } + + /** + * Gets the instance ID prefix filter. + * + * @return the instance ID prefix, or {@code null} + */ + @Nullable + public String getInstanceIdStartsWith() { + return this.instanceIdStartsWith; + } + + /** + * Sets the minimum last-modified time filter (inclusive). + * + * @param lastModifiedFrom the minimum last-modified time, or {@code null} for no lower bound + * @return this {@code EntityQuery} for chaining + */ + public EntityQuery setLastModifiedFrom(@Nullable Instant lastModifiedFrom) { + this.lastModifiedFrom = lastModifiedFrom; + return this; + } + + /** + * Gets the minimum last-modified time filter. + * + * @return the minimum last-modified time, or {@code null} + */ + @Nullable + public Instant getLastModifiedFrom() { + return this.lastModifiedFrom; + } + + /** + * Sets the maximum last-modified time filter (inclusive). + * + * @param lastModifiedTo the maximum last-modified time, or {@code null} for no upper bound + * @return this {@code EntityQuery} for chaining + */ + public EntityQuery setLastModifiedTo(@Nullable Instant lastModifiedTo) { + this.lastModifiedTo = lastModifiedTo; + return this; + } + + /** + * Gets the maximum last-modified time filter. + * + * @return the maximum last-modified time, or {@code null} + */ + @Nullable + public Instant getLastModifiedTo() { + return this.lastModifiedTo; + } + + /** + * Sets whether to include entity state in the query results. + * + * @param includeState {@code true} to include state, {@code false} to omit it + * @return this {@code EntityQuery} for chaining + */ + public EntityQuery setIncludeState(boolean includeState) { + this.includeState = includeState; + return this; + } + + /** + * Gets whether entity state is included in query results. + * + * @return {@code true} if state is included + */ + public boolean isIncludeState() { + return this.includeState; + } + + /** + * Sets whether to include transient (not yet persisted) entities in the results. + * + * @param includeTransient {@code true} to include transient entities + * @return this {@code EntityQuery} for chaining + */ + public EntityQuery setIncludeTransient(boolean includeTransient) { + this.includeTransient = includeTransient; + return this; + } + + /** + * Gets whether transient entities are included in query results. + * + * @return {@code true} if transient entities are included + */ + public boolean isIncludeTransient() { + return this.includeTransient; + } + + /** + * Sets the maximum number of results to return per page. + * + * @param pageSize the page size, or {@code null} for the server default + * @return this {@code EntityQuery} for chaining + */ + public EntityQuery setPageSize(@Nullable Integer pageSize) { + this.pageSize = pageSize; + return this; + } + + /** + * Gets the maximum number of results per page. + * + * @return the page size, or {@code null} for the server default + */ + @Nullable + public Integer getPageSize() { + return this.pageSize; + } + + /** + * Sets the continuation token for fetching the next page of results. + * + * @param continuationToken the continuation token from a previous query, or {@code null} to start from the beginning + * @return this {@code EntityQuery} for chaining + */ + public EntityQuery setContinuationToken(@Nullable String continuationToken) { + this.continuationToken = continuationToken; + return this; + } + + /** + * Gets the continuation token for pagination. + * + * @return the continuation token, or {@code null} + */ + @Nullable + public String getContinuationToken() { + return this.continuationToken; + } +} diff --git a/client/src/main/java/com/microsoft/durabletask/EntityQueryPageable.java b/client/src/main/java/com/microsoft/durabletask/EntityQueryPageable.java new file mode 100644 index 00000000..b21e64b7 --- /dev/null +++ b/client/src/main/java/com/microsoft/durabletask/EntityQueryPageable.java @@ -0,0 +1,180 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.microsoft.durabletask; + +import javax.annotation.Nullable; +import java.util.Iterator; +import java.util.List; +import java.util.NoSuchElementException; +import java.util.function.Function; + +/** + * An auto-paginating iterable over entity query results. + *

+ * This class automatically handles pagination when iterating over entity metadata results. + * It fetches pages from the store on demand and yields individual {@link EntityMetadata} + * items to the caller. + *

+ * Use {@link DurableEntityClient#getAllEntities(EntityQuery)} to obtain an instance of this class. + * + *

Example: iterate over all entities

+ *
{@code
+ * EntityQuery query = new EntityQuery()
+ *     .setInstanceIdStartsWith("counter")
+ *     .setIncludeState(true);
+ *
+ * for (EntityMetadata entity : client.getEntities().getAllEntities(query)) {
+ *     System.out.println(entity.getEntityInstanceId());
+ * }
+ * }
+ * + *

Example: iterate page by page

+ *
{@code
+ * for (EntityQueryResult page : client.getEntities().getAllEntities(query).byPage()) {
+ *     System.out.println("Got " + page.getEntities().size() + " entities");
+ *     for (EntityMetadata entity : page.getEntities()) {
+ *         System.out.println(entity.getEntityInstanceId());
+ *     }
+ * }
+ * }
+ */ +public final class EntityQueryPageable implements Iterable { + private final EntityQuery baseQuery; + private final Function queryExecutor; + + /** + * Creates a new {@code EntityQueryPageable}. + * + * @param baseQuery the base query parameters + * @param queryExecutor the function that executes a single page query + */ + EntityQueryPageable(EntityQuery baseQuery, Function queryExecutor) { + this.baseQuery = baseQuery; + this.queryExecutor = queryExecutor; + } + + /** + * Returns an iterator over individual {@link EntityMetadata} items, automatically + * fetching subsequent pages as needed. + * + * @return an iterator over all matching entities + */ + @Override + public Iterator iterator() { + return new EntityItemIterator(); + } + + /** + * Returns an iterable over pages of results, where each page is an {@link EntityQueryResult} + * containing a list of entities and an optional continuation token. + * + * @return an iterable over result pages + */ + public Iterable byPage() { + return PageIterable::new; + } + + private class EntityItemIterator implements Iterator { + private String continuationToken = baseQuery.getContinuationToken(); + private Iterator currentPageIterator; + private boolean finished; + + EntityItemIterator() { + fetchNextPage(); + } + + @Override + public boolean hasNext() { + while (true) { + if (currentPageIterator != null && currentPageIterator.hasNext()) { + return true; + } + if (finished) { + return false; + } + fetchNextPage(); + } + } + + @Override + public EntityMetadata next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + return currentPageIterator.next(); + } + + private void fetchNextPage() { + if (finished) { + return; + } + + EntityQuery pageQuery = cloneQuery(baseQuery); + pageQuery.setContinuationToken(continuationToken); + + EntityQueryResult result = queryExecutor.apply(pageQuery); + List entities = result.getEntities(); + + if (entities == null || entities.isEmpty()) { + finished = true; + currentPageIterator = null; + return; + } + + currentPageIterator = entities.iterator(); + continuationToken = result.getContinuationToken(); + + if (continuationToken == null || continuationToken.isEmpty()) { + finished = true; + } + } + } + + private class PageIterable implements Iterator { + private String continuationToken = baseQuery.getContinuationToken(); + private boolean finished; + private boolean firstPage = true; + + @Override + public boolean hasNext() { + return !finished; + } + + @Override + public EntityQueryResult next() { + if (finished) { + throw new NoSuchElementException(); + } + + EntityQuery pageQuery = cloneQuery(baseQuery); + if (!firstPage) { + pageQuery.setContinuationToken(continuationToken); + } + firstPage = false; + + EntityQueryResult result = queryExecutor.apply(pageQuery); + continuationToken = result.getContinuationToken(); + + if (continuationToken == null || continuationToken.isEmpty()) { + finished = true; + } + + return result; + } + } + + private static EntityQuery cloneQuery(EntityQuery source) { + EntityQuery clone = new EntityQuery(); + if (source.getInstanceIdStartsWith() != null) { + // Use raw setter value since the source is already normalized + clone.setInstanceIdStartsWith(source.getInstanceIdStartsWith()); + } + clone.setLastModifiedFrom(source.getLastModifiedFrom()); + clone.setLastModifiedTo(source.getLastModifiedTo()); + clone.setIncludeState(source.isIncludeState()); + clone.setIncludeTransient(source.isIncludeTransient()); + clone.setPageSize(source.getPageSize()); + clone.setContinuationToken(source.getContinuationToken()); + return clone; + } +} diff --git a/client/src/main/java/com/microsoft/durabletask/EntityQueryResult.java b/client/src/main/java/com/microsoft/durabletask/EntityQueryResult.java new file mode 100644 index 00000000..cac22cc7 --- /dev/null +++ b/client/src/main/java/com/microsoft/durabletask/EntityQueryResult.java @@ -0,0 +1,45 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.microsoft.durabletask; + +import javax.annotation.Nullable; +import java.util.List; + +/** + * Represents the result of an entity query operation, including the matching entities and + * an optional continuation token for pagination. + */ +public final class EntityQueryResult { + private final List entities; + private final String continuationToken; + + /** + * Creates a new {@code EntityQueryResult}. + * + * @param entities the list of entity metadata records matching the query + * @param continuationToken the continuation token for fetching the next page, or {@code null} if no more results + */ + EntityQueryResult(List entities, @Nullable String continuationToken) { + this.entities = entities; + this.continuationToken = continuationToken; + } + + /** + * Gets the list of entity metadata records matching the query. + * + * @return the list of entity metadata + */ + public List getEntities() { + return this.entities; + } + + /** + * Gets the continuation token for fetching the next page of results. + * + * @return the continuation token, or {@code null} if there are no more results + */ + @Nullable + public String getContinuationToken() { + return this.continuationToken; + } +} diff --git a/client/src/main/java/com/microsoft/durabletask/EntityRunner.java b/client/src/main/java/com/microsoft/durabletask/EntityRunner.java new file mode 100644 index 00000000..d28c6952 --- /dev/null +++ b/client/src/main/java/com/microsoft/durabletask/EntityRunner.java @@ -0,0 +1,132 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.microsoft.durabletask; + +import com.google.protobuf.InvalidProtocolBufferException; +import com.microsoft.durabletask.implementation.protobuf.OrchestratorService.EntityBatchRequest; +import com.microsoft.durabletask.implementation.protobuf.OrchestratorService.EntityBatchResult; + +import java.util.Base64; +import java.util.HashMap; +import java.util.logging.Logger; + +/** + * Helper class for invoking entity operations directly, without constructing a {@link DurableTaskGrpcWorker} object. + *

+ * This static class can be used to execute entity logic directly. In order to use it for this purpose, the + * caller must provide entity state as serialized protobuf bytes. This is the entity equivalent of + * {@link OrchestrationRunner}. + *

+ * Typical usage in an Azure Functions entity trigger: + *

+ * {@literal @}FunctionName("Counter")
+ * public String counterEntity(
+ *         {@literal @}DurableEntityTrigger(name = "req") String req) {
+ *     return EntityRunner.loadAndRun(req, () -> new CounterEntity());
+ * }
+ * 
+ */ +public final class EntityRunner { + private static final Logger logger = Logger.getLogger(EntityRunner.class.getPackage().getName()); + + private EntityRunner() { + } + + /** + * Loads an entity batch request from {@code base64EncodedEntityRequest} and uses it to execute + * entity operations using the entity created by {@code entityFactory}. + * + * @param base64EncodedEntityRequest the base64-encoded protobuf payload representing an entity batch request + * @param entityFactory a factory that creates the entity instance to handle operations + * @return a base64-encoded protobuf payload of the entity batch result + * @throws IllegalArgumentException if either parameter is {@code null} or if {@code base64EncodedEntityRequest} + * is not valid base64-encoded protobuf + */ + public static String loadAndRun(String base64EncodedEntityRequest, TaskEntityFactory entityFactory) { + byte[] decodedBytes = Base64.getDecoder().decode(base64EncodedEntityRequest); + byte[] resultBytes = loadAndRun(decodedBytes, entityFactory); + return Base64.getEncoder().encodeToString(resultBytes); + } + + /** + * Loads an entity batch request from {@code entityRequestBytes} and uses it to execute + * entity operations using the entity created by {@code entityFactory}. + * + * @param entityRequestBytes the protobuf payload representing an entity batch request + * @param entityFactory a factory that creates the entity instance to handle operations + * @return a protobuf-encoded payload of the entity batch result + * @throws IllegalArgumentException if either parameter is {@code null} or if {@code entityRequestBytes} + * is not valid protobuf + */ + public static byte[] loadAndRun(byte[] entityRequestBytes, TaskEntityFactory entityFactory) { + if (entityRequestBytes == null || entityRequestBytes.length == 0) { + throw new IllegalArgumentException("entityRequestBytes must not be null or empty"); + } + + if (entityFactory == null) { + throw new IllegalArgumentException("entityFactory must not be null"); + } + + EntityBatchRequest request; + try { + request = EntityBatchRequest.parseFrom(entityRequestBytes); + } catch (InvalidProtocolBufferException e) { + throw new IllegalArgumentException("entityRequestBytes was not valid protobuf", e); + } + + // Parse entity name from the instance ID so the executor can look it up + String instanceId = request.getInstanceId(); + String entityName; + try { + entityName = EntityInstanceId.fromString(instanceId).getName(); + } catch (Exception e) { + // Fallback: use the raw instance ID as the entity name + entityName = instanceId; + } + + HashMap factories = new HashMap<>(); + factories.put(entityName, entityFactory); + + TaskEntityExecutor executor = new TaskEntityExecutor( + factories, + new JacksonDataConverter(), + logger); + + EntityBatchResult result = executor.execute(request); + return result.toByteArray(); + } + + /** + * Loads an entity batch request from {@code base64EncodedEntityRequest} and uses it to execute + * entity operations using the provided {@code entity} instance. + * + * @param base64EncodedEntityRequest the base64-encoded protobuf payload representing an entity batch request + * @param entity the entity instance to handle operations + * @return a base64-encoded protobuf payload of the entity batch result + * @throws IllegalArgumentException if either parameter is {@code null} or if {@code base64EncodedEntityRequest} + * is not valid base64-encoded protobuf + */ + public static String loadAndRun(String base64EncodedEntityRequest, ITaskEntity entity) { + if (entity == null) { + throw new IllegalArgumentException("entity must not be null"); + } + return loadAndRun(base64EncodedEntityRequest, () -> entity); + } + + /** + * Loads an entity batch request from {@code entityRequestBytes} and uses it to execute + * entity operations using the provided {@code entity} instance. + * + * @param entityRequestBytes the protobuf payload representing an entity batch request + * @param entity the entity instance to handle operations + * @return a protobuf-encoded payload of the entity batch result + * @throws IllegalArgumentException if either parameter is {@code null} or if {@code entityRequestBytes} + * is not valid protobuf + */ + public static byte[] loadAndRun(byte[] entityRequestBytes, ITaskEntity entity) { + if (entity == null) { + throw new IllegalArgumentException("entity must not be null"); + } + return loadAndRun(entityRequestBytes, () -> entity); + } +} diff --git a/client/src/main/java/com/microsoft/durabletask/GrpcDurableEntityClient.java b/client/src/main/java/com/microsoft/durabletask/GrpcDurableEntityClient.java new file mode 100644 index 00000000..9e92c602 --- /dev/null +++ b/client/src/main/java/com/microsoft/durabletask/GrpcDurableEntityClient.java @@ -0,0 +1,175 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.microsoft.durabletask; + +import com.google.protobuf.StringValue; +import com.google.protobuf.Timestamp; +import com.microsoft.durabletask.implementation.protobuf.OrchestratorService.*; +import com.microsoft.durabletask.implementation.protobuf.TaskHubSidecarServiceGrpc.TaskHubSidecarServiceBlockingStub; + +import javax.annotation.Nullable; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.UUID; + +/** + * gRPC-based implementation of {@link DurableEntityClient}. + */ +final class GrpcDurableEntityClient extends DurableEntityClient { + + private final TaskHubSidecarServiceBlockingStub sidecarClient; + private final DataConverter dataConverter; + + GrpcDurableEntityClient( + String name, + TaskHubSidecarServiceBlockingStub sidecarClient, + DataConverter dataConverter) { + super(name); + this.sidecarClient = sidecarClient; + this.dataConverter = dataConverter; + } + + @Override + public void signalEntity( + EntityInstanceId entityId, + String operationName, + @Nullable Object input, + @Nullable SignalEntityOptions options) { + Helpers.throwIfArgumentNull(entityId, "entityId"); + Helpers.throwIfArgumentNull(operationName, "operationName"); + + SignalEntityRequest.Builder builder = SignalEntityRequest.newBuilder() + .setInstanceId(entityId.toString()) + .setName(operationName) + .setRequestId(UUID.randomUUID().toString()); + + if (input != null) { + String serializedInput = this.dataConverter.serialize(input); + if (serializedInput != null) { + builder.setInput(StringValue.of(serializedInput)); + } + } + + if (options != null && options.getScheduledTime() != null) { + Timestamp ts = DataConverter.getTimestampFromInstant(options.getScheduledTime()); + builder.setScheduledTime(ts); + } + + this.sidecarClient.signalEntity(builder.build()); + } + + @Override + @Nullable + public EntityMetadata getEntityMetadata(EntityInstanceId entityId, boolean includeState) { + Helpers.throwIfArgumentNull(entityId, "entityId"); + + GetEntityRequest request = GetEntityRequest.newBuilder() + .setInstanceId(entityId.toString()) + .setIncludeState(includeState) + .build(); + + GetEntityResponse response = this.sidecarClient.getEntity(request); + if (!response.getExists()) { + return null; + } + + return toEntityMetadata(response.getEntity()); + } + + @Override + public EntityQueryResult queryEntities(EntityQuery query) { + Helpers.throwIfArgumentNull(query, "query"); + + com.microsoft.durabletask.implementation.protobuf.OrchestratorService.EntityQuery.Builder queryBuilder = + com.microsoft.durabletask.implementation.protobuf.OrchestratorService.EntityQuery.newBuilder(); + + if (query.getInstanceIdStartsWith() != null) { + queryBuilder.setInstanceIdStartsWith(StringValue.of(query.getInstanceIdStartsWith())); + } + if (query.getLastModifiedFrom() != null) { + queryBuilder.setLastModifiedFrom(DataConverter.getTimestampFromInstant(query.getLastModifiedFrom())); + } + if (query.getLastModifiedTo() != null) { + queryBuilder.setLastModifiedTo(DataConverter.getTimestampFromInstant(query.getLastModifiedTo())); + } + queryBuilder.setIncludeState(query.isIncludeState()); + queryBuilder.setIncludeTransient(query.isIncludeTransient()); + if (query.getPageSize() != null) { + queryBuilder.setPageSize(com.google.protobuf.Int32Value.of(query.getPageSize())); + } + if (query.getContinuationToken() != null) { + queryBuilder.setContinuationToken(StringValue.of(query.getContinuationToken())); + } + + QueryEntitiesRequest request = QueryEntitiesRequest.newBuilder() + .setQuery(queryBuilder) + .build(); + + QueryEntitiesResponse response = this.sidecarClient.queryEntities(request); + + List entities = new ArrayList<>(); + for (com.microsoft.durabletask.implementation.protobuf.OrchestratorService.EntityMetadata protoEntity + : response.getEntitiesList()) { + entities.add(toEntityMetadata(protoEntity)); + } + + String continuationToken = response.hasContinuationToken() + ? response.getContinuationToken().getValue() + : null; + + return new EntityQueryResult(entities, continuationToken); + } + + @Override + public CleanEntityStorageResult cleanEntityStorage(CleanEntityStorageRequest request) { + Helpers.throwIfArgumentNull(request, "request"); + + int totalEmptyEntitiesRemoved = 0; + int totalOrphanedLocksReleased = 0; + String continuationToken = request.getContinuationToken(); + + do { + com.microsoft.durabletask.implementation.protobuf.OrchestratorService.CleanEntityStorageRequest.Builder builder = + com.microsoft.durabletask.implementation.protobuf.OrchestratorService.CleanEntityStorageRequest.newBuilder() + .setRemoveEmptyEntities(request.isRemoveEmptyEntities()) + .setReleaseOrphanedLocks(request.isReleaseOrphanedLocks()); + + if (continuationToken != null) { + builder.setContinuationToken(StringValue.of(continuationToken)); + } + + CleanEntityStorageResponse response = this.sidecarClient.cleanEntityStorage(builder.build()); + + totalEmptyEntitiesRemoved += response.getEmptyEntitiesRemoved(); + totalOrphanedLocksReleased += response.getOrphanedLocksReleased(); + + continuationToken = response.hasContinuationToken() + ? response.getContinuationToken().getValue() + : null; + } while (request.isContinueUntilComplete() && continuationToken != null); + + return new CleanEntityStorageResult( + continuationToken, + totalEmptyEntitiesRemoved, + totalOrphanedLocksReleased); + } + + private EntityMetadata toEntityMetadata( + com.microsoft.durabletask.implementation.protobuf.OrchestratorService.EntityMetadata protoEntity) { + Instant lastModifiedTime = DataConverter.getInstantFromTimestamp(protoEntity.getLastModifiedTime()); + String lockedBy = protoEntity.hasLockedBy() ? protoEntity.getLockedBy().getValue() : null; + String serializedState = protoEntity.hasSerializedState() + ? protoEntity.getSerializedState().getValue() + : null; + + return new EntityMetadata( + protoEntity.getInstanceId(), + lastModifiedTime, + protoEntity.getBacklogQueueSize(), + lockedBy, + serializedState, + protoEntity.hasSerializedState(), + this.dataConverter); + } +} diff --git a/client/src/main/java/com/microsoft/durabletask/ITaskEntity.java b/client/src/main/java/com/microsoft/durabletask/ITaskEntity.java new file mode 100644 index 00000000..65d49c54 --- /dev/null +++ b/client/src/main/java/com/microsoft/durabletask/ITaskEntity.java @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.microsoft.durabletask; + +/** + * Common interface for durable entity implementations. + *

+ * Entities are stateful, single-threaded actors identified by a name+key pair + * ({@link EntityInstanceId}). The durable task runtime manages state persistence, + * message routing, and concurrency control. + *

+ * Implement this interface directly for full control over entity behavior, or extend + * {@link TaskEntity} for a higher-level programming model with automatic reflection-based + * operation dispatch. + * + * @see TaskEntity + */ +@FunctionalInterface +public interface ITaskEntity { + /** + * Executes the entity logic for a single operation. + * + * @param operation the operation to execute, including the operation name, input, state, and context + * @return the result of the operation, which will be serialized and returned to the caller + * (for two-way calls). May be {@code null} for void operations. + * @throws Exception if the operation fails + */ + Object runAsync(TaskEntityOperation operation) throws Exception; +} diff --git a/client/src/main/java/com/microsoft/durabletask/SignalEntityOptions.java b/client/src/main/java/com/microsoft/durabletask/SignalEntityOptions.java new file mode 100644 index 00000000..1c3a91ed --- /dev/null +++ b/client/src/main/java/com/microsoft/durabletask/SignalEntityOptions.java @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.microsoft.durabletask; + +import javax.annotation.Nullable; +import java.time.Instant; + +/** + * Options for signaling a durable entity. + */ +public final class SignalEntityOptions { + private Instant scheduledTime; + + /** + * Creates a new {@code SignalEntityOptions} with default settings. + */ + public SignalEntityOptions() { + } + + /** + * Sets the scheduled time for the signal. If set, the signal will be delivered at the specified time + * rather than immediately. + * + * @param scheduledTime the time at which the signal should be delivered + * @return this {@code SignalEntityOptions} object for chaining + */ + public SignalEntityOptions setScheduledTime(@Nullable Instant scheduledTime) { + this.scheduledTime = scheduledTime; + return this; + } + + /** + * Gets the scheduled time for the signal, or {@code null} if the signal should be delivered immediately. + * + * @return the scheduled time, or {@code null} + */ + @Nullable + public Instant getScheduledTime() { + return this.scheduledTime; + } +} diff --git a/client/src/main/java/com/microsoft/durabletask/TaskEntity.java b/client/src/main/java/com/microsoft/durabletask/TaskEntity.java new file mode 100644 index 00000000..271453ab --- /dev/null +++ b/client/src/main/java/com/microsoft/durabletask/TaskEntity.java @@ -0,0 +1,275 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.microsoft.durabletask; + +import javax.annotation.Nullable; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; + +/** + * Base class for durable entities that provides automatic reflection-based operation dispatch. + *

+ * Subclasses define operations as public methods. When an operation is received, {@code TaskEntity} + * resolves it by: + *

    + *
  1. Reflection dispatch on {@code this}: Look for a public method on the entity class + * whose name matches the operation name (case-insensitive).
  2. + *
  3. State dispatch: If no method is found on the entity, look for a matching public + * method on the {@code TState} object.
  4. + *
  5. Implicit delete: If the operation name is "delete" and no explicit method exists, + * delete the entity state.
  6. + *
  7. If none of the above match, throw {@link UnsupportedOperationException}.
  8. + *
+ *

+ * Methods may accept 0, 1, or 2 parameters. Supported parameter types are the operation input type + * and {@link TaskEntityContext}. The method may return the operation result or be {@code void}. + * + *

Example: + *

{@code
+ * public class CounterEntity extends TaskEntity {
+ *     public void add(int amount) { this.state += amount; }
+ *     public void reset() { this.state = 0; }
+ *     public int get() { return this.state; }
+ *
+ *     protected Integer initializeState(TaskEntityOperation operation) {
+ *         return 0;
+ *     }
+ * }
+ * }
+ * + * @param the type of the entity's state + */ +public abstract class TaskEntity implements ITaskEntity { + + /** + * The current state of the entity. Subclasses may read and write this field directly + * in their operation methods. + */ + protected TState state; + + /** + * The current entity context, providing access to entity metadata such as the entity ID + * and the ability to signal other entities. + *

+ * This property is automatically set before each operation dispatch and is available for use + * in operation methods. This mirrors the .NET SDK's {@code TaskEntity.Context} property. + */ + protected TaskEntityContext context; + + /** + * Controls whether operations can be dispatched to methods on the state object. + * When {@code true}, if no matching method is found on the entity class itself, + * the framework will look for a matching method on the state object. + * When {@code false} (the default), only methods on the entity class are considered. + *

+ * This matches the .NET SDK default where {@code AllowStateDispatch} is {@code false}. + */ + private boolean allowStateDispatch = false; + + // Cache for resolved methods, keyed by (class, operationName). + // Uses Optional so that "not found" results are also cached. + private static final Map> methodCache = new ConcurrentHashMap<>(); + + /** + * Creates a new {@code TaskEntity} instance. + */ + protected TaskEntity() { + } + + /** + * Gets whether state dispatch is allowed. + * + * @return {@code true} if operations can be dispatched to state object methods + */ + protected boolean getAllowStateDispatch() { + return this.allowStateDispatch; + } + + /** + * Sets whether operations can be dispatched to methods on the state object. + *

+ * When {@code true}, if no matching method is found on the entity class itself, + * the framework will look for a matching method on the state object. + * When {@code false} (the default), only methods on the entity class are considered. + * + * @param allowStateDispatch {@code true} to allow state dispatch, {@code false} to disable + */ + protected void setAllowStateDispatch(boolean allowStateDispatch) { + this.allowStateDispatch = allowStateDispatch; + } + + /** + * Called to initialize the entity state when no prior state exists. + *

+ * The default implementation attempts to create a new instance of {@code TState} using + * its no-arg constructor via reflection. Override this method to provide custom initialization. + * + * @param operation the operation that triggered state initialization + * @return the initial state value + */ + @SuppressWarnings("unchecked") + protected TState initializeState(TaskEntityOperation operation) { + Class stateType = getStateType(); + if (stateType == null) { + return null; + } + try { + return stateType.getDeclaredConstructor().newInstance(); + } catch (Exception e) { + throw new RuntimeException( + "Failed to initialize entity state of type " + stateType.getName() + + ". Override initializeState() to provide custom initialization.", e); + } + } + + /** + * Gets the runtime class of the state type parameter. Override this method in concrete + * entity classes to provide the state type class, which is needed for deserialization. + * + * @return the state type class, or {@code null} if not applicable + */ + @Nullable + protected abstract Class getStateType(); + + @Override + public Object runAsync(TaskEntityOperation operation) throws Exception { + // Set the context before dispatch so subclass methods can access it + this.context = operation.getContext(); + + // Step 1: Load or initialize state + Class stateType = getStateType(); + if (stateType != null) { + this.state = operation.getState().getState(stateType); + if (this.state == null) { + this.state = initializeState(operation); + } + } + + Object result; + + // Step 2: Try reflection dispatch on this entity class + Method method = findMethod(this.getClass(), operation.getName()); + if (method != null) { + result = invokeMethod(method, this, operation); + } else if (this.allowStateDispatch && this.state != null) { + // Step 3: Try state dispatch (only if allowStateDispatch is true) + Method stateMethod = findMethod(this.state.getClass(), operation.getName()); + if (stateMethod != null) { + result = invokeMethod(stateMethod, this.state, operation); + } else { + // Step 4: Implicit delete + result = handleImplicitOperations(operation); + } + } else { + // Step 4: Implicit delete (no state loaded) + result = handleImplicitOperations(operation); + } + + // Step 5: Save state back only on success (the executor handles rollback on failure) + if (stateType != null) { + operation.getState().setState(this.state); + } + + return result; + } + + private Object handleImplicitOperations(TaskEntityOperation operation) { + if ("delete".equalsIgnoreCase(operation.getName())) { + operation.getState().deleteState(); + this.state = null; + return null; + } + throw new UnsupportedOperationException( + "Entity '" + this.getClass().getSimpleName() + "' does not support operation '" + + operation.getName() + "'."); + } + + /** + * Finds a public method on the target class matching the operation name (case-insensitive). + * Methods inherited from {@code Object} and from {@code TaskEntity} itself are excluded. + */ + @Nullable + private static Method findMethod(Class targetClass, String operationName) { + String cacheKey = targetClass.getName() + "#" + operationName.toLowerCase(); + return methodCache.computeIfAbsent(cacheKey, k -> { + List matches = new ArrayList<>(); + for (Method m : targetClass.getMethods()) { + // Skip static methods — only instance methods should be dispatchable + if (Modifier.isStatic(m.getModifiers())) { + continue; + } + // Skip methods from Object + if (m.getDeclaringClass() == Object.class) { + continue; + } + // Skip methods from TaskEntity base class itself + if (m.getDeclaringClass() == TaskEntity.class) { + continue; + } + // Skip methods from ITaskEntity interface + if (m.getDeclaringClass() == ITaskEntity.class) { + continue; + } + // Skip methods from JDK packages to prevent unintended state dispatch + // (e.g., Integer.intValue(), String.length()) when state type is a JDK class + String declaringPackage = m.getDeclaringClass().getPackageName(); + if (declaringPackage.startsWith("java.") || declaringPackage.startsWith("javax.")) { + continue; + } + if (m.getName().equalsIgnoreCase(operationName)) { + matches.add(m); + } + } + if (matches.size() > 1) { + throw new IllegalStateException( + "Ambiguous match: multiple methods named '" + operationName + "' found on " + + targetClass.getName() + ". Entity operation methods must have unique names."); + } + return matches.isEmpty() ? Optional.empty() : Optional.of(matches.get(0)); + }).orElse(null); + } + + /** + * Invokes the resolved method with appropriate parameter binding. + *

+ * Supports 0–2 parameters among: + *

    + *
  • The operation input (deserialized to the parameter type)
  • + *
  • {@link TaskEntityContext}
  • + *
+ */ + private static Object invokeMethod(Method method, Object target, TaskEntityOperation operation) + throws Exception { + Class[] paramTypes = method.getParameterTypes(); + Object[] args = new Object[paramTypes.length]; + + for (int i = 0; i < paramTypes.length; i++) { + if (TaskEntityContext.class.isAssignableFrom(paramTypes[i])) { + args[i] = operation.getContext(); + } else if (TaskEntityOperation.class.isAssignableFrom(paramTypes[i])) { + args[i] = operation; + } else { + // Assume this is the input parameter + args[i] = operation.getInput(paramTypes[i]); + } + } + + try { + return method.invoke(target, args); + } catch (InvocationTargetException e) { + // Unwrap the target exception + Throwable cause = e.getTargetException(); + if (cause instanceof Exception) { + throw (Exception) cause; + } + throw new RuntimeException(cause); + } + } +} diff --git a/client/src/main/java/com/microsoft/durabletask/TaskEntityContext.java b/client/src/main/java/com/microsoft/durabletask/TaskEntityContext.java new file mode 100644 index 00000000..23490ef2 --- /dev/null +++ b/client/src/main/java/com/microsoft/durabletask/TaskEntityContext.java @@ -0,0 +1,96 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.microsoft.durabletask; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.ArrayList; +import java.util.List; + +/** + * Provides contextual information and side-effecting capabilities to a durable entity during operation execution. + *

+ * The context allows entities to: + *

    + *
  • Get their own entity instance ID
  • + *
  • Signal other entities (fire-and-forget)
  • + *
  • Start new orchestration instances
  • + *
+ *

+ * Actions performed through the context (signals and orchestration starts) are collected and + * submitted as part of the entity batch result. They support transactional semantics: + * if an operation fails, any actions enqueued during that operation are rolled back. + */ +public abstract class TaskEntityContext { + /** + * Gets the instance ID of the currently executing entity. + * + * @return the entity instance ID + */ + @Nonnull + public abstract EntityInstanceId getId(); + + /** + * Sends a fire-and-forget signal to another entity. + * + * @param entityId the target entity's instance ID + * @param operationName the name of the operation to invoke on the target entity + */ + public void signalEntity(@Nonnull EntityInstanceId entityId, @Nonnull String operationName) { + this.signalEntity(entityId, operationName, null, null); + } + + /** + * Sends a fire-and-forget signal to another entity with input data. + * + * @param entityId the target entity's instance ID + * @param operationName the name of the operation to invoke on the target entity + * @param input the input data for the operation, or {@code null} + */ + public void signalEntity( + @Nonnull EntityInstanceId entityId, + @Nonnull String operationName, + @Nullable Object input) { + this.signalEntity(entityId, operationName, input, null); + } + + /** + * Sends a fire-and-forget signal to another entity with input data and options. + * + * @param entityId the target entity's instance ID + * @param operationName the name of the operation to invoke on the target entity + * @param input the input data for the operation, or {@code null} + * @param options signal options (e.g., scheduled time), or {@code null} + */ + public abstract void signalEntity( + @Nonnull EntityInstanceId entityId, + @Nonnull String operationName, + @Nullable Object input, + @Nullable SignalEntityOptions options); + + /** + * Starts a new orchestration instance. + * + * @param name the name of the orchestration to start + * @param input the input for the orchestration, or {@code null} + * @return the instance ID of the newly started orchestration + */ + @Nonnull + public String startNewOrchestration(@Nonnull String name, @Nullable Object input) { + return this.startNewOrchestration(name, input, null); + } + + /** + * Starts a new orchestration instance with options. + * + * @param name the name of the orchestration to start + * @param input the input for the orchestration, or {@code null} + * @param options the orchestration start options (e.g., instance ID), or {@code null} + * @return the instance ID of the newly started orchestration + */ + @Nonnull + public abstract String startNewOrchestration( + @Nonnull String name, + @Nullable Object input, + @Nullable NewOrchestrationInstanceOptions options); +} diff --git a/client/src/main/java/com/microsoft/durabletask/TaskEntityExecutor.java b/client/src/main/java/com/microsoft/durabletask/TaskEntityExecutor.java new file mode 100644 index 00000000..d12565d1 --- /dev/null +++ b/client/src/main/java/com/microsoft/durabletask/TaskEntityExecutor.java @@ -0,0 +1,364 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.microsoft.durabletask; + +import com.google.protobuf.StringValue; +import com.google.protobuf.Timestamp; +import com.microsoft.durabletask.implementation.protobuf.OrchestratorService.*; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.time.Instant; +import java.util.*; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * Executes entity batch requests by dispatching operations to registered entity factories. + *

+ * Each operation in the batch is executed independently with transactional semantics: + * successful operations commit their state and actions, while failed operations roll back + * to the previous committed state and discard any actions enqueued during the failed operation. + */ +final class TaskEntityExecutor { + private final HashMap entityFactories; + private final DataConverter dataConverter; + private final Logger logger; + + TaskEntityExecutor( + HashMap entityFactories, + DataConverter dataConverter, + Logger logger) { + this.entityFactories = entityFactories; + this.dataConverter = dataConverter; + this.logger = logger; + } + + /** + * Executes a batch of entity operations from an {@code EntityBatchRequest}. + * + * @param request the entity batch request from the sidecar + * @return the entity batch result to send back to the sidecar + */ + @Nonnull + EntityBatchResult execute(@Nonnull EntityBatchRequest request) { + String instanceId = request.getInstanceId(); + EntityInstanceId entityId = EntityInstanceId.fromString(instanceId); + String entityName = entityId.getName(); + + logger.log(Level.FINE, "Executing entity batch for '{0}' with {1} operation(s).", + new Object[]{instanceId, request.getOperationsCount()}); + + // Look up the entity factory + TaskEntityFactory factory = this.entityFactories.get(entityName); + if (factory == null) { + // Try case-insensitive lookup + for (Map.Entry entry : this.entityFactories.entrySet()) { + if (entry.getKey().equalsIgnoreCase(entityName)) { + factory = entry.getValue(); + break; + } + } + } + + if (factory == null) { + String errorMessage = String.format("No entity named '%s' is registered.", entityName); + logger.log(Level.WARNING, errorMessage); + TaskFailureDetails failureDetails = TaskFailureDetails.newBuilder() + .setErrorType(IllegalStateException.class.getName()) + .setErrorMessage(errorMessage) + .build(); + return EntityBatchResult.newBuilder() + .setFailureDetails(failureDetails) + .build(); + } + + // Initialize state from the request + String initialState = request.hasEntityState() + ? request.getEntityState().getValue() + : null; + TaskEntityState entityState = new TaskEntityState(this.dataConverter, initialState); + + // Create the concrete context that collects actions + TaskEntityContextImpl context = new TaskEntityContextImpl(entityId, this.dataConverter); + + // Process each operation + List results = new ArrayList<>(); + + // Create a single entity instance for the entire batch + ITaskEntity entity; + try { + entity = factory.create(); + if (entity == null) { + String errorMsg = String.format("The entity factory for '%s' returned a null entity.", entityName); + logger.log(Level.WARNING, errorMsg); + TaskFailureDetails failureDetails = TaskFailureDetails.newBuilder() + .setErrorType(IllegalStateException.class.getName()) + .setErrorMessage(errorMsg) + .build(); + return EntityBatchResult.newBuilder() + .setFailureDetails(failureDetails) + .build(); + } + } catch (Exception e) { + String errorMsg = String.format("Failed to create entity instance for '%s': %s", entityName, e.getMessage()); + logger.log(Level.WARNING, errorMsg, e); + TaskFailureDetails failureDetails = TaskFailureDetails.newBuilder() + .setErrorType(e.getClass().getName()) + .setErrorMessage(e.getMessage() != null ? e.getMessage() : "") + .setStackTrace(StringValue.of(FailureDetails.getFullStackTrace(e))) + .build(); + return EntityBatchResult.newBuilder() + .setFailureDetails(failureDetails) + .build(); + } + + for (OperationRequest opRequest : request.getOperationsList()) { + String operationName = opRequest.getOperation(); + String requestId = opRequest.getRequestId(); + String serializedInput = opRequest.hasInput() ? opRequest.getInput().getValue() : null; + + logger.log(Level.FINE, "Executing operation '{0}' (requestId={1}) on entity '{2}'.", + new Object[]{operationName, requestId, instanceId}); + + // Snapshot state and actions before each operation (for rollback on failure) + entityState.commit(); + context.commit(); + + Instant startTime = Instant.now(); + + try { + // Build the operation + TaskEntityOperation operation = new TaskEntityOperation( + operationName, serializedInput, context, entityState, this.dataConverter); + + // Execute + Object result = entity.runAsync(operation); + + Instant endTime = Instant.now(); + + // Build success result + OperationResultSuccess.Builder successBuilder = OperationResultSuccess.newBuilder() + .setStartTimeUtc(toTimestamp(startTime)) + .setEndTimeUtc(toTimestamp(endTime)); + + if (result != null) { + String serializedResult = this.dataConverter.serialize(result); + if (serializedResult != null) { + successBuilder.setResult(StringValue.of(serializedResult)); + } + } + + results.add(OperationResult.newBuilder() + .setSuccess(successBuilder.build()) + .build()); + + // Commit state and actions on success + entityState.commit(); + context.commit(); + + logger.log(Level.FINE, "Operation '{0}' on entity '{1}' completed successfully.", + new Object[]{operationName, instanceId}); + + } catch (Exception e) { + Instant endTime = Instant.now(); + + logger.log(Level.WARNING, + String.format("Operation '%s' on entity '%s' failed: %s", + operationName, instanceId, e.getMessage()), + e); + + // Build failure result + TaskFailureDetails failureDetails = TaskFailureDetails.newBuilder() + .setErrorType(e.getClass().getName()) + .setErrorMessage(e.getMessage() != null ? e.getMessage() : "") + .setStackTrace(StringValue.of(FailureDetails.getFullStackTrace(e))) + .build(); + + OperationResultFailure failure = OperationResultFailure.newBuilder() + .setFailureDetails(failureDetails) + .setStartTimeUtc(toTimestamp(startTime)) + .setEndTimeUtc(toTimestamp(endTime)) + .build(); + + results.add(OperationResult.newBuilder() + .setFailure(failure) + .build()); + + // Rollback state and actions on failure + entityState.rollback(); + context.rollback(); + } + } + + // Build the final result + EntityBatchResult.Builder resultBuilder = EntityBatchResult.newBuilder() + .addAllResults(results) + .addAllActions(context.getCommittedActions(0)); + + // Set the final entity state + String finalState = entityState.getSerializedState(); + if (finalState != null) { + resultBuilder.setEntityState(StringValue.of(finalState)); + } + + return resultBuilder.build(); + } + + private static Timestamp toTimestamp(Instant instant) { + return Timestamp.newBuilder() + .setSeconds(instant.getEpochSecond()) + .setNanos(instant.getNano()) + .build(); + } + + /** + * Concrete implementation of {@link TaskEntityContext} that collects {@link OperationAction} protos + * during entity operation execution. + */ + private static class TaskEntityContextImpl extends TaskEntityContext { + private final EntityInstanceId entityId; + private final DataConverter dataConverter; + private final List pendingActions = new ArrayList<>(); + private int committedActionCount = 0; + + TaskEntityContextImpl(EntityInstanceId entityId, DataConverter dataConverter) { + this.entityId = entityId; + this.dataConverter = dataConverter; + } + + @Nonnull + @Override + public EntityInstanceId getId() { + return this.entityId; + } + + @Override + public void signalEntity( + @Nonnull EntityInstanceId targetEntityId, + @Nonnull String operationName, + @Nullable Object input, + @Nullable SignalEntityOptions options) { + Objects.requireNonNull(targetEntityId, "targetEntityId must not be null"); + Objects.requireNonNull(operationName, "operationName must not be null"); + + SendSignalAction.Builder signalBuilder = SendSignalAction.newBuilder() + .setInstanceId(targetEntityId.toString()) + .setName(operationName); + + if (input != null) { + String serializedInput = this.dataConverter.serialize(input); + if (serializedInput != null) { + signalBuilder.setInput(StringValue.of(serializedInput)); + } + } + + if (options != null && options.getScheduledTime() != null) { + Instant scheduledTime = options.getScheduledTime(); + signalBuilder.setScheduledTime(Timestamp.newBuilder() + .setSeconds(scheduledTime.getEpochSecond()) + .setNanos(scheduledTime.getNano()) + .build()); + } + + this.pendingActions.add(new PendingAction(PendingAction.Type.SEND_SIGNAL, signalBuilder.build(), null)); + } + + @Nonnull + @Override + public String startNewOrchestration( + @Nonnull String name, + @Nullable Object input, + @Nullable NewOrchestrationInstanceOptions options) { + Objects.requireNonNull(name, "orchestration name must not be null"); + + String instanceId = (options != null && options.getInstanceId() != null) + ? options.getInstanceId() + : UUID.randomUUID().toString(); + + StartNewOrchestrationAction.Builder orchBuilder = StartNewOrchestrationAction.newBuilder() + .setInstanceId(instanceId) + .setName(name); + + if (input != null) { + String serializedInput = this.dataConverter.serialize(input); + if (serializedInput != null) { + orchBuilder.setInput(StringValue.of(serializedInput)); + } + } + + if (options != null) { + if (options.getVersion() != null) { + orchBuilder.setVersion(StringValue.of(options.getVersion())); + } + if (options.getStartTime() != null) { + Instant startTime = options.getStartTime(); + orchBuilder.setScheduledTime(Timestamp.newBuilder() + .setSeconds(startTime.getEpochSecond()) + .setNanos(startTime.getNano()) + .build()); + } + } + + this.pendingActions.add(new PendingAction( + PendingAction.Type.START_NEW_ORCHESTRATION, null, orchBuilder.build())); + + return instanceId; + } + + /** + * Marks the current set of pending actions as committed (snapshot for rollback). + */ + void commit() { + this.committedActionCount = this.pendingActions.size(); + } + + /** + * Rolls back any uncommitted actions (discards actions added since last commit). + */ + void rollback() { + while (this.pendingActions.size() > this.committedActionCount) { + this.pendingActions.remove(this.pendingActions.size() - 1); + } + } + + /** + * Returns all committed actions as proto {@link OperationAction} objects. + * + * @param startId the starting ID for action numbering + * @return the list of committed operation actions + */ + List getCommittedActions(int startId) { + List actions = new ArrayList<>(); + int id = startId; + for (PendingAction pending : this.pendingActions) { + OperationAction.Builder actionBuilder = OperationAction.newBuilder() + .setId(id++); + if (pending.type == PendingAction.Type.SEND_SIGNAL) { + actionBuilder.setSendSignal(pending.sendSignal); + } else { + actionBuilder.setStartNewOrchestration(pending.startNewOrchestration); + } + actions.add(actionBuilder.build()); + } + return actions; + } + + /** + * Represents a pending action (signal or orchestration start) collected during entity execution. + */ + private static class PendingAction { + enum Type { SEND_SIGNAL, START_NEW_ORCHESTRATION } + + final Type type; + final SendSignalAction sendSignal; + final StartNewOrchestrationAction startNewOrchestration; + + PendingAction(Type type, SendSignalAction sendSignal, StartNewOrchestrationAction startNewOrchestration) { + this.type = type; + this.sendSignal = sendSignal; + this.startNewOrchestration = startNewOrchestration; + } + } + } +} diff --git a/client/src/main/java/com/microsoft/durabletask/TaskEntityFactory.java b/client/src/main/java/com/microsoft/durabletask/TaskEntityFactory.java new file mode 100644 index 00000000..010fe26e --- /dev/null +++ b/client/src/main/java/com/microsoft/durabletask/TaskEntityFactory.java @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.microsoft.durabletask; + +/** + * Functional interface for creating {@link ITaskEntity} instances. + *

+ * Entity factories are registered with the {@link DurableTaskGrpcWorkerBuilder} and are used to create + * new entity instances when entity work items are received from the sidecar. + */ +@FunctionalInterface +public interface TaskEntityFactory { + /** + * Creates a new instance of {@link ITaskEntity}. + * + * @return a new entity instance + */ + ITaskEntity create(); +} diff --git a/client/src/main/java/com/microsoft/durabletask/TaskEntityOperation.java b/client/src/main/java/com/microsoft/durabletask/TaskEntityOperation.java new file mode 100644 index 00000000..f9974f60 --- /dev/null +++ b/client/src/main/java/com/microsoft/durabletask/TaskEntityOperation.java @@ -0,0 +1,104 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.microsoft.durabletask; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +/** + * Represents a single operation to be executed against a durable entity. + *

+ * An operation includes: + *

    + *
  • The operation name (e.g., "add", "get", "delete")
  • + *
  • Optional input data
  • + *
  • Access to the entity's state via {@link TaskEntityState}
  • + *
  • Access to the entity's context via {@link TaskEntityContext}
  • + *
+ */ +public class TaskEntityOperation { + private final String name; + private final String serializedInput; + private final TaskEntityContext context; + private final TaskEntityState state; + private final DataConverter dataConverter; + + /** + * Creates a new {@code TaskEntityOperation}. + * + * @param name the operation name + * @param serializedInput the serialized input data, or {@code null} if no input + * @param context the entity context + * @param state the entity state + * @param dataConverter the data converter for deserializing input + */ + public TaskEntityOperation( + @Nonnull String name, + @Nullable String serializedInput, + @Nonnull TaskEntityContext context, + @Nonnull TaskEntityState state, + @Nonnull DataConverter dataConverter) { + if (name == null || name.isEmpty()) { + throw new IllegalArgumentException("Operation name must not be null or empty."); + } + this.name = name; + this.serializedInput = serializedInput; + this.context = context; + this.state = state; + this.dataConverter = dataConverter; + } + + /** + * Gets the name of this operation. + * + * @return the operation name + */ + @Nonnull + public String getName() { + return this.name; + } + + /** + * Deserializes and returns the input data for this operation. + * + * @param inputType the class to deserialize the input into + * @param the expected type of the input + * @return the deserialized input, or {@code null} if no input was provided + */ + @Nullable + public T getInput(@Nonnull Class inputType) { + if (this.serializedInput == null) { + return null; + } + return this.dataConverter.deserialize(this.serializedInput, inputType); + } + + /** + * Returns whether this operation has input data. + * + * @return {@code true} if input data was provided, {@code false} otherwise + */ + public boolean hasInput() { + return this.serializedInput != null; + } + + /** + * Gets the context for the currently executing entity. + * + * @return the entity context + */ + @Nonnull + public TaskEntityContext getContext() { + return this.context; + } + + /** + * Gets the state accessor for the currently executing entity. + * + * @return the entity state + */ + @Nonnull + public TaskEntityState getState() { + return this.state; + } +} diff --git a/client/src/main/java/com/microsoft/durabletask/TaskEntityState.java b/client/src/main/java/com/microsoft/durabletask/TaskEntityState.java new file mode 100644 index 00000000..dad45e6d --- /dev/null +++ b/client/src/main/java/com/microsoft/durabletask/TaskEntityState.java @@ -0,0 +1,111 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.microsoft.durabletask; + +import javax.annotation.Nullable; + +/** + * Provides access to the state of a durable entity during operation execution. + *

+ * Entity state is automatically persisted by the durable task runtime after each successfully + * committed operation. The state supports transactional semantics: if an operation fails, + * the state is rolled back to the last committed snapshot. + */ +public class TaskEntityState { + private final DataConverter dataConverter; + private String serializedState; + private boolean stateExists; + + // Transactional rollback support + private String committedSerializedState; + private boolean committedStateExists; + + /** + * Creates a new {@code TaskEntityState} instance. + * + * @param dataConverter the data converter used for serialization/deserialization + * @param serializedState the initial serialized state, or {@code null} if no state exists + */ + public TaskEntityState(DataConverter dataConverter, @Nullable String serializedState) { + this.dataConverter = dataConverter; + this.serializedState = serializedState; + this.stateExists = serializedState != null; + // Initialize committed state to match initial state + this.committedSerializedState = serializedState; + this.committedStateExists = this.stateExists; + } + + /** + * Gets the current entity state, deserialized to the specified type. + * + * @param stateType the class to deserialize the state into + * @param the type of the state + * @return the deserialized state, or {@code null} if no state exists + */ + @Nullable + public T getState(Class stateType) { + if (!this.stateExists || this.serializedState == null) { + return null; + } + return this.dataConverter.deserialize(this.serializedState, stateType); + } + + /** + * Sets the entity state. The state will be serialized using the configured {@link DataConverter}. + * + * @param state the state value to set + */ + public void setState(@Nullable Object state) { + if (state == null) { + deleteState(); + } else { + this.serializedState = this.dataConverter.serialize(state); + this.stateExists = true; + } + } + + /** + * Returns whether the entity currently has state. + * + * @return {@code true} if the entity has state, {@code false} otherwise + */ + public boolean hasState() { + return this.stateExists; + } + + /** + * Deletes the entity state. After calling this method, {@link #hasState()} will return {@code false}. + */ + public void deleteState() { + this.serializedState = null; + this.stateExists = false; + } + + /** + * Gets the serialized (raw) state string. + * + * @return the serialized state, or {@code null} if no state exists + */ + @Nullable + String getSerializedState() { + return this.serializedState; + } + + /** + * Takes a snapshot of the current state as a rollback point. + * Called before each operation in a batch to support transactional semantics. + */ + void commit() { + this.committedSerializedState = this.serializedState; + this.committedStateExists = this.stateExists; + } + + /** + * Reverts the state to the last committed snapshot. + * Called when an operation fails to maintain transactional semantics. + */ + void rollback() { + this.serializedState = this.committedSerializedState; + this.stateExists = this.committedStateExists; + } +} diff --git a/client/src/main/java/com/microsoft/durabletask/TaskOrchestrationContext.java b/client/src/main/java/com/microsoft/durabletask/TaskOrchestrationContext.java index 370c4b66..ccc70408 100644 --- a/client/src/main/java/com/microsoft/durabletask/TaskOrchestrationContext.java +++ b/client/src/main/java/com/microsoft/durabletask/TaskOrchestrationContext.java @@ -9,6 +9,7 @@ import java.util.Arrays; import java.util.List; import java.util.UUID; +import javax.annotation.Nonnull; /** * Used by orchestrators to perform actions such as scheduling tasks, durable timers, waiting for external events, @@ -558,6 +559,187 @@ default Task waitForExternalEvent(String name, Class dataType) { } } + /** + * Gets the durable entity feature for this orchestration context. + *

+ * This mirrors the .NET SDK's {@code TaskOrchestrationContext.Entities} surface, + * adapted to Java as a method-based accessor. + * + * @return the entity feature for this orchestration context + */ + default TaskOrchestrationEntityFeature getEntities() { + return new ContextBackedTaskOrchestrationEntityFeature(this); + } + + /** + * Gets the durable entity feature for this orchestration context. + *

+ * This is an alias of {@link #getEntities()}. + * + * @return the entity feature for this orchestration context + */ + default TaskOrchestrationEntityFeature entities() { + return this.getEntities(); + } + + // region Entity integration methods + + /** + * Sends a fire-and-forget signal to a durable entity. + *

+ * Signals are one-way messages that do not return a result. The target entity will execute the specified + * operation asynchronously. If the entity does not exist, it will be created automatically. + * + * @param entityId the unique identifier of the target entity + * @param operationName the name of the operation to invoke on the entity + */ + default void signalEntity(@Nonnull EntityInstanceId entityId, @Nonnull String operationName) { + this.signalEntity(entityId, operationName, null, null); + } + + /** + * Sends a fire-and-forget signal to a durable entity with the specified input. + *

+ * Signals are one-way messages that do not return a result. The target entity will execute the specified + * operation asynchronously. If the entity does not exist, it will be created automatically. + * + * @param entityId the unique identifier of the target entity + * @param operationName the name of the operation to invoke on the entity + * @param input the serializable input to pass to the entity operation, or {@code null} + */ + default void signalEntity(@Nonnull EntityInstanceId entityId, @Nonnull String operationName, @Nullable Object input) { + this.signalEntity(entityId, operationName, input, null); + } + + /** + * Sends a fire-and-forget signal to a durable entity with the specified input and options. + *

+ * Signals are one-way messages that do not return a result. The target entity will execute the specified + * operation asynchronously. If the entity does not exist, it will be created automatically. + * + * @param entityId the unique identifier of the target entity + * @param operationName the name of the operation to invoke on the entity + * @param input the serializable input to pass to the entity operation, or {@code null} + * @param options signal options such as scheduled delivery time, or {@code null} + */ + void signalEntity(@Nonnull EntityInstanceId entityId, @Nonnull String operationName, @Nullable Object input, @Nullable SignalEntityOptions options); + + /** + * Calls an operation on a durable entity and waits for the result. + *

+ * Unlike {@link #signalEntity}, this method is a two-way call that returns a result. The calling orchestration + * will block until the entity operation completes and returns a response. + * + * @param entityId the unique identifier of the target entity + * @param operationName the name of the operation to invoke on the entity + * @param input the serializable input to pass to the entity operation, or {@code null} + * @param returnType the expected class type of the entity operation output + * @param the expected type of the entity operation output + * @return a {@link Task} that completes when the entity operation completes + */ + Task callEntity(@Nonnull EntityInstanceId entityId, @Nonnull String operationName, @Nullable Object input, @Nonnull Class returnType); + + /** + * Calls an operation on a durable entity and waits for the result, with options. + *

+ * Unlike {@link #signalEntity}, this method is a two-way call that returns a result. The calling orchestration + * will block until the entity operation completes and returns a response. + * + * @param entityId the unique identifier of the target entity + * @param operationName the name of the operation to invoke on the entity + * @param input the serializable input to pass to the entity operation, or {@code null} + * @param returnType the expected class type of the entity operation output + * @param options call options such as timeout, or {@code null} + * @param the expected type of the entity operation output + * @return a {@link Task} that completes when the entity operation completes + */ + Task callEntity(@Nonnull EntityInstanceId entityId, @Nonnull String operationName, @Nullable Object input, @Nonnull Class returnType, @Nullable CallEntityOptions options); + + /** + * Calls an operation on a durable entity and waits for it to complete (no return value). + * + * @param entityId the unique identifier of the target entity + * @param operationName the name of the operation to invoke on the entity + */ + default Task callEntity(@Nonnull EntityInstanceId entityId, @Nonnull String operationName) { + return this.callEntity(entityId, operationName, null, Void.class); + } + + /** + * Calls an operation on a durable entity with input and waits for it to complete (no return value). + * + * @param entityId the unique identifier of the target entity + * @param operationName the name of the operation to invoke on the entity + * @param input the serializable input to pass to the entity operation, or {@code null} + */ + default Task callEntity(@Nonnull EntityInstanceId entityId, @Nonnull String operationName, @Nullable Object input) { + return this.callEntity(entityId, operationName, input, Void.class); + } + + /** + * Calls an operation on a durable entity and waits for the result (no input). + * + * @param entityId the unique identifier of the target entity + * @param operationName the name of the operation to invoke on the entity + * @param returnType the expected class type of the entity operation output + * @param the expected type of the entity operation output + * @return a {@link Task} that completes when the entity operation completes + */ + default Task callEntity(@Nonnull EntityInstanceId entityId, @Nonnull String operationName, @Nonnull Class returnType) { + return this.callEntity(entityId, operationName, null, returnType); + } + + /** + * Acquires one or more entity locks for the duration of a critical section. + *

+ * Entity locks are used to coordinate access and prevent conflicts when multiple orchestrations need + * to access the same entities. The returned {@link AutoCloseable} must be closed to release the locks. + *

+ * Entity IDs are sorted deterministically before acquiring locks to prevent deadlocks. + * Nesting of lock calls is not supported and will throw an {@link IllegalStateException}. + *

+ * Example usage: + *

{@code
+     * try (AutoCloseable lock = ctx.lockEntities(entityIds).await()) {
+     *     // Perform operations on the locked entities
+     *     ctx.callEntity(entityId, "transfer", amount).await();
+     * }
+     * }
+ * + * @param entityIds the list of entity instance IDs to lock; must not be empty + * @return a {@link Task} that completes with an {@link AutoCloseable} when all locks are acquired + */ + Task lockEntities(@Nonnull List entityIds); + + /** + * Acquires one or more entity locks for the duration of a critical section (varargs overload). + * + * @param entityIds the entity instance IDs to lock; must not be empty + * @return a {@link Task} that completes with an {@link AutoCloseable} when all locks are acquired + */ + default Task lockEntities(@Nonnull EntityInstanceId... entityIds) { + return this.lockEntities(Arrays.asList(entityIds)); + } + + /** + * Gets a value indicating whether this orchestration is currently executing inside a critical section + * that was created by {@link #lockEntities}. + * + * @return {@code true} if the orchestration is inside a critical section, otherwise {@code false} + */ + boolean isInCriticalSection(); + + /** + * Gets the list of entity instance IDs that are currently locked by this orchestration. + *

+ * Returns an empty list if the orchestration is not inside a critical section. + * + * @return an unmodifiable list of locked entity instance IDs + */ + List getLockedEntities(); + + // endregion + /** * Assigns a custom status value to the current orchestration. *

@@ -575,4 +757,4 @@ default Task waitForExternalEvent(String name, Class dataType) { * Clears the orchestration's custom status. */ void clearCustomStatus(); -} +} \ No newline at end of file diff --git a/client/src/main/java/com/microsoft/durabletask/TaskOrchestrationEntityFeature.java b/client/src/main/java/com/microsoft/durabletask/TaskOrchestrationEntityFeature.java new file mode 100644 index 00000000..a633ffc0 --- /dev/null +++ b/client/src/main/java/com/microsoft/durabletask/TaskOrchestrationEntityFeature.java @@ -0,0 +1,188 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.microsoft.durabletask; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.List; + +/** + * Feature for interacting with durable entities from an orchestration. + *

+ * This mirrors the .NET SDK's {@code TaskOrchestrationContext.Entities} shape, adapted to Java. + */ +public abstract class TaskOrchestrationEntityFeature { + + /** + * Calls an operation on an entity and waits for it to complete. + * + * @param entityId the target entity + * @param operationName the name of the operation + * @param input the operation input, or {@code null} + * @param returnType the expected return type + * @param the result type + * @return a task that completes with the operation result + */ + public abstract Task callEntity( + @Nonnull EntityInstanceId entityId, + @Nonnull String operationName, + @Nullable Object input, + @Nonnull Class returnType); + + /** + * Calls an operation on an entity and waits for it to complete, with options. + * + * @param entityId the target entity + * @param operationName the name of the operation + * @param input the operation input, or {@code null} + * @param returnType the expected return type + * @param options the call options, or {@code null} + * @param the result type + * @return a task that completes with the operation result + */ + public abstract Task callEntity( + @Nonnull EntityInstanceId entityId, + @Nonnull String operationName, + @Nullable Object input, + @Nonnull Class returnType, + @Nullable CallEntityOptions options); + + /** + * Calls an operation on an entity and waits for it to complete. + * + * @param entityId the target entity + * @param operationName the name of the operation + * @param the result type + * @return a task that completes with the operation result + */ + public Task callEntity( + @Nonnull EntityInstanceId entityId, + @Nonnull String operationName, + @Nonnull Class returnType) { + return this.callEntity(entityId, operationName, null, returnType, null); + } + + /** + * Signals an entity operation without waiting for completion. + * + * @param entityId the target entity + * @param operationName the operation name + * @param input the operation input, or {@code null} + * @param options signal options, or {@code null} + */ + public abstract void signalEntity( + @Nonnull EntityInstanceId entityId, + @Nonnull String operationName, + @Nullable Object input, + @Nullable SignalEntityOptions options); + + /** + * Signals an entity operation without waiting for completion. + * + * @param entityId the target entity + * @param operationName the operation name + */ + public void signalEntity(@Nonnull EntityInstanceId entityId, @Nonnull String operationName) { + this.signalEntity(entityId, operationName, null, null); + } + + /** + * Signals an entity operation without waiting for completion. + * + * @param entityId the target entity + * @param operationName the operation name + * @param input the operation input, or {@code null} + */ + public void signalEntity( + @Nonnull EntityInstanceId entityId, + @Nonnull String operationName, + @Nullable Object input) { + this.signalEntity(entityId, operationName, input, null); + } + + /** + * Acquires one or more entity locks. + * + * @param entityIds the entity IDs to lock + * @return a task that completes with an {@link AutoCloseable} used to release locks + */ + public abstract Task lockEntities(@Nonnull List entityIds); + + /** + * Acquires one or more entity locks. + * + * @param entityIds the entity IDs to lock + * @return a task that completes with an {@link AutoCloseable} used to release locks + */ + public abstract Task lockEntities(@Nonnull EntityInstanceId... entityIds); + + /** + * Gets whether this orchestration is currently inside a critical section. + * + * @return {@code true} if inside a critical section, otherwise {@code false} + */ + public abstract boolean isInCriticalSection(); + + /** + * Gets the currently locked entity IDs. + * + * @return the list of currently locked entities, or an empty list + */ + public abstract List getLockedEntities(); +} + +final class ContextBackedTaskOrchestrationEntityFeature extends TaskOrchestrationEntityFeature { + private final TaskOrchestrationContext context; + + ContextBackedTaskOrchestrationEntityFeature(TaskOrchestrationContext context) { + this.context = context; + } + + @Override + public Task callEntity( + @Nonnull EntityInstanceId entityId, + @Nonnull String operationName, + @Nullable Object input, + @Nonnull Class returnType) { + return this.context.callEntity(entityId, operationName, input, returnType); + } + + @Override + public Task callEntity( + @Nonnull EntityInstanceId entityId, + @Nonnull String operationName, + @Nullable Object input, + @Nonnull Class returnType, + @Nullable CallEntityOptions options) { + return this.context.callEntity(entityId, operationName, input, returnType, options); + } + + @Override + public void signalEntity( + @Nonnull EntityInstanceId entityId, + @Nonnull String operationName, + @Nullable Object input, + @Nullable SignalEntityOptions options) { + this.context.signalEntity(entityId, operationName, input, options); + } + + @Override + public Task lockEntities(@Nonnull List entityIds) { + return this.context.lockEntities(entityIds); + } + + @Override + public Task lockEntities(@Nonnull EntityInstanceId... entityIds) { + return this.context.lockEntities(entityIds); + } + + @Override + public boolean isInCriticalSection() { + return this.context.isInCriticalSection(); + } + + @Override + public List getLockedEntities() { + return this.context.getLockedEntities(); + } +} \ No newline at end of file diff --git a/client/src/main/java/com/microsoft/durabletask/TaskOrchestrationExecutor.java b/client/src/main/java/com/microsoft/durabletask/TaskOrchestrationExecutor.java index 9c70db02..078cd960 100644 --- a/client/src/main/java/com/microsoft/durabletask/TaskOrchestrationExecutor.java +++ b/client/src/main/java/com/microsoft/durabletask/TaskOrchestrationExecutor.java @@ -2,6 +2,10 @@ // Licensed under the MIT License. package com.microsoft.durabletask; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ArrayNode; +import com.fasterxml.jackson.databind.node.ObjectNode; import com.google.protobuf.StringValue; import com.google.protobuf.Timestamp; import com.microsoft.durabletask.interruption.ContinueAsNewInterruption; @@ -26,6 +30,8 @@ final class TaskOrchestrationExecutor { private static final String EMPTY_STRING = ""; + // ObjectMapper for parsing DTFx entity ResponseMessage JSON wrappers in the trigger binding code path + private static final ObjectMapper JSON_MAPPER = new ObjectMapper(); private final HashMap orchestrationFactories; private final DataConverter dataConverter; private final Logger logger; @@ -60,7 +66,11 @@ public TaskOrchestratorResult execute(List pastEvents, List lockedEntityIds; + private final Map> pendingLockSets = new HashMap<>(); + public ContextImplTask(List pastEvents, List newEvents) { this.historyEventPlayer = new OrchestrationHistoryIterator(pastEvents, newEvents); } @@ -353,6 +370,275 @@ public UUID newUUID() { return UUIDGenerator.generate(version, hashV5, UUID.fromString(dnsNameSpace), name); } + // region Entity integration methods (Phase 4) + + @Override + public void signalEntity(EntityInstanceId entityId, String operationName, Object input, SignalEntityOptions options) { + Helpers.throwIfOrchestratorComplete(this.isComplete); + Helpers.throwIfArgumentNull(entityId, "entityId"); + Helpers.throwIfArgumentNull(operationName, "operationName"); + + int id = this.sequenceNumber++; + String requestId = this.newUUID().toString(); + String serializedInput = this.dataConverter.serialize(input); + + // Build DTFx RequestMessage JSON payload matching the legacy format that the + // Azure Functions extension (DTFx backend) understands. The extension processes + // entity messages as external events (SendEventAction), NOT the newer proto-native + // SendEntityMessageAction which is designed for the DTS backend. + ObjectNode requestMessage = JSON_MAPPER.createObjectNode(); + requestMessage.put("op", operationName); + requestMessage.put("signal", true); + if (serializedInput != null) { + requestMessage.put("input", serializedInput); + } + requestMessage.put("id", requestId); + String eventName = "op"; + if (options != null && options.getScheduledTime() != null) { + String scheduledTimeStr = options.getScheduledTime().toString(); + requestMessage.put("due", scheduledTimeStr); + eventName = "op@" + scheduledTimeStr; + } + + this.pendingActions.put(id, OrchestratorAction.newBuilder() + .setId(id) + .setSendEvent(SendEventAction.newBuilder() + .setInstance(OrchestrationInstance.newBuilder() + .setInstanceId(entityId.toString())) + .setName(eventName) + .setData(StringValue.of(requestMessage.toString()))) + .build()); + + if (!this.isReplaying) { + this.logger.fine(() -> String.format( + "%s: signaling entity '%s' operation '%s' (#%d)", + this.instanceId, + entityId, + operationName, + id)); + } + } + + @Override + public Task callEntity(EntityInstanceId entityId, String operationName, Object input, Class returnType) { + return this.callEntity(entityId, operationName, input, returnType, null); + } + + @Override + public Task callEntity(EntityInstanceId entityId, String operationName, Object input, Class returnType, CallEntityOptions options) { + Helpers.throwIfOrchestratorComplete(this.isComplete); + Helpers.throwIfArgumentNull(entityId, "entityId"); + Helpers.throwIfArgumentNull(operationName, "operationName"); + Helpers.throwIfArgumentNull(returnType, "returnType"); + + // Validate critical section: calls must target locked entities to prevent deadlocks + if (this.isInCriticalSection && this.lockedEntityIds != null + && !this.lockedEntityIds.contains(entityId.toString())) { + throw new IllegalStateException(String.format( + "Cannot call entity '%s' from within a critical section because it is not locked. " + + "Only locked entities can be called inside a critical section to prevent deadlocks.", + entityId)); + } + + int id = this.sequenceNumber++; + String requestId = this.newUUID().toString(); + String serializedInput = this.dataConverter.serialize(input); + + // Build DTFx RequestMessage JSON for entity call (two-way operation). + // Uses SendEventAction (external event) instead of SendEntityMessageAction for + // compatibility with the Azure Functions extension (DTFx backend). + ObjectNode requestMessage = JSON_MAPPER.createObjectNode(); + requestMessage.put("op", operationName); + if (serializedInput != null) { + requestMessage.put("input", serializedInput); + } + requestMessage.put("id", requestId); + requestMessage.put("parent", this.instanceId); + if (this.executionId != null) { + requestMessage.put("parentExecution", this.executionId); + } + + this.pendingActions.put(id, OrchestratorAction.newBuilder() + .setId(id) + .setSendEvent(SendEventAction.newBuilder() + .setInstance(OrchestrationInstance.newBuilder() + .setInstanceId(entityId.toString())) + .setName("op") + .setData(StringValue.of(requestMessage.toString()))) + .build()); + + if (!this.isReplaying) { + this.logger.info(() -> String.format( + "%s: calling entity '%s' operation '%s' (#%d) requestId=%s", + this.instanceId, + entityId, + operationName, + id, + requestId)); + } + + CompletableTask task = new CompletableTask<>(); + TaskRecord record = new TaskRecord<>(task, operationName, returnType, entityId); + Queue> eventQueue = this.outstandingEvents.computeIfAbsent(requestId, k -> new LinkedList<>()); + eventQueue.add(record); + + // If a timeout is specified, schedule a durable timer to cancel the call if the entity + // doesn't respond in time (same pattern as waitForExternalEvent with timeout). + Duration timeout = options != null ? options.getTimeout() : null; + if (timeout != null && !Helpers.isInfiniteTimeout(timeout)) { + if (timeout.isZero()) { + // Immediately cancel + eventQueue.removeIf(t -> t.task == task); + if (eventQueue.isEmpty()) { + this.outstandingEvents.remove(requestId); + } + String message = String.format( + "Timeout of %s expired while calling entity '%s' operation '%s' (requestId=%s).", + timeout, entityId, operationName, requestId); + task.completeExceptionally(new TaskCanceledException(message, operationName, id)); + } else { + this.createTimer(timeout).future.thenRun(() -> { + if (!task.isDone()) { + eventQueue.removeIf(t -> t.task == task); + if (eventQueue.isEmpty()) { + this.outstandingEvents.remove(requestId); + } + String message = String.format( + "Timeout of %s expired while calling entity '%s' operation '%s' (requestId=%s).", + timeout, entityId, operationName, requestId); + task.completeExceptionally(new TaskCanceledException(message, operationName, id)); + } + }); + } + } + + return task; + } + + @Override + public Task lockEntities(List entityIds) { + Helpers.throwIfOrchestratorComplete(this.isComplete); + Helpers.throwIfArgumentNull(entityIds, "entityIds"); + if (entityIds.isEmpty()) { + throw new IllegalArgumentException("entityIds must not be empty"); + } + if (this.isInCriticalSection) { + throw new IllegalStateException( + "Cannot nest lock calls. The orchestration is already inside a critical section."); + } + + // Sort entity IDs deterministically to prevent deadlocks + List sortedIds = new ArrayList<>(entityIds); + Collections.sort(sortedIds); + + String criticalSectionId = this.newUUID().toString(); + + // Build lock set as string list + List lockSet = new ArrayList<>(sortedIds.size()); + for (EntityInstanceId eid : sortedIds) { + lockSet.add(eid.toString()); + } + + // Send a lock request to the FIRST entity in the sorted lock set. + // DTFx entity infrastructure handles chaining the lock acquisition + // through subsequent entities in the lock set. + { + int id = this.sequenceNumber++; + ObjectNode lockRequestMessage = JSON_MAPPER.createObjectNode(); + lockRequestMessage.putNull("op"); + lockRequestMessage.put("id", criticalSectionId); + ArrayNode lockSetArray = lockRequestMessage.putArray("lockset"); + for (EntityInstanceId eid : sortedIds) { + ObjectNode entityIdNode = JSON_MAPPER.createObjectNode(); + entityIdNode.put("name", eid.getName()); + entityIdNode.put("key", eid.getKey()); + lockSetArray.add(entityIdNode); + } + lockRequestMessage.put("pos", 0); + lockRequestMessage.put("parent", this.instanceId); + + String targetEntityId = lockSet.get(0); + this.pendingActions.put(id, OrchestratorAction.newBuilder() + .setId(id) + .setSendEvent(SendEventAction.newBuilder() + .setInstance(OrchestrationInstance.newBuilder() + .setInstanceId(targetEntityId)) + .setName("op") + .setData(StringValue.of(lockRequestMessage.toString()))) + .build()); + } + + // Store the lock set so handleEntityLockGranted can populate lockedEntityIds + Set lockSetForStorage = new HashSet<>(lockSet); + this.pendingLockSets.put(criticalSectionId, lockSetForStorage); + + if (!this.isReplaying) { + this.logger.fine(() -> String.format( + "%s: requesting locks on %d entities, criticalSectionId=%s", + this.instanceId, + sortedIds.size(), + criticalSectionId)); + } + + // Create a waiter keyed by criticalSectionId + CompletableTask lockTask = new CompletableTask<>(); + TaskRecord record = new TaskRecord<>(lockTask, "(lock)", AutoCloseable.class); + Queue> eventQueue = this.outstandingEvents.computeIfAbsent(criticalSectionId, k -> new LinkedList<>()); + eventQueue.add(record); + + // Wrap the result so that when the lock is granted, we return an AutoCloseable + // that releases all locks on close() + return lockTask.thenApply(ignored -> (AutoCloseable) () -> { + // Release all locks + for (EntityInstanceId lockedEntity : sortedIds) { + int unlockId = this.sequenceNumber++; + // Build DTFx ReleaseMessage JSON for releasing entity locks + ObjectNode releaseMessage = JSON_MAPPER.createObjectNode(); + releaseMessage.put("parent", this.instanceId); + releaseMessage.put("id", criticalSectionId); + + this.pendingActions.put(unlockId, OrchestratorAction.newBuilder() + .setId(unlockId) + .setSendEvent(SendEventAction.newBuilder() + .setInstance(OrchestrationInstance.newBuilder() + .setInstanceId(lockedEntity.toString())) + .setName("release") + .setData(StringValue.of(releaseMessage.toString()))) + .build()); + } + + this.isInCriticalSection = false; + this.currentCriticalSectionId = null; + this.lockedEntityIds = null; + + if (!this.isReplaying) { + this.logger.fine(() -> String.format( + "%s: released locks for criticalSectionId=%s", + this.instanceId, + criticalSectionId)); + } + }); + } + + @Override + public boolean isInCriticalSection() { + return this.isInCriticalSection; + } + + @Override + public List getLockedEntities() { + if (!this.isInCriticalSection || this.lockedEntityIds == null || this.lockedEntityIds.isEmpty()) { + return Collections.emptyList(); + } + List result = new ArrayList<>(this.lockedEntityIds.size()); + for (String id : this.lockedEntityIds) { + result.add(EntityInstanceId.fromString(id)); + } + return Collections.unmodifiableList(result); + } + + // endregion + @Override public void sendEvent(String instanceId, String eventName, Object eventData) { Helpers.throwIfOrchestratorComplete(this.isComplete); @@ -586,6 +872,13 @@ private void handleEventRaised(HistoryEvent e) { Queue> outstandingEventQueue = this.outstandingEvents.get(eventName); if (outstandingEventQueue == null) { // No code is waiting for this event. Buffer it in case user-code waits for it later. + if (!this.isReplaying) { + this.logger.info(() -> String.format( + "%s: Received EventRaised '%s' but no outstanding waiter found. Buffering as unprocessed. Raw input: %s", + this.instanceId, + eventName, + eventRaised.getInput().getValue())); + } this.unprocessedEvents.add(e); return; } @@ -597,16 +890,294 @@ private void handleEventRaised(HistoryEvent e) { } String rawResult = eventRaised.getInput().getValue(); CompletableTask task = matchingTaskRecord.getTask(); + + // In the Azure Functions trigger binding code path, entity operation responses arrive as + // standard EventRaised events (not EntityOperationCompleted proto events). DTFx wraps entity + // responses in a ResponseMessage JSON format: {"result":"","errorMessage":...,"failureDetails":...} + // We detect entity call responses by checking if the task record has an associated entityId. + if (matchingTaskRecord.getEntityId() != null) { + if (!this.isReplaying) { + this.logger.info(() -> String.format( + "%s: Routing EventRaised '%s' to entity response handler for entity '%s'. Raw result: %s", + this.instanceId, + eventName, + matchingTaskRecord.getEntityId(), + rawResult != null ? rawResult : "(null)")); + } + this.handleEntityResponseFromEventRaised(matchingTaskRecord, rawResult); + } else { + try { + Object result = this.dataConverter.deserialize( + rawResult, + matchingTaskRecord.getDataType()); + task.complete(result); + } catch (Exception ex) { + task.completeExceptionally(ex); + } + } + } + + /** + * Handles an entity operation response that arrived as an EventRaised event (trigger binding path). + *

+ * In the trigger binding code path used by Azure Functions, DTFx.Core wraps entity responses + * in a ResponseMessage JSON format rather than using proto EntityOperationCompleted events. + * This method parses the ResponseMessage wrapper and extracts the actual operation result. + *

+ * The DTFx ResponseMessage class (DurableTask.Core.Entities.EventFormat.ResponseMessage) serializes as: + *

    + *
  • {@code "result"} — the serialized operation result (always present, may be null)
  • + *
  • {@code "exceptionType"} — the error message string (misleading name: the C# property is + * {@code ErrorMessage} but its {@code [DataMember(Name = "exceptionType")]} annotation overrides + * the JSON key). Omitted when null (EmitDefaultValue=false).
  • + *
  • {@code "failureDetails"} — a FailureDetails object with PascalCase fields + * ({@code ErrorType}, {@code ErrorMessage}, {@code StackTrace}, etc.). + * Omitted when null (EmitDefaultValue=false).
  • + *
+ * + * @param matchingTaskRecord the task record for the entity call + * @param rawResult the raw JSON string from the EventRaised event + */ + @SuppressWarnings("unchecked") + private void handleEntityResponseFromEventRaised(TaskRecord matchingTaskRecord, String rawResult) { + CompletableTask task = matchingTaskRecord.getTask(); + try { + // Parse the ResponseMessage JSON wrapper from DTFx + JsonNode responseNode = JSON_MAPPER.readTree(rawResult); + + if (responseNode == null || !responseNode.isObject() || !responseNode.has("result")) { + // Not a recognized ResponseMessage format — fall back to direct deserialization. + // This handles the case where the extension may send raw results in the future. + Object result = this.dataConverter.deserialize(rawResult, matchingTaskRecord.getDataType()); + task.complete(result); + return; + } + + // Check for error in the response. + // DTFx ResponseMessage uses "exceptionType" as the JSON key for the ErrorMessage property + // (due to [DataMember(Name = "exceptionType")]). These fields are omitted when null. + JsonNode exceptionTypeNode = responseNode.get("exceptionType"); + JsonNode failureDetailsNode = responseNode.get("failureDetails"); + boolean hasExceptionType = exceptionTypeNode != null && !exceptionTypeNode.isNull(); + boolean hasFailureDetails = failureDetailsNode != null && !failureDetailsNode.isNull(); + + if (hasExceptionType || hasFailureDetails) { + // Entity operation failed — extract error info and complete exceptionally. + // The "exceptionType" JSON field actually contains the error message (misleading name). + String errorMessage = hasExceptionType + ? exceptionTypeNode.asText() : "Entity operation failed"; + String errorType = "unknown"; + + if (hasFailureDetails) { + // FailureDetails has PascalCase JSON fields: ErrorType, ErrorMessage, StackTrace, etc. + JsonNode errorTypeNode = failureDetailsNode.get("ErrorType"); + if (errorTypeNode != null && !errorTypeNode.isNull()) { + errorType = errorTypeNode.asText(); + } + JsonNode detailErrorMsgNode = failureDetailsNode.get("ErrorMessage"); + if (detailErrorMsgNode != null && !detailErrorMsgNode.isNull()) { + errorMessage = detailErrorMsgNode.asText(); + } + } + + if (!this.isReplaying) { + final String logErrorType = errorType; + final String logErrorMessage = errorMessage; + this.logger.warning(() -> String.format( + "%s: Entity operation on '%s' failed: [%s] %s", + this.instanceId, + matchingTaskRecord.getEntityId(), + logErrorType, + logErrorMessage)); + } + + FailureDetails details = new FailureDetails(errorType, errorMessage, null, false); + task.completeExceptionally(new EntityOperationFailedException( + matchingTaskRecord.getEntityId(), + matchingTaskRecord.getTaskName(), + details)); + } else { + // Success — extract the inner result value + JsonNode resultNode = responseNode.get("result"); + String innerResult = (resultNode == null || resultNode.isNull()) ? null : resultNode.asText(); + + if (!this.isReplaying) { + this.logger.info(() -> String.format( + "%s: Entity operation on '%s' completed via EventRaised with result: %s", + this.instanceId, + matchingTaskRecord.getEntityId(), + innerResult != null ? innerResult : "(null)")); + } + + Object result = this.dataConverter.deserialize(innerResult, matchingTaskRecord.getDataType()); + task.complete(result); + } + } catch (EntityOperationFailedException ex) { + // Re-throw entity failures (already handled above via completeExceptionally) + task.completeExceptionally(ex); + } catch (Exception ex) { + task.completeExceptionally(ex); + } + } + + private void handleEventSent(HistoryEvent e) { + // During replay, remove the pending action so we don't re-send already-processed + // events. This applies to entity operations (signal, call, lock, unlock) which + // now use SendEventAction, as well as regular sendEvent calls. + int taskId = e.getEventId(); + this.pendingActions.remove(taskId); + } + + // region Entity event handlers (Phase 4) + + private void handleEntityOperationSignaled(HistoryEvent e) { + int taskId = e.getEventId(); + OrchestratorAction taskAction = this.pendingActions.remove(taskId); + if (taskAction == null) { + String message = String.format( + "Non-deterministic orchestrator detected: a history event for entity signal with sequence ID %d was replayed but the current orchestrator implementation didn't schedule this signal.", + taskId); + throw new NonDeterministicOrchestratorException(message); + } + } + + private void handleEntityOperationCalled(HistoryEvent e) { + int taskId = e.getEventId(); + OrchestratorAction taskAction = this.pendingActions.remove(taskId); + if (taskAction == null) { + String message = String.format( + "Non-deterministic orchestrator detected: a history event for entity call with sequence ID %d was replayed but the current orchestrator implementation didn't schedule this call.", + taskId); + throw new NonDeterministicOrchestratorException(message); + } + } + + @SuppressWarnings("unchecked") + private void handleEntityOperationCompleted(HistoryEvent e) { + EntityOperationCompletedEvent completedEvent = e.getEntityOperationCompleted(); + String requestId = completedEvent.getRequestId(); + + Queue> outstandingQueue = this.outstandingEvents.get(requestId); + if (outstandingQueue == null) { + this.logger.warning("Discarding entity operation completed event with requestId=" + requestId + ": no waiter found. Outstanding event keys: " + this.outstandingEvents.keySet()); + return; + } + + TaskRecord record = outstandingQueue.remove(); + if (outstandingQueue.isEmpty()) { + this.outstandingEvents.remove(requestId); + } + + String rawResult = completedEvent.hasOutput() ? completedEvent.getOutput().getValue() : null; + + if (!this.isReplaying) { + this.logger.info(() -> String.format( + "%s: Entity operation completed for requestId=%s with output: %s", + this.instanceId, + requestId, + rawResult != null ? rawResult : "(null)")); + } + + CompletableTask task = record.getTask(); try { - Object result = this.dataConverter.deserialize( - rawResult, - matchingTaskRecord.getDataType()); + Object result = this.dataConverter.deserialize(rawResult, record.getDataType()); task.complete(result); } catch (Exception ex) { task.completeExceptionally(ex); } } + private void handleEntityOperationFailed(HistoryEvent e) { + EntityOperationFailedEvent failedEvent = e.getEntityOperationFailed(); + String requestId = failedEvent.getRequestId(); + + Queue> outstandingQueue = this.outstandingEvents.get(requestId); + if (outstandingQueue == null) { + this.logger.warning("Discarding entity operation failed event with requestId=" + requestId + ": no waiter found"); + return; + } + + TaskRecord record = outstandingQueue.remove(); + if (outstandingQueue.isEmpty()) { + this.outstandingEvents.remove(requestId); + } + + FailureDetails details = new FailureDetails(failedEvent.getFailureDetails()); + + if (!this.isReplaying) { + this.logger.info(() -> String.format( + "%s: Entity operation failed for requestId=%s: %s", + this.instanceId, + requestId, + details.getErrorMessage())); + } + + CompletableTask task = record.getTask(); + EntityInstanceId failedEntityId = record.getEntityId() != null + ? record.getEntityId() + : EntityInstanceId.fromString("@unknown@unknown"); + EntityOperationFailedException exception = new EntityOperationFailedException( + failedEntityId, + record.getTaskName(), + details); + task.completeExceptionally(exception); + } + + private void handleEntityLockRequested(HistoryEvent e) { + int taskId = e.getEventId(); + OrchestratorAction taskAction = this.pendingActions.remove(taskId); + if (taskAction == null) { + String message = String.format( + "Non-deterministic orchestrator detected: a history event for entity lock request with sequence ID %d was replayed but the current orchestrator implementation didn't issue this lock request.", + taskId); + throw new NonDeterministicOrchestratorException(message); + } + } + + @SuppressWarnings("unchecked") + private void handleEntityLockGranted(HistoryEvent e) { + EntityLockGrantedEvent lockGrantedEvent = e.getEntityLockGranted(); + String criticalSectionId = lockGrantedEvent.getCriticalSectionId(); + + Queue> outstandingQueue = this.outstandingEvents.get(criticalSectionId); + if (outstandingQueue == null) { + this.logger.warning("Discarding entity lock granted event with criticalSectionId=" + criticalSectionId + ": no waiter found"); + return; + } + + TaskRecord record = outstandingQueue.remove(); + if (outstandingQueue.isEmpty()) { + this.outstandingEvents.remove(criticalSectionId); + } + + this.isInCriticalSection = true; + this.currentCriticalSectionId = criticalSectionId; + this.lockedEntityIds = this.pendingLockSets.remove(criticalSectionId); + + if (!this.isReplaying) { + this.logger.fine(() -> String.format( + "%s: Entity lock granted for criticalSectionId=%s", + this.instanceId, + criticalSectionId)); + } + + CompletableTask task = record.getTask(); + task.complete(null); // The actual AutoCloseable is created via thenApply in lockEntities() + } + + private void handleEntityUnlockSent(HistoryEvent e) { + int taskId = e.getEventId(); + OrchestratorAction taskAction = this.pendingActions.remove(taskId); + if (taskAction == null) { + String message = String.format( + "Non-deterministic orchestrator detected: a history event for entity unlock with sequence ID %d was replayed but the current orchestrator implementation didn't issue this unlock.", + taskId); + throw new NonDeterministicOrchestratorException(message); + } + } + + // endregion + private void handleEventWhileSuspended (HistoryEvent historyEvent){ if (historyEvent.getEventTypeCase() != HistoryEvent.EventTypeCase.EXECUTIONSUSPENDED) { eventsWhileSuspended.offer(historyEvent); @@ -855,6 +1426,13 @@ private boolean processNextEvent() { } private void processEvent(HistoryEvent e) { + if (!this.isReplaying) { + this.logger.info(() -> String.format( + "%s: Processing new event: %s (eventId=%d)", + this.instanceId, + e.getEventTypeCase(), + e.getEventId())); + } boolean overrideSuspension = e.getEventTypeCase() == HistoryEvent.EventTypeCase.EXECUTIONRESUMED || e.getEventTypeCase() == HistoryEvent.EventTypeCase.EXECUTIONTERMINATED; if (this.isSuspended && !overrideSuspension) { this.handleEventWhileSuspended(e); @@ -873,6 +1451,9 @@ private void processEvent(HistoryEvent e) { this.setName(name); String instanceId = startedEvent.getOrchestrationInstance().getInstanceId(); this.setInstanceId(instanceId); + if (startedEvent.getOrchestrationInstance().hasExecutionId()) { + this.executionId = startedEvent.getOrchestrationInstance().getExecutionId().getValue(); + } String input = startedEvent.getInput().getValue(); this.setInput(input); String version = startedEvent.getVersion().getValue(); @@ -922,8 +1503,9 @@ private void processEvent(HistoryEvent e) { case SUBORCHESTRATIONINSTANCEFAILED: this.handleSubOrchestrationFailed(e); break; -// case EVENTSENT: -// break; + case EVENTSENT: + this.handleEventSent(e); + break; case EVENTRAISED: this.handleEventRaised(e); break; @@ -939,6 +1521,28 @@ private void processEvent(HistoryEvent e) { case EXECUTIONRESUMED: this.handleExecutionResumed(e); break; + // Entity event cases (Phase 4) + case ENTITYOPERATIONSIGNALED: + this.handleEntityOperationSignaled(e); + break; + case ENTITYOPERATIONCALLED: + this.handleEntityOperationCalled(e); + break; + case ENTITYOPERATIONCOMPLETED: + this.handleEntityOperationCompleted(e); + break; + case ENTITYOPERATIONFAILED: + this.handleEntityOperationFailed(e); + break; + case ENTITYLOCKREQUESTED: + this.handleEntityLockRequested(e); + break; + case ENTITYLOCKGRANTED: + this.handleEntityLockGranted(e); + break; + case ENTITYUNLOCKSENT: + this.handleEntityUnlockSent(e); + break; default: throw new IllegalStateException("Don't know how to handle history type " + e.getEventTypeCase()); } @@ -949,11 +1553,17 @@ private class TaskRecord { private final CompletableTask task; private final String taskName; private final Class dataType; + private final EntityInstanceId entityId; public TaskRecord(CompletableTask task, String taskName, Class dataType) { + this(task, taskName, dataType, null); + } + + public TaskRecord(CompletableTask task, String taskName, Class dataType, EntityInstanceId entityId) { this.task = task; this.taskName = taskName; this.dataType = dataType; + this.entityId = entityId; } public CompletableTask getTask() { @@ -967,6 +1577,10 @@ public String getTaskName() { public Class getDataType() { return this.dataType; } + + public EntityInstanceId getEntityId() { + return this.entityId; + } } private class OrchestrationHistoryIterator { @@ -1387,6 +2001,10 @@ protected void handleException(Throwable e) { throw (DataConverter.DataConverterException)e; } + if (e instanceof EntityOperationFailedException) { + throw (EntityOperationFailedException)e; + } + throw new RuntimeException("Unexpected failure in the task execution", e); } diff --git a/client/src/main/java/com/microsoft/durabletask/TypedEntityMetadata.java b/client/src/main/java/com/microsoft/durabletask/TypedEntityMetadata.java new file mode 100644 index 00000000..9fcd8b32 --- /dev/null +++ b/client/src/main/java/com/microsoft/durabletask/TypedEntityMetadata.java @@ -0,0 +1,74 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.microsoft.durabletask; + +import javax.annotation.Nullable; + +/** + * An extension of {@link EntityMetadata} that provides typed access to the entity's state. + *

+ * This mirrors the .NET SDK's {@code EntityMetadata} which inherits from {@code EntityMetadata} + * and provides a typed {@code State} property. In Java, the state is eagerly deserialized and accessible + * via {@link #getState()}. + * + *

Example:

+ *
{@code
+ * TypedEntityMetadata metadata = client.getEntities()
+ *     .getEntityMetadata(entityId, Integer.class);
+ * if (metadata != null) {
+ *     Integer state = metadata.getState();
+ *     System.out.println("Counter value: " + state);
+ * }
+ * }
+ * + * @param the type of the entity's state + * @see EntityMetadata + * @see DurableEntityClient#getEntityMetadata(EntityInstanceId, Class) + */ +public final class TypedEntityMetadata extends EntityMetadata { + + private final T state; + private final Class stateType; + + /** + * Creates a new {@code TypedEntityMetadata} from an existing {@link EntityMetadata} and a state type. + *

+ * The state is eagerly deserialized from the metadata's serialized state. + * + * @param source the source metadata to wrap + * @param stateType the class to deserialize the state into + */ + TypedEntityMetadata(EntityMetadata source, Class stateType) { + super( + source.getInstanceId(), + source.getLastModifiedTime(), + source.getBacklogQueueSize(), + source.getLockedBy(), + source.getSerializedState(), + source.isIncludesState(), + source.getDataConverter()); + this.stateType = stateType; + this.state = source.readStateAs(stateType); + } + + /** + * Gets the deserialized entity state. + *

+ * Returns {@code null} if the entity has no state or if state was not included in the query. + * + * @return the deserialized state, or {@code null} + */ + @Nullable + public T getState() { + return this.state; + } + + /** + * Gets the state type class used for deserialization. + * + * @return the state type class + */ + public Class getStateType() { + return this.stateType; + } +} diff --git a/client/src/main/java/com/microsoft/durabletask/TypedEntityQueryPageable.java b/client/src/main/java/com/microsoft/durabletask/TypedEntityQueryPageable.java new file mode 100644 index 00000000..a8a7c1a9 --- /dev/null +++ b/client/src/main/java/com/microsoft/durabletask/TypedEntityQueryPageable.java @@ -0,0 +1,74 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.microsoft.durabletask; + +import java.util.Iterator; +import java.util.NoSuchElementException; + +/** + * An auto-paginating iterable over entity query results with typed state access. + *

+ * This class wraps an {@link EntityQueryPageable} and yields {@link TypedEntityMetadata} items + * with eagerly deserialized state, mirroring the .NET SDK's {@code AsyncPageable>} + * returned by {@code GetAllEntitiesAsync()}. + *

+ * Use {@link DurableEntityClient#getAllEntities(EntityQuery, Class)} to obtain an instance. + * + *

Example:

+ *
{@code
+ * EntityQuery query = new EntityQuery().setInstanceIdStartsWith("counter");
+ * for (TypedEntityMetadata entity : client.getEntities().getAllEntities(query, Integer.class)) {
+ *     Integer state = entity.getState();
+ *     System.out.println("Counter value: " + state);
+ * }
+ * }
+ * + * @param the entity state type + */ +public final class TypedEntityQueryPageable implements Iterable> { + private final EntityQueryPageable inner; + private final Class stateType; + + /** + * Creates a new {@code TypedEntityQueryPageable}. + * + * @param inner the underlying pageable that fetches raw entity metadata + * @param stateType the class to deserialize each entity's state into + */ + TypedEntityQueryPageable(EntityQueryPageable inner, Class stateType) { + this.inner = inner; + this.stateType = stateType; + } + + /** + * Returns an iterator over individual {@link TypedEntityMetadata} items with eagerly + * deserialized state, automatically fetching subsequent pages as needed. + * + * @return an iterator over all matching entities with typed state + */ + @Override + public Iterator> iterator() { + return new TypedEntityItemIterator(inner.iterator()); + } + + private class TypedEntityItemIterator implements Iterator> { + private final Iterator delegate; + + TypedEntityItemIterator(Iterator delegate) { + this.delegate = delegate; + } + + @Override + public boolean hasNext() { + return delegate.hasNext(); + } + + @Override + public TypedEntityMetadata next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + return new TypedEntityMetadata<>(delegate.next(), stateType); + } + } +} diff --git a/client/src/test/java/com/microsoft/durabletask/CleanEntityStorageRequestTest.java b/client/src/test/java/com/microsoft/durabletask/CleanEntityStorageRequestTest.java new file mode 100644 index 00000000..58155968 --- /dev/null +++ b/client/src/test/java/com/microsoft/durabletask/CleanEntityStorageRequestTest.java @@ -0,0 +1,67 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.microsoft.durabletask; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * Unit tests for {@link CleanEntityStorageRequest}. + */ +public class CleanEntityStorageRequestTest { + + @Test + void defaults_matchDotNetDefaults() { + CleanEntityStorageRequest request = new CleanEntityStorageRequest(); + assertNull(request.getContinuationToken()); + assertTrue(request.isRemoveEmptyEntities()); + assertTrue(request.isReleaseOrphanedLocks()); + assertFalse(request.isContinueUntilComplete()); + } + + @Test + void setContinuationToken_roundTrip() { + CleanEntityStorageRequest request = new CleanEntityStorageRequest() + .setContinuationToken("token123"); + assertEquals("token123", request.getContinuationToken()); + } + + @Test + void setContinuationToken_null_allowed() { + CleanEntityStorageRequest request = new CleanEntityStorageRequest() + .setContinuationToken("token") + .setContinuationToken(null); + assertNull(request.getContinuationToken()); + } + + @Test + void setRemoveEmptyEntities_roundTrip() { + CleanEntityStorageRequest request = new CleanEntityStorageRequest() + .setRemoveEmptyEntities(true); + assertTrue(request.isRemoveEmptyEntities()); + } + + @Test + void setReleaseOrphanedLocks_roundTrip() { + CleanEntityStorageRequest request = new CleanEntityStorageRequest() + .setReleaseOrphanedLocks(true); + assertTrue(request.isReleaseOrphanedLocks()); + } + + @Test + void setContinueUntilComplete_roundTrip() { + CleanEntityStorageRequest request = new CleanEntityStorageRequest() + .setContinueUntilComplete(true); + assertTrue(request.isContinueUntilComplete()); + } + + @Test + void fluentChaining_returnsSameInstance() { + CleanEntityStorageRequest request = new CleanEntityStorageRequest(); + assertSame(request, request.setContinuationToken("t")); + assertSame(request, request.setRemoveEmptyEntities(true)); + assertSame(request, request.setReleaseOrphanedLocks(true)); + assertSame(request, request.setContinueUntilComplete(true)); + } +} diff --git a/client/src/test/java/com/microsoft/durabletask/CleanEntityStorageResultTest.java b/client/src/test/java/com/microsoft/durabletask/CleanEntityStorageResultTest.java new file mode 100644 index 00000000..90b5ebd2 --- /dev/null +++ b/client/src/test/java/com/microsoft/durabletask/CleanEntityStorageResultTest.java @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.microsoft.durabletask; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * Unit tests for {@link CleanEntityStorageResult}. + */ +public class CleanEntityStorageResultTest { + + @Test + void constructor_setsAllFields() { + CleanEntityStorageResult result = new CleanEntityStorageResult("nextToken", 5, 3); + assertEquals("nextToken", result.getContinuationToken()); + assertEquals(5, result.getEmptyEntitiesRemoved()); + assertEquals(3, result.getOrphanedLocksReleased()); + } + + @Test + void constructor_nullContinuationToken_allowed() { + CleanEntityStorageResult result = new CleanEntityStorageResult(null, 0, 0); + assertNull(result.getContinuationToken()); + assertEquals(0, result.getEmptyEntitiesRemoved()); + assertEquals(0, result.getOrphanedLocksReleased()); + } + + @Test + void constructor_zeroCounts() { + CleanEntityStorageResult result = new CleanEntityStorageResult("token", 0, 0); + assertEquals("token", result.getContinuationToken()); + assertEquals(0, result.getEmptyEntitiesRemoved()); + assertEquals(0, result.getOrphanedLocksReleased()); + } +} diff --git a/client/src/test/java/com/microsoft/durabletask/EntityInstanceIdTest.java b/client/src/test/java/com/microsoft/durabletask/EntityInstanceIdTest.java new file mode 100644 index 00000000..c5083a6e --- /dev/null +++ b/client/src/test/java/com/microsoft/durabletask/EntityInstanceIdTest.java @@ -0,0 +1,215 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.microsoft.durabletask; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * Unit tests for {@link EntityInstanceId}. + */ +public class EntityInstanceIdTest { + + @Test + void constructor_validNameAndKey() { + EntityInstanceId id = new EntityInstanceId("Counter", "myCounter"); + assertEquals("counter", id.getName()); + assertEquals("myCounter", id.getKey()); + } + + @Test + void constructor_nullName_throwsException() { + assertThrows(IllegalArgumentException.class, () -> new EntityInstanceId(null, "key")); + } + + @Test + void constructor_emptyName_throwsException() { + assertThrows(IllegalArgumentException.class, () -> new EntityInstanceId("", "key")); + } + + @Test + void constructor_nullKey_throwsException() { + assertThrows(IllegalArgumentException.class, () -> new EntityInstanceId("name", null)); + } + + @Test + void constructor_emptyKey_throwsException() { + assertThrows(IllegalArgumentException.class, () -> new EntityInstanceId("name", "")); + } + + @Test + void toString_format() { + EntityInstanceId id = new EntityInstanceId("Counter", "myCounter"); + assertEquals("@counter@myCounter", id.toString()); + } + + @Test + void fromString_validFormat() { + EntityInstanceId id = EntityInstanceId.fromString("@Counter@myCounter"); + assertEquals("counter", id.getName()); + assertEquals("myCounter", id.getKey()); + } + + @Test + void fromString_keyContainsAtSymbol() { + // The key can contain @ symbols — only the first two @ delimiters matter + EntityInstanceId id = EntityInstanceId.fromString("@Counter@key@with@ats"); + assertEquals("counter", id.getName()); + assertEquals("key@with@ats", id.getKey()); + } + + @Test + void fromString_null_throwsException() { + assertThrows(IllegalArgumentException.class, () -> EntityInstanceId.fromString(null)); + } + + @Test + void fromString_empty_throwsException() { + assertThrows(IllegalArgumentException.class, () -> EntityInstanceId.fromString("")); + } + + @Test + void fromString_noLeadingAt_throwsException() { + assertThrows(IllegalArgumentException.class, () -> EntityInstanceId.fromString("Counter@key")); + } + + @Test + void fromString_onlyOneAt_throwsException() { + assertThrows(IllegalArgumentException.class, () -> EntityInstanceId.fromString("@Counter")); + } + + @Test + void roundTrip_toStringAndFromString() { + EntityInstanceId original = new EntityInstanceId("BankAccount", "acct-123"); + EntityInstanceId parsed = EntityInstanceId.fromString(original.toString()); + assertEquals(original, parsed); + } + + @Test + void equals_sameValues_areEqual() { + EntityInstanceId id1 = new EntityInstanceId("Counter", "c1"); + EntityInstanceId id2 = new EntityInstanceId("Counter", "c1"); + assertEquals(id1, id2); + assertEquals(id1.hashCode(), id2.hashCode()); + } + + @Test + void equals_differentName_notEqual() { + EntityInstanceId id1 = new EntityInstanceId("Counter", "c1"); + EntityInstanceId id2 = new EntityInstanceId("Timer", "c1"); + assertNotEquals(id1, id2); + } + + @Test + void equals_differentKey_notEqual() { + EntityInstanceId id1 = new EntityInstanceId("Counter", "c1"); + EntityInstanceId id2 = new EntityInstanceId("Counter", "c2"); + assertNotEquals(id1, id2); + } + + @Test + void equals_null_notEqual() { + EntityInstanceId id = new EntityInstanceId("Counter", "c1"); + assertNotEquals(null, id); + } + + @Test + void equals_differentType_notEqual() { + EntityInstanceId id = new EntityInstanceId("Counter", "c1"); + assertNotEquals("@counter@c1", id); + } + + @Test + void constructor_nameIsLowercased() { + EntityInstanceId id = new EntityInstanceId("Counter", "c1"); + assertEquals("counter", id.getName()); + } + + @Test + void equals_differentCaseName_areEqual() { + EntityInstanceId id1 = new EntityInstanceId("Counter", "c1"); + EntityInstanceId id2 = new EntityInstanceId("counter", "c1"); + assertEquals(id1, id2); + assertEquals(id1.hashCode(), id2.hashCode()); + } + + @Test + void equals_mixedCaseName_areEqual() { + EntityInstanceId id1 = new EntityInstanceId("COUNTER", "c1"); + EntityInstanceId id2 = new EntityInstanceId("counter", "c1"); + assertEquals(id1, id2); + } + + @Test + void toString_nameIsLowercased() { + EntityInstanceId id = new EntityInstanceId("MyEntity", "key1"); + assertEquals("@myentity@key1", id.toString()); + } + + @Test + void fromString_nameIsLowercased() { + EntityInstanceId id = EntityInstanceId.fromString("@MyEntity@key1"); + assertEquals("myentity", id.getName()); + } + + @Test + void constructor_nameContainsAt_throwsException() { + assertThrows(IllegalArgumentException.class, () -> new EntityInstanceId("my@entity", "key")); + } + + @Test + void constructor_nameStartsWithAt_throwsException() { + assertThrows(IllegalArgumentException.class, () -> new EntityInstanceId("@entity", "key")); + } + + @Test + void constructor_nameEndsWithAt_throwsException() { + assertThrows(IllegalArgumentException.class, () -> new EntityInstanceId("entity@", "key")); + } + + @Test + void compareTo_ordering() { + EntityInstanceId a = new EntityInstanceId("A", "1"); + EntityInstanceId b = new EntityInstanceId("B", "1"); + EntityInstanceId a2 = new EntityInstanceId("A", "2"); + + // Same values + assertEquals(0, a.compareTo(new EntityInstanceId("A", "1"))); + + // Same values different case + assertEquals(0, a.compareTo(new EntityInstanceId("a", "1"))); + + // Sort by name first + assertTrue(a.compareTo(b) < 0); + assertTrue(b.compareTo(a) > 0); + + // Same name, sort by key + assertTrue(a.compareTo(a2) < 0); + assertTrue(a2.compareTo(a) > 0); + } + + @Test + void compareTo_sortsList() { + EntityInstanceId c1 = new EntityInstanceId("Counter", "1"); + EntityInstanceId b2 = new EntityInstanceId("BankAccount", "2"); + EntityInstanceId c3 = new EntityInstanceId("Counter", "3"); + EntityInstanceId a1 = new EntityInstanceId("Account", "1"); + + List ids = Arrays.asList(c1, b2, c3, a1); + Collections.sort(ids); + + assertEquals("account", ids.get(0).getName()); + assertEquals("bankaccount", ids.get(1).getName()); + assertEquals("counter", ids.get(2).getName()); + assertEquals("1", ids.get(2).getKey()); + assertEquals("counter", ids.get(3).getName()); + assertEquals("3", ids.get(3).getKey()); + } +} diff --git a/client/src/test/java/com/microsoft/durabletask/EntityIntegrationTests.java b/client/src/test/java/com/microsoft/durabletask/EntityIntegrationTests.java new file mode 100644 index 00000000..39246732 --- /dev/null +++ b/client/src/test/java/com/microsoft/durabletask/EntityIntegrationTests.java @@ -0,0 +1,271 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.microsoft.durabletask; + +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; + +import java.time.Duration; +import java.util.Arrays; +import java.util.concurrent.TimeoutException; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * Integration tests for durable entity features. + *

+ * These tests require a sidecar process running on {@code localhost:4001}. + * They exercise the complete entity workflow: client signaling, entity execution, + * state persistence, and orchestration-entity interaction. + */ +@Tag("integration") +public class EntityIntegrationTests extends IntegrationTestBase { + + // region Test entity classes + + /** + * A simple counter entity for integration testing. + */ + static class CounterEntity extends TaskEntity { + public void add(int amount) { + this.state += amount; + } + + public void reset() { + this.state = 0; + } + + public int get() { + return this.state; + } + + @Override + protected Integer initializeState(TaskEntityOperation operation) { + return 0; + } + + @Override + protected Class getStateType() { + return Integer.class; + } + } + + // endregion + + // region Signal entity + state query tests + + @Test + void signalEntityAndGetState() throws TimeoutException, InterruptedException { + final String entityName = "Counter"; + EntityInstanceId entityId = new EntityInstanceId(entityName, "counter-signal-test"); + + this.createWorkerBuilder() + .addEntity(entityName, CounterEntity::new) + .buildAndStart(); + + DurableTaskClient client = this.createClientBuilder().build(); + + // Signal the entity to add 10 + client.getEntities().signalEntity(entityId, "add", 10); + + // Wait for the entity to process the signal + Thread.sleep(5000); + + // Query entity state + EntityMetadata metadata = client.getEntities().getEntityMetadata(entityId, true); + assertNotNull(metadata, "Entity metadata should not be null"); + + Integer state = metadata.readStateAs(Integer.class); + assertNotNull(state, "Entity state should not be null"); + assertEquals(10, state, "Entity state should be 10 after adding 10"); + } + + @Test + void signalEntityMultipleSignals() throws TimeoutException, InterruptedException { + final String entityName = "Counter"; + EntityInstanceId entityId = new EntityInstanceId(entityName, "counter-multi-signal"); + + this.createWorkerBuilder() + .addEntity(entityName, CounterEntity::new) + .buildAndStart(); + + DurableTaskClient client = this.createClientBuilder().build(); + + // Send multiple signals + client.getEntities().signalEntity(entityId, "add", 5); + client.getEntities().signalEntity(entityId, "add", 3); + client.getEntities().signalEntity(entityId, "add", 2); + + // Wait for all signals to be processed + Thread.sleep(8000); + + EntityMetadata metadata = client.getEntities().getEntityMetadata(entityId, true); + assertNotNull(metadata); + + Integer state = metadata.readStateAs(Integer.class); + assertNotNull(state); + assertEquals(10, state, "Entity state should be 10 after adding 5+3+2"); + } + + @Test + void signalEntityResetAndGet() throws TimeoutException, InterruptedException { + final String entityName = "Counter"; + EntityInstanceId entityId = new EntityInstanceId(entityName, "counter-reset"); + + this.createWorkerBuilder() + .addEntity(entityName, CounterEntity::new) + .buildAndStart(); + + DurableTaskClient client = this.createClientBuilder().build(); + + // Add, then reset + client.getEntities().signalEntity(entityId, "add", 42); + Thread.sleep(3000); + client.getEntities().signalEntity(entityId, "reset"); + Thread.sleep(5000); + + EntityMetadata metadata = client.getEntities().getEntityMetadata(entityId, true); + assertNotNull(metadata); + + Integer state = metadata.readStateAs(Integer.class); + assertNotNull(state); + assertEquals(0, state, "Entity state should be 0 after reset"); + } + + // endregion + + // region Orchestration ↔ entity tests + + @Test + void orchestrationCallsEntity() throws TimeoutException, InterruptedException { + final String entityName = "Counter"; + final String orchestratorName = "CallEntityOrchestration"; + EntityInstanceId entityId = new EntityInstanceId(entityName, "counter-orch-call"); + + this.createWorkerBuilder() + .addOrchestrator(orchestratorName, ctx -> { + // Signal the entity to set up some state + ctx.signalEntity(entityId, "add", 42); + // Wait for the signal to be processed + ctx.createTimer(Duration.ofSeconds(3)).await(); + // Then call to get the value + int value = ctx.callEntity(entityId, "get", null, int.class).await(); + ctx.complete(value); + }) + .addEntity(entityName, CounterEntity::new) + .buildAndStart(); + + DurableTaskClient client = this.createClientBuilder().build(); + + String instanceId = client.scheduleNewOrchestrationInstance(orchestratorName); + OrchestrationMetadata instance = client.waitForInstanceCompletion( + instanceId, defaultTimeout, true); + + assertNotNull(instance); + assertEquals(OrchestrationRuntimeStatus.COMPLETED, instance.getRuntimeStatus()); + assertEquals(42, instance.readOutputAs(int.class)); + } + + @Test + void orchestrationSignalsEntityFireAndForget() throws TimeoutException, InterruptedException { + final String entityName = "Counter"; + final String orchestratorName = "SignalEntityOrchestration"; + EntityInstanceId entityId = new EntityInstanceId(entityName, "counter-orch-signal"); + + this.createWorkerBuilder() + .addOrchestrator(orchestratorName, ctx -> { + ctx.signalEntity(entityId, "add", 100); + ctx.complete("signaled"); + }) + .addEntity(entityName, CounterEntity::new) + .buildAndStart(); + + DurableTaskClient client = this.createClientBuilder().build(); + + String instanceId = client.scheduleNewOrchestrationInstance(orchestratorName); + OrchestrationMetadata instance = client.waitForInstanceCompletion( + instanceId, defaultTimeout, true); + + assertNotNull(instance); + assertEquals(OrchestrationRuntimeStatus.COMPLETED, instance.getRuntimeStatus()); + assertEquals("signaled", instance.readOutputAs(String.class)); + + // Wait for the signal to be processed + Thread.sleep(5000); + + // Verify entity state was updated + EntityMetadata metadata = client.getEntities().getEntityMetadata(entityId, true); + assertNotNull(metadata); + Integer state = metadata.readStateAs(Integer.class); + assertNotNull(state); + assertEquals(100, state); + } + + // endregion + + // region Case-insensitive entity name tests + + @Test + void signalEntity_caseInsensitiveName_entityProcessesSignal() throws TimeoutException, InterruptedException { + final String entityName = "Counter"; + // Use mixed case for the entity ID — should still work since names are lowercased + EntityInstanceId entityId = new EntityInstanceId("COUNTER", "counter-case-test"); + + this.createWorkerBuilder() + .addEntity(entityName, CounterEntity::new) + .buildAndStart(); + + DurableTaskClient client = this.createClientBuilder().build(); + + client.getEntities().signalEntity(entityId, "add", 25); + Thread.sleep(5000); + + EntityMetadata metadata = client.getEntities().getEntityMetadata(entityId, true); + assertNotNull(metadata, "Entity metadata should not be null"); + + Integer state = metadata.readStateAs(Integer.class); + assertNotNull(state, "Entity state should not be null"); + assertEquals(25, state, "Entity state should be 25 after adding 25"); + } + + // endregion + + // region Lock entities + getLockedEntities tests + + @Test + void orchestration_lockAndCallEntity_succeeds() throws TimeoutException, InterruptedException { + final String entityName = "Counter"; + final String orchestratorName = "LockEntityOrchestration"; + EntityInstanceId entityId = new EntityInstanceId(entityName, "counter-lock-test"); + + this.createWorkerBuilder() + .addOrchestrator(orchestratorName, ctx -> { + AutoCloseable lock = ctx.lockEntities(Arrays.asList(entityId)).await(); + // Verify we are in a critical section with locked entities + assertTrue(ctx.isInCriticalSection()); + assertFalse(ctx.getLockedEntities().isEmpty()); + // Call the locked entity + ctx.signalEntity(entityId, "add", 10); + try { + lock.close(); + } catch (Exception e) { + throw new RuntimeException(e); + } + ctx.complete("lock-success"); + }) + .addEntity(entityName, CounterEntity::new) + .buildAndStart(); + + DurableTaskClient client = this.createClientBuilder().build(); + + String instanceId = client.scheduleNewOrchestrationInstance(orchestratorName); + OrchestrationMetadata instance = client.waitForInstanceCompletion( + instanceId, defaultTimeout, true); + + assertNotNull(instance); + assertEquals(OrchestrationRuntimeStatus.COMPLETED, instance.getRuntimeStatus()); + assertEquals("lock-success", instance.readOutputAs(String.class)); + } + + // endregion +} diff --git a/client/src/test/java/com/microsoft/durabletask/EntityMetadataTest.java b/client/src/test/java/com/microsoft/durabletask/EntityMetadataTest.java new file mode 100644 index 00000000..41d5576c --- /dev/null +++ b/client/src/test/java/com/microsoft/durabletask/EntityMetadataTest.java @@ -0,0 +1,96 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.microsoft.durabletask; + +import org.junit.jupiter.api.Test; + +import java.time.Instant; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * Unit tests for {@link EntityMetadata}. + */ +public class EntityMetadataTest { + + private static final DataConverter dataConverter = new JacksonDataConverter(); + + @Test + void getters_returnConstructorValues() { + Instant now = Instant.now(); + EntityMetadata metadata = new EntityMetadata( + "@counter@myKey", now, 5, "orch-123", "42", true, dataConverter); + + assertEquals("@counter@myKey", metadata.getInstanceId()); + assertEquals(now, metadata.getLastModifiedTime()); + assertEquals(5, metadata.getBacklogQueueSize()); + assertEquals("orch-123", metadata.getLockedBy()); + assertEquals("42", metadata.getSerializedState()); + } + + @Test + void getEntityInstanceId_parsesCorrectly() { + EntityMetadata metadata = new EntityMetadata( + "@counter@myKey", Instant.EPOCH, 0, null, null, false, dataConverter); + + EntityInstanceId entityId = metadata.getEntityInstanceId(); + assertEquals("counter", entityId.getName()); + assertEquals("myKey", entityId.getKey()); + } + + @Test + void readStateAs_deserializesIntegerState() { + EntityMetadata metadata = new EntityMetadata( + "@counter@c1", Instant.EPOCH, 0, null, "42", true, dataConverter); + + Integer state = metadata.readStateAs(Integer.class); + assertNotNull(state); + assertEquals(42, state); + } + + @Test + void readStateAs_deserializesStringState() { + EntityMetadata metadata = new EntityMetadata( + "@myEntity@k1", Instant.EPOCH, 0, null, "\"hello\"", true, dataConverter); + + String state = metadata.readStateAs(String.class); + assertEquals("hello", state); + } + + @Test + void readStateAs_nullState_returnsNull() { + EntityMetadata metadata = new EntityMetadata( + "@counter@c1", Instant.EPOCH, 0, null, null, false, dataConverter); + + Integer state = metadata.readStateAs(Integer.class); + assertNull(state); + } + + @Test + void lockedBy_nullWhenNotLocked() { + EntityMetadata metadata = new EntityMetadata( + "@counter@c1", Instant.EPOCH, 0, null, null, false, dataConverter); + assertNull(metadata.getLockedBy()); + } + + @Test + void backlogQueueSize_zero() { + EntityMetadata metadata = new EntityMetadata( + "@counter@c1", Instant.EPOCH, 0, null, null, false, dataConverter); + assertEquals(0, metadata.getBacklogQueueSize()); + } + + @Test + void includesState_true_whenStateIncluded() { + EntityMetadata metadata = new EntityMetadata( + "@counter@c1", Instant.EPOCH, 0, null, "42", true, dataConverter); + assertTrue(metadata.isIncludesState()); + } + + @Test + void includesState_false_whenStateNotIncluded() { + EntityMetadata metadata = new EntityMetadata( + "@counter@c1", Instant.EPOCH, 0, null, null, false, dataConverter); + assertFalse(metadata.isIncludesState()); + } +} diff --git a/client/src/test/java/com/microsoft/durabletask/EntityOptionsTest.java b/client/src/test/java/com/microsoft/durabletask/EntityOptionsTest.java new file mode 100644 index 00000000..9e579718 --- /dev/null +++ b/client/src/test/java/com/microsoft/durabletask/EntityOptionsTest.java @@ -0,0 +1,85 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.microsoft.durabletask; + +import org.junit.jupiter.api.Test; + +import java.time.Duration; +import java.time.Instant; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * Unit tests for {@link SignalEntityOptions} and {@link CallEntityOptions}. + */ +public class EntityOptionsTest { + + // region SignalEntityOptions tests + + @Test + void signalEntityOptions_default_scheduledTimeIsNull() { + SignalEntityOptions options = new SignalEntityOptions(); + assertNull(options.getScheduledTime()); + } + + @Test + void signalEntityOptions_setScheduledTime_roundTrip() { + Instant time = Instant.parse("2025-06-15T10:00:00Z"); + SignalEntityOptions options = new SignalEntityOptions().setScheduledTime(time); + assertEquals(time, options.getScheduledTime()); + } + + @Test + void signalEntityOptions_setScheduledTime_null_resetsToNull() { + Instant time = Instant.parse("2025-06-15T10:00:00Z"); + SignalEntityOptions options = new SignalEntityOptions() + .setScheduledTime(time) + .setScheduledTime(null); + assertNull(options.getScheduledTime()); + } + + @Test + void signalEntityOptions_fluentChaining_returnsSameInstance() { + SignalEntityOptions options = new SignalEntityOptions(); + assertSame(options, options.setScheduledTime(Instant.now())); + } + + // endregion + + // region CallEntityOptions tests + + @Test + void callEntityOptions_default_timeoutIsNull() { + CallEntityOptions options = new CallEntityOptions(); + assertNull(options.getTimeout()); + } + + @Test + void callEntityOptions_setTimeout_roundTrip() { + Duration timeout = Duration.ofSeconds(30); + CallEntityOptions options = new CallEntityOptions().setTimeout(timeout); + assertEquals(timeout, options.getTimeout()); + } + + @Test + void callEntityOptions_setTimeout_null_resetsToNull() { + CallEntityOptions options = new CallEntityOptions() + .setTimeout(Duration.ofSeconds(30)) + .setTimeout(null); + assertNull(options.getTimeout()); + } + + @Test + void callEntityOptions_setTimeout_zero_allowed() { + CallEntityOptions options = new CallEntityOptions().setTimeout(Duration.ZERO); + assertEquals(Duration.ZERO, options.getTimeout()); + } + + @Test + void callEntityOptions_fluentChaining_returnsSameInstance() { + CallEntityOptions options = new CallEntityOptions(); + assertSame(options, options.setTimeout(Duration.ofMinutes(5))); + } + + // endregion +} diff --git a/client/src/test/java/com/microsoft/durabletask/EntityQueryPageableTest.java b/client/src/test/java/com/microsoft/durabletask/EntityQueryPageableTest.java new file mode 100644 index 00000000..7c14198c --- /dev/null +++ b/client/src/test/java/com/microsoft/durabletask/EntityQueryPageableTest.java @@ -0,0 +1,183 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.microsoft.durabletask; + +import org.junit.jupiter.api.Test; + +import java.time.Instant; +import java.util.*; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * Unit tests for {@link EntityQueryPageable}. + */ +public class EntityQueryPageableTest { + + private static final DataConverter dataConverter = new JacksonDataConverter(); + + private EntityMetadata makeEntity(String name, String key) { + return new EntityMetadata( + "@" + name + "@" + key, Instant.EPOCH, 0, null, null, false, dataConverter); + } + + @Test + void iterator_emptyResult_yieldsNoItems() { + EntityQuery query = new EntityQuery(); + EntityQueryPageable pageable = new EntityQueryPageable(query, q -> { + return new EntityQueryResult(Collections.emptyList(), null); + }); + + Iterator it = pageable.iterator(); + assertFalse(it.hasNext()); + } + + @Test + void iterator_singlePage_yieldsAllItems() { + List entities = Arrays.asList( + makeEntity("counter", "a"), + makeEntity("counter", "b"), + makeEntity("counter", "c")); + + EntityQuery query = new EntityQuery(); + EntityQueryPageable pageable = new EntityQueryPageable(query, q -> { + return new EntityQueryResult(entities, null); + }); + + List collected = new ArrayList<>(); + for (EntityMetadata e : pageable) { + collected.add(e); + } + assertEquals(3, collected.size()); + assertEquals("@counter@a", collected.get(0).getInstanceId()); + assertEquals("@counter@b", collected.get(1).getInstanceId()); + assertEquals("@counter@c", collected.get(2).getInstanceId()); + } + + @Test + void iterator_multiplePages_yieldsAllItems() { + List page1 = Arrays.asList( + makeEntity("counter", "a"), + makeEntity("counter", "b")); + List page2 = Arrays.asList( + makeEntity("counter", "c")); + + EntityQuery query = new EntityQuery(); + final int[] callCount = {0}; + EntityQueryPageable pageable = new EntityQueryPageable(query, q -> { + callCount[0]++; + if (callCount[0] == 1) { + return new EntityQueryResult(page1, "token1"); + } else { + return new EntityQueryResult(page2, null); + } + }); + + List collected = new ArrayList<>(); + for (EntityMetadata e : pageable) { + collected.add(e); + } + assertEquals(3, collected.size()); + assertEquals("@counter@a", collected.get(0).getInstanceId()); + assertEquals("@counter@c", collected.get(2).getInstanceId()); + } + + @Test + void iterator_propagatesContinuationToken() { + EntityQuery query = new EntityQuery(); + List tokensReceived = new ArrayList<>(); + + EntityQueryPageable pageable = new EntityQueryPageable(query, q -> { + tokensReceived.add(q.getContinuationToken()); + if (tokensReceived.size() == 1) { + return new EntityQueryResult( + Collections.singletonList(makeEntity("e", "1")), "pageToken"); + } else { + return new EntityQueryResult( + Collections.singletonList(makeEntity("e", "2")), null); + } + }); + + Iterator iterator = pageable.iterator(); + while (iterator.hasNext()) { + iterator.next(); + } + + assertEquals(2, tokensReceived.size()); + assertEquals("pageToken", tokensReceived.get(1)); + } + + @Test + void byPage_emptyResult_yieldsSinglePage() { + EntityQuery query = new EntityQuery(); + EntityQueryPageable pageable = new EntityQueryPageable(query, q -> { + return new EntityQueryResult(Collections.emptyList(), null); + }); + + List pages = new ArrayList<>(); + for (EntityQueryResult page : pageable.byPage()) { + pages.add(page); + } + assertEquals(1, pages.size()); + assertTrue(pages.get(0).getEntities().isEmpty()); + } + + @Test + void byPage_multiplePages_yieldsAllPages() { + List page1 = Arrays.asList( + makeEntity("counter", "a"), + makeEntity("counter", "b")); + List page2 = Collections.singletonList( + makeEntity("counter", "c")); + + EntityQuery query = new EntityQuery(); + final int[] callCount = {0}; + EntityQueryPageable pageable = new EntityQueryPageable(query, q -> { + callCount[0]++; + if (callCount[0] == 1) { + return new EntityQueryResult(page1, "nextToken"); + } else { + return new EntityQueryResult(page2, null); + } + }); + + List pages = new ArrayList<>(); + for (EntityQueryResult page : pageable.byPage()) { + pages.add(page); + } + assertEquals(2, pages.size()); + assertEquals(2, pages.get(0).getEntities().size()); + assertEquals(1, pages.get(1).getEntities().size()); + } + + @Test + void byPage_preservesQueryParameters() { + EntityQuery query = new EntityQuery() + .setInstanceIdStartsWith("counter") + .setIncludeState(true) + .setPageSize(10); + + EntityQueryPageable pageable = new EntityQueryPageable(query, q -> { + // Verify query parameters are preserved (instanceIdStartsWith is normalized to @counter) + assertEquals("@counter", q.getInstanceIdStartsWith()); + assertTrue(q.isIncludeState()); + assertEquals(10, q.getPageSize()); + return new EntityQueryResult(Collections.emptyList(), null); + }); + + for (EntityQueryResult ignored : pageable.byPage()) { + // consume + } + } + + @Test + void iterator_nextWithoutHasNext_throwsWhenExhausted() { + EntityQuery query = new EntityQuery(); + EntityQueryPageable pageable = new EntityQueryPageable(query, q -> { + return new EntityQueryResult(Collections.emptyList(), null); + }); + + Iterator it = pageable.iterator(); + assertThrows(NoSuchElementException.class, it::next); + } +} diff --git a/client/src/test/java/com/microsoft/durabletask/EntityQueryResultTest.java b/client/src/test/java/com/microsoft/durabletask/EntityQueryResultTest.java new file mode 100644 index 00000000..fbd82360 --- /dev/null +++ b/client/src/test/java/com/microsoft/durabletask/EntityQueryResultTest.java @@ -0,0 +1,49 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.microsoft.durabletask; + +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * Unit tests for {@link EntityQueryResult}. + */ +public class EntityQueryResultTest { + + private static final DataConverter dataConverter = new JacksonDataConverter(); + + @Test + void constructor_setsEntitiesAndToken() { + List entities = Arrays.asList( + new EntityMetadata("@counter@c1", java.time.Instant.EPOCH, 0, null, null, false, dataConverter), + new EntityMetadata("@counter@c2", java.time.Instant.EPOCH, 0, null, null, false, dataConverter)); + + EntityQueryResult result = new EntityQueryResult(entities, "nextPage"); + + assertEquals(2, result.getEntities().size()); + assertEquals("nextPage", result.getContinuationToken()); + } + + @Test + void constructor_nullContinuationToken_meansNoMorePages() { + EntityQueryResult result = new EntityQueryResult(Collections.emptyList(), null); + + assertTrue(result.getEntities().isEmpty()); + assertNull(result.getContinuationToken()); + } + + @Test + void getEntities_returnsProvidedList() { + EntityMetadata metadata = new EntityMetadata( + "@counter@c1", java.time.Instant.EPOCH, 0, null, "42", true, dataConverter); + EntityQueryResult result = new EntityQueryResult(Collections.singletonList(metadata), null); + + assertEquals(1, result.getEntities().size()); + assertEquals("@counter@c1", result.getEntities().get(0).getInstanceId()); + } +} diff --git a/client/src/test/java/com/microsoft/durabletask/EntityQueryTest.java b/client/src/test/java/com/microsoft/durabletask/EntityQueryTest.java new file mode 100644 index 00000000..3f8a61af --- /dev/null +++ b/client/src/test/java/com/microsoft/durabletask/EntityQueryTest.java @@ -0,0 +1,152 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.microsoft.durabletask; + +import org.junit.jupiter.api.Test; + +import java.time.Instant; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * Unit tests for {@link EntityQuery}. + */ +public class EntityQueryTest { + + // region prefix normalization tests + + @Test + void setInstanceIdStartsWith_rawEntityName_prependsAtAndLowercases() { + EntityQuery query = new EntityQuery().setInstanceIdStartsWith("Counter"); + assertEquals("@counter", query.getInstanceIdStartsWith()); + } + + @Test + void setInstanceIdStartsWith_alreadyPrefixed_lowercasesName() { + EntityQuery query = new EntityQuery().setInstanceIdStartsWith("@Counter"); + assertEquals("@counter", query.getInstanceIdStartsWith()); + } + + @Test + void setInstanceIdStartsWith_fullEntityId_lowercasesNameOnly() { + EntityQuery query = new EntityQuery().setInstanceIdStartsWith("@Counter@myKey"); + assertEquals("@counter@myKey", query.getInstanceIdStartsWith()); + } + + @Test + void setInstanceIdStartsWith_null_staysNull() { + EntityQuery query = new EntityQuery().setInstanceIdStartsWith(null); + assertNull(query.getInstanceIdStartsWith()); + } + + @Test + void setInstanceIdStartsWith_alreadyLowercase_unchanged() { + EntityQuery query = new EntityQuery().setInstanceIdStartsWith("@counter"); + assertEquals("@counter", query.getInstanceIdStartsWith()); + } + + @Test + void setInstanceIdStartsWith_mixedCase_lowercased() { + EntityQuery query = new EntityQuery().setInstanceIdStartsWith("MyEntity"); + assertEquals("@myentity", query.getInstanceIdStartsWith()); + } + + // endregion + + // region defaults + + @Test + void defaults_includeStateTrue_othersNullOrFalse() { + EntityQuery query = new EntityQuery(); + assertNull(query.getInstanceIdStartsWith()); + assertNull(query.getLastModifiedFrom()); + assertNull(query.getLastModifiedTo()); + assertTrue(query.isIncludeState()); + assertFalse(query.isIncludeTransient()); + assertNull(query.getPageSize()); + assertNull(query.getContinuationToken()); + } + + // endregion + + // region DefaultPageSize constant + + @Test + void defaultPageSize_is100() { + assertEquals(100, EntityQuery.DEFAULT_PAGE_SIZE); + } + + // endregion + + // region setter/getter round-trip tests + + @Test + void setLastModifiedFrom_roundTrip() { + Instant time = Instant.parse("2025-01-15T10:00:00Z"); + EntityQuery query = new EntityQuery().setLastModifiedFrom(time); + assertEquals(time, query.getLastModifiedFrom()); + } + + @Test + void setLastModifiedTo_roundTrip() { + Instant time = Instant.parse("2025-12-31T23:59:59Z"); + EntityQuery query = new EntityQuery().setLastModifiedTo(time); + assertEquals(time, query.getLastModifiedTo()); + } + + @Test + void setIncludeState_roundTrip() { + EntityQuery query = new EntityQuery().setIncludeState(true); + assertTrue(query.isIncludeState()); + } + + @Test + void setIncludeTransient_roundTrip() { + EntityQuery query = new EntityQuery().setIncludeTransient(true); + assertTrue(query.isIncludeTransient()); + } + + @Test + void setPageSize_roundTrip() { + EntityQuery query = new EntityQuery().setPageSize(50); + assertEquals(50, query.getPageSize()); + } + + @Test + void setPageSize_null_allowed() { + EntityQuery query = new EntityQuery().setPageSize(50).setPageSize(null); + assertNull(query.getPageSize()); + } + + @Test + void setContinuationToken_roundTrip() { + EntityQuery query = new EntityQuery().setContinuationToken("token-abc"); + assertEquals("token-abc", query.getContinuationToken()); + } + + @Test + void setContinuationToken_null_allowed() { + EntityQuery query = new EntityQuery() + .setContinuationToken("token") + .setContinuationToken(null); + assertNull(query.getContinuationToken()); + } + + // endregion + + // region fluent chaining + + @Test + void fluentChaining_returnsSameInstance() { + EntityQuery query = new EntityQuery(); + assertSame(query, query.setInstanceIdStartsWith("Counter")); + assertSame(query, query.setLastModifiedFrom(Instant.now())); + assertSame(query, query.setLastModifiedTo(Instant.now())); + assertSame(query, query.setIncludeState(true)); + assertSame(query, query.setIncludeTransient(true)); + assertSame(query, query.setPageSize(100)); + assertSame(query, query.setContinuationToken("t")); + } + + // endregion +} diff --git a/client/src/test/java/com/microsoft/durabletask/EntityRegistrationTest.java b/client/src/test/java/com/microsoft/durabletask/EntityRegistrationTest.java new file mode 100644 index 00000000..7f0471a3 --- /dev/null +++ b/client/src/test/java/com/microsoft/durabletask/EntityRegistrationTest.java @@ -0,0 +1,212 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.microsoft.durabletask; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * Unit tests for entity registration overloads on {@link DurableTaskGrpcWorkerBuilder}. + */ +public class EntityRegistrationTest { + + // region Test entity classes + + /** + * A simple entity with a public no-arg constructor. + */ + static class TestEntity implements ITaskEntity { + @Override + public Object runAsync(TaskEntityOperation operation) { + return null; + } + } + + /** + * An entity that does NOT have a public no-arg constructor. + */ + static class NoDefaultConstructorEntity implements ITaskEntity { + private final String value; + + public NoDefaultConstructorEntity(String value) { + this.value = value; + } + + @Override + public Object runAsync(TaskEntityOperation operation) { + return null; + } + } + + // endregion + + // region addEntity(Class) + + @Test + void addEntity_class_derivesNameFromSimpleName() { + DurableTaskGrpcWorkerBuilder builder = new DurableTaskGrpcWorkerBuilder(); + builder.addEntity(TestEntity.class); + + // The name should be derived from the simple class name, lowercased + assertTrue(builder.entityFactories.containsKey("testentity")); + } + + @Test + void addEntity_class_nullClass_throws() { + DurableTaskGrpcWorkerBuilder builder = new DurableTaskGrpcWorkerBuilder(); + assertThrows(IllegalArgumentException.class, () -> { + builder.addEntity((Class) null); + }); + } + + @Test + void addEntity_class_factoryCreatesInstance() throws Exception { + DurableTaskGrpcWorkerBuilder builder = new DurableTaskGrpcWorkerBuilder(); + builder.addEntity(TestEntity.class); + + TaskEntityFactory factory = builder.entityFactories.get("testentity"); + assertNotNull(factory); + + ITaskEntity entity = factory.create(); + assertNotNull(entity); + assertInstanceOf(TestEntity.class, entity); + } + + @Test + void addEntity_class_factoryCreatesNewInstanceEachTime() throws Exception { + DurableTaskGrpcWorkerBuilder builder = new DurableTaskGrpcWorkerBuilder(); + builder.addEntity(TestEntity.class); + + TaskEntityFactory factory = builder.entityFactories.get("testentity"); + ITaskEntity entity1 = factory.create(); + ITaskEntity entity2 = factory.create(); + + assertNotSame(entity1, entity2, "Factory should create a new instance each time"); + } + + // endregion + + // region addEntity(String, Class) + + @Test + void addEntity_nameAndClass_usesProvidedName() { + DurableTaskGrpcWorkerBuilder builder = new DurableTaskGrpcWorkerBuilder(); + builder.addEntity("MyCustomName", TestEntity.class); + + assertTrue(builder.entityFactories.containsKey("mycustomname")); + } + + @Test + void addEntity_nameAndClass_nullClass_throws() { + DurableTaskGrpcWorkerBuilder builder = new DurableTaskGrpcWorkerBuilder(); + assertThrows(IllegalArgumentException.class, () -> { + builder.addEntity("name", (Class) null); + }); + } + + @Test + void addEntity_nameAndClass_noDefaultConstructor_throwsOnCreate() { + DurableTaskGrpcWorkerBuilder builder = new DurableTaskGrpcWorkerBuilder(); + builder.addEntity("myEntity", NoDefaultConstructorEntity.class); + + TaskEntityFactory factory = builder.entityFactories.get("myentity"); + assertNotNull(factory); + + // Should throw at creation time, not registration time + assertThrows(RuntimeException.class, factory::create); + } + + @Test + void addEntity_nameAndClass_duplicateName_throws() { + DurableTaskGrpcWorkerBuilder builder = new DurableTaskGrpcWorkerBuilder(); + builder.addEntity("dup", TestEntity.class); + + assertThrows(IllegalArgumentException.class, () -> { + builder.addEntity("dup", TestEntity.class); + }); + } + + // endregion + + // region addEntity(ITaskEntity) — singleton + + @Test + void addEntity_singleton_derivesNameFromClass() { + DurableTaskGrpcWorkerBuilder builder = new DurableTaskGrpcWorkerBuilder(); + TestEntity instance = new TestEntity(); + builder.addEntity(instance); + + assertTrue(builder.entityFactories.containsKey("testentity")); + } + + @Test + void addEntity_singleton_returnsSameInstance() throws Exception { + DurableTaskGrpcWorkerBuilder builder = new DurableTaskGrpcWorkerBuilder(); + TestEntity instance = new TestEntity(); + builder.addEntity(instance); + + TaskEntityFactory factory = builder.entityFactories.get("testentity"); + ITaskEntity created1 = factory.create(); + ITaskEntity created2 = factory.create(); + + assertSame(instance, created1, "Singleton registration should return the same instance"); + assertSame(instance, created2, "Singleton registration should return the same instance"); + } + + @Test + void addEntity_singleton_null_throws() { + DurableTaskGrpcWorkerBuilder builder = new DurableTaskGrpcWorkerBuilder(); + assertThrows(IllegalArgumentException.class, () -> { + builder.addEntity((ITaskEntity) null); + }); + } + + // endregion + + // region addEntity(String, ITaskEntity) — named singleton + + @Test + void addEntity_namedSingleton_usesProvidedName() { + DurableTaskGrpcWorkerBuilder builder = new DurableTaskGrpcWorkerBuilder(); + TestEntity instance = new TestEntity(); + builder.addEntity("customName", instance); + + assertTrue(builder.entityFactories.containsKey("customname")); + } + + @Test + void addEntity_namedSingleton_returnsSameInstance() throws Exception { + DurableTaskGrpcWorkerBuilder builder = new DurableTaskGrpcWorkerBuilder(); + TestEntity instance = new TestEntity(); + builder.addEntity("myEntity", instance); + + TaskEntityFactory factory = builder.entityFactories.get("myentity"); + assertSame(instance, factory.create()); + } + + @Test + void addEntity_namedSingleton_null_throws() { + DurableTaskGrpcWorkerBuilder builder = new DurableTaskGrpcWorkerBuilder(); + assertThrows(IllegalArgumentException.class, () -> { + builder.addEntity("name", (ITaskEntity) null); + }); + } + + // endregion + + // region Chaining + + @Test + void addEntity_returnsBuilderForChaining() { + DurableTaskGrpcWorkerBuilder builder = new DurableTaskGrpcWorkerBuilder(); + DurableTaskGrpcWorkerBuilder result = builder + .addEntity(TestEntity.class) + .addEntity("custom", new TestEntity()); + + assertSame(builder, result); + assertEquals(2, builder.entityFactories.size()); + } + + // endregion +} diff --git a/client/src/test/java/com/microsoft/durabletask/IntegrationTestBase.java b/client/src/test/java/com/microsoft/durabletask/IntegrationTestBase.java index 01a485d9..f7b45576 100644 --- a/client/src/test/java/com/microsoft/durabletask/IntegrationTestBase.java +++ b/client/src/test/java/com/microsoft/durabletask/IntegrationTestBase.java @@ -119,5 +119,10 @@ public TestDurableTaskWorkerBuilder useVersioning(DurableTaskGrpcWorkerVersionin this.innerBuilder.useVersioning(options); return this; } + + public TestDurableTaskWorkerBuilder addEntity(String name, TaskEntityFactory factory) { + this.innerBuilder.addEntity(name, factory); + return this; + } } } diff --git a/client/src/test/java/com/microsoft/durabletask/TaskEntityExecutorTest.java b/client/src/test/java/com/microsoft/durabletask/TaskEntityExecutorTest.java new file mode 100644 index 00000000..a5910527 --- /dev/null +++ b/client/src/test/java/com/microsoft/durabletask/TaskEntityExecutorTest.java @@ -0,0 +1,511 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.microsoft.durabletask; + +import com.google.protobuf.StringValue; +import com.microsoft.durabletask.implementation.protobuf.OrchestratorService.*; +import org.junit.jupiter.api.Test; + +import java.util.HashMap; +import java.util.logging.Logger; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * Unit tests for {@link TaskEntityExecutor}. + *

+ * Tests construct {@link EntityBatchRequest} proto objects directly and assert on + * the returned {@link EntityBatchResult}, following the same pattern as + * {@link TaskOrchestrationExecutorTest}. + */ +public class TaskEntityExecutorTest { + + private static final Logger logger = Logger.getLogger(TaskEntityExecutorTest.class.getName()); + private static final DataConverter dataConverter = new JacksonDataConverter(); + + // region Test entity implementations + + /** + * Simple counter entity for testing. + */ + static class TestCounterEntity extends TaskEntity { + public void add(int amount) { + this.state += amount; + } + + public void reset() { + this.state = 0; + } + + public int get() { + return this.state; + } + + @Override + protected Integer initializeState(TaskEntityOperation operation) { + return 0; + } + + @Override + protected Class getStateType() { + return Integer.class; + } + } + + /** + * Entity that always throws (for failure testing). + */ + static class FailingEntity implements ITaskEntity { + @Override + public Object runAsync(TaskEntityOperation operation) throws Exception { + throw new RuntimeException("Intentional failure: " + operation.getName()); + } + } + + /** + * Entity that signals another entity during execution (for action testing). + */ + static class SignalingEntity implements ITaskEntity { + @Override + public Object runAsync(TaskEntityOperation operation) throws Exception { + if ("signalOther".equals(operation.getName())) { + EntityInstanceId targetId = new EntityInstanceId("Counter", "target1"); + operation.getContext().signalEntity(targetId, "add", 10); + return "signaled"; + } + return null; + } + } + + /** + * Entity that starts an orchestration during execution. + */ + static class OrchestrationStartingEntity implements ITaskEntity { + @Override + public Object runAsync(TaskEntityOperation operation) throws Exception { + if ("startOrch".equals(operation.getName())) { + String orchId = operation.getContext().startNewOrchestration("MyOrchestration", "orchInput"); + return orchId; + } + return null; + } + } + + /** + * Entity that conditionally fails (first op succeeds, second fails). + */ + static class ConditionalFailEntity implements ITaskEntity { + private int callCount = 0; + + @Override + public Object runAsync(TaskEntityOperation operation) throws Exception { + callCount++; + if ("failOnSecond".equals(operation.getName()) && callCount == 2) { + throw new RuntimeException("Second call failed"); + } + // Modify state to test rollback + operation.getState().setState("state-after-op-" + callCount); + return "result-" + callCount; + } + } + + // endregion + + // region Helper methods + + private TaskEntityExecutor createExecutor(String entityName, TaskEntityFactory factory) { + HashMap factories = new HashMap<>(); + factories.put(entityName.toLowerCase(java.util.Locale.ROOT), factory); + return new TaskEntityExecutor(factories, dataConverter, logger); + } + + private OperationRequest buildOperationRequest(String operationName, Object input, String requestId) { + OperationRequest.Builder builder = OperationRequest.newBuilder() + .setOperation(operationName) + .setRequestId(requestId != null ? requestId : "req-" + operationName); + if (input != null) { + builder.setInput(StringValue.of(dataConverter.serialize(input))); + } + return builder.build(); + } + + private OperationRequest buildOperationRequest(String operationName, Object input) { + return buildOperationRequest(operationName, input, null); + } + + private OperationRequest buildOperationRequest(String operationName) { + return buildOperationRequest(operationName, null, null); + } + + private EntityBatchRequest buildBatchRequest(String entityName, String entityKey, String entityState, + OperationRequest... operations) { + EntityBatchRequest.Builder builder = EntityBatchRequest.newBuilder() + .setInstanceId("@" + entityName + "@" + entityKey); + if (entityState != null) { + builder.setEntityState(StringValue.of(entityState)); + } + for (OperationRequest op : operations) { + builder.addOperations(op); + } + return builder.build(); + } + + // endregion + + // region Single operation tests + + @Test + void execute_singleSuccessfulOperation_returnsSuccess() { + TaskEntityExecutor executor = createExecutor("Counter", TestCounterEntity::new); + + EntityBatchRequest request = buildBatchRequest("Counter", "c1", + dataConverter.serialize(10), + buildOperationRequest("add", 5)); + + EntityBatchResult result = executor.execute(request); + + assertEquals(1, result.getResultsCount()); + assertTrue(result.getResults(0).hasSuccess()); + + // Result has no explicit output (void method) + assertFalse(result.getResults(0).getSuccess().hasResult()); + + // Final state should be 15 (10 + 5) + assertTrue(result.hasEntityState()); + assertEquals("15", result.getEntityState().getValue()); + } + + @Test + void execute_operationWithReturnValue_returnsSerializedResult() { + TaskEntityExecutor executor = createExecutor("Counter", TestCounterEntity::new); + + EntityBatchRequest request = buildBatchRequest("Counter", "c1", + dataConverter.serialize(42), + buildOperationRequest("get")); + + EntityBatchResult result = executor.execute(request); + + assertEquals(1, result.getResultsCount()); + assertTrue(result.getResults(0).hasSuccess()); + assertTrue(result.getResults(0).getSuccess().hasResult()); + assertEquals("42", result.getResults(0).getSuccess().getResult().getValue()); + } + + @Test + void execute_operationWithNoExistingState_initializesState() { + TaskEntityExecutor executor = createExecutor("Counter", TestCounterEntity::new); + + // No entity state provided — should call initializeState() which returns 0 + EntityBatchRequest request = buildBatchRequest("Counter", "c1", null, + buildOperationRequest("add", 7)); + + EntityBatchResult result = executor.execute(request); + + assertEquals(1, result.getResultsCount()); + assertTrue(result.getResults(0).hasSuccess()); + + // State should be 0 + 7 = 7 + assertTrue(result.hasEntityState()); + assertEquals("7", result.getEntityState().getValue()); + } + + @Test + void execute_operationFails_returnsFailure() { + TaskEntityExecutor executor = createExecutor("Failing", FailingEntity::new); + + EntityBatchRequest request = buildBatchRequest("Failing", "f1", null, + buildOperationRequest("anyOp")); + + EntityBatchResult result = executor.execute(request); + + assertEquals(1, result.getResultsCount()); + assertTrue(result.getResults(0).hasFailure()); + + TaskFailureDetails failure = result.getResults(0).getFailure().getFailureDetails(); + assertEquals("java.lang.RuntimeException", failure.getErrorType()); + assertTrue(failure.getErrorMessage().contains("Intentional failure")); + } + + // endregion + + // region Batch (multi-operation) tests + + @Test + void execute_multipleSuccessfulOperations_allSucceed() { + TaskEntityExecutor executor = createExecutor("Counter", TestCounterEntity::new); + + EntityBatchRequest request = buildBatchRequest("Counter", "c1", + dataConverter.serialize(0), + buildOperationRequest("add", 3), + buildOperationRequest("add", 7), + buildOperationRequest("get")); + + EntityBatchResult result = executor.execute(request); + + assertEquals(3, result.getResultsCount()); + + // All should succeed + assertTrue(result.getResults(0).hasSuccess()); + assertTrue(result.getResults(1).hasSuccess()); + assertTrue(result.getResults(2).hasSuccess()); + + // Final get should return 10 + assertEquals("10", result.getResults(2).getSuccess().getResult().getValue()); + + // Final state should be 10 + assertEquals("10", result.getEntityState().getValue()); + } + + @Test + void execute_batchWithFailure_rollbacksFailedOperation() { + TaskEntityExecutor executor = createExecutor("Conditional", ConditionalFailEntity::new); + + EntityBatchRequest request = buildBatchRequest("Conditional", "key1", null, + buildOperationRequest("failOnSecond"), + buildOperationRequest("failOnSecond")); + + EntityBatchResult result = executor.execute(request); + + assertEquals(2, result.getResultsCount()); + + // First operation succeeds + assertTrue(result.getResults(0).hasSuccess()); + assertEquals("\"result-1\"", result.getResults(0).getSuccess().getResult().getValue()); + + // Second operation fails + assertTrue(result.getResults(1).hasFailure()); + assertTrue(result.getResults(1).getFailure().getFailureDetails().getErrorMessage().contains("Second call failed")); + + // State should be from the first successful operation (rolled back from second) + assertTrue(result.hasEntityState()); + assertEquals("\"state-after-op-1\"", result.getEntityState().getValue()); + } + + @Test + void execute_batchAfterFailure_continuesExecution() { + // After a failed op, subsequent ops should still execute. + // Using a fresh entity since ConditionalFailEntity tracks callCount. + // Build a batch: op1 = fail, op2 = different entity that succeeds + TaskEntityExecutor executor = createExecutor("Failing", FailingEntity::new); + + EntityBatchRequest request = buildBatchRequest("Failing", "key1", null, + buildOperationRequest("op1"), + buildOperationRequest("op2")); + + EntityBatchResult result = executor.execute(request); + + assertEquals(2, result.getResultsCount()); + // Both should fail since FailingEntity always fails + assertTrue(result.getResults(0).hasFailure()); + assertTrue(result.getResults(1).hasFailure()); + } + + // endregion + + // region Unregistered entity tests + + @Test + void execute_unregisteredEntity_returnsFailure() { + TaskEntityExecutor executor = createExecutor("Counter", TestCounterEntity::new); + + // Request for a non-existent entity + EntityBatchRequest request = buildBatchRequest("NonExistent", "key1", null, + buildOperationRequest("op")); + + EntityBatchResult result = executor.execute(request); + + assertTrue(result.hasFailureDetails()); + assertEquals(IllegalStateException.class.getName(), result.getFailureDetails().getErrorType()); + assertTrue(result.getFailureDetails().getErrorMessage().contains("nonexistent")); + } + + @Test + void execute_caseInsensitiveEntityLookup_succeeds() { + TaskEntityExecutor executor = createExecutor("Counter", TestCounterEntity::new); + + // Use different case for entity name in the instance ID + EntityBatchRequest request = buildBatchRequest("counter", "c1", + dataConverter.serialize(0), + buildOperationRequest("add", 5)); + + EntityBatchResult result = executor.execute(request); + + assertEquals(1, result.getResultsCount()); + assertTrue(result.getResults(0).hasSuccess()); + } + + // endregion + + // region Action (signal/orchestration) tests + + @Test + void execute_entitySignalsOther_actionsIncluded() { + TaskEntityExecutor executor = createExecutor("Signaler", SignalingEntity::new); + + EntityBatchRequest request = buildBatchRequest("Signaler", "s1", null, + buildOperationRequest("signalOther")); + + EntityBatchResult result = executor.execute(request); + + assertEquals(1, result.getResultsCount()); + assertTrue(result.getResults(0).hasSuccess()); + assertEquals("\"signaled\"", result.getResults(0).getSuccess().getResult().getValue()); + + // Should have one action: sendSignal + assertEquals(1, result.getActionsCount()); + assertTrue(result.getActions(0).hasSendSignal()); + + SendSignalAction signalAction = result.getActions(0).getSendSignal(); + assertEquals("@counter@target1", signalAction.getInstanceId()); + assertEquals("add", signalAction.getName()); + } + + @Test + void execute_entityStartsOrchestration_actionsIncluded() { + TaskEntityExecutor executor = createExecutor("OrchStarter", OrchestrationStartingEntity::new); + + EntityBatchRequest request = buildBatchRequest("OrchStarter", "o1", null, + buildOperationRequest("startOrch")); + + EntityBatchResult result = executor.execute(request); + + assertEquals(1, result.getResultsCount()); + assertTrue(result.getResults(0).hasSuccess()); + + // Should have one action: startNewOrchestration + assertEquals(1, result.getActionsCount()); + assertTrue(result.getActions(0).hasStartNewOrchestration()); + + StartNewOrchestrationAction orchAction = result.getActions(0).getStartNewOrchestration(); + assertEquals("MyOrchestration", orchAction.getName()); + } + + @Test + void execute_failedOperationRollsBackActions() { + // Create an entity that signals another entity then fails + TaskEntityFactory factory = () -> (ITaskEntity) operation -> { + operation.getContext().signalEntity( + new EntityInstanceId("Other", "o1"), "op", null); + throw new RuntimeException("fail after signal"); + }; + + TaskEntityExecutor executor = createExecutor("FailAfterSignal", factory); + + EntityBatchRequest request = buildBatchRequest("FailAfterSignal", "key1", null, + buildOperationRequest("op")); + + EntityBatchResult result = executor.execute(request); + + assertEquals(1, result.getResultsCount()); + assertTrue(result.getResults(0).hasFailure()); + + // Actions should be empty because the operation failed and actions were rolled back + assertEquals(0, result.getActionsCount()); + } + + // endregion + + // region Reset/delete tests + + @Test + void execute_deleteOperation_deletesState() { + TaskEntityExecutor executor = createExecutor("Counter", TestCounterEntity::new); + + EntityBatchRequest request = buildBatchRequest("Counter", "c1", + dataConverter.serialize(42), + buildOperationRequest("delete")); + + EntityBatchResult result = executor.execute(request); + + assertEquals(1, result.getResultsCount()); + assertTrue(result.getResults(0).hasSuccess()); + + // State should be deleted (not set) + assertFalse(result.hasEntityState()); + } + + @Test + void execute_resetOperation_setsStateToZero() { + TaskEntityExecutor executor = createExecutor("Counter", TestCounterEntity::new); + + EntityBatchRequest request = buildBatchRequest("Counter", "c1", + dataConverter.serialize(42), + buildOperationRequest("reset"), + buildOperationRequest("get")); + + EntityBatchResult result = executor.execute(request); + + assertEquals(2, result.getResultsCount()); + assertTrue(result.getResults(0).hasSuccess()); + assertTrue(result.getResults(1).hasSuccess()); + assertEquals("0", result.getResults(1).getSuccess().getResult().getValue()); + } + + // endregion + + // region Timestamps tests + + @Test + void execute_successResult_hasTimestamps() { + TaskEntityExecutor executor = createExecutor("Counter", TestCounterEntity::new); + + EntityBatchRequest request = buildBatchRequest("Counter", "c1", + dataConverter.serialize(0), + buildOperationRequest("add", 1)); + + EntityBatchResult result = executor.execute(request); + + OperationResultSuccess success = result.getResults(0).getSuccess(); + assertTrue(success.hasStartTimeUtc()); + assertTrue(success.hasEndTimeUtc()); + assertTrue(success.getEndTimeUtc().getSeconds() >= success.getStartTimeUtc().getSeconds()); + } + + @Test + void execute_failureResult_hasTimestamps() { + TaskEntityExecutor executor = createExecutor("Failing", FailingEntity::new); + + EntityBatchRequest request = buildBatchRequest("Failing", "f1", null, + buildOperationRequest("op")); + + EntityBatchResult result = executor.execute(request); + + OperationResultFailure failure = result.getResults(0).getFailure(); + assertTrue(failure.hasStartTimeUtc()); + assertTrue(failure.hasEndTimeUtc()); + } + + // endregion + + // region Null factory result test + + @Test + void execute_factoryReturnsNull_returnsFailure() { + TaskEntityExecutor executor = createExecutor("NullEntity", () -> null); + + EntityBatchRequest request = buildBatchRequest("NullEntity", "n1", null, + buildOperationRequest("op")); + + EntityBatchResult result = executor.execute(request); + + assertTrue(result.hasFailureDetails()); + assertTrue(result.getFailureDetails().getErrorMessage().contains("null")); + } + + @Test + void execute_factoryThrows_returnsFailure() { + TaskEntityExecutor executor = createExecutor("ThrowFactory", () -> { + throw new RuntimeException("Factory boom"); + }); + + EntityBatchRequest request = buildBatchRequest("ThrowFactory", "t1", null, + buildOperationRequest("op")); + + EntityBatchResult result = executor.execute(request); + + assertTrue(result.hasFailureDetails()); + assertTrue(result.getFailureDetails().getErrorMessage().contains("Factory boom")); + } + + // endregion +} diff --git a/client/src/test/java/com/microsoft/durabletask/TaskEntityTest.java b/client/src/test/java/com/microsoft/durabletask/TaskEntityTest.java new file mode 100644 index 00000000..62b94c7a --- /dev/null +++ b/client/src/test/java/com/microsoft/durabletask/TaskEntityTest.java @@ -0,0 +1,421 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.microsoft.durabletask; + +import org.junit.jupiter.api.Test; + +import javax.annotation.Nullable; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * Unit tests for {@link TaskEntity} reflection-based dispatch. + */ +public class TaskEntityTest { + + // region Test entity classes + + /** + * A simple counter entity for testing basic operations. + */ + static class CounterEntity extends TaskEntity { + public void add(int amount) { + this.state += amount; + } + + public void reset() { + this.state = 0; + } + + public int get() { + return this.state; + } + + @Override + protected Integer initializeState(TaskEntityOperation operation) { + return 0; + } + + @Override + protected Class getStateType() { + return Integer.class; + } + } + + /** + * Entity that accepts a TaskEntityContext as a method parameter. + */ + static class EntityWithContextParam extends TaskEntity { + public String info(TaskEntityContext context) { + return "Entity ID: " + context.getId().toString(); + } + + @Override + protected String initializeState(TaskEntityOperation operation) { + return ""; + } + + @Override + protected Class getStateType() { + return String.class; + } + } + + /** + * Entity that accepts both input and context parameters. + */ + static class EntityWithTwoParams extends TaskEntity { + public String greet(String name, TaskEntityContext context) { + return "Hello, " + name + " from " + context.getId().getKey(); + } + + @Override + protected String initializeState(TaskEntityOperation operation) { + return ""; + } + + @Override + protected Class getStateType() { + return String.class; + } + } + + /** + * State class used for state dispatch testing. + */ + static class MyState { + private int value; + + public MyState() { + this.value = 0; + } + + public int getValue() { + return this.value; + } + + public void increment() { + this.value++; + } + + public void setValue(int value) { + this.value = value; + } + } + + /** + * Entity that has no matching method but whose state type has the method (state dispatch). + * Explicitly enables state dispatch since the default is now {@code false}. + */ + static class StateDispatchEntity extends TaskEntity { + // No "increment" method on the entity itself — should dispatch to MyState.increment() + + public StateDispatchEntity() { + setAllowStateDispatch(true); + } + + @Override + protected Class getStateType() { + return MyState.class; + } + } + + /** + * Entity that throws during an operation. + */ + static class ThrowingEntity extends TaskEntity { + public void fail() { + throw new RuntimeException("Intentional failure"); + } + + @Override + protected String initializeState(TaskEntityOperation operation) { + return "initial"; + } + + @Override + protected Class getStateType() { + return String.class; + } + } + + /** + * Entity that disables state dispatch. + */ + static class NoStateDispatchEntity extends TaskEntity { + public NoStateDispatchEntity() { + setAllowStateDispatch(false); + } + + @Override + protected Class getStateType() { + return MyState.class; + } + } + + /** + * Entity with overloaded methods (ambiguous match). + */ + static class AmbiguousEntity extends TaskEntity { + public void add(int amount) { + this.state += amount; + } + + public void add(String label) { + // overloaded — should trigger ambiguous match error + } + + @Override + protected Integer initializeState(TaskEntityOperation operation) { + return 0; + } + + @Override + protected Class getStateType() { + return Integer.class; + } + } + + // endregion + + // region Helper methods + + private TaskEntityOperation createOperation(String operationName, Object input, String serializedState) { + DataConverter converter = new JacksonDataConverter(); + String serializedInput = input != null ? converter.serialize(input) : null; + + // Create a minimal TaskEntityContext for testing + EntityInstanceId entityId = new EntityInstanceId("TestEntity", "testKey"); + TaskEntityContext context = new TaskEntityContext() { + @Override + public EntityInstanceId getId() { + return entityId; + } + + @Override + public void signalEntity(EntityInstanceId entityId, String operationName, Object input, SignalEntityOptions options) { + // no-op for tests + } + + @Override + public String startNewOrchestration(String name, Object input, NewOrchestrationInstanceOptions options) { + return "test-orchestration-id"; + } + }; + + TaskEntityState state = new TaskEntityState(converter, serializedState); + + return new TaskEntityOperation(operationName, serializedInput, context, state, converter); + } + + private TaskEntityOperation createOperation(String operationName, Object input) { + return createOperation(operationName, input, null); + } + + private TaskEntityOperation createOperation(String operationName) { + return createOperation(operationName, null, null); + } + + // endregion + + // region Reflection dispatch tests + + @Test + void reflectionDispatch_voidMethodNoArgs() throws Exception { + CounterEntity entity = new CounterEntity(); + TaskEntityOperation op = createOperation("reset"); + entity.runAsync(op); + assertEquals(0, entity.state); + } + + @Test + void reflectionDispatch_voidMethodWithArg() throws Exception { + CounterEntity entity = new CounterEntity(); + // Set initial state to 0 + TaskEntityOperation op = createOperation("add", 5); + entity.runAsync(op); + assertEquals(5, entity.state); + } + + @Test + void reflectionDispatch_methodWithReturnValue() throws Exception { + CounterEntity entity = new CounterEntity(); + // Pre-load entity state to 42 and call "get" + DataConverter converter = new JacksonDataConverter(); + String serializedState = converter.serialize(42); + Object result = entity.runAsync(createOperation("get", null, serializedState)); + assertEquals(42, result); + } + + @Test + void reflectionDispatch_caseInsensitive() throws Exception { + CounterEntity entity = new CounterEntity(); + + // Method is "add" but call with "ADD" + TaskEntityOperation addOp = createOperation("ADD", 10); + entity.runAsync(addOp); + // After "add", state was saved to the operation's state + String serializedState = addOp.getState().getSerializedState(); + assertEquals(10, entity.state); + + // Method is "get" but call with "Get" — carry state forward + Object result = entity.runAsync(createOperation("Get", null, serializedState)); + assertEquals(10, result); + } + + @Test + void reflectionDispatch_methodWithContextParam() throws Exception { + EntityWithContextParam entity = new EntityWithContextParam(); + Object result = entity.runAsync(createOperation("info")); + assertNotNull(result); + assertTrue(result.toString().contains("testentity")); + assertTrue(result.toString().contains("testKey")); + } + + @Test + void reflectionDispatch_methodWithTwoParams() throws Exception { + EntityWithTwoParams entity = new EntityWithTwoParams(); + Object result = entity.runAsync(createOperation("greet", "World")); + assertEquals("Hello, World from testKey", result); + } + + // endregion + + // region Implicit delete tests + + @Test + void implicitDelete_deletesState() throws Exception { + CounterEntity entity = new CounterEntity(); + DataConverter converter = new JacksonDataConverter(); + String serializedState = converter.serialize(42); + + TaskEntityOperation op = createOperation("delete", null, serializedState); + entity.runAsync(op); + + assertFalse(op.getState().hasState()); + assertNull(entity.state); + } + + @Test + void implicitDelete_caseInsensitive() throws Exception { + CounterEntity entity = new CounterEntity(); + DataConverter converter = new JacksonDataConverter(); + String serializedState = converter.serialize(42); + + TaskEntityOperation op = createOperation("DELETE", null, serializedState); + entity.runAsync(op); + + assertFalse(op.getState().hasState()); + } + + // endregion + + // region State dispatch tests + + @Test + void stateDispatch_delegatesToStateMethod() throws Exception { + StateDispatchEntity entity = new StateDispatchEntity(); + DataConverter converter = new JacksonDataConverter(); + String serializedState = converter.serialize(new MyState()); + + TaskEntityOperation op = createOperation("increment", null, serializedState); + entity.runAsync(op); + + // State should have been incremented + assertEquals(1, entity.state.getValue()); + } + + @Test + void stateDispatch_disabledWithAllowStateDispatchFalse() { + NoStateDispatchEntity entity = new NoStateDispatchEntity(); + DataConverter converter = new JacksonDataConverter(); + String serializedState = converter.serialize(new MyState()); + + // "increment" exists on MyState but not on NoStateDispatchEntity. + // With allowStateDispatch=false, it should throw UnsupportedOperationException. + assertThrows(UnsupportedOperationException.class, () -> { + entity.runAsync(createOperation("increment", null, serializedState)); + }); + } + + @Test + void stateDispatch_disabledByDefault() throws Exception { + // Default is now false, matching the .NET SDK + CounterEntity entity = new CounterEntity(); + assertFalse(entity.getAllowStateDispatch()); + } + + // endregion + + // region Ambiguous match tests + + @Test + void ambiguousMatch_throwsIllegalStateException() { + AmbiguousEntity entity = new AmbiguousEntity(); + // "add" has two overloads: add(int) and add(String) — should throw + assertThrows(IllegalStateException.class, () -> { + entity.runAsync(createOperation("add", 5)); + }); + } + + // endregion + + // region Error handling tests + + @Test + void unknownOperation_throwsException() { + CounterEntity entity = new CounterEntity(); + assertThrows(UnsupportedOperationException.class, () -> { + entity.runAsync(createOperation("nonExistentOperation")); + }); + } + + @Test + void throwingOperation_propagatesException() { + ThrowingEntity entity = new ThrowingEntity(); + RuntimeException ex = assertThrows(RuntimeException.class, () -> { + entity.runAsync(createOperation("fail")); + }); + assertEquals("Intentional failure", ex.getMessage()); + } + + // endregion + + // region State initialization tests + + @Test + void stateInitialization_defaultInitializer() throws Exception { + CounterEntity entity = new CounterEntity(); + TaskEntityOperation op = createOperation("get"); + Object result = entity.runAsync(op); + // initializeState returns 0 for CounterEntity + assertEquals(0, result); + } + + @Test + void stateInitialization_withExistingState() throws Exception { + CounterEntity entity = new CounterEntity(); + DataConverter converter = new JacksonDataConverter(); + String serializedState = converter.serialize(99); + + TaskEntityOperation op = createOperation("get", null, serializedState); + Object result = entity.runAsync(op); + assertEquals(99, result); + } + + @Test + void statePersistence_stateIsSavedAfterOperation() throws Exception { + CounterEntity entity = new CounterEntity(); + TaskEntityOperation op = createOperation("add", 5); + entity.runAsync(op); + + // State should have been saved back + assertTrue(op.getState().hasState()); + + // Verify the saved state can be deserialized + Integer savedState = op.getState().getState(Integer.class); + assertEquals(5, savedState); + } + + // endregion +} diff --git a/client/src/test/java/com/microsoft/durabletask/TaskOrchestrationEntityEventTest.java b/client/src/test/java/com/microsoft/durabletask/TaskOrchestrationEntityEventTest.java new file mode 100644 index 00000000..31e81a26 --- /dev/null +++ b/client/src/test/java/com/microsoft/durabletask/TaskOrchestrationEntityEventTest.java @@ -0,0 +1,1321 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.microsoft.durabletask; + +import com.google.protobuf.StringValue; +import com.google.protobuf.Timestamp; +import com.microsoft.durabletask.implementation.protobuf.OrchestratorService.*; +import org.junit.jupiter.api.Test; + +import java.time.Duration; +import java.time.Instant; +import java.util.*; +import java.util.logging.Logger; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * Unit tests for orchestration ↔ entity integration (Phase 4). + *

+ * These tests construct {@link HistoryEvent} protobufs manually, run them through + * {@link TaskOrchestrationExecutor}, and assert on the returned actions and orchestration + * state. This mirrors the pattern of {@link TaskOrchestrationExecutorTest}. + */ +public class TaskOrchestrationEntityEventTest { + + private static final Logger logger = Logger.getLogger(TaskOrchestrationEntityEventTest.class.getName()); + + // region Helper methods + + private TaskOrchestrationExecutor createExecutor(String orchestratorName, TaskOrchestration orchestration) { + HashMap factories = new HashMap<>(); + factories.put(orchestratorName, new TaskOrchestrationFactory() { + @Override + public String getName() { + return orchestratorName; + } + + @Override + public TaskOrchestration create() { + return orchestration; + } + }); + return new TaskOrchestrationExecutor( + factories, + new JacksonDataConverter(), + Duration.ofDays(3), + logger, + null); + } + + private HistoryEvent orchestratorStarted() { + return HistoryEvent.newBuilder() + .setEventId(-1) + .setTimestamp(Timestamp.getDefaultInstance()) + .setOrchestratorStarted(OrchestratorStartedEvent.getDefaultInstance()) + .build(); + } + + private HistoryEvent executionStarted(String name, String input) { + return HistoryEvent.newBuilder() + .setEventId(-1) + .setTimestamp(Timestamp.getDefaultInstance()) + .setExecutionStarted(ExecutionStartedEvent.newBuilder() + .setName(name) + .setVersion(StringValue.of("")) + .setInput(StringValue.of(input != null ? input : "null")) + .setOrchestrationInstance(OrchestrationInstance.newBuilder() + .setInstanceId("test-instance-id") + .build()) + .build()) + .build(); + } + + private HistoryEvent orchestratorCompleted() { + return HistoryEvent.newBuilder() + .setEventId(-1) + .setTimestamp(Timestamp.getDefaultInstance()) + .setOrchestratorCompleted(OrchestratorCompletedEvent.getDefaultInstance()) + .build(); + } + + private HistoryEvent entityOperationSignaledEvent(int eventId) { + return HistoryEvent.newBuilder() + .setEventId(eventId) + .setTimestamp(Timestamp.getDefaultInstance()) + .setEntityOperationSignaled(EntityOperationSignaledEvent.newBuilder() + .setRequestId("signal-request-id") + .setOperation("add") + .setTargetInstanceId(StringValue.of("@Counter@c1")) + .build()) + .build(); + } + + private HistoryEvent entityOperationCalledEvent(int eventId) { + return HistoryEvent.newBuilder() + .setEventId(eventId) + .setTimestamp(Timestamp.getDefaultInstance()) + .setEntityOperationCalled(EntityOperationCalledEvent.newBuilder() + .setRequestId("call-request-id") + .setOperation("get") + .setTargetInstanceId(StringValue.of("@Counter@c1")) + .build()) + .build(); + } + + private HistoryEvent entityOperationCompletedEvent(String requestId, String output) { + EntityOperationCompletedEvent.Builder builder = EntityOperationCompletedEvent.newBuilder() + .setRequestId(requestId); + if (output != null) { + builder.setOutput(StringValue.of(output)); + } + return HistoryEvent.newBuilder() + .setEventId(-1) + .setTimestamp(Timestamp.getDefaultInstance()) + .setEntityOperationCompleted(builder.build()) + .build(); + } + + private HistoryEvent entityOperationFailedEvent(String requestId, String errorType, String errorMessage) { + return HistoryEvent.newBuilder() + .setEventId(-1) + .setTimestamp(Timestamp.getDefaultInstance()) + .setEntityOperationFailed(EntityOperationFailedEvent.newBuilder() + .setRequestId(requestId) + .setFailureDetails(TaskFailureDetails.newBuilder() + .setErrorType(errorType) + .setErrorMessage(errorMessage) + .build()) + .build()) + .build(); + } + + private HistoryEvent entityLockRequestedEvent(int eventId) { + return HistoryEvent.newBuilder() + .setEventId(eventId) + .setTimestamp(Timestamp.getDefaultInstance()) + .setEntityLockRequested(EntityLockRequestedEvent.newBuilder() + .setCriticalSectionId("lock-cs-id") + .addLockSet("@Counter@c1") + .setPosition(0) + .setParentInstanceId(StringValue.of("test-instance-id")) + .build()) + .build(); + } + + private HistoryEvent entityLockGrantedEvent(String criticalSectionId) { + return HistoryEvent.newBuilder() + .setEventId(-1) + .setTimestamp(Timestamp.getDefaultInstance()) + .setEntityLockGranted(EntityLockGrantedEvent.newBuilder() + .setCriticalSectionId(criticalSectionId) + .build()) + .build(); + } + + private HistoryEvent entityUnlockSentEvent(int eventId) { + return HistoryEvent.newBuilder() + .setEventId(eventId) + .setTimestamp(Timestamp.getDefaultInstance()) + .setEntityUnlockSent(EntityUnlockSentEvent.newBuilder() + .setCriticalSectionId("lock-cs-id") + .setParentInstanceId(StringValue.of("test-instance-id")) + .setTargetInstanceId(StringValue.of("@Counter@c1")) + .build()) + .build(); + } + + private HistoryEvent eventSentEvent(int eventId) { + return HistoryEvent.newBuilder() + .setEventId(eventId) + .setTimestamp(Timestamp.getDefaultInstance()) + .setEventSent(EventSentEvent.newBuilder() + .setName("someEvent") + .setInstanceId("some-instance") + .build()) + .build(); + } + + /** + * Creates an EventRaised HistoryEvent. This simulates the trigger binding code path + * where DTFx.Core delivers entity responses as EventRaised events (not proto entity events). + */ + private HistoryEvent eventRaisedEvent(String name, String input) { + return HistoryEvent.newBuilder() + .setEventId(-1) + .setTimestamp(Timestamp.getDefaultInstance()) + .setEventRaised(EventRaisedEvent.newBuilder() + .setName(name) + .setInput(StringValue.of(input)) + .build()) + .build(); + } + + /** + * Builds a DTFx ResponseMessage JSON for a successful entity operation. + * Matches the format produced by DTFx.Core's TaskEntityDispatcher: + * - "result": the serialized operation result + * - Error fields ("exceptionType", "failureDetails") are omitted on success (EmitDefaultValue=false) + */ + private String successResponseMessageJson(String result) { + // DTFx serializes with Newtonsoft.Json; result is always present. + // On success, exceptionType and failureDetails are omitted. + if (result == null) { + return "{\"result\":null}"; + } + return "{\"result\":" + escapeJsonString(result) + "}"; + } + + /** + * Builds a DTFx ResponseMessage JSON for a failed entity operation. + * The C# ResponseMessage property ErrorMessage maps to JSON key "exceptionType" + * (due to [DataMember(Name = "exceptionType")]). FailureDetails uses PascalCase fields. + */ + private String failedResponseMessageJson(String errorMessage, String errorType) { + return "{\"result\":null," + + "\"exceptionType\":" + escapeJsonString(errorMessage) + "," + + "\"failureDetails\":{" + + "\"ErrorType\":" + escapeJsonString(errorType) + "," + + "\"ErrorMessage\":" + escapeJsonString(errorMessage) + + "}}"; + } + + /** + * Builds a DTFx ResponseMessage JSON for a failed entity operation (exceptionType only, no failureDetails). + */ + private String failedResponseMessageJsonSimple(String errorMessage) { + return "{\"result\":null," + + "\"exceptionType\":" + escapeJsonString(errorMessage) + "}"; + } + + private String escapeJsonString(String value) { + if (value == null) return "null"; + // Simple JSON string escaping for test purposes + return "\"" + value.replace("\\", "\\\\").replace("\"", "\\\"") + "\""; + } + + // endregion + + // region signalEntity tests + + @Test + void signalEntity_producesSendEntityMessageAction() { + final String orchestratorName = "SignalEntityOrchestration"; + EntityInstanceId entityId = new EntityInstanceId("Counter", "c1"); + + TaskOrchestrationExecutor executor = createExecutor(orchestratorName, ctx -> { + ctx.signalEntity(entityId, "add", 5); + ctx.complete("done"); + }); + + List pastEvents = Arrays.asList( + orchestratorStarted(), + executionStarted(orchestratorName, "null")); + List newEvents = Collections.singletonList(orchestratorCompleted()); + + TaskOrchestratorResult result = executor.execute(pastEvents, newEvents); + + // Should have two actions: sendEntityMessage (signal) and completeOrchestration + Collection actions = result.getActions(); + boolean hasSignal = false; + boolean hasComplete = false; + for (OrchestratorAction action : actions) { + if (action.hasSendEntityMessage()) { + SendEntityMessageAction msg = action.getSendEntityMessage(); + assertTrue(msg.hasEntityOperationSignaled()); + EntityOperationSignaledEvent signal = msg.getEntityOperationSignaled(); + assertEquals("add", signal.getOperation()); + assertEquals("@counter@c1", signal.getTargetInstanceId().getValue()); + hasSignal = true; + } + if (action.hasCompleteOrchestration()) { + hasComplete = true; + } + } + assertTrue(hasSignal, "Expected a sendEntityMessage action with signal"); + assertTrue(hasComplete, "Expected a completeOrchestration action"); + } + + @Test + void getEntities_signalEntity_producesSendEntityMessageAction() { + final String orchestratorName = "SignalViaEntitiesFeatureOrchestration"; + EntityInstanceId entityId = new EntityInstanceId("Counter", "c1"); + + TaskOrchestrationExecutor executor = createExecutor(orchestratorName, ctx -> { + ctx.getEntities().signalEntity(entityId, "add", 7); + ctx.complete("done"); + }); + + List pastEvents = Arrays.asList( + orchestratorStarted(), + executionStarted(orchestratorName, "null")); + List newEvents = Collections.singletonList(orchestratorCompleted()); + + TaskOrchestratorResult result = executor.execute(pastEvents, newEvents); + + boolean hasSignal = false; + for (OrchestratorAction action : result.getActions()) { + if (action.hasSendEntityMessage()) { + SendEntityMessageAction msg = action.getSendEntityMessage(); + if (msg.hasEntityOperationSignaled()) { + EntityOperationSignaledEvent signal = msg.getEntityOperationSignaled(); + assertEquals("add", signal.getOperation()); + assertEquals("@counter@c1", signal.getTargetInstanceId().getValue()); + hasSignal = true; + } + } + } + + assertTrue(hasSignal, "Expected a sendEntityMessage action with signal"); + } + + @Test + void signalEntity_replayPassesNonDeterminismCheck() { + final String orchestratorName = "SignalEntityReplay"; + EntityInstanceId entityId = new EntityInstanceId("Counter", "c1"); + + TaskOrchestrationExecutor executor = createExecutor(orchestratorName, ctx -> { + ctx.signalEntity(entityId, "add", 5); + ctx.complete("done"); + }); + + // First execution produces the signal action (eventId = 0) + // On replay, the ENTITYOPERATIONSIGNALED event confirms the action + List pastEvents = Arrays.asList( + orchestratorStarted(), + executionStarted(orchestratorName, "null"), + entityOperationSignaledEvent(0), + orchestratorCompleted()); + List newEvents = Collections.singletonList(orchestratorCompleted()); + + // This should NOT throw a NonDeterministicOrchestratorException + TaskOrchestratorResult result = executor.execute(pastEvents, newEvents); + + // Should still have the complete action + boolean hasComplete = false; + for (OrchestratorAction action : result.getActions()) { + if (action.hasCompleteOrchestration()) { + hasComplete = true; + } + } + assertTrue(hasComplete); + } + + // endregion + + // region callEntity tests + + @Test + void callEntity_producesActionAndWaitsForResponse() { + final String orchestratorName = "CallEntityOrchestration"; + EntityInstanceId entityId = new EntityInstanceId("Counter", "c1"); + + TaskOrchestrationExecutor executor = createExecutor(orchestratorName, ctx -> { + int value = ctx.callEntity(entityId, "get", null, int.class).await(); + ctx.complete(value); + }); + + // First execution: produces the call action but blocks because no response yet + List pastEvents = Arrays.asList( + orchestratorStarted(), + executionStarted(orchestratorName, "null")); + List newEvents = Collections.singletonList(orchestratorCompleted()); + + TaskOrchestratorResult result = executor.execute(pastEvents, newEvents); + + // Should have the sendEntityMessage (call) action + boolean hasCall = false; + for (OrchestratorAction action : result.getActions()) { + if (action.hasSendEntityMessage()) { + SendEntityMessageAction msg = action.getSendEntityMessage(); + assertTrue(msg.hasEntityOperationCalled()); + EntityOperationCalledEvent call = msg.getEntityOperationCalled(); + assertEquals("get", call.getOperation()); + assertEquals("@counter@c1", call.getTargetInstanceId().getValue()); + hasCall = true; + } + } + assertTrue(hasCall, "Expected a sendEntityMessage action with call"); + + // Should NOT have a complete action (it's waiting for the response) + boolean hasComplete = false; + for (OrchestratorAction action : result.getActions()) { + if (action.hasCompleteOrchestration()) { + hasComplete = true; + } + } + assertFalse(hasComplete, "Should not complete while waiting for entity response"); + } + + @Test + void entities_callEntity_producesActionAndWaitsForResponse() { + final String orchestratorName = "CallViaEntitiesFeatureOrchestration"; + EntityInstanceId entityId = new EntityInstanceId("Counter", "c1"); + + TaskOrchestrationExecutor executor = createExecutor(orchestratorName, ctx -> { + int value = ctx.entities().callEntity(entityId, "get", null, int.class).await(); + ctx.complete(value); + }); + + List pastEvents = Arrays.asList( + orchestratorStarted(), + executionStarted(orchestratorName, "null")); + List newEvents = Collections.singletonList(orchestratorCompleted()); + + TaskOrchestratorResult result = executor.execute(pastEvents, newEvents); + + boolean hasCall = false; + boolean hasComplete = false; + for (OrchestratorAction action : result.getActions()) { + if (action.hasSendEntityMessage()) { + SendEntityMessageAction msg = action.getSendEntityMessage(); + if (msg.hasEntityOperationCalled()) { + EntityOperationCalledEvent call = msg.getEntityOperationCalled(); + assertEquals("get", call.getOperation()); + assertEquals("@counter@c1", call.getTargetInstanceId().getValue()); + hasCall = true; + } + } + if (action.hasCompleteOrchestration()) { + hasComplete = true; + } + } + + assertTrue(hasCall, "Expected a sendEntityMessage action with call"); + assertFalse(hasComplete, "Should not complete while waiting for entity response"); + } + + @Test + void callEntity_completesWhenResponseArrives() { + final String orchestratorName = "CallEntityComplete"; + EntityInstanceId entityId = new EntityInstanceId("Counter", "c1"); + + // We need to figure out the requestId that the executor generates. + // The executor uses newUUID() which is deterministic based on instanceId + timestamp + counter. + // For this test we need to get the requestId from the first execution. + + TaskOrchestrationExecutor executor = createExecutor(orchestratorName, ctx -> { + int value = ctx.callEntity(entityId, "get", null, int.class).await(); + ctx.complete(value); + }); + + // First pass: execute and capture the requestId from the generated action + List pastEvents1 = Arrays.asList( + orchestratorStarted(), + executionStarted(orchestratorName, "null")); + List newEvents1 = Collections.singletonList(orchestratorCompleted()); + + TaskOrchestratorResult result1 = executor.execute(pastEvents1, newEvents1); + + // Extract the requestId from the call action + String requestId = null; + for (OrchestratorAction action : result1.getActions()) { + if (action.hasSendEntityMessage() && action.getSendEntityMessage().hasEntityOperationCalled()) { + requestId = action.getSendEntityMessage().getEntityOperationCalled().getRequestId(); + } + } + assertNotNull(requestId, "Should have captured the requestId"); + + // Second pass (replay): include the call event in past and provide the response + executor = createExecutor(orchestratorName, ctx -> { + int value = ctx.callEntity(entityId, "get", null, int.class).await(); + ctx.complete(value); + }); + + List pastEvents2 = Arrays.asList( + orchestratorStarted(), + executionStarted(orchestratorName, "null"), + entityOperationCalledEvent(0), + orchestratorCompleted()); + List newEvents2 = Arrays.asList( + orchestratorStarted(), + entityOperationCompletedEvent(requestId, "42"), + orchestratorCompleted()); + + TaskOrchestratorResult result2 = executor.execute(pastEvents2, newEvents2); + + // Should now have a complete action with value 42 + boolean hasComplete = false; + for (OrchestratorAction action : result2.getActions()) { + if (action.hasCompleteOrchestration()) { + CompleteOrchestrationAction complete = action.getCompleteOrchestration(); + assertEquals(OrchestrationStatus.ORCHESTRATION_STATUS_COMPLETED, complete.getOrchestrationStatus()); + assertEquals("42", complete.getResult().getValue()); + hasComplete = true; + } + } + assertTrue(hasComplete, "Expected orchestration to complete with entity result"); + } + + @Test + void callEntity_failedResponse_completesExceptionally() { + final String orchestratorName = "CallEntityFail"; + EntityInstanceId entityId = new EntityInstanceId("Counter", "c1"); + + TaskOrchestrationExecutor executor = createExecutor(orchestratorName, ctx -> { + try { + ctx.callEntity(entityId, "get", null, int.class).await(); + ctx.complete("should not reach here"); + } catch (EntityOperationFailedException e) { + ctx.complete("caught: " + e.getFailureDetails().getErrorMessage()); + } + }); + + // First pass: capture requestId + List pastEvents1 = Arrays.asList( + orchestratorStarted(), + executionStarted(orchestratorName, "null")); + List newEvents1 = Collections.singletonList(orchestratorCompleted()); + + TaskOrchestratorResult result1 = executor.execute(pastEvents1, newEvents1); + + String requestId = null; + for (OrchestratorAction action : result1.getActions()) { + if (action.hasSendEntityMessage() && action.getSendEntityMessage().hasEntityOperationCalled()) { + requestId = action.getSendEntityMessage().getEntityOperationCalled().getRequestId(); + } + } + assertNotNull(requestId); + + // Second pass: replay with failed response + executor = createExecutor(orchestratorName, ctx -> { + try { + ctx.callEntity(entityId, "get", null, int.class).await(); + ctx.complete("should not reach here"); + } catch (EntityOperationFailedException e) { + ctx.complete("caught: " + e.getFailureDetails().getErrorMessage()); + } + }); + + List pastEvents2 = Arrays.asList( + orchestratorStarted(), + executionStarted(orchestratorName, "null"), + entityOperationCalledEvent(0), + orchestratorCompleted()); + List newEvents2 = Arrays.asList( + orchestratorStarted(), + entityOperationFailedEvent(requestId, "java.lang.RuntimeException", "Entity error!"), + orchestratorCompleted()); + + TaskOrchestratorResult result2 = executor.execute(pastEvents2, newEvents2); + + boolean hasComplete = false; + for (OrchestratorAction action : result2.getActions()) { + if (action.hasCompleteOrchestration()) { + CompleteOrchestrationAction complete = action.getCompleteOrchestration(); + assertEquals(OrchestrationStatus.ORCHESTRATION_STATUS_COMPLETED, complete.getOrchestrationStatus()); + assertTrue(complete.getResult().getValue().contains("Entity error!")); + hasComplete = true; + } + } + assertTrue(hasComplete, "Expected orchestration to complete after catching entity failure"); + } + + // endregion + + // region callEntity via EventRaised (trigger binding path) tests + // + // These tests simulate the Azure Functions trigger binding code path where entity + // operation responses arrive as EventRaised events containing DTFx ResponseMessage JSON, + // rather than proto EntityOperationCompleted/Failed events (gRPC path). + // + // In the trigger binding path: + // - Past events: EVENTSENT (no-op) instead of ENTITYOPERATIONCALLED + // - New events: EVENTRAISED with ResponseMessage JSON instead of ENTITYOPERATIONCOMPLETED + + @Test + void callEntity_completesViaEventRaised_triggerBindingPath() { + final String orchestratorName = "CallEntityTriggerPath"; + EntityInstanceId entityId = new EntityInstanceId("Counter", "c1"); + + TaskOrchestrationExecutor executor = createExecutor(orchestratorName, ctx -> { + int value = ctx.callEntity(entityId, "get", null, int.class).await(); + ctx.complete(value); + }); + + // First pass: capture the requestId from the generated action + List pastEvents1 = Arrays.asList( + orchestratorStarted(), + executionStarted(orchestratorName, "null")); + List newEvents1 = Collections.singletonList(orchestratorCompleted()); + + TaskOrchestratorResult result1 = executor.execute(pastEvents1, newEvents1); + + String requestId = null; + for (OrchestratorAction action : result1.getActions()) { + if (action.hasSendEntityMessage() && action.getSendEntityMessage().hasEntityOperationCalled()) { + requestId = action.getSendEntityMessage().getEntityOperationCalled().getRequestId(); + } + } + assertNotNull(requestId, "Should have captured the requestId"); + + // Second pass (replay): use EVENTSENT in past (trigger binding records EventSent, not EntityOperationCalled) + // and EVENTRAISED with ResponseMessage JSON in new events (not EntityOperationCompleted) + executor = createExecutor(orchestratorName, ctx -> { + int value = ctx.callEntity(entityId, "get", null, int.class).await(); + ctx.complete(value); + }); + + String responseJson = successResponseMessageJson("42"); + List pastEvents2 = Arrays.asList( + orchestratorStarted(), + executionStarted(orchestratorName, "null"), + eventSentEvent(0), // Trigger binding records EventSent, not EntityOperationCalled + orchestratorCompleted()); + List newEvents2 = Arrays.asList( + orchestratorStarted(), + eventRaisedEvent(requestId, responseJson), // ResponseMessage JSON wrapper + orchestratorCompleted()); + + TaskOrchestratorResult result2 = executor.execute(pastEvents2, newEvents2); + + boolean hasComplete = false; + for (OrchestratorAction action : result2.getActions()) { + if (action.hasCompleteOrchestration()) { + CompleteOrchestrationAction complete = action.getCompleteOrchestration(); + assertEquals(OrchestrationStatus.ORCHESTRATION_STATUS_COMPLETED, complete.getOrchestrationStatus()); + assertEquals("42", complete.getResult().getValue()); + hasComplete = true; + } + } + assertTrue(hasComplete, "Expected orchestration to complete with entity result via EventRaised"); + } + + @Test + void callEntity_stringResultViaEventRaised_triggerBindingPath() { + final String orchestratorName = "CallEntityTriggerPathString"; + EntityInstanceId entityId = new EntityInstanceId("Counter", "c1"); + + TaskOrchestrationExecutor executor = createExecutor(orchestratorName, ctx -> { + String value = ctx.callEntity(entityId, "getName", null, String.class).await(); + ctx.complete(value); + }); + + // First pass: capture requestId + List pastEvents1 = Arrays.asList( + orchestratorStarted(), + executionStarted(orchestratorName, "null")); + List newEvents1 = Collections.singletonList(orchestratorCompleted()); + TaskOrchestratorResult result1 = executor.execute(pastEvents1, newEvents1); + + String requestId = null; + for (OrchestratorAction action : result1.getActions()) { + if (action.hasSendEntityMessage() && action.getSendEntityMessage().hasEntityOperationCalled()) { + requestId = action.getSendEntityMessage().getEntityOperationCalled().getRequestId(); + } + } + assertNotNull(requestId); + + // Second pass: replay with EventRaised containing a JSON-serialized string result + executor = createExecutor(orchestratorName, ctx -> { + String value = ctx.callEntity(entityId, "getName", null, String.class).await(); + ctx.complete(value); + }); + + // DTFx serializes string results as JSON strings: "\"hello\"" + String responseJson = successResponseMessageJson("\"hello\""); + List pastEvents2 = Arrays.asList( + orchestratorStarted(), + executionStarted(orchestratorName, "null"), + eventSentEvent(0), + orchestratorCompleted()); + List newEvents2 = Arrays.asList( + orchestratorStarted(), + eventRaisedEvent(requestId, responseJson), + orchestratorCompleted()); + + TaskOrchestratorResult result2 = executor.execute(pastEvents2, newEvents2); + + boolean hasComplete = false; + for (OrchestratorAction action : result2.getActions()) { + if (action.hasCompleteOrchestration()) { + CompleteOrchestrationAction complete = action.getCompleteOrchestration(); + assertEquals(OrchestrationStatus.ORCHESTRATION_STATUS_COMPLETED, complete.getOrchestrationStatus()); + assertEquals("\"hello\"", complete.getResult().getValue()); + hasComplete = true; + } + } + assertTrue(hasComplete, "Expected orchestration to complete with string entity result via EventRaised"); + } + + @Test + void callEntity_nullResultViaEventRaised_triggerBindingPath() { + final String orchestratorName = "CallEntityTriggerPathNull"; + EntityInstanceId entityId = new EntityInstanceId("Counter", "c1"); + + TaskOrchestrationExecutor executor = createExecutor(orchestratorName, ctx -> { + // Void operation — result is null + Void value = ctx.callEntity(entityId, "reset", null, Void.class).await(); + ctx.complete("done"); + }); + + // First pass: capture requestId + List pastEvents1 = Arrays.asList( + orchestratorStarted(), + executionStarted(orchestratorName, "null")); + List newEvents1 = Collections.singletonList(orchestratorCompleted()); + TaskOrchestratorResult result1 = executor.execute(pastEvents1, newEvents1); + + String requestId = null; + for (OrchestratorAction action : result1.getActions()) { + if (action.hasSendEntityMessage() && action.getSendEntityMessage().hasEntityOperationCalled()) { + requestId = action.getSendEntityMessage().getEntityOperationCalled().getRequestId(); + } + } + assertNotNull(requestId); + + // Second pass: replay with null result in ResponseMessage + executor = createExecutor(orchestratorName, ctx -> { + Void value = ctx.callEntity(entityId, "reset", null, Void.class).await(); + ctx.complete("done"); + }); + + String responseJson = successResponseMessageJson(null); + List pastEvents2 = Arrays.asList( + orchestratorStarted(), + executionStarted(orchestratorName, "null"), + eventSentEvent(0), + orchestratorCompleted()); + List newEvents2 = Arrays.asList( + orchestratorStarted(), + eventRaisedEvent(requestId, responseJson), + orchestratorCompleted()); + + TaskOrchestratorResult result2 = executor.execute(pastEvents2, newEvents2); + + boolean hasComplete = false; + for (OrchestratorAction action : result2.getActions()) { + if (action.hasCompleteOrchestration()) { + CompleteOrchestrationAction complete = action.getCompleteOrchestration(); + assertEquals(OrchestrationStatus.ORCHESTRATION_STATUS_COMPLETED, complete.getOrchestrationStatus()); + assertEquals("\"done\"", complete.getResult().getValue()); + hasComplete = true; + } + } + assertTrue(hasComplete, "Expected orchestration to complete after void entity call via EventRaised"); + } + + @Test + void callEntity_failedViaEventRaised_withFailureDetails_triggerBindingPath() { + final String orchestratorName = "CallEntityTriggerPathFail"; + EntityInstanceId entityId = new EntityInstanceId("Counter", "c1"); + + TaskOrchestrationExecutor executor = createExecutor(orchestratorName, ctx -> { + try { + ctx.callEntity(entityId, "get", null, int.class).await(); + ctx.complete("should not reach here"); + } catch (EntityOperationFailedException e) { + ctx.complete("caught: " + e.getFailureDetails().getErrorMessage()); + } + }); + + // First pass: capture requestId + List pastEvents1 = Arrays.asList( + orchestratorStarted(), + executionStarted(orchestratorName, "null")); + List newEvents1 = Collections.singletonList(orchestratorCompleted()); + TaskOrchestratorResult result1 = executor.execute(pastEvents1, newEvents1); + + String requestId = null; + for (OrchestratorAction action : result1.getActions()) { + if (action.hasSendEntityMessage() && action.getSendEntityMessage().hasEntityOperationCalled()) { + requestId = action.getSendEntityMessage().getEntityOperationCalled().getRequestId(); + } + } + assertNotNull(requestId); + + // Second pass: replay with failed ResponseMessage via EventRaised + executor = createExecutor(orchestratorName, ctx -> { + try { + ctx.callEntity(entityId, "get", null, int.class).await(); + ctx.complete("should not reach here"); + } catch (EntityOperationFailedException e) { + ctx.complete("caught: " + e.getFailureDetails().getErrorMessage()); + } + }); + + String responseJson = failedResponseMessageJson("Entity error!", "java.lang.RuntimeException"); + List pastEvents2 = Arrays.asList( + orchestratorStarted(), + executionStarted(orchestratorName, "null"), + eventSentEvent(0), + orchestratorCompleted()); + List newEvents2 = Arrays.asList( + orchestratorStarted(), + eventRaisedEvent(requestId, responseJson), + orchestratorCompleted()); + + TaskOrchestratorResult result2 = executor.execute(pastEvents2, newEvents2); + + boolean hasComplete = false; + for (OrchestratorAction action : result2.getActions()) { + if (action.hasCompleteOrchestration()) { + CompleteOrchestrationAction complete = action.getCompleteOrchestration(); + assertEquals(OrchestrationStatus.ORCHESTRATION_STATUS_COMPLETED, complete.getOrchestrationStatus()); + assertTrue(complete.getResult().getValue().contains("Entity error!"), + "Expected result to contain the error message"); + hasComplete = true; + } + } + assertTrue(hasComplete, "Expected orchestration to complete after catching entity failure via EventRaised"); + } + + @Test + void callEntity_failedViaEventRaised_simpleError_triggerBindingPath() { + final String orchestratorName = "CallEntityTriggerPathFailSimple"; + EntityInstanceId entityId = new EntityInstanceId("Counter", "c1"); + + TaskOrchestrationExecutor executor = createExecutor(orchestratorName, ctx -> { + try { + ctx.callEntity(entityId, "get", null, int.class).await(); + ctx.complete("should not reach here"); + } catch (EntityOperationFailedException e) { + ctx.complete("caught: " + e.getFailureDetails().getErrorMessage()); + } + }); + + // First pass: capture requestId + List pastEvents1 = Arrays.asList( + orchestratorStarted(), + executionStarted(orchestratorName, "null")); + List newEvents1 = Collections.singletonList(orchestratorCompleted()); + TaskOrchestratorResult result1 = executor.execute(pastEvents1, newEvents1); + + String requestId = null; + for (OrchestratorAction action : result1.getActions()) { + if (action.hasSendEntityMessage() && action.getSendEntityMessage().hasEntityOperationCalled()) { + requestId = action.getSendEntityMessage().getEntityOperationCalled().getRequestId(); + } + } + assertNotNull(requestId); + + // Second pass: replay with simple error (exceptionType only, no failureDetails) + executor = createExecutor(orchestratorName, ctx -> { + try { + ctx.callEntity(entityId, "get", null, int.class).await(); + ctx.complete("should not reach here"); + } catch (EntityOperationFailedException e) { + ctx.complete("caught: " + e.getFailureDetails().getErrorMessage()); + } + }); + + String responseJson = failedResponseMessageJsonSimple("Simple error occurred"); + List pastEvents2 = Arrays.asList( + orchestratorStarted(), + executionStarted(orchestratorName, "null"), + eventSentEvent(0), + orchestratorCompleted()); + List newEvents2 = Arrays.asList( + orchestratorStarted(), + eventRaisedEvent(requestId, responseJson), + orchestratorCompleted()); + + TaskOrchestratorResult result2 = executor.execute(pastEvents2, newEvents2); + + boolean hasComplete = false; + for (OrchestratorAction action : result2.getActions()) { + if (action.hasCompleteOrchestration()) { + CompleteOrchestrationAction complete = action.getCompleteOrchestration(); + assertEquals(OrchestrationStatus.ORCHESTRATION_STATUS_COMPLETED, complete.getOrchestrationStatus()); + assertTrue(complete.getResult().getValue().contains("Simple error occurred"), + "Expected result to contain the simple error message"); + hasComplete = true; + } + } + assertTrue(hasComplete, "Expected orchestration to complete after catching simple entity failure via EventRaised"); + } + + // endregion + + // region EVENTSENT no-op test + + @Test + void eventSent_doesNotCrash() { + final String orchestratorName = "EventSentOrchestration"; + + TaskOrchestrationExecutor executor = createExecutor(orchestratorName, ctx -> { + ctx.sendEvent("some-instance", "someEvent", "data"); + ctx.complete("done"); + }); + + // Include EVENTSENT in past events — should be handled as a no-op + List pastEvents = Arrays.asList( + orchestratorStarted(), + executionStarted(orchestratorName, "null"), + eventSentEvent(0), + orchestratorCompleted()); + List newEvents = Collections.singletonList(orchestratorCompleted()); + + // This should NOT throw (EVENTSENT is a no-op) + TaskOrchestratorResult result = executor.execute(pastEvents, newEvents); + + boolean hasComplete = false; + for (OrchestratorAction action : result.getActions()) { + if (action.hasCompleteOrchestration()) { + hasComplete = true; + } + } + assertTrue(hasComplete, "Orchestration should still complete with EVENTSENT in history"); + } + + // endregion + + // region Non-determinism tests + + @Test + void entityOperationSignaled_nonDeterminism_throwsException() { + final String orchestratorName = "NonDetSignal"; + + // Orchestrator that does NOT signal any entity + TaskOrchestrationExecutor executor = createExecutor(orchestratorName, ctx -> { + ctx.complete("done"); + }); + + // But history has an ENTITYOPERATIONSIGNALED event + List pastEvents = Arrays.asList( + orchestratorStarted(), + executionStarted(orchestratorName, "null"), + entityOperationSignaledEvent(0), + orchestratorCompleted()); + List newEvents = Collections.singletonList(orchestratorCompleted()); + + // The orchestrator already completed, so when the non-determinism is detected, + // context.fail() throws IllegalStateException ("already completed") + assertThrows(IllegalStateException.class, () -> + executor.execute(pastEvents, newEvents)); + } + + @Test + void entityOperationCalled_nonDeterminism_throwsException() { + final String orchestratorName = "NonDetCall"; + + // Orchestrator that does NOT call any entity + TaskOrchestrationExecutor executor = createExecutor(orchestratorName, ctx -> { + ctx.complete("done"); + }); + + // But history has an ENTITYOPERATIONCALLED event + List pastEvents = Arrays.asList( + orchestratorStarted(), + executionStarted(orchestratorName, "null"), + entityOperationCalledEvent(0), + orchestratorCompleted()); + List newEvents = Collections.singletonList(orchestratorCompleted()); + + assertThrows(IllegalStateException.class, () -> + executor.execute(pastEvents, newEvents)); + } + + @Test + void entityLockRequested_nonDeterminism_throwsException() { + final String orchestratorName = "NonDetLock"; + + // Orchestrator that does NOT lock any entities + TaskOrchestrationExecutor executor = createExecutor(orchestratorName, ctx -> { + ctx.complete("done"); + }); + + // But history has an ENTITYLOCKREQUESTED event + List pastEvents = Arrays.asList( + orchestratorStarted(), + executionStarted(orchestratorName, "null"), + entityLockRequestedEvent(0), + orchestratorCompleted()); + List newEvents = Collections.singletonList(orchestratorCompleted()); + + assertThrows(IllegalStateException.class, () -> + executor.execute(pastEvents, newEvents)); + } + + @Test + void entityUnlockSent_nonDeterminism_throwsException() { + final String orchestratorName = "NonDetUnlock"; + + // Orchestrator that does NOT unlock any entities + TaskOrchestrationExecutor executor = createExecutor(orchestratorName, ctx -> { + ctx.complete("done"); + }); + + // But history has an ENTITYUNLOCKSENT event + List pastEvents = Arrays.asList( + orchestratorStarted(), + executionStarted(orchestratorName, "null"), + entityUnlockSentEvent(0), + orchestratorCompleted()); + List newEvents = Collections.singletonList(orchestratorCompleted()); + + assertThrows(IllegalStateException.class, () -> + executor.execute(pastEvents, newEvents)); + } + + // endregion + + // region SignalEntityOptions / CallEntityOptions / getLockedEntities / varargs lockEntities tests + + @Test + void signalEntity_withScheduledTime_setsScheduledTimeOnAction() { + final String orchestratorName = "SignalScheduledTimeTest"; + Instant scheduledTime = Instant.parse("2025-06-15T12:00:00Z"); + EntityInstanceId entityId = new EntityInstanceId("Counter", "c1"); + + TaskOrchestrationExecutor executor = createExecutor(orchestratorName, ctx -> { + SignalEntityOptions options = new SignalEntityOptions().setScheduledTime(scheduledTime); + ctx.signalEntity(entityId, "add", 5, options); + ctx.complete("done"); + }); + + List pastEvents = Arrays.asList( + orchestratorStarted(), + executionStarted(orchestratorName, "null")); + List newEvents = Collections.singletonList(orchestratorCompleted()); + + TaskOrchestratorResult result = executor.execute(pastEvents, newEvents); + + boolean hasScheduledSignal = false; + for (OrchestratorAction action : result.getActions()) { + if (action.hasSendEntityMessage()) { + SendEntityMessageAction msg = action.getSendEntityMessage(); + if (msg.hasEntityOperationSignaled()) { + EntityOperationSignaledEvent signal = msg.getEntityOperationSignaled(); + assertTrue(signal.hasScheduledTime(), "Expected scheduledTime to be set"); + assertEquals(scheduledTime.getEpochSecond(), signal.getScheduledTime().getSeconds()); + hasScheduledSignal = true; + } + } + } + assertTrue(hasScheduledSignal, "Expected a signal action with scheduledTime"); + } + + @Test + void signalEntity_withNullOptions_noScheduledTime() { + final String orchestratorName = "SignalNullOptionsTest"; + EntityInstanceId entityId = new EntityInstanceId("Counter", "c1"); + + TaskOrchestrationExecutor executor = createExecutor(orchestratorName, ctx -> { + ctx.signalEntity(entityId, "add", 5, null); + ctx.complete("done"); + }); + + List pastEvents = Arrays.asList( + orchestratorStarted(), + executionStarted(orchestratorName, "null")); + List newEvents = Collections.singletonList(orchestratorCompleted()); + + TaskOrchestratorResult result = executor.execute(pastEvents, newEvents); + + for (OrchestratorAction action : result.getActions()) { + if (action.hasSendEntityMessage() && action.getSendEntityMessage().hasEntityOperationSignaled()) { + EntityOperationSignaledEvent signal = action.getSendEntityMessage().getEntityOperationSignaled(); + assertFalse(signal.hasScheduledTime(), "Expected no scheduledTime when options are null"); + } + } + } + + @Test + void callEntity_withOptions_producesAction() { + final String orchestratorName = "CallWithOptionsTest"; + EntityInstanceId entityId = new EntityInstanceId("Counter", "c1"); + + TaskOrchestrationExecutor executor = createExecutor(orchestratorName, ctx -> { + CallEntityOptions options = new CallEntityOptions().setTimeout(Duration.ofSeconds(30)); + int value = ctx.callEntity(entityId, "get", null, int.class, options).await(); + ctx.complete(value); + }); + + List pastEvents = Arrays.asList( + orchestratorStarted(), + executionStarted(orchestratorName, "null")); + List newEvents = Collections.singletonList(orchestratorCompleted()); + + TaskOrchestratorResult result = executor.execute(pastEvents, newEvents); + + boolean hasCall = false; + for (OrchestratorAction action : result.getActions()) { + if (action.hasSendEntityMessage() && action.getSendEntityMessage().hasEntityOperationCalled()) { + EntityOperationCalledEvent call = action.getSendEntityMessage().getEntityOperationCalled(); + assertEquals("get", call.getOperation()); + assertEquals("@counter@c1", call.getTargetInstanceId().getValue()); + hasCall = true; + } + } + assertTrue(hasCall, "Expected a sendEntityMessage action with call"); + } + + @Test + void callEntity_withTimeout_producesCallAndTimerActions() { + final String orchestratorName = "CallWithTimeoutTest"; + EntityInstanceId entityId = new EntityInstanceId("Counter", "c1"); + + TaskOrchestrationExecutor executor = createExecutor(orchestratorName, ctx -> { + CallEntityOptions options = new CallEntityOptions().setTimeout(Duration.ofSeconds(30)); + ctx.callEntity(entityId, "get", null, int.class, options).await(); + ctx.complete("done"); + }); + + List pastEvents = Arrays.asList( + orchestratorStarted(), + executionStarted(orchestratorName, "null")); + List newEvents = Collections.singletonList(orchestratorCompleted()); + + TaskOrchestratorResult result = executor.execute(pastEvents, newEvents); + + boolean hasCall = false; + boolean hasTimer = false; + for (OrchestratorAction action : result.getActions()) { + if (action.hasSendEntityMessage() && action.getSendEntityMessage().hasEntityOperationCalled()) { + hasCall = true; + } + if (action.hasCreateTimer()) { + hasTimer = true; + } + } + assertTrue(hasCall, "Expected a sendEntityMessage action with call"); + assertTrue(hasTimer, "Expected a createTimer action for the timeout"); + } + + @Test + void callEntity_withoutTimeout_producesNoTimerAction() { + final String orchestratorName = "CallNoTimeoutTest"; + EntityInstanceId entityId = new EntityInstanceId("Counter", "c1"); + + TaskOrchestrationExecutor executor = createExecutor(orchestratorName, ctx -> { + // No options = no timeout + ctx.callEntity(entityId, "get", null, int.class).await(); + ctx.complete("done"); + }); + + List pastEvents = Arrays.asList( + orchestratorStarted(), + executionStarted(orchestratorName, "null")); + List newEvents = Collections.singletonList(orchestratorCompleted()); + + TaskOrchestratorResult result = executor.execute(pastEvents, newEvents); + + boolean hasCall = false; + boolean hasTimer = false; + for (OrchestratorAction action : result.getActions()) { + if (action.hasSendEntityMessage() && action.getSendEntityMessage().hasEntityOperationCalled()) { + hasCall = true; + } + if (action.hasCreateTimer()) { + hasTimer = true; + } + } + assertTrue(hasCall, "Expected a sendEntityMessage action with call"); + assertFalse(hasTimer, "Expected no createTimer action when no timeout is specified"); + } + + @Test + void callEntity_withZeroTimeout_cancelledImmediately() { + final String orchestratorName = "CallZeroTimeoutTest"; + EntityInstanceId entityId = new EntityInstanceId("Counter", "c1"); + + TaskOrchestrationExecutor executor = createExecutor(orchestratorName, ctx -> { + CallEntityOptions options = new CallEntityOptions().setTimeout(Duration.ZERO); + try { + ctx.callEntity(entityId, "get", null, int.class, options).await(); + ctx.complete("should not reach here"); + } catch (TaskCanceledException e) { + ctx.complete("cancelled"); + } + }); + + List pastEvents = Arrays.asList( + orchestratorStarted(), + executionStarted(orchestratorName, "null")); + List newEvents = Collections.singletonList(orchestratorCompleted()); + + TaskOrchestratorResult result = executor.execute(pastEvents, newEvents); + + boolean hasComplete = false; + for (OrchestratorAction action : result.getActions()) { + if (action.hasCompleteOrchestration()) { + String output = action.getCompleteOrchestration().getResult().getValue(); + assertEquals("\"cancelled\"", output); + hasComplete = true; + } + } + assertTrue(hasComplete, "Expected orchestration to complete with 'cancelled' after zero timeout"); + } + + @Test + void getLockedEntities_insideCriticalSection_returnsLockedIds() { + final String orchestratorName = "GetLockedEntitiesTest"; + EntityInstanceId entityId = new EntityInstanceId("Counter", "c1"); + + TaskOrchestrationExecutor executor = createExecutor(orchestratorName, ctx -> { + ctx.lockEntities(Arrays.asList(entityId)).await(); + // Inside critical section, getLockedEntities should return the locked entities + List locked = ctx.getLockedEntities(); + assertFalse(locked.isEmpty(), "Expected locked entities inside critical section"); + assertEquals("counter", locked.get(0).getName()); + assertEquals("c1", locked.get(0).getKey()); + ctx.complete("done"); + }); + + // First execution: orchestrator calls lockEntities, which produces a lock request action + List pastEvents1 = Arrays.asList( + orchestratorStarted(), + executionStarted(orchestratorName, "null")); + List newEvents1 = Collections.singletonList(orchestratorCompleted()); + + TaskOrchestratorResult result1 = executor.execute(pastEvents1, newEvents1); + + // Extract the criticalSectionId from the lock request action + String criticalSectionId = null; + for (OrchestratorAction action : result1.getActions()) { + if (action.hasSendEntityMessage() && action.getSendEntityMessage().hasEntityLockRequested()) { + criticalSectionId = action.getSendEntityMessage().getEntityLockRequested().getCriticalSectionId(); + break; + } + } + assertNotNull(criticalSectionId, "Expected a lock request action with criticalSectionId"); + + // Second execution: replay with the lock request in past, grant in new events + List pastEvents2 = Arrays.asList( + orchestratorStarted(), + executionStarted(orchestratorName, "null"), + entityLockRequestedEvent(0), + orchestratorCompleted()); + List newEvents2 = Arrays.asList( + orchestratorStarted(), + entityLockGrantedEvent(criticalSectionId), + orchestratorCompleted()); + + TaskOrchestratorResult result2 = executor.execute(pastEvents2, newEvents2); + + boolean hasComplete = false; + for (OrchestratorAction action : result2.getActions()) { + if (action.hasCompleteOrchestration()) { + hasComplete = true; + } + } + assertTrue(hasComplete, "Expected orchestration to complete after lock granted"); + } + + @Test + void getLockedEntities_outsideCriticalSection_returnsEmpty() { + final String orchestratorName = "GetLockedEntitiesEmptyTest"; + + TaskOrchestrationExecutor executor = createExecutor(orchestratorName, ctx -> { + List locked = ctx.getLockedEntities(); + assertTrue(locked.isEmpty(), "Expected empty locked entities outside critical section"); + ctx.complete("done"); + }); + + List pastEvents = Arrays.asList( + orchestratorStarted(), + executionStarted(orchestratorName, "null")); + List newEvents = Collections.singletonList(orchestratorCompleted()); + + TaskOrchestratorResult result = executor.execute(pastEvents, newEvents); + + boolean hasComplete = false; + for (OrchestratorAction action : result.getActions()) { + if (action.hasCompleteOrchestration()) { + hasComplete = true; + } + } + assertTrue(hasComplete, "Expected orchestration to complete"); + } + + @Test + void lockEntities_varargs_producesLockAction() { + final String orchestratorName = "VarargsLockTest"; + EntityInstanceId entityId1 = new EntityInstanceId("Counter", "c1"); + EntityInstanceId entityId2 = new EntityInstanceId("Counter", "c2"); + + TaskOrchestrationExecutor executor = createExecutor(orchestratorName, ctx -> { + // Use varargs overload + ctx.lockEntities(entityId1, entityId2).await(); + ctx.complete("locked"); + }); + + List pastEvents = Arrays.asList( + orchestratorStarted(), + executionStarted(orchestratorName, "null")); + List newEvents = Collections.singletonList(orchestratorCompleted()); + + TaskOrchestratorResult result = executor.execute(pastEvents, newEvents); + + boolean hasLockRequest = false; + for (OrchestratorAction action : result.getActions()) { + if (action.hasSendEntityMessage() && action.getSendEntityMessage().hasEntityLockRequested()) { + hasLockRequest = true; + } + } + assertTrue(hasLockRequest, "Expected an entityLockRequest action from varargs lockEntities"); + } + + // endregion + + // region DurableTaskGrpcWorkerBuilder entity config tests + + @Test + void maxConcurrentEntityWorkItems_rejectsZero() { + DurableTaskGrpcWorkerBuilder builder = new DurableTaskGrpcWorkerBuilder(); + assertThrows(IllegalArgumentException.class, () -> builder.maxConcurrentEntityWorkItems(0)); + } + + @Test + void maxConcurrentEntityWorkItems_rejectsNegative() { + DurableTaskGrpcWorkerBuilder builder = new DurableTaskGrpcWorkerBuilder(); + assertThrows(IllegalArgumentException.class, () -> builder.maxConcurrentEntityWorkItems(-1)); + } + + @Test + void maxConcurrentEntityWorkItems_acceptsValidValue() { + DurableTaskGrpcWorkerBuilder builder = new DurableTaskGrpcWorkerBuilder(); + DurableTaskGrpcWorkerBuilder result = builder.maxConcurrentEntityWorkItems(4); + assertSame(builder, result, "Builder should return itself for fluent chaining"); + } + + @Test + void maxConcurrentEntityWorkItems_defaultIsOne() { + DurableTaskGrpcWorkerBuilder builder = new DurableTaskGrpcWorkerBuilder(); + assertEquals(1, builder.maxConcurrentEntityWorkItems, + "Default maxConcurrentEntityWorkItems should be 1"); + } + + // endregion +} diff --git a/client/src/test/java/com/microsoft/durabletask/TypedEntityMetadataTest.java b/client/src/test/java/com/microsoft/durabletask/TypedEntityMetadataTest.java new file mode 100644 index 00000000..5f8eb48c --- /dev/null +++ b/client/src/test/java/com/microsoft/durabletask/TypedEntityMetadataTest.java @@ -0,0 +1,102 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.microsoft.durabletask; + +import org.junit.jupiter.api.Test; + +import java.time.Instant; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * Unit tests for {@link TypedEntityMetadata}. + */ +public class TypedEntityMetadataTest { + + private static final DataConverter dataConverter = new JacksonDataConverter(); + + @Test + void getState_deserializesIntegerState() { + EntityMetadata base = new EntityMetadata( + "@counter@c1", Instant.EPOCH, 0, null, "42", true, dataConverter); + + TypedEntityMetadata typed = new TypedEntityMetadata<>(base, Integer.class); + assertNotNull(typed.getState()); + assertEquals(42, typed.getState()); + } + + @Test + void getState_deserializesStringState() { + EntityMetadata base = new EntityMetadata( + "@myEntity@k1", Instant.EPOCH, 0, null, "\"hello\"", true, dataConverter); + + TypedEntityMetadata typed = new TypedEntityMetadata<>(base, String.class); + assertEquals("hello", typed.getState()); + } + + @Test + void getState_nullWhenNoState() { + EntityMetadata base = new EntityMetadata( + "@counter@c1", Instant.EPOCH, 0, null, null, false, dataConverter); + + TypedEntityMetadata typed = new TypedEntityMetadata<>(base, Integer.class); + assertNull(typed.getState()); + } + + @Test + void getStateType_returnsCorrectClass() { + EntityMetadata base = new EntityMetadata( + "@counter@c1", Instant.EPOCH, 0, null, "42", true, dataConverter); + + TypedEntityMetadata typed = new TypedEntityMetadata<>(base, Integer.class); + assertEquals(Integer.class, typed.getStateType()); + } + + @Test + void inheritsEntityMetadataFields() { + Instant now = Instant.now(); + EntityMetadata base = new EntityMetadata( + "@counter@myKey", now, 5, "orch-123", "99", true, dataConverter); + + TypedEntityMetadata typed = new TypedEntityMetadata<>(base, Integer.class); + + // Inherited fields from EntityMetadata + assertEquals("@counter@myKey", typed.getInstanceId()); + assertEquals(now, typed.getLastModifiedTime()); + assertEquals(5, typed.getBacklogQueueSize()); + assertEquals("orch-123", typed.getLockedBy()); + assertEquals("99", typed.getSerializedState()); + assertTrue(typed.isIncludesState()); + + // Parsed entity ID + EntityInstanceId entityId = typed.getEntityInstanceId(); + assertEquals("counter", entityId.getName()); + assertEquals("myKey", entityId.getKey()); + + // Typed state + assertEquals(99, typed.getState()); + } + + @Test + void isInstanceOfEntityMetadata() { + EntityMetadata base = new EntityMetadata( + "@counter@c1", Instant.EPOCH, 0, null, "42", true, dataConverter); + + TypedEntityMetadata typed = new TypedEntityMetadata<>(base, Integer.class); + + // TypedEntityMetadata IS-A EntityMetadata (matches .NET's EntityMetadata : EntityMetadata) + assertInstanceOf(EntityMetadata.class, typed); + } + + @Test + void readStateAs_stillWorksOnTypedInstance() { + EntityMetadata base = new EntityMetadata( + "@counter@c1", Instant.EPOCH, 0, null, "42", true, dataConverter); + + TypedEntityMetadata typed = new TypedEntityMetadata<>(base, Integer.class); + + // readStateAs from base class should still work + Integer state = typed.readStateAs(Integer.class); + assertEquals(42, state); + } +} diff --git a/endtoendtests/src/main/java/com/functions/CounterEntity.java b/endtoendtests/src/main/java/com/functions/CounterEntity.java new file mode 100644 index 00000000..a0adb87f --- /dev/null +++ b/endtoendtests/src/main/java/com/functions/CounterEntity.java @@ -0,0 +1,33 @@ +package com.functions; + +import com.microsoft.durabletask.TaskEntity; +import com.microsoft.durabletask.TaskEntityOperation; + +/** + * A simple counter entity for e2e testing. + */ +public class CounterEntity extends TaskEntity { + + public int add(int input) { + this.state += input; + return this.state; + } + + public int get() { + return this.state; + } + + public void reset() { + this.state = 0; + } + + @Override + protected Integer initializeState(TaskEntityOperation operation) { + return 0; + } + + @Override + protected Class getStateType() { + return Integer.class; + } +} diff --git a/endtoendtests/src/main/java/com/functions/EntityFunctions.java b/endtoendtests/src/main/java/com/functions/EntityFunctions.java new file mode 100644 index 00000000..1e95c176 --- /dev/null +++ b/endtoendtests/src/main/java/com/functions/EntityFunctions.java @@ -0,0 +1,280 @@ +package com.functions; + +import com.microsoft.azure.functions.*; +import com.microsoft.azure.functions.annotation.*; +import com.microsoft.durabletask.*; +import com.microsoft.durabletask.azurefunctions.DurableClientContext; +import com.microsoft.durabletask.azurefunctions.DurableClientInput; +import com.microsoft.durabletask.azurefunctions.DurableEntityTrigger; +import com.microsoft.durabletask.azurefunctions.DurableOrchestrationTrigger; + +import java.util.Optional; + +/** + * Azure Functions for entity e2e tests. + *

+ * Tests: + *

    + *
  • {@code callEntity} — orchestrator calls entity and returns result
  • + *
  • {@code signalEntity} — orchestrator signals entity (fire-and-forget), then calls to verify
  • + *
  • {@code signalAndCallEntity} — orchestrator signals entity to add, then calls to get updated value
  • + *
+ */ +public class EntityFunctions { + + // ─── Entity trigger ─── + + @FunctionName("Counter") + public String counterEntity( + @DurableEntityTrigger(name = "req") String req) { + return EntityRunner.loadAndRun(req, CounterEntity::new); + } + + // ─── Orchestrations ─── + + /** + * Orchestration that calls the counter entity "add" operation and returns the result. + * Input: JSON with entityKey and addValue fields. + */ + @FunctionName("CallEntityOrchestration") + public int callEntityOrchestration( + @DurableOrchestrationTrigger(name = "ctx") TaskOrchestrationContext ctx) { + EntityPayload input = ctx.getInput(EntityPayload.class); + EntityInstanceId entityId = new EntityInstanceId("counter", input.entityKey); + return ctx.callEntity(entityId, "add", input.addValue, Integer.class).await(); + } + + /** + * Orchestration that signals the counter entity to "add" then calls "get" to verify. + * This tests both signalEntity (fire-and-forget) and callEntity (request-response). + */ + @FunctionName("SignalThenCallEntityOrchestration") + public int signalThenCallEntityOrchestration( + @DurableOrchestrationTrigger(name = "ctx") TaskOrchestrationContext ctx) { + EntityPayload input = ctx.getInput(EntityPayload.class); + EntityInstanceId entityId = new EntityInstanceId("counter", input.entityKey); + + // Signal — fire-and-forget + ctx.signalEntity(entityId, "add", input.addValue); + + // Call — request-response (should see the updated state) + return ctx.callEntity(entityId, "get", null, Integer.class).await(); + } + + /** + * Orchestration that calls "get" on a fresh counter entity (tests zero-initialized state). + */ + @FunctionName("CallEntityGetOrchestration") + public int callEntityGetOrchestration( + @DurableOrchestrationTrigger(name = "ctx") TaskOrchestrationContext ctx) { + String entityKey = ctx.getInput(String.class); + EntityInstanceId entityId = new EntityInstanceId("counter", entityKey); + return ctx.callEntity(entityId, "get", null, Integer.class).await(); + } + + /** + * Comprehensive orchestration that exercises signal, call, and reset on a counter entity. + * Steps: signal add(5) -> call get (expect 5) -> signal add(10) -> call get (expect 15) + * -> signal reset -> call get (expect 0). + * Returns a summary string with pass/fail. + */ + @FunctionName("ComprehensiveEntityOrchestration") + public String comprehensiveEntityOrchestration( + @DurableOrchestrationTrigger(name = "ctx") TaskOrchestrationContext ctx) { + String entityKey = ctx.getInput(String.class); + EntityInstanceId counterId = new EntityInstanceId("counter", entityKey); + StringBuilder result = new StringBuilder(); + + // Test 1: signalEntity (fire-and-forget) — add 5 + ctx.signalEntity(counterId, "add", 5); + result.append("Step 1: Signaled add(5)\n"); + + // Test 2: callEntity (request-response) — get current value, should be 5 + int valueAfterAdd5 = ctx.callEntity(counterId, "get", null, Integer.class).await(); + result.append("Step 2: callEntity get() returned ").append(valueAfterAdd5).append("\n"); + + // Test 3: signalEntity — add 10 + ctx.signalEntity(counterId, "add", 10); + result.append("Step 3: Signaled add(10)\n"); + + // Test 4: callEntity — get current value, should be 15 + int valueAfterAdd10 = ctx.callEntity(counterId, "get", null, Integer.class).await(); + result.append("Step 4: callEntity get() returned ").append(valueAfterAdd10).append("\n"); + + // Test 5: signalEntity — reset + ctx.signalEntity(counterId, "reset"); + result.append("Step 5: Signaled reset()\n"); + + // Test 6: callEntity — get current value, should be 0 + int valueAfterReset = ctx.callEntity(counterId, "get", null, Integer.class).await(); + result.append("Step 6: callEntity get() returned ").append(valueAfterReset).append("\n"); + + // Summary + boolean passed = (valueAfterAdd5 == 5) && (valueAfterAdd10 == 15) && (valueAfterReset == 0); + result.append("\nAll tests passed: ").append(passed); + return result.toString(); + } + + /** + * Orchestration that calls "add" twice with different values to produce a result + * that differs from either individual input (proving the entity accumulates state). + * Input: JSON with entityKey and addValue fields. Calls add(addValue) then add(addValue + 2). + * Returns the final result. + */ + @FunctionName("CallEntityTwiceOrchestration") + public int callEntityTwiceOrchestration( + @DurableOrchestrationTrigger(name = "ctx") TaskOrchestrationContext ctx) { + EntityPayload input = ctx.getInput(EntityPayload.class); + EntityInstanceId entityId = new EntityInstanceId("counter", input.entityKey); + // First add + ctx.callEntity(entityId, "add", input.addValue, Integer.class).await(); + // Second add with a different value + int secondValue = input.addValue + 2; + return ctx.callEntity(entityId, "add", secondValue, Integer.class).await(); + } + + // ─── HTTP triggers ─── + + /** + * POST /api/StartCallEntityOrchestration?key={key}&value={value} + */ + @FunctionName("StartCallEntityOrchestration") + public HttpResponseMessage startCallEntityOrchestration( + @HttpTrigger(name = "req", methods = {HttpMethod.POST}, + authLevel = AuthorizationLevel.ANONYMOUS) HttpRequestMessage> request, + @DurableClientInput(name = "durableContext") DurableClientContext durableContext, + final ExecutionContext context) { + String key = request.getQueryParameters().getOrDefault("key", "e2e-call-" + System.currentTimeMillis()); + int value = Integer.parseInt(request.getQueryParameters().getOrDefault("value", "5")); + + DurableTaskClient client = durableContext.getClient(); + EntityPayload payload = new EntityPayload(key, value); + String instanceId = client.scheduleNewOrchestrationInstance("CallEntityOrchestration", payload); + context.getLogger().info("Started CallEntityOrchestration: " + instanceId); + return durableContext.createCheckStatusResponse(request, instanceId); + } + + /** + * POST /api/StartSignalThenCallEntityOrchestration?key={key}&value={value} + */ + @FunctionName("StartSignalThenCallEntityOrchestration") + public HttpResponseMessage startSignalThenCallEntityOrchestration( + @HttpTrigger(name = "req", methods = {HttpMethod.POST}, + authLevel = AuthorizationLevel.ANONYMOUS) HttpRequestMessage> request, + @DurableClientInput(name = "durableContext") DurableClientContext durableContext, + final ExecutionContext context) { + String key = request.getQueryParameters().getOrDefault("key", "e2e-signal-" + System.currentTimeMillis()); + int value = Integer.parseInt(request.getQueryParameters().getOrDefault("value", "10")); + + DurableTaskClient client = durableContext.getClient(); + EntityPayload payload = new EntityPayload(key, value); + String instanceId = client.scheduleNewOrchestrationInstance("SignalThenCallEntityOrchestration", payload); + context.getLogger().info("Started SignalThenCallEntityOrchestration: " + instanceId); + return durableContext.createCheckStatusResponse(request, instanceId); + } + + /** + * POST /api/StartCallEntityGetOrchestration?key={key} + */ + @FunctionName("StartCallEntityGetOrchestration") + public HttpResponseMessage startCallEntityGetOrchestration( + @HttpTrigger(name = "req", methods = {HttpMethod.POST}, + authLevel = AuthorizationLevel.ANONYMOUS) HttpRequestMessage> request, + @DurableClientInput(name = "durableContext") DurableClientContext durableContext, + final ExecutionContext context) { + String key = request.getQueryParameters().getOrDefault("key", "e2e-get-" + System.currentTimeMillis()); + + DurableTaskClient client = durableContext.getClient(); + String instanceId = client.scheduleNewOrchestrationInstance("CallEntityGetOrchestration", key); + context.getLogger().info("Started CallEntityGetOrchestration: " + instanceId); + return durableContext.createCheckStatusResponse(request, instanceId); + } + + /** + * POST /api/StartComprehensiveEntityOrchestration?key={key} + */ + @FunctionName("StartComprehensiveEntityOrchestration") + public HttpResponseMessage startComprehensiveEntityOrchestration( + @HttpTrigger(name = "req", methods = {HttpMethod.POST}, + authLevel = AuthorizationLevel.ANONYMOUS) HttpRequestMessage> request, + @DurableClientInput(name = "durableContext") DurableClientContext durableContext, + final ExecutionContext context) { + String key = request.getQueryParameters().getOrDefault("key", "e2e-comprehensive-" + System.currentTimeMillis()); + + DurableTaskClient client = durableContext.getClient(); + String instanceId = client.scheduleNewOrchestrationInstance("ComprehensiveEntityOrchestration", key); + context.getLogger().info("Started ComprehensiveEntityOrchestration: " + instanceId); + return durableContext.createCheckStatusResponse(request, instanceId); + } + + /** + * POST /api/StartCallEntityTwiceOrchestration?key={key}&value={value} + */ + @FunctionName("StartCallEntityTwiceOrchestration") + public HttpResponseMessage startCallEntityTwiceOrchestration( + @HttpTrigger(name = "req", methods = {HttpMethod.POST}, + authLevel = AuthorizationLevel.ANONYMOUS) HttpRequestMessage> request, + @DurableClientInput(name = "durableContext") DurableClientContext durableContext, + final ExecutionContext context) { + String key = request.getQueryParameters().getOrDefault("key", "e2e-twice-" + System.currentTimeMillis()); + int value = Integer.parseInt(request.getQueryParameters().getOrDefault("value", "3")); + + DurableTaskClient client = durableContext.getClient(); + EntityPayload payload = new EntityPayload(key, value); + String instanceId = client.scheduleNewOrchestrationInstance("CallEntityTwiceOrchestration", payload); + context.getLogger().info("Started CallEntityTwiceOrchestration: " + instanceId); + return durableContext.createCheckStatusResponse(request, instanceId); + } + + /** + * GET /api/GetEntityState?name={name}&key={key} + * Returns the entity's current state as a JSON integer. + */ + @FunctionName("GetEntityState") + public HttpResponseMessage getEntityState( + @HttpTrigger(name = "req", methods = {HttpMethod.GET}, + authLevel = AuthorizationLevel.ANONYMOUS) HttpRequestMessage> request, + @DurableClientInput(name = "durableContext") DurableClientContext durableContext, + final ExecutionContext context) { + String name = request.getQueryParameters().get("name"); + String key = request.getQueryParameters().get("key"); + if (name == null || key == null) { + return request.createResponseBuilder(HttpStatus.BAD_REQUEST) + .body("Missing 'name' or 'key' query parameter") + .build(); + } + + EntityInstanceId entityId = new EntityInstanceId(name, key); + EntityMetadata metadata = durableContext.getEntityMetadata(entityId, true); + if (metadata == null) { + return request.createResponseBuilder(HttpStatus.NOT_FOUND) + .body("Entity not found: " + name + "/" + key) + .build(); + } + + String serializedState = metadata.getSerializedState(); + context.getLogger().info("Entity " + name + "/" + key + " state: " + serializedState); + return request.createResponseBuilder(HttpStatus.OK) + .header("Content-Type", "application/json") + .body(serializedState) + .build(); + } + + // ─── Helpers ─── + + /** + * Payload for entity orchestrations. + */ + public static class EntityPayload { + public String entityKey; + public int addValue; + + public EntityPayload() { + } + + public EntityPayload(String entityKey, int addValue) { + this.entityKey = entityKey; + this.addValue = addValue; + } + } +} diff --git a/endtoendtests/src/test/java/com/functions/EndToEndTests.java b/endtoendtests/src/test/java/com/functions/EndToEndTests.java index e06b71e9..26973af8 100644 --- a/endtoendtests/src/test/java/com/functions/EndToEndTests.java +++ b/endtoendtests/src/test/java/com/functions/EndToEndTests.java @@ -411,6 +411,162 @@ public void VersionedSubOrchestrationTests(String version) throws InterruptedExc } } + // ─── Entity tests ─── + + /** + * Tests callEntity: orchestrator calls counter entity "add" and returns the result. + * Uses value=7 (not 0) so the expected output (7) differs from the initial state (0), + * proving the entity actually processed the add operation. + * Also verifies entity state directly via GetEntityState endpoint. + */ + @Test + public void callEntityTest() throws InterruptedException { + Set continueStates = new HashSet<>(); + continueStates.add("Pending"); + continueStates.add("Running"); + String entityKey = "call-test-" + System.currentTimeMillis(); + Response response = post("/api/StartCallEntityOrchestration?key=" + entityKey + "&value=7"); + JsonPath jsonPath = response.jsonPath(); + String statusQueryGetUri = jsonPath.get("statusQueryGetUri"); + boolean completed = pollingCheck(statusQueryGetUri, "Completed", continueStates, Duration.ofSeconds(30)); + assertTrue(completed, "CallEntityOrchestration should complete"); + + // Verify orchestration output + Response statusResponse = get(statusQueryGetUri); + int output = statusResponse.jsonPath().get("output"); + assertEquals(7, output, "Counter entity should return the added value (0 + 7 = 7)"); + + // Verify the actual entity state directly + int entityState = getEntityStateValue("counter", entityKey); + assertEquals(7, entityState, "Entity state should be 7 after add(7)"); + } + + /** + * Tests callEntity called twice: orchestrator calls add(3) then add(5) on the same entity. + * The final output is 8 (3+5), which differs from either input, proving the entity + * accumulates state across calls and doesn't just echo the input. + * Also verifies entity state directly. + */ + @Test + public void callEntityTwiceTest() throws InterruptedException { + Set continueStates = new HashSet<>(); + continueStates.add("Pending"); + continueStates.add("Running"); + String entityKey = "twice-test-" + System.currentTimeMillis(); + // value=3 -> orchestration calls add(3) then add(3+2=5), expecting 3+5=8 + Response response = post("/api/StartCallEntityTwiceOrchestration?key=" + entityKey + "&value=3"); + JsonPath jsonPath = response.jsonPath(); + String statusQueryGetUri = jsonPath.get("statusQueryGetUri"); + boolean completed = pollingCheck(statusQueryGetUri, "Completed", continueStates, Duration.ofSeconds(30)); + assertTrue(completed, "CallEntityTwiceOrchestration should complete"); + + // Verify orchestration output: add(3) -> 3, add(5) -> 8, returns 8 + Response statusResponse = get(statusQueryGetUri); + int output = statusResponse.jsonPath().get("output"); + assertEquals(8, output, "Counter entity should return cumulative value (3 + 5 = 8)"); + assertNotEquals(3, output, "Output should differ from first input"); + assertNotEquals(5, output, "Output should differ from second input"); + + // Verify the actual entity state directly + int entityState = getEntityStateValue("counter", entityKey); + assertEquals(8, entityState, "Entity state should be 8 after add(3) + add(5)"); + } + + /** + * Comprehensive entity test: exercises signal, call, and reset in a single orchestration. + * Steps: signal add(5) -> call get (5) -> signal add(10) -> call get (15) -> signal reset -> call get (0). + * Verifies the orchestration output string and entity state. + */ + @Test + public void comprehensiveEntityTest() throws InterruptedException { + Set continueStates = new HashSet<>(); + continueStates.add("Pending"); + continueStates.add("Running"); + String entityKey = "comprehensive-test-" + System.currentTimeMillis(); + Response response = post("/api/StartComprehensiveEntityOrchestration?key=" + entityKey); + JsonPath jsonPath = response.jsonPath(); + String statusQueryGetUri = jsonPath.get("statusQueryGetUri"); + boolean completed = pollingCheck(statusQueryGetUri, "Completed", continueStates, Duration.ofSeconds(60)); + assertTrue(completed, "ComprehensiveEntityOrchestration should complete"); + + // Verify orchestration output contains the pass summary + Response statusResponse = get(statusQueryGetUri); + String output = statusResponse.jsonPath().get("output"); + assertTrue(output.contains("All tests passed: true"), + "Comprehensive entity test should pass all steps. Output:\n" + output); + assertTrue(output.contains("Step 2: callEntity get() returned 5"), "Step 2 should return 5"); + assertTrue(output.contains("Step 4: callEntity get() returned 15"), "Step 4 should return 15"); + assertTrue(output.contains("Step 6: callEntity get() returned 0"), "Step 6 should return 0 after reset"); + + // Verify the actual entity state directly (should be 0 after reset) + int entityState = getEntityStateValue("counter", entityKey); + assertEquals(0, entityState, "Entity state should be 0 after reset"); + } + + /** + * Tests signalEntity + callEntity: orchestrator signals counter entity to "add", + * then calls "get" to verify the updated state. + * Also verifies entity state directly via GetEntityState endpoint. + */ + @Test + public void signalThenCallEntityTest() throws InterruptedException { + Set continueStates = new HashSet<>(); + continueStates.add("Pending"); + continueStates.add("Running"); + String entityKey = "signal-test-" + System.currentTimeMillis(); + Response response = post("/api/StartSignalThenCallEntityOrchestration?key=" + entityKey + "&value=10"); + JsonPath jsonPath = response.jsonPath(); + String statusQueryGetUri = jsonPath.get("statusQueryGetUri"); + boolean completed = pollingCheck(statusQueryGetUri, "Completed", continueStates, Duration.ofSeconds(30)); + assertTrue(completed, "SignalThenCallEntityOrchestration should complete"); + + // Verify orchestration output + Response statusResponse = get(statusQueryGetUri); + int output = statusResponse.jsonPath().get("output"); + assertEquals(10, output, "Counter entity should return the signaled value after get"); + + // Verify the actual entity state directly + int entityState = getEntityStateValue("counter", entityKey); + assertEquals(10, entityState, "Entity state should be 10 after signal add(10)"); + } + + /** + * Tests callEntity on a fresh entity: orchestrator calls "get" on a new counter entity, + * which should return zero (the initial state). + * Also verifies entity state directly via GetEntityState endpoint. + */ + @Test + public void callEntityGetInitialStateTest() throws InterruptedException { + Set continueStates = new HashSet<>(); + continueStates.add("Pending"); + continueStates.add("Running"); + String entityKey = "get-test-" + System.currentTimeMillis(); + Response response = post("/api/StartCallEntityGetOrchestration?key=" + entityKey); + JsonPath jsonPath = response.jsonPath(); + String statusQueryGetUri = jsonPath.get("statusQueryGetUri"); + boolean completed = pollingCheck(statusQueryGetUri, "Completed", continueStates, Duration.ofSeconds(30)); + assertTrue(completed, "CallEntityGetOrchestration should complete"); + + // Verify orchestration output + Response statusResponse = get(statusQueryGetUri); + int output = statusResponse.jsonPath().get("output"); + assertEquals(0, output, "Fresh counter entity should return initial state of 0"); + + // Verify the actual entity state directly + int entityState = getEntityStateValue("counter", entityKey); + assertEquals(0, entityState, "Entity state should be 0 for fresh entity"); + } + + /** + * Queries the entity state directly via the GetEntityState HTTP endpoint. + */ + private int getEntityStateValue(String entityName, String entityKey) { + Response response = get("/api/GetEntityState?name=" + entityName + "&key=" + entityKey); + assertEquals(200, response.getStatusCode(), + "GetEntityState should return 200, got: " + response.getStatusCode() + " body: " + response.getBody().asString()); + return Integer.parseInt(response.getBody().asString().trim()); + } + private boolean pollingCheck(String statusQueryGetUri, String expectedState, Set continueStates, diff --git a/internal/durabletask-protobuf/PROTO_SOURCE_COMMIT_HASH b/internal/durabletask-protobuf/PROTO_SOURCE_COMMIT_HASH index fdb90d6a..0ef1ed22 100644 --- a/internal/durabletask-protobuf/PROTO_SOURCE_COMMIT_HASH +++ b/internal/durabletask-protobuf/PROTO_SOURCE_COMMIT_HASH @@ -1 +1 @@ -026329c53fe6363985655857b9ca848ec7238bd2 \ No newline at end of file +1caadbd7ecfdf5f2309acbeac28a3e36d16aa156 \ No newline at end of file diff --git a/internal/durabletask-protobuf/protos/orchestrator_service.proto b/internal/durabletask-protobuf/protos/orchestrator_service.proto index 8ef46a4a..0c34d986 100644 --- a/internal/durabletask-protobuf/protos/orchestrator_service.proto +++ b/internal/durabletask-protobuf/protos/orchestrator_service.proto @@ -822,6 +822,7 @@ message GetWorkItemsRequest { int32 maxConcurrentEntityWorkItems = 3; repeated WorkerCapability capabilities = 10; + WorkItemFilters workItemFilters = 11; } enum WorkerCapability { @@ -844,6 +845,26 @@ enum WorkerCapability { WORKER_CAPABILITY_LARGE_PAYLOADS = 3; } +message WorkItemFilters { + repeated OrchestrationFilter orchestrations = 1; + repeated ActivityFilter activities = 2; + repeated EntityFilter entities = 3; +} + +message OrchestrationFilter { + string name = 1; + repeated string versions = 2; +} + +message ActivityFilter { + string name = 1; + repeated string versions = 2; +} + +message EntityFilter { + string name = 1; +} + message WorkItem { oneof request { OrchestratorRequest orchestratorRequest = 1; diff --git a/samples-azure-functions/src/main/java/com/functions/entities/CounterEntity.java b/samples-azure-functions/src/main/java/com/functions/entities/CounterEntity.java new file mode 100644 index 00000000..f46dd9ff --- /dev/null +++ b/samples-azure-functions/src/main/java/com/functions/entities/CounterEntity.java @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.functions.entities; + +import com.microsoft.durabletask.TaskEntity; +import com.microsoft.durabletask.TaskEntityOperation; + +/** + * A simple counter entity that demonstrates the three main dispatch modes: + *
    + *
  • Entity dispatch (class-based) — extends {@link TaskEntity} directly
  • + *
  • State dispatch (POJO) — dispatches operations to the state object itself
  • + *
  • Manual dispatch (low-level) — uses {@code ITaskEntity} for manual operation routing
  • + *
+ *

+ * This sample mirrors the .NET {@code Counter.cs} sample from + * {@code durabletask-dotnet/samples/AzureFunctionsApp/Entities/}. + */ +public class CounterEntity extends TaskEntity { + + /** + * Adds the given input to the current counter state and returns the new value. + */ + public int add(int input) { + this.state += input; + return this.state; + } + + /** + * Returns the current counter value. + */ + public int get() { + return this.state; + } + + /** + * Resets the counter to zero. + */ + public void reset() { + this.state = 0; + } + + @Override + protected Integer initializeState(TaskEntityOperation operation) { + return 0; + } + + @Override + protected Class getStateType() { + return Integer.class; + } +} diff --git a/samples-azure-functions/src/main/java/com/functions/entities/CounterFunctions.java b/samples-azure-functions/src/main/java/com/functions/entities/CounterFunctions.java new file mode 100644 index 00000000..e4abec93 --- /dev/null +++ b/samples-azure-functions/src/main/java/com/functions/entities/CounterFunctions.java @@ -0,0 +1,207 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.functions.entities; + +import com.microsoft.azure.functions.*; +import com.microsoft.azure.functions.annotation.*; +import com.microsoft.durabletask.*; +import com.microsoft.durabletask.azurefunctions.DurableClientContext; +import com.microsoft.durabletask.azurefunctions.DurableClientInput; +import com.microsoft.durabletask.azurefunctions.DurableEntityTrigger; +import com.microsoft.durabletask.azurefunctions.DurableOrchestrationTrigger; + +import java.util.Optional; + +/** + * Azure Functions for the Counter entity sample. + *

+ * Demonstrates three entity dispatch modes: + *

    + *
  • {@code mode=entity} (default) — dispatches to {@link CounterEntity} ({@code TaskEntity})
  • + *
  • {@code mode=state} — dispatches to {@link StateCounterEntity} (POJO state dispatch)
  • + *
  • {@code mode=manual} — dispatches to {@link ManualCounterEntity} ({@code ITaskEntity})
  • + *
+ *

+ * This mirrors the .NET {@code Counter.cs} and {@code CounterApis} from + * {@code durabletask-dotnet/samples/AzureFunctionsApp/Entities/}. + *

+ * See {@code counters.http} for example HTTP requests. + */ +public class CounterFunctions { + + // ─── Entity trigger functions ─── + + /** + * Entity function for the class-based counter ({@link CounterEntity}). + */ + @FunctionName("Counter") + public String counterEntity( + @DurableEntityTrigger(name = "req") String req) { + return EntityRunner.loadAndRun(req, CounterEntity::new); + } + + /** + * Entity function for the state-dispatch counter ({@link StateCounterEntity}). + */ + @FunctionName("Counter_State") + public String counterStateEntity( + @DurableEntityTrigger(name = "req") String req) { + return EntityRunner.loadAndRun(req, StateCounterEntity::new); + } + + /** + * Entity function for the manual (low-level) counter ({@link ManualCounterEntity}). + */ + @FunctionName("Counter_Manual") + public String counterManualEntity( + @DurableEntityTrigger(name = "req") String req) { + return EntityRunner.loadAndRun(req, ManualCounterEntity::new); + } + + // ─── Orchestration ─── + + /** + * Orchestration that calls the counter entity to add a value and return the result. + */ + @FunctionName("CounterOrchestration") + public int counterOrchestration( + @DurableOrchestrationTrigger(name = "ctx") TaskOrchestrationContext ctx) { + CounterPayload input = ctx.getInput(CounterPayload.class); + return ctx.callEntity(input.entityId, "add", input.addValue, Integer.class).await(); + } + + // ─── HTTP API functions ─── + + /** + * POST /api/counters/{id}/add/{value}?mode={mode} + *

+ * Starts an orchestration that calls the counter entity to add a value. + */ + @FunctionName("Counter_Add") + public HttpResponseMessage counterAdd( + @HttpTrigger(name = "req", methods = {HttpMethod.POST}, + route = "counters/{id}/add/{value}", + authLevel = AuthorizationLevel.ANONYMOUS) HttpRequestMessage> request, + @DurableClientInput(name = "durableContext") DurableClientContext durableContext, + @BindingName("id") String id, + @BindingName("value") int value, + final ExecutionContext context) { + EntityInstanceId entityId = getEntityId(request, id); + DurableTaskClient client = durableContext.getClient(); + CounterPayload payload = new CounterPayload(entityId, value); + String instanceId = client.scheduleNewOrchestrationInstance("CounterOrchestration", payload); + context.getLogger().info("Started CounterOrchestration: " + instanceId); + return durableContext.createCheckStatusResponse(request, instanceId); + } + + /** + * GET /api/counters/{id}?mode={mode} + *

+ * Gets the current state of the counter entity. + */ + @FunctionName("Counter_Get") + public HttpResponseMessage counterGet( + @HttpTrigger(name = "req", methods = {HttpMethod.GET}, + route = "counters/{id}", + authLevel = AuthorizationLevel.ANONYMOUS) HttpRequestMessage> request, + @DurableClientInput(name = "durableContext") DurableClientContext durableContext, + @BindingName("id") String id, + final ExecutionContext context) { + EntityInstanceId entityId = getEntityId(request, id); + DurableTaskClient client = durableContext.getClient(); + + TypedEntityMetadata entity = client.getEntities().getEntityMetadata(entityId, Integer.class); + if (entity == null) { + return request.createResponseBuilder(HttpStatus.NOT_FOUND).build(); + } + + return request.createResponseBuilder(HttpStatus.OK) + .header("Content-Type", "application/json") + .body(entity) + .build(); + } + + /** + * DELETE /api/counters/{id}?mode={mode} + *

+ * Deletes the counter entity using the built-in implicit "delete" operation. + */ + @FunctionName("Counter_Delete") + public HttpResponseMessage counterDelete( + @HttpTrigger(name = "req", methods = {HttpMethod.DELETE}, + route = "counters/{id}", + authLevel = AuthorizationLevel.ANONYMOUS) HttpRequestMessage> request, + @DurableClientInput(name = "durableContext") DurableClientContext durableContext, + @BindingName("id") String id, + final ExecutionContext context) { + EntityInstanceId entityId = getEntityId(request, id); + DurableTaskClient client = durableContext.getClient(); + client.getEntities().signalEntity(entityId, "delete"); + return request.createResponseBuilder(HttpStatus.ACCEPTED).build(); + } + + /** + * POST /api/counters/{id}/reset?mode={mode} + *

+ * Signals the counter entity to reset to zero. + */ + @FunctionName("Counter_Reset") + public HttpResponseMessage counterReset( + @HttpTrigger(name = "req", methods = {HttpMethod.POST}, + route = "counters/{id}/reset", + authLevel = AuthorizationLevel.ANONYMOUS) HttpRequestMessage> request, + @DurableClientInput(name = "durableContext") DurableClientContext durableContext, + @BindingName("id") String id, + final ExecutionContext context) { + EntityInstanceId entityId = getEntityId(request, id); + DurableTaskClient client = durableContext.getClient(); + client.getEntities().signalEntity(entityId, "reset"); + return request.createResponseBuilder(HttpStatus.ACCEPTED).build(); + } + + // ─── Helpers ─── + + /** + * Resolves the entity name based on the {@code mode} query parameter. + */ + private static EntityInstanceId getEntityId(HttpRequestMessage request, String key) { + String mode = request.getQueryParameters().get("mode"); + String name; + if (mode == null) { + name = "counter"; + } else { + switch (mode.toLowerCase()) { + case "1": + case "state": + name = "counter_state"; + break; + case "2": + case "manual": + name = "counter_manual"; + break; + case "0": + case "entity": + default: + name = "counter"; + break; + } + } + return new EntityInstanceId(name, key); + } + + /** + * Payload for the CounterOrchestration. + */ + public static class CounterPayload { + public EntityInstanceId entityId; + public int addValue; + + public CounterPayload() { + } + + public CounterPayload(EntityInstanceId entityId, int addValue) { + this.entityId = entityId; + this.addValue = addValue; + } + } +} diff --git a/samples-azure-functions/src/main/java/com/functions/entities/LifetimeEntity.java b/samples-azure-functions/src/main/java/com/functions/entities/LifetimeEntity.java new file mode 100644 index 00000000..4f18199c --- /dev/null +++ b/samples-azure-functions/src/main/java/com/functions/entities/LifetimeEntity.java @@ -0,0 +1,77 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.functions.entities; + +import com.microsoft.durabletask.TaskEntity; +import com.microsoft.durabletask.TaskEntityOperation; + +import java.util.Random; +import java.util.UUID; + +/** + * Entity that demonstrates the lifecycle of a durable entity. + *

+ * An entity is initialized on the first operation it receives and is considered deleted when + * its state is {@code null} at the end of an operation. This mirrors the .NET + * {@code Lifetime} entity from {@code Lifetime.cs}. + *

+ * Key concepts demonstrated: + *

    + *
  • {@link #initializeState} — customizes the default state for a new entity
  • + *
  • {@link #getAllowStateDispatch} — controls whether operations can be dispatched to the state object
  • + *
  • {@link #customDelete} — shows that deletion can happen from any operation by nulling state
  • + *
  • Implicit {@code delete} — handled by the base class when no matching method is found
  • + *
+ */ +public class LifetimeEntity extends TaskEntity { + + private static final Random RANDOM = new Random(); + + /** + * Returns the current entity state. + */ + public LifetimeState get() { + return this.state; + } + + /** + * No-op operation that simply initializes the entity if it doesn't already exist. + */ + public void init() { + // No-op — just triggers entity initialization via initializeState + } + + /** + * Demonstrates that entity deletion can be accomplished from any operation by nulling out the state. + * The operation doesn't have to be named "delete" — the only requirement for deletion is that + * state is {@code null} when the operation returns. + */ + public void customDelete() { + this.state = null; + } + + /** + * Explicitly handles the "delete" operation. + *

+ * Entities have an implicit "delete" operation when there is no matching "delete" method. + * By explicitly adding a delete method, it overrides the implicit behavior. + * Since state deletion is determined by nulling {@code this.state}, value-types cannot be + * deleted except by the implicit delete (which will still delete them). + */ + public void delete() { + this.state = null; + } + + @Override + protected LifetimeState initializeState(TaskEntityOperation operation) { + // Customizes the default state value for a new entity + return new LifetimeState( + UUID.randomUUID().toString().replace("-", ""), + RANDOM.nextInt(1000)); + } + + @Override + protected Class getStateType() { + return LifetimeState.class; + } +} diff --git a/samples-azure-functions/src/main/java/com/functions/entities/LifetimeFunctions.java b/samples-azure-functions/src/main/java/com/functions/entities/LifetimeFunctions.java new file mode 100644 index 00000000..d014a223 --- /dev/null +++ b/samples-azure-functions/src/main/java/com/functions/entities/LifetimeFunctions.java @@ -0,0 +1,108 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.functions.entities; + +import com.microsoft.azure.functions.*; +import com.microsoft.azure.functions.annotation.*; +import com.microsoft.durabletask.*; +import com.microsoft.durabletask.azurefunctions.DurableClientContext; +import com.microsoft.durabletask.azurefunctions.DurableClientInput; +import com.microsoft.durabletask.azurefunctions.DurableEntityTrigger; + +import java.util.Optional; + +/** + * Azure Functions for the Lifetime entity sample. + *

+ * Demonstrates entity lifecycle: initialization via {@code initializeState}, + * custom deletion (nulling state), and the implicit "delete" operation. + *

+ * This mirrors the .NET {@code Lifetime.cs} and {@code LifetimeApis} from + * {@code durabletask-dotnet/samples/AzureFunctionsApp/Entities/}. + *

+ * See {@code lifetimes.http} for example HTTP requests. + */ +public class LifetimeFunctions { + + // ─── Entity trigger function ─── + + @FunctionName("Lifetime") + public String lifetimeEntity( + @DurableEntityTrigger(name = "req") String req) { + return EntityRunner.loadAndRun(req, LifetimeEntity::new); + } + + // ─── HTTP API functions ─── + + /** + * GET /api/lifetimes/{id} + *

+ * Gets the current state of the lifetime entity. + */ + @FunctionName("Lifetime_Get") + public HttpResponseMessage lifetimeGet( + @HttpTrigger(name = "req", methods = {HttpMethod.GET}, + route = "lifetimes/{id}", + authLevel = AuthorizationLevel.ANONYMOUS) HttpRequestMessage> request, + @DurableClientInput(name = "durableContext") DurableClientContext durableContext, + @BindingName("id") String id, + final ExecutionContext context) { + EntityInstanceId entityId = new EntityInstanceId("Lifetime", id); + DurableTaskClient client = durableContext.getClient(); + + TypedEntityMetadata entity = client.getEntities() + .getEntityMetadata(entityId, LifetimeState.class); + if (entity == null) { + return request.createResponseBuilder(HttpStatus.NOT_FOUND).build(); + } + + return request.createResponseBuilder(HttpStatus.OK) + .header("Content-Type", "application/json") + .body(entity) + .build(); + } + + /** + * PUT /api/lifetimes/{id} + *

+ * Initializes the lifetime entity by sending an "init" signal. + */ + @FunctionName("Lifetime_Init") + public HttpResponseMessage lifetimeInit( + @HttpTrigger(name = "req", methods = {HttpMethod.PUT}, + route = "lifetimes/{id}", + authLevel = AuthorizationLevel.ANONYMOUS) HttpRequestMessage> request, + @DurableClientInput(name = "durableContext") DurableClientContext durableContext, + @BindingName("id") String id, + final ExecutionContext context) { + EntityInstanceId entityId = new EntityInstanceId("Lifetime", id); + DurableTaskClient client = durableContext.getClient(); + client.getEntities().signalEntity(entityId, "init"); + return request.createResponseBuilder(HttpStatus.ACCEPTED).build(); + } + + /** + * DELETE /api/lifetimes/{id}?custom={true|false} + *

+ * Deletes the lifetime entity. If {@code custom=true}, uses the {@code customDelete} + * operation; otherwise uses the standard {@code delete} operation. + */ + @FunctionName("Lifetime_Delete") + public HttpResponseMessage lifetimeDelete( + @HttpTrigger(name = "req", methods = {HttpMethod.DELETE}, + route = "lifetimes/{id}", + authLevel = AuthorizationLevel.ANONYMOUS) HttpRequestMessage> request, + @DurableClientInput(name = "durableContext") DurableClientContext durableContext, + @BindingName("id") String id, + final ExecutionContext context) { + EntityInstanceId entityId = new EntityInstanceId("Lifetime", id); + DurableTaskClient client = durableContext.getClient(); + + String customParam = request.getQueryParameters().get("custom"); + boolean useCustomDelete = "true".equalsIgnoreCase(customParam); + String operation = useCustomDelete ? "customDelete" : "delete"; + + client.getEntities().signalEntity(entityId, operation); + return request.createResponseBuilder(HttpStatus.ACCEPTED).build(); + } +} diff --git a/samples-azure-functions/src/main/java/com/functions/entities/LifetimeState.java b/samples-azure-functions/src/main/java/com/functions/entities/LifetimeState.java new file mode 100644 index 00000000..083f1cd6 --- /dev/null +++ b/samples-azure-functions/src/main/java/com/functions/entities/LifetimeState.java @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.functions.entities; + +/** + * State class for the {@link LifetimeEntity}. + *

+ * This mirrors the .NET {@code MyState} record from {@code Lifetime.cs}. + */ +public class LifetimeState { + private String propA; + private int propB; + + public LifetimeState() { + } + + public LifetimeState(String propA, int propB) { + this.propA = propA; + this.propB = propB; + } + + public String getPropA() { + return propA; + } + + public void setPropA(String propA) { + this.propA = propA; + } + + public int getPropB() { + return propB; + } + + public void setPropB(int propB) { + this.propB = propB; + } + + @Override + public String toString() { + return "LifetimeState{propA='" + propA + "', propB=" + propB + "}"; + } +} diff --git a/samples-azure-functions/src/main/java/com/functions/entities/ManualCounterEntity.java b/samples-azure-functions/src/main/java/com/functions/entities/ManualCounterEntity.java new file mode 100644 index 00000000..ce18eb0d --- /dev/null +++ b/samples-azure-functions/src/main/java/com/functions/entities/ManualCounterEntity.java @@ -0,0 +1,43 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.functions.entities; + +import com.microsoft.durabletask.ITaskEntity; +import com.microsoft.durabletask.TaskEntityOperation; + +import java.util.Map; + +/** + * A low-level counter entity that uses {@link ITaskEntity} with manual operation dispatch. + *

+ * This mirrors the .NET {@code ManualCounter} from {@code Counter.cs} and demonstrates + * how to process entity operations without the convenience of {@link com.microsoft.durabletask.TaskEntity}. + */ +public class ManualCounterEntity implements ITaskEntity { + + @Override + public Object runAsync(TaskEntityOperation operation) { + if (operation.getState().getState(Integer.class) == null) { + operation.getState().setState(0); + } + + switch (operation.getName().toLowerCase()) { + case "add": + int state = operation.getState().getState(Integer.class); + state += operation.getInput(int.class); + operation.getState().setState(state); + return state; + case "reset": + operation.getState().setState(0); + return null; + case "get": + Integer current = operation.getState().getState(Integer.class); + return current != null ? current : 0; + case "delete": + operation.getState().setState(null); + return null; + default: + return null; + } + } +} diff --git a/samples-azure-functions/src/main/java/com/functions/entities/StateCounterEntity.java b/samples-azure-functions/src/main/java/com/functions/entities/StateCounterEntity.java new file mode 100644 index 00000000..0cf2253d --- /dev/null +++ b/samples-azure-functions/src/main/java/com/functions/entities/StateCounterEntity.java @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.functions.entities; + +import com.microsoft.durabletask.TaskEntity; +import com.microsoft.durabletask.TaskEntityOperation; + +/** + * A counter entity that uses state dispatch (POJO pattern). + *

+ * In this pattern, the state object itself is the entity. All operations are dispatched + * to public methods on the {@link StateCounterState} POJO. This mirrors the .NET + * {@code StateCounter} from {@code Counter.cs}. + *

+ * Note the structural difference between {@link CounterEntity} and {@code StateCounterEntity}: + *

    + *
  • {@code CounterEntity}: state is {@code Integer} — serialized as just a number
  • + *
  • {@code StateCounterEntity}: state is {@link StateCounterState} — serialized as + * {@code {"value": number}}
  • + *
+ */ +public class StateCounterEntity extends TaskEntity { + + public StateCounterEntity() { + // Enable state dispatch so operations are forwarded to methods on StateCounterState + this.setAllowStateDispatch(true); + } + + @Override + protected StateCounterState initializeState(TaskEntityOperation operation) { + return new StateCounterState(); + } + + @Override + protected Class getStateType() { + return StateCounterState.class; + } +} diff --git a/samples-azure-functions/src/main/java/com/functions/entities/StateCounterState.java b/samples-azure-functions/src/main/java/com/functions/entities/StateCounterState.java new file mode 100644 index 00000000..20748b60 --- /dev/null +++ b/samples-azure-functions/src/main/java/com/functions/entities/StateCounterState.java @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.functions.entities; + +/** + * POJO state class used by {@link StateCounterEntity} for state dispatch. + *

+ * When using state dispatch, the entire object is serialized/deserialized as the entity state. + * Operations are dispatched to public methods on this class. + */ +public class StateCounterState { + private int value; + + public StateCounterState() { + this.value = 0; + } + + public int getValue() { + return value; + } + + public void setValue(int value) { + this.value = value; + } + + public int add(int input) { + this.value += input; + return this.value; + } + + public int get() { + return this.value; + } + + public void reset() { + this.value = 0; + } +} diff --git a/samples-azure-functions/src/main/java/com/functions/entities/UserEntity.java b/samples-azure-functions/src/main/java/com/functions/entities/UserEntity.java new file mode 100644 index 00000000..383b32de --- /dev/null +++ b/samples-azure-functions/src/main/java/com/functions/entities/UserEntity.java @@ -0,0 +1,86 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.functions.entities; + +import com.microsoft.durabletask.TaskEntity; +import com.microsoft.durabletask.TaskEntityContext; +import com.microsoft.durabletask.TaskEntityOperation; + +/** + * Entity that demonstrates complex state management and interaction with orchestrations. + *

+ * This mirrors the .NET {@code UserEntity} from {@code User.cs} and shows: + *

    + *
  • Setting and updating complex state
  • + *
  • Using {@link TaskEntityContext} to schedule orchestrations from within an entity
  • + *
  • Custom {@link #initializeState} for types without a no-arg constructor pattern
  • + *
  • Implicit "delete" handling from the {@link TaskEntity} base class
  • + *
+ */ +public class UserEntity extends TaskEntity { + + /** + * Sets the user state to the given value. + */ + public void set(UserState user) { + this.state = user; + } + + /** + * Partially updates the user state. Only non-null fields are applied. + */ + public void update(UserUpdate update) { + String newName = update.getName() != null ? update.getName() : this.state.getName(); + int newAge = update.getAge() != null ? update.getAge() : this.state.getAge(); + this.state = new UserState(newName, newAge); + } + + /** + * Starts a greeting orchestration for this user. + *

+ * Demonstrates using {@link TaskEntityContext} to schedule a new orchestration + * from within an entity operation. The context is accessible via {@code this.context}. + * + * @param message optional custom greeting message (may be null) + */ + public void greet(String message) { + if (this.state.getName() == null) { + throw new IllegalStateException("User has not been initialized."); + } + + // Access the TaskEntityContext to schedule an orchestration from within the entity + GreetingInput input = new GreetingInput( + this.state.getName(), this.state.getAge(), message); + this.context.startNewOrchestration("GreetingOrchestration", input); + } + + @Override + protected UserState initializeState(TaskEntityOperation operation) { + // UserState doesn't need special initialization, but this shows + // how to customize default state (mirroring .NET's new User(null!, -1)) + return new UserState(null, -1); + } + + @Override + protected Class getStateType() { + return UserState.class; + } + + /** + * Input payload for the GreetingOrchestration. + */ + public static class GreetingInput { + public String name; + public int age; + public String customMessage; + + public GreetingInput() { + } + + public GreetingInput(String name, int age, String customMessage) { + this.name = name; + this.age = age; + this.customMessage = customMessage; + } + } +} diff --git a/samples-azure-functions/src/main/java/com/functions/entities/UserFunctions.java b/samples-azure-functions/src/main/java/com/functions/entities/UserFunctions.java new file mode 100644 index 00000000..f623cc62 --- /dev/null +++ b/samples-azure-functions/src/main/java/com/functions/entities/UserFunctions.java @@ -0,0 +1,228 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.functions.entities; + +import com.microsoft.azure.functions.*; +import com.microsoft.azure.functions.annotation.*; +import com.microsoft.durabletask.*; +import com.microsoft.durabletask.azurefunctions.DurableActivityTrigger; +import com.microsoft.durabletask.azurefunctions.DurableClientContext; +import com.microsoft.durabletask.azurefunctions.DurableClientInput; +import com.microsoft.durabletask.azurefunctions.DurableEntityTrigger; +import com.microsoft.durabletask.azurefunctions.DurableOrchestrationTrigger; + +import java.util.Optional; + +/** + * Azure Functions for the User entity sample. + *

+ * Demonstrates: + *

    + *
  • Complex entity state (UserState) with set/update operations
  • + *
  • Entity-triggered orchestrations via {@code TaskEntityContext.startNewOrchestration}
  • + *
  • Implicit "delete" operation from the base class
  • + *
+ *

+ * This mirrors the .NET {@code User.cs}, {@code UserApis}, and {@code Greeting.cs} from + * {@code durabletask-dotnet/samples/AzureFunctionsApp/Entities/}. + *

+ * See {@code users.http} for example HTTP requests. + * + *

APIs: + *

    + *
  • Create User: PUT /api/users/{id}?name={name}&age={age}
  • + *
  • Update User: PATCH /api/users/{id}?name={name}&age={age}
  • + *
  • Get User: GET /api/users/{id}
  • + *
  • Delete User: DELETE /api/users/{id}
  • + *
  • Greet User: POST /api/users/{id}/greet?message={message}
  • + *
+ */ +public class UserFunctions { + + // ─── Entity trigger function ─── + + @FunctionName("User") + public String userEntity( + @DurableEntityTrigger(name = "req") String req) { + return EntityRunner.loadAndRun(req, UserEntity::new); + } + + // ─── Greeting orchestration & activity ─── + + /** + * Orchestration that greets a user. Triggered from within the User entity's + * {@code greet} operation via {@code context.startNewOrchestration}. + */ + @FunctionName("GreetingOrchestration") + public UserEntity.GreetingInput greetingOrchestration( + @DurableOrchestrationTrigger(name = "ctx") TaskOrchestrationContext ctx) { + UserEntity.GreetingInput input = ctx.getInput(UserEntity.GreetingInput.class); + return ctx.callActivity("GreetingActivity", input, UserEntity.GreetingInput.class).await(); + } + + /** + * Activity that performs the greeting logic. + */ + @FunctionName("GreetingActivity") + public UserEntity.GreetingInput greetingActivity( + @DurableActivityTrigger(name = "input") UserEntity.GreetingInput input, + final ExecutionContext context) { + String message = input.customMessage != null + ? input.customMessage + : String.format("Hello, %s! You are %d years old.", input.name, input.age); + context.getLogger().info("Greeting: " + message); + return input; + } + + // ─── HTTP API functions ─── + + /** + * PUT /api/users/{id}?name={name}&age={age} + *

+ * Creates or overwrites a user entity. Both name and age are required. + */ + @FunctionName("PutUser") + public HttpResponseMessage putUser( + @HttpTrigger(name = "req", methods = {HttpMethod.PUT}, + route = "users/{id}", + authLevel = AuthorizationLevel.ANONYMOUS) HttpRequestMessage> request, + @DurableClientInput(name = "durableContext") DurableClientContext durableContext, + @BindingName("id") String id, + final ExecutionContext context) { + String name = request.getQueryParameters().get("name"); + String ageStr = request.getQueryParameters().get("age"); + + if (name == null || ageStr == null) { + return request.createResponseBuilder(HttpStatus.BAD_REQUEST) + .body("Both name and age must be provided.") + .build(); + } + + int age; + try { + age = Integer.parseInt(ageStr); + } catch (NumberFormatException e) { + return request.createResponseBuilder(HttpStatus.BAD_REQUEST) + .body("Age must be a valid integer.") + .build(); + } + + if (age < 0) { + return request.createResponseBuilder(HttpStatus.BAD_REQUEST) + .body("Age must be a positive integer.") + .build(); + } + + EntityInstanceId entityId = new EntityInstanceId("User", id); + DurableTaskClient client = durableContext.getClient(); + client.getEntities().signalEntity(entityId, "set", new UserState(name, age)); + return request.createResponseBuilder(HttpStatus.ACCEPTED).build(); + } + + /** + * PATCH /api/users/{id}?name={name}&age={age} + *

+ * Partially updates a user entity. Either name or age can be updated. + */ + @FunctionName("PatchUser") + public HttpResponseMessage patchUser( + @HttpTrigger(name = "req", methods = {HttpMethod.PATCH}, + route = "users/{id}", + authLevel = AuthorizationLevel.ANONYMOUS) HttpRequestMessage> request, + @DurableClientInput(name = "durableContext") DurableClientContext durableContext, + @BindingName("id") String id, + final ExecutionContext context) { + String name = request.getQueryParameters().get("name"); + String ageStr = request.getQueryParameters().get("age"); + + Integer age = null; + if (ageStr != null) { + try { + age = Integer.parseInt(ageStr); + } catch (NumberFormatException e) { + return request.createResponseBuilder(HttpStatus.BAD_REQUEST) + .body("Age must be a valid integer.") + .build(); + } + if (age < 0) { + return request.createResponseBuilder(HttpStatus.BAD_REQUEST) + .body("Age must be a positive integer.") + .build(); + } + } + + EntityInstanceId entityId = new EntityInstanceId("User", id); + DurableTaskClient client = durableContext.getClient(); + client.getEntities().signalEntity(entityId, "update", new UserUpdate(name, age)); + return request.createResponseBuilder(HttpStatus.ACCEPTED).build(); + } + + /** + * GET /api/users/{id} + *

+ * Gets the current state of the user entity. + */ + @FunctionName("GetUser") + public HttpResponseMessage getUser( + @HttpTrigger(name = "req", methods = {HttpMethod.GET}, + route = "users/{id}", + authLevel = AuthorizationLevel.ANONYMOUS) HttpRequestMessage> request, + @DurableClientInput(name = "durableContext") DurableClientContext durableContext, + @BindingName("id") String id, + final ExecutionContext context) { + EntityInstanceId entityId = new EntityInstanceId("User", id); + DurableTaskClient client = durableContext.getClient(); + + TypedEntityMetadata entity = client.getEntities() + .getEntityMetadata(entityId, UserState.class); + if (entity == null) { + return request.createResponseBuilder(HttpStatus.NOT_FOUND).build(); + } + + return request.createResponseBuilder(HttpStatus.OK) + .header("Content-Type", "application/json") + .body(entity.getState()) + .build(); + } + + /** + * DELETE /api/users/{id} + *

+ * Deletes the user entity using the implicit "delete" operation from {@code TaskEntity}. + */ + @FunctionName("DeleteUser") + public HttpResponseMessage deleteUser( + @HttpTrigger(name = "req", methods = {HttpMethod.DELETE}, + route = "users/{id}", + authLevel = AuthorizationLevel.ANONYMOUS) HttpRequestMessage> request, + @DurableClientInput(name = "durableContext") DurableClientContext durableContext, + @BindingName("id") String id, + final ExecutionContext context) { + EntityInstanceId entityId = new EntityInstanceId("User", id); + DurableTaskClient client = durableContext.getClient(); + // Even though UserEntity does not have a 'delete' method, the base class TaskEntity handles it + client.getEntities().signalEntity(entityId, "delete"); + return request.createResponseBuilder(HttpStatus.ACCEPTED).build(); + } + + /** + * POST /api/users/{id}/greet?message={message} + *

+ * Signals the user entity to initiate a greeting orchestration. + */ + @FunctionName("GreetUser") + public HttpResponseMessage greetUser( + @HttpTrigger(name = "req", methods = {HttpMethod.POST}, + route = "users/{id}/greet", + authLevel = AuthorizationLevel.ANONYMOUS) HttpRequestMessage> request, + @DurableClientInput(name = "durableContext") DurableClientContext durableContext, + @BindingName("id") String id, + final ExecutionContext context) { + EntityInstanceId entityId = new EntityInstanceId("User", id); + DurableTaskClient client = durableContext.getClient(); + + String message = request.getQueryParameters().get("message"); + client.getEntities().signalEntity(entityId, "greet", message); + return request.createResponseBuilder(HttpStatus.ACCEPTED).build(); + } +} diff --git a/samples-azure-functions/src/main/java/com/functions/entities/UserState.java b/samples-azure-functions/src/main/java/com/functions/entities/UserState.java new file mode 100644 index 00000000..c280b766 --- /dev/null +++ b/samples-azure-functions/src/main/java/com/functions/entities/UserState.java @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.functions.entities; + +/** + * State class for the {@link UserEntity}. + *

+ * This mirrors the .NET {@code User} record from {@code User.cs}. + */ +public class UserState { + private String name; + private int age; + + public UserState() { + } + + public UserState(String name, int age) { + this.name = name; + this.age = age; + } + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public int getAge() { + return age; + } + + public void setAge(int age) { + this.age = age; + } + + @Override + public String toString() { + return "UserState{name='" + name + "', age=" + age + "}"; + } +} diff --git a/samples-azure-functions/src/main/java/com/functions/entities/UserUpdate.java b/samples-azure-functions/src/main/java/com/functions/entities/UserUpdate.java new file mode 100644 index 00000000..a4bbe969 --- /dev/null +++ b/samples-azure-functions/src/main/java/com/functions/entities/UserUpdate.java @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.functions.entities; + +/** + * Partial update DTO for the User entity. + *

+ * Fields that are {@code null} are not updated. This mirrors the .NET {@code UserUpdate} record. + */ +public class UserUpdate { + private String name; + private Integer age; + + public UserUpdate() { + } + + public UserUpdate(String name, Integer age) { + this.name = name; + this.age = age; + } + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public Integer getAge() { + return age; + } + + public void setAge(Integer age) { + this.age = age; + } +} diff --git a/samples-azure-functions/src/main/java/com/functions/entities/counters.http b/samples-azure-functions/src/main/java/com/functions/entities/counters.http new file mode 100644 index 00000000..acc51263 --- /dev/null +++ b/samples-azure-functions/src/main/java/com/functions/entities/counters.http @@ -0,0 +1,22 @@ +@host = http://localhost:7071/api +@mode = 0 + +// The counter example shows 3 different entity dispatch modes. +// The mode query string controls this: +// mode=0 or mode=entity (default) - dispatch to "counter" entity (TaskEntity) +// mode=1 or mode=state - dispatch to "counter_state" entity (POJO state dispatch) +// mode=2 or mode=manual - dispatch to "counter_manual" entity (ITaskEntity) + +POST {{host}}/counters/1/add/10?mode={{mode}} + +### + +GET {{host}}/counters/1?mode={{mode}} + +### + +POST {{host}}/counters/1/reset?mode={{mode}} + +### + +DELETE {{host}}/counters/1?mode={{mode}} diff --git a/samples-azure-functions/src/main/java/com/functions/entities/lifetimes.http b/samples-azure-functions/src/main/java/com/functions/entities/lifetimes.http new file mode 100644 index 00000000..b1bbaba6 --- /dev/null +++ b/samples-azure-functions/src/main/java/com/functions/entities/lifetimes.http @@ -0,0 +1,15 @@ +@host = http://localhost:7071/api + +PUT {{host}}/lifetimes/1 + +### + +GET {{host}}/lifetimes/1 + +### + +DELETE {{host}}/lifetimes/1 + +### + +DELETE {{host}}/lifetimes/1?custom=true diff --git a/samples-azure-functions/src/main/java/com/functions/entities/users.http b/samples-azure-functions/src/main/java/com/functions/entities/users.http new file mode 100644 index 00000000..5767dd46 --- /dev/null +++ b/samples-azure-functions/src/main/java/com/functions/entities/users.http @@ -0,0 +1,19 @@ +@host = http://localhost:7071/api + +PUT {{host}}/users/1?name=John&age=21 + +### + +PATCH {{host}}/users/1?age=22 + +### + +GET {{host}}/users/1 + +### + +POST {{host}}/users/1/greet?message=custom greeting + +### + +DELETE {{host}}/users/1 diff --git a/samples/src/main/java/io/durabletask/samples/BankAccountSample.java b/samples/src/main/java/io/durabletask/samples/BankAccountSample.java new file mode 100644 index 00000000..699989e0 --- /dev/null +++ b/samples/src/main/java/io/durabletask/samples/BankAccountSample.java @@ -0,0 +1,184 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package io.durabletask.samples; + +import com.microsoft.durabletask.*; + +import java.io.IOException; +import java.time.Duration; +import java.util.Arrays; +import java.util.concurrent.TimeoutException; + +/** + * Sample demonstrating bank account entities with an atomic transfer orchestration. + *

+ * This sample shows two key features: + *

    + *
  • Entity operations: deposit, withdraw, and get on a bank account entity
  • + *
  • Entity locking: using {@code lockEntities} for atomic multi-entity operations
  • + *
+ *

+ * The transfer orchestration locks both the source and destination accounts, then performs + * a withdraw and deposit atomically. This prevents concurrent operations from seeing + * inconsistent state. + *

+ * Usage: Run this sample with a Durable Task sidecar listening on localhost:4001. + *

+ *   docker run -d -p 4001:4001 durabletask-sidecar
+ *   ./gradlew :samples:run -PmainClass=io.durabletask.samples.BankAccountSample
+ * 
+ */ +final class BankAccountSample { + + public static void main(String[] args) throws IOException, InterruptedException, TimeoutException { + // Build the worker with bank account entity and transfer orchestration + DurableTaskGrpcWorker worker = new DurableTaskGrpcWorkerBuilder() + .addEntity("BankAccount", BankAccountEntity::new) + .addOrchestration(new TaskOrchestrationFactory() { + @Override + public String getName() { return "SetupAccounts"; } + + @Override + public TaskOrchestration create() { + return ctx -> { + EntityInstanceId accountA = new EntityInstanceId("BankAccount", "account-A"); + EntityInstanceId accountB = new EntityInstanceId("BankAccount", "account-B"); + + // Initialize both accounts with a deposit + ctx.signalEntity(accountA, "deposit", 1000.0); + ctx.signalEntity(accountB, "deposit", 500.0); + + ctx.complete("Accounts initialized"); + }; + } + }) + .addOrchestration(new TaskOrchestrationFactory() { + @Override + public String getName() { return "TransferFunds"; } + + @Override + public TaskOrchestration create() { + return ctx -> { + TransferRequest request = ctx.getInput(TransferRequest.class); + EntityInstanceId source = new EntityInstanceId("BankAccount", request.sourceAccount); + EntityInstanceId dest = new EntityInstanceId("BankAccount", request.destAccount); + + // Lock both accounts to ensure atomic transfer + try (AutoCloseable lock = ctx.lockEntities(Arrays.asList(source, dest)).await()) { + // Withdraw from source + double sourceBalance = ctx.callEntity( + source, "withdraw", request.amount, Double.class).await(); + + // Deposit to destination + double destBalance = ctx.callEntity( + dest, "deposit", request.amount, Double.class).await(); + + String result = String.format( + "Transferred %.2f from %s (balance: %.2f) to %s (balance: %.2f)", + request.amount, request.sourceAccount, sourceBalance, + request.destAccount, destBalance); + ctx.complete(result); + } catch (Exception e) { + ctx.complete("Transfer failed: " + e.getMessage()); + } + }; + } + }) + .build(); + + worker.start(); + System.out.println("Worker started. BankAccount entity and TransferFunds orchestration registered."); + + DurableTaskClient client = new DurableTaskGrpcClientBuilder().build(); + + // Step 1: Initialize accounts + String setupId = client.scheduleNewOrchestrationInstance("SetupAccounts"); + System.out.printf("Setting up accounts (instance: %s)...%n", setupId); + client.waitForInstanceCompletion(setupId, Duration.ofSeconds(30), true); + System.out.println("Accounts initialized: A=$1000, B=$500"); + + // Step 2: Transfer $250 from account A to account B + TransferRequest transfer = new TransferRequest("account-A", "account-B", 250.0); + String transferId = client.scheduleNewOrchestrationInstance( + "TransferFunds", + new NewOrchestrationInstanceOptions().setInput(transfer)); + System.out.printf("Transferring funds (instance: %s)...%n", transferId); + + OrchestrationMetadata result = client.waitForInstanceCompletion( + transferId, Duration.ofSeconds(30), true); + System.out.printf("Transfer result: %s%n", result.readOutputAs(String.class)); + + // Step 3: Query final account balances + EntityMetadata accountA = client.getEntities().getEntityMetadata( + new EntityInstanceId("BankAccount", "account-A"), true); + EntityMetadata accountB = client.getEntities().getEntityMetadata( + new EntityInstanceId("BankAccount", "account-B"), true); + + if (accountA != null) { + System.out.printf("Account A balance: $%.2f%n", accountA.readStateAs(Double.class)); + } + if (accountB != null) { + System.out.printf("Account B balance: $%.2f%n", accountB.readStateAs(Double.class)); + } + + worker.stop(); + } + + /** + * A bank account entity that stores a balance and supports deposit, withdraw, and get operations. + */ + public static class BankAccountEntity extends TaskEntity { + + public double deposit(double amount) { + if (amount <= 0) { + throw new IllegalArgumentException("Deposit amount must be positive."); + } + this.state += amount; + return this.state; + } + + public double withdraw(double amount) { + if (amount <= 0) { + throw new IllegalArgumentException("Withdrawal amount must be positive."); + } + if (amount > this.state) { + throw new IllegalStateException( + String.format("Insufficient funds. Balance: %.2f, Requested: %.2f", this.state, amount)); + } + this.state -= amount; + return this.state; + } + + public double get() { + return this.state; + } + + @Override + protected Double initializeState(TaskEntityOperation operation) { + return 0.0; + } + + @Override + protected Class getStateType() { + return Double.class; + } + } + + /** + * Represents a request to transfer funds between two accounts. + */ + public static class TransferRequest { + public String sourceAccount; + public String destAccount; + public double amount; + + // Required for deserialization + public TransferRequest() {} + + public TransferRequest(String sourceAccount, String destAccount, double amount) { + this.sourceAccount = sourceAccount; + this.destAccount = destAccount; + this.amount = amount; + } + } +} diff --git a/samples/src/main/java/io/durabletask/samples/CounterEntitySample.java b/samples/src/main/java/io/durabletask/samples/CounterEntitySample.java new file mode 100644 index 00000000..51d472e6 --- /dev/null +++ b/samples/src/main/java/io/durabletask/samples/CounterEntitySample.java @@ -0,0 +1,111 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package io.durabletask.samples; + +import com.microsoft.durabletask.*; + +import java.io.IOException; +import java.time.Duration; +import java.util.concurrent.TimeoutException; + +/** + * Sample demonstrating a simple counter durable entity. + *

+ * The counter entity supports three operations: + *

    + *
  • {@code add} — adds an integer amount to the counter
  • + *
  • {@code reset} — resets the counter to zero
  • + *
  • {@code get} — returns the current counter value
  • + *
+ *

+ * Usage: Run this sample with a Durable Task sidecar listening on localhost:4001. + *

+ *   docker run -d -p 4001:4001 durabletask-sidecar
+ *   ./gradlew :samples:run -PmainClass=io.durabletask.samples.CounterEntitySample
+ * 
+ */ +final class CounterEntitySample { + + public static void main(String[] args) throws IOException, InterruptedException, TimeoutException { + // Build the worker with the counter entity registered + DurableTaskGrpcWorker worker = new DurableTaskGrpcWorkerBuilder() + .addEntity("Counter", CounterEntity::new) + .addOrchestration(new TaskOrchestrationFactory() { + @Override + public String getName() { return "CounterOrchestration"; } + + @Override + public TaskOrchestration create() { + return ctx -> { + EntityInstanceId counterId = new EntityInstanceId("Counter", "myCounter"); + + // Signal entity to add 5 + ctx.signalEntity(counterId, "add", 5); + // Signal entity to add 10 + ctx.signalEntity(counterId, "add", 10); + + // Call entity to get the current value + int value = ctx.callEntity(counterId, "get", Integer.class).await(); + + ctx.complete(value); + }; + } + }) + .build(); + + worker.start(); + System.out.println("Worker started. Counter entity registered."); + + // Use the client to schedule an orchestration that interacts with the entity + DurableTaskClient client = new DurableTaskGrpcClientBuilder().build(); + + String instanceId = client.scheduleNewOrchestrationInstance("CounterOrchestration"); + System.out.printf("Started orchestration: %s%n", instanceId); + + OrchestrationMetadata result = client.waitForInstanceCompletion( + instanceId, Duration.ofSeconds(30), true); + System.out.printf("Orchestration completed: %s%n", result); + System.out.printf("Counter value: %s%n", result.readOutputAs(Integer.class)); + + // Query entity state directly + EntityMetadata entityMetadata = client.getEntities().getEntityMetadata( + new EntityInstanceId("Counter", "myCounter"), true); + if (entityMetadata != null) { + System.out.printf("Entity state: %s%n", entityMetadata.readStateAs(Integer.class)); + } + + // Signal the entity to reset + client.getEntities().signalEntity(new EntityInstanceId("Counter", "myCounter"), "reset"); + System.out.println("Sent reset signal to counter entity."); + + worker.stop(); + } + + /** + * A simple counter entity that stores an integer and supports add, reset, and get operations. + */ + public static class CounterEntity extends TaskEntity { + + public void add(int amount) { + this.state += amount; + } + + public void reset() { + this.state = 0; + } + + public int get() { + return this.state; + } + + @Override + protected Integer initializeState(TaskEntityOperation operation) { + return 0; + } + + @Override + protected Class getStateType() { + return Integer.class; + } + } +} diff --git a/samples/src/main/java/io/durabletask/samples/EntityCommunicationSample.java b/samples/src/main/java/io/durabletask/samples/EntityCommunicationSample.java new file mode 100644 index 00000000..3271a99f --- /dev/null +++ b/samples/src/main/java/io/durabletask/samples/EntityCommunicationSample.java @@ -0,0 +1,200 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package io.durabletask.samples; + +import com.microsoft.durabletask.*; + +import java.io.IOException; +import java.time.Duration; +import java.util.concurrent.TimeoutException; + +/** + * Sample demonstrating entity-to-entity signaling and entity-started orchestrations. + *

+ * This sample shows two advanced entity communication patterns: + *

    + *
  • Entity-to-entity signaling: An entity uses {@link TaskEntityContext#signalEntity} + * to send a fire-and-forget message to another entity.
  • + *
  • Entity starting an orchestration: An entity uses + * {@link TaskEntityContext#startNewOrchestration} to kick off an orchestration when + * a threshold condition is met.
  • + *
+ *

+ * Scenario: A sensor entity receives temperature readings. Each reading is forwarded to an + * aggregator entity that tracks the average. When the aggregator detects a reading above a + * threshold, it starts an alert orchestration. + *

+ * Usage: Run this sample with a Durable Task sidecar listening on localhost:4001. + *

+ *   docker run -d -p 4001:4001 durabletask-sidecar
+ *   ./gradlew :samples:run -PmainClass=io.durabletask.samples.EntityCommunicationSample
+ * 
+ */ +final class EntityCommunicationSample { + + public static void main(String[] args) throws IOException, InterruptedException, TimeoutException { + DurableTaskGrpcWorker worker = new DurableTaskGrpcWorkerBuilder() + .addEntity("Sensor", SensorEntity::new) + .addEntity("Aggregator", AggregatorEntity::new) + .addOrchestration(new TaskOrchestrationFactory() { + @Override + public String getName() { return "TemperatureAlert"; } + + @Override + public TaskOrchestration create() { + return ctx -> { + double temperature = ctx.getInput(Double.class); + String message = String.format( + "ALERT: Temperature %.1f°C exceeds threshold!", temperature); + ctx.complete(message); + }; + } + }) + .addOrchestration(new TaskOrchestrationFactory() { + @Override + public String getName() { return "SendReadings"; } + + @Override + public TaskOrchestration create() { + return ctx -> { + EntityInstanceId sensorId = new EntityInstanceId("Sensor", "sensor-1"); + + // Send several temperature readings to the sensor + double[] readings = { 22.5, 24.0, 26.5, 35.0, 23.0 }; + for (double reading : readings) { + ctx.signalEntity(sensorId, "record", reading); + } + + ctx.complete("Sent " + readings.length + " readings"); + }; + } + }) + .build(); + + worker.start(); + System.out.println("Worker started. Sensor and Aggregator entities registered."); + + DurableTaskClient client = new DurableTaskGrpcClientBuilder().build(); + + // Send readings through an orchestration + String instanceId = client.scheduleNewOrchestrationInstance("SendReadings"); + OrchestrationMetadata result = client.waitForInstanceCompletion( + instanceId, Duration.ofSeconds(30), true); + System.out.printf("Orchestration result: %s%n", result.readOutputAs(String.class)); + + // Wait for entity processing and potential alert orchestration + Thread.sleep(5000); + + // Check final entity states + EntityMetadata sensor = client.getEntities().getEntityMetadata( + new EntityInstanceId("Sensor", "sensor-1"), true); + if (sensor != null) { + System.out.printf("Sensor state: %s%n", sensor.readStateAs(SensorState.class)); + } + + EntityMetadata aggregator = client.getEntities().getEntityMetadata( + new EntityInstanceId("Aggregator", "sensor-1"), true); + if (aggregator != null) { + AggregatorState aggState = aggregator.readStateAs(AggregatorState.class); + System.out.printf("Aggregator state: count=%d, sum=%.1f, avg=%.1f%n", + aggState.count, aggState.sum, aggState.getAverage()); + } + + worker.stop(); + } + + // ---- State classes ---- + + /** + * State for the sensor entity — tracks the last recorded temperature. + */ + public static class SensorState { + public double lastReading; + public int totalReadings; + + @Override + public String toString() { + return String.format("SensorState{lastReading=%.1f, totalReadings=%d}", lastReading, totalReadings); + } + } + + /** + * State for the aggregator entity — tracks sum and count for computing averages. + */ + public static class AggregatorState { + public double sum; + public int count; + + public double getAverage() { + return count > 0 ? sum / count : 0; + } + } + + // ---- Entity implementations ---- + + /** + * Sensor entity that records temperature readings and forwards them to an aggregator. + *

+ * Demonstrates entity-to-entity signaling via {@link TaskEntityContext#signalEntity}. + */ + public static class SensorEntity extends TaskEntity { + + /** + * Records a temperature reading and forwards it to the aggregator entity. + */ + public void record(double temperature, TaskEntityContext ctx) { + this.state.lastReading = temperature; + this.state.totalReadings++; + + // Entity-to-entity signaling: forward the reading to the aggregator + EntityInstanceId aggregatorId = new EntityInstanceId("Aggregator", ctx.getId().getKey()); + ctx.signalEntity(aggregatorId, "addReading", temperature); + } + + @Override + protected SensorState initializeState(TaskEntityOperation operation) { + return new SensorState(); + } + + @Override + protected Class getStateType() { + return SensorState.class; + } + } + + /** + * Aggregator entity that computes running averages and starts an alert orchestration + * when a reading exceeds a threshold. + *

+ * Demonstrates entity starting an orchestration via {@link TaskEntityContext#startNewOrchestration}. + */ + public static class AggregatorEntity extends TaskEntity { + private static final double ALERT_THRESHOLD = 30.0; + + /** + * Adds a reading to the running total. If the reading exceeds the threshold, + * starts an alert orchestration. + */ + public void addReading(double reading, TaskEntityContext ctx) { + this.state.sum += reading; + this.state.count++; + + if (reading > ALERT_THRESHOLD) { + // Entity starting an orchestration when a threshold is breached + String orchestrationId = ctx.startNewOrchestration("TemperatureAlert", reading); + System.out.printf("Aggregator started alert orchestration: %s (reading=%.1f)%n", + orchestrationId, reading); + } + } + + @Override + protected AggregatorState initializeState(TaskEntityOperation operation) { + return new AggregatorState(); + } + + @Override + protected Class getStateType() { + return AggregatorState.class; + } + } +} diff --git a/samples/src/main/java/io/durabletask/samples/EntityQuerySample.java b/samples/src/main/java/io/durabletask/samples/EntityQuerySample.java new file mode 100644 index 00000000..c1c83a5e --- /dev/null +++ b/samples/src/main/java/io/durabletask/samples/EntityQuerySample.java @@ -0,0 +1,133 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package io.durabletask.samples; + +import com.microsoft.durabletask.*; + +import java.io.IOException; +import java.time.Duration; +import java.util.concurrent.TimeoutException; + +/** + * Sample demonstrating entity querying and storage cleanup. + *

+ * This sample shows how to use the client management APIs: + *

    + *
  • {@link DurableEntityClient#queryEntities(EntityQuery)} — query entities with filters and pagination
  • + *
  • {@link DurableEntityClient#cleanEntityStorage(CleanEntityStorageRequest)} — remove empty entities + * and release orphaned locks
  • + *
+ *

+ * The sample creates several counter entities, queries them by name prefix, then deletes + * them and cleans up the empty entity storage. + *

+ * Usage: Run this sample with a Durable Task sidecar listening on localhost:4001. + *

+ *   docker run -d -p 4001:4001 durabletask-sidecar
+ *   ./gradlew :samples:run -PmainClass=io.durabletask.samples.EntityQuerySample
+ * 
+ */ +final class EntityQuerySample { + + public static void main(String[] args) throws IOException, InterruptedException, TimeoutException { + // Build the worker with a simple counter entity + DurableTaskGrpcWorker worker = new DurableTaskGrpcWorkerBuilder() + .addEntity("Counter", CounterEntity::new) + .addOrchestration(new TaskOrchestrationFactory() { + @Override + public String getName() { return "CreateCounters"; } + + @Override + public TaskOrchestration create() { + return ctx -> { + // Create several counter entities with different keys + for (int i = 1; i <= 5; i++) { + EntityInstanceId counterId = new EntityInstanceId("Counter", "counter-" + i); + ctx.signalEntity(counterId, "add", i * 10); + } + ctx.complete("Created 5 counters"); + }; + } + }) + .build(); + + worker.start(); + System.out.println("Worker started. Counter entity registered."); + + DurableTaskClient client = new DurableTaskGrpcClientBuilder().build(); + + // Step 1: Create several counter entities via an orchestration + String instanceId = client.scheduleNewOrchestrationInstance("CreateCounters"); + client.waitForInstanceCompletion(instanceId, Duration.ofSeconds(30), true); + System.out.println("Created 5 counter entities."); + + // Step 2: Query all Counter entities (filter by entity name prefix) + System.out.println("\n--- Querying all Counter entities ---"); + EntityQuery query = new EntityQuery() + .setInstanceIdStartsWith("Counter") // filters to @counter prefix + .setIncludeState(true) + .setPageSize(3); // use small page size to demonstrate pagination + + String continuationToken = null; + int pageNumber = 0; + do { + if (continuationToken != null) { + query.setContinuationToken(continuationToken); + } + + EntityQueryResult result = client.getEntities().queryEntities(query); + pageNumber++; + System.out.printf("Page %d: %d entities%n", pageNumber, result.getEntities().size()); + + for (EntityMetadata entity : result.getEntities()) { + EntityInstanceId entityId = entity.getEntityInstanceId(); + Integer state = entity.readStateAs(Integer.class); + System.out.printf(" %s/%s = %d (lastModified: %s)%n", + entityId.getName(), entityId.getKey(), state, entity.getLastModifiedTime()); + } + + continuationToken = result.getContinuationToken(); + } while (continuationToken != null); + + // Step 3: Delete all counter entities by signaling the implicit "delete" operation + System.out.println("\n--- Deleting all counter entities ---"); + for (int i = 1; i <= 5; i++) { + EntityInstanceId counterId = new EntityInstanceId("Counter", "counter-" + i); + client.getEntities().signalEntity(counterId, "delete"); + } + // Give time for delete signals to be processed + Thread.sleep(3000); + + // Step 4: Clean entity storage to remove the now-empty entities + System.out.println("\n--- Cleaning entity storage ---"); + CleanEntityStorageRequest cleanRequest = new CleanEntityStorageRequest() + .setRemoveEmptyEntities(true) + .setReleaseOrphanedLocks(true); + + CleanEntityStorageResult cleanResult = client.getEntities().cleanEntityStorage(cleanRequest); + System.out.printf("Cleaned storage: %d empty entities removed, %d orphaned locks released%n", + cleanResult.getEmptyEntitiesRemoved(), + cleanResult.getOrphanedLocksReleased()); + + // Step 5: Verify entities are gone + EntityQueryResult afterClean = client.getEntities().queryEntities( + new EntityQuery().setInstanceIdStartsWith("Counter")); + System.out.printf("Entities remaining after cleanup: %d%n", afterClean.getEntities().size()); + + worker.stop(); + } + + /** + * A simple counter entity (reused for this sample). + */ + public static class CounterEntity extends TaskEntity { + public void add(int amount) { this.state += amount; } + public int get() { return this.state; } + + @Override + protected Integer initializeState(TaskEntityOperation operation) { return 0; } + + @Override + protected Class getStateType() { return Integer.class; } + } +} diff --git a/samples/src/main/java/io/durabletask/samples/EntityTimeoutSample.java b/samples/src/main/java/io/durabletask/samples/EntityTimeoutSample.java new file mode 100644 index 00000000..698f6aec --- /dev/null +++ b/samples/src/main/java/io/durabletask/samples/EntityTimeoutSample.java @@ -0,0 +1,163 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package io.durabletask.samples; + +import com.microsoft.durabletask.*; + +import java.io.IOException; +import java.time.Duration; +import java.time.Instant; +import java.util.concurrent.TimeoutException; + +/** + * Sample demonstrating entity call timeouts and scheduled signals. + *

+ * This sample shows two features: + *

    + *
  • Call timeouts: Using {@link CallEntityOptions#setTimeout(Duration)} to set a deadline + * on a {@code callEntity} call. If the entity doesn't respond in time, the orchestration + * receives a {@link TaskCanceledException}.
  • + *
  • Scheduled signals: Using {@link SignalEntityOptions#setScheduledTime(Instant)} to + * schedule a signal for delivery at a future time (e.g., a reminder or expiration pattern).
  • + *
+ *

+ * Usage: Run this sample with a Durable Task sidecar listening on localhost:4001. + *

+ *   docker run -d -p 4001:4001 durabletask-sidecar
+ *   ./gradlew :samples:run -PmainClass=io.durabletask.samples.EntityTimeoutSample
+ * 
+ */ +final class EntityTimeoutSample { + + public static void main(String[] args) throws IOException, InterruptedException, TimeoutException { + DurableTaskGrpcWorker worker = new DurableTaskGrpcWorkerBuilder() + .addEntity("SlowCounter", SlowCounterEntity::new) + .addOrchestration(new TaskOrchestrationFactory() { + @Override + public String getName() { return "CallWithTimeout"; } + + @Override + public TaskOrchestration create() { + return ctx -> { + EntityInstanceId counterId = new EntityInstanceId("SlowCounter", "myCounter"); + + // First: signal entity to set up initial state + ctx.signalEntity(counterId, "add", 100); + + // Call entity with a generous timeout — should succeed + try { + CallEntityOptions options = new CallEntityOptions() + .setTimeout(Duration.ofSeconds(30)); + int value = ctx.callEntity( + counterId, "get", null, Integer.class, options).await(); + ctx.complete("Entity value: " + value); + } catch (TaskCanceledException e) { + ctx.complete("Call timed out: " + e.getMessage()); + } + }; + } + }) + .addOrchestration(new TaskOrchestrationFactory() { + @Override + public String getName() { return "CallWithShortTimeout"; } + + @Override + public TaskOrchestration create() { + return ctx -> { + EntityInstanceId counterId = new EntityInstanceId("SlowCounter", "timeout-test"); + + // Call entity with a very short timeout — demonstrates timeout handling + try { + CallEntityOptions options = new CallEntityOptions() + .setTimeout(Duration.ofMillis(1)); + int value = ctx.callEntity( + counterId, "get", null, Integer.class, options).await(); + ctx.complete("Got value: " + value); + } catch (TaskCanceledException e) { + // The orchestration can gracefully handle the timeout + ctx.complete("Handled timeout gracefully: " + e.getMessage()); + } + }; + } + }) + .build(); + + worker.start(); + System.out.println("Worker started. SlowCounter entity registered."); + + DurableTaskClient client = new DurableTaskGrpcClientBuilder().build(); + + // --- Demo 1: Successful call with generous timeout --- + System.out.println("\n--- Demo 1: callEntity with generous timeout ---"); + String instanceId1 = client.scheduleNewOrchestrationInstance("CallWithTimeout"); + OrchestrationMetadata result1 = client.waitForInstanceCompletion( + instanceId1, Duration.ofSeconds(60), true); + System.out.printf("Result: %s%n", result1.readOutputAs(String.class)); + + // --- Demo 2: Call with very short timeout --- + System.out.println("\n--- Demo 2: callEntity with short timeout ---"); + String instanceId2 = client.scheduleNewOrchestrationInstance("CallWithShortTimeout"); + OrchestrationMetadata result2 = client.waitForInstanceCompletion( + instanceId2, Duration.ofSeconds(30), true); + System.out.printf("Result: %s%n", result2.readOutputAs(String.class)); + + // --- Demo 3: Scheduled signal --- + System.out.println("\n--- Demo 3: Scheduled signal (5 seconds in the future) ---"); + EntityInstanceId counterId = new EntityInstanceId("SlowCounter", "scheduled-test"); + + // Initialize the counter + client.getEntities().signalEntity(counterId, "add", 50); + Thread.sleep(2000); + + // Schedule a signal to add 25 more, delivered 5 seconds from now + Instant scheduledTime = Instant.now().plusSeconds(5); + SignalEntityOptions signalOptions = new SignalEntityOptions() + .setScheduledTime(scheduledTime); + client.getEntities().signalEntity(counterId, "add", 25, signalOptions); + System.out.printf("Scheduled 'add 25' signal for %s%n", scheduledTime); + + // Check state before the scheduled signal is delivered + EntityMetadata beforeMeta = client.getEntities().getEntityMetadata(counterId, true); + if (beforeMeta != null) { + System.out.printf("State before scheduled signal: %d%n", + beforeMeta.readStateAs(Integer.class)); + } + + // Wait for the scheduled signal to be delivered + System.out.println("Waiting for scheduled signal delivery..."); + Thread.sleep(7000); + + // Check state after + EntityMetadata afterMeta = client.getEntities().getEntityMetadata(counterId, true); + if (afterMeta != null) { + System.out.printf("State after scheduled signal: %d%n", + afterMeta.readStateAs(Integer.class)); + } + + worker.stop(); + } + + /** + * A counter entity used for timeout and scheduled signal demonstrations. + */ + public static class SlowCounterEntity extends TaskEntity { + + public void add(int amount) { + this.state += amount; + } + + public int get() { + return this.state; + } + + @Override + protected Integer initializeState(TaskEntityOperation operation) { + return 0; + } + + @Override + protected Class getStateType() { + return Integer.class; + } + } +} diff --git a/samples/src/main/java/io/durabletask/samples/LowLevelEntitySample.java b/samples/src/main/java/io/durabletask/samples/LowLevelEntitySample.java new file mode 100644 index 00000000..4239d987 --- /dev/null +++ b/samples/src/main/java/io/durabletask/samples/LowLevelEntitySample.java @@ -0,0 +1,254 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package io.durabletask.samples; + +import com.microsoft.durabletask.*; + +import java.io.IOException; +import java.time.Duration; +import java.util.concurrent.TimeoutException; + +/** + * Sample demonstrating advanced entity programming models: + *
    + *
  • {@link ITaskEntity} low-level interface: Implementing the raw entity interface with + * manual switch-based operation dispatch instead of reflection-based {@link TaskEntity}.
  • + *
  • State dispatch: Using {@link TaskEntity} with {@code allowStateDispatch=true} + * (default) so operations can be dispatched to methods on the state POJO itself.
  • + *
  • POJO entity state: Using a rich class as the entity state instead of a primitive.
  • + *
  • Implicit delete: Sending a "delete" operation to an entity that has no explicit + * delete method — {@link TaskEntity} handles this automatically by clearing the state.
  • + *
+ *

+ * Usage: Run this sample with a Durable Task sidecar listening on localhost:4001. + *

+ *   docker run -d -p 4001:4001 durabletask-sidecar
+ *   ./gradlew :samples:run -PmainClass=io.durabletask.samples.LowLevelEntitySample
+ * 
+ */ +final class LowLevelEntitySample { + + public static void main(String[] args) throws IOException, InterruptedException, TimeoutException { + DurableTaskGrpcWorker worker = new DurableTaskGrpcWorkerBuilder() + // Demo 1: ITaskEntity — manual dispatch + .addEntity("KeyValue", KeyValueEntity::new) + // Demo 2: TaskEntity with POJO state and state dispatch + .addEntity("ShoppingCart", CartEntity::new) + .addOrchestration(new TaskOrchestrationFactory() { + @Override + public String getName() { return "LowLevelDemo"; } + + @Override + public TaskOrchestration create() { + return ctx -> { + EntityInstanceId kvId = new EntityInstanceId("KeyValue", "config"); + + // Use the low-level ITaskEntity entity + ctx.callEntity(kvId, "set", new KeyValuePair("color", "blue")).await(); + ctx.callEntity(kvId, "set", new KeyValuePair("size", "large")).await(); + String color = ctx.callEntity(kvId, "get", "color", String.class).await(); + + ctx.complete("color=" + color); + }; + } + }) + .addOrchestration(new TaskOrchestrationFactory() { + @Override + public String getName() { return "StateDispatchDemo"; } + + @Override + public TaskOrchestration create() { + return ctx -> { + EntityInstanceId cartId = new EntityInstanceId("ShoppingCart", "cart-1"); + + // Operations dispatched to methods on the CartState POJO + ctx.signalEntity(cartId, "addItem", "Widget"); + ctx.signalEntity(cartId, "addItem", "Gadget"); + ctx.signalEntity(cartId, "addItem", "Widget"); + + int count = ctx.callEntity(cartId, "getItemCount", Integer.class).await(); + ctx.complete("Cart has " + count + " items"); + }; + } + }) + .addOrchestration(new TaskOrchestrationFactory() { + @Override + public String getName() { return "ImplicitDeleteDemo"; } + + @Override + public TaskOrchestration create() { + return ctx -> { + EntityInstanceId cartId = new EntityInstanceId("ShoppingCart", "cart-delete"); + + // Add an item + ctx.signalEntity(cartId, "addItem", "TempItem"); + + // Read state to confirm it exists + int count = ctx.callEntity(cartId, "getItemCount", Integer.class).await(); + + // Implicit delete — no explicit "delete" method on CartEntity or CartState, + // but TaskEntity handles it automatically by clearing the entity state + ctx.callEntity(cartId, "delete").await(); + + ctx.complete("Had " + count + " item(s), then deleted"); + }; + } + }) + .build(); + + worker.start(); + System.out.println("Worker started. KeyValue and ShoppingCart entities registered."); + + DurableTaskClient client = new DurableTaskGrpcClientBuilder().build(); + + // --- Demo 1: Low-level ITaskEntity --- + System.out.println("\n--- Demo 1: ITaskEntity with manual dispatch ---"); + String id1 = client.scheduleNewOrchestrationInstance("LowLevelDemo"); + OrchestrationMetadata result1 = client.waitForInstanceCompletion( + id1, Duration.ofSeconds(30), true); + System.out.printf("Result: %s%n", result1.readOutputAs(String.class)); + + EntityMetadata kvMeta = client.getEntities().getEntityMetadata( + new EntityInstanceId("KeyValue", "config"), true); + if (kvMeta != null) { + System.out.printf("KeyValue entity state: %s%n", kvMeta.readStateAs(Object.class)); + } + + // --- Demo 2: State dispatch (operations dispatched to CartState methods) --- + System.out.println("\n--- Demo 2: State dispatch with POJO state ---"); + String id2 = client.scheduleNewOrchestrationInstance("StateDispatchDemo"); + OrchestrationMetadata result2 = client.waitForInstanceCompletion( + id2, Duration.ofSeconds(30), true); + System.out.printf("Result: %s%n", result2.readOutputAs(String.class)); + + // --- Demo 3: Implicit delete --- + System.out.println("\n--- Demo 3: Implicit delete ---"); + String id3 = client.scheduleNewOrchestrationInstance("ImplicitDeleteDemo"); + OrchestrationMetadata result3 = client.waitForInstanceCompletion( + id3, Duration.ofSeconds(30), true); + System.out.printf("Result: %s%n", result3.readOutputAs(String.class)); + + // Verify entity was deleted + EntityMetadata deletedMeta = client.getEntities().getEntityMetadata( + new EntityInstanceId("ShoppingCart", "cart-delete"), true); + System.out.printf("Entity after delete: %s%n", deletedMeta == null ? "null (deleted)" : "still exists"); + + worker.stop(); + } + + // ---- Data classes ---- + + /** + * A simple key-value pair used as input for the KeyValue entity. + */ + public static class KeyValuePair { + public String key; + public String value; + + public KeyValuePair() {} // for deserialization + public KeyValuePair(String key, String value) { + this.key = key; + this.value = value; + } + } + + // ---- Low-level ITaskEntity implementation ---- + + /** + * A key-value store entity implemented directly with {@link ITaskEntity}. + *

+ * This demonstrates manual switch-based operation dispatch without the reflection-based + * {@link TaskEntity} base class. This gives full control over how operations are routed. + */ + public static class KeyValueEntity implements ITaskEntity { + @Override + public Object runAsync(TaskEntityOperation operation) throws Exception { + // Load current state (a Map stored as JSON) + @SuppressWarnings("unchecked") + java.util.Map store = operation.getState().getState(java.util.Map.class); + if (store == null) { + store = new java.util.HashMap<>(); + } + + Object result = null; + + // Manual switch-based dispatch + switch (operation.getName().toLowerCase()) { + case "set": + KeyValuePair kvp = operation.getInput(KeyValuePair.class); + store.put(kvp.key, kvp.value); + break; + case "get": + String key = operation.getInput(String.class); + result = store.get(key); + break; + case "remove": + String removeKey = operation.getInput(String.class); + result = store.remove(removeKey); + break; + case "getall": + result = new java.util.HashMap<>(store); + break; + case "delete": + operation.getState().deleteState(); + return null; + default: + throw new UnsupportedOperationException( + "KeyValue entity does not support operation: " + operation.getName()); + } + + // Save state back + operation.getState().setState(store); + return result; + } + } + + // ---- State dispatch entity implementation ---- + + /** + * Shopping cart state POJO whose public methods serve as entity operations + * via state dispatch. + *

+ * When {@link CartEntity} receives an operation like "addItem", and no matching method + * exists on {@code CartEntity} itself, the framework dispatches the operation to + * this state class's {@code addItem} method. + */ + public static class CartState { + public java.util.List items = new java.util.ArrayList<>(); + + /** + * Adds an item to the cart. Called via state dispatch. + */ + public void addItem(String item) { + items.add(item); + } + + /** + * Returns the number of items in the cart. Called via state dispatch. + */ + public int getItemCount() { + return items.size(); + } + } + + /** + * Cart entity that delegates operations to methods on the {@link CartState} POJO. + *

+ * This entity has no operation methods of its own — all operations (addItem, getItemCount) + * are dispatched to the state object. This is enabled by default via + * {@code allowStateDispatch = true}. + */ + public static class CartEntity extends TaskEntity { + // No operation methods here — they are on CartState + + @Override + protected CartState initializeState(TaskEntityOperation operation) { + return new CartState(); + } + + @Override + protected Class getStateType() { + return CartState.class; + } + } +}