Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 44 additions & 5 deletions api/src/feeds/impl/search_api_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def get_parsed_search_tsquery(search_query: str) -> str:

@staticmethod
def add_search_query_filters(
query, search_query, data_type, feed_id, status, is_official, features, version
query, search_query, data_type, feed_id, status, is_official, features, version, license_ids, license_is_spdx
) -> Query:
"""
Add filters to the search query.
Expand Down Expand Up @@ -68,6 +68,19 @@ def add_search_query_filters(
query = query.filter(
t_feedsearch.c.document.op("@@")(SearchApiImpl.get_parsed_search_tsquery(search_query))
)
if license_ids:
license_ids_list = [lid.strip() for lid in license_ids.split(",") if len(lid.strip()) > 0]
if len(license_ids_list) > 0:
query = query.where(t_feedsearch.c.license_id.in_(license_ids_list))

if license_is_spdx is not None:
if license_is_spdx:
query = query.where(t_feedsearch.c.license_is_spdx.is_(True))
else:
query = query.where(
or_(t_feedsearch.c.license_is_spdx.is_(False), t_feedsearch.c.license_is_spdx.is_(None))
)

# Add feature filter with OR logic
if features:
features_list = [s.strip() for s in features[0].split(",") if s]
Expand All @@ -86,13 +99,24 @@ def create_count_search_query(
features,
version: str,
search_query: str,
license_ids: str,
license_is_spdx: bool,
) -> Query:
"""
Create a search query for the database.
"""
query = select(func.count(t_feedsearch.c.feed_id))
return SearchApiImpl.add_search_query_filters(
query, search_query, data_type, feed_id, status, is_official, features, version
query,
search_query,
data_type,
feed_id,
status,
is_official,
features,
version,
license_ids,
license_is_spdx,
)

@staticmethod
Expand All @@ -104,6 +128,8 @@ def create_search_query(
search_query: str,
features: List[str],
version: str,
license_ids: str,
license_is_spdx: bool,
) -> Query:
"""
Create a search query for the database.
Expand All @@ -117,7 +143,16 @@ def create_search_query(
*feed_search_columns,
)
query = SearchApiImpl.add_search_query_filters(
query, search_query, data_type, feed_id, status, is_official, features, version
query,
search_query,
data_type,
feed_id,
status,
is_official,
features,
version,
license_ids,
license_is_spdx,
)
# If search query is provided, use it as secondary sort after timestamp
if search_query and len(search_query.strip()) > 0:
Expand All @@ -140,10 +175,14 @@ def search_feeds(
version: str,
search_query: str,
feature: List[str],
license_ids: str,
license_is_spdx: bool,
db_session: "Session",
) -> SearchFeeds200Response:
"""Search feeds using full-text search on feed, location and provider's information."""
query = self.create_search_query(status, feed_id, data_type, is_official, search_query, feature, version)
query = self.create_search_query(
status, feed_id, data_type, is_official, search_query, feature, version, license_ids, license_is_spdx
)
feed_rows = Database().select(
session=db_session,
query=query,
Expand All @@ -153,7 +192,7 @@ def search_feeds(
feed_total_count = Database().select(
session=db_session,
query=self.create_count_search_query(
status, feed_id, data_type, is_official, feature, version, search_query
status, feed_id, data_type, is_official, feature, version, search_query, license_ids, license_is_spdx
),
)
if feed_rows is None or feed_total_count is None:
Expand Down
24 changes: 24 additions & 0 deletions api/src/scripts/populate_db_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Gbfsendpoint,
Gbfsfeed,
Rule,
Feed,
)
from scripts.populate_db import set_up_configs, DatabasePopulateHelper
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -83,6 +84,29 @@ def populate_test_datasets(self, filepath, db_session: "Session"):
db_session.add(license_obj)
db_session.commit()

# Link licenses to feeds if specified
if "feed_licenses" in data:
for lf in data["feed_licenses"]:
license_id = lf.get("license_id")
feed_stable_id = lf.get("feed_stable_id")
if not license_id or not feed_stable_id:
continue
license_obj = db_session.get(License, license_id)
if not license_obj:
self.logger.error(
f"No license found with id: {license_id}; skipping license_feed for feed " f"{feed_stable_id}"
)
continue
feed_obj = db_session.query(Feed).filter(Feed.stable_id == feed_stable_id).one_or_none()
if not feed_obj:
self.logger.error(
f"No feed found with stable_id: {feed_stable_id}; skipping license_feed for"
f" license {license_id}"
)
continue
feed_obj.license = license_obj
db_session.commit()

# Rules (optional section to seed rule metadata used by license_rules)
if "rules" in data:
for rule in data["rules"]:
Expand Down
6 changes: 6 additions & 0 deletions api/src/shared/db_models/search_feed_item_result_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def from_orm_gtfs(cls, feed_search_row: t_feedsearch):
authentication_info_url=feed_search_row.authentication_info_url,
api_key_parameter_name=feed_search_row.api_key_parameter_name,
license_url=feed_search_row.license_url,
license_id=feed_search_row.license_id,
license_is_spdx=feed_search_row.license_is_spdx,
),
redirects=feed_search_row.redirect_ids,
locations=cls.resolve_locations(feed_search_row.locations),
Expand Down Expand Up @@ -91,6 +93,8 @@ def from_orm_gbfs(cls, feed_search_row):
authentication_info_url=feed_search_row.authentication_info_url,
api_key_parameter_name=feed_search_row.api_key_parameter_name,
license_url=feed_search_row.license_url,
license_id=feed_search_row.license_id,
license_is_spdx=feed_search_row.license_is_spdx,
),
redirects=feed_search_row.redirect_ids,
locations=cls.resolve_locations(feed_search_row.locations),
Expand Down Expand Up @@ -118,6 +122,8 @@ def from_orm_gtfs_rt(cls, feed_search_row):
authentication_info_url=feed_search_row.authentication_info_url,
api_key_parameter_name=feed_search_row.api_key_parameter_name,
license_url=feed_search_row.license_url,
license_id=feed_search_row.license_id,
license_is_spdx=feed_search_row.license_is_spdx,
),
redirects=feed_search_row.redirect_ids,
locations=cls.resolve_locations(feed_search_row.locations),
Expand Down
43 changes: 43 additions & 0 deletions api/tests/integration/test_search_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,49 @@ def test_search_filter_by_versions(client: TestClient, values: dict):
), f"There should be {expected_count} feeds for versions={values['versions']}"


