Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions src/py/mat3ra/api_client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
# ruff: noqa: F401
try:
from ._version import version as __version__
except ModuleNotFoundError:
__version__ = None

from mat3ra.api_client.endpoints.bank_materials import BankMaterialEndpoints
from mat3ra.api_client.endpoints.bank_workflows import BankWorkflowEndpoints
from mat3ra.api_client.endpoints.jobs import JobEndpoints
from mat3ra.api_client.endpoints.login import LoginEndpoint
from mat3ra.api_client.endpoints.logout import LogoutEndpoint
from mat3ra.api_client.endpoints.materials import MaterialEndpoints
from mat3ra.api_client.endpoints.metaproperties import MetaPropertiesEndpoints
from mat3ra.api_client.endpoints.projects import ProjectEndpoints
from mat3ra.api_client.endpoints.properties import PropertiesEndpoints
from mat3ra.api_client.endpoints.workflows import WorkflowEndpoints

from mat3ra.api_client.client import APIClient
from .client import APIClient
from .constants import ACCESS_TOKEN_ENV_VAR, CLIENT_ID, SCOPE, build_oidc_base_url
from .models import Account, APIEnv, AuthContext, AuthEnv
from .endpoints.bank_materials import BankMaterialEndpoints
from .endpoints.bank_workflows import BankWorkflowEndpoints
from .endpoints.jobs import JobEndpoints
from .endpoints.login import LoginEndpoint
from .endpoints.logout import LogoutEndpoint
from .endpoints.materials import MaterialEndpoints
from .endpoints.metaproperties import MetaPropertiesEndpoints
from .endpoints.projects import ProjectEndpoints
from .endpoints.properties import PropertiesEndpoints
from .endpoints.workflows import WorkflowEndpoints
224 changes: 93 additions & 131 deletions src/py/mat3ra/api_client/client.py
Original file line number Diff line number Diff line change
@@ -1,90 +1,21 @@
import os
from typing import Any, Optional, Tuple
import re
from typing import Any, List, Optional, Tuple

import requests
from pydantic import BaseModel, ConfigDict, Field

from mat3ra.api_client.endpoints.bank_materials import BankMaterialEndpoints
from mat3ra.api_client.endpoints.bank_workflows import BankWorkflowEndpoints
from mat3ra.api_client.endpoints.jobs import JobEndpoints
from mat3ra.api_client.endpoints.materials import MaterialEndpoints
from mat3ra.api_client.endpoints.metaproperties import MetaPropertiesEndpoints
from mat3ra.api_client.endpoints.projects import ProjectEndpoints
from mat3ra.api_client.endpoints.properties import PropertiesEndpoints
from mat3ra.api_client.endpoints.workflows import WorkflowEndpoints

# Default API Configuration
DEFAULT_API_HOST = "platform-new.mat3ra.com"
DEFAULT_API_PORT = 443
DEFAULT_API_VERSION = "2018-10-01"
DEFAULT_API_SECURE = True

# Environment Variable Names
ACCESS_TOKEN_ENV_VAR = "OIDC_ACCESS_TOKEN"
API_HOST_ENV_VAR = "API_HOST"
API_PORT_ENV_VAR = "API_PORT"
API_VERSION_ENV_VAR = "API_VERSION"
API_SECURE_ENV_VAR = "API_SECURE"
ACCOUNT_ID_ENV_VAR = "ACCOUNT_ID"
AUTH_TOKEN_ENV_VAR = "AUTH_TOKEN"

# Default OIDC Configuration
CLIENT_ID = "cli-device-client"
SCOPE = "openid profile email"

# API Paths
USERS_ME_PATH = "/api/v1/users/me"


class AuthContext(BaseModel):
access_token: Optional[str] = None
account_id: Optional[str] = None
auth_token: Optional[str] = None


