Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .env.development
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ MEDCAT_MODELS_DIR=/Volumes/clinical-data/medcat
NER_SERVICE_PORT=8003

EMBEDDING_SERVICE_PORT=8004
EMBEDDING_SERVICE_HOST=embedding-service-dev
EMBEDDING_SERVICE_HOST=embedding-service-dev-cpu
BATCH_SIZE=1

CHROMA_SERVICE_HOST=chromadb-dev
Expand Down
36 changes: 7 additions & 29 deletions adrenaline/api/patients/answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@

import json
import logging
import os
from typing import Tuple

from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnableSequence
from langchain_openai import ChatOpenAI

from api.pages.data import Answer
from api.patients.llm import LLM
from api.patients.prompts import (
general_answer_template,
patient_answer_template,
Expand All @@ -21,27 +20,6 @@
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Set up OpenAI client with custom endpoint
LLM_SERVICE_URL = os.getenv("LLM_SERVICE_URL")
if not LLM_SERVICE_URL:
raise ValueError("LLM_SERVICE_URL is not set")
logger.info(f"LLM_SERVICE_URL is set to: {LLM_SERVICE_URL}")

os.environ["OPENAI_API_KEY"] = "EMPTY"

# Initialize LLM with increased timeout
try:
llm = ChatOpenAI(
base_url=LLM_SERVICE_URL,
model_name="Meta-Llama-3.1-70B-Instruct",
temperature=0.3,
max_tokens=4096,
request_timeout=60,
)
logger.info("ChatOpenAI initialized successfully")
except Exception as e:
logger.error(f"Error initializing ChatOpenAI: {str(e)}")
raise

answer_parser = PydanticOutputParser(pydantic_object=Answer)
patient_answer_prompt = PromptTemplate(
Expand All @@ -54,8 +32,8 @@
)

# Initialize the LLMChains
patient_answer_chain = RunnableSequence(patient_answer_prompt | llm)
general_answer_chain = RunnableSequence(general_answer_prompt | llm)
patient_answer_chain = RunnableSequence(patient_answer_prompt | LLM)
general_answer_chain = RunnableSequence(general_answer_prompt | LLM)


def parse_llm_output_answer(output: str) -> Tuple[str, str]:
Expand Down Expand Up @@ -139,7 +117,7 @@ async def generate_answer(
raise


async def test_llm_connection():
async def test_llm_connection() -> bool:
"""Test the connection to the LLM.

Returns
Expand All @@ -158,12 +136,12 @@ async def test_llm_connection():
return False


async def initialize_llm():
async def initialize_llm() -> bool:
"""Initialize the LLM.

Returns
-------
bool
True if the connection is successful, False otherwise.
True if the LLM is initialized successfully, False otherwise.
"""
await test_llm_connection()
return await test_llm_connection()
12 changes: 12 additions & 0 deletions adrenaline/api/patients/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,18 @@
from pydantic import BaseModel, Field, field_validator


class MedicationsRequest(BaseModel):
"""Request for formatting medications.

Attributes
----------
medications: str
The medications to format.
"""

medications: str


class CohortSearchQuery(BaseModel):
"""Query for cohort search.

Expand Down
73 changes: 73 additions & 0 deletions adrenaline/api/patients/ehr.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,74 @@ def fetch_patient_events_by_type(
)
raise

def fetch_latest_medications(self, patient_id: int) -> str:
"""Fetch medication events from the latest encounter and format them.

Parameters
----------
patient_id : int
The patient ID.

Returns
-------
str
Comma-separated list of medications from the latest encounter.
"""
if self.lazy_df is None:
raise ValueError("LazyFrame not initialized")

try:
# First get all events for the patient
filtered_df = (
self.lazy_df.filter(pl.col("patient_id") == patient_id)
.select(list(self._required_columns))
.collect(streaming=True)
)

if filtered_df.height == 0:
logger.info(f"No events found for patient {patient_id}")
return ""

# Process all events first
processed_events = [
self._process_event(row) for row in filtered_df.to_dicts()
]

# Find the latest encounter
latest_encounter = None
latest_timestamp = None

for event in processed_events:
if event["event_type"] == "HOSPITAL_ADMISSION" and (
latest_timestamp is None or event["timestamp"] > latest_timestamp
):
latest_timestamp = event["timestamp"]
latest_encounter = event["encounter_id"]

if latest_encounter is None:
logger.info(f"No hospital admissions found for patient {patient_id}")
return ""

# Filter medications for the latest encounter
medications = {
event["details"]
for event in processed_events
if (
event["event_type"] == "MEDICATION"
and event["timestamp"] > latest_timestamp
)
}

# Return sorted, comma-separated string
return ", ".join(sorted(medications))

except Exception as e:
logger.error(
f"Error fetching medications for patient {patient_id}: {str(e)}",
exc_info=True,
)
raise


def fetch_patient_encounters(patient_id: int) -> List[dict]:
"""Fetch encounters with admission dates for a patient.
Expand Down Expand Up @@ -277,3 +345,8 @@ def fetch_patient_events(patient_id: int) -> List[Event]:
def fetch_patient_events_by_type(patient_id: int, event_type: str) -> List[Event]:
"""Fetch events filtered by event_type for a patient."""
return ehr_data_manager.fetch_patient_events_by_type(patient_id, event_type)


def fetch_latest_medications(patient_id: int) -> str:
"""Fetch medication list from the latest encounter."""
return ehr_data_manager.fetch_latest_medications(patient_id)
33 changes: 33 additions & 0 deletions adrenaline/api/patients/llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""LLM module for patients API."""

import logging
import os

from langchain_openai import ChatOpenAI


# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Set up OpenAI client with custom endpoint
LLM_SERVICE_URL = os.getenv("LLM_SERVICE_URL")
if not LLM_SERVICE_URL:
raise ValueError("LLM_SERVICE_URL is not set")
logger.info(f"LLM_SERVICE_URL is set to: {LLM_SERVICE_URL}")

os.environ["OPENAI_API_KEY"] = "EMPTY"

# Initialize LLM with increased timeout
try:
LLM = ChatOpenAI(
base_url=LLM_SERVICE_URL,
model_name="Meta-Llama-3.1-70B-Instruct",
temperature=0.3,
max_tokens=4096,
request_timeout=60,
)
logger.info("ChatOpenAI initialized successfully")
except Exception as e:
logger.error(f"Error initializing ChatOpenAI: {str(e)}")
raise
1 change: 1 addition & 0 deletions adrenaline/api/patients/workflows.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""EHR workflow functions."""
59 changes: 56 additions & 3 deletions adrenaline/api/routes/answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
import logging
import os
from datetime import datetime
from typing import Dict, List
from typing import Any, Dict, List

from fastapi import APIRouter, Body, Depends, HTTPException
from motor.motor_asyncio import AsyncIOMotorDatabase

from api.pages.data import Query
from api.patients.answer import generate_answer
from api.patients.data import CohortSearchQuery, CohortSearchResult
from api.patients.data import CohortSearchQuery, CohortSearchResult, MedicationsRequest
from api.patients.db import get_database
from api.patients.rag import (
ChromaManager,
Expand Down Expand Up @@ -53,10 +53,63 @@
RAG_MANAGER = RAGManager(EMBEDDING_MANAGER, CHROMA_MANAGER, NER_MANAGER)


@router.post("/format_medications")
async def format_medications(
request: MedicationsRequest,
current_user: User = Depends(get_current_active_user), # noqa: B008
) -> Dict[str, str]:
"""Format medications into a markdown table.

Parameters
----------
request : MedicationsRequest
Request containing medications string
current_user : User
The current authenticated user

Returns
-------
Dict[str, str]
Formatted markdown table of medications
"""
try:
if not request.medications.strip():
return {"formatted_medications": "No medications found"}

# Prepare the prompt for the LLM
prompt = f"""
Convert this comma-separated list of medications into a well-formatted markdown table with columns for Medication Name and Status.
Sort them alphabetically by medication name. Remove any duplicate entries.

Medications: {request.medications}

Format the table like this:
| Medication Name | Status |
|----------------|---------|
| Med 1 | Status 1 |
"""

# Generate the formatted table
formatted_table = await generate_answer(
user_query=prompt,
mode="general",
context="",
)

return {"formatted_medications": formatted_table[0]}

except Exception as e:
logger.error(f"Error formatting medications: {str(e)}")
raise HTTPException(
status_code=500,
detail="An error occurred while formatting medications",
) from e


@router.post("/generate_answer")
async def generate_answer_endpoint(
query: Query = Body(...), # noqa: B008
db: AsyncIOMotorDatabase = Depends(get_database), # noqa: B008
db: AsyncIOMotorDatabase[Any] = Depends(get_database), # noqa: B008
current_user: User = Depends(get_current_active_user), # noqa: B008
) -> Dict[str, str]:
"""Generate an answer using RAG."""
Expand Down
12 changes: 6 additions & 6 deletions adrenaline/api/routes/pages.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
from datetime import UTC, datetime
from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional

from bson import ObjectId
from fastapi import APIRouter, Body, Depends, HTTPException
Expand Down Expand Up @@ -32,7 +32,7 @@ class CreatePageRequest(BaseModel):
@router.post("/pages/create")
async def create_page(
request: CreatePageRequest,
db: AsyncIOMotorDatabase = Depends(get_database), # noqa: B008
db: AsyncIOMotorDatabase[Any] = Depends(get_database), # noqa: B008
current_user: User = Depends(get_current_active_user), # noqa: B008
) -> Dict[str, str]:
"""Create a new page for a user."""
Expand Down Expand Up @@ -63,9 +63,9 @@ async def append_to_page(
page_id: str,
question: str = Body(...),
answer: str = Body(...),
db: AsyncIOMotorDatabase = Depends(get_database), # noqa: B008
db: AsyncIOMotorDatabase[Any] = Depends(get_database), # noqa: B008
current_user: User = Depends(get_current_active_user), # noqa: B008
) -> dict:
) -> Dict[str, str]:
"""Append a follow-up question and answer to an existing page."""
existing_page = await db.pages.find_one(
{"_id": ObjectId(page_id), "user_id": str(current_user.id)}
Expand All @@ -89,7 +89,7 @@ async def append_to_page(

@router.get("/pages/history", response_model=List[Page])
async def get_user_page_history(
db: AsyncIOMotorDatabase = Depends(get_database), # noqa: B008
db: AsyncIOMotorDatabase[Any] = Depends(get_database), # noqa: B008
current_user: User = Depends(get_current_active_user), # noqa: B008
) -> List[Page]:
"""Retrieve all pages for the current user.
Expand All @@ -114,7 +114,7 @@ async def get_user_page_history(
@router.get("/pages/{page_id}", response_model=Page)
async def get_page(
page_id: str,
db: AsyncIOMotorDatabase = Depends(get_database), # noqa: B008
db: AsyncIOMotorDatabase[Any] = Depends(get_database), # noqa: B008
current_user: User = Depends(get_current_active_user), # noqa: B008
) -> Page:
"""Retrieve a specific page.
Expand Down
Loading