@pytest.mark.parametrize(
"values",
[
{"license_ids": "CC BY 4.0", "expected_count": 1},
{"license_ids": "ODbL-1.0", "expected_count": 1},
{"license_ids": "ODbL-1.0,CC BY 4.0", "expected_count": 2},
{"license_ids": "", "expected_count": 16},
],
ids=[
"License ID CC BY 4.0",
"License ID ODbL-1.0",
"License IDs ODbL-1.0 and CC BY 4.0",
"No license IDs specified",
],
)
def test_search_filter_by_license_ids(client: TestClient, values: dict):
"""
Retrieve feeds that contain specific license IDs.
"""
params = None
if values["license_ids"] is not None:
params = [
("license_ids", values["license_ids"]),
]
headers = {
"Authentication": "special-key",
}
response = client.request(
"GET",
"/v1/search",
headers=headers,
params=params,
)
# Assert the status code of the HTTP response
assert response.status_code == 200
# Parse the response body into a Python object
response_body = SearchFeeds200Response.parse_obj(response.json())
expected_count = values["expected_count"]
assert (
response_body.total == expected_count
), f"There should be {expected_count} feeds for license_ids={values['license_ids']}"

Comment thread
cka-y marked this conversation as resolved.

@pytest.mark.parametrize(
"values",
[
Expand Down
24 changes: 24 additions & 0 deletions api/tests/test_data/extra_test_data.json
Original file line number Diff line number Diff line change
Expand Up @@ -825,5 +825,29 @@
}
]
}
],
"licenses": [
{
"id": "CC BY 4.0",
"is_spdx": true,
"name": "Creative Commons Attribution 4.0 International",
"url": "https://creativecommons.org/licenses/by/4.0/"
},
{
"id": "ODbL-1.0",
"is_spdx": true,
"name": "Open Data Commons Open Database License (ODbL) v1.0",
"url": "https://opendatacommons.org/licenses/odbl/1.0/"
}
],
"feed_licenses": [
{
"feed_stable_id": "mdb-1",
"license_id": "CC BY 4.0"
},
{
"feed_stable_id": "gbfs-system_id_1",
"license_id": "ODbL-1.0"
}
]
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def __init__(self, **kwargs):
country_translations=[],
subdivision_name_translations=[],
municipality_translations=[],
license_id=None,
license_is_spdx=None,
)


