From 08c934e1ba878eb6ef55c5c4d9c2279f07793bae Mon Sep 17 00:00:00 2001 From: "Peter C. Petersen" Date: Mon, 30 Mar 2026 09:28:01 +0200 Subject: [PATCH 1/3] Refactor client: device auth, CLI, pagination MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace password auth with a browser-based device authorization flow and persistent token caching (~/.config/brainstem/token). Export AuthenticationError and modernize BrainstemClient: session-backed requests, unified load/save/delete methods, auto-pagination (load_all), query param handling (filters/sort/include/limit/offset), model→app routing table, and many convenience loaders (load_project/subject/session/... ). Add a command-line interface (brainstem entrypoint), update README and tutorial for new usage, add pyproject.toml (remove setup.py), and include comprehensive unit tests for client behavior and auth flow. --- README.md | 141 +++- brainstem_api_tools/__init__.py | 2 +- brainstem_api_tools/brainstem_api_client.py | 772 ++++++++++++++------ brainstem_api_tools/cli.py | 179 +++++ brainstem_api_tutorial.ipynb | 402 +++++----- brainstem_api_tutorial.py | 97 ++- pyproject.toml | 32 +- setup.py | 24 - __init__.py => tests/__init__.py | 0 tests/test_client.py | 452 ++++++++++++ 10 files changed, 1617 insertions(+), 484 deletions(-) create mode 100644 brainstem_api_tools/cli.py delete mode 100644 setup.py rename __init__.py => tests/__init__.py (100%) create mode 100644 tests/test_client.py diff --git a/README.md b/README.md index 07f9ded..1c95367 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,35 @@ You can install the package using `pip`: pip install brainstem_python_api_tools ``` +## Authentication + +Authentication uses a **browser-based device authorization flow** that supports two-factor authentication (2FA). No credentials are entered into the tool itself. + +```python +from brainstem_api_tools import BrainstemClient + +# First run: opens a browser window for secure login. +# The token is cached at ~/.config/brainstem/token and reused automatically. +client = BrainstemClient() +``` + +To skip the browser flow, pass a token directly: +```python +client = BrainstemClient(token="YOUR_TOKEN") +``` + +For headless environments (no browser available): +```python +client = BrainstemClient(headless=True) # prints a URL + code to enter manually +``` + +To connect to a different server (e.g. a local development instance): +```python +client = BrainstemClient(url="http://127.0.0.1:8000/") +``` + +> **Security note:** The cached token file is stored with owner-read-only permissions (`0600`). Treat it like a password and do not commit it to version control. + ## Getting Started To get started with the BrainSTEM API tools, please refer to the tutorial script provided: @@ -15,21 +44,121 @@ To get started with the BrainSTEM API tools, please refer to the tutorial script The tutorial demonstrates how to: -- **Authenticate:** Load the client and authenticate using your credentials. -- **Loading Data:** Load sessions and filter data using flexible options. -- **Updating Entries:** Modify existing models and update them in the database. -- **Creating Entries:** Submit new data entries with required fields. -- **Loading Public Data:** Access public projects and data using the public portal. +- **Authenticate:** Load the client and authenticate via your browser. +- **Load Data:** Load sessions and filter data using flexible options. +- **Paginate:** Retrieve large datasets using `limit` / `offset` or `load_all=True`. +- **Convenience Loaders:** Use `load_session()`, `load_subject()`, etc. for common queries. +- **Update Entries:** Modify existing records and update them in the database. +- **Create Entries:** Submit new data entries with required fields. +- **Delete Entries:** Remove records by ID. +- **Load Public Data:** Access public projects and data using the public portal. ## Example Usage ```python from brainstem_api_tools import BrainstemClient client = BrainstemClient() -response = client.load_model('session', sort=['name']) + +# Load sessions sorted by name +response = client.load('session', sort=['name']) print(response.json()) + +# Filter and include related data +response = client.load( + 'session', + filters={'name.icontains': 'Rat'}, + sort=['-name'], + include=['projects'], +) + +# Paginate results (manual) +page2 = client.load('session', limit=20, offset=20).json() + +# Auto-paginate — fetches every page and returns a merged dict +all_sessions = client.load('session', load_all=True) + +# Convenience loaders — sensible defaults with named filter kwargs +# load_session embeds dataacquisition, behaviors, manipulations, epochs +sessions = client.load_session(name='Rat', load_all=True) + +# load_subject embeds procedures and subjectlogs +subjects = client.load_subject(sex='M', projects='', load_all=True) + +# load_project embeds sessions, subjects, collections, cohorts +projects = client.load_project(name='MyProject') + +# load_behavior / load_dataacquisition / load_manipulation scope by session UUID +behaviors = client.load_behavior(session='', load_all=True) + +# load_procedure scopes by subject UUID +procedures = client.load_procedure(subject='', load_all=True) + +# Create a record +client.save('session', data={'name': 'New session', 'projects': ['']}) + +# Update a record +client.save('session', id='', data={'description': 'updated'}) + +# Delete a record +client.delete('session', id='') +``` + +## Contributing +Contributions are welcome! Feel free to open issues or submit pull requests on GitHub. + +## Command-line Interface + +After installation a `brainstem` command is available in your shell. + +### Authentication +```bash +# Authenticate (opens browser) and cache token +brainstem login + +# Headless — prints URL + code instead of opening browser +brainstem login --headless + +# Connect to a local dev server +brainstem login --url http://127.0.0.1:8000/ + +# Remove cached token +brainstem logout +``` + +### Loading data +```bash +# Load all sessions (private portal) +brainstem load session + +# Filter, sort and embed related data +brainstem load session --filters name.icontains=Rat --sort -name --include projects + +# Load a single record by UUID +brainstem load session --id + +# Manual pagination +brainstem load session --limit 20 --offset 20 + +# Public portal +brainstem load project --portal public +``` + +### Creating and updating records +```bash +# Create a new session +brainstem save session --data '{"name":"New session","projects":[""]}' + +# Update an existing record +brainstem save session --id --data '{"description":"updated"}' ``` +### Deleting records +```bash +brainstem delete session --id +``` + +All subcommands accept `--token`, `--headless`, and `--url` to override defaults. + ## Contributing Contributions are welcome! Feel free to open issues or submit pull requests on GitHub. diff --git a/brainstem_api_tools/__init__.py b/brainstem_api_tools/__init__.py index 9fc63f9..e98a356 100644 --- a/brainstem_api_tools/__init__.py +++ b/brainstem_api_tools/__init__.py @@ -1 +1 @@ -from .brainstem_api_client import BrainstemClient, ModelType, PortalType +from .brainstem_api_client import BrainstemClient, ModelType, PortalType, AuthenticationError diff --git a/brainstem_api_tools/brainstem_api_client.py b/brainstem_api_tools/brainstem_api_client.py index 1cb838a..31c7153 100644 --- a/brainstem_api_tools/brainstem_api_client.py +++ b/brainstem_api_tools/brainstem_api_client.py @@ -1,43 +1,101 @@ import os -from getpass import getpass +import stat +import time +import webbrowser +from enum import Enum +from pathlib import Path + import requests from requests.models import Response -import json -from enum import Enum -class ModelType(Enum): - project = "project" - subject = "subject" - session = "session" - collection = "collection" - cohort = "cohort" - procedure = "procedure" - behavior = "behavior" - dataacquisition = "dataacquisition" - manipulation = "manipulation" - equipment = "equipment" - consumablestock = "consumablestock" - procedurelog = "procedurelog" - subjectlog = "subjectlog" - behavioralparadigm = "behavioralparadigm" - datastorage = "datastorage" - inventory = "inventory" - setup = "setup" - consumable = "consumable" - hardwaredevice = "hardwaredevice" - supplier = "supplier" - brainregion = "brainregion" - setuptype = "setuptype" - species = "species" - strain = "strain" - strainapproval = "strainapproval" - journal = "journal" - laboratory = "laboratory" - publication = "publication" - journalapproval = "journalapproval" - user = "user" - group = "group" +# --------------------------------------------------------------------------- +# Model → app routing table (single source of truth) +# --------------------------------------------------------------------------- + +_MODEL_TO_APP: dict = { + # stem + "project": "stem", + "subject": "stem", + "session": "stem", + "collection": "stem", + "cohort": "stem", + "breeding": "stem", + "project_membership_invitation": "stem", + "project_group_membership_invitation": "stem", + # modules + "procedure": "modules", + "behavior": "modules", + "dataacquisition": "modules", + "manipulation": "modules", + "equipment": "modules", + "consumablestock": "modules", + "procedurelog": "modules", + "subjectlog": "modules", + # personal_attributes + "behavioralassay": "personal_attributes", + "datastorage": "personal_attributes", + "inventory": "personal_attributes", + "license": "personal_attributes", + "protocol": "personal_attributes", + "setup": "personal_attributes", + # resources + "consumable": "resources", + "hardwaredevice": "resources", + "supplier": "resources", + # taxonomies + "behavioralcategory": "taxonomies", + "behavioralparadigm": "taxonomies", + "brainregion": "taxonomies", + "regulatoryauthority": "taxonomies", + "setuptype": "taxonomies", + "species": "taxonomies", + "strain": "taxonomies", + "strainapproval": "taxonomies", + # dissemination + "journal": "dissemination", + "journalapproval": "dissemination", + "publication": "dissemination", + # users + "group": "users", + "group_membership_invitation": "users", + "group_membership_request": "users", + "laboratory": "users", + "user": "users", +} + +_TOKEN_FILE = Path.home() / ".config" / "brainstem" / "token" + +# Derived from _MODEL_TO_APP so there is exactly one source of truth. +ModelType = Enum("ModelType", {k: k for k in _MODEL_TO_APP}) # type: ignore[misc] + +_VALID_PORTALS = {"private", "public", "super"} + + +def _resolve_model(model) -> str: + """Accept a plain string or a ModelType member and return the string value.""" + if isinstance(model, ModelType): + return model.value + model = str(model) + if model not in _MODEL_TO_APP: + raise ValueError( + f"Unknown model '{model}'. Valid models are: " + + ", ".join(sorted(_MODEL_TO_APP)) + ) + return model + + +def _resolve_portal(portal) -> str: + """Accept a plain string or a PortalType member and return the string value.""" + if isinstance(portal, Enum): + portal = portal.value + portal = str(portal) + if portal not in _VALID_PORTALS: + raise ValueError( + f"Unknown portal '{portal}'. Valid portals are: " + + ", ".join(sorted(_VALID_PORTALS)) + ) + return portal class PortalType(Enum): @@ -46,202 +104,480 @@ class PortalType(Enum): super = "super" -class LoginError(Exception): - def __str__(self): - return f'User/password combination incorrect or user does not exist' +# --------------------------------------------------------------------------- +# Exceptions +# --------------------------------------------------------------------------- + +class AuthenticationError(Exception): + pass +# --------------------------------------------------------------------------- +# Client +# --------------------------------------------------------------------------- + class BrainstemClient: - def __init__(self, - token: str = None) -> None: + BASE_URL = "https://www.brainstem.org/" - # Server path - self._address = 'https://www.brainstem.org/api/' + def __init__( + self, + token: str = None, + headless: bool = False, + url: str = None, + ) -> None: + base = (url.rstrip("/") + "/") if url else self.BASE_URL + self._address = base + "api/" + self._session = requests.Session() if token: self._token = token else: - username = input("Please enter your username/email:") - password = getpass("Please enter your password:") - - self._token = self.__set_token_authentication( - url=self._address + "token/", - username=username, - password=password, + self._token = self._load_cached_token() + if not self._token: + self._token = self._device_auth_flow(headless=headless) + self._save_token(self._token) + + self._session.headers.update({"Authorization": f"Bearer {self._token}"}) + + # ------------------------------------------------------------------ + # Token management + # ------------------------------------------------------------------ + + def _load_cached_token(self) -> str: + """Return the cached token from disk, or None if not present.""" + if _TOKEN_FILE.exists(): + return _TOKEN_FILE.read_text().strip() or None + return None + + def _save_token(self, token: str) -> None: + """Persist the token to disk with owner-read-only permissions.""" + _TOKEN_FILE.parent.mkdir(parents=True, exist_ok=True) + _TOKEN_FILE.write_text(token) + _TOKEN_FILE.chmod(stat.S_IRUSR | stat.S_IWUSR) # 0o600 + + def _device_auth_flow(self, headless: bool = False) -> str: + """Run the BrainSTEM device authorization flow and return a token. + + The user authenticates in their browser (supports 2FA). No + credentials are sent by this tool. + + Parameters + ---------- + headless : If ``True``, print the verification URI and user code + instead of opening a browser window. + """ + # Step 1 — initiate a device session + resp = requests.post( + self._address + "auth/device/", + json={"client_name": "brainstem-python"}, + timeout=10, + ) + resp.raise_for_status() + data = resp.json() + + device_code = data["device_code"] + interval = int(data.get("interval", 5)) + + # Step 2 — open browser or print instructions + if headless: + print(f"Open {data['verification_uri']} and enter: {data['user_code']}") + else: + webbrowser.open(data["verification_uri_complete"]) + print("Waiting for browser approval...") + + # Step 3 — poll until resolved + while True: + time.sleep(interval) + poll = requests.post( + self._address + "auth/device/token/", + json={"device_code": device_code}, + timeout=10, ) - - print("Your authorization token is:\n", self._token) - print("\nPlease keep it in a safe place.") - - def __set_token_authentication(self, - url: str, - username: str, - password: str) -> str: - - headers = { - "accept": "application/json", - "Content-Type": "application/json" - } - params = { - "username": username, - "password": password - } - resp = requests.post(url, headers=headers, json=params) - - if resp.status_code != 200: - raise LoginError() - - token = json.loads(resp.text)['token'] - return token - - def __get_app_from_model(self, modelname: str) -> str: - app = None - - if modelname in ['project', 'subject', 'session', 'collection', 'cohort', - 'projectmembershipinvitation', - 'projectgroupmembershipinvitation']: - app = 'stem' - - elif modelname in ['procedure', 'behavior', 'dataacquisition', 'manipulation', - 'equipment', 'consumablestock', 'procedurelog', 'subjectlog']: - app = 'modules' - - elif modelname in ['behavioralparadigm', 'datastorage', 'inventory', - 'setup']: - app = 'personal_attributes' - - elif any([x in modelname for x in - ['consumable', 'hardwaredevice', 'supplier']]): - app = 'resources' - - elif any([x in modelname for x in - ['brainregion', 'setuptype', - 'species', 'strain']]): - app = 'taxonomies' - - elif any([x in modelname for x in - ['journal', 'laboratory', 'publication']]): - app = 'attributes' - - elif modelname in ['user', 'group', 'groupmembershipinvitation', - 'groupmembershiprequest']: - app = 'users' - - return app - - def load_model(self, - model: ModelType, - portal: PortalType = "private", - id: str = None, - options: str = None, - filters: dict = None, - sort: list = None, - include: list = None) -> Response: - - app = self.__get_app_from_model(model) - if app is None: - resp = Response() - resp.status_code = 404 - return resp - - if options is None: - options = "" - + poll.raise_for_status() + result = poll.json() + + if result.get("status") == "success": + return result["token"] + elif result.get("status") == "authorization_pending": + continue + else: + raise AuthenticationError( + f"Authorization failed: {result.get('error')}" + ) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _build_url(self, portal: str, app: str, model: str, + id: str = None, options: str = None) -> str: + url = f"{self._address}{portal}/{app}/{model}/" if id: - query_parameters = id + "/" + options - else: - query_parameters = "" - - if filters: - for key in filters.keys(): - if query_parameters == "": - prefix = "?" - else: - prefix = "&" - query_parameters += (prefix - + "filter{" + key + "}=" - + filters[key]) - if sort: - for elem in sort: - if query_parameters == "": - prefix = "?" - else: - prefix = "&" - query_parameters += prefix + "sort[]=" + elem - - if include: - for elem in include: - if query_parameters == "": - prefix = "?" - else: - prefix = "&" - query_parameters += prefix + "include[]=" + elem + ".*" - - request_url = (self._address + portal - + "/" + app + "/" + model - + "/" + query_parameters) - - resp = requests.get(request_url, - headers={"Authorization": "Bearer %s" - % self._token}) - return resp - - def save_model(self, - model: ModelType, - portal: PortalType = "private", - id: str = None, - data: dict = None, - options: str = None) -> Response: - - app = self.__get_app_from_model(model) - if app is None: - resp = Response() - resp.status_code = 404 - return resp - + url += f"{id}/" + if options: + url += options + return url + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def load(self, + model, + portal="private", + id: str = None, + options: str = None, + filters: dict = None, + sort: list = None, + include: list = None, + limit: int = None, + offset: int = None, + load_all: bool = False): + """Load one or more records of *model*. + + Parameters + ---------- + model : Model name string or ``ModelType`` member, e.g. ``'session'``. + portal : ``'private'`` (default) or ``'public'``, or ``PortalType`` member. + id : UUID of a specific record. Returns a single object when set. + filters : Dict of field filters, e.g. ``{'name.icontains': 'rat'}``. + Supports ``.icontains``, ``.startswith``, ``.endswith``, + ``.gt``, ``.gte``, ``.lt``, ``.lte``. + sort : List of fields to sort by. Prefix with ``'-'`` for descending. + include : Related models to embed, e.g. ``['dataacquisition', 'behaviors']``. + limit : Max records per page (API maximum: 100). + offset : Number of records to skip (for pagination). + load_all : When ``True``, automatically follow pagination and return a + combined ``dict`` with all records merged under the model key. + When ``False`` (default), returns the raw ``Response``. + """ + model = _resolve_model(model) + portal = _resolve_portal(portal) + app = _MODEL_TO_APP[model] + url = self._build_url(portal, app, model, id, options if id else None) + + params = {} + if not id: + for key, val in (filters or {}).items(): + params[f"filter{{{key}}}"] = val + for field in (sort or []): + params.setdefault("sort[]", []).append(field) + for rel in (include or []): + params.setdefault("include[]", []).append(f"{rel}.*") + if limit is not None: + params["limit"] = limit + if offset is not None: + params["offset"] = offset + + if not load_all: + return self._session.get(url, params=params) + + # --- auto-paginate and merge all pages --- + page_size = limit or 100 + params["limit"] = page_size + params.setdefault("offset", offset or 0) + + combined: dict = {} + records_key: str = None + + while True: + resp = self._session.get(url, params=params) + resp.raise_for_status() + data = resp.json() + + if records_key is None: + # Detect the list key (e.g. 'sessions', 'projects') + records_key = next( + (k for k, v in data.items() if isinstance(v, list)), None + ) + combined = {k: v for k, v in data.items() if k != records_key} + combined[records_key] = [] + + combined[records_key].extend(data.get(records_key, [])) + total = data.get("count", len(combined[records_key])) + + if len(combined[records_key]) >= total: + break + + params["offset"] = params.get("offset", 0) + page_size + + return combined + + def save(self, + model, + portal="private", + id: str = None, + data: dict = None, + options: str = None) -> Response: + """Create or update a record. + + When *id* is provided the record is **updated** (PATCH). + When *id* is omitted a **new record** is created (POST). + + Parameters + ---------- + model : Model name string or ``ModelType`` member, e.g. ``'session'``. + portal : ``'private'`` (default) or ``'public'``, or ``PortalType`` member. + id : UUID of the record to update. Omit to create a new record. + data : Dict of fields to submit. + """ + model = _resolve_model(model) + portal = _resolve_portal(portal) + app = _MODEL_TO_APP[model] if data is None: data = {} - if options is None: - options = "" - - # Check if entry already exists if id is not None: - # This is an update request - request_url = (self._address + portal + "/" + app + "/" + model - + "/" + id + "/" + options) - resp = requests.patch(request_url, json=data, - headers={"Authorization": "Bearer %s" - % self._token}) - + url = self._build_url(portal, app, model, id, options) + return self._session.patch(url, json=data) else: - # This is a create request - request_url = (self._address + portal + "/" + app + "/" - + model + "/" + options) - resp = requests.post(request_url, json=data, - headers={"Authorization": "Bearer %s" - % self._token}) - - return resp - - def delete_model(self, - model: ModelType, - portal: PortalType = "private", - id: str = None): - - app = self.__get_app_from_model(model) - if app is None: - resp = Response() - resp.status_code = 404 - return resp - - # Check if entry already exists - if id is not None: - request_url = (self._address + portal + "/" + app + "/" + model - + "/" + id + "/") - resp = requests.delete(request_url, - headers={"Authorization": "Bearer %s" - % self._token}) - - return resp + url = self._build_url(portal, app, model, options=options) + return self._session.post(url, json=data) + + def delete(self, + model, + portal="private", + id: str = None) -> Response: + """Delete a record by ID. + + Parameters + ---------- + model : Model name string or ``ModelType`` member, e.g. ``'session'``. + portal : ``'private'`` (default) or ``'public'``, or ``PortalType`` member. + id : UUID of the record to delete (required). + """ + if id is None: + raise ValueError("'id' is required to delete a record.") + model = _resolve_model(model) + portal = _resolve_portal(portal) + app = _MODEL_TO_APP[model] + url = self._build_url(portal, app, model, id) + return self._session.delete(url) + + # ------------------------------------------------------------------ + # Convenience loaders (mirror the MATLAB load_* helpers) + # ------------------------------------------------------------------ + + def _convenience_load(self, model, default_include, filter_map, + portal, id, filters, sort, include, limit, offset, + load_all, **field_kwargs): + """Shared implementation for all convenience loaders.""" + merged_filters = dict(filters or {}) + for kwarg, api_field in filter_map.items(): + value = field_kwargs.get(kwarg) + if value: + merged_filters[api_field] = value + return self.load( + model, + portal=portal, + id=id, + filters=merged_filters or None, + sort=sort, + include=include if include is not None else default_include, + limit=limit, + offset=offset, + load_all=load_all, + ) + + def load_project(self, portal="private", id: str = None, + name: str = None, sessions: str = None, + subjects: str = None, tags: str = None, + filters: dict = None, sort: list = None, + include: list = None, limit: int = None, + offset: int = 0, load_all: bool = False): + """Load project(s). Embeds sessions, subjects, collections and cohorts by default. + + Parameters + ---------- + name : Filter by project name (case-insensitive contains). + sessions : Filter by session UUID. + subjects : Filter by subject UUID. + tags : Filter by tag. + """ + return self._convenience_load( + "project", + default_include=["sessions", "subjects", "collections", "cohorts"], + filter_map={"name": "name.icontains", "sessions": "sessions.id", + "subjects": "subjects.id", "tags": "tags"}, + portal=portal, id=id, filters=filters, sort=sort, include=include, + limit=limit, offset=offset, load_all=load_all, + name=name, sessions=sessions, subjects=subjects, tags=tags, + ) + + def load_subject(self, portal="private", id: str = None, + name: str = None, projects: str = None, + strain: str = None, sex: str = None, tags: str = None, + filters: dict = None, sort: list = None, + include: list = None, limit: int = None, + offset: int = 0, load_all: bool = False): + """Load subject(s). Embeds procedures and subjectlogs by default. + + Parameters + ---------- + name : Filter by subject name (case-insensitive contains). + projects : Filter by project UUID. + strain : Filter by strain UUID. + sex : Filter by sex (``'M'``, ``'F'``, or ``'U'``). + tags : Filter by tag. + """ + return self._convenience_load( + "subject", + default_include=["procedures", "subjectlogs"], + filter_map={"name": "name.icontains", "projects": "projects.id", + "strain": "strain.id", "sex": "sex", "tags": "tags"}, + portal=portal, id=id, filters=filters, sort=sort, include=include, + limit=limit, offset=offset, load_all=load_all, + name=name, projects=projects, strain=strain, sex=sex, tags=tags, + ) + + def load_session(self, portal="private", id: str = None, + name: str = None, projects: str = None, + datastorage: str = None, tags: str = None, + filters: dict = None, sort: list = None, + include: list = None, limit: int = None, + offset: int = 0, load_all: bool = False): + """Load session(s). Embeds dataacquisition, behaviors, manipulations and epochs by default. + + Parameters + ---------- + name : Filter by session name (case-insensitive contains). + projects : Filter by project UUID. + datastorage : Filter by data storage UUID. + tags : Filter by tag. + """ + return self._convenience_load( + "session", + default_include=["dataacquisition", "behaviors", "manipulations", "epochs"], + filter_map={"name": "name.icontains", "projects": "projects.id", + "datastorage": "datastorage.id", "tags": "tags"}, + portal=portal, id=id, filters=filters, sort=sort, include=include, + limit=limit, offset=offset, load_all=load_all, + name=name, projects=projects, datastorage=datastorage, tags=tags, + ) + + def load_collection(self, portal="private", id: str = None, + name: str = None, tags: str = None, + filters: dict = None, sort: list = None, + include: list = None, limit: int = None, + offset: int = 0, load_all: bool = False): + """Load collection(s). Embeds sessions by default. + + Parameters + ---------- + name : Filter by collection name (case-insensitive contains). + tags : Filter by tag. + """ + return self._convenience_load( + "collection", + default_include=["sessions"], + filter_map={"name": "name.icontains", "tags": "tags"}, + portal=portal, id=id, filters=filters, sort=sort, include=include, + limit=limit, offset=offset, load_all=load_all, + name=name, tags=tags, + ) + + def load_cohort(self, portal="private", id: str = None, + name: str = None, tags: str = None, + filters: dict = None, sort: list = None, + include: list = None, limit: int = None, + offset: int = 0, load_all: bool = False): + """Load cohort(s). Embeds subjects by default. + + Parameters + ---------- + name : Filter by cohort name (case-insensitive contains). + tags : Filter by tag. + """ + return self._convenience_load( + "cohort", + default_include=["subjects"], + filter_map={"name": "name.icontains", "tags": "tags"}, + portal=portal, id=id, filters=filters, sort=sort, include=include, + limit=limit, offset=offset, load_all=load_all, + name=name, tags=tags, + ) + + def load_behavior(self, portal="private", id: str = None, + session: str = None, tags: str = None, + filters: dict = None, sort: list = None, + include: list = None, limit: int = None, + offset: int = 0, load_all: bool = False): + """Load behavior record(s). + + Parameters + ---------- + session : Filter by session UUID. + tags : Filter by tag. + """ + return self._convenience_load( + "behavior", + default_include=[], + filter_map={"session": "session.id", "tags": "tags"}, + portal=portal, id=id, filters=filters, sort=sort, include=include, + limit=limit, offset=offset, load_all=load_all, + session=session, tags=tags, + ) + + def load_dataacquisition(self, portal="private", id: str = None, + session: str = None, tags: str = None, + filters: dict = None, sort: list = None, + include: list = None, limit: int = None, + offset: int = 0, load_all: bool = False): + """Load data acquisition record(s). + + Parameters + ---------- + session : Filter by session UUID. + tags : Filter by tag. + """ + return self._convenience_load( + "dataacquisition", + default_include=[], + filter_map={"session": "session.id", "tags": "tags"}, + portal=portal, id=id, filters=filters, sort=sort, include=include, + limit=limit, offset=offset, load_all=load_all, + session=session, tags=tags, + ) + + def load_manipulation(self, portal="private", id: str = None, + session: str = None, tags: str = None, + filters: dict = None, sort: list = None, + include: list = None, limit: int = None, + offset: int = 0, load_all: bool = False): + """Load manipulation record(s). + + Parameters + ---------- + session : Filter by session UUID. + tags : Filter by tag. + """ + return self._convenience_load( + "manipulation", + default_include=[], + filter_map={"session": "session.id", "tags": "tags"}, + portal=portal, id=id, filters=filters, sort=sort, include=include, + limit=limit, offset=offset, load_all=load_all, + session=session, tags=tags, + ) + + def load_procedure(self, portal="private", id: str = None, + subject: str = None, tags: str = None, + filters: dict = None, sort: list = None, + include: list = None, limit: int = None, + offset: int = 0, load_all: bool = False): + """Load procedure record(s). + + Parameters + ---------- + subject : Filter by subject UUID. + tags : Filter by tag. + """ + return self._convenience_load( + "procedure", + default_include=[], + filter_map={"subject": "subject.id", "tags": "tags"}, + portal=portal, id=id, filters=filters, sort=sort, include=include, + limit=limit, offset=offset, load_all=load_all, + subject=subject, tags=tags, + ) diff --git a/brainstem_api_tools/cli.py b/brainstem_api_tools/cli.py new file mode 100644 index 0000000..8cfe715 --- /dev/null +++ b/brainstem_api_tools/cli.py @@ -0,0 +1,179 @@ +"""Command-line interface for BrainSTEM API tools. + +Usage examples +-------------- + brainstem login + brainstem login --headless + brainstem login --url http://127.0.0.1:8000/ + brainstem logout + brainstem load session + brainstem load session --portal public --filters name.icontains=rat --sort -name + brainstem load session --id + brainstem load session --limit 20 --offset 40 + brainstem save session --data '{"name":"New","projects":[""]}' + brainstem save session --id --data '{"description":"updated"}' + brainstem delete session --id +""" + +import argparse +import json +import os +import sys + +from .brainstem_api_client import BrainstemClient, _MODEL_TO_APP, _TOKEN_FILE + + +def _build_parser() -> argparse.ArgumentParser: + # Shared flags available on every subcommand + common = argparse.ArgumentParser(add_help=False) + common.add_argument( + "--token", + default=os.environ.get("BRAINSTEM_API_TOKEN"), + help="API token. Defaults to $BRAINSTEM_API_TOKEN or the cached token.", + ) + common.add_argument( + "--headless", + action="store_true", + help="Print verification URL + code instead of opening a browser.", + ) + common.add_argument( + "--url", + default=None, + help="Base URL of the BrainSTEM server (default: https://www.brainstem.org/).", + ) + + parser = argparse.ArgumentParser( + prog="brainstem", + description="BrainSTEM command-line API client.", + parents=[common], + ) + + sub = parser.add_subparsers(dest="command", required=True) + + # ---- login ------------------------------------------------------ + sub.add_parser( + "login", + parents=[common], + help="Authenticate and cache your API token.", + ) + + # ---- logout ----------------------------------------------------- + sub.add_parser( + "logout", + help="Remove the cached API token.", + ) + + # ---- load ------------------------------------------------------- + p_load = sub.add_parser("load", parents=[common], help="Load records from BrainSTEM.") + p_load.add_argument("model", choices=sorted(_MODEL_TO_APP), help="Model name.") + p_load.add_argument("--portal", default="private", help="'private' or 'public'.") + p_load.add_argument("--id", help="UUID of a specific record.") + p_load.add_argument( + "--filters", + nargs="+", + metavar="FIELD=VALUE", + help="Filter expressions, e.g. name.icontains=rat", + ) + p_load.add_argument( + "--sort", + nargs="+", + metavar="FIELD", + help="Sort fields. Prefix with '-' for descending.", + ) + p_load.add_argument( + "--include", + nargs="+", + metavar="RELATION", + help="Related models to embed.", + ) + p_load.add_argument("--limit", type=int, help="Max records (API max: 100).") + p_load.add_argument("--offset", type=int, help="Records to skip (pagination).") + + # ---- save ------------------------------------------------------- + p_save = sub.add_parser("save", parents=[common], help="Create or update a record.") + p_save.add_argument("model", choices=sorted(_MODEL_TO_APP), help="Model name.") + p_save.add_argument("--portal", default="private") + p_save.add_argument("--id", help="UUID of the record to update (omit to create).") + p_save.add_argument( + "--data", + required=True, + help="JSON string of fields to submit, e.g. '{\"name\":\"x\"}'.", + ) + + # ---- delete ----------------------------------------------------- + p_delete = sub.add_parser("delete", parents=[common], help="Delete a record by ID.") + p_delete.add_argument("model", choices=sorted(_MODEL_TO_APP), help="Model name.") + p_delete.add_argument("--portal", default="private") + p_delete.add_argument("--id", required=True, help="UUID of the record to delete.") + + return parser + + +def main(): + parser = _build_parser() + args = parser.parse_args() + + if args.command == "logout": + if _TOKEN_FILE.exists(): + _TOKEN_FILE.unlink() + print(f"Logged out. Token removed from {_TOKEN_FILE}") + else: + print("No cached token found.") + return + + client = BrainstemClient( + token=args.token, + headless=getattr(args, "headless", False), + url=getattr(args, "url", None), + ) + + if args.command == "login": + print(f"Token: {client._token}") + print(f"Cached at: {_TOKEN_FILE}") + return + + if args.command == "load": + filters = {} + for expr in (args.filters or []): + if "=" not in expr: + parser.error(f"Invalid filter '{expr}': expected FIELD=VALUE") + k, v = expr.split("=", 1) + filters[k] = v + + resp = client.load( + args.model, + portal=args.portal, + id=args.id, + filters=filters or None, + sort=args.sort, + include=args.include, + limit=args.limit, + offset=args.offset, + ) + + elif args.command == "save": + try: + data = json.loads(args.data) + except json.JSONDecodeError as exc: + parser.error(f"Invalid JSON for --data: {exc}") + + resp = client.save(args.model, portal=args.portal, id=args.id, data=data) + + elif args.command == "delete": + resp = client.delete(args.model, portal=args.portal, id=args.id) + + # Output + if resp.status_code == 204: + print("Deleted successfully.") + else: + try: + print(json.dumps(resp.json(), indent=2)) + except Exception: + print(resp.text) + + if not resp.ok: + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/brainstem_api_tutorial.ipynb b/brainstem_api_tutorial.ipynb index b5ec0c8..fa521e1 100644 --- a/brainstem_api_tutorial.ipynb +++ b/brainstem_api_tutorial.ipynb @@ -22,9 +22,7 @@ "metadata": {}, "outputs": [], "source": [ - "from brainstem_api_tools import BrainstemClient\n", - "\n", - "import json" + "from brainstem_api_tools import BrainstemClient" ] }, { @@ -32,7 +30,10 @@ "metadata": {}, "source": [ "# 1. Client Setup and Authentication\n", - "The Brainstem API client provides easy access to the Brainstem data platform. To get started, initialize the client, which will prompt you for login credentials the first time. After successful authentication, the client will generate a token that can be saved for future use. This allows you to avoid re-entering your credentials each time." + "\n", + "Authentication uses a **browser-based device authorization flow** that supports two-factor authentication (2FA). No credentials are entered into this tool.\n", + "\n", + "On the first run, a browser window opens for you to approve access. The token is then cached at `~/.config/brainstem/token` and reused automatically on subsequent runs." ] }, { @@ -41,7 +42,8 @@ "metadata": {}, "outputs": [], "source": [ - "# When initializing client without a token\n", + "# First run: opens a browser window for secure login.\n", + "# Subsequent runs reuse the cached token automatically.\n", "client = BrainstemClient()" ] }, @@ -49,8 +51,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# 2. Loging in with token\n", - "For convenience, you can save your authentication token to a configuration file. This allows you to quickly initialize the client in future sessions without entering login details. Simply load the token from your saved file and pass it to the client constructor." + "# 2. Alternative Authentication Options\n", + "\n", + "You can also pass a token directly, use headless mode (prints a URL + code instead of opening a browser), or point the client at a different server." ] }, { @@ -59,23 +62,23 @@ "metadata": {}, "outputs": [], "source": [ - "token = None # Input your token here\n", - "if token:\n", - " client = BrainstemClient(token=token)\n", - " print(\"Client initialized with saved token\")\n", - "else:\n", - " # This will prompt for username/password if run\n", - " print(\"No saved token found. Will need to login.\")\n", - " client = BrainstemClient()" + "# Pass a token directly to skip the browser flow:\n", + "# client = BrainstemClient(token=\"YOUR_TOKEN\")\n", + "\n", + "# Headless mode — prints a URL + code to enter manually (useful in remote/CI environments):\n", + "# client = BrainstemClient(headless=True)\n", + "\n", + "# Connect to a local development server:\n", + "# client = BrainstemClient(url=\"http://127.0.0.1:8000/\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "# 3. Loading Data Models (Sessions Example - Public Data)\n", + "# 3. Loading Public Data\n", "\n", - "The Brainstem platform hosts both public and private data. Public data can be accessed without authentication. Use the portal=\"public\" parameter to specifically query public repositories. This is useful for exploring datasets like the Allen Institute's Visual Coding – Neuropixels project." + "The BrainSTEM platform hosts both public and private data. Public data can be accessed without authentication. Use `portal=\"public\"` to query public repositories. This is useful for exploring datasets like the Allen Institute's Visual Coding – Neuropixels project." ] }, { @@ -84,8 +87,8 @@ "metadata": {}, "outputs": [], "source": [ - "# Load public projects (no authentication needed)\n", - "public_projects = client.load_model(\"project\", portal=\"public\").json()\n", + "# Load public projects (no authentication needed for public portal)\n", + "public_projects = client.load(\"project\", portal=\"public\").json()\n", "\n", "# Print the number of available public projects\n", "print(f\"Found {len(public_projects.get('projects', []))} public projects\")\n", @@ -95,8 +98,7 @@ "for i, project in enumerate(public_projects.get('projects', [])[:3]):\n", " print(f\"{i+1}. {project.get('name', 'Unnamed')}\")\n", " print(f\" Description: {project.get('description', 'No description')[:100]}...\")\n", - " print(\"\")\n", - "\n" + " print(\"\")" ] }, { @@ -105,20 +107,9 @@ "metadata": {}, "outputs": [], "source": [ - "# To load private projects (requires authentication):\n", - "\"\"\"\n", - "# Load your private projects (requires valid token)\n", - "private_projects = client.load_model(\"project\").json() # default is private\n", - "\n", - "# Print the number of your private projects\n", - "print(f\"Found {len(private_projects.get('projects', []))} private projects\")\n", - "\n", - "# Display your private projects\n", - "print(\"\\nYour private projects:\")\n", - "for i, project in enumerate(private_projects.get('projects', [])[:3]):\n", - " print(f\"{i+1}. {project.get('name', 'Unnamed')}\")\n", - " print(f\" Created by: {project.get('principal_investigator', 'Unknown')}\")\n", - "\"\"\"" + "# Load your private projects (requires authentication — client must be initialised with a valid token)\n", + "# private_projects = client.load(\"project\").json() # default portal is \"private\"\n", + "# print(f\"Found {len(private_projects.get('projects', []))} private projects\")" ] }, { @@ -127,12 +118,10 @@ "source": [ "# 4. Filtering\n", "\n", - "The API supports powerful filtering capabilities to help you find specific records. You can filter by exact matches, partial text matches (case-sensitive or insensitive), and various other criteria. Multiple filters can be combined to narrow your search. Common filter modifiers include .icontains for case-insensitive text search and .iexact for exact matching.\n", + "The API supports powerful filtering capabilities. You can filter by exact matches, partial text matches (case-sensitive or insensitive), and various other criteria. Multiple filters on different fields are combined with AND logic.\n", "\n", "### Filter Modifiers\n", "\n", - "You can use the following filter modifiers to refine your search:\n", - "\n", "- `attribute.contains`: Case-sensitive partial match\n", "- `attribute.icontains`: Case-insensitive partial match\n", "- `attribute.iexact`: Case-insensitive exact match\n", @@ -148,43 +137,81 @@ "metadata": {}, "outputs": [], "source": [ - "# Basic filtering example\n", - "filtered_projects = client.load_model(\n", + "# Case-insensitive name filter\n", + "filtered_projects = client.load(\n", " \"project\",\n", " portal=\"public\",\n", - " filters={'name.icontains': 'Allen'} # Case-insensitive contains\n", + " filters={'name.icontains': 'Allen'}\n", ").json()\n", - "\n", "print(f\"Found {len(filtered_projects.get('projects', []))} projects matching 'Allen'\")\n", "\n", - "# Multiple filters with AND logic\n", - "multi_filter = client.load_model(\n", + "# Filters on different fields are combined with AND logic\n", + "# Note: a Python dict cannot have two identical keys — use different field names to combine filters\n", + "multi_filter = client.load(\n", " \"project\",\n", " portal=\"public\",\n", " filters={\n", - " 'name.icontains': 'institute',\n", - " 'name.iendswith': 'neuropixels' # Fixed typo in operator name\n", + " 'name.icontains': 'Allen',\n", + " 'description.icontains': 'Neuropixels'\n", " }\n", ").json()\n", "\n", - "if isinstance(multi_filter, dict):\n", - " count = len(multi_filter.get('projects', []))\n", - "else:\n", - " count = 0\n", + "print(f\"Found {len(multi_filter.get('projects', []))} projects matching both criteria\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 5. Retrieving Related Data\n", + "# Find the Allen Institute: Visual Coding – Neuropixels project by exact name\n", + "allen_projects = client.load(\n", + " \"project\",\n", + " portal=\"public\",\n", + " filters={'name.iexact': 'Allen Institute: Visual Coding \\u2013 Neuropixels'}\n", + ").json()\n", + "\n", + "project_id = allen_projects['projects'][0]['id']\n", + "\n", + "# Load the project by ID to get its full details\n", + "project_details = client.load(\"project\", portal=\"public\", id=project_id).json()\n", + "project = project_details['project']\n", + "print(f\"Project: {project['name']}\")\n", + "print(f\"Description: {project['description'][:200]}...\")\n", + "\n", + "# Load subjects belonging to this project\n", + "subjects = client.load(\n", + " \"subject\",\n", + " portal=\"public\",\n", + " filters={'projects': project_id}\n", + ").json()\n", + "print(f\"\\nSubjects: {len(subjects['subjects'])}\")\n", "\n", - "print(f\"Found {count} projects matching all criteria\")\n", + "# Load sessions and include experiment data\n", + "sessions = client.load(\n", + " \"session\",\n", + " portal=\"public\",\n", + " filters={'projects': project_id},\n", + " include=['dataacquisition', 'behaviors']\n", + ").json()\n", + "print(f\"Sessions: {len(sessions['sessions'])}\")\n", "\n", - "print(\"\\nResponse type:\", type(multi_filter))\n", - "if isinstance(multi_filter, dict) and 'projects' in multi_filter:\n", - " print(\"Response contains 'projects' key with data\")\n", - " print(f\"The id of project is {multi_filter['projects'][0]['id']} and the name is {multi_filter['projects'][0]['name']}\")" + "# Show details for the first session\n", + "if sessions['sessions']:\n", + " s = sessions['sessions'][0]\n", + " print(f\"\\nFirst session: {s['name']}\")\n", + " print(f\" Data acquisitions: {len(s.get('dataacquisition', []))}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "# 5. Basic Data Retrieval " + "# 6. Sorting\n", + "\n", + "You can control the order of returned results using the `sort` parameter. Sort by any field in ascending order or use a `-` prefix for descending order." ] }, { @@ -193,12 +220,26 @@ "metadata": {}, "outputs": [], "source": [ - "allen_projects = client.load_model(\n", - " \"project\", \n", - " portal=\"public\", \n", - " filters={'name.iexact': 'Allen Institute: Visual Coding – Neuropixels'}\n", - ").json()\n", - "print(allen_projects['projects'][0])" + "# Sort projects alphabetically by name (ascending)\n", + "alpha_projects = client.load(\"project\", portal=\"public\", sort=['name']).json()\n", + "print(\"Alphabetically sorted:\")\n", + "for p in alpha_projects.get('projects', [])[:3]:\n", + " print(f\" {p['name']}\")\n", + "\n", + "# Sort descending\n", + "reverse_alpha = client.load(\"project\", portal=\"public\", sort=['-name']).json()\n", + "print(\"\\nReverse alphabetical:\")\n", + "for p in reverse_alpha.get('projects', [])[:3]:\n", + " print(f\" {p['name']}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 7. Pagination\n", + "\n", + "By default the API returns one page of results. Use `limit` and `offset` for manual pagination, or `load_all=True` to automatically fetch every page and return a combined dict." ] }, { @@ -207,67 +248,24 @@ "metadata": {}, "outputs": [], "source": [ - "# Find the Allen Institute: Visual Coding – Neuropixels project\n", - "allen_projects = client.load_model(\n", - " \"project\", \n", - " portal=\"public\", \n", - " filters={'name.iexact': 'Allen Institute: Visual Coding – Neuropixels'}\n", - ").json()\n", - "\n", - "project_id = allen_projects['projects'][0]['id']\n", - "\n", - "# Get basic project information without including sessions\n", - "project_details = client.load_model(\n", - " \"project\",\n", - " portal=\"public\",\n", - " id=project_id\n", - ").json()\n", - "\n", - "# Access project information\n", - "project = project_details['project']\n", - "print(f\"Project name: {project['name']}\")\n", - "print(f\"Description: {project['description'][:200]}...\")\n", - "\n", - "# Get subjects related to this project\n", - "subjects = client.load_model(\n", - " \"subject\", \n", - " portal=\"public\",\n", - " filters={'projects': project_id}\n", - ").json()\n", - "\n", - "print(f\"\\nThis project has {len(subjects['subjects'])} subjects\")\n", - "\n", - "# Access the sessions included with the project - sessions should be available now\n", - "if 'sessions' in project:\n", - " session_ids = project['sessions']\n", - " print(f\"This project has {len(session_ids)} sessions\")\n", - " \n", - " # Get details for the first session as an example\n", - " if session_ids:\n", - " first_session_id = session_ids[0]\n", - " session_details = client.load_model(\n", - " \"session\",\n", - " portal=\"public\",\n", - " id=first_session_id\n", - " ).json()\n", - " \n", - " # Print session details\n", - " if 'session' in session_details:\n", - " session = session_details['session']\n", - " print(f\"\\nExample session details:\")\n", - " print(f\"ID: {first_session_id}\")\n", - " print(f\"Name: {session['name']}\")\n", - "else:\n", - " print(\"No sessions found in project data\")" + "# Manual pagination: fetch records 1-10, then 11-20\n", + "page1 = client.load(\"project\", portal=\"public\", limit=10, offset=0).json()\n", + "page2 = client.load(\"project\", portal=\"public\", limit=10, offset=10).json()\n", + "print(f\"Page 1: {len(page1['projects'])} projects\")\n", + "print(f\"Page 2: {len(page2['projects'])} projects\")\n", + "\n", + "# Auto-pagination: fetch ALL records across all pages in one call\n", + "all_projects = client.load(\"project\", portal=\"public\", load_all=True)\n", + "print(f\"\\nTotal public projects (all pages): {len(all_projects['projects'])}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "# 6. Sorting\n", + "# 8. Convenience Loaders\n", "\n", - "You can control the order of returned results using the sort parameter. Sort by any field in ascending order (e.g., alphabetically by name) or use a minus sign prefix for descending order." + "Convenience loaders are shortcut methods for the most common models. They set sensible default `include` relations and expose named keyword arguments for the most frequent filter fields." ] }, { @@ -276,35 +274,45 @@ "metadata": {}, "outputs": [], "source": [ - "# Sort projects alphabetically by name\n", - "alpha_projects = client.load_model(\n", - " \"project\",\n", - " portal=\"public\",\n", - " sort=['name'] # Ascending alphabetical order\n", - ").json()\n", - "\n", - "print(\"\\nAlphabetically sorted projects:\")\n", - "for i, project in enumerate(alpha_projects.get('projects', [])[:3]):\n", - " print(f\"{i+1}. {project.get('name')}\")\n", - "\n", - "# Sort projects reverse-alphabetically by name\n", - "reverse_alpha = client.load_model(\n", - " \"project\",\n", - " portal=\"public\",\n", - " sort=['-name'] # Descending alphabetical order\n", - ").json()\n", - "\n", - "print(\"\\nReverse alphabetically sorted projects:\")\n", - "for i, project in enumerate(reverse_alpha.get('projects', [])[:3]):\n", - " print(f\"{i+1}. {project.get('name')}\")" + "# load_project — embeds sessions, subjects, collections, cohorts by default\n", + "projects = client.load_project(portal=\"public\", name=\"Allen\", load_all=True)\n", + "print(f\"Projects matching 'Allen': {len(projects['projects'])}\")\n", + "\n", + "# load_session — embeds dataacquisition, behaviors, manipulations, epochs by default\n", + "# (use portal=\"public\" for public data, or omit for private)\n", + "# sessions = client.load_session(name=\"Rat\", load_all=True)\n", + "\n", + "# load_subject — embeds procedures and subjectlogs by default\n", + "# subjects = client.load_subject(sex=\"M\", projects=\"\", load_all=True)\n", + "\n", + "# load_behavior / load_dataacquisition / load_manipulation — scope by session UUID\n", + "# behaviors = client.load_behavior(session=\"\", load_all=True)\n", + "# dataacquisition = client.load_dataacquisition(session=\"\", load_all=True)\n", + "# manipulations = client.load_manipulation(session=\"\", load_all=True)\n", + "\n", + "# load_procedure — scopes by subject UUID\n", + "# procedures = client.load_procedure(subject=\"\", load_all=True)\n", + "\n", + "# load_collection / load_cohort\n", + "# collection = client.load_collection(name=\"MyCollection\")\n", + "# cohort = client.load_cohort(name=\"MyCohort\")\n", + "\n", + "# All convenience loaders accept extra filters= and a custom include= override:\n", + "# sessions = client.load_session(\n", + "# name=\"Rat\",\n", + "# include=[\"projects\"],\n", + "# filters={\"description.icontains\": \"hippocampus\"},\n", + "# load_all=True,\n", + "# )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "# 7. Updating data\n", - "The API supports creating new records and updating existing ones. When creating a record, provide the required fields for that model type. For updates, fetch the existing record, modify its attributes, and save it back. Both operations require appropriate permissions." + "# 9. Creating and Updating Data\n", + "\n", + "Use `client.save()` to create or update records. Omit `id` to create a new record (POST); provide `id` to update an existing one (PATCH). Write operations require authentication on the private portal." ] }, { @@ -313,62 +321,46 @@ "metadata": {}, "outputs": [], "source": [ - "# Example of updating a session (use your own private data)\n", - "# Note: This is just syntax demonstration and won't execute on public data\n", - "token = None # Input your token here\n", - "if token:\n", - " client = BrainstemClient(token=token)\n", - " print(\"Client initialized with saved token\")\n", - "else:\n", - " # This will prompt for username/password if run\n", - " print(\"No saved token found. Will need to login.\")\n", - " client = BrainstemClient()\n", - "\"\"\"\n", - "# First load a session you have permission to modify\n", - "filtered_session = client.load_model('session', filters={'name.iexact': 'your session'}).json()\n", - "\n", - "# Your existing code works without using a 'session' key\n", - "filtered_session['description'] = 'This is a test description, showing how to update a session'\n", - "\n", - "# Pass the whole filtered_session object to save_model\n", - "updated_session = client.save_model(\n", - " 'session', \n", - " id=filtered_session['sessions'][0]['id'], \n", - " data=filtered_session\n", + "# Create a new session (replace project UUID with your own)\n", + "new_session = client.save(\n", + " \"session\",\n", + " data={\n", + " \"name\": \"My new session\",\n", + " \"description\": \"Created via the Python API\",\n", + " \"projects\": [\"\"],\n", + " }\n", ").json()\n", + "# print(new_session)\n", "\n", - "print(\"Session updated successfully\")\n", - "\"\"\"\n", + "# Update an existing session (replace UUID with your own)\n", + "updated = client.save(\n", + " \"session\",\n", + " id=\"\",\n", + " data={\"description\": \"Updated description\"}\n", + ").json()\n", + "# print(updated)\n", "\n", - "# Creating a new subject - SYNTAX DEMONSTRATION ONLY\n", - "\"\"\"\n", "# Create a new subject\n", - "new_subject_data = {\n", - " 'name': 'Test Subject 001',\n", - " 'species': 'mouse',\n", - " 'sex': 'M',\n", - " 'projects': ['your-project-id'] # Project this subject belongs to\n", - "}\n", - "\n", - "# Save the new subject\n", - "new_subject = client.save_model(\n", - " 'subject',\n", - " data=new_subject_data\n", + "new_subject = client.save(\n", + " \"subject\",\n", + " data={\n", + " \"name\": \"Sub001\",\n", + " \"sex\": \"M\",\n", + " \"projects\": [\"\"],\n", + " }\n", ").json()\n", + "# print(new_subject)\n", "\n", - "print(f\"New subject created with ID: {new_subject['subject']['id']}\")\n", - "\"\"\"\n", - "\n", - "print(\"To use them, replace placeholder IDs with your own private data IDs.\")\n", - "print(\"Only run these commands on data you have permission to modify.\")" + "print(\"Replace the placeholder UUIDs above with your own IDs before running.\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "# 8. Deleting data\n", - "You can remove records from the database using their IDs. Deletion is permanent, so use this operation with caution. The API will return a success status code (204) when deletion is successful. Like other write operations, deletion requires appropriate permissions." + "# 10. Deleting Data\n", + "\n", + "Use `client.delete()` with a model name and record UUID. A successful deletion returns HTTP 204 (no content). This operation is **permanent** — use with care." ] }, { @@ -377,43 +369,15 @@ "metadata": {}, "outputs": [], "source": [ - "token = None # Input your token here.\n", - "if token:\n", - " client = BrainstemClient(token=token)\n", - " print(\"Client initialized with saved token\")\n", + "# Delete a record by UUID (replace with your own ID)\n", + "response = client.delete(\"session\", id=\"\")\n", + "\n", + "if response.status_code == 204:\n", + " print(\"Session deleted successfully.\")\n", "else:\n", - " # This will prompt for username/password if run\n", - " print(\"No saved token found. Will need to login.\")\n", - " client = BrainstemClient()\n", - "\n", - "# Example of deleting a session (demonstration only)\n", - "\"\"\"\n", - "# First create a test session that we can safely delete\n", - "new_session_data = {\n", - " 'name': 'Temporary Test Session',\n", - " 'description': 'This session will be deleted',\n", - " 'project': 'your-project-id' # Replace with your own project ID\n", - "}\n", - "\n", - "# Create the session\n", - "created_session = client.save_model('session', data=new_session_data).json()\n", - "\n", - "if 'session' in created_session:\n", - " session_id = created_session['session']['id']\n", - " print(f\"Created temporary session with ID: {session_id}\")\n", - " \n", - " # Now delete the session we just created\n", - " delete_response = client.delete_model(\"session\", id=session_id)\n", - " \n", - " # Check if deletion was successful\n", - " if delete_response.status_code == 204: # 204 No Content indicates success\n", - " print(\"Session deleted successfully\")\n", - " else:\n", - " print(f\"Delete failed with status code: {delete_response.status_code}\")\n", - "\"\"\"\n", - "\n", - "print(\"To use it, you would need to have write permissions and valid IDs.\")\n", - "print(\"Only delete data that you have created and have permission to remove.\")" + " print(f\"Delete failed: {response.status_code} — {response.text}\")\n", + "\n", + "print(\"Replace the placeholder UUID above with your own ID before running.\")" ] } ], diff --git a/brainstem_api_tutorial.py b/brainstem_api_tutorial.py index 2b94ce2..a125d97 100644 --- a/brainstem_api_tutorial.py +++ b/brainstem_api_tutorial.py @@ -1,14 +1,20 @@ from brainstem_api_tools import BrainstemClient -# 0. Load the client. User email and password will be requested. +# 0. Load the client. +# On first run a browser window opens for secure login (supports 2FA). +# The token is cached at ~/.config/brainstem/token and reused automatically. +# Pass a saved token directly to skip the browser flow: +# client = BrainstemClient(token="YOUR_TOKEN") +# To connect to a local development server: +# client = BrainstemClient(url="http://127.0.0.1:8000/") client = BrainstemClient() # Loading sessions. -## load_model can be used to load any model: +## load can be used to load any model: ## We just need to pass our settings and the name of the model. -output1 = client.load_model('session').json() +output1 = client.load('session').json() ### We can fetch a single session entry from the loaded models. session = output1["sessions"][0] @@ -16,29 +22,82 @@ ## We can also filter the models by providing a dictionary where ## the keys are the fields and the values the possible contents. ## In this example, it will just load sessions whose name is "yeah". -output1 = client.load_model('session', filters={'name': 'yeah'}).json() +output1 = client.load('session', filters={'name': 'yeah'}).json() ## Loaded models can be sorted by different criteria applying to -## their fields. In this example, sessions will be sorted in -## descending ording according to their name. -output1 = client.load_model('session', sort=['-name']).json() +## their fields. In this example, sessions will be sorted in +## descending order according to their name. +output1 = client.load('session', sort=['-name']).json() ## In some cases models contain relations with other models, and ## they can be also loaded with the models if requested. In this ## example, all the projects, data acquisition, behaviors and -## manipulations related to each session will be included. -output1 = client.load_model('session', include=['projects', 'dataacquisition', 'behaviors', 'manipulations']).json() +## manipulations related to each session will be included. +output1 = client.load('session', include=['projects', 'dataacquisition', 'behaviors', 'manipulations']).json() -### The list of related experiment data can be retrived from the +### The list of related experiment data can be retrieved from the ### returned dictionary. dataacquisition = output1["sessions"][0]["dataacquisition"] ## All these options can be combined to suit the requirements -## of the users. For example, we can get only the session that +## of the users. For example, we can get only the sessions that ## contain the word "Rat" in their name, sorted in descending ## order by their name and including the related projects. -output1 = client.load_model('session', filters={'name.icontains': 'Rat'}, sort=["-name"], include=['projects']).json() +output1 = client.load('session', filters={'name.icontains': 'Rat'}, sort=["-name"], include=['projects']).json() + +## Pagination: load the second page of results (records 21-40). +output1 = client.load('session', limit=20, offset=20).json() + +## load_all=True automatically fetches every page and returns a +## combined dict — useful when the record count exceeds one page. +output1 = client.load('session', load_all=True) +all_sessions = output1["sessions"] + + +# Convenience loaders + +## Each convenience loader sets sensible defaults (related models +## included automatically) and exposes named keyword arguments for +## the most common filter fields. + +## load_session includes dataacquisition, behaviors, manipulations +## and epochs by default. Filter by name substring: +output_sessions = client.load_session(name='Rat', load_all=True) + +## load_subject includes procedures and subjectlogs. +## Filter by sex ('M' / 'F') and/or project UUID: +project_uuid = 'e7475834-7733-48cf-9e3b-f4f2d2d0305a' +output_subjects = client.load_subject(sex='M', projects=project_uuid, load_all=True) + +## load_project includes sessions, subjects, collections and cohorts. +output_projects = client.load_project(name='MyProject') + +## load_collection includes sessions by default. +output_collection = client.load_collection(name='MyCollection') + +## load_cohort includes subjects by default. +output_cohort = client.load_cohort(name='MyCohort') + +## load_behavior / load_dataacquisition / load_manipulation all +## accept session= to scope results to a single session. +session_uuid = '0e39c1fd-f413-4142-95f7-f50185e81fa4' +output_behaviors = client.load_behavior(session=session_uuid, load_all=True) +output_dataacquisition = client.load_dataacquisition(session=session_uuid, load_all=True) +output_manipulations = client.load_manipulation(session=session_uuid, load_all=True) + +## load_procedure scopes by subject UUID. +subject_uuid = 'bfb0e3e2-2c48-4b72-9034-1ef50c5c432a' +output_procedures = client.load_procedure(subject=subject_uuid, load_all=True) + +## Any convenience loader also accepts extra filters= and custom +## include= overrides just like load() itself. +output_sessions = client.load_session( + name='Rat', + include=['projects'], + filters={'description.icontains': 'hippocampus'}, + load_all=True, +) # Updating a session @@ -48,7 +107,7 @@ ## previously loaded sessions. session = {} session["description"] = 'new description' -output2 = client.save_model("session", id="0e39c1fd-f413-4142-95f7-f50185e81fa4", data=session).json() +output2 = client.save("session", id="0e39c1fd-f413-4142-95f7-f50185e81fa4", data=session).json() # Creating a new session @@ -61,10 +120,18 @@ session["description"] = 'description' ## Submitting session -output3 = client.save_model("session", data=session).json() +output3 = client.save("session", data=session).json() + + +# Deleting a session + +## Pass the model name and the UUID of the record to remove. +response = client.delete("session", id="0e39c1fd-f413-4142-95f7-f50185e81fa4") +if response.status_code == 204: + print("Session deleted") # Load public projects ## Request the public data by defining the portal to be public -output4 = client.load_model("project", portal="public").json() +output4 = client.load("project", portal="public").json() diff --git a/pyproject.toml b/pyproject.toml index f5b211d..e46a1f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,33 @@ [build-system] requires = ["setuptools>=61.0", "wheel"] -build-backend = "setuptools.build_meta" \ No newline at end of file +build-backend = "setuptools.build_meta" + +[project] +name = "brainstem_python_api_tools" +version = "2.0.0" +description = "A Python toolset for interacting with the BrainSTEM API." +readme = "README.md" +license = { text = "MIT" } +authors = [ + { name = "BrainSTEM Team", email = "petersen.peter@gmail.com" } +] +requires-python = ">=3.8" +dependencies = [ + "requests", +] + +[project.optional-dependencies] +dev = [ + "pytest", + "pytest-mock", +] + +[project.urls] +Homepage = "https://github.com/brainstem-org/brainstem_python_api_tools" + +[project.scripts] +brainstem = "brainstem_api_tools.cli:main" + +[tool.setuptools.packages.find] +where = ["."] +include = ["brainstem_api_tools*"] diff --git a/setup.py b/setup.py deleted file mode 100644 index 42d8586..0000000 --- a/setup.py +++ /dev/null @@ -1,24 +0,0 @@ -from setuptools import setup, find_packages - -setup( - name="brainstem_python_api_tools", - version="0.2.0", - packages=find_packages(), - install_requires=[ - "requests", - "jupyter", - ], - author="BrainSTEM Team", - author_email="petersen.peter@gmail.com", - description="A Python toolset for interacting with API of BrainSTEM.", - long_description=open("README.md").read(), - long_description_content_type="text/markdown", - url="https://github.com/brainstem-org/brainstem_python_api_tools", - license="MIT", - classifiers=[ - "Programming Language :: Python :: 3", - "License :: OSI Approved :: MIT License", - "Operating System :: OS Independent", - ], - python_requires='>=3.7', -) diff --git a/__init__.py b/tests/__init__.py similarity index 100% rename from __init__.py rename to tests/__init__.py diff --git a/tests/test_client.py b/tests/test_client.py new file mode 100644 index 0000000..68d3531 --- /dev/null +++ b/tests/test_client.py @@ -0,0 +1,452 @@ +"""Unit tests for BrainstemClient. + +All HTTP calls are intercepted with pytest-mock / unittest.mock so no +network access is required. +""" + +import pytest +from unittest.mock import MagicMock, patch, PropertyMock + +from brainstem_api_tools.brainstem_api_client import ( + BrainstemClient, + AuthenticationError, + _MODEL_TO_APP, + _resolve_model, + _resolve_portal, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +TOKEN = "test-token-abc123" + + +def make_client(token=TOKEN): + """Return a BrainstemClient with a pre-set token (no auth flow).""" + return BrainstemClient(token=token) + + +def mock_response(status_code=200, json_body=None): + resp = MagicMock() + resp.status_code = status_code + resp.ok = status_code < 400 + resp.json.return_value = json_body or {} + return resp + + +# --------------------------------------------------------------------------- +# _resolve_model / _resolve_portal +# --------------------------------------------------------------------------- + +class TestResolvers: + def test_resolve_model_string(self): + assert _resolve_model("session") == "session" + + def test_resolve_model_invalid(self): + with pytest.raises(ValueError, match="Unknown model 'foobar'"): + _resolve_model("foobar") + + def test_resolve_portal_string(self): + assert _resolve_portal("public") == "public" + + def test_resolve_portal_invalid(self): + with pytest.raises(ValueError, match="Unknown portal 'nope'"): + _resolve_portal("nope") + + +# --------------------------------------------------------------------------- +# BrainstemClient construction +# --------------------------------------------------------------------------- + +class TestClientInit: + def test_token_set_directly(self): + client = make_client() + assert client._token == TOKEN + + def test_auth_header_on_session(self): + client = make_client() + assert client._session.headers["Authorization"] == f"Bearer {TOKEN}" + + def test_cached_token_used_on_init(self, tmp_path, monkeypatch): + token_file = tmp_path / "token" + token_file.write_text("cached-token") + monkeypatch.setattr( + "brainstem_api_tools.brainstem_api_client._TOKEN_FILE", token_file + ) + client = BrainstemClient() + assert client._token == "cached-token" + + def test_device_flow_triggered_when_no_cache(self, tmp_path, monkeypatch): + token_file = tmp_path / "token" # does not exist + monkeypatch.setattr( + "brainstem_api_tools.brainstem_api_client._TOKEN_FILE", token_file + ) + with patch.object(BrainstemClient, "_device_auth_flow", return_value="new-token") as mock_flow, \ + patch.object(BrainstemClient, "_save_token") as mock_save: + client = BrainstemClient() + mock_flow.assert_called_once() + mock_save.assert_called_once_with("new-token") + assert client._token == "new-token" + + +# --------------------------------------------------------------------------- +# load() +# --------------------------------------------------------------------------- + +class TestLoad: + def setup_method(self): + self.client = make_client() + + def _mock_get(self, json_body=None, status=200): + resp = mock_response(status, json_body) + self.client._session.get = MagicMock(return_value=resp) + return resp + + def test_load_all_sessions(self): + self._mock_get({"sessions": []}) + self.client.load("session") + self.client._session.get.assert_called_once() + url = self.client._session.get.call_args[0][0] + assert "stem/session/" in url + + def test_load_with_id(self): + uid = "00000000-0000-0000-0000-000000000001" + self._mock_get({"session": {}}) + self.client.load("session", id=uid) + url = self.client._session.get.call_args[0][0] + assert uid in url + + def test_load_filters_built_correctly(self): + self._mock_get() + self.client.load("session", filters={"name.icontains": "rat"}) + params = self.client._session.get.call_args[1]["params"] + assert params["filter{name.icontains}"] == "rat" + + def test_load_sort_built_correctly(self): + self._mock_get() + self.client.load("session", sort=["-name", "date"]) + params = self.client._session.get.call_args[1]["params"] + assert params["sort[]"] == ["-name", "date"] + + def test_load_include_appends_wildcard(self): + self._mock_get() + self.client.load("session", include=["behaviors"]) + params = self.client._session.get.call_args[1]["params"] + assert params["include[]"] == ["behaviors.*"] + + def test_load_pagination_params(self): + self._mock_get() + self.client.load("session", limit=50, offset=20) + params = self.client._session.get.call_args[1]["params"] + assert params["limit"] == 50 + assert params["offset"] == 20 + + def test_load_public_portal(self): + self._mock_get() + self.client.load("project", portal="public") + url = self.client._session.get.call_args[0][0] + assert "/public/" in url + + def test_load_invalid_model_raises(self): + with pytest.raises(ValueError, match="Unknown model"): + self.client.load("notamodel") + + def test_load_invalid_portal_raises(self): + with pytest.raises(ValueError, match="Unknown portal"): + self.client.load("session", portal="badportal") + + def test_model_to_app_routing(self): + """Spot-check a few app assignments.""" + cases = [ + ("session", "stem"), + ("breeding", "stem"), + ("project_membership_invitation", "stem"), + ("procedure", "modules"), + ("behavioralassay", "personal_attributes"), + ("license", "personal_attributes"), + ("setup", "personal_attributes"), + ("consumable", "resources"), + ("behavioralparadigm", "taxonomies"), + ("behavioralcategory", "taxonomies"), + ("regulatoryauthority", "taxonomies"), + ("brainregion", "taxonomies"), + ("publication", "dissemination"), + ("laboratory", "users"), + ("group_membership_invitation", "users"), + ("group_membership_request", "users"), + ] + for model, expected_app in cases: + self._mock_get() + self.client.load(model) + url = self.client._session.get.call_args[0][0] + assert f"/{expected_app}/{model}/" in url, ( + f"Expected app '{expected_app}' for model '{model}', got URL: {url}" + ) + + +# --------------------------------------------------------------------------- +# save() +# --------------------------------------------------------------------------- + +class TestSave: + def setup_method(self): + self.client = make_client() + + def test_create_uses_post(self): + self.client._session.post = MagicMock(return_value=mock_response(201)) + self.client.save("session", data={"name": "x", "projects": ["uuid"]}) + self.client._session.post.assert_called_once() + + def test_update_uses_patch(self): + uid = "00000000-0000-0000-0000-000000000002" + self.client._session.patch = MagicMock(return_value=mock_response(200)) + self.client.save("session", id=uid, data={"description": "updated"}) + self.client._session.patch.assert_called_once() + url = self.client._session.patch.call_args[0][0] + assert uid in url + + def test_save_invalid_model_raises(self): + with pytest.raises(ValueError): + self.client.save("ghost") + + +# --------------------------------------------------------------------------- +# delete() +# --------------------------------------------------------------------------- + +class TestDelete: + def setup_method(self): + self.client = make_client() + + def test_delete_sends_correct_url(self): + uid = "00000000-0000-0000-0000-000000000003" + self.client._session.delete = MagicMock(return_value=mock_response(204)) + self.client.delete("session", id=uid) + url = self.client._session.delete.call_args[0][0] + assert uid in url + assert "stem/session/" in url + + def test_delete_without_id_raises(self): + with pytest.raises(ValueError, match="'id' is required"): + self.client.delete("session") + + def test_delete_invalid_model_raises(self): + with pytest.raises(ValueError): + self.client.delete("ghost", id="some-uuid") + + +# --------------------------------------------------------------------------- +# Device auth flow +# --------------------------------------------------------------------------- + +class TestDeviceAuthFlow: + def test_success_flow(self): + client = make_client() # skip auth + + device_resp = MagicMock() + device_resp.raise_for_status = MagicMock() + device_resp.json.return_value = { + "device_code": "dev123", + "user_code": "ABC-DEF", + "verification_uri": "https://brainstem.org/activate", + "verification_uri_complete": "https://brainstem.org/activate?code=ABC-DEF", + "interval": 0, + } + + pending_resp = MagicMock() + pending_resp.raise_for_status = MagicMock() + pending_resp.json.return_value = {"status": "authorization_pending"} + + success_resp = MagicMock() + success_resp.raise_for_status = MagicMock() + success_resp.json.return_value = {"status": "success", "token": "new-jwt"} + + with patch("requests.post", side_effect=[device_resp, pending_resp, success_resp]), \ + patch("webbrowser.open"): + token = client._device_auth_flow() + + assert token == "new-jwt" + + def test_access_denied_raises(self): + client = make_client() + + device_resp = MagicMock() + device_resp.raise_for_status = MagicMock() + device_resp.json.return_value = { + "device_code": "dev123", + "user_code": "XYZ", + "verification_uri": "https://brainstem.org/activate", + "verification_uri_complete": "https://brainstem.org/activate?code=XYZ", + "interval": 0, + } + denied_resp = MagicMock() + denied_resp.raise_for_status = MagicMock() + denied_resp.json.return_value = {"error": "access_denied"} + + with patch("requests.post", side_effect=[device_resp, denied_resp]), \ + patch("webbrowser.open"): + with pytest.raises(AuthenticationError, match="Authorization failed"): + client._device_auth_flow() + + +# --------------------------------------------------------------------------- +# load_all auto-pagination +# --------------------------------------------------------------------------- + +class TestLoadAll: + def setup_method(self): + self.client = make_client() + + def _mock_pages(self, pages): + responses = [] + for page in pages: + r = MagicMock() + r.raise_for_status = MagicMock() + r.json.return_value = page + responses.append(r) + self.client._session.get = MagicMock(side_effect=responses) + + def test_single_page(self): + self._mock_pages([ + {"sessions": [{"id": "1"}, {"id": "2"}], "count": 2}, + ]) + result = self.client.load("session", load_all=True) + assert isinstance(result, dict) + assert len(result["sessions"]) == 2 + + def test_two_pages_merged(self): + self._mock_pages([ + {"sessions": [{"id": "1"}, {"id": "2"}], "count": 3}, + {"sessions": [{"id": "3"}], "count": 3}, + ]) + result = self.client.load("session", load_all=True) + assert len(result["sessions"]) == 3 + assert [r["id"] for r in result["sessions"]] == ["1", "2", "3"] + + def test_load_all_false_returns_response(self): + r = MagicMock() + self.client._session.get = MagicMock(return_value=r) + result = self.client.load("session", load_all=False) + assert result is r + + +# --------------------------------------------------------------------------- +# Convenience loaders +# --------------------------------------------------------------------------- + +class TestConvenienceLoaders: + def setup_method(self): + self.client = make_client() + self.client._session.get = MagicMock(return_value=mock_response(200, {})) + + def _get_params(self): + return self.client._session.get.call_args[1].get("params", {}) + + def test_load_project_default_include(self): + self.client.load_project() + params = self._get_params() + assert set(params["include[]"]) == { + "sessions.*", "subjects.*", "collections.*", "cohorts.*" + } + + def test_load_project_name_filter(self): + self.client.load_project(name="Allen") + params = self._get_params() + assert params["filter{name.icontains}"] == "Allen" + + def test_load_subject_default_include(self): + self.client.load_subject() + params = self._get_params() + assert set(params["include[]"]) == {"procedures.*", "subjectlogs.*"} + + def test_load_subject_sex_filter(self): + self.client.load_subject(sex="M") + assert self._get_params()["filter{sex}"] == "M" + + def test_load_subject_strain_filter(self): + uuid = "aaaaaaaa-0000-0000-0000-000000000000" + self.client.load_subject(strain=uuid) + assert self._get_params()["filter{strain.id}"] == uuid + + def test_load_session_default_include(self): + self.client.load_session() + params = self._get_params() + assert set(params["include[]"]) == { + "dataacquisition.*", "behaviors.*", "manipulations.*", "epochs.*" + } + + def test_load_session_project_filter(self): + uuid = "bbbbbbbb-0000-0000-0000-000000000000" + self.client.load_session(projects=uuid) + assert self._get_params()["filter{projects.id}"] == uuid + + def test_load_collection_default_include(self): + self.client.load_collection() + assert self._get_params()["include[]"] == ["sessions.*"] + + def test_load_cohort_default_include(self): + self.client.load_cohort() + assert self._get_params()["include[]"] == ["subjects.*"] + + def test_load_behavior_session_filter(self): + uuid = "cccccccc-0000-0000-0000-000000000000" + self.client.load_behavior(session=uuid) + assert self._get_params()["filter{session.id}"] == uuid + + def test_load_dataacquisition_session_filter(self): + uuid = "dddddddd-0000-0000-0000-000000000000" + self.client.load_dataacquisition(session=uuid) + assert self._get_params()["filter{session.id}"] == uuid + + def test_load_manipulation_session_filter(self): + uuid = "eeeeeeee-0000-0000-0000-000000000000" + self.client.load_manipulation(session=uuid) + assert self._get_params()["filter{session.id}"] == uuid + + def test_load_procedure_subject_filter(self): + uuid = "ffffffff-0000-0000-0000-000000000000" + self.client.load_procedure(subject=uuid) + assert self._get_params()["filter{subject.id}"] == uuid + + def test_custom_include_overrides_default(self): + self.client.load_session(include=["behaviors"]) + assert self._get_params()["include[]"] == ["behaviors.*"] + + def test_extra_filters_merged_with_field_kwargs(self): + self.client.load_session( + name="Rat", + filters={"description.icontains": "hippocampus"}, + ) + params = self._get_params() + assert params["filter{name.icontains}"] == "Rat" + assert params["filter{description.icontains}"] == "hippocampus" + + +# --------------------------------------------------------------------------- +# CLI logout +# --------------------------------------------------------------------------- + +class TestCLILogout: + def test_logout_removes_token_file(self, tmp_path, capsys): + import brainstem_api_tools.cli as cli_module + token_file = tmp_path / "token" + token_file.write_text("mytoken") + + with patch.object(cli_module, "_TOKEN_FILE", token_file), \ + patch("sys.argv", ["brainstem", "logout"]): + cli_module.main() + + assert not token_file.exists() + assert "Logged out" in capsys.readouterr().out + + def test_logout_no_token_file(self, tmp_path, capsys): + import brainstem_api_tools.cli as cli_module + token_file = tmp_path / "token" # does not exist + + with patch.object(cli_module, "_TOKEN_FILE", token_file), \ + patch("sys.argv", ["brainstem", "logout"]): + cli_module.main() + + assert "No cached token found" in capsys.readouterr().out From 477bcef8ba7ff6c59722b4f5ad69a28dfbee5d7e Mon Sep 17 00:00:00 2001 From: "Peter C. Petersen" Date: Mon, 30 Mar 2026 10:48:22 +0200 Subject: [PATCH 2/3] Add timeouts, retries, new loaders, tests & CI Introduce request retries and a DEFAULT_TIMEOUT applied to all HTTP calls, add timeout parameters to get/post/patch/delete, and raise AuthenticationError on 401 during paginated loads. Add device auth polling timeout (max_wait), context-manager support (__enter__/__exit__), and validation for unknown kwargs in _convenience_load; change many convenience loader offsets to default to None. Add numerous new convenience loader helpers (procedurelog, subjectlog, equipment, consumablestock, behavioralassay, datastorage, setup, hardwaredevice, brainregion, species, strain, publication, laboratory). Bump package __version__ to 2.0.0 and require requests>=2.28, update tutorial to use placeholder UUIDs, add comprehensive unit tests for timeouts, auth polling, context manager, loader validation and new loaders, and add a GitHub Actions workflow to run tests on push/PR across Python versions. --- .github/workflows/tests.yml | 26 ++ brainstem_api_tools/__init__.py | 2 + brainstem_api_tools/brainstem_api_client.py | 338 +++++++++++++++++++- brainstem_api_tutorial.py | 15 +- pyproject.toml | 2 +- tests/test_client.py | 276 ++++++++++++++++ 6 files changed, 633 insertions(+), 26 deletions(-) create mode 100644 .github/workflows/tests.yml diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 0000000..2f8faaf --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,26 @@ +name: tests + +on: + push: + branches: ["main"] + pull_request: + +jobs: + test: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.8", "3.10", "3.12"] + + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install package and dev deps + run: pip install -e ".[dev]" + + - name: Run tests + run: pytest --tb=short diff --git a/brainstem_api_tools/__init__.py b/brainstem_api_tools/__init__.py index e98a356..7b19a80 100644 --- a/brainstem_api_tools/__init__.py +++ b/brainstem_api_tools/__init__.py @@ -1 +1,3 @@ from .brainstem_api_client import BrainstemClient, ModelType, PortalType, AuthenticationError + +__version__ = "2.0.0" diff --git a/brainstem_api_tools/brainstem_api_client.py b/brainstem_api_tools/brainstem_api_client.py index 31c7153..4c97a90 100644 --- a/brainstem_api_tools/brainstem_api_client.py +++ b/brainstem_api_tools/brainstem_api_client.py @@ -4,9 +4,12 @@ import webbrowser from enum import Enum from pathlib import Path +from typing import Union import requests +from requests.adapters import HTTPAdapter from requests.models import Response +from urllib3.util.retry import Retry # --------------------------------------------------------------------------- @@ -119,6 +122,7 @@ class AuthenticationError(Exception): class BrainstemClient: BASE_URL = "https://www.brainstem.org/" + DEFAULT_TIMEOUT: int = 30 # seconds; applied to all HTTP calls def __init__( self, @@ -130,6 +134,11 @@ def __init__( self._address = base + "api/" self._session = requests.Session() + # Automatically retry transient server errors + _retry = Retry(total=3, backoff_factor=0.5, status_forcelist={502, 503, 504}) + self._session.mount("https://", HTTPAdapter(max_retries=_retry)) + self._session.mount("http://", HTTPAdapter(max_retries=_retry)) + if token: self._token = token else: @@ -156,7 +165,7 @@ def _save_token(self, token: str) -> None: _TOKEN_FILE.write_text(token) _TOKEN_FILE.chmod(stat.S_IRUSR | stat.S_IWUSR) # 0o600 - def _device_auth_flow(self, headless: bool = False) -> str: + def _device_auth_flow(self, headless: bool = False, max_wait: int = 900) -> str: """Run the BrainSTEM device authorization flow and return a token. The user authenticates in their browser (supports 2FA). No @@ -166,6 +175,7 @@ def _device_auth_flow(self, headless: bool = False) -> str: ---------- headless : If ``True``, print the verification URI and user code instead of opening a browser window. + max_wait : Maximum seconds to wait for browser approval (default: 900). """ # Step 1 — initiate a device session resp = requests.post( @@ -186,8 +196,13 @@ def _device_auth_flow(self, headless: bool = False) -> str: webbrowser.open(data["verification_uri_complete"]) print("Waiting for browser approval...") - # Step 3 — poll until resolved + # Step 3 — poll until resolved (timeout after max_wait seconds) + deadline = time.monotonic() + max_wait while True: + if time.monotonic() > deadline: + raise AuthenticationError( + f"Device authorization timed out after {max_wait} seconds." + ) time.sleep(interval) poll = requests.post( self._address + "auth/device/token/", @@ -233,7 +248,7 @@ def load(self, include: list = None, limit: int = None, offset: int = None, - load_all: bool = False): + load_all: bool = False) -> Union[Response, dict]: """Load one or more records of *model*. Parameters @@ -271,7 +286,7 @@ def load(self, params["offset"] = offset if not load_all: - return self._session.get(url, params=params) + return self._session.get(url, params=params, timeout=self.DEFAULT_TIMEOUT) # --- auto-paginate and merge all pages --- page_size = limit or 100 @@ -282,7 +297,11 @@ def load(self, records_key: str = None while True: - resp = self._session.get(url, params=params) + resp = self._session.get(url, params=params, timeout=self.DEFAULT_TIMEOUT) + if resp.status_code == 401: + raise AuthenticationError( + "API token is invalid or expired. Run `brainstem login` to re-authenticate." + ) resp.raise_for_status() data = resp.json() @@ -330,10 +349,10 @@ def save(self, if id is not None: url = self._build_url(portal, app, model, id, options) - return self._session.patch(url, json=data) + return self._session.patch(url, json=data, timeout=self.DEFAULT_TIMEOUT) else: url = self._build_url(portal, app, model, options=options) - return self._session.post(url, json=data) + return self._session.post(url, json=data, timeout=self.DEFAULT_TIMEOUT) def delete(self, model, @@ -353,7 +372,13 @@ def delete(self, portal = _resolve_portal(portal) app = _MODEL_TO_APP[model] url = self._build_url(portal, app, model, id) - return self._session.delete(url) + return self._session.delete(url, timeout=self.DEFAULT_TIMEOUT) + + def __enter__(self): + return self + + def __exit__(self, *args): + self._session.close() # ------------------------------------------------------------------ # Convenience loaders (mirror the MATLAB load_* helpers) @@ -363,6 +388,11 @@ def _convenience_load(self, model, default_include, filter_map, portal, id, filters, sort, include, limit, offset, load_all, **field_kwargs): """Shared implementation for all convenience loaders.""" + unknown = set(field_kwargs) - set(filter_map) + if unknown: + raise TypeError( + f"Unexpected keyword argument(s): {', '.join(sorted(unknown))}" + ) merged_filters = dict(filters or {}) for kwarg, api_field in filter_map.items(): value = field_kwargs.get(kwarg) @@ -385,7 +415,7 @@ def load_project(self, portal="private", id: str = None, subjects: str = None, tags: str = None, filters: dict = None, sort: list = None, include: list = None, limit: int = None, - offset: int = 0, load_all: bool = False): + offset: int = None, load_all: bool = False): """Load project(s). Embeds sessions, subjects, collections and cohorts by default. Parameters @@ -410,7 +440,7 @@ def load_subject(self, portal="private", id: str = None, strain: str = None, sex: str = None, tags: str = None, filters: dict = None, sort: list = None, include: list = None, limit: int = None, - offset: int = 0, load_all: bool = False): + offset: int = None, load_all: bool = False): """Load subject(s). Embeds procedures and subjectlogs by default. Parameters @@ -436,7 +466,7 @@ def load_session(self, portal="private", id: str = None, datastorage: str = None, tags: str = None, filters: dict = None, sort: list = None, include: list = None, limit: int = None, - offset: int = 0, load_all: bool = False): + offset: int = None, load_all: bool = False): """Load session(s). Embeds dataacquisition, behaviors, manipulations and epochs by default. Parameters @@ -460,7 +490,7 @@ def load_collection(self, portal="private", id: str = None, name: str = None, tags: str = None, filters: dict = None, sort: list = None, include: list = None, limit: int = None, - offset: int = 0, load_all: bool = False): + offset: int = None, load_all: bool = False): """Load collection(s). Embeds sessions by default. Parameters @@ -481,7 +511,7 @@ def load_cohort(self, portal="private", id: str = None, name: str = None, tags: str = None, filters: dict = None, sort: list = None, include: list = None, limit: int = None, - offset: int = 0, load_all: bool = False): + offset: int = None, load_all: bool = False): """Load cohort(s). Embeds subjects by default. Parameters @@ -502,7 +532,7 @@ def load_behavior(self, portal="private", id: str = None, session: str = None, tags: str = None, filters: dict = None, sort: list = None, include: list = None, limit: int = None, - offset: int = 0, load_all: bool = False): + offset: int = None, load_all: bool = False): """Load behavior record(s). Parameters @@ -523,7 +553,7 @@ def load_dataacquisition(self, portal="private", id: str = None, session: str = None, tags: str = None, filters: dict = None, sort: list = None, include: list = None, limit: int = None, - offset: int = 0, load_all: bool = False): + offset: int = None, load_all: bool = False): """Load data acquisition record(s). Parameters @@ -544,7 +574,7 @@ def load_manipulation(self, portal="private", id: str = None, session: str = None, tags: str = None, filters: dict = None, sort: list = None, include: list = None, limit: int = None, - offset: int = 0, load_all: bool = False): + offset: int = None, load_all: bool = False): """Load manipulation record(s). Parameters @@ -565,7 +595,7 @@ def load_procedure(self, portal="private", id: str = None, subject: str = None, tags: str = None, filters: dict = None, sort: list = None, include: list = None, limit: int = None, - offset: int = 0, load_all: bool = False): + offset: int = None, load_all: bool = False): """Load procedure record(s). Parameters @@ -581,3 +611,277 @@ def load_procedure(self, portal="private", id: str = None, limit=limit, offset=offset, load_all=load_all, subject=subject, tags=tags, ) + + def load_procedurelog(self, portal="private", id: str = None, + procedure: str = None, tags: str = None, + filters: dict = None, sort: list = None, + include: list = None, limit: int = None, + offset: int = None, load_all: bool = False): + """Load procedure log record(s). + + Parameters + ---------- + procedure : Filter by procedure UUID. + tags : Filter by tag. + """ + return self._convenience_load( + "procedurelog", + default_include=[], + filter_map={"procedure": "procedure.id", "tags": "tags"}, + portal=portal, id=id, filters=filters, sort=sort, include=include, + limit=limit, offset=offset, load_all=load_all, + procedure=procedure, tags=tags, + ) + + def load_subjectlog(self, portal="private", id: str = None, + subject: str = None, tags: str = None, + filters: dict = None, sort: list = None, + include: list = None, limit: int = None, + offset: int = None, load_all: bool = False): + """Load subject log record(s). + + Parameters + ---------- + subject : Filter by subject UUID. + tags : Filter by tag. + """ + return self._convenience_load( + "subjectlog", + default_include=[], + filter_map={"subject": "subject.id", "tags": "tags"}, + portal=portal, id=id, filters=filters, sort=sort, include=include, + limit=limit, offset=offset, load_all=load_all, + subject=subject, tags=tags, + ) + + def load_equipment(self, portal="private", id: str = None, + name: str = None, setup: str = None, tags: str = None, + filters: dict = None, sort: list = None, + include: list = None, limit: int = None, + offset: int = None, load_all: bool = False): + """Load equipment record(s). + + Parameters + ---------- + name : Filter by equipment name (case-insensitive contains). + setup : Filter by setup UUID. + tags : Filter by tag. + """ + return self._convenience_load( + "equipment", + default_include=[], + filter_map={"name": "name.icontains", "setup": "setup.id", "tags": "tags"}, + portal=portal, id=id, filters=filters, sort=sort, include=include, + limit=limit, offset=offset, load_all=load_all, + name=name, setup=setup, tags=tags, + ) + + def load_consumablestock(self, portal="private", id: str = None, + tags: str = None, + filters: dict = None, sort: list = None, + include: list = None, limit: int = None, + offset: int = None, load_all: bool = False): + """Load consumable stock record(s). + + Parameters + ---------- + tags : Filter by tag. + """ + return self._convenience_load( + "consumablestock", + default_include=[], + filter_map={"tags": "tags"}, + portal=portal, id=id, filters=filters, sort=sort, include=include, + limit=limit, offset=offset, load_all=load_all, + tags=tags, + ) + + def load_behavioralassay(self, portal="private", id: str = None, + name: str = None, tags: str = None, + filters: dict = None, sort: list = None, + include: list = None, limit: int = None, + offset: int = None, load_all: bool = False): + """Load behavioral assay record(s). + + Parameters + ---------- + name : Filter by name (case-insensitive contains). + tags : Filter by tag. + """ + return self._convenience_load( + "behavioralassay", + default_include=[], + filter_map={"name": "name.icontains", "tags": "tags"}, + portal=portal, id=id, filters=filters, sort=sort, include=include, + limit=limit, offset=offset, load_all=load_all, + name=name, tags=tags, + ) + + def load_datastorage(self, portal="private", id: str = None, + name: str = None, tags: str = None, + filters: dict = None, sort: list = None, + include: list = None, limit: int = None, + offset: int = None, load_all: bool = False): + """Load data storage record(s). + + Parameters + ---------- + name : Filter by name (case-insensitive contains). + tags : Filter by tag. + """ + return self._convenience_load( + "datastorage", + default_include=[], + filter_map={"name": "name.icontains", "tags": "tags"}, + portal=portal, id=id, filters=filters, sort=sort, include=include, + limit=limit, offset=offset, load_all=load_all, + name=name, tags=tags, + ) + + def load_setup(self, portal="private", id: str = None, + name: str = None, tags: str = None, + filters: dict = None, sort: list = None, + include: list = None, limit: int = None, + offset: int = None, load_all: bool = False): + """Load setup record(s). + + Parameters + ---------- + name : Filter by name (case-insensitive contains). + tags : Filter by tag. + """ + return self._convenience_load( + "setup", + default_include=[], + filter_map={"name": "name.icontains", "tags": "tags"}, + portal=portal, id=id, filters=filters, sort=sort, include=include, + limit=limit, offset=offset, load_all=load_all, + name=name, tags=tags, + ) + + def load_hardwaredevice(self, portal="private", id: str = None, + name: str = None, tags: str = None, + filters: dict = None, sort: list = None, + include: list = None, limit: int = None, + offset: int = None, load_all: bool = False): + """Load hardware device record(s). + + Parameters + ---------- + name : Filter by name (case-insensitive contains). + tags : Filter by tag. + """ + return self._convenience_load( + "hardwaredevice", + default_include=[], + filter_map={"name": "name.icontains", "tags": "tags"}, + portal=portal, id=id, filters=filters, sort=sort, include=include, + limit=limit, offset=offset, load_all=load_all, + name=name, tags=tags, + ) + + def load_brainregion(self, portal="private", id: str = None, + name: str = None, tags: str = None, + filters: dict = None, sort: list = None, + include: list = None, limit: int = None, + offset: int = None, load_all: bool = False): + """Load brain region record(s). + + Parameters + ---------- + name : Filter by name (case-insensitive contains). + tags : Filter by tag. + """ + return self._convenience_load( + "brainregion", + default_include=[], + filter_map={"name": "name.icontains", "tags": "tags"}, + portal=portal, id=id, filters=filters, sort=sort, include=include, + limit=limit, offset=offset, load_all=load_all, + name=name, tags=tags, + ) + + def load_species(self, portal="private", id: str = None, + name: str = None, tags: str = None, + filters: dict = None, sort: list = None, + include: list = None, limit: int = None, + offset: int = None, load_all: bool = False): + """Load species record(s). + + Parameters + ---------- + name : Filter by name (case-insensitive contains). + tags : Filter by tag. + """ + return self._convenience_load( + "species", + default_include=[], + filter_map={"name": "name.icontains", "tags": "tags"}, + portal=portal, id=id, filters=filters, sort=sort, include=include, + limit=limit, offset=offset, load_all=load_all, + name=name, tags=tags, + ) + + def load_strain(self, portal="private", id: str = None, + name: str = None, species: str = None, tags: str = None, + filters: dict = None, sort: list = None, + include: list = None, limit: int = None, + offset: int = None, load_all: bool = False): + """Load strain record(s). + + Parameters + ---------- + name : Filter by name (case-insensitive contains). + species : Filter by species UUID. + tags : Filter by tag. + """ + return self._convenience_load( + "strain", + default_include=[], + filter_map={"name": "name.icontains", "species": "species.id", "tags": "tags"}, + portal=portal, id=id, filters=filters, sort=sort, include=include, + limit=limit, offset=offset, load_all=load_all, + name=name, species=species, tags=tags, + ) + + def load_publication(self, portal="private", id: str = None, + name: str = None, tags: str = None, + filters: dict = None, sort: list = None, + include: list = None, limit: int = None, + offset: int = None, load_all: bool = False): + """Load publication record(s). + + Parameters + ---------- + name : Filter by name (case-insensitive contains). + tags : Filter by tag. + """ + return self._convenience_load( + "publication", + default_include=[], + filter_map={"name": "name.icontains", "tags": "tags"}, + portal=portal, id=id, filters=filters, sort=sort, include=include, + limit=limit, offset=offset, load_all=load_all, + name=name, tags=tags, + ) + + def load_laboratory(self, portal="private", id: str = None, + name: str = None, tags: str = None, + filters: dict = None, sort: list = None, + include: list = None, limit: int = None, + offset: int = None, load_all: bool = False): + """Load laboratory record(s). + + Parameters + ---------- + name : Filter by name (case-insensitive contains). + tags : Filter by tag. + """ + return self._convenience_load( + "laboratory", + default_include=[], + filter_map={"name": "name.icontains", "tags": "tags"}, + portal=portal, id=id, filters=filters, sort=sort, include=include, + limit=limit, offset=offset, load_all=load_all, + name=name, tags=tags, + ) diff --git a/brainstem_api_tutorial.py b/brainstem_api_tutorial.py index a125d97..e959f8e 100644 --- a/brainstem_api_tutorial.py +++ b/brainstem_api_tutorial.py @@ -20,8 +20,7 @@ session = output1["sessions"][0] ## We can also filter the models by providing a dictionary where -## the keys are the fields and the values the possible contents. -## In this example, it will just load sessions whose name is "yeah". +## In this example, it will just load sessions whose name exactly equals "yeah". output1 = client.load('session', filters={'name': 'yeah'}).json() ## Loaded models can be sorted by different criteria applying to @@ -67,7 +66,7 @@ ## load_subject includes procedures and subjectlogs. ## Filter by sex ('M' / 'F') and/or project UUID: -project_uuid = 'e7475834-7733-48cf-9e3b-f4f2d2d0305a' +project_uuid = '' # replace with a real project UUID output_subjects = client.load_subject(sex='M', projects=project_uuid, load_all=True) ## load_project includes sessions, subjects, collections and cohorts. @@ -81,13 +80,13 @@ ## load_behavior / load_dataacquisition / load_manipulation all ## accept session= to scope results to a single session. -session_uuid = '0e39c1fd-f413-4142-95f7-f50185e81fa4' +session_uuid = '' # replace with a real session UUID output_behaviors = client.load_behavior(session=session_uuid, load_all=True) output_dataacquisition = client.load_dataacquisition(session=session_uuid, load_all=True) output_manipulations = client.load_manipulation(session=session_uuid, load_all=True) ## load_procedure scopes by subject UUID. -subject_uuid = 'bfb0e3e2-2c48-4b72-9034-1ef50c5c432a' +subject_uuid = '' # replace with a real subject UUID output_procedures = client.load_procedure(subject=subject_uuid, load_all=True) ## Any convenience loader also accepts extra filters= and custom @@ -107,7 +106,7 @@ ## previously loaded sessions. session = {} session["description"] = 'new description' -output2 = client.save("session", id="0e39c1fd-f413-4142-95f7-f50185e81fa4", data=session).json() +output2 = client.save("session", id="", data=session).json() # Creating a new session @@ -116,7 +115,7 @@ ## required fields. session = {} session["name"] = "New session" -session["projects"] = ['e7475834-7733-48cf-9e3b-f4f2d2d0305a'] +session["projects"] = [''] # replace with a real project UUID session["description"] = 'description' ## Submitting session @@ -126,7 +125,7 @@ # Deleting a session ## Pass the model name and the UUID of the record to remove. -response = client.delete("session", id="0e39c1fd-f413-4142-95f7-f50185e81fa4") +response = client.delete("session", id="") if response.status_code == 204: print("Session deleted") diff --git a/pyproject.toml b/pyproject.toml index e46a1f0..6e85b8c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ authors = [ ] requires-python = ">=3.8" dependencies = [ - "requests", + "requests>=2.28", ] [project.optional-dependencies] diff --git a/tests/test_client.py b/tests/test_client.py index 68d3531..b77b72f 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -450,3 +450,279 @@ def test_logout_no_token_file(self, tmp_path, capsys): cli_module.main() assert "No cached token found" in capsys.readouterr().out + + +# --------------------------------------------------------------------------- +# Timeouts +# --------------------------------------------------------------------------- + +class TestTimeouts: + def setup_method(self): + self.client = make_client() + + def test_load_passes_timeout(self): + self.client._session.get = MagicMock(return_value=mock_response(200, {})) + self.client.load("session") + _, kwargs = self.client._session.get.call_args + assert kwargs.get("timeout") == BrainstemClient.DEFAULT_TIMEOUT + + def test_save_post_passes_timeout(self): + self.client._session.post = MagicMock(return_value=mock_response(201)) + self.client.save("session", data={"name": "x"}) + _, kwargs = self.client._session.post.call_args + assert kwargs.get("timeout") == BrainstemClient.DEFAULT_TIMEOUT + + def test_save_patch_passes_timeout(self): + uid = "00000000-0000-0000-0000-000000000010" + self.client._session.patch = MagicMock(return_value=mock_response(200)) + self.client.save("session", id=uid, data={"description": "x"}) + _, kwargs = self.client._session.patch.call_args + assert kwargs.get("timeout") == BrainstemClient.DEFAULT_TIMEOUT + + def test_delete_passes_timeout(self): + uid = "00000000-0000-0000-0000-000000000011" + self.client._session.delete = MagicMock(return_value=mock_response(204)) + self.client.delete("session", id=uid) + _, kwargs = self.client._session.delete.call_args + assert kwargs.get("timeout") == BrainstemClient.DEFAULT_TIMEOUT + + def test_load_all_passes_timeout(self): + r = MagicMock() + r.raise_for_status = MagicMock() + r.status_code = 200 + r.json.return_value = {"sessions": [{"id": "1"}], "count": 1} + self.client._session.get = MagicMock(return_value=r) + self.client.load("session", load_all=True) + _, kwargs = self.client._session.get.call_args + assert kwargs.get("timeout") == BrainstemClient.DEFAULT_TIMEOUT + + +# --------------------------------------------------------------------------- +# 401 → AuthenticationError in load_all path +# --------------------------------------------------------------------------- + +class TestLoadAll401: + def test_401_raises_authentication_error(self): + client = make_client() + r = MagicMock() + r.status_code = 401 + r.raise_for_status = MagicMock() + client._session.get = MagicMock(return_value=r) + with pytest.raises(AuthenticationError, match="brainstem login"): + client.load("session", load_all=True) + + +# --------------------------------------------------------------------------- +# Device auth polling timeout +# --------------------------------------------------------------------------- + +class TestDeviceAuthTimeout: + def test_polling_timeout_raises(self): + client = make_client() + + device_resp = MagicMock() + device_resp.raise_for_status = MagicMock() + device_resp.json.return_value = { + "device_code": "dev123", + "user_code": "ABC-DEF", + "verification_uri": "https://brainstem.org/activate", + "verification_uri_complete": "https://brainstem.org/activate?code=ABC-DEF", + "interval": 0, + } + pending_resp = MagicMock() + pending_resp.raise_for_status = MagicMock() + pending_resp.json.return_value = {"status": "authorization_pending"} + + # max_wait=0 triggers immediate timeout on the first iteration check + with patch("requests.post", side_effect=[device_resp] + [pending_resp] * 10), \ + patch("webbrowser.open"), \ + patch("time.monotonic", side_effect=[0, 0, 999]): + with pytest.raises(AuthenticationError, match="timed out"): + client._device_auth_flow(max_wait=0) + + +# --------------------------------------------------------------------------- +# Context manager +# --------------------------------------------------------------------------- + +class TestContextManager: + def test_enter_returns_client(self): + client = make_client() + with client as c: + assert c is client + + def test_exit_closes_session(self): + client = make_client() + client._session.close = MagicMock() + with client: + pass + client._session.close.assert_called_once() + + +# --------------------------------------------------------------------------- +# Unknown kwarg validation in convenience loaders +# --------------------------------------------------------------------------- + +class TestUnknownKwarg: + def test_unknown_kwarg_raises(self): + client = make_client() + with pytest.raises(TypeError, match="Unexpected keyword argument"): + # Call _convenience_load directly with a kwarg not in filter_map + client._convenience_load( + "session", + default_include=[], + filter_map={"name": "name.icontains"}, + portal="private", id=None, filters=None, sort=None, + include=None, limit=None, offset=None, load_all=False, + name="test", + oops="extra", # not listed in filter_map + ) + + +# --------------------------------------------------------------------------- +# New convenience loaders +# --------------------------------------------------------------------------- + +class TestNewConvenienceLoaders: + def setup_method(self): + self.client = make_client() + self.client._session.get = MagicMock(return_value=mock_response(200, {})) + + def _get_params(self): + return self.client._session.get.call_args[1].get("params", {}) + + def _get_url(self): + return self.client._session.get.call_args[0][0] + + def test_load_procedurelog_procedure_filter(self): + uuid = "11111111-0000-0000-0000-000000000000" + self.client.load_procedurelog(procedure=uuid) + assert self._get_params()["filter{procedure.id}"] == uuid + assert "/modules/procedurelog/" in self._get_url() + + def test_load_subjectlog_subject_filter(self): + uuid = "22222222-0000-0000-0000-000000000000" + self.client.load_subjectlog(subject=uuid) + assert self._get_params()["filter{subject.id}"] == uuid + assert "/modules/subjectlog/" in self._get_url() + + def test_load_equipment_setup_filter(self): + uuid = "33333333-0000-0000-0000-000000000000" + self.client.load_equipment(setup=uuid) + assert self._get_params()["filter{setup.id}"] == uuid + assert "/modules/equipment/" in self._get_url() + + def test_load_equipment_name_filter(self): + self.client.load_equipment(name="headstage") + assert self._get_params()["filter{name.icontains}"] == "headstage" + + def test_load_consumablestock_url(self): + self.client.load_consumablestock() + assert "/modules/consumablestock/" in self._get_url() + + def test_load_behavioralassay_name_filter(self): + self.client.load_behavioralassay(name="open field") + assert self._get_params()["filter{name.icontains}"] == "open field" + assert "/personal_attributes/behavioralassay/" in self._get_url() + + def test_load_datastorage_url(self): + self.client.load_datastorage() + assert "/personal_attributes/datastorage/" in self._get_url() + + def test_load_setup_url(self): + self.client.load_setup() + assert "/personal_attributes/setup/" in self._get_url() + + def test_load_hardwaredevice_url(self): + self.client.load_hardwaredevice() + assert "/resources/hardwaredevice/" in self._get_url() + + def test_load_brainregion_name_filter(self): + self.client.load_brainregion(name="CA1") + assert self._get_params()["filter{name.icontains}"] == "CA1" + assert "/taxonomies/brainregion/" in self._get_url() + + def test_load_species_url(self): + self.client.load_species() + assert "/taxonomies/species/" in self._get_url() + + def test_load_strain_species_filter(self): + uuid = "44444444-0000-0000-0000-000000000000" + self.client.load_strain(species=uuid) + assert self._get_params()["filter{species.id}"] == uuid + assert "/taxonomies/strain/" in self._get_url() + + def test_load_publication_url(self): + self.client.load_publication() + assert "/dissemination/publication/" in self._get_url() + + def test_load_laboratory_name_filter(self): + self.client.load_laboratory(name="Petersen") + assert self._get_params()["filter{name.icontains}"] == "Petersen" + assert "/users/laboratory/" in self._get_url() + + +# --------------------------------------------------------------------------- +# CLI — load / save / delete commands +# --------------------------------------------------------------------------- + +class TestCLICommands: + def _run_cli(self, argv, client_mock): + import brainstem_api_tools.cli as cli_module + with patch("brainstem_api_tools.cli.BrainstemClient", return_value=client_mock), \ + patch("sys.argv", argv): + cli_module.main() + + def test_cli_load_calls_load(self, capsys): + resp = mock_response(200, {"sessions": []}) + client = MagicMock() + client.load.return_value = resp + self._run_cli(["brainstem", "--token", TOKEN, "load", "session"], client) + client.load.assert_called_once() + + def test_cli_load_filter_parsed(self, capsys): + resp = mock_response(200, {}) + client = MagicMock() + client.load.return_value = resp + self._run_cli( + ["brainstem", "--token", TOKEN, "load", "session", + "--filters", "name.icontains=rat"], + client, + ) + _, kwargs = client.load.call_args + assert kwargs["filters"] == {"name.icontains": "rat"} + + def test_cli_save_post(self, capsys): + resp = mock_response(201, {"session": {"id": "abc"}}) + client = MagicMock() + client.save.return_value = resp + self._run_cli( + ["brainstem", "--token", TOKEN, "save", "session", + "--data", '{"name": "x", "projects": []}'], + client, + ) + client.save.assert_called_once() + _, kwargs = client.save.call_args + assert kwargs["data"] == {"name": "x", "projects": []} + + def test_cli_delete(self, capsys): + resp = mock_response(204) + client = MagicMock() + client.delete.return_value = resp + self._run_cli( + ["brainstem", "--token", TOKEN, "delete", "session", + "--id", "00000000-0000-0000-0000-000000000099"], + client, + ) + client.delete.assert_called_once() + _, kwargs = client.delete.call_args + assert kwargs["id"] == "00000000-0000-0000-0000-000000000099" + + def test_cli_invalid_filter_exits(self): + import brainstem_api_tools.cli as cli_module + with patch("brainstem_api_tools.cli.BrainstemClient", return_value=MagicMock()), \ + patch("sys.argv", ["brainstem", "--token", TOKEN, "load", "session", + "--filters", "badfilter"]): + with pytest.raises(SystemExit): + cli_module.main() + From 4159329c5532210e6d59aa66798208adb953a09e Mon Sep 17 00:00:00 2001 From: "Peter C. Petersen" Date: Mon, 30 Mar 2026 15:21:40 +0200 Subject: [PATCH 3/3] Adding Copilot review: Improve load_all handling and CLI logout parser Enforce and clarify pagination/parameter behavior and fix CLI parser inheritance. - brainstem_api_client: always pass options to _build_url; disallow load_all=True when an id is provided; raise a clear ValueError if auto-pagination expects a list but the API response contains no list-valued key; change filter merging to include falsy values (check for None instead of truthiness). - cli: make the logout subparser inherit common arguments (parents=[common]). - tests: set mocked response.status_code to 200 for pagination tests. - README: remove the Contributing section. These changes make load_all usage safer and errors clearer, ensure logout CLI gets shared options, and update tests to better mock HTTP responses. --- README.md | 3 --- brainstem_api_tools/brainstem_api_client.py | 12 ++++++++++-- brainstem_api_tools/cli.py | 1 + tests/test_client.py | 1 + 4 files changed, 12 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 1c95367..8ada93d 100644 --- a/README.md +++ b/README.md @@ -103,9 +103,6 @@ client.save('session', id='', data={'description': 'updated'}) client.delete('session', id='') ``` -## Contributing -Contributions are welcome! Feel free to open issues or submit pull requests on GitHub. - ## Command-line Interface After installation a `brainstem` command is available in your shell. diff --git a/brainstem_api_tools/brainstem_api_client.py b/brainstem_api_tools/brainstem_api_client.py index 4c97a90..da55b1f 100644 --- a/brainstem_api_tools/brainstem_api_client.py +++ b/brainstem_api_tools/brainstem_api_client.py @@ -270,7 +270,7 @@ def load(self, model = _resolve_model(model) portal = _resolve_portal(portal) app = _MODEL_TO_APP[model] - url = self._build_url(portal, app, model, id, options if id else None) + url = self._build_url(portal, app, model, id, options) params = {} if not id: @@ -288,6 +288,9 @@ def load(self, if not load_all: return self._session.get(url, params=params, timeout=self.DEFAULT_TIMEOUT) + if id is not None: + raise ValueError("load_all=True cannot be used together with id.") + # --- auto-paginate and merge all pages --- page_size = limit or 100 params["limit"] = page_size @@ -310,6 +313,11 @@ def load(self, records_key = next( (k for k, v in data.items() if isinstance(v, list)), None ) + if records_key is None: + raise ValueError( + "load_all=True requires a paginated list response, but the API " + "returned no list-valued key. Use load_all=False for single-object endpoints." + ) combined = {k: v for k, v in data.items() if k != records_key} combined[records_key] = [] @@ -396,7 +404,7 @@ def _convenience_load(self, model, default_include, filter_map, merged_filters = dict(filters or {}) for kwarg, api_field in filter_map.items(): value = field_kwargs.get(kwarg) - if value: + if value is not None: merged_filters[api_field] = value return self.load( model, diff --git a/brainstem_api_tools/cli.py b/brainstem_api_tools/cli.py index 8cfe715..043b5d1 100644 --- a/brainstem_api_tools/cli.py +++ b/brainstem_api_tools/cli.py @@ -60,6 +60,7 @@ def _build_parser() -> argparse.ArgumentParser: # ---- logout ----------------------------------------------------- sub.add_parser( "logout", + parents=[common], help="Remove the cached API token.", ) diff --git a/tests/test_client.py b/tests/test_client.py index b77b72f..3851c2e 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -303,6 +303,7 @@ def _mock_pages(self, pages): responses = [] for page in pages: r = MagicMock() + r.status_code = 200 r.raise_for_status = MagicMock() r.json.return_value = page responses.append(r)