diff --git a/src/py/mat3ra/api_client/__init__.py b/src/py/mat3ra/api_client/__init__.py index 43e090a..c469dad 100644 --- a/src/py/mat3ra/api_client/__init__.py +++ b/src/py/mat3ra/api_client/__init__.py @@ -1,18 +1,18 @@ -# ruff: noqa: F401 try: from ._version import version as __version__ except ModuleNotFoundError: __version__ = None -from mat3ra.api_client.endpoints.bank_materials import BankMaterialEndpoints -from mat3ra.api_client.endpoints.bank_workflows import BankWorkflowEndpoints -from mat3ra.api_client.endpoints.jobs import JobEndpoints -from mat3ra.api_client.endpoints.login import LoginEndpoint -from mat3ra.api_client.endpoints.logout import LogoutEndpoint -from mat3ra.api_client.endpoints.materials import MaterialEndpoints -from mat3ra.api_client.endpoints.metaproperties import MetaPropertiesEndpoints -from mat3ra.api_client.endpoints.projects import ProjectEndpoints -from mat3ra.api_client.endpoints.properties import PropertiesEndpoints -from mat3ra.api_client.endpoints.workflows import WorkflowEndpoints - -from mat3ra.api_client.client import APIClient +from .client import APIClient +from .constants import ACCESS_TOKEN_ENV_VAR, CLIENT_ID, SCOPE, build_oidc_base_url +from .models import Account, APIEnv, AuthContext, AuthEnv +from .endpoints.bank_materials import BankMaterialEndpoints +from .endpoints.bank_workflows import BankWorkflowEndpoints +from .endpoints.jobs import JobEndpoints +from .endpoints.login import LoginEndpoint +from .endpoints.logout import LogoutEndpoint +from .endpoints.materials import MaterialEndpoints +from .endpoints.metaproperties import MetaPropertiesEndpoints +from .endpoints.projects import ProjectEndpoints +from .endpoints.properties import PropertiesEndpoints +from .endpoints.workflows import WorkflowEndpoints diff --git a/src/py/mat3ra/api_client/client.py b/src/py/mat3ra/api_client/client.py index fd9867d..00914aa 100644 --- a/src/py/mat3ra/api_client/client.py +++ b/src/py/mat3ra/api_client/client.py @@ -1,90 +1,21 @@ import os -from typing import Any, Optional, Tuple +import re +from typing import Any, List, Optional, Tuple import requests -from pydantic import BaseModel, ConfigDict, Field - -from mat3ra.api_client.endpoints.bank_materials import BankMaterialEndpoints -from mat3ra.api_client.endpoints.bank_workflows import BankWorkflowEndpoints -from mat3ra.api_client.endpoints.jobs import JobEndpoints -from mat3ra.api_client.endpoints.materials import MaterialEndpoints -from mat3ra.api_client.endpoints.metaproperties import MetaPropertiesEndpoints -from mat3ra.api_client.endpoints.projects import ProjectEndpoints -from mat3ra.api_client.endpoints.properties import PropertiesEndpoints -from mat3ra.api_client.endpoints.workflows import WorkflowEndpoints - -# Default API Configuration -DEFAULT_API_HOST = "platform-new.mat3ra.com" -DEFAULT_API_PORT = 443 -DEFAULT_API_VERSION = "2018-10-01" -DEFAULT_API_SECURE = True - -# Environment Variable Names -ACCESS_TOKEN_ENV_VAR = "OIDC_ACCESS_TOKEN" -API_HOST_ENV_VAR = "API_HOST" -API_PORT_ENV_VAR = "API_PORT" -API_VERSION_ENV_VAR = "API_VERSION" -API_SECURE_ENV_VAR = "API_SECURE" -ACCOUNT_ID_ENV_VAR = "ACCOUNT_ID" -AUTH_TOKEN_ENV_VAR = "AUTH_TOKEN" - -# Default OIDC Configuration -CLIENT_ID = "cli-device-client" -SCOPE = "openid profile email" - -# API Paths -USERS_ME_PATH = "/api/v1/users/me" - - -class AuthContext(BaseModel): - access_token: Optional[str] = None - account_id: Optional[str] = None - auth_token: Optional[str] = None - - -class APIEnv(BaseModel): - host: str = Field(default=DEFAULT_API_HOST, validation_alias=API_HOST_ENV_VAR) - port: int = Field(default=DEFAULT_API_PORT, validation_alias=API_PORT_ENV_VAR) - version: str = Field(default=DEFAULT_API_VERSION, validation_alias=API_VERSION_ENV_VAR) - secure: bool = Field(default=DEFAULT_API_SECURE, validation_alias=API_SECURE_ENV_VAR) +from pydantic import BaseModel, ConfigDict - @classmethod - def from_env(cls) -> "APIEnv": - return cls.model_validate(os.environ) - - -class AuthEnv(BaseModel): - access_token: Optional[str] = Field(None, validation_alias=ACCESS_TOKEN_ENV_VAR) - account_id: Optional[str] = Field(None, validation_alias=ACCOUNT_ID_ENV_VAR) - auth_token: Optional[str] = Field(None, validation_alias=AUTH_TOKEN_ENV_VAR) - - @classmethod - def from_env(cls) -> "AuthEnv": - return cls.model_validate(os.environ) - - -class Account(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True, validate_assignment=True) +from .constants import ACCESS_TOKEN_ENV_VAR, _build_base_url +from .endpoints.bank_materials import BankMaterialEndpoints +from .endpoints.bank_workflows import BankWorkflowEndpoints +from .endpoints.jobs import JobEndpoints +from .endpoints.materials import MaterialEndpoints +from .endpoints.metaproperties import MetaPropertiesEndpoints +from .endpoints.projects import ProjectEndpoints +from .endpoints.properties import PropertiesEndpoints +from .endpoints.workflows import WorkflowEndpoints +from .models import Account, APIEnv, AuthContext, AuthEnv - client: Any = Field(exclude=True, repr=False) - id_cache: Optional[str] = None - - @property - def id(self) -> str: - if self.id_cache: - return self.id_cache - self.id_cache = self.client._resolve_account_id() - return self.id_cache - - -def _build_base_url(host: str, port: int, secure: bool, path: str) -> str: - protocol = "https" if secure else "http" - port_str = f":{port}" if port not in (80, 443) else "" - return f"{protocol}://{host}{port_str}{path}" - -# Used in API-examples utils -def build_oidc_base_url(host: str, port: int, secure: bool) -> str: - return _build_base_url(host, port, secure, "/oidc") class APIClient(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow", validate_assignment=True) @@ -99,8 +30,15 @@ class APIClient(BaseModel): def model_post_init(self, __context: Any) -> None: self.my_account = Account(client=self) self.account = self.my_account + self._my_organization: Optional[Account] = None self._init_endpoints(self.timeout_seconds) + @property + def my_organization(self) -> Optional[Account]: + if self._my_organization is None: + self._my_organization = self.get_default_organization() + return self._my_organization + @classmethod def env(cls) -> APIEnv: return APIEnv.from_env() @@ -110,39 +48,17 @@ def auth_env(cls) -> AuthEnv: return AuthEnv.from_env() def _init_endpoints(self, timeout_seconds: int) -> None: - kwargs = {"timeout": timeout_seconds, "auth": self.auth} - account_id = self.auth.account_id or "" - auth_token = self.auth.auth_token or "" - self._init_core_endpoints(kwargs, account_id, auth_token) - self._init_bank_endpoints(kwargs, account_id, auth_token) - - def _init_core_endpoints(self, kwargs: dict, account_id: str, auth_token: str) -> None: - self.materials = MaterialEndpoints( - self.host, self.port, account_id, auth_token, version=self.version, secure=self.secure, **kwargs - ) - self.workflows = WorkflowEndpoints( - self.host, self.port, account_id, auth_token, version=self.version, secure=self.secure, **kwargs - ) - self.jobs = JobEndpoints( - self.host, self.port, account_id, auth_token, version=self.version, secure=self.secure, **kwargs - ) - self.projects = ProjectEndpoints( - self.host, self.port, account_id, auth_token, version=self.version, secure=self.secure, **kwargs - ) - self.properties = PropertiesEndpoints( - self.host, self.port, account_id, auth_token, version=self.version, secure=self.secure, **kwargs - ) - self.metaproperties = MetaPropertiesEndpoints( - self.host, self.port, account_id, auth_token, version=self.version, secure=self.secure, **kwargs - ) - - def _init_bank_endpoints(self, kwargs: dict, account_id: str, auth_token: str) -> None: - self.bank_materials = BankMaterialEndpoints( - self.host, self.port, account_id, auth_token, version=self.version, secure=self.secure, **kwargs - ) - self.bank_workflows = BankWorkflowEndpoints( - self.host, self.port, account_id, auth_token, version=self.version, secure=self.secure, **kwargs - ) + base_args = (self.host, self.port, self.auth.account_id or "", self.auth.auth_token or "") + base_kwargs = {"version": self.version, "secure": self.secure, "timeout": timeout_seconds, "auth": self.auth} + + self.materials = MaterialEndpoints(*base_args, **base_kwargs) + self.workflows = WorkflowEndpoints(*base_args, **base_kwargs) + self.jobs = JobEndpoints(*base_args, **base_kwargs) + self.projects = ProjectEndpoints(*base_args, **base_kwargs) + self.properties = PropertiesEndpoints(*base_args, **base_kwargs) + self.metaproperties = MetaPropertiesEndpoints(*base_args, **base_kwargs) + self.bank_materials = BankMaterialEndpoints(*base_args, **base_kwargs) + self.bank_workflows = BankWorkflowEndpoints(*base_args, **base_kwargs) @staticmethod def _resolve_config( @@ -195,27 +111,73 @@ def authenticate( auth_token: Optional[str] = None, timeout_seconds: int = 60, ) -> "APIClient": - host_value, port_value, version_value, secure_value = cls._resolve_config(host, port, version, secure, - cls.env()) + host_value, port_value, version_value, secure_value = cls._resolve_config( + host, port, version, secure, cls.env() + ) auth = cls._auth_from_env(access_token=access_token, account_id=account_id, auth_token=auth_token) cls._validate_auth(auth) - return cls(host=host_value, port=port_value, version=version_value, secure=secure_value, auth=auth, - timeout_seconds=timeout_seconds) - - def _resolve_account_id(self) -> str: - account_id = self.auth.account_id or os.environ.get(ACCOUNT_ID_ENV_VAR) - if account_id: - self.auth.account_id = account_id - return account_id + return cls( + host=host_value, + port=port_value, + version=version_value, + secure=secure_value, + auth=auth, + timeout_seconds=timeout_seconds, + ) + def _fetch_data(self) -> dict: access_token = self.auth.access_token or os.environ.get(ACCESS_TOKEN_ENV_VAR) if not access_token: - raise ValueError("ACCOUNT_ID is not set and no OIDC access token is available.") + raise ValueError("Access token is required to fetch user data") - url = _build_base_url(self.host, self.port, self.secure, USERS_ME_PATH) + url = _build_base_url(self.host, self.port, self.secure, "/api/v1/users/me") response = requests.get(url, headers={"Authorization": f"Bearer {access_token}"}, timeout=30) response.raise_for_status() - account_id = response.json()["data"]["user"]["entity"]["defaultAccountId"] - os.environ[ACCOUNT_ID_ENV_VAR] = account_id - self.auth.account_id = account_id - return account_id + return response.json()["data"] + + def _fetch_user_accounts(self) -> List[dict]: + return self._fetch_data().get("accounts", []) + + def list_accounts(self) -> List[dict]: + accounts = self._fetch_user_accounts() + return [ + { + "_id": account["entity"]["_id"], + "name": account["entity"].get("name", ""), + "type": account["entity"].get("type", "personal"), + "isDefault": account.get("isDefault", False), + } + for account in accounts + ] + + def get_account(self, name: Optional[str] = None, index: Optional[int] = None) -> Account: + """Get account by name (partial regex match) or index from the list of user accounts.""" + if name is None and index is None: + raise ValueError("Either 'name' or 'index' must be provided") + + accounts = self._fetch_user_accounts() + + if index is not None: + return Account(client=self, entity_cache=accounts[index]["entity"]) + + pattern = re.compile(name, re.IGNORECASE) + matches = [account for account in accounts if pattern.search(account["entity"].get("name", ""))] + + if not matches: + raise ValueError(f"No account found matching '{name}'") + if len(matches) > 1: + names = [acc["entity"].get("name", "") for acc in matches] + raise ValueError(f"Multiple accounts match '{name}': {names}") + + return Account(client=self, entity_cache=matches[0]["entity"]) + + def get_default_organization(self) -> Optional[Account]: + accounts = self._fetch_user_accounts() + organizations = [account for account in accounts if + account["entity"].get("type") in ("organization", "enterprise")] + + if not organizations: + return None + + default_org = next((org for org in organizations if org.get("isDefault")), organizations[0]) + return Account(client=self, entity_cache=default_org["entity"]) diff --git a/src/py/mat3ra/api_client/constants.py b/src/py/mat3ra/api_client/constants.py new file mode 100644 index 0000000..b22b5ff --- /dev/null +++ b/src/py/mat3ra/api_client/constants.py @@ -0,0 +1,20 @@ +# Environment Variable Names - Exported for external use (e.g., api-examples) +ACCESS_TOKEN_ENV_VAR = "OIDC_ACCESS_TOKEN" +ACCOUNT_ID_ENV_VAR = "ACCOUNT_ID" +AUTH_TOKEN_ENV_VAR = "AUTH_TOKEN" + +# OIDC Configuration - Exported for external use +CLIENT_ID = "cli-device-client" +SCOPE = "openid profile email" + + +def _build_base_url(host: str, port: int, secure: bool, path: str) -> str: + protocol = "https" if secure else "http" + port_str = f":{port}" if port not in (80, 443) else "" + return f"{protocol}://{host}{port_str}{path}" + + +def build_oidc_base_url(host: str, port: int, secure: bool) -> str: + """Used in api-examples utils.""" + return _build_base_url(host, port, secure, "/oidc") + diff --git a/src/py/mat3ra/api_client/endpoints/jobs.py b/src/py/mat3ra/api_client/endpoints/jobs.py index d5b703c..0001695 100644 --- a/src/py/mat3ra/api_client/endpoints/jobs.py +++ b/src/py/mat3ra/api_client/endpoints/jobs.py @@ -54,7 +54,8 @@ def terminate(self, id_): """ self.request("POST", "/".join((self.name, id_, "submit")), headers=self.headers) - def get_config(self, material_ids, workflow_id, project_id, owner_id, name, compute=None, is_multi_material=False): + def build_config(self, material_ids, workflow_id, project_id, owner_id, name, compute=None, + is_multi_material=False): """ Returns a job config based on the given parameters. @@ -85,7 +86,7 @@ def get_config(self, material_ids, workflow_id, project_id, owner_id, name, comp config.update({"_material": {"_id": material_ids[0]}}) return config - def get_compute(self, cluster, ppn=1, nodes=1, queue="D", time_limit="01:00:00", notify="abe"): + def build_compute_config(self, cluster, ppn=1, nodes=1, queue="D", time_limit="01:00:00", notify="abe"): """ Returns job compute configuration. @@ -128,7 +129,7 @@ def create_by_ids(self, materials, workflow_id, project_id, prefix, owner_id=Non jobs = [] for material in materials: job_name = " ".join((prefix, material["formula"])) - job_config = self.get_config([material["_id"]], workflow_id, project_id, owner_id, job_name, compute) + job_config = self.build_config([material["_id"]], workflow_id, project_id, owner_id, job_name, compute) jobs.append(self.create(job_config)) return jobs diff --git a/src/py/mat3ra/api_client/endpoints/properties.py b/src/py/mat3ra/api_client/endpoints/properties.py index d6208cb..6eaade0 100644 --- a/src/py/mat3ra/api_client/endpoints/properties.py +++ b/src/py/mat3ra/api_client/endpoints/properties.py @@ -3,11 +3,11 @@ class BasePropertiesEndpoints(EntityEndpoint): - def get_property_selector(self, job_id, unit_flowchart_id, property_name): + def build_property_selector(self, job_id, unit_flowchart_id, property_name): return {"source.info.jobId": job_id, "source.info.unitId": unit_flowchart_id, "data.name": property_name} def get_property(self, job_id, unit_flowchart_id, property_name): - selector = self.get_property_selector(job_id, unit_flowchart_id, property_name) + selector = self.build_property_selector(job_id, unit_flowchart_id, property_name) return self.list(query=selector)[0] def get_band_gap_by_type(self, job_id, unit_flowchart_id, type): diff --git a/src/py/mat3ra/api_client/models.py b/src/py/mat3ra/api_client/models.py new file mode 100644 index 0000000..e834ed8 --- /dev/null +++ b/src/py/mat3ra/api_client/models.py @@ -0,0 +1,83 @@ +import os +from typing import Any, Optional + +from pydantic import BaseModel, ConfigDict, Field + +from .constants import ACCESS_TOKEN_ENV_VAR, ACCOUNT_ID_ENV_VAR, AUTH_TOKEN_ENV_VAR + + +class AuthContext(BaseModel): + access_token: Optional[str] = None + account_id: Optional[str] = None + auth_token: Optional[str] = None + + +class APIEnv(BaseModel): + host: str = Field(default="platform-new.mat3ra.com", validation_alias="API_HOST") + port: int = Field(default=443, validation_alias="API_PORT") + version: str = Field(default="2018-10-01", validation_alias="API_VERSION") + secure: bool = Field(default=True, validation_alias="API_SECURE") + + @classmethod + def from_env(cls) -> "APIEnv": + return cls.model_validate(os.environ) + + +class AuthEnv(BaseModel): + access_token: Optional[str] = Field(None, validation_alias=ACCESS_TOKEN_ENV_VAR) + account_id: Optional[str] = Field(None, validation_alias=ACCOUNT_ID_ENV_VAR) + auth_token: Optional[str] = Field(None, validation_alias=AUTH_TOKEN_ENV_VAR) + + @classmethod + def from_env(cls) -> "AuthEnv": + return cls.model_validate(os.environ) + + +class Account(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True, validate_assignment=True) + + client: Any = Field(exclude=True, repr=False) + entity_cache: Optional[dict] = None + + @property + def id(self) -> str: + if not self.entity_cache: + self._get_entity() + return self.entity_cache["_id"] + + @property + def name(self) -> str: + if not self.entity_cache: + self._get_entity() + return self.entity_cache.get("name", "") + + def _get_entity(self) -> None: + account_id, accounts = self._get_account_id_and_accounts() + self.entity_cache = self._find_account_entity(account_id, accounts) + + def _get_account_id_and_accounts(self) -> tuple[str, Optional[list]]: + account_id = self.client.auth.account_id or os.environ.get(ACCOUNT_ID_ENV_VAR) + + if account_id: + return account_id, None + + if not (self.client.auth.access_token or os.environ.get(ACCESS_TOKEN_ENV_VAR)): + raise ValueError("ACCOUNT_ID is not set and no OIDC access token is available.") + + data = self.client._fetch_data() + account_id = data["user"]["entity"]["defaultAccountId"] + os.environ[ACCOUNT_ID_ENV_VAR] = account_id + self.client.auth.account_id = account_id + return account_id, data.get("accounts", []) + + def _find_account_entity(self, account_id: str, accounts: Optional[list]) -> dict: + if accounts is None and (self.client.auth.access_token or os.environ.get(ACCESS_TOKEN_ENV_VAR)): + accounts = self.client._fetch_user_accounts() + + if accounts: + for account in accounts: + if account["entity"]["_id"] == account_id: + return account["entity"] + + return {"_id": account_id} + diff --git a/tests/py/unit/test_client.py b/tests/py/unit/test_client.py index 59306ad..b19aa41 100644 --- a/tests/py/unit/test_client.py +++ b/tests/py/unit/test_client.py @@ -17,6 +17,28 @@ ME_ACCOUNT_ID = "my-account-id" USERS_ME_RESPONSE = {"data": {"user": {"entity": {"defaultAccountId": ME_ACCOUNT_ID}}}} +ACCOUNTS_RESPONSE = { + "data": { + "user": { + "entity": {"defaultAccountId": ME_ACCOUNT_ID} + }, + "accounts": [ + { + "entity": {"_id": "user-acc-1", "name": "John Doe", "type": "personal"}, + "isDefault": True, + }, + { + "entity": {"_id": "org-acc-1", "name": "Acme Corp", "type": "enterprise"}, + "isDefault": True, + }, + { + "entity": {"_id": "org-acc-2", "name": "Beta Industries", "type": "organization"}, + "isDefault": False, + }, + ], + } +} + class APIClientUnitTest(EndpointBaseUnitTest): def _base_env(self): @@ -27,9 +49,9 @@ def _base_env(self): "API_SECURE": API_SECURE_FALSE, } - def _mock_users_me(self, mock_get): + def _mock_users_me(self, mock_get, response=None): mock_resp = mock.Mock() - mock_resp.json.return_value = USERS_ME_RESPONSE + mock_resp.json.return_value = response or USERS_ME_RESPONSE mock_resp.raise_for_status.return_value = None mock_get.return_value = mock_resp @@ -61,8 +83,16 @@ def test_my_account_id_uses_existing_account_id(self, mock_get): @mock.patch("requests.get") def test_my_account_id_fetches_and_caches(self, mock_get): env = self._base_env() | {"OIDC_ACCESS_TOKEN": OIDC_ACCESS_TOKEN} + response_with_account = { + "data": { + "user": {"entity": {"defaultAccountId": ME_ACCOUNT_ID}}, + "accounts": [ + {"entity": {"_id": ME_ACCOUNT_ID, "name": "Test User", "type": "personal"}, "isDefault": True} + ], + } + } with mock.patch.dict("os.environ", env, clear=True): - self._mock_users_me(mock_get) + self._mock_users_me(mock_get, response_with_account) client = APIClient.authenticate() self.assertEqual(client.my_account.id, ME_ACCOUNT_ID) self.assertEqual(client.my_account.id, ME_ACCOUNT_ID) @@ -71,3 +101,44 @@ def test_my_account_id_fetches_and_caches(self, mock_get): self.assertEqual(mock_get.call_args[1]["headers"]["Authorization"], f"Bearer {OIDC_ACCESS_TOKEN}") self.assertEqual(mock_get.call_args[1]["timeout"], 30) self.assertEqual(os.environ.get("ACCOUNT_ID"), ME_ACCOUNT_ID) + + @mock.patch("requests.get") + def test_list_accounts(self, mock_get): + env = self._base_env() | {"OIDC_ACCESS_TOKEN": OIDC_ACCESS_TOKEN} + with mock.patch.dict("os.environ", env, clear=True): + self._mock_users_me(mock_get, ACCOUNTS_RESPONSE) + client = APIClient.authenticate() + accounts = client.list_accounts() + + self.assertEqual(len(accounts), 3) + self.assertEqual(accounts[0]["_id"], "user-acc-1") + self.assertEqual(accounts[0]["name"], "John Doe") + self.assertEqual(accounts[0]["type"], "personal") + self.assertTrue(accounts[0]["isDefault"]) + self.assertEqual(accounts[1]["_id"], "org-acc-1") + self.assertEqual(accounts[1]["name"], "Acme Corp") + self.assertEqual(accounts[1]["type"], "enterprise") + + @mock.patch("requests.get") + def test_get_account(self, mock_get): + env = self._base_env() | {"OIDC_ACCESS_TOKEN": OIDC_ACCESS_TOKEN} + with mock.patch.dict("os.environ", env, clear=True): + self._mock_users_me(mock_get, ACCOUNTS_RESPONSE) + client = APIClient.authenticate() + + account = client.get_account(index=1) + self.assertEqual(account.id, "org-acc-1") + self.assertEqual(account.name, "Acme Corp") + + self.assertEqual(client.get_account(name="Acme").id, "org-acc-1") + self.assertEqual(client.get_account(name="Beta.*").id, "org-acc-2") + + @mock.patch("requests.get") + def test_my_organization(self, mock_get): + env = self._base_env() | {"OIDC_ACCESS_TOKEN": OIDC_ACCESS_TOKEN} + with mock.patch.dict("os.environ", env, clear=True): + self._mock_users_me(mock_get, ACCOUNTS_RESPONSE) + client = APIClient.authenticate() + org = client.my_organization + self.assertEqual(org.id, "org-acc-1") + self.assertEqual(org.name, "Acme Corp")