Support torch.distributed as alias for pytorch distribution type#44968
Support torch.distributed as alias for pytorch distribution type#44968
Conversation
|
Please go ahead and make the changes. |
- Added TORCH_DISTRIBUTED constant as legacy alias in DistributionType - Updated DISTRIBUTION_TYPE_MAP to map both pytorch and torch.distributed to PyTorchDistribution - Updated PyTorchDistributionSchema to accept both values in allowed_values - Added test YAML file with torch.distributed type - Added unit test to verify both values work interchangeably Co-authored-by: mohammadsheraj <221899694+mohammadsheraj@users.noreply.github.com>
- Modified from_dict_to_rest_distribution to accept both pytorch and torch.distributed - Updated error message to include torch.distributed in valid types - Added comprehensive unit tests for distribution helper functions Co-authored-by: mohammadsheraj <221899694+mohammadsheraj@users.noreply.github.com>
- Added TORCH_DISTRIBUTED mapping to RestDistributionType.PY_TORCH - Ensures proper REST API conversion for both pytorch and torch.distributed Co-authored-by: mohammadsheraj <221899694+mohammadsheraj@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
This PR adds support for torch.distributed as a legacy alias for the pytorch distribution type, enabling backwards compatibility with AML SDK v1.5 while maintaining the existing SDK v2 behavior.
Changes:
- Added
TORCH_DISTRIBUTED = "torch.distributed"constant toDistributionType - Updated entity mappings (
SDK_TO_RESTandDISTRIBUTION_TYPE_MAP) to handle bothpytorchandtorch.distributedtypes, mapping them to the same PyTorch classes - Modified schema validation to accept both values in
PyTorchDistributionSchema.allowed_values - Updated pipeline helper
from_dict_to_rest_distribution()to recognize both distribution types - Added comprehensive test coverage including YAML configuration, schema round-trips, and case-insensitive handling
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| sdk/ml/azure-ai-ml/azure/ai/ml/constants/_job/job.py | Adds TORCH_DISTRIBUTED constant for backwards compatibility |
| sdk/ml/azure-ai-ml/azure/ai/ml/entities/_job/distribution.py | Updates SDK_TO_REST and DISTRIBUTION_TYPE_MAP to map both types to PyTorch classes |
| sdk/ml/azure-ai-ml/azure/ai/ml/_schema/job/distribution.py | Modifies PyTorchDistributionSchema to accept both pytorch and torch.distributed in validation |
| sdk/ml/azure-ai-ml/azure/ai/ml/entities/_job/pipeline/_pipeline_job_helpers.py | Updates from_dict_to_rest_distribution() to handle both types and updates error message |
| sdk/ml/azure-ai-ml/tests/test_configs/command_job/dist_job_pytorch_torch_distributed.yml | Adds YAML test fixture using torch.distributed type |
| sdk/ml/azure-ai-ml/tests/command_job/unittests/test_command_job_schema.py | Adds test for interchangeability of pytorch and torch.distributed types |
| sdk/ml/azure-ai-ml/tests/pipeline_job/unittests/test_distribution_helpers.py | Adds tests for pipeline helper function handling both types |
| if target_type == "ray": | ||
| return Ray(**distribution_dict) | ||
| msg = "Distribution type must be pytorch, mpi, tensorflow or ray: {}".format(target_type) | ||
| msg = "Distribution type must be pytorch, torch.distributed, mpi, tensorflow or ray: {}".format(target_type) |
There was a problem hiding this comment.
The error message lists "torch.distributed" separately from "pytorch", which might confuse users since torch.distributed is documented as an internal alias for backwards compatibility. Consider updating the message to either:
- Only mention "pytorch" as the valid type and omit "torch.distributed" from user-facing error messages (since it's a legacy alias)
- Clarify that "torch.distributed" is a legacy alias
For consistency with the PR's intention to support the alias transparently, option 1 would be more user-friendly.
| msg = "Distribution type must be pytorch, torch.distributed, mpi, tensorflow or ray: {}".format(target_type) | |
| msg = "Distribution type must be pytorch, mpi, tensorflow or ray: {}".format(target_type) |
| # Verify roundtrip for torch.distributed | ||
| rest_obj = job_torch_distributed._to_rest_object() | ||
| reconstructed = CommandJob._load_from_rest(rest_obj) | ||
| assert isinstance(reconstructed.distribution, PyTorchDistribution) | ||
| assert reconstructed.distribution.process_count_per_instance == 4 |
There was a problem hiding this comment.
The test validates that torch.distributed deserializes correctly and creates a PyTorchDistribution object. However, it doesn't verify that the resulting REST object uses the correct distribution type (RestDistributionType.PY_TORCH). Consider adding an assertion to check that the REST representation is correct, for example:
rest_obj = job_torch_distributed._to_rest_object()
assert rest_obj.properties.distribution.distribution_type == "PyTorch"This would ensure that the SDK_TO_REST mapping is working correctly and that the alias normalization is complete.
| def test_from_dict_to_rest_distribution_torch_distributed(self): | ||
| """Test that torch.distributed type is properly handled""" | ||
| distribution_dict = { | ||
| "distribution_type": "torch.distributed", | ||
| "process_count_per_instance": 4 | ||
| } | ||
| result = from_dict_to_rest_distribution(distribution_dict) | ||
| assert result is not None | ||
| assert hasattr(result, 'process_count_per_instance') | ||
| assert result.process_count_per_instance == 4 |
There was a problem hiding this comment.
The test validates that the from_dict_to_rest_distribution helper function accepts torch.distributed, but it doesn't verify that both pytorch and torch.distributed types result in the same REST object type. Consider adding an assertion to verify that both types produce equivalent REST objects:
# Compare results from both types
pytorch_dict = {"distribution_type": "pytorch", "process_count_per_instance": 4}
torch_dist_dict = {"distribution_type": "torch.distributed", "process_count_per_instance": 4}
pytorch_result = from_dict_to_rest_distribution(pytorch_dict)
torch_dist_result = from_dict_to_rest_distribution(torch_dist_dict)
assert type(pytorch_result) == type(torch_dist_result)This would strengthen the test by verifying that the alias normalization produces equivalent results.
| # Verify roundtrip for torch.distributed | ||
| rest_obj = job_torch_distributed._to_rest_object() | ||
| reconstructed = CommandJob._load_from_rest(rest_obj) | ||
| assert isinstance(reconstructed.distribution, PyTorchDistribution) | ||
| assert reconstructed.distribution.process_count_per_instance == 4 |
There was a problem hiding this comment.
The test validates round-trip serialization but doesn't verify what value the type field has after deserialization and re-serialization. According to the PR description, torch.distributed should be normalized to pytorch internally. Consider adding an assertion to verify the normalized type after round-trip:
rest_obj = job_torch_distributed._to_rest_object()
reconstructed = CommandJob._load_from_rest(rest_obj)
# Verify that the type is normalized to "pytorch" after round-trip
assert reconstructed.distribution.type == DistributionType.PYTORCHThis ensures that the internal normalization is working correctly and that torch.distributed is transparently converted to pytorch.
AML Designer stores PyTorch distribution configs as
type: torch.distributedfor SDK v1.5 compatibility, but SDK v2 only recognizestype: pytorch, causing deserialization failures.Changes
TORCH_DISTRIBUTED = "torch.distributed"toDistributionTypepytorchandtorch.distributednow map toPyTorchDistributioninDISTRIBUTION_TYPE_MAPandRestDistributionType.PY_TORCHinSDK_TO_RESTPyTorchDistributionSchemaaccepts both values inallowed_valuesfrom_dict_to_rest_distribution()handles both typesBehavior
The normalization is unidirectional:
torch.distributed→pytorchduring deserialization.PyTorchDistributionalways usespytorchinternally, maintaining consistency while accepting the legacy alias from service responses.Testing
torch.distributedWarning
Firewall rules blocked me from connecting to one or more addresses (expand for details)
I tried to connect to the following addresses, but was blocked by firewall rules:
pypi.org/home/REDACTED/work/azure-sdk-for-python/azure-sdk-for-python/.venv/bin/python /home/REDACTED/work/azure-sdk-for-python/azure-sdk-for-python/.venv/bin/python /home/REDACTED/work/azure-sdk-for-python/azure-sdk-for-python/.venv/lib/python3.9/site-packages/pip/__pip-REDACTED__.py install --ignore-installed --no-user --prefix /tmp/pip-build-env-bwbf7dov/overlay --no-warn-script-location --no-binary :none: --only-binary :none: -i REDACTED -- setuptools>=40.8.0(dns block)scanning-api.github.com/home/REDACTED/work/_temp/ghcca-node/node/bin/node /home/REDACTED/work/_temp/ghcca-node/node/bin/node --enable-source-maps /home/REDACTED/work/_temp/copilot-developer-action-main/dist/index.js(dns block)If you need me to access, download, or install something from one of these locations, you can either:
Original prompt
💬 We'd love your input! Share your thoughts on Copilot coding agent in our 2 minute survey.