Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
59b40f0
REFACTOR Pull apart get_closest_embeddings to make testing easier
sahilds1 Feb 13, 2026
3ffb74a
ADD Add infra required to run pytest
sahilds1 Feb 13, 2026
12b09a7
ADD Start adding tests for embedding_services"
sahilds1 Feb 13, 2026
da9afaa
DOC Add a note about running pytest in the README
sahilds1 Feb 17, 2026
5ce7782
Preload SentenceTransformer model at Django startup before traffic is…
sahilds1 Feb 27, 2026
50a8bd3
Merge branch 'develop' into 441-embedding-models
sahilds1 Mar 11, 2026
795f218
Run python-app workflow on pushes and PRs to develop branch
sahilds1 Mar 11, 2026
d498a00
Pytest won’t automatically discover config files in subdirectories
sahilds1 Mar 19, 2026
6d3d8d1
Merge branch 'develop' into 441-embedding-models
sahilds1 Mar 19, 2026
3824d81
Suppress E402 import violations
sahilds1 Mar 19, 2026
46e9969
Add build_query tests and document coverage gaps in embedding_services
sahilds1 Mar 20, 2026
64a19ef
Fill test gaps in test_embedding_services
sahilds1 Mar 20, 2026
dec3c12
Fix incorrect build_query test assertions
sahilds1 Mar 20, 2026
f9e890a
Guard TransformerModel preload to runserver processes only
sahilds1 Mar 23, 2026
67176a8
Revert GitHub Workflow changes
sahilds1 Mar 25, 2026
d273921
Add section header comments to all four test groups in test_embedding…
sahilds1 Mar 26, 2026
8198574
Document why tests are split by responsibility
sahilds1 Mar 26, 2026
5d8c8b3
Improve logging and comments
sahilds1 Mar 31, 2026
31498dc
Fall back to lazy load using try except block
sahilds1 Mar 31, 2026
a39d33c
Revert settings.py to develop state
sahilds1 Mar 31, 2026
fe1eeca
Manually test fall back to lazy loading
sahilds1 Mar 31, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ df = pd.read_sql(query, engine)

#### Django REST
- The email and password are set in `server/api/management/commands/createsu.py`
- Backend tests can be run using `pytest` by running the below command inside the running backend container:

```
docker compose exec backend pytest api/ -v
```

## API Documentation

Expand Down
35 changes: 35 additions & 0 deletions server/api/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,38 @@
class ApiConfig(AppConfig):
default_auto_field = 'django.db.models.BigAutoField'
name = 'api'

def ready(self):

try:
import os
import sys

# ready() runs in every Django process: migrate, test, shell, runserver, etc.
# Only preload the model when we're actually going to serve requests.
# Dev (docker-compose.yml) runs `manage.py runserver 0.0.0.0:8000`.
# Prod (Dockerfile.prod CMD) runs `manage.py runserver 0.0.0.0:8000 --noreload`.
# entrypoint.prod.sh also runs migrate, createsu, and populatedb before exec'ing
# runserver — the guard below correctly skips model loading for those commands too.
if sys.argv[1:2] != ['runserver']:
return

# Dev's autoreloader spawns two processes: a parent file-watcher and a child
# server. ready() runs in both, but only the child (RUN_MAIN=true) serves
# requests. Skip the parent to avoid loading the model twice on each file change.
# Prod uses --noreload so RUN_MAIN is never set; 'noreload' in sys.argv handles that case.
if os.environ.get('RUN_MAIN') != 'true' and '--noreload' not in sys.argv:
return

# Note: paraphrase-MiniLM-L6-v2 (~80MB) is downloaded from HuggingFace on first
# use and cached to ~/.cache/torch/sentence_transformers/ inside the container.
# That cache is ephemeral — every container rebuild re-downloads the model unless
# a volume is mounted at that path.
from .services.sentencetTransformer_model import TransformerModel
TransformerModel.get_instance()
except Exception:
# TransformerModel._instance stays None on failure, so the first actual request
# that calls get_instance() will attempt to load the model again.
import logging
logger = logging.getLogger(__name__)
logger.exception("Failed to preload the embedding model at startup")
166 changes: 115 additions & 51 deletions server/api/services/embedding_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
from statistics import median

