|
27 | 27 | ) |
28 | 28 | from temporalio.client import ( |
29 | 29 | Client, |
| 30 | + WorkflowHandle, |
30 | 31 | ) |
31 | 32 | from temporalio.common import PinnedVersioningOverride, RawValue, VersioningBehavior |
32 | 33 | from temporalio.runtime import ( |
@@ -1162,6 +1163,54 @@ async def set_ramping_version( |
1162 | 1163 | return response |
1163 | 1164 |
|
1164 | 1165 |
|
| 1166 | +async def wait_for_worker_deployment_routing_config_propagation( |
| 1167 | + client: Client, |
| 1168 | + deployment_name: str, |
| 1169 | + expected_current_build_id: str, |
| 1170 | + expected_ramping_build_id: str = "", |
| 1171 | +) -> None: |
| 1172 | + """Wait for routing config to be propagated to all task queues.""" |
| 1173 | + import temporalio.api.enums.v1 |
| 1174 | + |
| 1175 | + async def check() -> bool: |
| 1176 | + resp = await client.workflow_service.describe_worker_deployment( |
| 1177 | + DescribeWorkerDeploymentRequest( |
| 1178 | + namespace=client.namespace, |
| 1179 | + deployment_name=deployment_name, |
| 1180 | + ) |
| 1181 | + ) |
| 1182 | + routing_config = resp.worker_deployment_info.routing_config |
| 1183 | + if ( |
| 1184 | + routing_config.current_deployment_version.build_id |
| 1185 | + != expected_current_build_id |
| 1186 | + ): |
| 1187 | + return False |
| 1188 | + if ( |
| 1189 | + routing_config.ramping_deployment_version.build_id |
| 1190 | + != expected_ramping_build_id |
| 1191 | + ): |
| 1192 | + return False |
| 1193 | + state = resp.worker_deployment_info.routing_config_update_state |
| 1194 | + if ( |
| 1195 | + state |
| 1196 | + == temporalio.api.enums.v1.RoutingConfigUpdateState.ROUTING_CONFIG_UPDATE_STATE_COMPLETED |
| 1197 | + ): |
| 1198 | + return True |
| 1199 | + if ( |
| 1200 | + state |
| 1201 | + == temporalio.api.enums.v1.RoutingConfigUpdateState.ROUTING_CONFIG_UPDATE_STATE_UNSPECIFIED |
| 1202 | + ): |
| 1203 | + return True # unimplemented |
| 1204 | + if ( |
| 1205 | + state |
| 1206 | + == temporalio.api.enums.v1.RoutingConfigUpdateState.ROUTING_CONFIG_UPDATE_STATE_IN_PROGRESS |
| 1207 | + ): |
| 1208 | + return False |
| 1209 | + return False |
| 1210 | + |
| 1211 | + await assert_eventually(check) |
| 1212 | + |
| 1213 | + |
1165 | 1214 | def create_worker( |
1166 | 1215 | client: Client, |
1167 | 1216 | on_fatal_error: Callable[[BaseException], Awaitable[None]] | None = None, |
@@ -1316,3 +1365,124 @@ async def capture_client_activity() -> None: |
1316 | 1365 | assert len(captured_clients) == 2 |
1317 | 1366 | assert captured_clients[0] is client |
1318 | 1367 | assert captured_clients[1] is client2 # This will fail before the fix |
| 1368 | + |
| 1369 | + |
| 1370 | +@workflow.defn( |
| 1371 | + name="ContinueAsNewWithVersionUpgrade", |
| 1372 | + versioning_behavior=VersioningBehavior.PINNED, |
| 1373 | +) |
| 1374 | +class ContinueAsNewWithVersionUpgradeV1: |
| 1375 | + @workflow.run |
| 1376 | + async def run(self, attempt: int) -> str: |
| 1377 | + if attempt > 0: |
| 1378 | + return "v1.0" |
| 1379 | + |
| 1380 | + # Loop waiting for CAN suggestion with version changed |
| 1381 | + while True: |
| 1382 | + # Trigger a WFT when timer expires, thereby refreshing the continue-as-new-suggested flag |
| 1383 | + await asyncio.sleep(0.01) |
| 1384 | + info = workflow.info() |
| 1385 | + if info.is_target_worker_deployment_version_changed(): |
| 1386 | + workflow.continue_as_new( |
| 1387 | + arg=attempt + 1, |
| 1388 | + initial_versioning_behavior=workflow.ContinueAsNewVersioningBehavior.AUTO_UPGRADE, |
| 1389 | + ) |
| 1390 | + |
| 1391 | + |
| 1392 | +@workflow.defn( |
| 1393 | + name="ContinueAsNewWithVersionUpgrade", |
| 1394 | + versioning_behavior=VersioningBehavior.PINNED, |
| 1395 | +) |
| 1396 | +class ContinueAsNewWithVersionUpgradeV2: |
| 1397 | + @workflow.run |
| 1398 | + async def run(self, attempt: int) -> str: # type:ignore[reportUnusedParameter] |
| 1399 | + return "v2.0" |
| 1400 | + |
| 1401 | + |
| 1402 | +async def wait_for_workflow_running_on_version( |
| 1403 | + handle: WorkflowHandle[Any, Any], expected_build_id: str |
| 1404 | +) -> None: |
| 1405 | + """Wait until workflow is RUNNING with expected build ID.""" |
| 1406 | + |
| 1407 | + async def check() -> bool: |
| 1408 | + desc = await handle.describe() |
| 1409 | + if ( |
| 1410 | + desc.status |
| 1411 | + != temporalio.api.enums.v1.WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_RUNNING |
| 1412 | + ): |
| 1413 | + return False |
| 1414 | + versioning_info = desc.raw_description.workflow_execution_info.versioning_info |
| 1415 | + if not versioning_info.HasField("deployment_version"): |
| 1416 | + return False |
| 1417 | + return versioning_info.deployment_version.build_id == expected_build_id |
| 1418 | + |
| 1419 | + await assert_eventually(check) |
| 1420 | + |
| 1421 | + |
| 1422 | +async def test_continue_as_new_with_version_upgrade( |
| 1423 | + client: Client, env: WorkflowEnvironment |
| 1424 | +): |
| 1425 | + if env.supports_time_skipping: |
| 1426 | + pytest.skip("Test Server doesn't support worker deployments") |
| 1427 | + |
| 1428 | + deployment_name = f"deployment-can-upgrade-{uuid.uuid4()}" |
| 1429 | + v1 = WorkerDeploymentVersion(deployment_name=deployment_name, build_id="1.0") |
| 1430 | + v2 = WorkerDeploymentVersion(deployment_name=deployment_name, build_id="2.0") |
| 1431 | + |
| 1432 | + async with ( |
| 1433 | + new_worker( |
| 1434 | + client, |
| 1435 | + ContinueAsNewWithVersionUpgradeV1, |
| 1436 | + deployment_config=WorkerDeploymentConfig( |
| 1437 | + version=v1, |
| 1438 | + use_worker_versioning=True, |
| 1439 | + ), |
| 1440 | + ) as w1, |
| 1441 | + new_worker( |
| 1442 | + client, |
| 1443 | + ContinueAsNewWithVersionUpgradeV2, |
| 1444 | + deployment_config=WorkerDeploymentConfig( |
| 1445 | + version=v2, |
| 1446 | + use_worker_versioning=True, |
| 1447 | + ), |
| 1448 | + task_queue=w1.task_queue, |
| 1449 | + ), |
| 1450 | + ): |
| 1451 | + # Wait for the deployment to be ready |
| 1452 | + describe_resp = await wait_until_worker_deployment_visible(client, v1) |
| 1453 | + |
| 1454 | + # Set version 1.0 as current |
| 1455 | + resp2 = await set_current_deployment_version( |
| 1456 | + client, describe_resp.conflict_token, v1 |
| 1457 | + ) |
| 1458 | + |
| 1459 | + # Wait for v1.0-as-Current routing config to be propagated |
| 1460 | + await wait_for_worker_deployment_routing_config_propagation( |
| 1461 | + client, deployment_name, v1.build_id |
| 1462 | + ) |
| 1463 | + |
| 1464 | + # Start workflow with v1 as current |
| 1465 | + handle = await client.start_workflow( |
| 1466 | + "ContinueAsNewWithVersionUpgrade", |
| 1467 | + 0, |
| 1468 | + id=f"test-can-version-upgrade-{uuid.uuid4()}", |
| 1469 | + task_queue=w1.task_queue, |
| 1470 | + ) |
| 1471 | + |
| 1472 | + # Wait for workflow to complete one WFT on v1.0 |
| 1473 | + await wait_for_workflow_running_on_version(handle, v1.build_id) |
| 1474 | + |
| 1475 | + # Wait for version 2.0 to be ready |
| 1476 | + await wait_until_worker_deployment_visible(client, v2) |
| 1477 | + |
| 1478 | + # Set version 2.0 as current |
| 1479 | + await set_current_deployment_version(client, resp2.conflict_token, v2) |
| 1480 | + |
| 1481 | + # Wait for v2.0-as-Current routing config to be propagated |
| 1482 | + await wait_for_worker_deployment_routing_config_propagation( |
| 1483 | + client, deployment_name, v2.build_id |
| 1484 | + ) |
| 1485 | + |
| 1486 | + # Expect workflow to return "v2.0", indicating that it continued-as-new and completed on v2 |
| 1487 | + result = await handle.result() |
| 1488 | + assert result == "v2.0" |
0 commit comments