From d6f81ca20527773c891ebb7a972b33d49519ed4f Mon Sep 17 00:00:00 2001 From: Yash Date: Fri, 27 Mar 2026 14:59:07 +1100 Subject: [PATCH] Add function implementation --- src/azure-cli-core/azure/cli/core/_profile.py | 17 +++ .../azure/cli/core/tests/test_profile.py | 129 ++++++++++++++++++ 2 files changed, 146 insertions(+) diff --git a/src/azure-cli-core/azure/cli/core/_profile.py b/src/azure-cli-core/azure/cli/core/_profile.py index c1950d26c0e..9b192b84199 100644 --- a/src/azure-cli-core/azure/cli/core/_profile.py +++ b/src/azure-cli-core/azure/cli/core/_profile.py @@ -856,6 +856,23 @@ def find_using_specific_tenant(self, tenant, credential, tenant_id_description=N self.tenants.append(tenant) return all_subscriptions + def find_specific_subscriptions(self, tenant, credential, subscription_ids): + """Fetch specific subscriptions by ID using GET /subscriptions/{id} + instead of listing all subscriptions. + https://learn.microsoft.com/en-us/rest/api/resources/subscriptions/get + """ + client = self._create_subscription_client(credential) + all_subscriptions = [] + for sub_id in subscription_ids: + try: + s = client.subscriptions.get(sub_id) + _attach_token_tenant(s, tenant) + all_subscriptions.append(s) + except Exception as ex: # pylint: disable=broad-except + logger.warning("Failed to retrieve subscription %s: %s", sub_id, ex) + self.tenants.append(tenant) + return all_subscriptions + def _create_subscription_client(self, credential): from azure.cli.core.profiles import ResourceType, get_api_version from azure.cli.core.profiles._shared import get_client_class diff --git a/src/azure-cli-core/azure/cli/core/tests/test_profile.py b/src/azure-cli-core/azure/cli/core/tests/test_profile.py index 0f899cebacb..20ae4c892cb 100644 --- a/src/azure-cli-core/azure/cli/core/tests/test_profile.py +++ b/src/azure-cli-core/azure/cli/core/tests/test_profile.py @@ -1638,5 +1638,134 @@ class SimpleManagedByTenant: assert d == {'managedByTenants': [{"tenantId": tenant_id}]} +class TestSubscriptionFinderFindSpecific(unittest.TestCase): + """Tests for SubscriptionFinder.find_specific_subscriptions()""" + + @classmethod + def setUpClass(cls): + cls.tenant_id = 'test.onmicrosoft.com' + cls.sub_id_1 = '00000000-0000-0000-0000-000000000001' + cls.sub_id_2 = '00000000-0000-0000-0000-000000000002' + + cls.subscription1_raw = SubscriptionStub( + 'subscriptions/{}'.format(cls.sub_id_1), + 'sub1', 'Enabled', tenant_id=cls.tenant_id) + cls.subscription2_raw = SubscriptionStub( + 'subscriptions/{}'.format(cls.sub_id_2), + 'sub2', 'Enabled', tenant_id=cls.tenant_id) + + @mock.patch('azure.cli.core._profile.SubscriptionFinder._create_subscription_client', autospec=True) + def test_find_specific_subscriptions_single(self, create_subscription_client_mock): + """Single subscription ID is fetched via GET, not LIST.""" + cli = DummyCli() + mock_client = mock.MagicMock() + mock_client.subscriptions.get.return_value = deepcopy(self.subscription1_raw) + create_subscription_client_mock.return_value = mock_client + + finder = SubscriptionFinder(cli) + credential = mock.MagicMock() + result = finder.find_specific_subscriptions(self.tenant_id, credential, [self.sub_id_1]) + + # Assert GET was called, LIST was not + mock_client.subscriptions.get.assert_called_once_with(self.sub_id_1) + mock_client.subscriptions.list.assert_not_called() + + # Assert result + self.assertEqual(len(result), 1) + self.assertEqual(result[0].subscription_id, self.sub_id_1) + # Assert tenant_id is attached (by _attach_token_tenant) + self.assertEqual(result[0].tenant_id, self.tenant_id) + # Assert tenant is tracked + self.assertIn(self.tenant_id, finder.tenants) + + @mock.patch('azure.cli.core._profile.SubscriptionFinder._create_subscription_client', autospec=True) + def test_find_specific_subscriptions_multiple(self, create_subscription_client_mock): + """Multiple subscription IDs are each fetched individually via GET.""" + cli = DummyCli() + mock_client = mock.MagicMock() + mock_client.subscriptions.get.side_effect = [ + deepcopy(self.subscription1_raw), + deepcopy(self.subscription2_raw) + ] + create_subscription_client_mock.return_value = mock_client + + finder = SubscriptionFinder(cli) + credential = mock.MagicMock() + result = finder.find_specific_subscriptions( + self.tenant_id, credential, [self.sub_id_1, self.sub_id_2]) + + # Assert GET was called for each sub + self.assertEqual(mock_client.subscriptions.get.call_count, 2) + mock_client.subscriptions.get.assert_any_call(self.sub_id_1) + mock_client.subscriptions.get.assert_any_call(self.sub_id_2) + mock_client.subscriptions.list.assert_not_called() + + # Assert both results returned + self.assertEqual(len(result), 2) + self.assertEqual(result[0].subscription_id, self.sub_id_1) + self.assertEqual(result[1].subscription_id, self.sub_id_2) + + @mock.patch('azure.cli.core._profile.SubscriptionFinder._create_subscription_client', autospec=True) + def test_find_specific_subscriptions_inaccessible_warns_and_continues(self, create_subscription_client_mock): + """If a subscription is inaccessible, log warning and continue with others.""" + cli = DummyCli() + mock_client = mock.MagicMock() + mock_client.subscriptions.get.side_effect = [ + Exception("Subscription not found or not accessible"), + deepcopy(self.subscription2_raw) + ] + create_subscription_client_mock.return_value = mock_client + + finder = SubscriptionFinder(cli) + credential = mock.MagicMock() + + with mock.patch('azure.cli.core._profile.logger') as mock_logger: + result = finder.find_specific_subscriptions( + self.tenant_id, credential, [self.sub_id_1, self.sub_id_2]) + + # Assert warning was logged for the failed sub + mock_logger.warning.assert_called_once() + self.assertIn(self.sub_id_1, mock_logger.warning.call_args[0][1]) + + # Assert only the accessible sub is returned + self.assertEqual(len(result), 1) + self.assertEqual(result[0].subscription_id, self.sub_id_2) + # Assert tenant is still tracked + self.assertIn(self.tenant_id, finder.tenants) + + @mock.patch('azure.cli.core._profile.SubscriptionFinder._create_subscription_client', autospec=True) + def test_find_specific_subscriptions_all_inaccessible(self, create_subscription_client_mock): + """If all subscriptions are inaccessible, return empty list.""" + cli = DummyCli() + mock_client = mock.MagicMock() + mock_client.subscriptions.get.side_effect = Exception("Not found") + create_subscription_client_mock.return_value = mock_client + + finder = SubscriptionFinder(cli) + credential = mock.MagicMock() + result = finder.find_specific_subscriptions( + self.tenant_id, credential, [self.sub_id_1]) + + self.assertEqual(len(result), 0) + # Tenant is still tracked even with no results + self.assertIn(self.tenant_id, finder.tenants) + + @mock.patch('azure.cli.core._profile.SubscriptionFinder._create_subscription_client', autospec=True) + def test_find_specific_subscriptions_empty_list(self, create_subscription_client_mock): + """Empty subscription_ids list returns empty result.""" + cli = DummyCli() + mock_client = mock.MagicMock() + create_subscription_client_mock.return_value = mock_client + + finder = SubscriptionFinder(cli) + credential = mock.MagicMock() + result = finder.find_specific_subscriptions(self.tenant_id, credential, []) + + mock_client.subscriptions.get.assert_not_called() + mock_client.subscriptions.list.assert_not_called() + self.assertEqual(len(result), 0) + self.assertIn(self.tenant_id, finder.tenants) + + if __name__ == '__main__': unittest.main()