# Use Q objects to express OR conditions in Django queries
from django.db.models import Q
from pgvector.django import L2Distance

Expand All @@ -11,18 +12,17 @@

logger = logging.getLogger(__name__)

def get_closest_embeddings(
user, message_data, document_name=None, guid=None, num_results=10
):

def build_query(user, embedding_vector, document_name=None, guid=None, num_results=10):
"""
Find the closest embeddings to a given message for a specific user.
Build an unevaluated QuerySet for the closest embeddings.

Parameters
----------
user : User
The user whose uploaded documents will be searched
message_data : str
The input message to find similar embeddings for
embedding_vector : array-like
Pre-computed embedding vector to compare against
document_name : str, optional
Filter results to a specific document name
guid : str, optional
Expand All @@ -32,59 +32,52 @@ def get_closest_embeddings(

Returns
-------
list[dict]
List of dictionaries containing embedding results with keys:
- name: document name
- text: embedded text content
- page_number: page number in source document
- chunk_number: chunk number within the document
- distance: L2 distance from query embedding
- file_id: GUID of the source file
QuerySet
Unevaluated Django QuerySet ordered by L2 distance, sliced to num_results
"""

encoding_start = time.time()
transformerModel = TransformerModel.get_instance().model
embedding_message = transformerModel.encode(message_data)
encoding_time = time.time() - encoding_start

db_query_start = time.time()

