Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class AzureStorageSchema(PathAwareSchema):
protocol = fields.Str()
description = fields.Str()
tags = fields.Dict(keys=fields.Str(), values=fields.Str())
subscription_id = fields.Str()
resource_group = fields.Str()


class AzureFileSchema(AzureStorageSchema):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ class AzureFileDatastore(Datastore):
:param credentials: Credentials to use for Azure ML workspace to connect to the storage. Defaults to None.
:type credentials: Union[~azure.ai.ml.entities.AccountKeyConfiguration,
~azure.ai.ml.entities.SasTokenConfiguration]
:param subscription_id: Azure subscription ID of the storage account. Defaults to None.
:type subscription_id: Optional[str]
:param resource_group: Azure resource group of the storage account. Defaults to None.
:type resource_group: Optional[str]
:param kwargs: A dictionary of additional configuration parameters.
:type kwargs: dict
"""
Expand All @@ -68,6 +72,8 @@ def __init__(
protocol: str = HTTPS,
properties: Optional[Dict] = None,
credentials: Optional[Union[AccountKeyConfiguration, SasTokenConfiguration]] = None,
subscription_id: Optional[str] = None,
resource_group: Optional[str] = None,
**kwargs: Any
):
kwargs[TYPE] = DatastoreType.AZURE_FILE
Expand All @@ -78,6 +84,8 @@ def __init__(
self.account_name = account_name
self.endpoint = endpoint
self.protocol = protocol
self.subscription_id = subscription_id
self.resource_group = resource_group
Comment on lines 84 to +88
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AzureFileDatastore now has subscription_id/resource_group fields that affect datastore identity/scope, but eq doesn’t compare them. This can cause two datastores with different ARM scope to compare equal and can hide regressions in round-trip tests. Include subscription_id and resource_group in the eq comparison (alongside endpoint/protocol).

Copilot uses AI. Check for mistakes.

def _to_rest_object(self) -> DatastoreData:
file_ds = RestAzureFileDatastore(
Expand All @@ -88,6 +96,8 @@ def _to_rest_object(self) -> DatastoreData:
protocol=self.protocol,
description=self.description,
tags=self.tags,
subscription_id=self.subscription_id,
resource_group=self.resource_group,
)
return DatastoreData(properties=file_ds)

Expand All @@ -109,6 +119,8 @@ def _from_rest_object(cls, datastore_resource: DatastoreData) -> "AzureFileDatas
file_share_name=properties.file_share_name,
description=properties.description,
tags=properties.tags,
subscription_id=properties.subscription_id,
resource_group=properties.resource_group,
)

def __eq__(self, other: Any) -> bool:
Expand Down Expand Up @@ -152,6 +164,10 @@ class AzureBlobDatastore(Datastore):
:param credentials: Credentials to use for Azure ML workspace to connect to the storage.
:type credentials: Union[~azure.ai.ml.entities.AccountKeyConfiguration,
~azure.ai.ml.entities.SasTokenConfiguration]
:param subscription_id: Azure subscription ID of the storage account. Defaults to None.
:type subscription_id: Optional[str]
:param resource_group: Azure resource group of the storage account. Defaults to None.
:type resource_group: Optional[str]
:param kwargs: A dictionary of additional configuration parameters.
:type kwargs: dict
"""
Expand All @@ -168,6 +184,8 @@ def __init__(
protocol: str = HTTPS,
properties: Optional[Dict] = None,
credentials: Optional[Union[AccountKeyConfiguration, SasTokenConfiguration]] = None,
subscription_id: Optional[str] = None,
resource_group: Optional[str] = None,
**kwargs: Any
):
kwargs[TYPE] = DatastoreType.AZURE_BLOB
Expand All @@ -179,6 +197,8 @@ def __init__(
self.account_name = account_name
self.endpoint = endpoint if endpoint else _get_storage_endpoint_from_metadata()
self.protocol = protocol
self.subscription_id = subscription_id
self.resource_group = resource_group

Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AzureBlobDatastore now stores subscription_id/resource_group, but eq does not include these fields. That means objects with different storage ARM scope will compare equal, and round-trip equality assertions won’t validate these new properties. Update eq to compare subscription_id and resource_group as well.

Suggested change
def __eq__(self, other: object) -> bool:
if not isinstance(other, AzureBlobDatastore):
return NotImplemented
return (
super().__eq__(other)
and self.subscription_id == other.subscription_id
and self.resource_group == other.resource_group
)

Copilot uses AI. Check for mistakes.
def _to_rest_object(self) -> DatastoreData:
blob_ds = RestAzureBlobDatastore(
Expand All @@ -189,6 +209,8 @@ def _to_rest_object(self) -> DatastoreData:
protocol=self.protocol,
tags=self.tags,
description=self.description,
subscription_id=self.subscription_id,
resource_group=self.resource_group,
)
return DatastoreData(properties=blob_ds)

Expand All @@ -210,6 +232,8 @@ def _from_rest_object(cls, datastore_resource: DatastoreData) -> "AzureBlobDatas
container_name=properties.container_name,
description=properties.description,
tags=properties.tags,
subscription_id=properties.subscription_id,
resource_group=properties.resource_group,
)

def __eq__(self, other: Any) -> bool:
Expand Down Expand Up @@ -256,6 +280,10 @@ class AzureDataLakeGen2Datastore(Datastore):
]
:param properties: The asset property dictionary.
:type properties: dict[str, str]
:param subscription_id: Azure subscription ID of the storage account. Defaults to None.
:type subscription_id: Optional[str]
:param resource_group: Azure resource group of the storage account. Defaults to None.
:type resource_group: Optional[str]
:param kwargs: A dictionary of additional configuration parameters.
:type kwargs: dict
"""
Expand All @@ -272,6 +300,8 @@ def __init__(
protocol: str = HTTPS,
properties: Optional[Dict] = None,
credentials: Optional[Union[ServicePrincipalConfiguration, CertificateConfiguration]] = None,
subscription_id: Optional[str] = None,
resource_group: Optional[str] = None,
**kwargs: Any
):
kwargs[TYPE] = DatastoreType.AZURE_DATA_LAKE_GEN2
Expand All @@ -283,6 +313,8 @@ def __init__(
self.filesystem = filesystem
self.endpoint = endpoint
self.protocol = protocol
self.subscription_id = subscription_id
self.resource_group = resource_group

def _to_rest_object(self) -> DatastoreData:
gen2_ds = RestAzureDataLakeGen2Datastore(
Expand All @@ -293,6 +325,8 @@ def _to_rest_object(self) -> DatastoreData:
protocol=self.protocol,
description=self.description,
tags=self.tags,
subscription_id=self.subscription_id,
resource_group=self.resource_group,
)
return DatastoreData(properties=gen2_ds)

Expand All @@ -316,6 +350,8 @@ def _from_rest_object(cls, datastore_resource: DatastoreData) -> "AzureDataLakeG
filesystem=properties.filesystem,
description=properties.description,
tags=properties.tags,
subscription_id=properties.subscription_id,
resource_group=properties.resource_group,
)

def __eq__(self, other: Any) -> bool:
Expand Down
Loading