44from typing import Annotated , Any , Literal , NamedTuple
55
66from fastapi import APIRouter , Body , Depends
7- from sqlalchemy import text
7+ from sqlalchemy import bindparam , text
88from sqlalchemy .engine import Row
99from sqlalchemy .ext .asyncio import AsyncConnection
1010
@@ -73,9 +73,26 @@ class DatasetStatusFilter(StrEnum):
7373 ALL = "all"
7474
7575
76+ def _quality_clause (quality : str , range_ : str | None ) -> str :
77+ if not range_ :
78+ return ""
79+ if not (match := re .match (integer_range_regex , range_ )):
80+ msg = f"`range_` not a valid range: { range_ } "
81+ raise ValueError (msg )
82+ start , end = match .groups ()
83+ value = f"`value` BETWEEN { start } AND { end [2 :]} " if end else f"`value`={ start } "
84+ return f""" AND
85+ d.`did` IN (
86+ SELECT `data`
87+ FROM data_quality
88+ WHERE `quality`='{ quality } ' AND { value }
89+ )
90+ """ # noqa: S608 - `quality` is not user provided, value is filtered with regex
91+
92+
7693@router .post (path = "/list" , description = "Provided for convenience, same as `GET` endpoint." )
7794@router .get (path = "/list" )
78- async def list_datasets ( # noqa: PLR0913
95+ async def list_datasets ( # noqa: PLR0913, C901
7996 pagination : Annotated [Pagination , Body (default_factory = Pagination )],
8097 data_name : Annotated [str | None , CasualString128 ] = None ,
8198 tag : Annotated [str | None , SystemString64 ] = None ,
@@ -103,7 +120,7 @@ async def list_datasets( # noqa: PLR0913
103120 expdb_db : Annotated [AsyncConnection , Depends (expdb_connection )] = None ,
104121) -> list [dict [str , Any ]]:
105122 assert expdb_db is not None # noqa: S101
106- current_status = text (
123+ status_subquery = text (
107124 """
108125 SELECT ds1.`did`, ds1.`status`
109126 FROM dataset_status as ds1
@@ -115,90 +132,78 @@ async def list_datasets( # noqa: PLR0913
115132 """ ,
116133 )
117134
118- if status == DatasetStatusFilter . ALL :
119- statuses = [
120- DatasetStatusFilter . ACTIVE ,
121- DatasetStatusFilter . DEACTIVATED ,
122- DatasetStatusFilter . IN_PREPARATION ,
123- ]
124- else :
125- statuses = [ status ]
135+ clauses = []
136+ parameters : dict [ str , Any ] = {
137+ "offset" : pagination . offset ,
138+ "limit" : pagination . limit ,
139+ }
140+ if status != DatasetStatusFilter . ALL :
141+ clauses . append ( "AND IFNULL(cs.`status`, 'in_preparation') = :status" )
142+ parameters [ "status" ] = status
126143
127- where_status = "," .join (f"'{ status } '" for status in statuses )
128144 if user is None :
129- visible_to_user = "`visibility`='public'"
130- elif UserGroup .ADMIN in await user .get_groups ():
131- visible_to_user = "TRUE"
132- else :
133- visible_to_user = f"(`visibility`='public' OR `uploader`={ user .user_id } )"
145+ clauses .append ("AND `visibility`='public'" )
146+ elif UserGroup .ADMIN not in await user .get_groups ():
147+ clauses .append ("AND (`visibility`='public' OR `uploader`=:user_id)" )
148+ parameters ["user_id" ] = user .user_id
149+
150+ if uploader :
151+ clauses .append ("AND `uploader`=:uploader" )
152+ parameters ["uploader" ] = uploader
153+
154+ if data_name :
155+ clauses .append ("AND `name`=:data_name" )
156+ parameters ["data_name" ] = data_name
157+
158+ if data_version :
159+ clauses .append ("AND `version`=:data_version" )
160+ parameters ["data_version" ] = data_version
134161
135- where_name = "" if data_name is None else "AND `name`=:data_name"
136- where_version = "" if data_version is None else "AND `version`=:data_version"
137- where_uploader = "" if uploader is None else "AND `uploader`=:uploader"
138- data_id_str = "," .join (str (did ) for did in data_id ) if data_id else ""
139- where_data_id = "" if not data_id else f"AND d.`did` IN ({ data_id_str } )"
162+ if data_id :
163+ clauses .append ("AND d.`did` IN :data_ids" )
164+ parameters ["data_ids" ] = data_id
140165
141166 # requires some benchmarking on whether e.g., IN () is more efficient.
142- matching_tag = (
143- text (
167+ if tag :
168+ clauses . append (
144169 """
145- AND d.`did` IN (
146- SELECT `id`
147- FROM dataset_tag as dt
148- WHERE dt.`tag`=:tag
149- )
150- """ ,
170+ AND d.`did` IN (
171+ SELECT `id`
172+ FROM dataset_tag as dt
173+ WHERE dt.`tag`=:tag
174+ )
175+ """ ,
151176 )
152- if tag
153- else ""
154- )
177+ parameters ["tag" ] = tag
155178
156- def quality_clause (quality : str , range_ : str | None ) -> str :
157- if not range_ :
158- return ""
159- if not (match := re .match (integer_range_regex , range_ )):
160- msg = f"`range_` not a valid range: { range_ } "
161- raise ValueError (msg )
162- start , end = match .groups ()
163- value = f"`value` BETWEEN { start } AND { end [2 :]} " if end else f"`value`={ start } "
164- return f""" AND
165- d.`did` IN (
166- SELECT `data`
167- FROM data_quality
168- WHERE `quality`='{ quality } ' AND { value }
169- )
170- """ # noqa: S608 - `quality` is not user provided, value is filtered with regex
179+ number_instances_filter = _quality_clause ("NumberOfInstances" , number_instances )
180+ number_classes_filter = _quality_clause ("NumberOfClasses" , number_classes )
181+ number_features_filter = _quality_clause ("NumberOfFeatures" , number_features )
182+ number_missing_values_filter = _quality_clause ("NumberOfMissingValues" , number_missing_values )
171183
172- number_instances_filter = quality_clause ("NumberOfInstances" , number_instances )
173- number_classes_filter = quality_clause ("NumberOfClasses" , number_classes )
174- number_features_filter = quality_clause ("NumberOfFeatures" , number_features )
175- number_missing_values_filter = quality_clause ("NumberOfMissingValues" , number_missing_values )
184+ columns = ["did" , "name" , "version" , "format" , "file_id" , "status" ]
176185 matching_filter = text (
177186 f"""
178187 SELECT d.`did`,d.`name`,d.`version`,d.`format`,d.`file_id`,
179188 IFNULL(cs.`status`, 'in_preparation')
180189 FROM dataset AS d
181- LEFT JOIN ({ current_status } ) AS cs ON d.`did`=cs.`did`
182- WHERE { visible_to_user } { where_name } { where_version } { where_uploader }
183- { where_data_id } { matching_tag } { number_instances_filter } { number_features_filter }
190+ LEFT JOIN ({ status_subquery } ) AS cs ON d.`did`=cs.`did`
191+ WHERE 1=1 { number_instances_filter } { number_features_filter }
184192 { number_classes_filter } { number_missing_values_filter }
185- AND IFNULL(cs.`status`, 'in_preparation') IN ( { where_status } )
186- LIMIT { pagination . limit } OFFSET { pagination . offset }
193+ { " " . join ( clauses ) }
194+ LIMIT : limit OFFSET : offset
187195 """ , # noqa: S608
188196 # I am not sure how to do this correctly without an error from Bandit here.
189197 # However, the `status` input is already checked by FastAPI to be from a set
190198 # of given options, so no injection is possible (I think). The `current_status`
191199 # subquery also has no user input. So I think this should be safe.
192200 )
193- columns = ["did" , "name" , "version" , "format" , "file_id" , "status" ]
201+
202+ if data_id :
203+ matching_filter .bindparams (bindparam ("data_ids" , expanding = True ))
194204 result = await expdb_db .execute (
195205 matching_filter ,
196- parameters = {
197- "tag" : tag ,
198- "data_name" : data_name ,
199- "data_version" : data_version ,
200- "uploader" : uploader ,
201- },
206+ parameters = parameters ,
202207 )
203208 rows = result .all ()
204209 datasets : dict [int , dict [str , Any ]] = {
0 commit comments