diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 0000000..2f8faaf --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,26 @@ +name: tests + +on: + push: + branches: ["main"] + pull_request: + +jobs: + test: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.8", "3.10", "3.12"] + + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install package and dev deps + run: pip install -e ".[dev]" + + - name: Run tests + run: pytest --tb=short diff --git a/README.md b/README.md index 07f9ded..8ada93d 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,35 @@ You can install the package using `pip`: pip install brainstem_python_api_tools ``` +## Authentication + +Authentication uses a **browser-based device authorization flow** that supports two-factor authentication (2FA). No credentials are entered into the tool itself. + +```python +from brainstem_api_tools import BrainstemClient + +# First run: opens a browser window for secure login. +# The token is cached at ~/.config/brainstem/token and reused automatically. +client = BrainstemClient() +``` + +To skip the browser flow, pass a token directly: +```python +client = BrainstemClient(token="YOUR_TOKEN") +``` + +For headless environments (no browser available): +```python +client = BrainstemClient(headless=True) # prints a URL + code to enter manually +``` + +To connect to a different server (e.g. a local development instance): +```python +client = BrainstemClient(url="http://127.0.0.1:8000/") +``` + +> **Security note:** The cached token file is stored with owner-read-only permissions (`0600`). Treat it like a password and do not commit it to version control. + ## Getting Started To get started with the BrainSTEM API tools, please refer to the tutorial script provided: @@ -15,21 +44,118 @@ To get started with the BrainSTEM API tools, please refer to the tutorial script The tutorial demonstrates how to: -- **Authenticate:** Load the client and authenticate using your credentials. -- **Loading Data:** Load sessions and filter data using flexible options. -- **Updating Entries:** Modify existing models and update them in the database. -- **Creating Entries:** Submit new data entries with required fields. -- **Loading Public Data:** Access public projects and data using the public portal. +- **Authenticate:** Load the client and authenticate via your browser. +- **Load Data:** Load sessions and filter data using flexible options. +- **Paginate:** Retrieve large datasets using `limit` / `offset` or `load_all=True`. +- **Convenience Loaders:** Use `load_session()`, `load_subject()`, etc. for common queries. +- **Update Entries:** Modify existing records and update them in the database. +- **Create Entries:** Submit new data entries with required fields. +- **Delete Entries:** Remove records by ID. +- **Load Public Data:** Access public projects and data using the public portal. ## Example Usage ```python from brainstem_api_tools import BrainstemClient client = BrainstemClient() -response = client.load_model('session', sort=['name']) + +# Load sessions sorted by name +response = client.load('session', sort=['name']) print(response.json()) + +# Filter and include related data +response = client.load( + 'session', + filters={'name.icontains': 'Rat'}, + sort=['-name'], + include=['projects'], +) + +# Paginate results (manual) +page2 = client.load('session', limit=20, offset=20).json() + +# Auto-paginate — fetches every page and returns a merged dict +all_sessions = client.load('session', load_all=True) + +# Convenience loaders — sensible defaults with named filter kwargs +# load_session embeds dataacquisition, behaviors, manipulations, epochs +sessions = client.load_session(name='Rat', load_all=True) + +# load_subject embeds procedures and subjectlogs +subjects = client.load_subject(sex='M', projects='', load_all=True) + +# load_project embeds sessions, subjects, collections, cohorts +projects = client.load_project(name='MyProject') + +# load_behavior / load_dataacquisition / load_manipulation scope by session UUID +behaviors = client.load_behavior(session='', load_all=True) + +# load_procedure scopes by subject UUID +procedures = client.load_procedure(subject='', load_all=True) + +# Create a record +client.save('session', data={'name': 'New session', 'projects': ['']}) + +# Update a record +client.save('session', id='', data={'description': 'updated'}) + +# Delete a record +client.delete('session', id='') +``` + +## Command-line Interface + +After installation a `brainstem` command is available in your shell. + +### Authentication +```bash +# Authenticate (opens browser) and cache token +brainstem login + +# Headless — prints URL + code instead of opening browser +brainstem login --headless + +# Connect to a local dev server +brainstem login --url http://127.0.0.1:8000/ + +# Remove cached token +brainstem logout +``` + +### Loading data +```bash +# Load all sessions (private portal) +brainstem load session + +# Filter, sort and embed related data +brainstem load session --filters name.icontains=Rat --sort -name --include projects + +# Load a single record by UUID +brainstem load session --id + +# Manual pagination +brainstem load session --limit 20 --offset 20 + +# Public portal +brainstem load project --portal public ``` +### Creating and updating records +```bash +# Create a new session +brainstem save session --data '{"name":"New session","projects":[""]}' + +# Update an existing record +brainstem save session --id --data '{"description":"updated"}' +``` + +### Deleting records +```bash +brainstem delete session --id +``` + +All subcommands accept `--token`, `--headless`, and `--url` to override defaults. + ## Contributing Contributions are welcome! Feel free to open issues or submit pull requests on GitHub. diff --git a/brainstem_api_tools/__init__.py b/brainstem_api_tools/__init__.py index 9fc63f9..7b19a80 100644 --- a/brainstem_api_tools/__init__.py +++ b/brainstem_api_tools/__init__.py @@ -1 +1,3 @@ -from .brainstem_api_client import BrainstemClient, ModelType, PortalType +from .brainstem_api_client import BrainstemClient, ModelType, PortalType, AuthenticationError + +__version__ = "2.0.0" diff --git a/brainstem_api_tools/brainstem_api_client.py b/brainstem_api_tools/brainstem_api_client.py index 1cb838a..da55b1f 100644 --- a/brainstem_api_tools/brainstem_api_client.py +++ b/brainstem_api_tools/brainstem_api_client.py @@ -1,43 +1,104 @@ import os -from getpass import getpass -import requests -from requests.models import Response -import json +import stat +import time +import webbrowser from enum import Enum +from pathlib import Path +from typing import Union - -class ModelType(Enum): - project = "project" - subject = "subject" - session = "session" - collection = "collection" - cohort = "cohort" - procedure = "procedure" - behavior = "behavior" - dataacquisition = "dataacquisition" - manipulation = "manipulation" - equipment = "equipment" - consumablestock = "consumablestock" - procedurelog = "procedurelog" - subjectlog = "subjectlog" - behavioralparadigm = "behavioralparadigm" - datastorage = "datastorage" - inventory = "inventory" - setup = "setup" - consumable = "consumable" - hardwaredevice = "hardwaredevice" - supplier = "supplier" - brainregion = "brainregion" - setuptype = "setuptype" - species = "species" - strain = "strain" - strainapproval = "strainapproval" - journal = "journal" - laboratory = "laboratory" - publication = "publication" - journalapproval = "journalapproval" - user = "user" - group = "group" +import requests +from requests.adapters import HTTPAdapter +from requests.models import Response +from urllib3.util.retry import Retry + + +# --------------------------------------------------------------------------- +# Model → app routing table (single source of truth) +# --------------------------------------------------------------------------- + +_MODEL_TO_APP: dict = { + # stem + "project": "stem", + "subject": "stem", + "session": "stem", + "collection": "stem", + "cohort": "stem", + "breeding": "stem", + "project_membership_invitation": "stem", + "project_group_membership_invitation": "stem", + # modules + "procedure": "modules", + "behavior": "modules", + "dataacquisition": "modules", + "manipulation": "modules", + "equipment": "modules", + "consumablestock": "modules", + "procedurelog": "modules", + "subjectlog": "modules", + # personal_attributes + "behavioralassay": "personal_attributes", + "datastorage": "personal_attributes", + "inventory": "personal_attributes", + "license": "personal_attributes", + "protocol": "personal_attributes", + "setup": "personal_attributes", + # resources + "consumable": "resources", + "hardwaredevice": "resources", + "supplier": "resources", + # taxonomies + "behavioralcategory": "taxonomies", + "behavioralparadigm": "taxonomies", + "brainregion": "taxonomies", + "regulatoryauthority": "taxonomies", + "setuptype": "taxonomies", + "species": "taxonomies", + "strain": "taxonomies", + "strainapproval": "taxonomies", + # dissemination + "journal": "dissemination", + "journalapproval": "dissemination", + "publication": "dissemination", + # users + "group": "users", + "group_membership_invitation": "users", + "group_membership_request": "users", + "laboratory": "users", + "user": "users", +} + +_TOKEN_FILE = Path.home() / ".config" / "brainstem" / "token" + +# Derived from _MODEL_TO_APP so there is exactly one source of truth. +ModelType = Enum("ModelType", {k: k for k in _MODEL_TO_APP}) # type: ignore[misc] + +_VALID_PORTALS = {"private", "public", "super"} + + +def _resolve_model(model) -> str: + """Accept a plain string or a ModelType member and return the string value.""" + if isinstance(model, ModelType): + return model.value + model = str(model) + if model not in _MODEL_TO_APP: + raise ValueError( + f"Unknown model '{model}'. Valid models are: " + + ", ".join(sorted(_MODEL_TO_APP)) + ) + return model + + +def _resolve_portal(portal) -> str: + """Accept a plain string or a PortalType member and return the string value.""" + if isinstance(portal, Enum): + portal = portal.value + portal = str(portal) + if portal not in _VALID_PORTALS: + raise ValueError( + f"Unknown portal '{portal}'. Valid portals are: " + + ", ".join(sorted(_VALID_PORTALS)) + ) + return portal class PortalType(Enum): @@ -46,202 +107,789 @@ class PortalType(Enum): super = "super" -class LoginError(Exception): - def __str__(self): - return f'User/password combination incorrect or user does not exist' +# --------------------------------------------------------------------------- +# Exceptions +# --------------------------------------------------------------------------- +class AuthenticationError(Exception): + pass + + +# --------------------------------------------------------------------------- +# Client +# --------------------------------------------------------------------------- class BrainstemClient: - def __init__(self, - token: str = None) -> None: + BASE_URL = "https://www.brainstem.org/" + DEFAULT_TIMEOUT: int = 30 # seconds; applied to all HTTP calls + + def __init__( + self, + token: str = None, + headless: bool = False, + url: str = None, + ) -> None: + base = (url.rstrip("/") + "/") if url else self.BASE_URL + self._address = base + "api/" + self._session = requests.Session() - # Server path - self._address = 'https://www.brainstem.org/api/' + # Automatically retry transient server errors + _retry = Retry(total=3, backoff_factor=0.5, status_forcelist={502, 503, 504}) + self._session.mount("https://", HTTPAdapter(max_retries=_retry)) + self._session.mount("http://", HTTPAdapter(max_retries=_retry)) if token: self._token = token else: - username = input("Please enter your username/email:") - password = getpass("Please enter your password:") - - self._token = self.__set_token_authentication( - url=self._address + "token/", - username=username, - password=password, + self._token = self._load_cached_token() + if not self._token: + self._token = self._device_auth_flow(headless=headless) + self._save_token(self._token) + + self._session.headers.update({"Authorization": f"Bearer {self._token}"}) + + # ------------------------------------------------------------------ + # Token management + # ------------------------------------------------------------------ + + def _load_cached_token(self) -> str: + """Return the cached token from disk, or None if not present.""" + if _TOKEN_FILE.exists(): + return _TOKEN_FILE.read_text().strip() or None + return None + + def _save_token(self, token: str) -> None: + """Persist the token to disk with owner-read-only permissions.""" + _TOKEN_FILE.parent.mkdir(parents=True, exist_ok=True) + _TOKEN_FILE.write_text(token) + _TOKEN_FILE.chmod(stat.S_IRUSR | stat.S_IWUSR) # 0o600 + + def _device_auth_flow(self, headless: bool = False, max_wait: int = 900) -> str: + """Run the BrainSTEM device authorization flow and return a token. + + The user authenticates in their browser (supports 2FA). No + credentials are sent by this tool. + + Parameters + ---------- + headless : If ``True``, print the verification URI and user code + instead of opening a browser window. + max_wait : Maximum seconds to wait for browser approval (default: 900). + """ + # Step 1 — initiate a device session + resp = requests.post( + self._address + "auth/device/", + json={"client_name": "brainstem-python"}, + timeout=10, + ) + resp.raise_for_status() + data = resp.json() + + device_code = data["device_code"] + interval = int(data.get("interval", 5)) + + # Step 2 — open browser or print instructions + if headless: + print(f"Open {data['verification_uri']} and enter: {data['user_code']}") + else: + webbrowser.open(data["verification_uri_complete"]) + print("Waiting for browser approval...") + + # Step 3 — poll until resolved (timeout after max_wait seconds) + deadline = time.monotonic() + max_wait + while True: + if time.monotonic() > deadline: + raise AuthenticationError( + f"Device authorization timed out after {max_wait} seconds." + ) + time.sleep(interval) + poll = requests.post( + self._address + "auth/device/token/", + json={"device_code": device_code}, + timeout=10, ) - - print("Your authorization token is:\n", self._token) - print("\nPlease keep it in a safe place.") - - def __set_token_authentication(self, - url: str, - username: str, - password: str) -> str: - - headers = { - "accept": "application/json", - "Content-Type": "application/json" - } - params = { - "username": username, - "password": password - } - resp = requests.post(url, headers=headers, json=params) - - if resp.status_code != 200: - raise LoginError() - - token = json.loads(resp.text)['token'] - return token - - def __get_app_from_model(self, modelname: str) -> str: - app = None - - if modelname in ['project', 'subject', 'session', 'collection', 'cohort', - 'projectmembershipinvitation', - 'projectgroupmembershipinvitation']: - app = 'stem' - - elif modelname in ['procedure', 'behavior', 'dataacquisition', 'manipulation', - 'equipment', 'consumablestock', 'procedurelog', 'subjectlog']: - app = 'modules' - - elif modelname in ['behavioralparadigm', 'datastorage', 'inventory', - 'setup']: - app = 'personal_attributes' - - elif any([x in modelname for x in - ['consumable', 'hardwaredevice', 'supplier']]): - app = 'resources' - - elif any([x in modelname for x in - ['brainregion', 'setuptype', - 'species', 'strain']]): - app = 'taxonomies' - - elif any([x in modelname for x in - ['journal', 'laboratory', 'publication']]): - app = 'attributes' - - elif modelname in ['user', 'group', 'groupmembershipinvitation', - 'groupmembershiprequest']: - app = 'users' - - return app - - def load_model(self, - model: ModelType, - portal: PortalType = "private", - id: str = None, - options: str = None, - filters: dict = None, - sort: list = None, - include: list = None) -> Response: - - app = self.__get_app_from_model(model) - if app is None: - resp = Response() - resp.status_code = 404 - return resp - - if options is None: - options = "" - + poll.raise_for_status() + result = poll.json() + + if result.get("status") == "success": + return result["token"] + elif result.get("status") == "authorization_pending": + continue + else: + raise AuthenticationError( + f"Authorization failed: {result.get('error')}" + ) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _build_url(self, portal: str, app: str, model: str, + id: str = None, options: str = None) -> str: + url = f"{self._address}{portal}/{app}/{model}/" if id: - query_parameters = id + "/" + options - else: - query_parameters = "" - - if filters: - for key in filters.keys(): - if query_parameters == "": - prefix = "?" - else: - prefix = "&" - query_parameters += (prefix - + "filter{" + key + "}=" - + filters[key]) - if sort: - for elem in sort: - if query_parameters == "": - prefix = "?" - else: - prefix = "&" - query_parameters += prefix + "sort[]=" + elem - - if include: - for elem in include: - if query_parameters == "": - prefix = "?" - else: - prefix = "&" - query_parameters += prefix + "include[]=" + elem + ".*" - - request_url = (self._address + portal - + "/" + app + "/" + model - + "/" + query_parameters) - - resp = requests.get(request_url, - headers={"Authorization": "Bearer %s" - % self._token}) - return resp - - def save_model(self, - model: ModelType, - portal: PortalType = "private", - id: str = None, - data: dict = None, - options: str = None) -> Response: - - app = self.__get_app_from_model(model) - if app is None: - resp = Response() - resp.status_code = 404 - return resp + url += f"{id}/" + if options: + url += options + return url + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def load(self, + model, + portal="private", + id: str = None, + options: str = None, + filters: dict = None, + sort: list = None, + include: list = None, + limit: int = None, + offset: int = None, + load_all: bool = False) -> Union[Response, dict]: + """Load one or more records of *model*. + + Parameters + ---------- + model : Model name string or ``ModelType`` member, e.g. ``'session'``. + portal : ``'private'`` (default) or ``'public'``, or ``PortalType`` member. + id : UUID of a specific record. Returns a single object when set. + filters : Dict of field filters, e.g. ``{'name.icontains': 'rat'}``. + Supports ``.icontains``, ``.startswith``, ``.endswith``, + ``.gt``, ``.gte``, ``.lt``, ``.lte``. + sort : List of fields to sort by. Prefix with ``'-'`` for descending. + include : Related models to embed, e.g. ``['dataacquisition', 'behaviors']``. + limit : Max records per page (API maximum: 100). + offset : Number of records to skip (for pagination). + load_all : When ``True``, automatically follow pagination and return a + combined ``dict`` with all records merged under the model key. + When ``False`` (default), returns the raw ``Response``. + """ + model = _resolve_model(model) + portal = _resolve_portal(portal) + app = _MODEL_TO_APP[model] + url = self._build_url(portal, app, model, id, options) + + params = {} + if not id: + for key, val in (filters or {}).items(): + params[f"filter{{{key}}}"] = val + for field in (sort or []): + params.setdefault("sort[]", []).append(field) + for rel in (include or []): + params.setdefault("include[]", []).append(f"{rel}.*") + if limit is not None: + params["limit"] = limit + if offset is not None: + params["offset"] = offset + + if not load_all: + return self._session.get(url, params=params, timeout=self.DEFAULT_TIMEOUT) + if id is not None: + raise ValueError("load_all=True cannot be used together with id.") + + # --- auto-paginate and merge all pages --- + page_size = limit or 100 + params["limit"] = page_size + params.setdefault("offset", offset or 0) + + combined: dict = {} + records_key: str = None + + while True: + resp = self._session.get(url, params=params, timeout=self.DEFAULT_TIMEOUT) + if resp.status_code == 401: + raise AuthenticationError( + "API token is invalid or expired. Run `brainstem login` to re-authenticate." + ) + resp.raise_for_status() + data = resp.json() + + if records_key is None: + # Detect the list key (e.g. 'sessions', 'projects') + records_key = next( + (k for k, v in data.items() if isinstance(v, list)), None + ) + if records_key is None: + raise ValueError( + "load_all=True requires a paginated list response, but the API " + "returned no list-valued key. Use load_all=False for single-object endpoints." + ) + combined = {k: v for k, v in data.items() if k != records_key} + combined[records_key] = [] + + combined[records_key].extend(data.get(records_key, [])) + total = data.get("count", len(combined[records_key])) + + if len(combined[records_key]) >= total: + break + + params["offset"] = params.get("offset", 0) + page_size + + return combined + + def save(self, + model, + portal="private", + id: str = None, + data: dict = None, + options: str = None) -> Response: + """Create or update a record. + + When *id* is provided the record is **updated** (PATCH). + When *id* is omitted a **new record** is created (POST). + + Parameters + ---------- + model : Model name string or ``ModelType`` member, e.g. ``'session'``. + portal : ``'private'`` (default) or ``'public'``, or ``PortalType`` member. + id : UUID of the record to update. Omit to create a new record. + data : Dict of fields to submit. + """ + model = _resolve_model(model) + portal = _resolve_portal(portal) + app = _MODEL_TO_APP[model] if data is None: data = {} - if options is None: - options = "" - - # Check if entry already exists if id is not None: - # This is an update request - request_url = (self._address + portal + "/" + app + "/" + model - + "/" + id + "/" + options) - resp = requests.patch(request_url, json=data, - headers={"Authorization": "Bearer %s" - % self._token}) - + url = self._build_url(portal, app, model, id, options) + return self._session.patch(url, json=data, timeout=self.DEFAULT_TIMEOUT) else: - # This is a create request - request_url = (self._address + portal + "/" + app + "/" - + model + "/" + options) - resp = requests.post(request_url, json=data, - headers={"Authorization": "Bearer %s" - % self._token}) - - return resp - - def delete_model(self, - model: ModelType, - portal: PortalType = "private", - id: str = None): - - app = self.__get_app_from_model(model) - if app is None: - resp = Response() - resp.status_code = 404 - return resp - - # Check if entry already exists - if id is not None: - request_url = (self._address + portal + "/" + app + "/" + model - + "/" + id + "/") - resp = requests.delete(request_url, - headers={"Authorization": "Bearer %s" - % self._token}) - - return resp + url = self._build_url(portal, app, model, options=options) + return self._session.post(url, json=data, timeout=self.DEFAULT_TIMEOUT) + + def delete(self, + model, + portal="private", + id: str = None) -> Response: + """Delete a record by ID. + + Parameters + ---------- + model : Model name string or ``ModelType`` member, e.g. ``'session'``. + portal : ``'private'`` (default) or ``'public'``, or ``PortalType`` member. + id : UUID of the record to delete (required). + """ + if id is None: + raise ValueError("'id' is required to delete a record.") + model = _resolve_model(model) + portal = _resolve_portal(portal) + app = _MODEL_TO_APP[model] + url = self._build_url(portal, app, model, id) + return self._session.delete(url, timeout=self.DEFAULT_TIMEOUT) + + def __enter__(self): + return self + + def __exit__(self, *args): + self._session.close() + + # ------------------------------------------------------------------ + # Convenience loaders (mirror the MATLAB load_* helpers) + # ------------------------------------------------------------------ + + def _convenience_load(self, model, default_include, filter_map, + portal, id, filters, sort, include, limit, offset, + load_all, **field_kwargs): + """Shared implementation for all convenience loaders.""" + unknown = set(field_kwargs) - set(filter_map) + if unknown: + raise TypeError( + f"Unexpected keyword argument(s): {', '.join(sorted(unknown))}" + ) + merged_filters = dict(filters or {}) + for kwarg, api_field in filter_map.items(): + value = field_kwargs.get(kwarg) + if value is not None: + merged_filters[api_field] = value + return self.load( + model, + portal=portal, + id=id, + filters=merged_filters or None, + sort=sort, + include=include if include is not None else default_include, + limit=limit, + offset=offset, + load_all=load_all, + ) + + def load_project(self, portal="private", id: str = None, + name: str = None, sessions: str = None, + subjects: str = None, tags: str = None, + filters: dict = None, sort: list = None, + include: list = None, limit: int = None, + offset: int = None, load_all: bool = False): + """Load project(s). Embeds sessions, subjects, collections and cohorts by default. + + Parameters + ---------- + name : Filter by project name (case-insensitive contains). + sessions : Filter by session UUID. + subjects : Filter by subject UUID. + tags : Filter by tag. + """ + return self._convenience_load( + "project", + default_include=["sessions", "subjects", "collections", "cohorts"], + filter_map={"name": "name.icontains", "sessions": "sessions.id", + "subjects": "subjects.id", "tags": "tags"}, + portal=portal, id=id, filters=filters, sort=sort, include=include, + limit=limit, offset=offset, load_all=load_all, + name=name, sessions=sessions, subjects=subjects, tags=tags, + ) + + def load_subject(self, portal="private", id: str = None, + name: str = None, projects: str = None, + strain: str = None, sex: str = None, tags: str = None, + filters: dict = None, sort: list = None, + include: list = None, limit: int = None, + offset: int = None, load_all: bool = False): + """Load subject(s). Embeds procedures and subjectlogs by default. + + Parameters + ---------- + name : Filter by subject name (case-insensitive contains). + projects : Filter by project UUID. + strain : Filter by strain UUID. + sex : Filter by sex (``'M'``, ``'F'``, or ``'U'``). + tags : Filter by tag. + """ + return self._convenience_load( + "subject", + default_include=["procedures", "subjectlogs"], + filter_map={"name": "name.icontains", "projects": "projects.id", + "strain": "strain.id", "sex": "sex", "tags": "tags"}, + portal=portal, id=id, filters=filters, sort=sort, include=include, + limit=limit, offset=offset, load_all=load_all, + name=name, projects=projects, strain=strain, sex=sex, tags=tags, + ) + + def load_session(self, portal="private", id: str = None, + name: str = None, projects: str = None, + datastorage: str = None, tags: str = None, + filters: dict = None, sort: list = None, + include: list = None, limit: int = None, + offset: int = None, load_all: bool = False): + """Load session(s). Embeds dataacquisition, behaviors, manipulations and epochs by default. + + Parameters + ---------- + name : Filter by session name (case-insensitive contains). + projects : Filter by project UUID. + datastorage : Filter by data storage UUID. + tags : Filter by tag. + """ + return self._convenience_load( + "session", + default_include=["dataacquisition", "behaviors", "manipulations", "epochs"], + filter_map={"name": "name.icontains", "projects": "projects.id", + "datastorage": "datastorage.id", "tags": "tags"}, + portal=portal, id=id, filters=filters, sort=sort, include=include, + limit=limit, offset=offset, load_all=load_all, + name=name, projects=projects, datastorage=datastorage, tags=tags, + ) + + def load_collection(self, portal="private", id: str = None, + name: str = None, tags: str = None, + filters: dict = None, sort: list = None, + include: list = None, limit: int = None, + offset: int = None, load_all: bool = False): + """Load collection(s). Embeds sessions by default. + + Parameters + ---------- + name : Filter by collection name (case-insensitive contains). + tags : Filter by tag. + """ + return self._convenience_load( + "collection", + default_include=["sessions"], + filter_map={"name": "name.icontains", "tags": "tags"}, + portal=portal, id=id, filters=filters, sort=sort, include=include, + limit=limit, offset=offset, load_all=load_all, + name=name, tags=tags, + ) + + def load_cohort(self, portal="private", id: str = None, + name: str = None, tags: str = None, + filters: dict = None, sort: list = None, + include: list = None, limit: int = None, + offset: int = None, load_all: bool = False): + """Load cohort(s). Embeds subjects by default. + + Parameters + ---------- + name : Filter by cohort name (case-insensitive contains). + tags : Filter by tag. + """ + return self._convenience_load( + "cohort", + default_include=["subjects"], + filter_map={"name": "name.icontains", "tags": "tags"}, + portal=portal, id=id, filters=filters, sort=sort, include=include, + limit=limit, offset=offset, load_all=load_all, + name=name, tags=tags, + ) + + def load_behavior(self, portal="private", id: str = None, + session: str = None, tags: str = None, + filters: dict = None, sort: list = None, + include: list = None, limit: int = None, + offset: int = None, load_all: bool = False): + """Load behavior record(s). + + Parameters + ---------- + session : Filter by session UUID. + tags : Filter by tag. + """ + return self._convenience_load( + "behavior", + default_include=[], + filter_map={"session": "session.id", "tags": "tags"}, + portal=portal, id=id, filters=filters, sort=sort, include=include, + limit=limit, offset=offset, load_all=load_all, + session=session, tags=tags, + ) + + def load_dataacquisition(self, portal="private", id: str = None, + session: str = None, tags: str = None, + filters: dict = None, sort: list = None, + include: list = None, limit: int = None, + offset: int = None, load_all: bool = False): + """Load data acquisition record(s). + + Parameters + ---------- + session : Filter by session UUID. + tags : Filter by tag. + """ + return self._convenience_load( + "dataacquisition", + default_include=[], + filter_map={"session": "session.id", "tags": "tags"}, + portal=portal, id=id, filters=filters, sort=sort, include=include, + limit=limit, offset=offset, load_all=load_all, + session=session, tags=tags, + ) + + def load_manipulation(self, portal="private", id: str = None, + session: str = None, tags: str = None, + filters: dict = None, sort: list = None, + include: list = None, limit: int = None, + offset: int = None, load_all: bool = False): + """Load manipulation record(s). + + Parameters + ---------- + session : Filter by session UUID. + tags : Filter by tag. + """ + return self._convenience_load( + "manipulation", + default_include=[], + filter_map={"session": "session.id", "tags": "tags"}, + portal=portal, id=id, filters=filters, sort=sort, include=include, + limit=limit, offset=offset, load_all=load_all, + session=session, tags=tags, + ) + + def load_procedure(self, portal="private", id: str = None, + subject: str = None, tags: str = None, + filters: dict = None, sort: list = None, + include: list = None, limit: int = None, + offset: int = None, load_all: bool = False): + """Load procedure record(s). + + Parameters + ---------- + subject : Filter by subject UUID. + tags : Filter by tag. + """ + return self._convenience_load( + "procedure", + default_include=[], + filter_map={"subject": "subject.id", "tags": "tags"}, + portal=portal, id=id, filters=filters, sort=sort, include=include, + limit=limit, offset=offset, load_all=load_all, + subject=subject, tags=tags, + ) + + def load_procedurelog(self, portal="private", id: str = None, + procedure: str = None, tags: str = None, + filters: dict = None, sort: list = None, + include: list = None, limit: int = None, + offset: int = None, load_all: bool = False): + """Load procedure log record(s). + + Parameters + ---------- + procedure : Filter by procedure UUID. + tags : Filter by tag. + """ + return self._convenience_load( + "procedurelog", + default_include=[], + filter_map={"procedure": "procedure.id", "tags": "tags"}, + portal=portal, id=id, filters=filters, sort=sort, include=include, + limit=limit, offset=offset, load_all=load_all, + procedure=procedure, tags=tags, + ) + + def load_subjectlog(self, portal="private", id: str = None, + subject: str = None, tags: str = None, + filters: dict = None, sort: list = None, + include: list = None, limit: int = None, + offset: int = None, load_all: bool = False): + """Load subject log record(s). + + Parameters + ---------- + subject : Filter by subject UUID. + tags : Filter by tag. + """ + return self._convenience_load( + "subjectlog", + default_include=[], + filter_map={"subject": "subject.id", "tags": "tags"}, + portal=portal, id=id, filters=filters, sort=sort, include=include, + limit=limit, offset=offset, load_all=load_all, + subject=subject, tags=tags, + ) + + def load_equipment(self, portal="private", id: str = None, + name: str = None, setup: str = None, tags: str = None, + filters: dict = None, sort: list = None, + include: list = None, limit: int = None, + offset: int = None, load_all: bool = False): + """Load equipment record(s). + + Parameters + ---------- + name : Filter by equipment name (case-insensitive contains). + setup : Filter by setup UUID. + tags : Filter by tag. + """ + return self._convenience_load( + "equipment", + default_include=[], + filter_map={"name": "name.icontains", "setup": "setup.id", "tags": "tags"}, + portal=portal, id=id, filters=filters, sort=sort, include=include, + limit=limit, offset=offset, load_all=load_all, + name=name, setup=setup, tags=tags, + ) + + def load_consumablestock(self, portal="private", id: str = None, + tags: str = None, + filters: dict = None, sort: list = None, + include: list = None, limit: int = None, + offset: int = None, load_all: bool = False): + """Load consumable stock record(s). + + Parameters + ---------- + tags : Filter by tag. + """ + return self._convenience_load( + "consumablestock", + default_include=[], + filter_map={"tags": "tags"}, + portal=portal, id=id, filters=filters, sort=sort, include=include, + limit=limit, offset=offset, load_all=load_all, + tags=tags, + ) + + def load_behavioralassay(self, portal="private", id: str = None, + name: str = None, tags: str = None, + filters: dict = None, sort: list = None, + include: list = None, limit: int = None, + offset: int = None, load_all: bool = False): + """Load behavioral assay record(s). + + Parameters + ---------- + name : Filter by name (case-insensitive contains). + tags : Filter by tag. + """ + return self._convenience_load( + "behavioralassay", + default_include=[], + filter_map={"name": "name.icontains", "tags": "tags"}, + portal=portal, id=id, filters=filters, sort=sort, include=include, + limit=limit, offset=offset, load_all=load_all, + name=name, tags=tags, + ) + + def load_datastorage(self, portal="private", id: str = None, + name: str = None, tags: str = None, + filters: dict = None, sort: list = None, + include: list = None, limit: int = None, + offset: int = None, load_all: bool = False): + """Load data storage record(s). + + Parameters + ---------- + name : Filter by name (case-insensitive contains). + tags : Filter by tag. + """ + return self._convenience_load( + "datastorage", + default_include=[], + filter_map={"name": "name.icontains", "tags": "tags"}, + portal=portal, id=id, filters=filters, sort=sort, include=include, + limit=limit, offset=offset, load_all=load_all, + name=name, tags=tags, + ) + + def load_setup(self, portal="private", id: str = None, + name: str = None, tags: str = None, + filters: dict = None, sort: list = None, + include: list = None, limit: int = None, + offset: int = None, load_all: bool = False): + """Load setup record(s). + + Parameters + ---------- + name : Filter by name (case-insensitive contains). + tags : Filter by tag. + """ + return self._convenience_load( + "setup", + default_include=[], + filter_map={"name": "name.icontains", "tags": "tags"}, + portal=portal, id=id, filters=filters, sort=sort, include=include, + limit=limit, offset=offset, load_all=load_all, + name=name, tags=tags, + ) + + def load_hardwaredevice(self, portal="private", id: str = None, + name: str = None, tags: str = None, + filters: dict = None, sort: list = None, + include: list = None, limit: int = None, + offset: int = None, load_all: bool = False): + """Load hardware device record(s). + + Parameters + ---------- + name : Filter by name (case-insensitive contains). + tags : Filter by tag. + """ + return self._convenience_load( + "hardwaredevice", + default_include=[], + filter_map={"name": "name.icontains", "tags": "tags"}, + portal=portal, id=id, filters=filters, sort=sort, include=include, + limit=limit, offset=offset, load_all=load_all, + name=name, tags=tags, + ) + + def load_brainregion(self, portal="private", id: str = None, + name: str = None, tags: str = None, + filters: dict = None, sort: list = None, + include: list = None, limit: int = None, + offset: int = None, load_all: bool = False): + """Load brain region record(s). + + Parameters + ---------- + name : Filter by name (case-insensitive contains). + tags : Filter by tag. + """ + return self._convenience_load( + "brainregion", + default_include=[], + filter_map={"name": "name.icontains", "tags": "tags"}, + portal=portal, id=id, filters=filters, sort=sort, include=include, + limit=limit, offset=offset, load_all=load_all, + name=name, tags=tags, + ) + + def load_species(self, portal="private", id: str = None, + name: str = None, tags: str = None, + filters: dict = None, sort: list = None, + include: list = None, limit: int = None, + offset: int = None, load_all: bool = False): + """Load species record(s). + + Parameters + ---------- + name : Filter by name (case-insensitive contains). + tags : Filter by tag. + """ + return self._convenience_load( + "species", + default_include=[], + filter_map={"name": "name.icontains", "tags": "tags"}, + portal=portal, id=id, filters=filters, sort=sort, include=include, + limit=limit, offset=offset, load_all=load_all, + name=name, tags=tags, + ) + + def load_strain(self, portal="private", id: str = None, + name: str = None, species: str = None, tags: str = None, + filters: dict = None, sort: list = None, + include: list = None, limit: int = None, + offset: int = None, load_all: bool = False): + """Load strain record(s). + + Parameters + ---------- + name : Filter by name (case-insensitive contains). + species : Filter by species UUID. + tags : Filter by tag. + """ + return self._convenience_load( + "strain", + default_include=[], + filter_map={"name": "name.icontains", "species": "species.id", "tags": "tags"}, + portal=portal, id=id, filters=filters, sort=sort, include=include, + limit=limit, offset=offset, load_all=load_all, + name=name, species=species, tags=tags, + ) + + def load_publication(self, portal="private", id: str = None, + name: str = None, tags: str = None, + filters: dict = None, sort: list = None, + include: list = None, limit: int = None, + offset: int = None, load_all: bool = False): + """Load publication record(s). + + Parameters + ---------- + name : Filter by name (case-insensitive contains). + tags : Filter by tag. + """ + return self._convenience_load( + "publication", + default_include=[], + filter_map={"name": "name.icontains", "tags": "tags"}, + portal=portal, id=id, filters=filters, sort=sort, include=include, + limit=limit, offset=offset, load_all=load_all, + name=name, tags=tags, + ) + + def load_laboratory(self, portal="private", id: str = None, + name: str = None, tags: str = None, + filters: dict = None, sort: list = None, + include: list = None, limit: int = None, + offset: int = None, load_all: bool = False): + """Load laboratory record(s). + + Parameters + ---------- + name : Filter by name (case-insensitive contains). + tags : Filter by tag. + """ + return self._convenience_load( + "laboratory", + default_include=[], + filter_map={"name": "name.icontains", "tags": "tags"}, + portal=portal, id=id, filters=filters, sort=sort, include=include, + limit=limit, offset=offset, load_all=load_all, + name=name, tags=tags, + ) diff --git a/brainstem_api_tools/cli.py b/brainstem_api_tools/cli.py new file mode 100644 index 0000000..043b5d1 --- /dev/null +++ b/brainstem_api_tools/cli.py @@ -0,0 +1,180 @@ +"""Command-line interface for BrainSTEM API tools. + +Usage examples +-------------- + brainstem login + brainstem login --headless + brainstem login --url http://127.0.0.1:8000/ + brainstem logout + brainstem load session + brainstem load session --portal public --filters name.icontains=rat --sort -name + brainstem load session --id + brainstem load session --limit 20 --offset 40 + brainstem save session --data '{"name":"New","projects":[""]}' + brainstem save session --id --data '{"description":"updated"}' + brainstem delete session --id +""" + +import argparse +import json +import os +import sys + +from .brainstem_api_client import BrainstemClient, _MODEL_TO_APP, _TOKEN_FILE + + +def _build_parser() -> argparse.ArgumentParser: + # Shared flags available on every subcommand + common = argparse.ArgumentParser(add_help=False) + common.add_argument( + "--token", + default=os.environ.get("BRAINSTEM_API_TOKEN"), + help="API token. Defaults to $BRAINSTEM_API_TOKEN or the cached token.", + ) + common.add_argument( + "--headless", + action="store_true", + help="Print verification URL + code instead of opening a browser.", + ) + common.add_argument( + "--url", + default=None, + help="Base URL of the BrainSTEM server (default: https://www.brainstem.org/).", + ) + + parser = argparse.ArgumentParser( + prog="brainstem", + description="BrainSTEM command-line API client.", + parents=[common], + ) + + sub = parser.add_subparsers(dest="command", required=True) + + # ---- login ------------------------------------------------------ + sub.add_parser( + "login", + parents=[common], + help="Authenticate and cache your API token.", + ) + + # ---- logout ----------------------------------------------------- + sub.add_parser( + "logout", + parents=[common], + help="Remove the cached API token.", + ) + + # ---- load ------------------------------------------------------- + p_load = sub.add_parser("load", parents=[common], help="Load records from BrainSTEM.") + p_load.add_argument("model", choices=sorted(_MODEL_TO_APP), help="Model name.") + p_load.add_argument("--portal", default="private", help="'private' or 'public'.") + p_load.add_argument("--id", help="UUID of a specific record.") + p_load.add_argument( + "--filters", + nargs="+", + metavar="FIELD=VALUE", + help="Filter expressions, e.g. name.icontains=rat", + ) + p_load.add_argument( + "--sort", + nargs="+", + metavar="FIELD", + help="Sort fields. Prefix with '-' for descending.", + ) + p_load.add_argument( + "--include", + nargs="+", + metavar="RELATION", + help="Related models to embed.", + ) + p_load.add_argument("--limit", type=int, help="Max records (API max: 100).") + p_load.add_argument("--offset", type=int, help="Records to skip (pagination).") + + # ---- save ------------------------------------------------------- + p_save = sub.add_parser("save", parents=[common], help="Create or update a record.") + p_save.add_argument("model", choices=sorted(_MODEL_TO_APP), help="Model name.") + p_save.add_argument("--portal", default="private") + p_save.add_argument("--id", help="UUID of the record to update (omit to create).") + p_save.add_argument( + "--data", + required=True, + help="JSON string of fields to submit, e.g. '{\"name\":\"x\"}'.", + ) + + # ---- delete ----------------------------------------------------- + p_delete = sub.add_parser("delete", parents=[common], help="Delete a record by ID.") + p_delete.add_argument("model", choices=sorted(_MODEL_TO_APP), help="Model name.") + p_delete.add_argument("--portal", default="private") + p_delete.add_argument("--id", required=True, help="UUID of the record to delete.") + + return parser + + +def main(): + parser = _build_parser() + args = parser.parse_args() + + if args.command == "logout": + if _TOKEN_FILE.exists(): + _TOKEN_FILE.unlink() + print(f"Logged out. Token removed from {_TOKEN_FILE}") + else: + print("No cached token found.") + return + + client = BrainstemClient( + token=args.token, + headless=getattr(args, "headless", False), + url=getattr(args, "url", None), + ) + + if args.command == "login": + print(f"Token: {client._token}") + print(f"Cached at: {_TOKEN_FILE}") + return + + if args.command == "load": + filters = {} + for expr in (args.filters or []): + if "=" not in expr: + parser.error(f"Invalid filter '{expr}': expected FIELD=VALUE") + k, v = expr.split("=", 1) + filters[k] = v + + resp = client.load( + args.model, + portal=args.portal, + id=args.id, + filters=filters or None, + sort=args.sort, + include=args.include, + limit=args.limit, + offset=args.offset, + ) + + elif args.command == "save": + try: + data = json.loads(args.data) + except json.JSONDecodeError as exc: + parser.error(f"Invalid JSON for --data: {exc}") + + resp = client.save(args.model, portal=args.portal, id=args.id, data=data) + + elif args.command == "delete": + resp = client.delete(args.model, portal=args.portal, id=args.id) + + # Output + if resp.status_code == 204: + print("Deleted successfully.") + else: + try: + print(json.dumps(resp.json(), indent=2)) + except Exception: + print(resp.text) + + if not resp.ok: + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/brainstem_api_tutorial.ipynb b/brainstem_api_tutorial.ipynb index b5ec0c8..fa521e1 100644 --- a/brainstem_api_tutorial.ipynb +++ b/brainstem_api_tutorial.ipynb @@ -22,9 +22,7 @@ "metadata": {}, "outputs": [], "source": [ - "from brainstem_api_tools import BrainstemClient\n", - "\n", - "import json" + "from brainstem_api_tools import BrainstemClient" ] }, { @@ -32,7 +30,10 @@ "metadata": {}, "source": [ "# 1. Client Setup and Authentication\n", - "The Brainstem API client provides easy access to the Brainstem data platform. To get started, initialize the client, which will prompt you for login credentials the first time. After successful authentication, the client will generate a token that can be saved for future use. This allows you to avoid re-entering your credentials each time." + "\n", + "Authentication uses a **browser-based device authorization flow** that supports two-factor authentication (2FA). No credentials are entered into this tool.\n", + "\n", + "On the first run, a browser window opens for you to approve access. The token is then cached at `~/.config/brainstem/token` and reused automatically on subsequent runs." ] }, { @@ -41,7 +42,8 @@ "metadata": {}, "outputs": [], "source": [ - "# When initializing client without a token\n", + "# First run: opens a browser window for secure login.\n", + "# Subsequent runs reuse the cached token automatically.\n", "client = BrainstemClient()" ] }, @@ -49,8 +51,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# 2. Loging in with token\n", - "For convenience, you can save your authentication token to a configuration file. This allows you to quickly initialize the client in future sessions without entering login details. Simply load the token from your saved file and pass it to the client constructor." + "# 2. Alternative Authentication Options\n", + "\n", + "You can also pass a token directly, use headless mode (prints a URL + code instead of opening a browser), or point the client at a different server." ] }, { @@ -59,23 +62,23 @@ "metadata": {}, "outputs": [], "source": [ - "token = None # Input your token here\n", - "if token:\n", - " client = BrainstemClient(token=token)\n", - " print(\"Client initialized with saved token\")\n", - "else:\n", - " # This will prompt for username/password if run\n", - " print(\"No saved token found. Will need to login.\")\n", - " client = BrainstemClient()" + "# Pass a token directly to skip the browser flow:\n", + "# client = BrainstemClient(token=\"YOUR_TOKEN\")\n", + "\n", + "# Headless mode — prints a URL + code to enter manually (useful in remote/CI environments):\n", + "# client = BrainstemClient(headless=True)\n", + "\n", + "# Connect to a local development server:\n", + "# client = BrainstemClient(url=\"http://127.0.0.1:8000/\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "# 3. Loading Data Models (Sessions Example - Public Data)\n", + "# 3. Loading Public Data\n", "\n", - "The Brainstem platform hosts both public and private data. Public data can be accessed without authentication. Use the portal=\"public\" parameter to specifically query public repositories. This is useful for exploring datasets like the Allen Institute's Visual Coding – Neuropixels project." + "The BrainSTEM platform hosts both public and private data. Public data can be accessed without authentication. Use `portal=\"public\"` to query public repositories. This is useful for exploring datasets like the Allen Institute's Visual Coding – Neuropixels project." ] }, { @@ -84,8 +87,8 @@ "metadata": {}, "outputs": [], "source": [ - "# Load public projects (no authentication needed)\n", - "public_projects = client.load_model(\"project\", portal=\"public\").json()\n", + "# Load public projects (no authentication needed for public portal)\n", + "public_projects = client.load(\"project\", portal=\"public\").json()\n", "\n", "# Print the number of available public projects\n", "print(f\"Found {len(public_projects.get('projects', []))} public projects\")\n", @@ -95,8 +98,7 @@ "for i, project in enumerate(public_projects.get('projects', [])[:3]):\n", " print(f\"{i+1}. {project.get('name', 'Unnamed')}\")\n", " print(f\" Description: {project.get('description', 'No description')[:100]}...\")\n", - " print(\"\")\n", - "\n" + " print(\"\")" ] }, { @@ -105,20 +107,9 @@ "metadata": {}, "outputs": [], "source": [ - "# To load private projects (requires authentication):\n", - "\"\"\"\n", - "# Load your private projects (requires valid token)\n", - "private_projects = client.load_model(\"project\").json() # default is private\n", - "\n", - "# Print the number of your private projects\n", - "print(f\"Found {len(private_projects.get('projects', []))} private projects\")\n", - "\n", - "# Display your private projects\n", - "print(\"\\nYour private projects:\")\n", - "for i, project in enumerate(private_projects.get('projects', [])[:3]):\n", - " print(f\"{i+1}. {project.get('name', 'Unnamed')}\")\n", - " print(f\" Created by: {project.get('principal_investigator', 'Unknown')}\")\n", - "\"\"\"" + "# Load your private projects (requires authentication — client must be initialised with a valid token)\n", + "# private_projects = client.load(\"project\").json() # default portal is \"private\"\n", + "# print(f\"Found {len(private_projects.get('projects', []))} private projects\")" ] }, { @@ -127,12 +118,10 @@ "source": [ "# 4. Filtering\n", "\n", - "The API supports powerful filtering capabilities to help you find specific records. You can filter by exact matches, partial text matches (case-sensitive or insensitive), and various other criteria. Multiple filters can be combined to narrow your search. Common filter modifiers include .icontains for case-insensitive text search and .iexact for exact matching.\n", + "The API supports powerful filtering capabilities. You can filter by exact matches, partial text matches (case-sensitive or insensitive), and various other criteria. Multiple filters on different fields are combined with AND logic.\n", "\n", "### Filter Modifiers\n", "\n", - "You can use the following filter modifiers to refine your search:\n", - "\n", "- `attribute.contains`: Case-sensitive partial match\n", "- `attribute.icontains`: Case-insensitive partial match\n", "- `attribute.iexact`: Case-insensitive exact match\n", @@ -148,43 +137,81 @@ "metadata": {}, "outputs": [], "source": [ - "# Basic filtering example\n", - "filtered_projects = client.load_model(\n", + "# Case-insensitive name filter\n", + "filtered_projects = client.load(\n", " \"project\",\n", " portal=\"public\",\n", - " filters={'name.icontains': 'Allen'} # Case-insensitive contains\n", + " filters={'name.icontains': 'Allen'}\n", ").json()\n", - "\n", "print(f\"Found {len(filtered_projects.get('projects', []))} projects matching 'Allen'\")\n", "\n", - "# Multiple filters with AND logic\n", - "multi_filter = client.load_model(\n", + "# Filters on different fields are combined with AND logic\n", + "# Note: a Python dict cannot have two identical keys — use different field names to combine filters\n", + "multi_filter = client.load(\n", " \"project\",\n", " portal=\"public\",\n", " filters={\n", - " 'name.icontains': 'institute',\n", - " 'name.iendswith': 'neuropixels' # Fixed typo in operator name\n", + " 'name.icontains': 'Allen',\n", + " 'description.icontains': 'Neuropixels'\n", " }\n", ").json()\n", "\n", - "if isinstance(multi_filter, dict):\n", - " count = len(multi_filter.get('projects', []))\n", - "else:\n", - " count = 0\n", + "print(f\"Found {len(multi_filter.get('projects', []))} projects matching both criteria\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 5. Retrieving Related Data\n", + "# Find the Allen Institute: Visual Coding – Neuropixels project by exact name\n", + "allen_projects = client.load(\n", + " \"project\",\n", + " portal=\"public\",\n", + " filters={'name.iexact': 'Allen Institute: Visual Coding \\u2013 Neuropixels'}\n", + ").json()\n", + "\n", + "project_id = allen_projects['projects'][0]['id']\n", + "\n", + "# Load the project by ID to get its full details\n", + "project_details = client.load(\"project\", portal=\"public\", id=project_id).json()\n", + "project = project_details['project']\n", + "print(f\"Project: {project['name']}\")\n", + "print(f\"Description: {project['description'][:200]}...\")\n", + "\n", + "# Load subjects belonging to this project\n", + "subjects = client.load(\n", + " \"subject\",\n", + " portal=\"public\",\n", + " filters={'projects': project_id}\n", + ").json()\n", + "print(f\"\\nSubjects: {len(subjects['subjects'])}\")\n", "\n", - "print(f\"Found {count} projects matching all criteria\")\n", + "# Load sessions and include experiment data\n", + "sessions = client.load(\n", + " \"session\",\n", + " portal=\"public\",\n", + " filters={'projects': project_id},\n", + " include=['dataacquisition', 'behaviors']\n", + ").json()\n", + "print(f\"Sessions: {len(sessions['sessions'])}\")\n", "\n", - "print(\"\\nResponse type:\", type(multi_filter))\n", - "if isinstance(multi_filter, dict) and 'projects' in multi_filter:\n", - " print(\"Response contains 'projects' key with data\")\n", - " print(f\"The id of project is {multi_filter['projects'][0]['id']} and the name is {multi_filter['projects'][0]['name']}\")" + "# Show details for the first session\n", + "if sessions['sessions']:\n", + " s = sessions['sessions'][0]\n", + " print(f\"\\nFirst session: {s['name']}\")\n", + " print(f\" Data acquisitions: {len(s.get('dataacquisition', []))}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "# 5. Basic Data Retrieval " + "# 6. Sorting\n", + "\n", + "You can control the order of returned results using the `sort` parameter. Sort by any field in ascending order or use a `-` prefix for descending order." ] }, { @@ -193,12 +220,26 @@ "metadata": {}, "outputs": [], "source": [ - "allen_projects = client.load_model(\n", - " \"project\", \n", - " portal=\"public\", \n", - " filters={'name.iexact': 'Allen Institute: Visual Coding – Neuropixels'}\n", - ").json()\n", - "print(allen_projects['projects'][0])" + "# Sort projects alphabetically by name (ascending)\n", + "alpha_projects = client.load(\"project\", portal=\"public\", sort=['name']).json()\n", + "print(\"Alphabetically sorted:\")\n", + "for p in alpha_projects.get('projects', [])[:3]:\n", + " print(f\" {p['name']}\")\n", + "\n", + "# Sort descending\n", + "reverse_alpha = client.load(\"project\", portal=\"public\", sort=['-name']).json()\n", + "print(\"\\nReverse alphabetical:\")\n", + "for p in reverse_alpha.get('projects', [])[:3]:\n", + " print(f\" {p['name']}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 7. Pagination\n", + "\n", + "By default the API returns one page of results. Use `limit` and `offset` for manual pagination, or `load_all=True` to automatically fetch every page and return a combined dict." ] }, { @@ -207,67 +248,24 @@ "metadata": {}, "outputs": [], "source": [ - "# Find the Allen Institute: Visual Coding – Neuropixels project\n", - "allen_projects = client.load_model(\n", - " \"project\", \n", - " portal=\"public\", \n", - " filters={'name.iexact': 'Allen Institute: Visual Coding – Neuropixels'}\n", - ").json()\n", - "\n", - "project_id = allen_projects['projects'][0]['id']\n", - "\n", - "# Get basic project information without including sessions\n", - "project_details = client.load_model(\n", - " \"project\",\n", - " portal=\"public\",\n", - " id=project_id\n", - ").json()\n", - "\n", - "# Access project information\n", - "project = project_details['project']\n", - "print(f\"Project name: {project['name']}\")\n", - "print(f\"Description: {project['description'][:200]}...\")\n", - "\n", - "# Get subjects related to this project\n", - "subjects = client.load_model(\n", - " \"subject\", \n", - " portal=\"public\",\n", - " filters={'projects': project_id}\n", - ").json()\n", - "\n", - "print(f\"\\nThis project has {len(subjects['subjects'])} subjects\")\n", - "\n", - "# Access the sessions included with the project - sessions should be available now\n", - "if 'sessions' in project:\n", - " session_ids = project['sessions']\n", - " print(f\"This project has {len(session_ids)} sessions\")\n", - " \n", - " # Get details for the first session as an example\n", - " if session_ids:\n", - " first_session_id = session_ids[0]\n", - " session_details = client.load_model(\n", - " \"session\",\n", - " portal=\"public\",\n", - " id=first_session_id\n", - " ).json()\n", - " \n", - " # Print session details\n", - " if 'session' in session_details:\n", - " session = session_details['session']\n", - " print(f\"\\nExample session details:\")\n", - " print(f\"ID: {first_session_id}\")\n", - " print(f\"Name: {session['name']}\")\n", - "else:\n", - " print(\"No sessions found in project data\")" + "# Manual pagination: fetch records 1-10, then 11-20\n", + "page1 = client.load(\"project\", portal=\"public\", limit=10, offset=0).json()\n", + "page2 = client.load(\"project\", portal=\"public\", limit=10, offset=10).json()\n", + "print(f\"Page 1: {len(page1['projects'])} projects\")\n", + "print(f\"Page 2: {len(page2['projects'])} projects\")\n", + "\n", + "# Auto-pagination: fetch ALL records across all pages in one call\n", + "all_projects = client.load(\"project\", portal=\"public\", load_all=True)\n", + "print(f\"\\nTotal public projects (all pages): {len(all_projects['projects'])}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "# 6. Sorting\n", + "# 8. Convenience Loaders\n", "\n", - "You can control the order of returned results using the sort parameter. Sort by any field in ascending order (e.g., alphabetically by name) or use a minus sign prefix for descending order." + "Convenience loaders are shortcut methods for the most common models. They set sensible default `include` relations and expose named keyword arguments for the most frequent filter fields." ] }, { @@ -276,35 +274,45 @@ "metadata": {}, "outputs": [], "source": [ - "# Sort projects alphabetically by name\n", - "alpha_projects = client.load_model(\n", - " \"project\",\n", - " portal=\"public\",\n", - " sort=['name'] # Ascending alphabetical order\n", - ").json()\n", - "\n", - "print(\"\\nAlphabetically sorted projects:\")\n", - "for i, project in enumerate(alpha_projects.get('projects', [])[:3]):\n", - " print(f\"{i+1}. {project.get('name')}\")\n", - "\n", - "# Sort projects reverse-alphabetically by name\n", - "reverse_alpha = client.load_model(\n", - " \"project\",\n", - " portal=\"public\",\n", - " sort=['-name'] # Descending alphabetical order\n", - ").json()\n", - "\n", - "print(\"\\nReverse alphabetically sorted projects:\")\n", - "for i, project in enumerate(reverse_alpha.get('projects', [])[:3]):\n", - " print(f\"{i+1}. {project.get('name')}\")" + "# load_project — embeds sessions, subjects, collections, cohorts by default\n", + "projects = client.load_project(portal=\"public\", name=\"Allen\", load_all=True)\n", + "print(f\"Projects matching 'Allen': {len(projects['projects'])}\")\n", + "\n", + "# load_session — embeds dataacquisition, behaviors, manipulations, epochs by default\n", + "# (use portal=\"public\" for public data, or omit for private)\n", + "# sessions = client.load_session(name=\"Rat\", load_all=True)\n", + "\n", + "# load_subject — embeds procedures and subjectlogs by default\n", + "# subjects = client.load_subject(sex=\"M\", projects=\"\", load_all=True)\n", + "\n", + "# load_behavior / load_dataacquisition / load_manipulation — scope by session UUID\n", + "# behaviors = client.load_behavior(session=\"\", load_all=True)\n", + "# dataacquisition = client.load_dataacquisition(session=\"\", load_all=True)\n", + "# manipulations = client.load_manipulation(session=\"\", load_all=True)\n", + "\n", + "# load_procedure — scopes by subject UUID\n", + "# procedures = client.load_procedure(subject=\"\", load_all=True)\n", + "\n", + "# load_collection / load_cohort\n", + "# collection = client.load_collection(name=\"MyCollection\")\n", + "# cohort = client.load_cohort(name=\"MyCohort\")\n", + "\n", + "# All convenience loaders accept extra filters= and a custom include= override:\n", + "# sessions = client.load_session(\n", + "# name=\"Rat\",\n", + "# include=[\"projects\"],\n", + "# filters={\"description.icontains\": \"hippocampus\"},\n", + "# load_all=True,\n", + "# )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "# 7. Updating data\n", - "The API supports creating new records and updating existing ones. When creating a record, provide the required fields for that model type. For updates, fetch the existing record, modify its attributes, and save it back. Both operations require appropriate permissions." + "# 9. Creating and Updating Data\n", + "\n", + "Use `client.save()` to create or update records. Omit `id` to create a new record (POST); provide `id` to update an existing one (PATCH). Write operations require authentication on the private portal." ] }, { @@ -313,62 +321,46 @@ "metadata": {}, "outputs": [], "source": [ - "# Example of updating a session (use your own private data)\n", - "# Note: This is just syntax demonstration and won't execute on public data\n", - "token = None # Input your token here\n", - "if token:\n", - " client = BrainstemClient(token=token)\n", - " print(\"Client initialized with saved token\")\n", - "else:\n", - " # This will prompt for username/password if run\n", - " print(\"No saved token found. Will need to login.\")\n", - " client = BrainstemClient()\n", - "\"\"\"\n", - "# First load a session you have permission to modify\n", - "filtered_session = client.load_model('session', filters={'name.iexact': 'your session'}).json()\n", - "\n", - "# Your existing code works without using a 'session' key\n", - "filtered_session['description'] = 'This is a test description, showing how to update a session'\n", - "\n", - "# Pass the whole filtered_session object to save_model\n", - "updated_session = client.save_model(\n", - " 'session', \n", - " id=filtered_session['sessions'][0]['id'], \n", - " data=filtered_session\n", + "# Create a new session (replace project UUID with your own)\n", + "new_session = client.save(\n", + " \"session\",\n", + " data={\n", + " \"name\": \"My new session\",\n", + " \"description\": \"Created via the Python API\",\n", + " \"projects\": [\"\"],\n", + " }\n", ").json()\n", + "# print(new_session)\n", "\n", - "print(\"Session updated successfully\")\n", - "\"\"\"\n", + "# Update an existing session (replace UUID with your own)\n", + "updated = client.save(\n", + " \"session\",\n", + " id=\"\",\n", + " data={\"description\": \"Updated description\"}\n", + ").json()\n", + "# print(updated)\n", "\n", - "# Creating a new subject - SYNTAX DEMONSTRATION ONLY\n", - "\"\"\"\n", "# Create a new subject\n", - "new_subject_data = {\n", - " 'name': 'Test Subject 001',\n", - " 'species': 'mouse',\n", - " 'sex': 'M',\n", - " 'projects': ['your-project-id'] # Project this subject belongs to\n", - "}\n", - "\n", - "# Save the new subject\n", - "new_subject = client.save_model(\n", - " 'subject',\n", - " data=new_subject_data\n", + "new_subject = client.save(\n", + " \"subject\",\n", + " data={\n", + " \"name\": \"Sub001\",\n", + " \"sex\": \"M\",\n", + " \"projects\": [\"\"],\n", + " }\n", ").json()\n", + "# print(new_subject)\n", "\n", - "print(f\"New subject created with ID: {new_subject['subject']['id']}\")\n", - "\"\"\"\n", - "\n", - "print(\"To use them, replace placeholder IDs with your own private data IDs.\")\n", - "print(\"Only run these commands on data you have permission to modify.\")" + "print(\"Replace the placeholder UUIDs above with your own IDs before running.\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "# 8. Deleting data\n", - "You can remove records from the database using their IDs. Deletion is permanent, so use this operation with caution. The API will return a success status code (204) when deletion is successful. Like other write operations, deletion requires appropriate permissions." + "# 10. Deleting Data\n", + "\n", + "Use `client.delete()` with a model name and record UUID. A successful deletion returns HTTP 204 (no content). This operation is **permanent** — use with care." ] }, { @@ -377,43 +369,15 @@ "metadata": {}, "outputs": [], "source": [ - "token = None # Input your token here.\n", - "if token:\n", - " client = BrainstemClient(token=token)\n", - " print(\"Client initialized with saved token\")\n", + "# Delete a record by UUID (replace with your own ID)\n", + "response = client.delete(\"session\", id=\"\")\n", + "\n", + "if response.status_code == 204:\n", + " print(\"Session deleted successfully.\")\n", "else:\n", - " # This will prompt for username/password if run\n", - " print(\"No saved token found. Will need to login.\")\n", - " client = BrainstemClient()\n", - "\n", - "# Example of deleting a session (demonstration only)\n", - "\"\"\"\n", - "# First create a test session that we can safely delete\n", - "new_session_data = {\n", - " 'name': 'Temporary Test Session',\n", - " 'description': 'This session will be deleted',\n", - " 'project': 'your-project-id' # Replace with your own project ID\n", - "}\n", - "\n", - "# Create the session\n", - "created_session = client.save_model('session', data=new_session_data).json()\n", - "\n", - "if 'session' in created_session:\n", - " session_id = created_session['session']['id']\n", - " print(f\"Created temporary session with ID: {session_id}\")\n", - " \n", - " # Now delete the session we just created\n", - " delete_response = client.delete_model(\"session\", id=session_id)\n", - " \n", - " # Check if deletion was successful\n", - " if delete_response.status_code == 204: # 204 No Content indicates success\n", - " print(\"Session deleted successfully\")\n", - " else:\n", - " print(f\"Delete failed with status code: {delete_response.status_code}\")\n", - "\"\"\"\n", - "\n", - "print(\"To use it, you would need to have write permissions and valid IDs.\")\n", - "print(\"Only delete data that you have created and have permission to remove.\")" + " print(f\"Delete failed: {response.status_code} — {response.text}\")\n", + "\n", + "print(\"Replace the placeholder UUID above with your own ID before running.\")" ] } ], diff --git a/brainstem_api_tutorial.py b/brainstem_api_tutorial.py index 2b94ce2..e959f8e 100644 --- a/brainstem_api_tutorial.py +++ b/brainstem_api_tutorial.py @@ -1,44 +1,102 @@ from brainstem_api_tools import BrainstemClient -# 0. Load the client. User email and password will be requested. +# 0. Load the client. +# On first run a browser window opens for secure login (supports 2FA). +# The token is cached at ~/.config/brainstem/token and reused automatically. +# Pass a saved token directly to skip the browser flow: +# client = BrainstemClient(token="YOUR_TOKEN") +# To connect to a local development server: +# client = BrainstemClient(url="http://127.0.0.1:8000/") client = BrainstemClient() # Loading sessions. -## load_model can be used to load any model: +## load can be used to load any model: ## We just need to pass our settings and the name of the model. -output1 = client.load_model('session').json() +output1 = client.load('session').json() ### We can fetch a single session entry from the loaded models. session = output1["sessions"][0] ## We can also filter the models by providing a dictionary where -## the keys are the fields and the values the possible contents. -## In this example, it will just load sessions whose name is "yeah". -output1 = client.load_model('session', filters={'name': 'yeah'}).json() +## In this example, it will just load sessions whose name exactly equals "yeah". +output1 = client.load('session', filters={'name': 'yeah'}).json() ## Loaded models can be sorted by different criteria applying to -## their fields. In this example, sessions will be sorted in -## descending ording according to their name. -output1 = client.load_model('session', sort=['-name']).json() +## their fields. In this example, sessions will be sorted in +## descending order according to their name. +output1 = client.load('session', sort=['-name']).json() ## In some cases models contain relations with other models, and ## they can be also loaded with the models if requested. In this ## example, all the projects, data acquisition, behaviors and -## manipulations related to each session will be included. -output1 = client.load_model('session', include=['projects', 'dataacquisition', 'behaviors', 'manipulations']).json() +## manipulations related to each session will be included. +output1 = client.load('session', include=['projects', 'dataacquisition', 'behaviors', 'manipulations']).json() -### The list of related experiment data can be retrived from the +### The list of related experiment data can be retrieved from the ### returned dictionary. dataacquisition = output1["sessions"][0]["dataacquisition"] ## All these options can be combined to suit the requirements -## of the users. For example, we can get only the session that +## of the users. For example, we can get only the sessions that ## contain the word "Rat" in their name, sorted in descending ## order by their name and including the related projects. -output1 = client.load_model('session', filters={'name.icontains': 'Rat'}, sort=["-name"], include=['projects']).json() +output1 = client.load('session', filters={'name.icontains': 'Rat'}, sort=["-name"], include=['projects']).json() + +## Pagination: load the second page of results (records 21-40). +output1 = client.load('session', limit=20, offset=20).json() + +## load_all=True automatically fetches every page and returns a +## combined dict — useful when the record count exceeds one page. +output1 = client.load('session', load_all=True) +all_sessions = output1["sessions"] + + +# Convenience loaders + +## Each convenience loader sets sensible defaults (related models +## included automatically) and exposes named keyword arguments for +## the most common filter fields. + +## load_session includes dataacquisition, behaviors, manipulations +## and epochs by default. Filter by name substring: +output_sessions = client.load_session(name='Rat', load_all=True) + +## load_subject includes procedures and subjectlogs. +## Filter by sex ('M' / 'F') and/or project UUID: +project_uuid = '' # replace with a real project UUID +output_subjects = client.load_subject(sex='M', projects=project_uuid, load_all=True) + +## load_project includes sessions, subjects, collections and cohorts. +output_projects = client.load_project(name='MyProject') + +## load_collection includes sessions by default. +output_collection = client.load_collection(name='MyCollection') + +## load_cohort includes subjects by default. +output_cohort = client.load_cohort(name='MyCohort') + +## load_behavior / load_dataacquisition / load_manipulation all +## accept session= to scope results to a single session. +session_uuid = '' # replace with a real session UUID +output_behaviors = client.load_behavior(session=session_uuid, load_all=True) +output_dataacquisition = client.load_dataacquisition(session=session_uuid, load_all=True) +output_manipulations = client.load_manipulation(session=session_uuid, load_all=True) + +## load_procedure scopes by subject UUID. +subject_uuid = '' # replace with a real subject UUID +output_procedures = client.load_procedure(subject=subject_uuid, load_all=True) + +## Any convenience loader also accepts extra filters= and custom +## include= overrides just like load() itself. +output_sessions = client.load_session( + name='Rat', + include=['projects'], + filters={'description.icontains': 'hippocampus'}, + load_all=True, +) # Updating a session @@ -48,7 +106,7 @@ ## previously loaded sessions. session = {} session["description"] = 'new description' -output2 = client.save_model("session", id="0e39c1fd-f413-4142-95f7-f50185e81fa4", data=session).json() +output2 = client.save("session", id="", data=session).json() # Creating a new session @@ -57,14 +115,22 @@ ## required fields. session = {} session["name"] = "New session" -session["projects"] = ['e7475834-7733-48cf-9e3b-f4f2d2d0305a'] +session["projects"] = [''] # replace with a real project UUID session["description"] = 'description' ## Submitting session -output3 = client.save_model("session", data=session).json() +output3 = client.save("session", data=session).json() + + +# Deleting a session + +## Pass the model name and the UUID of the record to remove. +response = client.delete("session", id="") +if response.status_code == 204: + print("Session deleted") # Load public projects ## Request the public data by defining the portal to be public -output4 = client.load_model("project", portal="public").json() +output4 = client.load("project", portal="public").json() diff --git a/pyproject.toml b/pyproject.toml index f5b211d..6e85b8c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,33 @@ [build-system] requires = ["setuptools>=61.0", "wheel"] -build-backend = "setuptools.build_meta" \ No newline at end of file +build-backend = "setuptools.build_meta" + +[project] +name = "brainstem_python_api_tools" +version = "2.0.0" +description = "A Python toolset for interacting with the BrainSTEM API." +readme = "README.md" +license = { text = "MIT" } +authors = [ + { name = "BrainSTEM Team", email = "petersen.peter@gmail.com" } +] +requires-python = ">=3.8" +dependencies = [ + "requests>=2.28", +] + +[project.optional-dependencies] +dev = [ + "pytest", + "pytest-mock", +] + +[project.urls] +Homepage = "https://github.com/brainstem-org/brainstem_python_api_tools" + +[project.scripts] +brainstem = "brainstem_api_tools.cli:main" + +[tool.setuptools.packages.find] +where = ["."] +include = ["brainstem_api_tools*"] diff --git a/setup.py b/setup.py deleted file mode 100644 index 42d8586..0000000 --- a/setup.py +++ /dev/null @@ -1,24 +0,0 @@ -from setuptools import setup, find_packages - -setup( - name="brainstem_python_api_tools", - version="0.2.0", - packages=find_packages(), - install_requires=[ - "requests", - "jupyter", - ], - author="BrainSTEM Team", - author_email="petersen.peter@gmail.com", - description="A Python toolset for interacting with API of BrainSTEM.", - long_description=open("README.md").read(), - long_description_content_type="text/markdown", - url="https://github.com/brainstem-org/brainstem_python_api_tools", - license="MIT", - classifiers=[ - "Programming Language :: Python :: 3", - "License :: OSI Approved :: MIT License", - "Operating System :: OS Independent", - ], - python_requires='>=3.7', -) diff --git a/__init__.py b/tests/__init__.py similarity index 100% rename from __init__.py rename to tests/__init__.py diff --git a/tests/test_client.py b/tests/test_client.py new file mode 100644 index 0000000..3851c2e --- /dev/null +++ b/tests/test_client.py @@ -0,0 +1,729 @@ +"""Unit tests for BrainstemClient. + +All HTTP calls are intercepted with pytest-mock / unittest.mock so no +network access is required. +""" + +import pytest +from unittest.mock import MagicMock, patch, PropertyMock + +from brainstem_api_tools.brainstem_api_client import ( + BrainstemClient, + AuthenticationError, + _MODEL_TO_APP, + _resolve_model, + _resolve_portal, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +TOKEN = "test-token-abc123" + + +def make_client(token=TOKEN): + """Return a BrainstemClient with a pre-set token (no auth flow).""" + return BrainstemClient(token=token) + + +def mock_response(status_code=200, json_body=None): + resp = MagicMock() + resp.status_code = status_code + resp.ok = status_code < 400 + resp.json.return_value = json_body or {} + return resp + + +# --------------------------------------------------------------------------- +# _resolve_model / _resolve_portal +# --------------------------------------------------------------------------- + +class TestResolvers: + def test_resolve_model_string(self): + assert _resolve_model("session") == "session" + + def test_resolve_model_invalid(self): + with pytest.raises(ValueError, match="Unknown model 'foobar'"): + _resolve_model("foobar") + + def test_resolve_portal_string(self): + assert _resolve_portal("public") == "public" + + def test_resolve_portal_invalid(self): + with pytest.raises(ValueError, match="Unknown portal 'nope'"): + _resolve_portal("nope") + + +# --------------------------------------------------------------------------- +# BrainstemClient construction +# --------------------------------------------------------------------------- + +class TestClientInit: + def test_token_set_directly(self): + client = make_client() + assert client._token == TOKEN + + def test_auth_header_on_session(self): + client = make_client() + assert client._session.headers["Authorization"] == f"Bearer {TOKEN}" + + def test_cached_token_used_on_init(self, tmp_path, monkeypatch): + token_file = tmp_path / "token" + token_file.write_text("cached-token") + monkeypatch.setattr( + "brainstem_api_tools.brainstem_api_client._TOKEN_FILE", token_file + ) + client = BrainstemClient() + assert client._token == "cached-token" + + def test_device_flow_triggered_when_no_cache(self, tmp_path, monkeypatch): + token_file = tmp_path / "token" # does not exist + monkeypatch.setattr( + "brainstem_api_tools.brainstem_api_client._TOKEN_FILE", token_file + ) + with patch.object(BrainstemClient, "_device_auth_flow", return_value="new-token") as mock_flow, \ + patch.object(BrainstemClient, "_save_token") as mock_save: + client = BrainstemClient() + mock_flow.assert_called_once() + mock_save.assert_called_once_with("new-token") + assert client._token == "new-token" + + +# --------------------------------------------------------------------------- +# load() +# --------------------------------------------------------------------------- + +class TestLoad: + def setup_method(self): + self.client = make_client() + + def _mock_get(self, json_body=None, status=200): + resp = mock_response(status, json_body) + self.client._session.get = MagicMock(return_value=resp) + return resp + + def test_load_all_sessions(self): + self._mock_get({"sessions": []}) + self.client.load("session") + self.client._session.get.assert_called_once() + url = self.client._session.get.call_args[0][0] + assert "stem/session/" in url + + def test_load_with_id(self): + uid = "00000000-0000-0000-0000-000000000001" + self._mock_get({"session": {}}) + self.client.load("session", id=uid) + url = self.client._session.get.call_args[0][0] + assert uid in url + + def test_load_filters_built_correctly(self): + self._mock_get() + self.client.load("session", filters={"name.icontains": "rat"}) + params = self.client._session.get.call_args[1]["params"] + assert params["filter{name.icontains}"] == "rat" + + def test_load_sort_built_correctly(self): + self._mock_get() + self.client.load("session", sort=["-name", "date"]) + params = self.client._session.get.call_args[1]["params"] + assert params["sort[]"] == ["-name", "date"] + + def test_load_include_appends_wildcard(self): + self._mock_get() + self.client.load("session", include=["behaviors"]) + params = self.client._session.get.call_args[1]["params"] + assert params["include[]"] == ["behaviors.*"] + + def test_load_pagination_params(self): + self._mock_get() + self.client.load("session", limit=50, offset=20) + params = self.client._session.get.call_args[1]["params"] + assert params["limit"] == 50 + assert params["offset"] == 20 + + def test_load_public_portal(self): + self._mock_get() + self.client.load("project", portal="public") + url = self.client._session.get.call_args[0][0] + assert "/public/" in url + + def test_load_invalid_model_raises(self): + with pytest.raises(ValueError, match="Unknown model"): + self.client.load("notamodel") + + def test_load_invalid_portal_raises(self): + with pytest.raises(ValueError, match="Unknown portal"): + self.client.load("session", portal="badportal") + + def test_model_to_app_routing(self): + """Spot-check a few app assignments.""" + cases = [ + ("session", "stem"), + ("breeding", "stem"), + ("project_membership_invitation", "stem"), + ("procedure", "modules"), + ("behavioralassay", "personal_attributes"), + ("license", "personal_attributes"), + ("setup", "personal_attributes"), + ("consumable", "resources"), + ("behavioralparadigm", "taxonomies"), + ("behavioralcategory", "taxonomies"), + ("regulatoryauthority", "taxonomies"), + ("brainregion", "taxonomies"), + ("publication", "dissemination"), + ("laboratory", "users"), + ("group_membership_invitation", "users"), + ("group_membership_request", "users"), + ] + for model, expected_app in cases: + self._mock_get() + self.client.load(model) + url = self.client._session.get.call_args[0][0] + assert f"/{expected_app}/{model}/" in url, ( + f"Expected app '{expected_app}' for model '{model}', got URL: {url}" + ) + + +# --------------------------------------------------------------------------- +# save() +# --------------------------------------------------------------------------- + +class TestSave: + def setup_method(self): + self.client = make_client() + + def test_create_uses_post(self): + self.client._session.post = MagicMock(return_value=mock_response(201)) + self.client.save("session", data={"name": "x", "projects": ["uuid"]}) + self.client._session.post.assert_called_once() + + def test_update_uses_patch(self): + uid = "00000000-0000-0000-0000-000000000002" + self.client._session.patch = MagicMock(return_value=mock_response(200)) + self.client.save("session", id=uid, data={"description": "updated"}) + self.client._session.patch.assert_called_once() + url = self.client._session.patch.call_args[0][0] + assert uid in url + + def test_save_invalid_model_raises(self): + with pytest.raises(ValueError): + self.client.save("ghost") + + +# --------------------------------------------------------------------------- +# delete() +# --------------------------------------------------------------------------- + +class TestDelete: + def setup_method(self): + self.client = make_client() + + def test_delete_sends_correct_url(self): + uid = "00000000-0000-0000-0000-000000000003" + self.client._session.delete = MagicMock(return_value=mock_response(204)) + self.client.delete("session", id=uid) + url = self.client._session.delete.call_args[0][0] + assert uid in url + assert "stem/session/" in url + + def test_delete_without_id_raises(self): + with pytest.raises(ValueError, match="'id' is required"): + self.client.delete("session") + + def test_delete_invalid_model_raises(self): + with pytest.raises(ValueError): + self.client.delete("ghost", id="some-uuid") + + +# --------------------------------------------------------------------------- +# Device auth flow +# --------------------------------------------------------------------------- + +class TestDeviceAuthFlow: + def test_success_flow(self): + client = make_client() # skip auth + + device_resp = MagicMock() + device_resp.raise_for_status = MagicMock() + device_resp.json.return_value = { + "device_code": "dev123", + "user_code": "ABC-DEF", + "verification_uri": "https://brainstem.org/activate", + "verification_uri_complete": "https://brainstem.org/activate?code=ABC-DEF", + "interval": 0, + } + + pending_resp = MagicMock() + pending_resp.raise_for_status = MagicMock() + pending_resp.json.return_value = {"status": "authorization_pending"} + + success_resp = MagicMock() + success_resp.raise_for_status = MagicMock() + success_resp.json.return_value = {"status": "success", "token": "new-jwt"} + + with patch("requests.post", side_effect=[device_resp, pending_resp, success_resp]), \ + patch("webbrowser.open"): + token = client._device_auth_flow() + + assert token == "new-jwt" + + def test_access_denied_raises(self): + client = make_client() + + device_resp = MagicMock() + device_resp.raise_for_status = MagicMock() + device_resp.json.return_value = { + "device_code": "dev123", + "user_code": "XYZ", + "verification_uri": "https://brainstem.org/activate", + "verification_uri_complete": "https://brainstem.org/activate?code=XYZ", + "interval": 0, + } + denied_resp = MagicMock() + denied_resp.raise_for_status = MagicMock() + denied_resp.json.return_value = {"error": "access_denied"} + + with patch("requests.post", side_effect=[device_resp, denied_resp]), \ + patch("webbrowser.open"): + with pytest.raises(AuthenticationError, match="Authorization failed"): + client._device_auth_flow() + + +# --------------------------------------------------------------------------- +# load_all auto-pagination +# --------------------------------------------------------------------------- + +class TestLoadAll: + def setup_method(self): + self.client = make_client() + + def _mock_pages(self, pages): + responses = [] + for page in pages: + r = MagicMock() + r.status_code = 200 + r.raise_for_status = MagicMock() + r.json.return_value = page + responses.append(r) + self.client._session.get = MagicMock(side_effect=responses) + + def test_single_page(self): + self._mock_pages([ + {"sessions": [{"id": "1"}, {"id": "2"}], "count": 2}, + ]) + result = self.client.load("session", load_all=True) + assert isinstance(result, dict) + assert len(result["sessions"]) == 2 + + def test_two_pages_merged(self): + self._mock_pages([ + {"sessions": [{"id": "1"}, {"id": "2"}], "count": 3}, + {"sessions": [{"id": "3"}], "count": 3}, + ]) + result = self.client.load("session", load_all=True) + assert len(result["sessions"]) == 3 + assert [r["id"] for r in result["sessions"]] == ["1", "2", "3"] + + def test_load_all_false_returns_response(self): + r = MagicMock() + self.client._session.get = MagicMock(return_value=r) + result = self.client.load("session", load_all=False) + assert result is r + + +# --------------------------------------------------------------------------- +# Convenience loaders +# --------------------------------------------------------------------------- + +class TestConvenienceLoaders: + def setup_method(self): + self.client = make_client() + self.client._session.get = MagicMock(return_value=mock_response(200, {})) + + def _get_params(self): + return self.client._session.get.call_args[1].get("params", {}) + + def test_load_project_default_include(self): + self.client.load_project() + params = self._get_params() + assert set(params["include[]"]) == { + "sessions.*", "subjects.*", "collections.*", "cohorts.*" + } + + def test_load_project_name_filter(self): + self.client.load_project(name="Allen") + params = self._get_params() + assert params["filter{name.icontains}"] == "Allen" + + def test_load_subject_default_include(self): + self.client.load_subject() + params = self._get_params() + assert set(params["include[]"]) == {"procedures.*", "subjectlogs.*"} + + def test_load_subject_sex_filter(self): + self.client.load_subject(sex="M") + assert self._get_params()["filter{sex}"] == "M" + + def test_load_subject_strain_filter(self): + uuid = "aaaaaaaa-0000-0000-0000-000000000000" + self.client.load_subject(strain=uuid) + assert self._get_params()["filter{strain.id}"] == uuid + + def test_load_session_default_include(self): + self.client.load_session() + params = self._get_params() + assert set(params["include[]"]) == { + "dataacquisition.*", "behaviors.*", "manipulations.*", "epochs.*" + } + + def test_load_session_project_filter(self): + uuid = "bbbbbbbb-0000-0000-0000-000000000000" + self.client.load_session(projects=uuid) + assert self._get_params()["filter{projects.id}"] == uuid + + def test_load_collection_default_include(self): + self.client.load_collection() + assert self._get_params()["include[]"] == ["sessions.*"] + + def test_load_cohort_default_include(self): + self.client.load_cohort() + assert self._get_params()["include[]"] == ["subjects.*"] + + def test_load_behavior_session_filter(self): + uuid = "cccccccc-0000-0000-0000-000000000000" + self.client.load_behavior(session=uuid) + assert self._get_params()["filter{session.id}"] == uuid + + def test_load_dataacquisition_session_filter(self): + uuid = "dddddddd-0000-0000-0000-000000000000" + self.client.load_dataacquisition(session=uuid) + assert self._get_params()["filter{session.id}"] == uuid + + def test_load_manipulation_session_filter(self): + uuid = "eeeeeeee-0000-0000-0000-000000000000" + self.client.load_manipulation(session=uuid) + assert self._get_params()["filter{session.id}"] == uuid + + def test_load_procedure_subject_filter(self): + uuid = "ffffffff-0000-0000-0000-000000000000" + self.client.load_procedure(subject=uuid) + assert self._get_params()["filter{subject.id}"] == uuid + + def test_custom_include_overrides_default(self): + self.client.load_session(include=["behaviors"]) + assert self._get_params()["include[]"] == ["behaviors.*"] + + def test_extra_filters_merged_with_field_kwargs(self): + self.client.load_session( + name="Rat", + filters={"description.icontains": "hippocampus"}, + ) + params = self._get_params() + assert params["filter{name.icontains}"] == "Rat" + assert params["filter{description.icontains}"] == "hippocampus" + + +# --------------------------------------------------------------------------- +# CLI logout +# --------------------------------------------------------------------------- + +class TestCLILogout: + def test_logout_removes_token_file(self, tmp_path, capsys): + import brainstem_api_tools.cli as cli_module + token_file = tmp_path / "token" + token_file.write_text("mytoken") + + with patch.object(cli_module, "_TOKEN_FILE", token_file), \ + patch("sys.argv", ["brainstem", "logout"]): + cli_module.main() + + assert not token_file.exists() + assert "Logged out" in capsys.readouterr().out + + def test_logout_no_token_file(self, tmp_path, capsys): + import brainstem_api_tools.cli as cli_module + token_file = tmp_path / "token" # does not exist + + with patch.object(cli_module, "_TOKEN_FILE", token_file), \ + patch("sys.argv", ["brainstem", "logout"]): + cli_module.main() + + assert "No cached token found" in capsys.readouterr().out + + +# --------------------------------------------------------------------------- +# Timeouts +# --------------------------------------------------------------------------- + +class TestTimeouts: + def setup_method(self): + self.client = make_client() + + def test_load_passes_timeout(self): + self.client._session.get = MagicMock(return_value=mock_response(200, {})) + self.client.load("session") + _, kwargs = self.client._session.get.call_args + assert kwargs.get("timeout") == BrainstemClient.DEFAULT_TIMEOUT + + def test_save_post_passes_timeout(self): + self.client._session.post = MagicMock(return_value=mock_response(201)) + self.client.save("session", data={"name": "x"}) + _, kwargs = self.client._session.post.call_args + assert kwargs.get("timeout") == BrainstemClient.DEFAULT_TIMEOUT + + def test_save_patch_passes_timeout(self): + uid = "00000000-0000-0000-0000-000000000010" + self.client._session.patch = MagicMock(return_value=mock_response(200)) + self.client.save("session", id=uid, data={"description": "x"}) + _, kwargs = self.client._session.patch.call_args + assert kwargs.get("timeout") == BrainstemClient.DEFAULT_TIMEOUT + + def test_delete_passes_timeout(self): + uid = "00000000-0000-0000-0000-000000000011" + self.client._session.delete = MagicMock(return_value=mock_response(204)) + self.client.delete("session", id=uid) + _, kwargs = self.client._session.delete.call_args + assert kwargs.get("timeout") == BrainstemClient.DEFAULT_TIMEOUT + + def test_load_all_passes_timeout(self): + r = MagicMock() + r.raise_for_status = MagicMock() + r.status_code = 200 + r.json.return_value = {"sessions": [{"id": "1"}], "count": 1} + self.client._session.get = MagicMock(return_value=r) + self.client.load("session", load_all=True) + _, kwargs = self.client._session.get.call_args + assert kwargs.get("timeout") == BrainstemClient.DEFAULT_TIMEOUT + + +# --------------------------------------------------------------------------- +# 401 → AuthenticationError in load_all path +# --------------------------------------------------------------------------- + +class TestLoadAll401: + def test_401_raises_authentication_error(self): + client = make_client() + r = MagicMock() + r.status_code = 401 + r.raise_for_status = MagicMock() + client._session.get = MagicMock(return_value=r) + with pytest.raises(AuthenticationError, match="brainstem login"): + client.load("session", load_all=True) + + +# --------------------------------------------------------------------------- +# Device auth polling timeout +# --------------------------------------------------------------------------- + +class TestDeviceAuthTimeout: + def test_polling_timeout_raises(self): + client = make_client() + + device_resp = MagicMock() + device_resp.raise_for_status = MagicMock() + device_resp.json.return_value = { + "device_code": "dev123", + "user_code": "ABC-DEF", + "verification_uri": "https://brainstem.org/activate", + "verification_uri_complete": "https://brainstem.org/activate?code=ABC-DEF", + "interval": 0, + } + pending_resp = MagicMock() + pending_resp.raise_for_status = MagicMock() + pending_resp.json.return_value = {"status": "authorization_pending"} + + # max_wait=0 triggers immediate timeout on the first iteration check + with patch("requests.post", side_effect=[device_resp] + [pending_resp] * 10), \ + patch("webbrowser.open"), \ + patch("time.monotonic", side_effect=[0, 0, 999]): + with pytest.raises(AuthenticationError, match="timed out"): + client._device_auth_flow(max_wait=0) + + +# --------------------------------------------------------------------------- +# Context manager +# --------------------------------------------------------------------------- + +class TestContextManager: + def test_enter_returns_client(self): + client = make_client() + with client as c: + assert c is client + + def test_exit_closes_session(self): + client = make_client() + client._session.close = MagicMock() + with client: + pass + client._session.close.assert_called_once() + + +# --------------------------------------------------------------------------- +# Unknown kwarg validation in convenience loaders +# --------------------------------------------------------------------------- + +class TestUnknownKwarg: + def test_unknown_kwarg_raises(self): + client = make_client() + with pytest.raises(TypeError, match="Unexpected keyword argument"): + # Call _convenience_load directly with a kwarg not in filter_map + client._convenience_load( + "session", + default_include=[], + filter_map={"name": "name.icontains"}, + portal="private", id=None, filters=None, sort=None, + include=None, limit=None, offset=None, load_all=False, + name="test", + oops="extra", # not listed in filter_map + ) + + +# --------------------------------------------------------------------------- +# New convenience loaders +# --------------------------------------------------------------------------- + +class TestNewConvenienceLoaders: + def setup_method(self): + self.client = make_client() + self.client._session.get = MagicMock(return_value=mock_response(200, {})) + + def _get_params(self): + return self.client._session.get.call_args[1].get("params", {}) + + def _get_url(self): + return self.client._session.get.call_args[0][0] + + def test_load_procedurelog_procedure_filter(self): + uuid = "11111111-0000-0000-0000-000000000000" + self.client.load_procedurelog(procedure=uuid) + assert self._get_params()["filter{procedure.id}"] == uuid + assert "/modules/procedurelog/" in self._get_url() + + def test_load_subjectlog_subject_filter(self): + uuid = "22222222-0000-0000-0000-000000000000" + self.client.load_subjectlog(subject=uuid) + assert self._get_params()["filter{subject.id}"] == uuid + assert "/modules/subjectlog/" in self._get_url() + + def test_load_equipment_setup_filter(self): + uuid = "33333333-0000-0000-0000-000000000000" + self.client.load_equipment(setup=uuid) + assert self._get_params()["filter{setup.id}"] == uuid + assert "/modules/equipment/" in self._get_url() + + def test_load_equipment_name_filter(self): + self.client.load_equipment(name="headstage") + assert self._get_params()["filter{name.icontains}"] == "headstage" + + def test_load_consumablestock_url(self): + self.client.load_consumablestock() + assert "/modules/consumablestock/" in self._get_url() + + def test_load_behavioralassay_name_filter(self): + self.client.load_behavioralassay(name="open field") + assert self._get_params()["filter{name.icontains}"] == "open field" + assert "/personal_attributes/behavioralassay/" in self._get_url() + + def test_load_datastorage_url(self): + self.client.load_datastorage() + assert "/personal_attributes/datastorage/" in self._get_url() + + def test_load_setup_url(self): + self.client.load_setup() + assert "/personal_attributes/setup/" in self._get_url() + + def test_load_hardwaredevice_url(self): + self.client.load_hardwaredevice() + assert "/resources/hardwaredevice/" in self._get_url() + + def test_load_brainregion_name_filter(self): + self.client.load_brainregion(name="CA1") + assert self._get_params()["filter{name.icontains}"] == "CA1" + assert "/taxonomies/brainregion/" in self._get_url() + + def test_load_species_url(self): + self.client.load_species() + assert "/taxonomies/species/" in self._get_url() + + def test_load_strain_species_filter(self): + uuid = "44444444-0000-0000-0000-000000000000" + self.client.load_strain(species=uuid) + assert self._get_params()["filter{species.id}"] == uuid + assert "/taxonomies/strain/" in self._get_url() + + def test_load_publication_url(self): + self.client.load_publication() + assert "/dissemination/publication/" in self._get_url() + + def test_load_laboratory_name_filter(self): + self.client.load_laboratory(name="Petersen") + assert self._get_params()["filter{name.icontains}"] == "Petersen" + assert "/users/laboratory/" in self._get_url() + + +# --------------------------------------------------------------------------- +# CLI — load / save / delete commands +# --------------------------------------------------------------------------- + +class TestCLICommands: + def _run_cli(self, argv, client_mock): + import brainstem_api_tools.cli as cli_module + with patch("brainstem_api_tools.cli.BrainstemClient", return_value=client_mock), \ + patch("sys.argv", argv): + cli_module.main() + + def test_cli_load_calls_load(self, capsys): + resp = mock_response(200, {"sessions": []}) + client = MagicMock() + client.load.return_value = resp + self._run_cli(["brainstem", "--token", TOKEN, "load", "session"], client) + client.load.assert_called_once() + + def test_cli_load_filter_parsed(self, capsys): + resp = mock_response(200, {}) + client = MagicMock() + client.load.return_value = resp + self._run_cli( + ["brainstem", "--token", TOKEN, "load", "session", + "--filters", "name.icontains=rat"], + client, + ) + _, kwargs = client.load.call_args + assert kwargs["filters"] == {"name.icontains": "rat"} + + def test_cli_save_post(self, capsys): + resp = mock_response(201, {"session": {"id": "abc"}}) + client = MagicMock() + client.save.return_value = resp + self._run_cli( + ["brainstem", "--token", TOKEN, "save", "session", + "--data", '{"name": "x", "projects": []}'], + client, + ) + client.save.assert_called_once() + _, kwargs = client.save.call_args + assert kwargs["data"] == {"name": "x", "projects": []} + + def test_cli_delete(self, capsys): + resp = mock_response(204) + client = MagicMock() + client.delete.return_value = resp + self._run_cli( + ["brainstem", "--token", TOKEN, "delete", "session", + "--id", "00000000-0000-0000-0000-000000000099"], + client, + ) + client.delete.assert_called_once() + _, kwargs = client.delete.call_args + assert kwargs["id"] == "00000000-0000-0000-0000-000000000099" + + def test_cli_invalid_filter_exits(self): + import brainstem_api_tools.cli as cli_module + with patch("brainstem_api_tools.cli.BrainstemClient", return_value=MagicMock()), \ + patch("sys.argv", ["brainstem", "--token", TOKEN, "load", "session", + "--filters", "badfilter"]): + with pytest.raises(SystemExit): + cli_module.main() +