# Django QuerySets are lazily evaluated
if user.is_authenticated:
# User sees their own files + files uploaded by superusers
closest_embeddings_query = (
Embeddings.objects.filter(
Q(upload_file__uploaded_by=user) | Q(upload_file__uploaded_by__is_superuser=True)
)
.annotate(
distance=L2Distance("embedding_sentence_transformers", embedding_message)
)
.order_by("distance")
queryset = Embeddings.objects.filter(
Q(upload_file__uploaded_by=user) | Q(upload_file__uploaded_by__is_superuser=True)
)
else:
# Unauthenticated users only see superuser-uploaded files
closest_embeddings_query = (
Embeddings.objects.filter(upload_file__uploaded_by__is_superuser=True)
.annotate(
distance=L2Distance("embedding_sentence_transformers", embedding_message)
)
.order_by("distance")
)
queryset = Embeddings.objects.filter(upload_file__uploaded_by__is_superuser=True)

queryset = (
queryset
.annotate(distance=L2Distance("embedding_sentence_transformers", embedding_vector))
.order_by("distance")
)

# Filtering to a document GUID takes precedence over a document name
if guid:
closest_embeddings_query = closest_embeddings_query.filter(
upload_file__guid=guid
)
queryset = queryset.filter(upload_file__guid=guid)
elif document_name:
closest_embeddings_query = closest_embeddings_query.filter(name=document_name)
queryset = queryset.filter(name=document_name)

# Slicing is equivalent to SQL's LIMIT clause
closest_embeddings_query = closest_embeddings_query[:num_results]
return queryset[:num_results]
Comment on lines 16 to +61
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

build_query() introduces/relocates important filtering + precedence logic (authenticated vs unauthenticated visibility; guid-over-document_name; LIMIT slicing), but the new tests only cover evaluate_query and log_usage. Add unit/integration tests covering build_query behavior (e.g., guid precedence and the authenticated/unauthenticated queryset filters) to prevent regressions in access control and filtering.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Building on Copilot's comment, the specifics of the QuerySet object's structure aren't publicly documented. To inspect the QuerySets, we should actually execute them.

There's a couple ways we handle DB access for these tests. We could use [pytest-django's ``@pytest.mark.django_db](https://pytest-django.readthedocs.io/en/latest/database.html), which wraps the test in a transaction the rolls back automatically afterwards. Django also has a built-in django.test.TestCase`, which does a similar thing.

Copy link
Copy Markdown
Collaborator Author

@sahilds1 sahilds1 Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for sharing the docs references -- I added tests for build_query and didn't have to access the database because I was able to inspect which methods and arguments were called on the model ("Embeddings")



def evaluate_query(queryset):
"""
Evaluate a QuerySet and return a list of result dicts.

Parameters
----------
queryset : iterable
Iterable of Embeddings objects (or any objects with the expected attributes)

Returns
-------
list[dict]
List of dicts with keys: name, text, page_number, chunk_number, distance, file_id
"""
# Iterating evaluates the QuerySet and hits the database
# TODO: Research improving the query evaluation performance
results = [
return [
{
"name": obj.name,
"text": obj.text,
Expand All @@ -93,13 +86,36 @@ def get_closest_embeddings(
"distance": obj.distance,
"file_id": obj.upload_file.guid if obj.upload_file else None,
}
for obj in closest_embeddings_query
for obj in queryset
]

db_query_time = time.time() - db_query_start

def log_usage(
results, message_data, user, guid, document_name, num_results, encoding_time, db_query_time
):
"""
Create a SemanticSearchUsage record. Swallows exceptions so search isn't interrupted.

Parameters
----------
results : list[dict]
The search results, each containing a "distance" key
message_data : str
The original search query text
user : User
The user who performed the search
guid : str or None
Document GUID filter used in the search
document_name : str or None
Document name filter used in the search
num_results : int
Number of results requested
encoding_time : float
Time in seconds to encode the query
db_query_time : float
Time in seconds for the database query
"""
try:
# Handle user having no uploaded docs or doc filtering returning no matches
if results:
distances = [r["distance"] for r in results]
SemanticSearchUsage.objects.create(
Expand All @@ -113,11 +129,10 @@ def get_closest_embeddings(
num_results_returned=len(results),
max_distance=max(distances),
median_distance=median(distances),
min_distance=min(distances)
min_distance=min(distances),
)
else:
logger.warning("Semantic search returned no results")

SemanticSearchUsage.objects.create(
query_text=message_data,
user=user if (user and user.is_authenticated) else None,
Expand All @@ -129,9 +144,58 @@ def get_closest_embeddings(
num_results_returned=0,
max_distance=None,
median_distance=None,
min_distance=None
min_distance=None,
)
except Exception as e:
logger.error(f"Failed to create semantic search usage database record: {e}")
except Exception:
logger.exception("Failed to create semantic search usage database record")


def get_closest_embeddings(
user, message_data, document_name=None, guid=None, num_results=10
):
"""
Find the closest embeddings to a given message for a specific user.

Parameters
----------
user : User
The user whose uploaded documents will be searched
message_data : str
The input message to find similar embeddings for
document_name : str, optional
Filter results to a specific document name
guid : str, optional
Filter results to a specific document GUID (takes precedence over document_name)
num_results : int, default 10
Maximum number of results to return

Returns
-------
list[dict]
List of dictionaries containing embedding results with keys:
- name: document name
- text: embedded text content
- page_number: page number in source document
- chunk_number: chunk number within the document
- distance: L2 distance from query embedding
- file_id: GUID of the source file

Notes
-----
Creates a SemanticSearchUsage record. Swallows exceptions so search isn't interrupted.
"""
encoding_start = time.time()
model = TransformerModel.get_instance().model
embedding_vector = model.encode(message_data)
encoding_time = time.time() - encoding_start

db_query_start = time.time()
queryset = build_query(user, embedding_vector, document_name, guid, num_results)
results = evaluate_query(queryset)
db_query_time = time.time() - db_query_start

log_usage(
results, message_data, user, guid, document_name, num_results, encoding_time, db_query_time
)

return results
Loading
Loading