class APIEnv(BaseModel):
host: str = Field(default=DEFAULT_API_HOST, validation_alias=API_HOST_ENV_VAR)
port: int = Field(default=DEFAULT_API_PORT, validation_alias=API_PORT_ENV_VAR)
version: str = Field(default=DEFAULT_API_VERSION, validation_alias=API_VERSION_ENV_VAR)
secure: bool = Field(default=DEFAULT_API_SECURE, validation_alias=API_SECURE_ENV_VAR)
from pydantic import BaseModel, ConfigDict

@classmethod
def from_env(cls) -> "APIEnv":
return cls.model_validate(os.environ)


class AuthEnv(BaseModel):
access_token: Optional[str] = Field(None, validation_alias=ACCESS_TOKEN_ENV_VAR)
account_id: Optional[str] = Field(None, validation_alias=ACCOUNT_ID_ENV_VAR)
auth_token: Optional[str] = Field(None, validation_alias=AUTH_TOKEN_ENV_VAR)

@classmethod
def from_env(cls) -> "AuthEnv":
return cls.model_validate(os.environ)


class Account(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True, validate_assignment=True)
from .constants import ACCESS_TOKEN_ENV_VAR, _build_base_url
from .endpoints.bank_materials import BankMaterialEndpoints
from .endpoints.bank_workflows import BankWorkflowEndpoints
from .endpoints.jobs import JobEndpoints
from .endpoints.materials import MaterialEndpoints
from .endpoints.metaproperties import MetaPropertiesEndpoints
from .endpoints.projects import ProjectEndpoints
from .endpoints.properties import PropertiesEndpoints
from .endpoints.workflows import WorkflowEndpoints
from .models import Account, APIEnv, AuthContext, AuthEnv

client: Any = Field(exclude=True, repr=False)
id_cache: Optional[str] = None

@property
def id(self) -> str:
if self.id_cache:
return self.id_cache
self.id_cache = self.client._resolve_account_id()
return self.id_cache


def _build_base_url(host: str, port: int, secure: bool, path: str) -> str:
protocol = "https" if secure else "http"
port_str = f":{port}" if port not in (80, 443) else ""
return f"{protocol}://{host}{port_str}{path}"

# Used in API-examples utils
def build_oidc_base_url(host: str, port: int, secure: bool) -> str:
return _build_base_url(host, port, secure, "/oidc")

