diff --git a/app/demo_adapter.py b/app/demo_adapter.py index 50a9267..b028554 100644 --- a/app/demo_adapter.py +++ b/app/demo_adapter.py @@ -1,27 +1,44 @@ +import base64 import datetime -import random -import uuid +import glob +import grp import os -import stat +import pathlib import pwd -import grp -import glob +import random +import stat import subprocess -import pathlib -import base64 +import uuid from typing import Any, Tuple -from pydantic import BaseModel + from fastapi import HTTPException -from .routers.common import AllocationUnit, Capability -from .routers.facility import models as facility_models, facility_adapter as facility_adapter -from .routers.status import models as status_models, facility_adapter as status_adapter -from .routers.account import models as account_models, facility_adapter as account_adapter -from .routers.compute import models as compute_models, facility_adapter as compute_adapter -from .routers.filesystem import models as filesystem_models, facility_adapter as filesystem_adapter -from .routers.task import models as task_models, facility_adapter as task_adapter +from pydantic import BaseModel + +from .routers.account import facility_adapter as account_adapter +from .routers.account import models as account_models +from .routers.compute import facility_adapter as compute_adapter +from .routers.compute import models as compute_models +from .routers.facility import facility_adapter +from .routers.facility import models as facility_models +from .routers.filesystem import facility_adapter as filesystem_adapter +from .routers.filesystem import models as filesystem_models +from .routers.status import facility_adapter as status_adapter +from .routers.status import models as status_models +from .routers.task import facility_adapter as task_adapter +from .routers.task import models as task_models +from .types.models import Capability +from .types.scalars import AllocationUnit DEMO_QUEUE_UPDATE_SECS = 5 +def paginate_list(items, offset: int | None, limit: int | None): + """Return a sliced items using offset and limit.""" + if offset is not None and offset > 0: + items = items[offset:] + if limit is not None and limit >= 0: + items = items[:limit] + return items + class PathSandbox: _base_temp_dir = None @@ -126,21 +143,30 @@ def _init_state(self): "gpfs": Capability(id=str(uuid.uuid4()), name="GPFS Storage", units=[AllocationUnit.bytes, AllocationUnit.inodes]), } - pm = status_models.Resource(id=str(uuid.uuid4()), group="perlmutter", name="compute nodes", description="the perlmutter computer compute nodes", capability_ids=[ - self.capabilities["cpu"].id, - self.capabilities["gpu"].id, - ], current_status=status_models.Status.degraded, last_modified=day_ago, resource_type=status_models.ResourceType.compute) - hpss = status_models.Resource(id=str(uuid.uuid4()), group="hpss", name="hpss", description="hpss tape storage", capability_ids=[self.capabilities["hpss"].id], current_status=status_models.Status.up, last_modified=day_ago, resource_type=status_models.ResourceType.storage) - cfs = status_models.Resource(id=str(uuid.uuid4()), group="cfs", name="cfs", description="cfs storage", capability_ids=[self.capabilities["gpfs"].id], current_status=status_models.Status.up, last_modified=day_ago, resource_type=status_models.ResourceType.storage) - - self.resources = [ - pm, - hpss, - cfs, - status_models.Resource(id=str(uuid.uuid4()), group="perlmutter", name="login nodes", description="the perlmutter computer login nodes", capability_ids=[], current_status=status_models.Status.degraded, last_modified=day_ago, resource_type=status_models.ResourceType.system), - status_models.Resource(id=str(uuid.uuid4()), group="services", name="Iris", description="Iris webapp", capability_ids=[], current_status=status_models.Status.down, last_modified=day_ago, resource_type=status_models.ResourceType.website), - status_models.Resource(id=str(uuid.uuid4()), group="services", name="sfapi", description="the Superfacility API", capability_ids=[], current_status=status_models.Status.up, last_modified=day_ago, resource_type=status_models.ResourceType.service), - ] + pm = status_models.Resource(id=str(uuid.uuid4()), site_id=site1.id, group="perlmutter", name="compute nodes", description="the perlmutter computer compute nodes", + capability_ids=[self.capabilities["cpu"].id, self.capabilities["gpu"].id,], current_status=status_models.Status.degraded, + last_modified=day_ago, resource_type=status_models.ResourceType.compute, located_at_uri=site1.self_uri) + + hpss = status_models.Resource(id=str(uuid.uuid4()), site_id=site1.id, group="hpss", name="hpss", description="hpss tape storage", + capability_ids=[self.capabilities["hpss"].id], current_status=status_models.Status.up, + last_modified=day_ago, resource_type=status_models.ResourceType.storage, located_at_uri=site1.self_uri) + + cfs = status_models.Resource(id=str(uuid.uuid4()), site_id=site1.id, group="cfs", name="cfs", description="cfs storage", + capability_ids=[self.capabilities["gpfs"].id], current_status=status_models.Status.up, + last_modified=day_ago, resource_type=status_models.ResourceType.storage, located_at_uri=site1.self_uri) + + login = status_models.Resource(id=str(uuid.uuid4()), site_id=site2.id, group="perlmutter", name="login nodes", description="the perlmutter computer login nodes", + capability_ids=[], current_status=status_models.Status.degraded, + last_modified=day_ago, resource_type=status_models.ResourceType.system, located_at_uri=site2.self_uri) + + iris = status_models.Resource(id=str(uuid.uuid4()), site_id=site2.id, group="services", name="Iris", description="Iris webapp", + capability_ids=[], current_status=status_models.Status.down, + last_modified=day_ago, resource_type=status_models.ResourceType.website, located_at_uri=site2.self_uri) + sfapi = status_models.Resource(id=str(uuid.uuid4()), site_id=site2.id, group="services", name="sfapi", description="the Superfacility API", + capability_ids=[], current_status=status_models.Status.up, + last_modified=day_ago, resource_type=status_models.ResourceType.service, located_at_uri=site2.self_uri) + + self.resources = [pm, hpss, cfs, login, iris, sfapi] self.projects = [ account_models.Project( @@ -249,8 +275,8 @@ def _init_state(self): async def get_facility( self: "DemoAdapter", - modified_since: str | None = None, - ) -> facility_models.Facility: + modified_since: str | None = None + ) -> facility_models.Facility: return self.facility @@ -260,9 +286,8 @@ async def list_sites( name: str | None = None, offset: int | None = None, limit: int | None = None, - short_name: str | None = None, - ) -> list[facility_models.Site]: - + short_name: str | None = None + ) -> list[facility_models.Site]: sites = self.sites if name: @@ -283,9 +308,8 @@ async def list_sites( async def get_site( self: "DemoAdapter", site_id: str, - modified_since: str | None = None, - ) -> facility_models.Site: - + modified_since: str | None = None + ) -> facility_models.Site: site = next((s for s in self.sites if s.id == site_id), None) if not site: raise HTTPException(status_code=404, detail="Site not found") @@ -312,17 +336,19 @@ async def get_resources( modified_since : datetime.datetime | None = None, resource_type : status_models.ResourceType | None = None, current_status : status_models.Status | None = None, - capability: Capability | None = None + capability: Capability | None = None, + site_id: str | None = None ) -> list[status_models.Resource]: - return status_models.Resource.find(self.resources, name, description, group, modified_since, resource_type)[offset:offset + limit] + resources = status_models.Resource.find(self.resources, name=name, description=description, group=group, modified_since=modified_since, + resource_type=resource_type, current_status=current_status, capability=capability, site_id=site_id) + return paginate_list(resources, offset, limit) async def get_resource( self : "DemoAdapter", - id : str + id_ : str ) -> status_models.Resource: - return status_models.Resource.find_by_id(self.resources, id) - + return status_models.Resource.find_by_id(self.resources, id_) async def get_events( self : "DemoAdapter", @@ -336,17 +362,19 @@ async def get_events( from_ : datetime.datetime | None = None, to : datetime.datetime | None = None, time_ : datetime.datetime | None = None, - modified_since : datetime.datetime | None = None, + modified_since : datetime.datetime | None = None ) -> list[status_models.Event]: - return status_models.Event.find([e for e in self.events if e.incident_id == incident_id], resource_id, name, description, status, from_, to, time_, modified_since)[offset:offset + limit] + events = status_models.Event.find([e for e in self.events if e.incident_id == incident_id], resource_id=resource_id, name=name, description=description, + status=status, from_=from_, to=to, time_=time_, modified_since=modified_since) + return paginate_list(events, offset, limit) async def get_event( self : "DemoAdapter", incident_id : str, - id : str + id_ : str ) -> status_models.Event: - return status_models.Event.find_by_id(self.events, id) + return status_models.Event.find_by_id(self.events, id_) async def get_incidents( @@ -356,26 +384,32 @@ async def get_incidents( name : str | None = None, description : str | None = None, status : status_models.Status | None = None, - type : status_models.IncidentType | None = None, + type_ : status_models.IncidentType | None = None, from_ : datetime.datetime | None = None, to : datetime.datetime | None = None, time_ : datetime.datetime | None = None, modified_since : datetime.datetime | None = None, resource_id : str | None = None, - resolution: status_models.Resolution | None = None, + resolution: status_models.Resolution | None = None ) -> list[status_models.Incident]: - return status_models.Incident.find(self.incidents, name, description, status, type, from_, to, time_, modified_since, resource_id)[offset:offset + limit] + incidents = status_models.Incident.find(self.incidents, name=name, description=description, status=status, type_=type_,from_=from_, to=to, + time_=time_, modified_since=modified_since, resource_id=resource_id, resolution=resolution) + return paginate_list(incidents, offset, limit) async def get_incident( self : "DemoAdapter", - id : str + id_ : str ) -> status_models.Incident: - return status_models.Incident.find_by_id(self.incidents, id) + return status_models.Incident.find_by_id(self.incidents, id_) async def get_capabilities( self : "DemoAdapter", + name : str | None = None, + modified_since : str | None = None, + offset : int = 0, + limit : int = 1000 ) -> list[Capability]: return self.capabilities.values() @@ -415,7 +449,7 @@ async def get_projects( async def get_project_allocations( self : "DemoAdapter", project: account_models.Project, - user: account_models.User + user: account_models.User, ) -> list[account_models.ProjectAllocation]: return [pa for pa in self.project_allocations if pa.project_id == project.id] @@ -614,7 +648,7 @@ async def chown( self : "DemoAdapter", resource: status_models.Resource, user: account_models.User, - request_model: filesystem_models.PutFileChownRequest + request_model: filesystem_models.PutFileChownRequest, ) -> filesystem_models.PutFileChownResponse: rp = self.validate_path(request_model.path) os.chown(rp, request_model.owner, request_model.group) @@ -684,7 +718,7 @@ async def tail( path: str, file_bytes: int | None, lines: int | None, - skip_trailing: bool, + skip_trailing: bool ) -> Tuple[Any, int]: return self._headtail("tail", path, file_bytes, lines) @@ -695,8 +729,8 @@ async def view( user: account_models.User, path: str, size: int, - offset: int, - ) -> filesystem_models.GetViewFileResponse: + offset: int + ) -> filesystem_models.GetViewFileResponse: rp = self.validate_path(path) result = subprocess.run( f"tail -c +{offset+1} {rp} | head -c {size}", @@ -714,8 +748,8 @@ async def checksum( self : "DemoAdapter", resource: status_models.Resource, user: account_models.User, - path: str, - ) -> filesystem_models.GetFileChecksumResponse: + path: str + ) -> filesystem_models.GetFileChecksumResponse: rp = self.validate_path(path) result = subprocess.run( ["sha256sum", rp], @@ -734,7 +768,7 @@ async def file( self : "DemoAdapter", resource: status_models.Resource, user: account_models.User, - path: str, + path: str ) -> filesystem_models.GetFileTypeResponse: rp = self.validate_path(path) result = subprocess.run( @@ -752,7 +786,7 @@ async def stat( resource: status_models.Resource, user: account_models.User, path: str, - dereference: bool, + dereference: bool ) -> filesystem_models.GetFileStatResponse: rp = self.validate_path(path) if dereference: @@ -792,8 +826,8 @@ async def mkdir( self : "DemoAdapter", resource: status_models.Resource, user: account_models.User, - request_model: filesystem_models.PostMakeDirRequest, - ) -> filesystem_models.PostMkdirResponse: + request_model: filesystem_models.PostMakeDirRequest + ) -> filesystem_models.PostMkdirResponse: rp = self.validate_path(request_model.path) args = ["mkdir"] if request_model.parent: @@ -809,7 +843,7 @@ async def symlink( self : "DemoAdapter", resource: status_models.Resource, user: account_models.User, - request_model: filesystem_models.PostFileSymlinkRequest, + request_model: filesystem_models.PostFileSymlinkRequest ) -> filesystem_models.PostFileSymlinkResponse: rp_src = self.validate_path(request_model.path) rp_dst = self.validate_path(request_model.link_path) @@ -823,7 +857,7 @@ async def download( self : "DemoAdapter", resource: status_models.Resource, user: account_models.User, - path: str, + path: str ) -> Any: rp = self.validate_path(path) raw_content = pathlib.Path(rp).read_bytes() @@ -839,7 +873,7 @@ async def upload( resource: status_models.Resource, user: account_models.User, path: str, - content: str, + content: str ) -> None: rp = self.validate_path(path) if isinstance(content, bytes): @@ -854,8 +888,8 @@ async def compress( self : "DemoAdapter", resource: status_models.Resource, user: account_models.User, - request_model: filesystem_models.PostCompressRequest, - ) -> filesystem_models.PostCompressResponse: + request_model: filesystem_models.PostCompressRequest + ) -> filesystem_models.PostCompressResponse: src_rp = self.validate_path(request_model.path) dst_rp = self.validate_path(request_model.target_path) @@ -887,8 +921,8 @@ async def extract( self : "DemoAdapter", resource: status_models.Resource, user: account_models.User, - request_model: filesystem_models.PostExtractRequest, - ) -> filesystem_models.PostExtractResponse: + request_model: filesystem_models.PostExtractRequest + ) -> filesystem_models.PostExtractResponse: src_rp = self.validate_path(request_model.path) dst_rp = self.validate_path(request_model.target_path) @@ -915,8 +949,8 @@ async def mv( self : "DemoAdapter", resource: status_models.Resource, user: account_models.User, - request_model: filesystem_models.PostMoveRequest, - ) -> filesystem_models.PostMoveResponse: + request_model: filesystem_models.PostMoveRequest + ) -> filesystem_models.PostMoveResponse: src_rp = self.validate_path(request_model.path) dst_rp = self.validate_path(request_model.target_path) subprocess.run(["mv", src_rp, dst_rp], check=True) @@ -929,8 +963,8 @@ async def cp( self : "DemoAdapter", resource: status_models.Resource, user: account_models.User, - request_model: filesystem_models.PostCopyRequest, - ) -> filesystem_models.PostCopyResponse: + request_model: filesystem_models.PostCopyRequest + ) -> filesystem_models.PostCopyResponse: src_rp = self.validate_path(request_model.path) dst_rp = self.validate_path(request_model.target_path) args = ["cp"] @@ -947,7 +981,7 @@ async def cp( async def get_task( self : "DemoAdapter", user: account_models.User, - task_id: str, + task_id: str ) -> task_models.Task|None: await DemoTaskQueue._process_tasks(self) return next((t for t in DemoTaskQueue.tasks if t.user.name == user.name and t.id == task_id), None) @@ -955,7 +989,7 @@ async def get_task( async def get_tasks( self : "DemoAdapter", - user: account_models.User, + user: account_models.User ) -> list[task_models.Task]: await DemoTaskQueue._process_tasks(self) return [t for t in DemoTaskQueue.tasks if t.user.name == user.name] @@ -965,15 +999,14 @@ async def put_task( self: "DemoAdapter", user: account_models.User, resource: status_models.Resource, - body: str - ) -> str: + task: str) -> str: await DemoTaskQueue._process_tasks(self) - return DemoTaskQueue._create_task(user, resource, body) + return DemoTaskQueue._create_task(user, resource, task) class DemoTask(BaseModel): id: str - body: str + task: str resource: status_models.Resource user: account_models.User start: float @@ -996,7 +1029,7 @@ async def _process_tasks(da: DemoAdapter): t.status = task_models.TaskStatus.active t.start = now elif t.status == task_models.TaskStatus.active and now - t.start > DEMO_QUEUE_UPDATE_SECS: - cmd = task_models.TaskCommand.model_validate_json(t.body) + cmd = task_models.TaskCommand.model_validate_json(t.task) (result, status) = await DemoAdapter.on_task(t.resource, t.user, cmd) t.result = result t.status = status @@ -1007,5 +1040,5 @@ async def _process_tasks(da: DemoAdapter): @staticmethod def _create_task(user: account_models.User, resource: status_models.Resource, command: task_models.TaskCommand) -> str: task_id = f"task_{len(DemoTaskQueue.tasks)}" - DemoTaskQueue.tasks.append(DemoTask(id=task_id, body=command.model_dump_json(), user=user, resource=resource, start=utc_timestamp())) + DemoTaskQueue.tasks.append(DemoTask(id=task_id, task=command.model_dump_json(), user=user, resource=resource, start=utc_timestamp())) return task_id diff --git a/app/routers/account/account.py b/app/routers/account/account.py index f856312..abd55ad 100644 --- a/app/routers/account/account.py +++ b/app/routers/account/account.py @@ -1,9 +1,11 @@ -from fastapi import HTTPException, Request, Depends, Query -from . import models, facility_adapter +from fastapi import Depends, HTTPException, Query, Request + +from ...types.http import forbidExtraQueryParams +from ...types.models import Capability +from ...types.scalars import StrictDateTime from .. import iri_router from ..error_handlers import DEFAULT_RESPONSES -from ..common import forbidExtraQueryParams, StrictDateTime, Capability - +from . import facility_adapter, models router = iri_router.IriRouter( facility_adapter.FacilityAdapter, @@ -18,8 +20,7 @@ description="Get a list of capabilities at this facility.", responses=DEFAULT_RESPONSES, operation_id="getCapabilities", - -) + response_model_exclude_none=True) async def get_capabilities( request : Request, name : str = Query(default=None, min_length=1), @@ -28,7 +29,7 @@ async def get_capabilities( limit : int = Query(default=100, ge=0, le=1000), _forbid = Depends(forbidExtraQueryParams("name", "modified_since", "offset", "limit")), ) -> list[Capability]: - return await router.adapter.get_capabilities() + return await router.adapter.get_capabilities(name=name, modified_since=modified_since, offset=offset, limit=limit) @router.get( @@ -44,7 +45,7 @@ async def get_capability( modified_since: StrictDateTime = Query(default=None), _forbid = Depends(forbidExtraQueryParams("modified_since")), ) -> Capability: - caps = await router.adapter.get_capabilities() + caps = await router.adapter.get_capabilities(name=None, modified_since=modified_since, offset=0, limit=100) cc = next((c for c in caps if c.id == capability_id), None) if not cc: raise HTTPException(status_code=404, detail="Capability not found") @@ -63,7 +64,7 @@ async def get_projects( request : Request, _forbid = Depends(forbidExtraQueryParams()), ) -> list[models.Project]: - user = await router.adapter.get_user(request.state.current_user_id, request.state.api_key, iri_router.get_client_ip(request)) + user = await router.adapter.get_user(user_id=request.state.current_user_id, api_key=request.state.api_key, client_ip=iri_router.get_client_ip(request)) if not user: raise HTTPException(status_code=404, detail="User not found") return await router.adapter.get_projects(user) @@ -82,10 +83,10 @@ async def get_project( request : Request, _forbid = Depends(forbidExtraQueryParams()), ) -> models.Project: - user = await router.adapter.get_user(request.state.current_user_id, request.state.api_key, iri_router.get_client_ip(request)) + user = await router.adapter.get_user(user_id=request.state.current_user_id, api_key=request.state.api_key, client_ip=iri_router.get_client_ip(request)) if not user: raise HTTPException(status_code=404, detail="User not found") - projects = await router.adapter.get_projects(user) + projects = await router.adapter.get_projects(user=user) pp = next((p for p in projects if p.id == project_id), None) if not pp: raise HTTPException(status_code=404, detail="Project not found") @@ -105,14 +106,14 @@ async def get_project_allocations( request : Request, _forbid = Depends(forbidExtraQueryParams()), ) -> list[models.ProjectAllocation]: - user = await router.adapter.get_user(request.state.current_user_id, request.state.api_key, iri_router.get_client_ip(request)) + user = await router.adapter.get_user(user_id=request.state.current_user_id, api_key=request.state.api_key, client_ip=iri_router.get_client_ip(request)) if not user: raise HTTPException(status_code=404, detail="User not found") - projects = await router.adapter.get_projects(user) + projects = await router.adapter.get_projects(user=user) project = next((p for p in projects if p.id == project_id), None) if not project: raise HTTPException(status_code=404, detail="Project not found") - return await router.adapter.get_project_allocations(project, user) + return await router.adapter.get_project_allocations(project=project, user=user) @router.get( @@ -129,12 +130,12 @@ async def get_project_allocation( request : Request, _forbid = Depends(forbidExtraQueryParams()), ) -> models.ProjectAllocation: - user = await router.adapter.get_user(request.state.current_user_id, request.state.api_key, iri_router.get_client_ip(request)) + user = await router.adapter.get_user(user_id=request.state.current_user_id, api_key=request.state.api_key, client_ip=iri_router.get_client_ip(request)) if not user: raise HTTPException(status_code=404, detail="User not found") - projects = await router.adapter.get_projects(user) + projects = await router.adapter.get_projects(user=user) project = next((p for p in projects if p.id == project_id), None) - pas = await router.adapter.get_project_allocations(project, user) + pas = await router.adapter.get_project_allocations(project=project, user=user) pa = next((pa for pa in pas if pa.id == project_allocation_id), None) if not pa: raise HTTPException(status_code=404, detail="Project allocation not found") @@ -155,18 +156,18 @@ async def get_user_allocations( request : Request, _forbid = Depends(forbidExtraQueryParams()), ) -> list[models.UserAllocation]: - user = await router.adapter.get_user(request.state.current_user_id, request.state.api_key, iri_router.get_client_ip(request)) + user = await router.adapter.get_user(user_id=request.state.current_user_id, api_key=request.state.api_key, client_ip=iri_router.get_client_ip(request)) if not user: raise HTTPException(status_code=404, detail="User not found") - projects = await router.adapter.get_projects(user) + projects = await router.adapter.get_projects(user=user) project = next((p for p in projects if p.id == project_id), None) if not project: raise HTTPException(status_code=404, detail="Project not found") - pas = await router.adapter.get_project_allocations(project, user) + pas = await router.adapter.get_project_allocations(project=project, user=user) pa = next((pa for pa in pas if pa.id == project_allocation_id), None) if not pa: raise HTTPException(status_code=404, detail="Project allocation not found") - return await router.adapter.get_user_allocations(user, pa) + return await router.adapter.get_user_allocations(user=user, project_allocation=pa) @router.get( @@ -184,18 +185,18 @@ async def get_user_allocation( request : Request, _forbid = Depends(forbidExtraQueryParams()), ) -> models.UserAllocation: - user = await router.adapter.get_user(request.state.current_user_id, request.state.api_key, iri_router.get_client_ip(request)) + user = await router.adapter.get_user(user_id=request.state.current_user_id, api_key=request.state.api_key, client_ip=iri_router.get_client_ip(request)) if not user: raise HTTPException(status_code=404, detail="User not found") - projects = await router.adapter.get_projects(user) + projects = await router.adapter.get_projects(user=user) project = next((p for p in projects if p.id == project_id), None) if not project: raise HTTPException(status_code=404, detail="Project not found") - pas = await router.adapter.get_project_allocations(project, user) + pas = await router.adapter.get_project_allocations(project=project, user=user) pa = next((pa for pa in pas if pa.id == project_allocation_id), None) if not pa: raise HTTPException(status_code=404, detail="Project allocation not found") - uas = await router.adapter.get_user_allocations(user, pa) + uas = await router.adapter.get_user_allocations(user=user, project_allocation=pa) ua = next((ua for ua in uas if ua.id == user_allocation_id), None) if not ua: raise HTTPException(status_code=404, detail="User allocation not found") diff --git a/app/routers/account/facility_adapter.py b/app/routers/account/facility_adapter.py index 235a2f7..fb2a66f 100644 --- a/app/routers/account/facility_adapter.py +++ b/app/routers/account/facility_adapter.py @@ -1,7 +1,8 @@ from abc import abstractmethod -from . import models as account_models -from ..common import Capability + +from ...types.models import Capability from ..iri_router import AuthenticatedAdapter +from . import models as account_models class FacilityAdapter(AuthenticatedAdapter): @@ -14,6 +15,10 @@ class FacilityAdapter(AuthenticatedAdapter): @abstractmethod async def get_capabilities( self : "FacilityAdapter", + name : str | None = None, + modified_since : str | None = None, + offset : int = 0, + limit : int = 1000 ) -> list[Capability]: pass @@ -39,6 +44,6 @@ async def get_project_allocations( async def get_user_allocations( self : "FacilityAdapter", user: account_models.User, - project_allocation: account_models.ProjectAllocation, + project_allocation: account_models.ProjectAllocation ) -> list[account_models.UserAllocation]: pass diff --git a/app/routers/account/models.py b/app/routers/account/models.py index 1a9333d..6bde3be 100644 --- a/app/routers/account/models.py +++ b/app/routers/account/models.py @@ -1,6 +1,8 @@ -from pydantic import computed_field, Field +from pydantic import Field, computed_field + from ... import config -from ..common import IRIBaseModel, AllocationUnit +from ...types.base import IRIBaseModel +from ...types.scalars import AllocationUnit class User(IRIBaseModel): diff --git a/app/routers/common.py b/app/routers/common.py deleted file mode 100644 index fd2882f..0000000 --- a/app/routers/common.py +++ /dev/null @@ -1,204 +0,0 @@ -"""Default models used by multiple routers.""" -import datetime -import enum -from typing import Optional -from urllib.parse import parse_qs - -from pydantic_core import core_schema -from pydantic import BaseModel, ConfigDict, Field, computed_field, model_serializer -from fastapi import Request, HTTPException - -from .. import config - - -# These are Pydantic custom types for strict validation -# that are not implmented in Pydantic by default. -# ----------------------------------------------------------------------- -# StrictBool: a strict boolean type -class StrictBool: - """Strict boolean: - - Accepts: real booleans, 'true', 'false' - - Rejects everything else. - """ - - @classmethod - def __get_pydantic_core_schema__(cls, source, handler): - return core_schema.no_info_plain_validator_function(cls.validate) - - @staticmethod - def validate(value): - """Validate the input value as a strict boolean.""" - if isinstance(value, bool): - return value - if isinstance(value, str): - v = value.strip().lower() - if v == "true": - return True - if v == "false": - return False - raise ValueError("Invalid boolean value. Expected 'true' or 'false'.") - raise ValueError("Invalid boolean value. Expected true/false or 'true'/'false'.") - - @classmethod - def __get_pydantic_json_schema__(cls, schema, handler): - return { - "type": "boolean", - "description": "Strict boolean. Only true/false allowed (bool or string)." - } - -# ----------------------------------------------------------------------- -# StrictDateTime: a strict ISO8601 datetime type - -class StrictDateTime: - """ - Strict ISO8601 datetime: - - Accepts datetime objects - - Accepts ISO8601 strings: 2025-12-06T10:00:00Z, 2025-12-06T10:00:00+00:00 - - Converts 'Z' → UTC - - Converts naive datetimes → UTC - - Rejects integers ("0"), null, garbage strings, etc. - """ - - @classmethod - def __get_pydantic_core_schema__(cls, source, handler): - return core_schema.no_info_plain_validator_function(cls.validate) - - @staticmethod - def validate(value): - if isinstance(value, datetime.datetime): - return StrictDateTime._normalize(value) - if not isinstance(value, str): - raise ValueError("Invalid datetime value. Expected ISO8601 datetime string.") - v = value.strip() - if v.endswith("Z"): - v = v[:-1] + "+00:00" - try: - dt = datetime.datetime.fromisoformat(v) - except Exception as ex: - raise ValueError("Invalid datetime format. Expected ISO8601 string.") from ex - - return StrictDateTime._normalize(dt) - - @staticmethod - def _normalize(dt: datetime.datetime) -> datetime.datetime: - if dt.tzinfo is None: - return dt.replace(tzinfo=datetime.timezone.utc) - return dt - - @classmethod - def __get_pydantic_json_schema__(cls, schema, handler): - return { - "type": "string", - "format": "date-time", - "description": "Strict ISO8601 datetime. Only valid ISO8601 datetime strings are accepted." - } - - -def forbidExtraQueryParams(*allowedParams: str, multiParams: set[str] | None = None): - multiParams = multiParams or set() - - async def checker(req: Request): - if "*" in allowedParams: - return - - raw_qs = req.scope.get("query_string", b"") - parsed = parse_qs(raw_qs.decode("utf-8", errors="strict"), keep_blank_values=True) - - allowed = set(allowedParams) - - for key, values in parsed.items(): - if key not in allowed: - raise HTTPException(status_code=422, - detail=[{"type": "extra_forbidden", - "loc": ["query", key], - "msg": f"Unexpected query parameter: {key}"}]) - - - if len(values) > 1 and key not in multiParams: - raise HTTPException(status_code=422, - detail=[{"type": "duplicate_forbidden", - "loc": ["query", key], - "msg": f"Duplicate query parameter: {key}"}]) - - return checker - - - - -class IRIBaseModel(BaseModel): - """Base model for IRI models.""" - model_config = ConfigDict(extra="allow") - - @model_serializer(mode="wrap") - def _hide_extra(self, handler, info): - data = handler(self) - - model_fields = set(self.model_fields or {}) - computed_fields = set(self.model_computed_fields or {}) - extra = getattr(self, "__pydantic_extra__", {}) or {} - for k in extra: - if k not in model_fields and k not in computed_fields: - data.pop(k, None) - return data - - def get_extra(self, key, default=None): - return getattr(self, "__pydantic_extra__", {}).get(key, default) - - -class NamedObject(IRIBaseModel): - id: str = Field(..., description="The unique identifier for the object. Typically a UUID or URN.") - def _self_path(self) -> str: - raise NotImplementedError - - @computed_field(description="The canonical URL of this object") - @property - def self_uri(self) -> str: - """Computed self URI property.""" - return f"{config.API_URL_ROOT}{config.API_PREFIX}{config.API_URL}{self._self_path()}" - - name: Optional[str] = Field(None, description="The long name of the object.") - description: Optional[str] = Field(None, description="Human-readable description of the object.") - last_modified: StrictDateTime = Field(..., description="ISO 8601 timestamp when this object was last modified.") - - @staticmethod - def find_by_id(a, id, allow_name: bool|None=False): - # Find a resource by its id. - # If allow_name is True, the id parameter can also match the resource's name. - return next((r for r in a if r.id == id or (allow_name and r.name == id)), None) - - - @staticmethod - def find(a, name, description, modified_since): - def normalize(dt: datetime) -> datetime: - # Convert naive datetimes into UTC-aware versions - if dt.tzinfo is None: - return dt.replace(tzinfo=datetime.timezone.utc) - return dt - if name: - a = [aa for aa in a if aa.name == name] - if description: - a = [aa for aa in a if description in aa.description] - if modified_since: - if modified_since.tzinfo is None: - modified_since = modified_since.replace(tzinfo=datetime.timezone.utc) - a = [aa for aa in a if normalize(aa.last_modified) >= modified_since] - return a - - -class AllocationUnit(enum.Enum): - node_hours = "node_hours" - bytes = "bytes" - inodes = "inodes" - - -class Capability(IRIBaseModel): - """ - An aspect of a resource that can have an allocation. - For example, Perlmutter nodes with GPUs - For some resources at a facility, this will be 1 to 1 with the resource. - It is a way to further subdivide a resource into allocatable sub-resources. - The word "capability" is also known to users as something they need for a job to run. (eg. gpu) - """ - id: str - name: str - units: list[AllocationUnit] \ No newline at end of file diff --git a/app/routers/compute/compute.py b/app/routers/compute/compute.py index e915048..d1eadf4 100644 --- a/app/routers/compute/compute.py +++ b/app/routers/compute/compute.py @@ -1,12 +1,12 @@ """Compute resource API router""" -from fastapi import HTTPException, Request, Depends, status, Query -from . import models, facility_adapter -from .. import iri_router +from fastapi import Depends, HTTPException, Query, Request, status +from ...types.http import forbidExtraQueryParams +from ...types.scalars import StrictHTTPBool +from .. import iri_router from ..error_handlers import DEFAULT_RESPONSES from ..status.status import router as status_router -from ..common import forbidExtraQueryParams, StrictBool - +from . import facility_adapter, models router = iri_router.IriRouter( facility_adapter.FacilityAdapter, @@ -36,7 +36,7 @@ async def submit_job( This command will attempt to submit a job and return its id. """ - user = await router.adapter.get_user(request.state.current_user_id, request.state.api_key, iri_router.get_client_ip(request)) + user = await router.adapter.get_user(user_id=request.state.current_user_id, api_key=request.state.api_key, client_ip=iri_router.get_client_ip(request)) if not user: raise HTTPException(status_code=404, detail="User not found") @@ -45,7 +45,7 @@ async def submit_job( # the handler can use whatever means it wants to submit the job and then fill in its id # see: https://exaworks.org/psij-python/docs/v/0.9.11/user_guide.html#submitting-jobs - return await router.adapter.submit_job(resource, user, job_spec) + return await router.adapter.submit_job(resource=resource, user=user, job_spec=job_spec) # TODO: this conflicts with PUT commented out while we finalize the API design @@ -73,7 +73,7 @@ async def submit_job( # # This command will attempt to submit a job and return its id. # """ -# user = await router.adapter.get_user(request.state.current_user_id, request.state.api_key, iri_router.get_client_ip(request)) +# user = await router.adapter.get_user(user_id=request.state.current_user_id, api_key=request.state.api_key, client_ip=iri_router.get_client_ip(request)) # if not user: # raise HTTPException(status_code=404, detail="User not found") # @@ -82,7 +82,7 @@ async def submit_job( # # # the handler can use whatever means it wants to submit the job and then fill in its id # # see: https://exaworks.org/psij-python/docs/v/0.9.11/user_guide.html#submitting-jobs -# return await router.adapter.submit_job_script(resource, user, job_script_path, args) +# return await router.adapter.submit_job_script(resource=resource, user=user, job_script_path=job_script_path, args=args) @router.put( @@ -108,7 +108,7 @@ async def update_job( - **job_request**: a PSIJ job spec as defined here """ - user = await router.adapter.get_user(request.state.current_user_id, request.state.api_key, iri_router.get_client_ip(request)) + user = await router.adapter.get_user(user_id=request.state.current_user_id, api_key=request.state.api_key, client_ip=iri_router.get_client_ip(request)) if not user: raise HTTPException(status_code=404, detail="User not found") @@ -117,7 +117,7 @@ async def update_job( # the handler can use whatever means it wants to submit the job and then fill in its id # see: https://exaworks.org/psij-python/docs/v/0.9.11/user_guide.html#submitting-jobs - return await router.adapter.update_job(resource, user, job_spec, job_id) + return await router.adapter.update_job(resource=resource, user=user, job_spec=job_spec, job_id=job_id) @router.get( @@ -132,12 +132,12 @@ async def get_job_status( resource_id : str, job_id : str, request : Request, - historical : StrictBool = Query(default=False, description="Whether to include historical jobs. Defaults to false"), - include_spec: StrictBool = Query(default=False, description="Whether to include the job specification. Defaults to false"), + historical : StrictHTTPBool = Query(default=False, description="Whether to include historical jobs. Defaults to false"), + include_spec: StrictHTTPBool = Query(default=False, description="Whether to include the job specification. Defaults to false"), _forbid = Depends(forbidExtraQueryParams("historical", "include_spec")), ): """Get a job's status""" - user = await router.adapter.get_user(request.state.current_user_id, request.state.api_key, iri_router.get_client_ip(request)) + user = await router.adapter.get_user(user_id=request.state.current_user_id, api_key=request.state.api_key, client_ip=iri_router.get_client_ip(request)) if not user: raise HTTPException(status_code=404, detail="User not found") @@ -145,7 +145,7 @@ async def get_job_status( # This could be done via slurm (in the adapter) or via psij's "attach" (https://exaworks.org/psij-python/docs/v/0.9.11/user_guide.html#detaching-and-attaching-jobs) resource = await status_router.adapter.get_resource(resource_id) - job = await router.adapter.get_job(resource, user, job_id, historical, include_spec) + job = await router.adapter.get_job(resource=resource, user=user, job_id=job_id, historical=historical, include_spec=include_spec) return job @@ -164,12 +164,12 @@ async def get_job_statuses( offset : int = Query(default=0, ge=0, le=1000), limit : int = Query(default=100, ge=0, le=1000), filters : dict[str, object] | None = None, - historical : StrictBool = Query(default=False, description="Whether to include historical jobs. Defaults to false"), - include_spec: StrictBool = Query(default=False, description="Whether to include the job specification. Defaults to false"), + historical : StrictHTTPBool = Query(default=False, description="Whether to include historical jobs. Defaults to false"), + include_spec: StrictHTTPBool = Query(default=False, description="Whether to include the job specification. Defaults to false"), _forbid = Depends(forbidExtraQueryParams("offset", "limit", "filters", "historical", "include_spec")), ): """Get multiple jobs' statuses""" - user = await router.adapter.get_user(request.state.current_user_id, request.state.api_key, iri_router.get_client_ip(request)) + user = await router.adapter.get_user(user_id=request.state.current_user_id, api_key=request.state.api_key, client_ip=iri_router.get_client_ip(request)) if not user: raise HTTPException(status_code=404, detail="User not found") @@ -177,7 +177,7 @@ async def get_job_statuses( # This could be done via slurm (in the adapter) or via psij's "attach" (https://exaworks.org/psij-python/docs/v/0.9.11/user_guide.html#detaching-and-attaching-jobs) resource = await status_router.adapter.get_resource(resource_id) - jobs = await router.adapter.get_jobs(resource, user, offset, limit, filters, historical, include_spec) + jobs = await router.adapter.get_jobs(resource=resource, user=user, offset=offset, limit=limit, filters=filters, historical=historical, include_spec=include_spec) return jobs @@ -198,13 +198,13 @@ async def cancel_job( _forbid = Depends(forbidExtraQueryParams()), ): """Cancel a job""" - user = await router.adapter.get_user(request.state.current_user_id, request.state.api_key, iri_router.get_client_ip(request)) + user = await router.adapter.get_user(user_id=request.state.current_user_id, api_key=request.state.api_key, client_ip=iri_router.get_client_ip(request)) if not user: raise HTTPException(status_code=404, detail="User not found") # look up the resource (todo: maybe ensure it's available) resource = await status_router.adapter.get_resource(resource_id) - await router.adapter.cancel_job(resource, user, job_id) + await router.adapter.cancel_job(resource=resource, user=user, job_id=job_id) return None diff --git a/app/routers/compute/facility_adapter.py b/app/routers/compute/facility_adapter.py index 6cf0bb2..910cacd 100644 --- a/app/routers/compute/facility_adapter.py +++ b/app/routers/compute/facility_adapter.py @@ -18,8 +18,8 @@ async def submit_job( self: "FacilityAdapter", resource: status_models.Resource, user: account_models.User, - job_spec: compute_models.JobSpec, - ) -> compute_models.Job: + job_spec: compute_models.JobSpec + ) -> compute_models.Job: pass @@ -29,7 +29,7 @@ async def submit_job_script( resource: status_models.Resource, user: account_models.User, job_script_path: str, - args: list[str] = [], + args: list[str] = [] ) -> compute_models.Job: pass @@ -40,7 +40,7 @@ async def update_job( resource: status_models.Resource, user: account_models.User, job_spec: compute_models.JobSpec, - job_id: str, + job_id: str ) -> compute_models.Job: pass @@ -66,7 +66,7 @@ async def get_jobs( limit : int, filters: dict[str, object] | None = None, historical: bool = False, - include_spec: bool = False, + include_spec: bool = False ) -> list[compute_models.Job]: pass @@ -76,6 +76,6 @@ async def cancel_job( self: "FacilityAdapter", resource: status_models.Resource, user: account_models.User, - job_id: str, + job_id: str ) -> bool: pass diff --git a/app/routers/compute/models.py b/app/routers/compute/models.py index a56d4fe..87142a4 100644 --- a/app/routers/compute/models.py +++ b/app/routers/compute/models.py @@ -1,7 +1,9 @@ -from typing import Annotated from enum import IntEnum -from pydantic import field_serializer, ConfigDict, StrictBool, Field -from ..common import IRIBaseModel +from typing import Annotated + +from pydantic import ConfigDict, Field, StrictBool, field_serializer + +from ...types.base import IRIBaseModel class ResourceSpec(IRIBaseModel): diff --git a/app/routers/error_handlers.py b/app/routers/error_handlers.py index 337b5fc..bf781be 100644 --- a/app/routers/error_handlers.py +++ b/app/routers/error_handlers.py @@ -3,7 +3,7 @@ Default problem schema and example responses for various HTTP status codes. """ import logging -from urllib.parse import unquote +from urllib.parse import urlsplit, urlunsplit, quote from fastapi import FastAPI, HTTPException, Request from fastapi.responses import JSONResponse from fastapi.exceptions import RequestValidationError @@ -16,12 +16,36 @@ def get_url_base(request: Request) -> str: proto = request.headers.get("x-forwarded-proto") or request.url.scheme return f"{proto}://{host}/problems" +def safe_instance_url(request: Request) -> str: + """Return a URL-safe version of the request URL for the 'instance' field.""" + parts = urlsplit(str(request.url)) + + # Encode unsafe characters in each component + safe_path = quote(parts.path, safe="/:@&+$,;=-._~") + safe_query = quote(parts.query, safe="=&?/:@+$,;=-._~") + safe_fragment = quote(parts.fragment, safe="=&?/:@+$,;=-._~") + + return urlunsplit((parts.scheme, parts.netloc, safe_path, safe_query, safe_fragment)) + def problem_response(*, request: Request, status: int, - title: str, detail: str, problem_type: str, + title, detail, problem_type: str, invalid_params=None, extra_headers=None): """Return a JSON problem response with the given status, title, and detail.""" - instance = unquote(str(request.url)) + instance = safe_instance_url(request) url_base = get_url_base(request) + + # Normalize title and detail to strings (Official spec says they must be strings) + # but fastapi validation errors may provide lists/dicts + if not isinstance(title, str): + title = "Error" + + if not isinstance(detail, str): + if isinstance(detail, list): + detail = ", ".join(err.get("msg", str(err)) if isinstance(err, dict) else str(err) + for err in detail) + else: + detail = str(detail) + body = { "type": f"{url_base}/{problem_type}", "title": title, @@ -34,7 +58,13 @@ def problem_response(*, request: Request, status: int, body["invalid_params"] = invalid_params headers = extra_headers or {} - return JSONResponse(status_code=status, content=body, headers=headers) + return JSONResponse( + status_code=status, + content=body, + headers=headers, + media_type="application/problem+json" + ) + def install_error_handlers(app: FastAPI): @@ -45,8 +75,7 @@ async def validation_error_handler(request: Request, exc: RequestValidationError invalid_params = [] for err in exc.errors(): - loc = err.get("loc", []) - name = loc[-1] if loc else "unknown" + name = str((err.get("loc") or ["unknown"])[-1]) reason = err.get("msg", "Invalid parameter") invalid_params.append({"name": name, "reason": reason}) diff --git a/app/routers/facility/facility.py b/app/routers/facility/facility.py index acd4d39..7a1a1ea 100644 --- a/app/routers/facility/facility.py +++ b/app/routers/facility/facility.py @@ -1,15 +1,14 @@ -from fastapi import Request, Depends, Query +from fastapi import Depends, Query, Request + +from ...types.http import forbidExtraQueryParams +from ...types.scalars import StrictDateTime from .. import iri_router from ..error_handlers import DEFAULT_RESPONSES -from .import models, facility_adapter -from ..common import StrictDateTime, forbidExtraQueryParams - +from . import facility_adapter, models -router = iri_router.IriRouter( - facility_adapter.FacilityAdapter, - prefix="/facility", - tags=["facility"], -) +router = iri_router.IriRouter(facility_adapter.FacilityAdapter, + prefix="/facility", + tags=["facility"]) @router.get("", responses=DEFAULT_RESPONSES, operation_id="getFacility") async def get_facility( @@ -41,4 +40,4 @@ async def get_site( _forbid = Depends(forbidExtraQueryParams("modified_since")), )-> models.Site: """Get site by ID""" - return await router.adapter.get_site(site_id=site_id, modified_since=modified_since) \ No newline at end of file + return await router.adapter.get_site(site_id=site_id, modified_since=modified_since) diff --git a/app/routers/facility/facility_adapter.py b/app/routers/facility/facility_adapter.py index 7758f24..b765640 100644 --- a/app/routers/facility/facility_adapter.py +++ b/app/routers/facility/facility_adapter.py @@ -13,8 +13,8 @@ class FacilityAdapter(AuthenticatedAdapter): @abstractmethod async def get_facility( self: "FacilityAdapter", - modified_since: str | None = None, - ) -> facility_models.Facility | None: + modified_since: str | None = None + ) -> facility_models.Facility | None: pass @abstractmethod @@ -24,8 +24,8 @@ async def list_sites( name: str | None = None, offset: int | None = None, limit: int | None = None, - short_name: str | None = None, - ) -> list[facility_models.Site]: + short_name: str | None = None + ) -> list[facility_models.Site]: pass @abstractmethod @@ -34,4 +34,4 @@ async def get_site( site_id: str, modified_since: str | None = None, ) -> facility_models.Site | None: - pass \ No newline at end of file + pass diff --git a/app/routers/facility/models.py b/app/routers/facility/models.py index a3a781c..021d9a2 100644 --- a/app/routers/facility/models.py +++ b/app/routers/facility/models.py @@ -1,8 +1,9 @@ """Facility-related models.""" -from typing import Optional, List +from typing import List, Optional + from pydantic import Field, HttpUrl -from ..common import NamedObject +from ...types.base import NamedObject class Site(NamedObject): @@ -20,6 +21,16 @@ def _self_path(self) -> str: longitude: Optional[float] = Field(None, description="Longitude of the Location.") resource_uris: List[HttpUrl] = Field(default_factory=list, description="URIs of Resources hosted at this Site.") + @classmethod + def find(cls, items, name=None, description=None, modified_since=None, short_name=None, country_name=None): + """ Find Locations matching the given criteria. """ + items = super().find(items, name=name, description=description, modified_since=modified_since) + if short_name: + items = [item for item in items if item.short_name == short_name] + if country_name: + items = [item for item in items if item.country_name == country_name] + return items + class Facility(NamedObject): def _self_path(self) -> str: diff --git a/app/routers/filesystem/facility_adapter.py b/app/routers/filesystem/facility_adapter.py index a70efb0..636b0a9 100644 --- a/app/routers/filesystem/facility_adapter.py +++ b/app/routers/filesystem/facility_adapter.py @@ -30,7 +30,7 @@ async def chmod( resource: status_models.Resource, user: account_models.User, request_model: filesystem_models.PutFileChmodRequest - ) -> filesystem_models.PutFileChmodResponse: + ) -> filesystem_models.PutFileChmodResponse: pass @@ -40,7 +40,7 @@ async def chown( resource: status_models.Resource, user: account_models.User, request_model: filesystem_models.PutFileChownRequest - ) -> filesystem_models.PutFileChownResponse: + ) -> filesystem_models.PutFileChownResponse: pass @@ -53,7 +53,7 @@ async def ls( show_hidden: bool, numeric_uid: bool, recursive: bool, - dereference: bool, + dereference: bool ) -> filesystem_models.GetDirectoryLsResponse: pass @@ -66,7 +66,7 @@ async def head( path: str, file_bytes: int, lines: int, - skip_trailing: bool, + skip_trailing: bool ) -> Tuple[Any, int]: pass @@ -79,8 +79,8 @@ async def tail( path: str, file_bytes: int | None, lines: int | None, - skip_trailing: bool, - ) -> Tuple[Any, int]: + skip_trailing: bool + ) -> Tuple[Any, int]: pass @@ -91,8 +91,8 @@ async def view( user: account_models.User, path: str, size: int, - offset: int, - ) -> filesystem_models.GetViewFileResponse: + offset: int + ) -> filesystem_models.GetViewFileResponse: pass @@ -101,8 +101,8 @@ async def checksum( self : "FacilityAdapter", resource: status_models.Resource, user: account_models.User, - path: str, - ) -> filesystem_models.GetFileChecksumResponse: + path: str + ) -> filesystem_models.GetFileChecksumResponse: pass @@ -111,8 +111,8 @@ async def file( self : "FacilityAdapter", resource: status_models.Resource, user: account_models.User, - path: str, - ) -> filesystem_models.GetFileTypeResponse: + path: str + ) -> filesystem_models.GetFileTypeResponse: pass @@ -122,8 +122,8 @@ async def stat( resource: status_models.Resource, user: account_models.User, path: str, - dereference: bool, - ) -> filesystem_models.GetFileStatResponse: + dereference: bool + ) -> filesystem_models.GetFileStatResponse: pass @@ -132,8 +132,7 @@ async def rm( self : "FacilityAdapter", resource: status_models.Resource, user: account_models.User, - path: str, - ): + path: str): pass @@ -142,8 +141,8 @@ async def mkdir( self : "FacilityAdapter", resource: status_models.Resource, user: account_models.User, - request_model: filesystem_models.PostMakeDirRequest, - ) -> filesystem_models.PostMkdirResponse: + request_model: filesystem_models.PostMakeDirRequest + ) -> filesystem_models.PostMkdirResponse: pass @@ -153,7 +152,7 @@ async def symlink( resource: status_models.Resource, user: account_models.User, request_model: filesystem_models.PostFileSymlinkRequest, - ) -> filesystem_models.PostFileSymlinkResponse: + ) -> filesystem_models.PostFileSymlinkResponse: pass @@ -162,8 +161,8 @@ async def download( self : "FacilityAdapter", resource: status_models.Resource, user: account_models.User, - path: str, - ) -> Any: + path: str + ) -> Any: pass @@ -173,8 +172,8 @@ async def upload( resource: status_models.Resource, user: account_models.User, path: str, - content: str, - ) -> None: + content: str + ) -> None: pass @@ -183,8 +182,8 @@ async def compress( self : "FacilityAdapter", resource: status_models.Resource, user: account_models.User, - request_model: filesystem_models.PostCompressRequest, - ) -> filesystem_models.PostCompressResponse: + request_model: filesystem_models.PostCompressRequest + ) -> filesystem_models.PostCompressResponse: pass @@ -193,8 +192,8 @@ async def extract( self : "FacilityAdapter", resource: status_models.Resource, user: account_models.User, - request_model: filesystem_models.PostExtractRequest, - ) -> filesystem_models.PostExtractResponse: + request_model: filesystem_models.PostExtractRequest + ) -> filesystem_models.PostExtractResponse: pass @@ -203,8 +202,8 @@ async def mv( self : "FacilityAdapter", resource: status_models.Resource, user: account_models.User, - request_model: filesystem_models.PostMoveRequest, - ) -> filesystem_models.PostMoveResponse: + request_model: filesystem_models.PostMoveRequest + ) -> filesystem_models.PostMoveResponse: pass @@ -213,6 +212,6 @@ async def cp( self : "FacilityAdapter", resource: status_models.Resource, user: account_models.User, - request_model: filesystem_models.PostCopyRequest, - ) -> filesystem_models.PostCopyResponse: + request_model: filesystem_models.PostCopyRequest + ) -> filesystem_models.PostCopyResponse: pass diff --git a/app/routers/filesystem/filesystem.py b/app/routers/filesystem/filesystem.py index d583c64..fb27427 100644 --- a/app/routers/filesystem/filesystem.py +++ b/app/routers/filesystem/filesystem.py @@ -34,12 +34,12 @@ async def _user_resource( resource_id: str, request: Request, ) -> tuple[account_models.User, status_models.Resource]: - user = await router.adapter.get_user(request.state.current_user_id, request.state.api_key, iri_router.get_client_ip(request)) + user = await router.adapter.get_user(user_id=request.state.current_user_id, api_key=request.state.api_key, client_ip=iri_router.get_client_ip(request)) if not user: raise HTTPException(status_code=404, detail="User not found") # look up the resource (todo: maybe ensure it's available) - resource = await status_router.adapter.get_resource(resource_id) + resource = await status_router.adapter.get_resource(resource_id=resource_id) if not resource: raise HTTPException(status_code=404, detail="Resource not found") return (user, resource) @@ -62,9 +62,9 @@ async def put_chmod( ) -> str: user, resource = await _user_resource(resource_id, request) return await router.task_adapter.put_task( - user, - resource, - task_models.TaskCommand( + user=user, + resource=resource, + task=task_models.TaskCommand( router=router.get_router_name(), command="chmod", args={ @@ -91,9 +91,9 @@ async def put_chown( ) -> str: user, resource = await _user_resource(resource_id, request) return await router.task_adapter.put_task( - user, - resource, - task_models.TaskCommand( + user=user, + resource=resource, + task=task_models.TaskCommand( router=router.get_router_name(), command="chown", args={ @@ -121,9 +121,9 @@ async def get_file( ) -> str: user, resource = await _user_resource(resource_id, request) return await router.task_adapter.put_task( - user, - resource, - task_models.TaskCommand( + user=user, + resource=resource, + task=task_models.TaskCommand( router=router.get_router_name(), command="file", args={ @@ -151,9 +151,9 @@ async def get_stat( ) -> str: user, resource = await _user_resource(resource_id, request) return await router.task_adapter.put_task( - user, - resource, - task_models.TaskCommand( + user=user, + resource=resource, + task=task_models.TaskCommand( router=router.get_router_name(), command="stat", args={ @@ -181,9 +181,9 @@ async def post_mkdir( ) -> str: user, resource = await _user_resource(resource_id, request) return await router.task_adapter.put_task( - user, - resource, - task_models.TaskCommand( + user=user, + resource=resource, + task=task_models.TaskCommand( router=router.get_router_name(), command="mkdir", args={ @@ -211,9 +211,9 @@ async def post_symlink( ) -> str: user, resource = await _user_resource(resource_id, request) return await router.task_adapter.put_task( - user, - resource, - task_models.TaskCommand( + user=user, + resource=resource, + task=task_models.TaskCommand( router=router.get_router_name(), command="symlink", args={ @@ -257,9 +257,9 @@ async def get_ls_async( ) -> str: user, resource = await _user_resource(resource_id, request) return await router.task_adapter.put_task( - user, - resource, - task_models.TaskCommand( + user=user, + resource=resource, + task=task_models.TaskCommand( router=router.get_router_name(), command="ls", args={ @@ -325,9 +325,9 @@ async def get_head( detail="Exactly one of `bytes` or `lines` must be specified." ) return await router.task_adapter.put_task( - user, - resource, - task_models.TaskCommand( + user=user, + resource=resource, + task=task_models.TaskCommand( router=router.get_router_name(), command="head", args={ @@ -363,9 +363,9 @@ async def get_view( user, resource = await _user_resource(resource_id, request) return await router.task_adapter.put_task( - user, - resource, - task_models.TaskCommand( + user=user, + resource=resource, + task=task_models.TaskCommand( router=router.get_router_name(), command="view", args={ @@ -422,9 +422,9 @@ async def get_tail( detail="Exactly one of `bytes` or `lines` must be specified." ) return await router.task_adapter.put_task( - user, - resource, - task_models.TaskCommand( + user=user, + resource=resource, + task=task_models.TaskCommand( router=router.get_router_name(), command="tail", args={ @@ -455,9 +455,9 @@ async def get_checksum( ) -> str: user, resource = await _user_resource(resource_id, request) return await router.task_adapter.put_task( - user, - resource, - task_models.TaskCommand( + user=user, + resource=resource, + task=task_models.TaskCommand( router=router.get_router_name(), command="checksum", args={ @@ -481,9 +481,9 @@ async def delete_rm( ) -> str: user, resource = await _user_resource(resource_id, request) return await router.task_adapter.put_task( - user, - resource, - task_models.TaskCommand( + user=user, + resource=resource, + task=task_models.TaskCommand( router=router.get_router_name(), command="rm", args={ @@ -510,9 +510,9 @@ async def post_compress( ) -> str: user, resource = await _user_resource(resource_id, request) return await router.task_adapter.put_task( - user, - resource, - task_models.TaskCommand( + user=user, + resource=resource, + task=task_models.TaskCommand( router=router.get_router_name(), command="compress", args={ @@ -539,9 +539,9 @@ async def post_extract( ) -> str: user, resource = await _user_resource(resource_id, request) return await router.task_adapter.put_task( - user, - resource, - task_models.TaskCommand( + user=user, + resource=resource, + task=task_models.TaskCommand( router=router.get_router_name(), command="extract", args={ @@ -568,9 +568,9 @@ async def move_mv( ) -> str: user, resource = await _user_resource(resource_id, request) return await router.task_adapter.put_task( - user, - resource, - task_models.TaskCommand( + user=user, + resource=resource, + task=task_models.TaskCommand( router=router.get_router_name(), command="mv", args={ @@ -597,9 +597,9 @@ async def post_cp( ) -> str: user, resource = await _user_resource(resource_id, request) return await router.task_adapter.put_task( - user, - resource, - task_models.TaskCommand( + user=user, + resource=resource, + task=task_models.TaskCommand( router=router.get_router_name(), command="cp", args={ @@ -625,9 +625,9 @@ async def get_download( ) -> str: user, resource = await _user_resource(resource_id, request) return await router.task_adapter.put_task( - user, - resource, - task_models.TaskCommand( + user=user, + resource=resource, + task=task_models.TaskCommand( router=router.get_router_name(), command="download", args={ @@ -665,9 +665,9 @@ async def post_upload( ) return await router.task_adapter.put_task( - user, - resource, - task_models.TaskCommand( + user=user, + resource=resource, + task=task_models.TaskCommand( router=router.get_router_name(), command="upload", args={ diff --git a/app/routers/iri_router.py b/app/routers/iri_router.py index f0b5b49..be31762 100644 --- a/app/routers/iri_router.py +++ b/app/routers/iri_router.py @@ -102,7 +102,7 @@ class AuthenticatedAdapter(ABC): async def get_current_user( self : "AuthenticatedAdapter", api_key: str, - client_ip: str|None, + client_ip: str|None ) -> str: """ Decode the api_key and return the authenticated user's id. @@ -117,7 +117,7 @@ async def get_user( self : "AuthenticatedAdapter", user_id: str, api_key: str, - client_ip: str|None, + client_ip: str|None ) -> User: """ Retrieve additional user information (name, email, etc.) for the given user_id. diff --git a/app/routers/status/facility_adapter.py b/app/routers/status/facility_adapter.py index d7358c5..6be6c88 100644 --- a/app/routers/status/facility_adapter.py +++ b/app/routers/status/facility_adapter.py @@ -1,8 +1,10 @@ -from abc import ABC, abstractmethod import datetime +from abc import ABC, abstractmethod + from fastapi import Query + +from ...types.models import Capability from . import models as status_models -from ..common import Capability class FacilityAdapter(ABC): @@ -25,6 +27,7 @@ async def get_resources( resource_type: status_models.ResourceType = Query(default=None), current_status: status_models.Status = Query(default=None), capability: Capability | None = None, + site_id: str | None = None ) -> list[status_models.Resource]: pass @@ -32,7 +35,7 @@ async def get_resources( @abstractmethod async def get_resource( self : "FacilityAdapter", - id : str + id_ : str ) -> status_models.Resource: pass @@ -49,8 +52,8 @@ async def get_events( status : status_models.Status | None = None, from_ : datetime.datetime | None = None, to : datetime.datetime | None = None, - time : datetime.datetime | None = None, - modified_since : datetime.datetime | None = None, + time_ : datetime.datetime | None = None, + modified_since : datetime.datetime | None = None ) -> list[status_models.Event]: pass @@ -59,7 +62,7 @@ async def get_events( async def get_event( self : "FacilityAdapter", incident_id : str, - id : str + id_ : str ) -> status_models.Event: pass @@ -78,7 +81,7 @@ async def get_incidents( time_ : datetime.datetime | None = None, modified_since : datetime.datetime | None = None, resource_id : str | None = None, - resolution: status_models.Resolution | None = None, + resolution: status_models.Resolution | None = None ) -> list[status_models.Incident]: pass @@ -86,6 +89,6 @@ async def get_incidents( @abstractmethod async def get_incident( self : "FacilityAdapter", - id : str + id_ : str ) -> status_models.Incident: pass diff --git a/app/routers/status/models.py b/app/routers/status/models.py index 15a9e36..7f0d51d 100644 --- a/app/routers/status/models.py +++ b/app/routers/status/models.py @@ -1,9 +1,11 @@ import datetime import enum from typing import Optional -from pydantic import BaseModel, computed_field, Field, HttpUrl + +from pydantic import BaseModel, Field, HttpUrl, computed_field, field_validator + from ... import config -from ..common import NamedObject +from ...types.base import NamedObject class Link(BaseModel): @@ -31,12 +33,16 @@ class ResourceType(enum.Enum): class Resource(NamedObject): def _self_path(self) -> str: + """ Return the API path for this resource. """ return f"/status/resources/{self.id}" - capability_ids: list[str] = Field(exclude=True) + # NOTE (TBR): If site_id is required, then located_at_uri should be also required. This can be easily identified by Site.self_uri + # Is there a specific Resource, that has no Site? + site_id: str = Field(..., description="The site identifier this resource is located at") + capability_ids: list[str] = Field(default_factory=list, exclude=True) + group: str | None + current_status: Status | None = Field(default=None, description="The current status comes from the status of the last event for this resource") resource_type: ResourceType - group: str | None = Field(None, description="Group this resource belongs to") - current_status: Status | None = Field(None, description="The current status comes from the status of the last event for this resource") located_at_uri: Optional[HttpUrl] = Field(None, description="Resource located at specific Site") @@ -44,23 +50,39 @@ def _self_path(self) -> str: @computed_field(description="The list of capabilities in this resource") @property def capability_uris(self) -> list[str]: + """ Return the list of capability URIs for this resource. """ return [f"{config.API_URL_ROOT}{config.API_PREFIX}{config.API_URL}/account/capabilities/{e}" for e in self.capability_ids] - @staticmethod - def find(resources, name, description, group, modified_since, resource_type): - a = NamedObject.find(resources, name, description, modified_since) + @classmethod + def find(cls, items, name=None, description=None, modified_since=None, group=None, + resource_type=None, current_status=None, capability=None, site_id=None) -> list: + items = super().find(items, name=name, description=description, modified_since=modified_since) if group: - a = [aa for aa in a if aa.group == group] + items = [item for item in items if item.group == group] if resource_type: - a = [aa for aa in a if aa.resource_type == resource_type] - return a - + if isinstance(resource_type, str): + resource_type = ResourceType(resource_type) + items = [item for item in items if item.resource_type == resource_type] + if current_status: + items = [item for item in items if item.current_status == current_status] + if capability: + items = [item for item in items + if any(cap_id in item.capability_ids for cap_id in capability)] + if site_id: + items = [item for item in items if item.site_id == site_id] + return items class Event(NamedObject): def _self_path(self) -> str: + """ Return the API path for this event. """ return f"/status/incidents/{self.incident_id}/events/{self.id}" + @field_validator("occurred_at", mode="before") + @classmethod + def _norm_dt_field(cls, v): + return cls.normalize_dt(v) + occurred_at : datetime.datetime status : Status resource_id : str = Field(exclude=True) @@ -69,38 +91,39 @@ def _self_path(self) -> str: @computed_field(description="The resource belonging to this event") @property def resource_uri(self) -> str: + """ Return the resource URI for this event. """ return f"{config.API_URL_ROOT}{config.API_PREFIX}{config.API_URL}/status/resources/{self.resource_id}" @computed_field(description="The event's incident") @property def incident_uri(self) -> str|None: + """ Return the incident URI for this event. """ return f"{config.API_URL_ROOT}{config.API_PREFIX}{config.API_URL}/status/incidents/{self.incident_id}" if self.incident_id else None - @staticmethod - def find( - events : list, - resource_id : str | None = None, - name : str | None = None, - description : str | None = None, - status : Status | None = None, - from_ : datetime.datetime | None = None, - to : datetime.datetime | None = None, - time_ : datetime.datetime | None = None, - modified_since : datetime.datetime | None = None, - ) -> list: - events = NamedObject.find(events, name, description, modified_since) + @classmethod + def find(cls, items, name=None, description=None, modified_since=None, + resource_id=None, status=None, from_=None, to=None, time_=None) -> list: + items = super().find(items, name=name, description=description, modified_since=modified_since) + if resource_id: - events = [e for e in events if e.resource_id == resource_id] + items = [e for e in items if e.resource_id == resource_id] if status: - events = [e for e in events if e.status == status] + if isinstance(status, str): + status = Status(status) + items = [e for e in items if e.status == status] + + from_ = cls.normalize_dt(from_) if from_ else None + to = cls.normalize_dt(to) if to else None + time_ = cls.normalize_dt(time_) if time_ else None + if from_: - events = [e for e in events if e.occurred_at >= from_] + items = [e for e in items if e.occurred_at >= from_] if to: - events = [e for e in events if e.occurred_at < to] + items = [e for e in items if e.occurred_at < to] if time_: - events = [e for e in events if e.occurred_at == time_] - return events + items = [e for e in items if e.occurred_at == time_] + return items class IncidentType(enum.Enum): @@ -121,11 +144,17 @@ class Resolution(enum.Enum): class Incident(NamedObject): def _self_path(self) -> str: + """ Return the API path for this incident. """ return f"/status/incidents/{self.id}" + @field_validator("start", "end", mode="before") + @classmethod + def _norm_dt_field(cls, v): + return cls.normalize_dt(v) + status : Status - resource_ids : list[str] = Field(exclude=True) - event_ids : list[str] = Field(exclude=True) + resource_ids : list[str] = Field(default_factory=list, exclude=True) + event_ids : list[str] = Field(default_factory=list, exclude=True) start : datetime.datetime end : datetime.datetime | None type : IncidentType @@ -134,37 +163,39 @@ def _self_path(self) -> str: @computed_field(description="The list of past events in this incident") @property def event_uris(self) -> list[str]: + """ Return the list of event URIs for this incident. """ return [f"{config.API_URL_ROOT}{config.API_PREFIX}{config.API_URL}/status/incidents/{self.id}/events/{e}" for e in self.event_ids] @computed_field(description="The list of resources that may be impacted by this incident") @property def resource_uris(self) -> list[str]: + """ Return the list of resource URIs for this incident. """ return [f"{config.API_URL_ROOT}{config.API_PREFIX}{config.API_URL}/status/resources/{r}" for r in self.resource_ids] - @staticmethod - def find( - incidents : list, - name : str | None = None, - description : str | None = None, - status : Status | None = None, - type_ : IncidentType | None = None, - from_ : datetime.datetime | None = None, - to : datetime.datetime | None = None, - time_ : datetime.datetime | None = None, - modified_since : datetime.datetime | None = None, - resource_id : str | None = None, - ) -> list: - incidents = NamedObject.find(incidents, name, description, modified_since) + @classmethod + def find(cls, items, name=None, description=None, modified_since=None, status=None, + type_=None, from_= None, to = None, time_ = None, resource_id = None, resolution=None) -> list: + items = super().find(items, name=name, description=description, modified_since=modified_since) + if resource_id: - incidents = [e for e in incidents if resource_id in e.resource_ids] + items = [e for e in items if resource_id in e.resource_ids] if status: - incidents = [e for e in incidents if e.status == status] + items = [e for e in items if e.status == status] if type_: - incidents = [e for e in incidents if e.type == type_] + items = [e for e in items if e.type == type_] + if resolution: + items = [e for e in items if e.resolution == resolution] + + from_ = cls.normalize_dt(from_) if from_ else None + to = cls.normalize_dt(to) if to else None + time_ = cls.normalize_dt(time_) if time_ else None + if from_: - incidents = [e for e in incidents if e.start >= from_] + items = [e for e in items if e.start >= from_] if to: - incidents = [e for e in incidents if e.end < to] + items = [e for e in items if e.end and e.end < to] + if time_: - incidents = [e for e in incidents if e.start <= time_ and e.end > time_] - return incidents + items = [e for e in items + if e.start <= time_ and (e.end is None or e.end > time_)] + return items \ No newline at end of file diff --git a/app/routers/status/status.py b/app/routers/status/status.py index 7f4fae6..6ffa0ba 100644 --- a/app/routers/status/status.py +++ b/app/routers/status/status.py @@ -1,9 +1,12 @@ -from typing import Optional, List, Annotated -from fastapi import HTTPException, Request, Query, Depends -from . import models, facility_adapter +from typing import Annotated, List, Optional + +from fastapi import Depends, HTTPException, Query, Request + +from ...types.http import forbidExtraQueryParams +from ...types.scalars import AllocationUnit, StrictDateTime from .. import iri_router from ..error_handlers import DEFAULT_RESPONSES -from ..common import StrictDateTime, forbidExtraQueryParams, AllocationUnit +from . import facility_adapter, models router = iri_router.IriRouter( facility_adapter.FacilityAdapter, @@ -17,6 +20,7 @@ description="Get a list of all resources at this facility. You can optionally filter the returned list by specifying attribtes.", responses=DEFAULT_RESPONSES, operation_id="getResources", + response_model_exclude_none=True ) async def get_resources( request : Request, @@ -31,7 +35,8 @@ async def get_resources( capability: List[AllocationUnit] = Query(default=None, min_length=1), _forbid = Depends(forbidExtraQueryParams("name", "description", "group", "offset", "limit", "modified_since", "resource_type", "current_status", "capability", multiParams={"capability"})), ) -> list[models.Resource]: - return await router.adapter.get_resources(offset, limit, name, description, group, modified_since, resource_type, current_status, capability) + return await router.adapter.get_resources(offset=offset, limit=limit, name=name, description=description, group=group, modified_since=modified_since, + resource_type=resource_type, current_status=current_status, capability=capability) @router.get( @@ -75,7 +80,8 @@ async def get_incidents( _forbid = Depends(forbidExtraQueryParams("name", "description", "status", "type", "from", "to", "time", "modified_since", "resource_id", "offset", "limit", "resolution", "resource_uris", "event_uris", multiParams={"resource_uris", "event_uris"})), ) -> list[models.Incident]: - return await router.adapter.get_incidents(offset, limit, name, description, status, type_, from_, to, time_, modified_since, resource_id, resolution) + return await router.adapter.get_incidents(offset=offset, limit=limit, name=name, description=description, status=status, type_=type_, from_=from_, to=to, + time_=time_, modified_since=modified_since, resource_id=resource_id, resolution=resolution) @router.get( @@ -118,7 +124,8 @@ async def get_events( limit : int = Query(default=100, ge=0, le=1000), _forbid = Depends(forbidExtraQueryParams("resource_id", "name", "description", "status", "from", "to", "time", "modified_since", "offset", "limit")), ) -> list[models.Event]: - return await router.adapter.get_events(incident_id, offset, limit, resource_id, name, description, status, from_, to, time_, modified_since) + return await router.adapter.get_events(incident_id, offset=offset, limit=limit, resource_id=resource_id, name=name, description=description, status=status, + from_=from_, to=to, time_=time_, modified_since=modified_since) @router.get( diff --git a/app/routers/task/facility_adapter.py b/app/routers/task/facility_adapter.py index 6659d15..47f686e 100644 --- a/app/routers/task/facility_adapter.py +++ b/app/routers/task/facility_adapter.py @@ -1,5 +1,4 @@ from abc import abstractmethod -from typing import Any from . import models as task_models from ..account import models as account_models from ..status import models as status_models @@ -19,7 +18,7 @@ class FacilityAdapter(AuthenticatedAdapter): async def get_task( self : "FacilityAdapter", user: account_models.User, - task_id: str, + task_id: str ) -> task_models.Task|None: pass @@ -27,7 +26,7 @@ async def get_task( @abstractmethod async def get_tasks( self : "FacilityAdapter", - user: account_models.User, + user: account_models.User ) -> list[task_models.Task]: pass @@ -37,8 +36,8 @@ async def put_task( self: "FacilityAdapter", user: account_models.User, resource: status_models.Resource|None, - command: task_models.TaskCommand - ) -> str: + task: task_models.TaskCommand + ) -> str: pass @@ -46,78 +45,78 @@ async def put_task( async def on_task( resource: status_models.Resource, user: account_models.User, - cmd: task_models.TaskCommand, - ) -> tuple[str, task_models.TaskStatus]: + task: task_models.TaskCommand + ) -> tuple[str, task_models.TaskStatus]: # Handle a task from the facility message queue. # Returns: (result, status) try: r = None - if cmd.router == "filesystem": - fs_adapter = IriRouter.create_adapter(cmd.router, filesystem_adapter.FacilityAdapter) - if cmd.command == "chmod": - request_model = filesystem_models.PutFileChmodRequest.model_validate(cmd.args["request_model"]) + if task.router == "filesystem": + fs_adapter = IriRouter.create_adapter(task.router, filesystem_adapter.FacilityAdapter) + if task.command == "chmod": + request_model = filesystem_models.PutFileChmodRequest.model_validate(task.args["request_model"]) o = await fs_adapter.chmod(resource, user, request_model) r = o.model_dump_json() - elif cmd.command == "chown": - request_model = filesystem_models.PutFileChownRequest.model_validate(cmd.args["request_model"]) + elif task.command == "chown": + request_model = filesystem_models.PutFileChownRequest.model_validate(task.args["request_model"]) o = await fs_adapter.chown(resource, user, request_model) r = o.model_dump_json() - elif cmd.command == "file": - o = await fs_adapter.file(resource, user, **cmd.args) + elif task.command == "file": + o = await fs_adapter.file(resource, user, **task.args) r = o.model_dump_json() - elif cmd.command == "stat": - o = await fs_adapter.stat(resource, user, **cmd.args) + elif task.command == "stat": + o = await fs_adapter.stat(resource, user, **task.args) r = o.model_dump_json() - elif cmd.command == "mkdir": - request_model = filesystem_models.PostMakeDirRequest.model_validate(cmd.args["request_model"]) + elif task.command == "mkdir": + request_model = filesystem_models.PostMakeDirRequest.model_validate(task.args["request_model"]) o = await fs_adapter.mkdir(resource, user, request_model) r = o.model_dump_json() - elif cmd.command == "symlink": - request_model = filesystem_models.PostFileSymlinkRequest.model_validate(cmd.args["request_model"]) + elif task.command == "symlink": + request_model = filesystem_models.PostFileSymlinkRequest.model_validate(task.args["request_model"]) o = await fs_adapter.symlink(resource, user, request_model) r = o.model_dump_json() - elif cmd.command == "ls": - o = await fs_adapter.ls(resource, user, **cmd.args) + elif task.command == "ls": + o = await fs_adapter.ls(resource, user, **task.args) r = o.model_dump_json() - elif cmd.command == "head": - o = await fs_adapter.head(resource, user, **cmd.args) + elif task.command == "head": + o = await fs_adapter.head(resource, user, **task.args) r = o.model_dump_json() - elif cmd.command == "view": - o = await fs_adapter.view(resource, user, **cmd.args) + elif task.command == "view": + o = await fs_adapter.view(resource, user, **task.args) r = o.model_dump_json() - elif cmd.command == "tail": - o = await fs_adapter.tail(resource, user, **cmd.args) + elif task.command == "tail": + o = await fs_adapter.tail(resource, user, **task.args) r = o.model_dump_json() - elif cmd.command == "checksum": - o = await fs_adapter.checksum(resource, user, **cmd.args) + elif task.command == "checksum": + o = await fs_adapter.checksum(resource, user, **task.args) r = o.model_dump_json() - elif cmd.command == "rm": - o = await fs_adapter.rm(resource, user, **cmd.args) + elif task.command == "rm": + o = await fs_adapter.rm(resource, user, **task.args) r = o.model_dump_json() - elif cmd.command == "compress": - request_model = filesystem_models.PostCompressRequest.model_validate(cmd.args["request_model"]) + elif task.command == "compress": + request_model = filesystem_models.PostCompressRequest.model_validate(task.args["request_model"]) o = await fs_adapter.compress(resource, user, request_model) r = o.model_dump_json() - elif cmd.command == "extract": - request_model = filesystem_models.PostExtractRequest.model_validate(cmd.args["request_model"]) + elif task.command == "extract": + request_model = filesystem_models.PostExtractRequest.model_validate(task.args["request_model"]) o = await fs_adapter.extract(resource, user, request_model) r = o.model_dump_json() - elif cmd.command == "mv": - request_model = filesystem_models.PostMoveRequest.model_validate(cmd.args["request_model"]) + elif task.command == "mv": + request_model = filesystem_models.PostMoveRequest.model_validate(task.args["request_model"]) o = await fs_adapter.mv(resource, user, request_model) r = o.model_dump_json() - elif cmd.command == "cp": - request_model = filesystem_models.PostCopyRequest.model_validate(cmd.args["request_model"]) + elif task.command == "cp": + request_model = filesystem_models.PostCopyRequest.model_validate(task.args["request_model"]) o = await fs_adapter.cp(resource, user, request_model) r = o.model_dump_json() - elif cmd.command == "download": - r = await fs_adapter.download(resource, user, **cmd.args) - elif cmd.command == "upload": - o = await fs_adapter.upload(resource, user, **cmd.args) + elif task.command == "download": + r = await fs_adapter.download(resource, user, **task.args) + elif task.command == "upload": + o = await fs_adapter.upload(resource, user, **task.args) r = "File uploaded successfully" if r: return (r, task_models.TaskStatus.completed) else: - return (f"Task was cancelled due to unknown router/command: {cmd.router}:{cmd.command}", task_models.TaskStatus.failed) + return (f"Task was cancelled due to unknown router/command: {task.router}:{task.command}", task_models.TaskStatus.failed) except Exception as exc: return (f"Error: {exc}", task_models.TaskStatus.failed) diff --git a/app/routers/task/task.py b/app/routers/task/task.py index 094cdd0..1a65459 100644 --- a/app/routers/task/task.py +++ b/app/routers/task/task.py @@ -22,10 +22,10 @@ async def get_task( task_id : str, ) -> models.Task: """Get a task""" - user = await router.adapter.get_user(request.state.current_user_id, request.state.api_key, iri_router.get_client_ip(request)) + user = await router.adapter.get_user(user_id=request.state.current_user_id, api_key=request.state.api_key, client_ip=iri_router.get_client_ip(request)) if not user: raise HTTPException(status_code=404, detail="User not found") - task = await router.adapter.get_task(user, task_id) + task = await router.adapter.get_task(user=user, task_id=task_id) if not task: raise HTTPException(status_code=404, detail=f"Task {task_id} not found") return task @@ -42,7 +42,7 @@ async def get_tasks( request : Request, ) -> list[models.Task]: """Get all tasks""" - user = await router.adapter.get_user(request.state.current_user_id, request.state.api_key, iri_router.get_client_ip(request)) + user = await router.adapter.get_user(user_id=request.state.current_user_id, api_key=request.state.api_key, client_ip=iri_router.get_client_ip(request)) if not user: raise HTTPException(status_code=404, detail="User not found") - return await router.adapter.get_tasks(user) + return await router.adapter.get_tasks(user=user) \ No newline at end of file diff --git a/app/types/__init__.py b/app/types/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/types/base.py b/app/types/base.py new file mode 100644 index 0000000..4a3b245 --- /dev/null +++ b/app/types/base.py @@ -0,0 +1,101 @@ +"""Default models used by multiple routers.""" +import datetime +from collections.abc import Iterable +from typing import Optional + +from pydantic import (BaseModel, ConfigDict, Field, computed_field, + field_validator, model_serializer) + +from .. import config +from .scalars import StrictDateTime + + +class IRIBaseModel(BaseModel): + """Base model for IRI models.""" + model_config = ConfigDict(extra="allow") + + @model_serializer(mode="wrap") + def _hide_extra(self, handler, info): + data = handler(self) + + model_fields = set(self.model_fields or {}) + computed_fields = set(self.model_computed_fields or {}) + extra = getattr(self, "__pydantic_extra__", {}) or {} + for k in extra: + if k not in model_fields and k not in computed_fields: + data.pop(k, None) + return data + + def get_extra(self, key, default=None): + """Get an extra field value that is not defined in the model. Returns default if not found.""" + return getattr(self, "__pydantic_extra__", {}).get(key, default) + + +class NamedObject(IRIBaseModel): + """Base model for named objects.""" + id: str = Field(..., description="The unique identifier for the object. Typically a UUID or URN.") + def _self_path(self) -> str: + raise NotImplementedError + + @classmethod + def normalize_dt(cls, dt: datetime | None) -> datetime | None: + """Normalize datetime to UTC-aware.""" + # Convert naive datetimes into UTC-aware versions + if dt is None: + return None + if isinstance(dt, str): + dt = StrictDateTime.validate(dt) + if dt.tzinfo is None: + return dt.replace(tzinfo=datetime.timezone.utc) + return dt + + @field_validator("last_modified", mode="before") + @classmethod + def _norm_dt_field(cls, v): + return cls.normalize_dt(v) + + @computed_field(description="The canonical URL of this object") + @property + def self_uri(self) -> str: + """Computed self URI property.""" + return f"{config.API_URL_ROOT}{config.API_PREFIX}{config.API_URL}{self._self_path()}" + + name: Optional[str] = Field(None, description="The long name of the object.") + description: Optional[str] = Field(None, description="Human-readable description of the object.") + last_modified: StrictDateTime = Field(..., description="ISO 8601 timestamp when this object was last modified.") + + @classmethod + def find_by_id(cls, items, id_, allow_name: bool = False): + """ Find an object by its id or name == id. """ + # Find a resource by its id. + # If allow_name is True, the id parameter can also match the resource's name. + matches = [r for r in items if r.id == id_ or (allow_name and r.name == id_)] + if not matches: + return None + if len(matches) > 1: + raise ValueError(f"Multiple {cls.__name__} objects matched identifier '{id_}'") + + return matches[0] + + @classmethod + def find(cls, items, name=None, description=None, modified_since=None): + """ Find objects matching the given criteria. """ + single = False + if not any((name, description, modified_since)): + return items + + if not isinstance(items, Iterable) or isinstance(items, BaseModel): + items = [items] + single = True + + if name: + items = [item for item in items if item.name == name] + if description: + items = [item for item in items if item.description and description in item.description] + if modified_since: + modified_since = cls.normalize_dt(modified_since) + items = [item for item in items + if item.last_modified and item.last_modified >= modified_since] + if single: + return items[0] if items else None + return items diff --git a/app/types/http.py b/app/types/http.py new file mode 100644 index 0000000..c59c8ef --- /dev/null +++ b/app/types/http.py @@ -0,0 +1,90 @@ +"""HTTP-related types and utilities for the IRI Facility API""" +import datetime +from email.utils import parsedate_to_datetime +from urllib.parse import parse_qs + +from fastapi import HTTPException, Request, status + +from .scalars import StrictDateTime + +# ----------------------------------------------------------------------- +# modifiedSinceDatetime: combine modified_since (ISO8601) and If-Modified-Since (RFC1123) +# If both are provided, the most recent timestamp is used. Strict validation is applied to both formats. +# modified_since must be a valid ISO8601 datetime string. +# If-Modified-Since must be a valid RFC1123 datetime string. +# TODO: If-Modified-Since is not yet supported/used by the API. + +def modifiedSinceDatetime( + modified_since: str | None, + header_modified_since: str | None +) -> datetime.datetime | None: + """ + Combine modified_since (ISO8601) and If-Modified-Since (RFC1123). + If both are provided, the most recent timestamp is used. + """ + + parsed_times: list[datetime.datetime] = [] + + # Query param (ISO 8601) + if modified_since is not None: + try: + dt = StrictDateTime.validate(modified_since) + parsed_times.append(dt) + except ValueError as exc: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Invalid modified_since query param: {exc}", + ) from exc + + # Header (RFC 1123) + if header_modified_since is not None: + try: + dt = parsedate_to_datetime(header_modified_since) + if dt is None: + raise ValueError("Invalid RFC1123 date") + if dt.tzinfo is None: + dt = dt.replace(tzinfo=datetime.timezone.utc) + parsed_times.append(dt.astimezone(datetime.timezone.utc)) + except Exception as exc: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid If-Modified-Since header format (must be RFC1123)", + ) from exc + + if not parsed_times: + return None + + # Stricter constraint wins + return max(parsed_times) + +# ----------------------------------------------------------------------- +# forbidExtraQueryParams: a dependency to forbid extra query parameters + +def forbidExtraQueryParams(*allowedParams: str, multiParams: set[str] | None = None): + """Dependency to forbid extra query parameters. If allowedParams contains "*", all params are allowed.""" + multiParams = multiParams or set() + + async def checker(req: Request): + if "*" in allowedParams: + return + + raw_qs = req.scope.get("query_string", b"") + parsed = parse_qs(raw_qs.decode("utf-8", errors="strict"), keep_blank_values=True) + + allowed = set(allowedParams) + + for key, values in parsed.items(): + if key not in allowed: + raise HTTPException(status_code=422, + detail=[{"type": "extra_forbidden", + "loc": ["query", key], + "msg": f"Unexpected query parameter: {key}"}]) + + + if len(values) > 1 and key not in multiParams: + raise HTTPException(status_code=422, + detail=[{"type": "duplicate_forbidden", + "loc": ["query", key], + "msg": f"Duplicate query parameter: {key}"}]) + + return checker diff --git a/app/types/models.py b/app/types/models.py new file mode 100644 index 0000000..9378f62 --- /dev/null +++ b/app/types/models.py @@ -0,0 +1,21 @@ +"""Models for the IRI Facility API.""" +from pydantic import Field + +from .base import NamedObject +from .scalars import AllocationUnit, StrictDateTime + + +class Capability(NamedObject): + """ + An aspect of a resource that can have an allocation. + For example, Perlmutter nodes with GPUs + For some resources at a facility, this will be 1 to 1 with the resource. + It is a way to further subdivide a resource into allocatable sub-resources. + The word "capability" is also known to users as something they need for a job to run. (eg. gpu) + """ + def _self_path(self) -> str: + return f"/account/capabilities/{self.id}" + + last_modified: StrictDateTime | None = Field(default=None, description="ISO 8601 timestamp when this object was last modified.") + + units: list[AllocationUnit] diff --git a/app/types/scalars.py b/app/types/scalars.py new file mode 100644 index 0000000..582efce --- /dev/null +++ b/app/types/scalars.py @@ -0,0 +1,97 @@ +"""Scalar types for the IRI Facility API""" +# pylint: disable=unused-argument +import datetime +import enum + +from pydantic_core import core_schema + +# ----------------------------------------------------------------------- +# StrictHTTPBool: a strict boolean type + +class StrictHTTPBool: + """Strict boolean: + - Accepts: real booleans, 'true', 'false' + - Rejects everything else. + """ + + @classmethod + def __get_pydantic_core_schema__(cls, source, handler): + return core_schema.no_info_plain_validator_function(cls.validate) + + @staticmethod + def validate(value): + """Validate the input value as a strict boolean.""" + if isinstance(value, bool): + return value + if isinstance(value, str): + v = value.strip().lower() + if v == "true": + return True + if v == "false": + return False + raise ValueError("Invalid boolean value. Expected 'true' or 'false'.") + raise ValueError("Invalid boolean value. Expected true/false or 'true'/'false'.") + + @classmethod + def __get_pydantic_json_schema__(cls, schema, handler): + return { + "type": "boolean", + "description": "Strict boolean. Only true/false allowed (bool or string)." + } + +# ----------------------------------------------------------------------- +# StrictDateTime: a strict ISO8601 datetime type + +class StrictDateTime: + """ + Strict ISO8601 datetime: + - Accepts datetime objects + - Accepts ISO8601 strings: 2025-12-06T10:00:00Z, 2025-12-06T10:00:00+00:00 + - Converts 'Z' → UTC + - Converts naive datetimes → UTC + - Rejects integers ("0"), null, garbage strings, etc. + """ + + @classmethod + def __get_pydantic_core_schema__(cls, source, handler): + return core_schema.no_info_plain_validator_function(cls.validate) + + @staticmethod + def validate(value): + """Validate the input value as a strict ISO8601 datetime.""" + if isinstance(value, datetime.datetime): + return StrictDateTime._normalize(value) + if not isinstance(value, str): + raise ValueError("Invalid datetime value. Expected ISO8601 datetime string.") + v = value.strip() + if v.endswith("Z"): + v = v[:-1] + "+00:00" + try: + dt = datetime.datetime.fromisoformat(v) + except Exception as ex: + raise ValueError("Invalid datetime format. Expected ISO8601 string.") from ex + + return StrictDateTime._normalize(dt) + + @staticmethod + def _normalize(dt: datetime.datetime) -> datetime.datetime: + if dt.tzinfo is None: + return dt.replace(tzinfo=datetime.timezone.utc) + return dt + + @classmethod + def __get_pydantic_json_schema__(cls, schema, handler): + return { + "type": "string", + "format": "date-time", + "description": "Strict ISO8601 datetime. Only valid ISO8601 datetime strings are accepted." + } + +# ----------------------------------------------------------------------- +# AllocationUnit: an enum for allocation units + +class AllocationUnit(enum.Enum): + """Units for allocation""" + node_hours = "node_hours" + bytes = "bytes" + inodes = "inodes"