Skip to content

Commit a74b056

Browse files
authored
Refactor dataset list (#286)
Refactors the list dataset function to make it easier to read and check that input parameters are provided correctly.
1 parent 5a9eb8c commit a74b056

1 file changed

Lines changed: 69 additions & 64 deletions

File tree

src/routers/openml/datasets.py

Lines changed: 69 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import Annotated, Any, Literal, NamedTuple
55

66
from fastapi import APIRouter, Body, Depends
7-
from sqlalchemy import text
7+
from sqlalchemy import bindparam, text
88
from sqlalchemy.engine import Row
99
from sqlalchemy.ext.asyncio import AsyncConnection
1010

@@ -73,9 +73,26 @@ class DatasetStatusFilter(StrEnum):
7373
ALL = "all"
7474

7575

76+
def _quality_clause(quality: str, range_: str | None) -> str:
77+
if not range_:
78+
return ""
79+
if not (match := re.match(integer_range_regex, range_)):
80+
msg = f"`range_` not a valid range: {range_}"
81+
raise ValueError(msg)
82+
start, end = match.groups()
83+
value = f"`value` BETWEEN {start} AND {end[2:]}" if end else f"`value`={start}"
84+
return f""" AND
85+
d.`did` IN (
86+
SELECT `data`
87+
FROM data_quality
88+
WHERE `quality`='{quality}' AND {value}
89+
)
90+
""" # noqa: S608 - `quality` is not user provided, value is filtered with regex
91+
92+
7693
@router.post(path="/list", description="Provided for convenience, same as `GET` endpoint.")
7794
@router.get(path="/list")
78-
async def list_datasets( # noqa: PLR0913
95+
async def list_datasets( # noqa: PLR0913, C901
7996
pagination: Annotated[Pagination, Body(default_factory=Pagination)],
8097
data_name: Annotated[str | None, CasualString128] = None,
8198
tag: Annotated[str | None, SystemString64] = None,
@@ -103,7 +120,7 @@ async def list_datasets( # noqa: PLR0913
103120
expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)] = None,
104121
) -> list[dict[str, Any]]:
105122
assert expdb_db is not None # noqa: S101
106-
current_status = text(
123+
status_subquery = text(
107124
"""
108125
SELECT ds1.`did`, ds1.`status`
109126
FROM dataset_status as ds1
@@ -115,90 +132,78 @@ async def list_datasets( # noqa: PLR0913
115132
""",
116133
)
117134

118-
if status == DatasetStatusFilter.ALL:
119-
statuses = [
120-
DatasetStatusFilter.ACTIVE,
121-
DatasetStatusFilter.DEACTIVATED,
122-
DatasetStatusFilter.IN_PREPARATION,
123-
]
124-
else:
125-
statuses = [status]
135+
clauses = []
136+
parameters: dict[str, Any] = {
137+
"offset": pagination.offset,
138+
"limit": pagination.limit,
139+
}
140+
if status != DatasetStatusFilter.ALL:
141+
clauses.append("AND IFNULL(cs.`status`, 'in_preparation') = :status")
142+
parameters["status"] = status
126143

127-
where_status = ",".join(f"'{status}'" for status in statuses)
128144
if user is None:
129-
visible_to_user = "`visibility`='public'"
130-
elif UserGroup.ADMIN in await user.get_groups():
131-
visible_to_user = "TRUE"
132-
else:
133-
visible_to_user = f"(`visibility`='public' OR `uploader`={user.user_id})"
145+
clauses.append("AND `visibility`='public'")
146+
elif UserGroup.ADMIN not in await user.get_groups():
147+
clauses.append("AND (`visibility`='public' OR `uploader`=:user_id)")
148+
parameters["user_id"] = user.user_id
149+
150+
if uploader:
151+
clauses.append("AND `uploader`=:uploader")
152+
parameters["uploader"] = uploader
153+
154+
if data_name:
155+
clauses.append("AND `name`=:data_name")
156+
parameters["data_name"] = data_name
157+
158+
if data_version:
159+
clauses.append("AND `version`=:data_version")
160+
parameters["data_version"] = data_version
134161

135-
where_name = "" if data_name is None else "AND `name`=:data_name"
136-
where_version = "" if data_version is None else "AND `version`=:data_version"
137-
where_uploader = "" if uploader is None else "AND `uploader`=:uploader"
138-
data_id_str = ",".join(str(did) for did in data_id) if data_id else ""
139-
where_data_id = "" if not data_id else f"AND d.`did` IN ({data_id_str})"
162+
if data_id:
163+
clauses.append("AND d.`did` IN :data_ids")
164+
parameters["data_ids"] = data_id
140165

141166
# requires some benchmarking on whether e.g., IN () is more efficient.
142-
matching_tag = (
143-
text(
167+
if tag:
168+
clauses.append(
144169
"""
145-
AND d.`did` IN (
146-
SELECT `id`
147-
FROM dataset_tag as dt
148-
WHERE dt.`tag`=:tag
149-
)
150-
""",
170+
AND d.`did` IN (
171+
SELECT `id`
172+
FROM dataset_tag as dt
173+
WHERE dt.`tag`=:tag
174+
)
175+
""",
151176
)
152-
if tag
153-
else ""
154-
)
177+
parameters["tag"] = tag
155178

156-
def quality_clause(quality: str, range_: str | None) -> str:
157-
if not range_:
158-
return ""
159-
if not (match := re.match(integer_range_regex, range_)):
160-
msg = f"`range_` not a valid range: {range_}"
161-
raise ValueError(msg)
162-
start, end = match.groups()
163-
value = f"`value` BETWEEN {start} AND {end[2:]}" if end else f"`value`={start}"
164-
return f""" AND
165-
d.`did` IN (
166-
SELECT `data`
167-
FROM data_quality
168-
WHERE `quality`='{quality}' AND {value}
169-
)
170-
""" # noqa: S608 - `quality` is not user provided, value is filtered with regex
179+
number_instances_filter = _quality_clause("NumberOfInstances", number_instances)
180+
number_classes_filter = _quality_clause("NumberOfClasses", number_classes)
181+
number_features_filter = _quality_clause("NumberOfFeatures", number_features)
182+
number_missing_values_filter = _quality_clause("NumberOfMissingValues", number_missing_values)
171183

172-
number_instances_filter = quality_clause("NumberOfInstances", number_instances)
173-
number_classes_filter = quality_clause("NumberOfClasses", number_classes)
174-
number_features_filter = quality_clause("NumberOfFeatures", number_features)
175-
number_missing_values_filter = quality_clause("NumberOfMissingValues", number_missing_values)
184+
columns = ["did", "name", "version", "format", "file_id", "status"]
176185
matching_filter = text(
177186
f"""
178187
SELECT d.`did`,d.`name`,d.`version`,d.`format`,d.`file_id`,
179188
IFNULL(cs.`status`, 'in_preparation')
180189
FROM dataset AS d
181-
LEFT JOIN ({current_status}) AS cs ON d.`did`=cs.`did`
182-
WHERE {visible_to_user} {where_name} {where_version} {where_uploader}
183-
{where_data_id} {matching_tag} {number_instances_filter} {number_features_filter}
190+
LEFT JOIN ({status_subquery}) AS cs ON d.`did`=cs.`did`
191+
WHERE 1=1 {number_instances_filter} {number_features_filter}
184192
{number_classes_filter} {number_missing_values_filter}
185-
AND IFNULL(cs.`status`, 'in_preparation') IN ({where_status})
186-
LIMIT {pagination.limit} OFFSET {pagination.offset}
193+
{" ".join(clauses)}
194+
LIMIT :limit OFFSET :offset
187195
""", # noqa: S608
188196
# I am not sure how to do this correctly without an error from Bandit here.
189197
# However, the `status` input is already checked by FastAPI to be from a set
190198
# of given options, so no injection is possible (I think). The `current_status`
191199
# subquery also has no user input. So I think this should be safe.
192200
)
193-
columns = ["did", "name", "version", "format", "file_id", "status"]
201+
202+
if data_id:
203+
matching_filter.bindparams(bindparam("data_ids", expanding=True))
194204
result = await expdb_db.execute(
195205
matching_filter,
196-
parameters={
197-
"tag": tag,
198-
"data_name": data_name,
199-
"data_version": data_version,
200-
"uploader": uploader,
201-
},
206+
parameters=parameters,
202207
)
203208
rows = result.all()
204209
datasets: dict[int, dict[str, Any]] = {

0 commit comments

Comments
 (0)