class APIClient(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow", validate_assignment=True)
Expand All @@ -99,8 +30,15 @@ class APIClient(BaseModel):
def model_post_init(self, __context: Any) -> None:
self.my_account = Account(client=self)
self.account = self.my_account
self._my_organization: Optional[Account] = None
self._init_endpoints(self.timeout_seconds)

@property
def my_organization(self) -> Optional[Account]:
if self._my_organization is None:
self._my_organization = self.get_default_organization()
return self._my_organization

@classmethod
def env(cls) -> APIEnv:
return APIEnv.from_env()
Expand All @@ -110,39 +48,17 @@ def auth_env(cls) -> AuthEnv:
return AuthEnv.from_env()

def _init_endpoints(self, timeout_seconds: int) -> None:
kwargs = {"timeout": timeout_seconds, "auth": self.auth}
account_id = self.auth.account_id or ""
auth_token = self.auth.auth_token or ""
self._init_core_endpoints(kwargs, account_id, auth_token)
self._init_bank_endpoints(kwargs, account_id, auth_token)

def _init_core_endpoints(self, kwargs: dict, account_id: str, auth_token: str) -> None:
self.materials = MaterialEndpoints(
self.host, self.port, account_id, auth_token, version=self.version, secure=self.secure, **kwargs
)
self.workflows = WorkflowEndpoints(
self.host, self.port, account_id, auth_token, version=self.version, secure=self.secure, **kwargs
)
self.jobs = JobEndpoints(
self.host, self.port, account_id, auth_token, version=self.version, secure=self.secure, **kwargs
)
self.projects = ProjectEndpoints(
self.host, self.port, account_id, auth_token, version=self.version, secure=self.secure, **kwargs
)
self.properties = PropertiesEndpoints(
self.host, self.port, account_id, auth_token, version=self.version, secure=self.secure, **kwargs
)
self.metaproperties = MetaPropertiesEndpoints(
self.host, self.port, account_id, auth_token, version=self.version, secure=self.secure, **kwargs
)

def _init_bank_endpoints(self, kwargs: dict, account_id: str, auth_token: str) -> None:
self.bank_materials = BankMaterialEndpoints(
self.host, self.port, account_id, auth_token, version=self.version, secure=self.secure, **kwargs
)
self.bank_workflows = BankWorkflowEndpoints(
self.host, self.port, account_id, auth_token, version=self.version, secure=self.secure, **kwargs
)
base_args = (self.host, self.port, self.auth.account_id or "", self.auth.auth_token or "")
base_kwargs = {"version": self.version, "secure": self.secure, "timeout": timeout_seconds, "auth": self.auth}

self.materials = MaterialEndpoints(*base_args, **base_kwargs)
self.workflows = WorkflowEndpoints(*base_args, **base_kwargs)
self.jobs = JobEndpoints(*base_args, **base_kwargs)
self.projects = ProjectEndpoints(*base_args, **base_kwargs)
self.properties = PropertiesEndpoints(*base_args, **base_kwargs)
self.metaproperties = MetaPropertiesEndpoints(*base_args, **base_kwargs)
self.bank_materials = BankMaterialEndpoints(*base_args, **base_kwargs)
self.bank_workflows = BankWorkflowEndpoints(*base_args, **base_kwargs)

@staticmethod
def _resolve_config(
Expand Down Expand Up @@ -195,27 +111,73 @@ def authenticate(
auth_token: Optional[str] = None,
timeout_seconds: int = 60,
) -> "APIClient":
host_value, port_value, version_value, secure_value = cls._resolve_config(host, port, version, secure,
cls.env())
host_value, port_value, version_value, secure_value = cls._resolve_config(
host, port, version, secure, cls.env()
)
auth = cls._auth_from_env(access_token=access_token, account_id=account_id, auth_token=auth_token)
cls._validate_auth(auth)
return cls(host=host_value, port=port_value, version=version_value, secure=secure_value, auth=auth,
timeout_seconds=timeout_seconds)

def _resolve_account_id(self) -> str:
account_id = self.auth.account_id or os.environ.get(ACCOUNT_ID_ENV_VAR)
if account_id:
self.auth.account_id = account_id
return account_id
return cls(
host=host_value,
port=port_value,
version=version_value,
secure=secure_value,
auth=auth,
timeout_seconds=timeout_seconds,
)

def _fetch_data(self) -> dict:
access_token = self.auth.access_token or os.environ.get(ACCESS_TOKEN_ENV_VAR)
if not access_token:
raise ValueError("ACCOUNT_ID is not set and no OIDC access token is available.")
raise ValueError("Access token is required to fetch user data")

url = _build_base_url(self.host, self.port, self.secure, USERS_ME_PATH)
url = _build_base_url(self.host, self.port, self.secure, "/api/v1/users/me")
response = requests.get(url, headers={"Authorization": f"Bearer {access_token}"}, timeout=30)
response.raise_for_status()
account_id = response.json()["data"]["user"]["entity"]["defaultAccountId"]
os.environ[ACCOUNT_ID_ENV_VAR] = account_id
self.auth.account_id = account_id
return account_id
return response.json()["data"]

def _fetch_user_accounts(self) -> List[dict]:
return self._fetch_data().get("accounts", [])

def list_accounts(self) -> List[dict]:
accounts = self._fetch_user_accounts()
return [
{
"_id": account["entity"]["_id"],
"name": account["entity"].get("name", ""),
"type": account["entity"].get("type", "personal"),
"isDefault": account.get("isDefault", False),
}
for account in accounts
]

def get_account(self, name: Optional[str] = None, index: Optional[int] = None) -> Account:
"""Get account by name (partial regex match) or index from the list of user accounts."""
if name is None and index is None:
raise ValueError("Either 'name' or 'index' must be provided")

accounts = self._fetch_user_accounts()

if index is not None:
return Account(client=self, entity_cache=accounts[index]["entity"])

pattern = re.compile(name, re.IGNORECASE)
matches = [account for account in accounts if pattern.search(account["entity"].get("name", ""))]

if not matches:
raise ValueError(f"No account found matching '{name}'")
if len(matches) > 1:
names = [acc["entity"].get("name", "") for acc in matches]
raise ValueError(f"Multiple accounts match '{name}': {names}")

return Account(client=self, entity_cache=matches[0]["entity"])

def get_default_organization(self) -> Optional[Account]:
accounts = self._fetch_user_accounts()
organizations = [account for account in accounts if
account["entity"].get("type") in ("organization", "enterprise")]

if not organizations:
return None

default_org = next((org for org in organizations if org.get("isDefault")), organizations[0])
return Account(client=self, entity_cache=default_org["entity"])
20 changes: 20 additions & 0 deletions src/py/mat3ra/api_client/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Environment Variable Names - Exported for external use (e.g., api-examples)
ACCESS_TOKEN_ENV_VAR = "OIDC_ACCESS_TOKEN"
ACCOUNT_ID_ENV_VAR = "ACCOUNT_ID"
AUTH_TOKEN_ENV_VAR = "AUTH_TOKEN"

# OIDC Configuration - Exported for external use
CLIENT_ID = "cli-device-client"
SCOPE = "openid profile email"


def _build_base_url(host: str, port: int, secure: bool, path: str) -> str:
protocol = "https" if secure else "http"
port_str = f":{port}" if port not in (80, 443) else ""
return f"{protocol}://{host}{port_str}{path}"


def build_oidc_base_url(host: str, port: int, secure: bool) -> str:
"""Used in api-examples utils."""
return _build_base_url(host, port, secure, "/oidc")

7 changes: 4 additions & 3 deletions src/py/mat3ra/api_client/endpoints/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ def terminate(self, id_):
"""
self.request("POST", "/".join((self.name, id_, "submit")), headers=self.headers)

def get_config(self, material_ids, workflow_id, project_id, owner_id, name, compute=None, is_multi_material=False):
def build_config(self, material_ids, workflow_id, project_id, owner_id, name, compute=None,
is_multi_material=False):
"""
Returns a job config based on the given parameters.

Expand Down Expand Up @@ -85,7 +86,7 @@ def get_config(self, material_ids, workflow_id, project_id, owner_id, name, comp
config.update({"_material": {"_id": material_ids[0]}})
return config

def get_compute(self, cluster, ppn=1, nodes=1, queue="D", time_limit="01:00:00", notify="abe"):
def build_compute_config(self, cluster, ppn=1, nodes=1, queue="D", time_limit="01:00:00", notify="abe"):
"""
Returns job compute configuration.

Expand Down Expand Up @@ -128,7 +129,7 @@ def create_by_ids(self, materials, workflow_id, project_id, prefix, owner_id=Non
jobs = []
for material in materials:
job_name = " ".join((prefix, material["formula"]))
job_config = self.get_config([material["_id"]], workflow_id, project_id, owner_id, job_name, compute)
job_config = self.build_config([material["_id"]], workflow_id, project_id, owner_id, job_name, compute)
jobs.append(self.create(job_config))
return jobs

Expand Down
4 changes: 2 additions & 2 deletions src/py/mat3ra/api_client/endpoints/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@


class BasePropertiesEndpoints(EntityEndpoint):
def get_property_selector(self, job_id, unit_flowchart_id, property_name):
def build_property_selector(self, job_id, unit_flowchart_id, property_name):
return {"source.info.jobId": job_id, "source.info.unitId": unit_flowchart_id, "data.name": property_name}

def get_property(self, job_id, unit_flowchart_id, property_name):
selector = self.get_property_selector(job_id, unit_flowchart_id, property_name)
selector = self.build_property_selector(job_id, unit_flowchart_id, property_name)
return self.list(query=selector)[0]

def get_band_gap_by_type(self, job_id, unit_flowchart_id, type):
Expand Down
Loading
Loading