Expand Down
17 changes: 17 additions & 0 deletions docs/DatabaseCatalogAPI.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,8 @@ paths:
- $ref: "#/components/parameters/version_query_param"
- $ref: "#/components/parameters/search_text_query_param"
- $ref: "#/components/parameters/feature"
- $ref: "#/components/parameters/license_ids"
- $ref: "#/components/parameters/license_is_spdx"
security:
- Authentication: []
responses:
Expand Down Expand Up @@ -1485,6 +1487,21 @@ components:
type: array
items:
type: string
license_ids:
name: license_ids
in: query
description: Comma separated list of license IDs to filter feeds by their license.
required: false
schema:
type: string
example: CC-BY-4.0,ODbL-1.0
Comment thread
cka-y marked this conversation as resolved.
license_is_spdx:
name: license_is_spdx
in: query
description: Filter feeds by whether their license is an SPDX license.
required: false
schema:
type: boolean
provider:
name: provider
in: query
Expand Down
27 changes: 25 additions & 2 deletions liquibase/materialized_views/feed_search.sql
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ SELECT
Feed.feed_name,
Feed.note,
Feed.feed_contact_email,

-- source
Feed.producer_url,
Feed.authentication_info_url,
Expand All @@ -16,10 +17,18 @@ SELECT
Feed.license_url,
Feed.provider,
Feed.operational_status,

-- official status
Feed.official AS official,

-- created_at
Feed.created_at AS created_at,

-- license fields
Feed.license_id AS license_id,
License.is_spdx AS license_is_spdx,
License.name AS license_name,

-- latest_dataset
Latest_dataset.stable_id AS latest_dataset_id,
Latest_dataset.hosted_url AS latest_dataset_hosted_url,
Expand All @@ -29,28 +38,37 @@ SELECT
Latest_dataset.agency_timezone AS latest_dataset_agency_timezone,
Latest_dataset.service_date_range_start AS latest_dataset_service_date_range_start,
Latest_dataset.service_date_range_end AS latest_dataset_service_date_range_end,

-- Latest dataset features
LatestDatasetFeatures AS latest_dataset_features,

-- Latest dataset validation totals
COALESCE(LatestDatasetValidationReportJoin.total_error, 0) as latest_total_error,
COALESCE(LatestDatasetValidationReportJoin.total_warning, 0) as latest_total_warning,
COALESCE(LatestDatasetValidationReportJoin.total_info, 0) as latest_total_info,
COALESCE(LatestDatasetValidationReportJoin.unique_error_count, 0) as latest_unique_error_count,
COALESCE(LatestDatasetValidationReportJoin.unique_warning_count, 0) as latest_unique_warning_count,
COALESCE(LatestDatasetValidationReportJoin.unique_info_count, 0) as latest_unique_info_count,

-- external_ids
ExternalIdJoin.external_ids,

-- redirect_ids
RedirectingIdJoin.redirect_ids,

-- feed gtfs_rt references
FeedReferenceJoin.feed_reference_ids,

-- feed gtfs_rt entities
EntityTypeFeedJoin.entities,

-- locations
FeedLocationJoin.locations,

-- osm locations grouped
OsmLocationJoin.osm_locations,
-- gbfs versions

-- gbfs versions
COALESCE(GbfsVersionsJoin.versions, '[]'::jsonb) AS versions,

-- full-text searchable document
Expand All @@ -70,6 +88,9 @@ SELECT
AS document
FROM Feed

-- license join
LEFT JOIN License ON License.id = Feed.license_id

-- Latest dataset
LEFT JOIN gtfsfeed gtf ON gtf.id = Feed.id AND Feed.data_type = 'gtfs'
LEFT JOIN gtfsdataset Latest_dataset ON Latest_dataset.id = gtf.latest_dataset_id
Expand Down Expand Up @@ -149,7 +170,6 @@ LEFT JOIN (
GROUP BY gtfs_rt_feed_id
) AS FeedReferenceJoin ON FeedReferenceJoin.gtfs_rt_feed_id = Feed.id AND Feed.data_type = 'gtfs_rt'

-- Redirect ids
-- Redirect ids
LEFT JOIN (
SELECT
Expand All @@ -159,6 +179,7 @@ LEFT JOIN (
JOIN Feed f ON r.target_id = f.id
GROUP BY r.target_id
) AS RedirectingIdJoin ON RedirectingIdJoin.target_id = Feed.id

-- Feed locations
LEFT JOIN (
SELECT
Expand Down Expand Up @@ -247,4 +268,6 @@ CREATE INDEX feedsearch_document_idx ON FeedSearch USING GIN(document);
CREATE INDEX feedsearch_feed_stable_id ON FeedSearch(feed_stable_id);
CREATE INDEX feedsearch_data_type ON FeedSearch(data_type);
CREATE INDEX feedsearch_status ON FeedSearch(status);
CREATE INDEX feedsearch_license_id ON FeedSearch(license_id);
CREATE INDEX feedsearch_license_is_spdx ON FeedSearch(license_is_spdx);

Loading
Loading