diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..d8e589d --- /dev/null +++ b/.dockerignore @@ -0,0 +1,23 @@ +# Datasets +data/ +**/training_set/ +**/supplementary_set/ +**/*.edf + +# Artifacts +model/ +model_smoke/ +model_full_smoke/ +outputs/ +outputs_smoke/ +__pycache__/ +*.pyc +*.pkl +*.sav +*.joblib + +# OS / IDE +.DS_Store +Thumbs.db +.vscode/ +.idea/ \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..7cc07cb --- /dev/null +++ b/.gitignore @@ -0,0 +1,238 @@ +# Dataset +data/ + +# Model artifacts +model/ +model_smoke/ +*.pkl +*.sav +*.joblib + +# Outputs +outputs/ +outputs_smoke/ + +# Python +__pycache__/ +*.pyc + +# OS +.DS_Store +Thumbs.db + +# IDE +.vscode/ +.idea/ + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[codz] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py.cover +.hypothesis/ +.pytest_cache/ +cover/ +graphs/ + +graphs +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# UV +# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +#uv.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock +#poetry.toml + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python. +# https://pdm-project.org/en/latest/usage/project/#working-with-version-control +#pdm.lock +#pdm.toml +.pdm-python +.pdm-build/ + +# pixi +# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control. +#pixi.lock +# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one +# in the .venv directory. It is recommended not to include this directory in version control. +.pixi + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.envrc +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# Abstra +# Abstra is an AI-powered process automation framework. +# Ignore directories containing user credentials, local state, and settings. +# Learn more at https://abstra.io/docs +.abstra/ + +# Visual Studio Code +# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore +# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore +# and can be added to the global gitignore or merged into this file. However, if you prefer, +# you could uncomment the following to ignore the entire vscode folder +# .vscode/ + +# Ruff stuff: +.ruff_cache/ + +# PyPI configuration file +.pypirc + +# Cursor +# Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to +# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data +# refer to https://docs.cursor.com/context/ignore-files +.cursorignore +.cursorindexingignore + +# Marimo +marimo/_static/ +marimo/_lsp/ +__marimo__/ +results_summaryEEG_I0002.csv +results_summaryEEG_I0004.csv +results_summaryEEG_I0006.csv +results_summaryEEG_I0007.csv diff --git a/AUTHORS.txt b/AUTHORS.txt new file mode 100644 index 0000000..143148b --- /dev/null +++ b/AUTHORS.txt @@ -0,0 +1,4 @@ +Sofia Romagnoli - Universidad de Zaragoza +Diego Cajal - CIBER-BBN +Josseline Madrid - Universidad de Zaragoza +Rodrigo Lozano - Universidad de Zaragoza diff --git a/docs/01_overview.md b/docs/01_overview.md new file mode 100644 index 0000000..9805383 --- /dev/null +++ b/docs/01_overview.md @@ -0,0 +1,35 @@ +# CINC 2026 – Visión General del Proyecto + +Estamos participando en el Challenge 2026 de Computing in Cardiology. + +El objetivo es predecir deterioro cognitivo a partir de datos de polisomnografía (PSG). + +## Cómo nos evaluarán + +La organización: + +1. Construirá nuestra imagen Docker. +2. Ejecutará `train_model.py`. +3. Ejecutará `run_model.py`. +4. Evaluará las predicciones generadas. + +Por tanto, la reproducibilidad mediante Docker es obligatoria. + +Nuestro objetivo es garantizar que: +- El código se ejecuta sin intervención manual. +- El modelo se entrena correctamente. +- Las predicciones se generan en el formato requerido. + +## Qué se puede modificar y qué no + +❌ No modificar + +- `train_model.py` +- `run_model.py` +- `helper_code.py` +- `evaluate_model.py` + +✅ Modificar/Añadir + +- `team_code.py` <-- Toda la lógica científica y de modelado debe implementarse ahí. +- Helpers, scripts, métodos: añadir a voluntad en `src/` \ No newline at end of file diff --git a/docs/02_docker.md b/docs/02_docker.md new file mode 100644 index 0000000..c24ced2 --- /dev/null +++ b/docs/02_docker.md @@ -0,0 +1,46 @@ +# Uso de Docker + +Este documento define el contexto de ejecución con Docker. + +## Requisitos + +- Docker Desktop instalado (modo Linux containers) +- Dataset descargado desde Kaggle +- Dataset completo disponible en `data/training_set/` (ruta por defecto del proyecto) + +Si tu dataset está en otra ubicación, actualiza la variable de ruta en el script de ejecución. + +## Estructura de trabajo + +Entradas: + +- `data/training_set/` (dataset completo) +- `data/training_smoke/` (dataset reducido para modo desarrollo (smoke)) + +Salidas: + +- `model/` y `outputs/` (flujo completo) +- `model_smoke/` y `outputs_smoke/` (flujo smoke/desarrollo) + +## Orden recomendado de ejecución + +1. Construir imagen Docker (`build`) +2. Preparar dataset smoke (`smoke`) +3. Iterar en modo desarrollo (smoke) (`train-dev` / `run-dev`) +4. Ejecutar validación completa (`train` / `run`) +5. Limpiar artefactos cuando corresponda (`clean`) + +La guía paso a paso está en `docs/04_run_script.md`. + +## Compatibilidad de scripts + +El flujo principal del equipo está documentado con `run.sh` (Git Bash). +También existen equivalentes en PowerShell: `run.ps1` y `scripts/create_smoke.ps1`. + +## Resultado esperado + +Tras ejecutar la generación de predicciones (inferencia) completa, en `outputs/` se genera un `demographics.csv` con: + +- Columnas originales +- `Cognitive_Impairment` +- `Cognitive_Impairment_Probability` \ No newline at end of file diff --git a/docs/03_smoke_dataset.md b/docs/03_smoke_dataset.md new file mode 100644 index 0000000..ae4ee23 --- /dev/null +++ b/docs/03_smoke_dataset.md @@ -0,0 +1,41 @@ +# Dataset smoke (Modo desarrollo) + +Entrenar con el dataset completo tarda aproximadamente 30–40 minutos con el modelo de ejemplo. + +Para desarrollo utilizamos un dataset reducido (5 sujetos por defecto). + +Este documento describe cuándo y por qué usar smoke. +Los comandos de ejecución están centralizados en `docs/04_run_script.md`. + +--- + +## Qué incluye + +- Muestra reducida del dataset (5 sujetos por defecto) +- Estructura compatible con el flujo oficial del proyecto +- Directorio de salida en `data/training_smoke/` +- `demographics.csv` filtrado para que solo incluya los registros copiados al smoke + +## Para qué se usa + +- Validar cambios de código rápidamente +- Detectar errores de integración antes del entrenamiento completo +- Iterar en modo desarrollo (smoke) sin esperar ciclos largos + +## Artefactos asociados + +- Entrenamiento smoke: `model_smoke/` +- Predicciones (inferencia) smoke: `outputs_smoke/` + +## Relación con el flujo principal + +El dataset smoke se crea al inicio del ciclo de desarrollo y se usa junto con `train-dev` y `run-dev`. +El orden detallado de ejecución está en `docs/04_run_script.md`. + +## ¿Cuándo usar smoke? + +- Desarrollo de nuevas funcionalidades +- Comprobación rápida de que el código no rompe +- Validación de cambios en `team_code.py` + +Nunca usar smoke para evaluar rendimiento final. \ No newline at end of file diff --git a/docs/04_run_script.md b/docs/04_run_script.md new file mode 100644 index 0000000..87eef6d --- /dev/null +++ b/docs/04_run_script.md @@ -0,0 +1,130 @@ +# Script unificado de ejecución (`run.sh`) + +Este documento es la guía operativa única para ejecutar el proyecto. +Aquí se define el orden recomendado y los comandos asociados. + +--- + +# Requisitos + +- Docker Desktop instalado +- Dataset descargado en: + +``` +data/training_set/ +data/supplementary_set/ +``` + +⚠️ Si el dataset está en otra ubicación, modificar las variables `$TRAIN_DATA_REL` y `$RUN_DATA_REL` +dentro de `run.sh`. + +⚠️ Ejecutar los comandos desde Git Bash. + +ℹ️ Existen scripts equivalentes en PowerShell (`run.ps1` y `scripts/create_smoke.ps1`) para quienes prefieran ese entorno. + +ℹ️ Para contexto general y definición de artefactos, ver `docs/02_docker.md` y `docs/03_smoke_dataset.md`. +--- + +# Orden de ejecución recomendado + +Desde la raíz del repositorio. + +## 1) Preparar entorno + +### Construir imagen Docker + +```bash +./run.sh build +``` + +Ejecutar la primera vez y cada vez que cambien `requirements.txt` o `Dockerfile`. + +### Crear dataset smoke (5 sujetos) + +```bash +./run.sh smoke +``` + +Genera `data/training_smoke/`. + +## 2) Ciclo en modo desarrollo (smoke) + +### Entrenar en modo desarrollo (smoke) + +```bash +./run.sh train-dev +``` + +Usa `data/training_smoke/` y guarda modelo en `model_smoke/`. + +### Generar predicciones (inferencia) en modo desarrollo (smoke) + +```bash +./run.sh run-dev +``` + +Genera resultados en `outputs_smoke/` y luego imprime métricas de evaluación en consola. + +### Evaluar predicciones existentes en modo desarrollo (smoke) + +```bash +./run.sh eval-dev +``` + +Reutiliza `outputs_smoke/demographics.csv` y muestra AUROC, AUPRC, Accuracy y F-measure sin volver a ejecutar inferencia. + +### Secuencia típica en modo desarrollo (smoke) + +```bash +./run.sh build # solo la primera vez +./run.sh smoke # solo si no existe +./run.sh train-dev +./run.sh run-dev +./run.sh eval-dev # opcional: reevaluar sin correr inferencia +``` + +## 3) Validación completa + +### Entrenar con dataset completo + +```bash +./run.sh train +``` + +Guarda el modelo en `model/`. + +### Generar predicciones (inferencia) completas + +```bash +./run.sh run +``` + +Genera resultados en `outputs/` usando `data/test_set/`. +Si el dataset no tiene etiquetas (como en `test_set`), el script omite la evaluación automáticamente. + +### Evaluar predicciones existentes completas + +```bash +./run.sh eval +``` + +Reutiliza `outputs/demographics.csv` y muestra AUROC, AUPRC, Accuracy y F-measure sin volver a ejecutar inferencia. +Evalúa contra `data/test_set/`. +Si no hay etiquetas en ese set, el script omite la evaluación automáticamente. + +### Evaluar predicciones existentes del dataset smoke + +```bash +./run.sh eval-smoke +``` + +Reutiliza `outputs_smoke/demographics.csv` y muestra AUROC, AUPRC, Accuracy y F-measure sin volver a ejecutar inferencia. + +## 4) Limpieza de artefactos + +```bash +./run.sh clean +``` + +Elimina `model/`, `model_smoke/`, `outputs/` y `outputs_smoke/`. +No elimina datasets. diff --git a/docs/05_optimization_tracking.md b/docs/05_optimization_tracking.md new file mode 100644 index 0000000..db352fa --- /dev/null +++ b/docs/05_optimization_tracking.md @@ -0,0 +1,128 @@ +# Seguimiento De Optimizaciones + +Este documento registra las optimizaciones de tiempo de entrenamiento aplicadas sobre el código de ejemplo original del PhysioNet Challenge en `team_code.py`. + +## Objetivo + +Mantener un registro claro de los cambios respecto a la base proporcionada por la organización para que el equipo pueda: + +- entender qué optimizaciones se probaron; +- medir su efecto sobre el flujo smoke; +- identificar qué cambios merece la pena conservar; +- revertir cambios concretos si la submission se comporta distinto en el entorno del Challenge. + +## Línea Base + +- Fuente de la línea base: implementación de ejemplo proporcionada por la organización en `team_code.py`. +- Tiempo observado de entrenamiento smoke con `./run.sh train-dev`: alrededor de 22 segundos. +- Comportamiento base: extracción secuencial de features, lecturas repetidas de CSV, recarga repetida de reglas de renombrado de canales y carga no utilizada de anotaciones humanas durante entrenamiento. + +## Cambios Aplicados + +### 1. Eliminación de la carga no utilizada de anotaciones humanas en entrenamiento + +Cambio: + +- Se eliminó la carga de `human_annotations` dentro de `train_model`. +- Se dejó intacta la función auxiliar `extract_human_annotations_features`. + +Motivo: + +- El vector final de entrenamiento solo concatenaba features demográficas, fisiológicas y algorítmicas. +- Las features de anotaciones humanas se calculaban, pero nunca se incluían en el `np.hstack(...)` que se pasaba al clasificador. + +Efecto observado: + +- El tiempo de entrenamiento smoke pasó de unos 22.0 s a 21.891 s. +- Conclusión: la limpieza es correcta a nivel lógico, pero su impacto en tiempo es despreciable en el dataset smoke. + +Riesgo: + +- Bajo. Solo elimina trabajo muerto. + +### 2. Caché de reglas de renombrado de canales + +Cambio: + +- Se añadió una caché en proceso para las reglas de renombrado cargadas desde `channel_table.csv`. +- Se sustituyeron las llamadas repetidas a `load_rename_rules(os.path.abspath(csv_path))` por una consulta a la caché. + +Motivo: + +- `extract_physiological_features` estaba cargando y parseando el mismo CSV para cada registro. + +Efecto observado: + +- El tiempo smoke medido en la siguiente ejecución fue 22.040 s. +- Conclusión: la optimización es correcta, pero no ataca un cuello de botella relevante en smoke. + +Riesgo: + +- Bajo. El comportamiento no cambia salvo por reutilizar reglas ya parseadas. + +### 3. Caché de demographics y etiquetas para entrenamiento + +Cambio: + +- Se añadió una lectura única de `demographics.csv` al inicio de `train_model`. +- Se construyeron: + - una caché de demographics indexada por `(patient_id, session_id)`; + - una caché de diagnósticos indexada por `patient_id`. +- Se reemplazaron las llamadas por registro a `load_demographics(...)` y `load_diagnoses(...)` durante entrenamiento. + +Motivo: + +- El bucle original de entrenamiento releía el mismo CSV para cada registro. + +Efecto observado: + +- El tiempo smoke bajó a 20.837 s. +- Conclusión: es una mejora real, aunque moderada. + +Riesgo: + +- Bajo a medio. +- Asume que las etiquetas de entrenamiento son estables a nivel de paciente cuando se cachean por `patient_id`, igual que hacía el comportamiento original de `load_diagnoses(...)`. + +### 4. Paralelización de la extracción de features en entrenamiento + +Cambio: + +- Se añadió procesamiento paralelo por registro con `ThreadPoolExecutor` dentro de `train_model`. +- Se movió la lógica de extracción por registro a `process_training_record(...)`. +- Se limitó el número de workers con: + +```python +MAX_TRAIN_WORKERS = max(1, min(4, os.cpu_count() or 1)) +``` + +Motivo: + +- Cada registro de entrenamiento se procesa de forma independiente. +- El pipeline mezcla lecturas de archivos EDF y trabajo con NumPy, así que un pool pequeño de hilos puede reducir el tiempo total. + +Efecto observado: + +- El tiempo smoke bajó a 9.578 s en la primera ejecución tras paralelizar. +- Las ejecuciones de seguimiento midieron 9.762 s y 9.655 s. +- Conclusión: esta es la optimización dominante. + +Riesgo: + +- Medio. +- El acceso paralelo a archivos puede comportarse distinto en discos más lentos o en una infraestructura más limitada del Challenge. + +## Plan De Rollback + +Si la submission se comporta distinto en el entorno del Challenge, revertir en este orden: + +1. Eliminar la extracción con hilos y restaurar el bucle secuencial original en `train_model`. +2. Eliminar las cachés de metadata de entrenamiento y volver a `load_demographics(...)` / `load_diagnoses(...)`. +3. Eliminar la caché de reglas de renombrado y volver a las llamadas directas a `load_rename_rules(...)`. +4. Rehabilitar la carga de anotaciones humanas solo si el vector de entrenamiento se modifica explícitamente para usar esas features. + +Este orden de rollback elimina primero la optimización de mayor riesgo y deja para el final los cambios de comportamiento más pequeños. + +## Archivos Modificados + +- `team_code.py` \ No newline at end of file diff --git a/run.ps1 b/run.ps1 new file mode 100644 index 0000000..8d2141b --- /dev/null +++ b/run.ps1 @@ -0,0 +1,262 @@ +param( + [Parameter(Mandatory=$true)] + [ValidateSet( + "build", + "smoke", + "train", + "train-smoke", + "run", + "run-smoke", + "eval", + "eval-smoke", + "train-dev", + "run-dev", + "eval-dev", + "clean" + )] + [string]$Command +) + +# ============================================ +# CONFIGURACIÓN +# ============================================ + +# >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> +# IMPORTANTE: +# Si tu dataset no está en data/training_set o data/test_set, +# modifica estas rutas. +# >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> +$TRAIN_DATA_REL = "data/training_set" +$RUN_DATA_REL = "data/test_set" +$SMOKE_DATA_REL = "data/training_smoke" + +$IMAGE_NAME = "cinc2026" + +$MODEL_FULL_REL = "model" +$MODEL_SMOKE_REL = "model_smoke" + +$OUT_FULL_REL = "outputs" +$OUT_SMOKE_REL = "outputs_smoke" +$DEMOGRAPHICS_FILE = "demographics.csv" + +# ============================================ +# FUNCIONES AUXILIARES +# ============================================ + +function Get-AbsolutePath($relativePath) { + return (Resolve-Path $relativePath).Path +} + +function Ensure-Directory($path) { + if (!(Test-Path $path)) { + New-Item -ItemType Directory -Force -Path $path | Out-Null + } +} + +function Invoke-Evaluation($DataPath, $OutputPath, $Label) { + Write-Host "Evaluating $Label predictions..." + docker run --rm ` + -v "${DataPath}:/challenge/eval_data:ro" ` + -v "${OutputPath}:/challenge/eval_outputs:ro" ` + $IMAGE_NAME ` + python evaluate_model.py -d "/challenge/eval_data/$DEMOGRAPHICS_FILE" -o "/challenge/eval_outputs/$DEMOGRAPHICS_FILE" +} + +function Invoke-EvaluationDev($CodePath, $DataPath, $OutputPath, $Label) { + Write-Host "Evaluating $Label predictions..." + docker run --rm ` + -v "${CodePath}:/challenge" ` + -v "${DataPath}:/challenge/eval_data:ro" ` + $IMAGE_NAME ` + python evaluate_model.py -d "/challenge/eval_data/$DEMOGRAPHICS_FILE" -o "$OutputPath/$DEMOGRAPHICS_FILE" +} + +function Test-DatasetHasLabels($DataPath) { + $demographicsPath = Join-Path $DataPath $DEMOGRAPHICS_FILE + if (!(Test-Path $demographicsPath)) { + return $false + } + + $header = Get-Content -Path $demographicsPath -TotalCount 1 + return $header -match "Cognitive_Impairment" +} + +# ============================================ +# COMANDOS +# ============================================ + +function Build-Image { + docker build -t $IMAGE_NAME . +} + +function Create-Smoke { + Write-Host "Creando dataset smoke..." + powershell -ExecutionPolicy Bypass -File scripts/create_smoke.ps1 +} + +function Train-Full { + + $FULL_DATA = Get-AbsolutePath $TRAIN_DATA_REL + $MODEL_FULL = Join-Path (Get-AbsolutePath ".") $MODEL_FULL_REL + + Ensure-Directory $MODEL_FULL + + docker run --rm ` + -v "${FULL_DATA}:/challenge/training_data:ro" ` + -v "${MODEL_FULL}:/challenge/model" ` + $IMAGE_NAME ` + python train_model.py -d training_data -m model -v +} + +function Train-Smoke { + + $SMOKE_DATA = Get-AbsolutePath $SMOKE_DATA_REL + $MODEL_SMOKE = Join-Path (Get-AbsolutePath ".") $MODEL_SMOKE_REL + + Ensure-Directory $MODEL_SMOKE + + docker run --rm ` + -v "${SMOKE_DATA}:/challenge/training_data:ro" ` + -v "${MODEL_SMOKE}:/challenge/model" ` + $IMAGE_NAME ` + python train_model.py -d training_data -m model -v +} + +function Run-Full { + + $RUN_DATA = Get-AbsolutePath $RUN_DATA_REL + $MODEL_FULL = Get-AbsolutePath $MODEL_FULL_REL + $OUT_FULL = Join-Path (Get-AbsolutePath ".") $OUT_FULL_REL + + Ensure-Directory $OUT_FULL + + docker run --rm ` + -v "${RUN_DATA}:/challenge/holdout_data:ro" ` + -v "${MODEL_FULL}:/challenge/model:ro" ` + -v "${OUT_FULL}:/challenge/holdout_outputs" ` + $IMAGE_NAME ` + python run_model.py -d holdout_data -m model -o holdout_outputs -v + + if (Test-DatasetHasLabels $RUN_DATA) { + Invoke-Evaluation $RUN_DATA $OUT_FULL "run-dataset" + } else { + Write-Host "Skipping evaluation for run dataset (labels not present in $RUN_DATA_REL/$DEMOGRAPHICS_FILE)." + } +} + +function Run-Smoke { + + $SMOKE_DATA = Get-AbsolutePath $SMOKE_DATA_REL + $MODEL_SMOKE = Get-AbsolutePath $MODEL_SMOKE_REL + $OUT_SMOKE = Join-Path (Get-AbsolutePath ".") $OUT_SMOKE_REL + + Ensure-Directory $OUT_SMOKE + + docker run --rm ` + -v "${SMOKE_DATA}:/challenge/holdout_data:ro" ` + -v "${MODEL_SMOKE}:/challenge/model:ro" ` + -v "${OUT_SMOKE}:/challenge/holdout_outputs" ` + $IMAGE_NAME ` + python run_model.py -d holdout_data -m model -o holdout_outputs -v + + Invoke-Evaluation $SMOKE_DATA $OUT_SMOKE "smoke" +} + +function Eval-Full { + + $RUN_DATA = Get-AbsolutePath $RUN_DATA_REL + $OUT_FULL = Get-AbsolutePath $OUT_FULL_REL + + if (Test-DatasetHasLabels $RUN_DATA) { + Invoke-Evaluation $RUN_DATA $OUT_FULL "run-dataset" + } else { + Write-Host "Skipping evaluation for run dataset (labels not present in $RUN_DATA_REL/$DEMOGRAPHICS_FILE)." + } +} + +function Eval-Smoke { + + $SMOKE_DATA = Get-AbsolutePath $SMOKE_DATA_REL + $OUT_SMOKE = Get-AbsolutePath $OUT_SMOKE_REL + + Invoke-Evaluation $SMOKE_DATA $OUT_SMOKE "smoke" +} + +# ====================== +# MODO DESARROLLO (SIN REBUILD) +# ====================== + +function Train-Dev { + + $CODE_PATH = Get-AbsolutePath "." + $SMOKE_DATA = Get-AbsolutePath $SMOKE_DATA_REL + $MODEL_SMOKE = Join-Path $CODE_PATH $MODEL_SMOKE_REL + + Ensure-Directory $MODEL_SMOKE + + docker run --rm ` + -v "${CODE_PATH}:/challenge" ` + -v "${SMOKE_DATA}:/challenge/training_data:ro" ` + -v "${MODEL_SMOKE}:/challenge/model" ` + $IMAGE_NAME ` + python train_model.py -d training_data -m model -v +} + +function Run-Dev { + + $CODE_PATH = Get-AbsolutePath "." + $SMOKE_DATA = Get-AbsolutePath $SMOKE_DATA_REL + $MODEL_SMOKE = Get-AbsolutePath $MODEL_SMOKE_REL + $OUT_SMOKE = Join-Path $CODE_PATH $OUT_SMOKE_REL + + Ensure-Directory $OUT_SMOKE + + docker run --rm ` + -v "${CODE_PATH}:/challenge" ` + -v "${SMOKE_DATA}:/challenge/holdout_data:ro" ` + -v "${MODEL_SMOKE}:/challenge/model:ro" ` + -v "${OUT_SMOKE}:/challenge/holdout_outputs" ` + $IMAGE_NAME ` + python run_model.py -d holdout_data -m model -o holdout_outputs -v + + Invoke-EvaluationDev $CODE_PATH $SMOKE_DATA "/challenge/holdout_outputs" "development smoke" +} + +function Eval-Dev { + + $CODE_PATH = Get-AbsolutePath "." + $SMOKE_DATA = Get-AbsolutePath $SMOKE_DATA_REL + + Invoke-EvaluationDev $CODE_PATH $SMOKE_DATA "/challenge/holdout_outputs" "development smoke" +} + +function Clean-All { + + Remove-Item -Recurse -Force $MODEL_FULL_REL -ErrorAction SilentlyContinue + Remove-Item -Recurse -Force $MODEL_SMOKE_REL -ErrorAction SilentlyContinue + Remove-Item -Recurse -Force $OUT_FULL_REL -ErrorAction SilentlyContinue + Remove-Item -Recurse -Force $OUT_SMOKE_REL -ErrorAction SilentlyContinue + + Write-Host "Modelos y outputs eliminados." +} + +# ============================================ +# SWITCH PRINCIPAL +# ============================================ + +switch ($Command) { + + "build" { Build-Image } + "smoke" { Create-Smoke } + "train" { Train-Full } + "train-smoke" { Train-Smoke } + "run" { Run-Full } + "run-smoke" { Run-Smoke } + "eval" { Eval-Full } + "eval-smoke" { Eval-Smoke } + "train-dev" { Train-Dev } + "run-dev" { Run-Dev } + "eval-dev" { Eval-Dev } + "clean" { Clean-All } + +} \ No newline at end of file diff --git a/run.sh b/run.sh new file mode 100644 index 0000000..dcabdbf --- /dev/null +++ b/run.sh @@ -0,0 +1,295 @@ +#!/usr/bin/env bash +set -euo pipefail + +if [[ $# -lt 1 ]]; then + echo "Usage: $0 " + exit 1 +fi + +COMMAND="$1" + +# ============================================ +# CONFIGURATION +# ============================================ + +TRAIN_DATA_REL="data/training_set" +RUN_DATA_REL="data/test_set" +SMOKE_DATA_REL="data/training_smoke" + +IMAGE_NAME="cinc2026" + +MODEL_FULL_REL="model" +MODEL_SMOKE_REL="model_smoke" + +OUT_FULL_REL="outputs" +OUT_SMOKE_REL="outputs_smoke" +DEMOGRAPHICS_FILE="demographics.csv" + +# ============================================ +# HELPERS +# ============================================ + +get_absolute_path() { + local rel_path="$1" + (cd "$rel_path" && pwd) +} + +ensure_directory() { + local dir_path="$1" + mkdir -p "$dir_path" +} + +to_docker_path() { + local host_path="$1" + + if command -v cygpath >/dev/null 2>&1; then + cygpath -m "$host_path" + else + echo "$host_path" + fi +} + +docker_cli() { + MSYS_NO_PATHCONV=1 MSYS2_ARG_CONV_EXCL="*" docker "$@" +} + +evaluate_predictions() { + local data_dir="$1" + local output_dir="$2" + local label="$3" + local data_dir_docker output_dir_docker + + data_dir_docker="$(to_docker_path "$data_dir")" + output_dir_docker="$(to_docker_path "$output_dir")" + + echo "Evaluating ${label} predictions..." + docker_cli run --rm \ + -v "${data_dir_docker}:/challenge/eval_data:ro" \ + -v "${output_dir_docker}:/challenge/eval_outputs:ro" \ + "$IMAGE_NAME" \ + python evaluate_model.py \ + -d "/challenge/eval_data/${DEMOGRAPHICS_FILE}" \ + -o "/challenge/eval_outputs/${DEMOGRAPHICS_FILE}" +} + +evaluate_predictions_dev() { + local code_path="$1" + local data_path="$2" + local output_path="$3" + local label="$4" + local code_path_docker data_path_docker + + code_path_docker="$(to_docker_path "$code_path")" + data_path_docker="$(to_docker_path "$data_path")" + + echo "Evaluating ${label} predictions..." + docker_cli run --rm \ + -v "${code_path_docker}:/challenge" \ + -v "${data_path_docker}:/challenge/eval_data:ro" \ + "$IMAGE_NAME" \ + python evaluate_model.py \ + -d "/challenge/eval_data/${DEMOGRAPHICS_FILE}" \ + -o "$output_path/${DEMOGRAPHICS_FILE}" +} + +dataset_has_labels() { + local data_dir="$1" + local demographics_path="$data_dir/$DEMOGRAPHICS_FILE" + + [[ -f "$demographics_path" ]] && head -n 1 "$demographics_path" | grep -q "Cognitive_Impairment" +} + +build_image() { + docker_cli build -t "$IMAGE_NAME" . +} + +create_smoke() { + echo "Creating smoke dataset..." + bash scripts/create_smoke.sh +} + +train_full() { + local full_data model_full + local full_data_docker model_full_docker + + full_data="$(get_absolute_path "$TRAIN_DATA_REL")" + model_full="$(get_absolute_path ".")/${MODEL_FULL_REL}" + full_data_docker="$(to_docker_path "$full_data")" + model_full_docker="$(to_docker_path "$model_full")" + + ensure_directory "$model_full" + + docker_cli run --rm \ + -v "${full_data_docker}:/challenge/training_data:ro" \ + -v "${model_full_docker}:/challenge/model" \ + "$IMAGE_NAME" \ + python train_model.py -d training_data -m model -v +} + +train_smoke() { + local smoke_data model_smoke + local smoke_data_docker model_smoke_docker + + smoke_data="$(get_absolute_path "$SMOKE_DATA_REL")" + model_smoke="$(get_absolute_path ".")/${MODEL_SMOKE_REL}" + smoke_data_docker="$(to_docker_path "$smoke_data")" + model_smoke_docker="$(to_docker_path "$model_smoke")" + + ensure_directory "$model_smoke" + + docker_cli run --rm \ + -v "${smoke_data_docker}:/challenge/training_data:ro" \ + -v "${model_smoke_docker}:/challenge/model" \ + "$IMAGE_NAME" \ + python train_model.py -d training_data -m model -v +} + +run_full() { + local run_data model_full out_full + local run_data_docker model_full_docker out_full_docker + + run_data="$(get_absolute_path "$RUN_DATA_REL")" + model_full="$(get_absolute_path "$MODEL_FULL_REL")" + out_full="$(get_absolute_path ".")/${OUT_FULL_REL}" + run_data_docker="$(to_docker_path "$run_data")" + model_full_docker="$(to_docker_path "$model_full")" + out_full_docker="$(to_docker_path "$out_full")" + + ensure_directory "$out_full" + + docker_cli run --rm \ + -v "${run_data_docker}:/challenge/holdout_data:ro" \ + -v "${model_full_docker}:/challenge/model:ro" \ + -v "${out_full_docker}:/challenge/holdout_outputs" \ + "$IMAGE_NAME" \ + python run_model.py -d holdout_data -m model -o holdout_outputs -v + + if dataset_has_labels "$run_data"; then + evaluate_predictions "$run_data" "$out_full" "run-dataset" + else + echo "Skipping evaluation for run dataset (labels not present in ${RUN_DATA_REL}/${DEMOGRAPHICS_FILE})." + fi +} + +run_smoke() { + local smoke_data model_smoke out_smoke + local smoke_data_docker model_smoke_docker out_smoke_docker + + smoke_data="$(get_absolute_path "$SMOKE_DATA_REL")" + model_smoke="$(get_absolute_path "$MODEL_SMOKE_REL")" + out_smoke="$(get_absolute_path ".")/${OUT_SMOKE_REL}" + smoke_data_docker="$(to_docker_path "$smoke_data")" + model_smoke_docker="$(to_docker_path "$model_smoke")" + out_smoke_docker="$(to_docker_path "$out_smoke")" + + ensure_directory "$out_smoke" + + docker_cli run --rm \ + -v "${smoke_data_docker}:/challenge/holdout_data:ro" \ + -v "${model_smoke_docker}:/challenge/model:ro" \ + -v "${out_smoke_docker}:/challenge/holdout_outputs" \ + "$IMAGE_NAME" \ + python run_model.py -d holdout_data -m model -o holdout_outputs -v + + evaluate_predictions "$smoke_data" "$out_smoke" "smoke" +} + +eval_full() { + local run_data out_full + + run_data="$(get_absolute_path "$RUN_DATA_REL")" + out_full="$(get_absolute_path "$OUT_FULL_REL")" + + if dataset_has_labels "$run_data"; then + evaluate_predictions "$run_data" "$out_full" "run-dataset" + else + echo "Skipping evaluation for run dataset (labels not present in ${RUN_DATA_REL}/${DEMOGRAPHICS_FILE})." + fi +} + +eval_smoke() { + local smoke_data out_smoke + + smoke_data="$(get_absolute_path "$SMOKE_DATA_REL")" + out_smoke="$(get_absolute_path "$OUT_SMOKE_REL")" + + evaluate_predictions "$smoke_data" "$out_smoke" "smoke" +} + +# ===================== +# DEVELOPMENT MODE (NO REBUILD) +# ===================== + +train_dev() { + local code_path smoke_data model_smoke + local code_path_docker smoke_data_docker + + code_path="$(get_absolute_path ".")" + smoke_data="$(get_absolute_path "$SMOKE_DATA_REL")" + model_smoke="${code_path}/${MODEL_SMOKE_REL}" + code_path_docker="$(to_docker_path "$code_path")" + smoke_data_docker="$(to_docker_path "$smoke_data")" + + ensure_directory "$model_smoke" + + docker_cli run --rm \ + -v "${code_path_docker}:/challenge" \ + -v "${smoke_data_docker}:/challenge/data_smoke:ro" \ + "$IMAGE_NAME" \ + python train_model.py -d /challenge/data_smoke -m /challenge/model_smoke -v +} + +run_dev() { + local code_path smoke_data out_smoke + local code_path_docker smoke_data_docker + + code_path="$(get_absolute_path ".")" + smoke_data="$(get_absolute_path "$SMOKE_DATA_REL")" + out_smoke="${code_path}/${OUT_SMOKE_REL}" + code_path_docker="$(to_docker_path "$code_path")" + smoke_data_docker="$(to_docker_path "$smoke_data")" + + ensure_directory "$out_smoke" + + docker_cli run --rm \ + -v "${code_path_docker}:/challenge" \ + -v "${smoke_data_docker}:/challenge/data_smoke:ro" \ + "$IMAGE_NAME" \ + python run_model.py -d /challenge/data_smoke -m /challenge/model_smoke -o /challenge/outputs_smoke -v + + evaluate_predictions_dev "$code_path" "$smoke_data" "/challenge/outputs_smoke" "development smoke" +} + +eval_dev() { + local code_path smoke_data + + code_path="$(get_absolute_path ".")" + smoke_data="$(get_absolute_path "$SMOKE_DATA_REL")" + + evaluate_predictions_dev "$code_path" "$smoke_data" "/challenge/outputs_smoke" "development smoke" +} + +clean_all() { + rm -rf "$MODEL_FULL_REL" "$MODEL_SMOKE_REL" "$OUT_FULL_REL" "$OUT_SMOKE_REL" + echo "Models and outputs removed." +} + +case "$COMMAND" in + build) build_image ;; + smoke) create_smoke ;; + train) train_full ;; + train-smoke) train_smoke ;; + run) run_full ;; + run-smoke) run_smoke ;; + eval) eval_full ;; + eval-smoke) eval_smoke ;; + train-dev) train_dev ;; + run-dev) run_dev ;; + eval-dev) eval_dev ;; + clean) clean_all ;; + *) + echo "Invalid command: $COMMAND" + echo "Valid commands: build, smoke, train, train-smoke, run, run-smoke, eval, eval-smoke, train-dev, run-dev, eval-dev, clean" + exit 1 + ;; +esac diff --git a/scripts/create_smoke.ps1 b/scripts/create_smoke.ps1 new file mode 100644 index 0000000..c8f45dd --- /dev/null +++ b/scripts/create_smoke.ps1 @@ -0,0 +1,98 @@ +# ============================================ +# Create smoke training dataset +# ============================================ + +# >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> +# IMPORTANT: +# Each team member must modify this path to +# match their local dataset location. +# >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> +$FULL_DATA_PATH = "data/training_set" # <-- CHANGE THIS IF NEEDED + +$SMOKE_PATH = "data/training_smoke" +$N_RECORDS = 5 + +Write-Host "Creating smoke dataset..." +Write-Host "Source: $FULL_DATA_PATH" +Write-Host "Destination: $SMOKE_PATH" + +Remove-Item -Recurse -Force $SMOKE_PATH -ErrorAction SilentlyContinue +New-Item -ItemType Directory -Force -Path $SMOKE_PATH | Out-Null + +$selectedRecords = New-Object System.Collections.Generic.List[object] + +# Select first N EDF files +$edfs = Get-ChildItem "$FULL_DATA_PATH/physiological_data" -Recurse -Filter *.edf | + Sort-Object FullName | + Select-Object -First $N_RECORDS + +foreach ($f in $edfs) { + $rel = $f.FullName.Substring((Resolve-Path $FULL_DATA_PATH).Path.Length).TrimStart('\') + $target = Join-Path $SMOKE_PATH $rel + New-Item -ItemType Directory -Force -Path (Split-Path $target) | Out-Null + Copy-Item $f.FullName $target + + $stem = [System.IO.Path]::GetFileNameWithoutExtension($f.Name) + $parts = $stem -split '_ses-' + $selectedRecords.Add([pscustomobject]@{ + SiteID = $f.Directory.Name + Patient = $parts[0] + Session = $parts[1] + }) | Out-Null +} + +# Copy only annotation EDFs for the selected smoke records. +foreach ($record in $selectedRecords) { + $algoSource = Join-Path $FULL_DATA_PATH "algorithmic_annotations/$($record.SiteID)/$($record.Patient)_ses-$($record.Session)_caisr_annotations.edf" + $algoTarget = Join-Path $SMOKE_PATH "algorithmic_annotations/$($record.SiteID)/$($record.Patient)_ses-$($record.Session)_caisr_annotations.edf" + if (Test-Path $algoSource) { + New-Item -ItemType Directory -Force -Path (Split-Path $algoTarget) | Out-Null + Copy-Item $algoSource $algoTarget + } + + $humanSource = Join-Path $FULL_DATA_PATH "human_annotations/$($record.SiteID)/$($record.Patient)_ses-$($record.Session)_expert_annotations.edf" + $humanTarget = Join-Path $SMOKE_PATH "human_annotations/$($record.SiteID)/$($record.Patient)_ses-$($record.Session)_expert_annotations.edf" + if (Test-Path $humanSource) { + New-Item -ItemType Directory -Force -Path (Split-Path $humanTarget) | Out-Null + Copy-Item $humanSource $humanTarget + } +} + +# Filter demographics to the copied smoke records. +$env:SMOKE_FULL_DATA_PATH = (Resolve-Path $FULL_DATA_PATH).Path +$env:SMOKE_PATH = (Resolve-Path $SMOKE_PATH).Path +python -c @" +import csv +import os +from pathlib import Path + +full_data = Path(os.environ['SMOKE_FULL_DATA_PATH']) +smoke_path = Path(os.environ['SMOKE_PATH']) + +source_csv = full_data / 'demographics.csv' +target_csv = smoke_path / 'demographics.csv' +phys_root = smoke_path / 'physiological_data' + +selected_records = set() +for edf_path in phys_root.rglob('*.edf'): + site_id = edf_path.parent.name + patient_part, session_part = edf_path.stem.rsplit('_ses-', 1) + selected_records.add((site_id, patient_part, session_part)) + +with source_csv.open('r', newline='', encoding='utf-8') as source_file: + reader = csv.DictReader(source_file) + rows = [ + row for row in reader + if (row['SiteID'], row['BidsFolder'], str(row['SessionID'])) in selected_records + ] + fieldnames = reader.fieldnames + +with target_csv.open('w', newline='', encoding='utf-8') as target_file: + writer = csv.DictWriter(target_file, fieldnames=fieldnames) + writer.writeheader() + writer.writerows(rows) +"@ +Remove-Item Env:SMOKE_FULL_DATA_PATH -ErrorAction SilentlyContinue +Remove-Item Env:SMOKE_PATH -ErrorAction SilentlyContinue + +Write-Host "Smoke dataset created successfully." \ No newline at end of file diff --git a/scripts/create_smoke.sh b/scripts/create_smoke.sh new file mode 100644 index 0000000..88765ec --- /dev/null +++ b/scripts/create_smoke.sh @@ -0,0 +1,93 @@ +#!/usr/bin/env bash +set -euo pipefail + +# ============================================ +# Create smoke training dataset +# ============================================ + +# >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> +# IMPORTANT: +# Each team member can modify this path to +# match their local dataset location. +# >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> +FULL_DATA_PATH="${FULL_DATA_PATH:-data/training_set}" # Override with env var if needed + +SMOKE_PATH="data/training_smoke" +N_RECORDS="${N_RECORDS:-5}" + +echo "Creating smoke dataset..." +echo "Source: ${FULL_DATA_PATH}" +echo "Destination: ${SMOKE_PATH}" + +rm -rf "${SMOKE_PATH}" +mkdir -p "${SMOKE_PATH}" + +selected_records_file="$(mktemp)" +trap 'rm -f "${selected_records_file}"' EXIT + +# Select first N EDF files +while IFS= read -r file_path; do + rel_path="${file_path#${FULL_DATA_PATH}/}" + target_path="${SMOKE_PATH}/${rel_path}" + mkdir -p "$(dirname "${target_path}")" + cp "${file_path}" "${target_path}" + stem="$(basename "${file_path}" .edf)" + patient_part="${stem%_ses-*}" + session_part="${stem##*_ses-}" + site_id="$(basename "$(dirname "${file_path}")")" + printf '%s,%s,%s\n' "${site_id}" "${patient_part}" "${session_part}" >> "${selected_records_file}" +done < <( + find "${FULL_DATA_PATH}/physiological_data" -type f -name "*.edf" | sort | head -n "${N_RECORDS}" +) + +# Copy only annotation EDFs for the selected smoke records. +while IFS=',' read -r site_id patient_part session_part; do + algo_source="${FULL_DATA_PATH}/algorithmic_annotations/${site_id}/${patient_part}_ses-${session_part}_caisr_annotations.edf" + algo_target="${SMOKE_PATH}/algorithmic_annotations/${site_id}/${patient_part}_ses-${session_part}_caisr_annotations.edf" + if [[ -f "${algo_source}" ]]; then + mkdir -p "$(dirname "${algo_target}")" + cp "${algo_source}" "${algo_target}" + fi + + human_source="${FULL_DATA_PATH}/human_annotations/${site_id}/${patient_part}_ses-${session_part}_expert_annotations.edf" + human_target="${SMOKE_PATH}/human_annotations/${site_id}/${patient_part}_ses-${session_part}_expert_annotations.edf" + if [[ -f "${human_source}" ]]; then + mkdir -p "$(dirname "${human_target}")" + cp "${human_source}" "${human_target}" + fi +done < "${selected_records_file}" + +# Filter demographics to the copied smoke records. +python - <<'PY' +import csv +from pathlib import Path + +full_data = Path("data/training_set") +smoke_path = Path("data/training_smoke") + +source_csv = full_data / "demographics.csv" +target_csv = smoke_path / "demographics.csv" +phys_root = smoke_path / "physiological_data" + +selected_records = set() +for edf_path in phys_root.rglob("*.edf"): + site_id = edf_path.parent.name + stem = edf_path.stem + patient_part, session_part = stem.rsplit("_ses-", 1) + selected_records.add((site_id, patient_part, session_part)) + +with source_csv.open("r", newline="", encoding="utf-8") as source_file: + reader = csv.DictReader(source_file) + rows = [ + row for row in reader + if (row["SiteID"], row["BidsFolder"], str(row["SessionID"])) in selected_records + ] + fieldnames = reader.fieldnames + +with target_csv.open("w", newline="", encoding="utf-8") as target_file: + writer = csv.DictWriter(target_file, fieldnames=fieldnames) + writer.writeheader() + writer.writerows(rows) +PY + +echo "Smoke dataset created successfully." diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..6319343 --- /dev/null +++ b/src/__init__.py @@ -0,0 +1 @@ +"""Project source package.""" \ No newline at end of file diff --git a/src/eeg_processing.py b/src/eeg_processing.py new file mode 100644 index 0000000..cf95237 --- /dev/null +++ b/src/eeg_processing.py @@ -0,0 +1,152 @@ +"""EEG_processing.py + +Este módulo contiene funciones para procesar datos EEG de los +hospitales incluidos en el desafío CincChallenge 2026. La principal +función definida es `MetricasHospitlal`, que recorre los archivos EDF +correspondientes a un hospital concreto, extrae las señales EEG, +las filtra, normaliza, crea épocas y calcula potencias de banda y +complejidades. Los resultados se guardan en un CSV resumen por +hospital. + +Características principales: + +- Soporta datos tanto del conjunto de entrenamiento como del + conjunto suplementario. +- Selección automática de canales EEG a partir de la tabla + `notebooks/channel_table.csv`. +- Creación de canales bipolares si están disponibles. +- Filtrado de banda 0.3-35 Hz y normalización de la señal. +- Re-muestreo a 200 Hz si fuese necesario. +- Cálculo de potencias de banda y complejidades usando + funciones auxiliares (`lib/EEG_functions.py`). +- Exportación de resultados en `results_summaryEEG_{hospital}.csv`. + +Uso típico: + +>>> from src.scripts.EEG_processing import MetricasHospitlal +>>> MetricasHospitlal('I0002') + +El módulo depende de `numpy`, `pandas`, `matplotlib`, `plotly` y de +las utilidades definidas en `lib/helper_code` y `lib/EEG_functions`. +""" +import sys +import os +import pandas as pd +import numpy as np +import helper_code as helper_code +from .lib import EEG_functions + + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +EEG_FEATURE_NAMES = [ + 'EEG_Channel_Count', + 'EEG_Rel_Delta', + 'EEG_Rel_Theta', + 'EEG_Rel_Alpha', + 'EEG_Rel_Sigma', + 'EEG_Rel_Beta', + 'EEG_Theta_Alpha_Ratio', + 'EEG_Hjorth_Complexity', +] +EEG_FEATURE_LENGTH = len(EEG_FEATURE_NAMES) + + +def _normalize_label(text): + normalized = ''.join(ch if ch.isalnum() else ' ' for ch in str(text).lower()) + return ' '.join(normalized.split()) + + +def _split_aliases(raw_aliases): + return {_normalize_label(alias) for alias in str(raw_aliases).split(';') if alias} + + +def _build_eeg_aliases(channels): + eeg_rows = channels[channels['Category'].eq('eeg')] + aliases = set() + for _, row in eeg_rows.iterrows(): + aliases.update(_split_aliases(row['Channel_Names'])) + return aliases + + +def _resample_signal(signal, fs, target_fs): + signal = np.asarray(signal, dtype=float) + if signal.size == 0: + return signal, target_fs + if fs == target_fs: + return signal, target_fs + + duration = signal.size / fs + target_samples = max(1, int(round(duration * target_fs))) + time_original = np.linspace(0, duration, signal.size) + time_target = np.linspace(0, duration, target_samples) + return np.interp(time_target, time_original, signal), target_fs + + +def _extract_channel_metrics(signal, fs): + signal = np.nan_to_num(np.asarray(signal, dtype=float), nan=0.0, posinf=0.0, neginf=0.0) + if signal.size < max(int(fs * 30), 2): + return None + + if fs != 200: + signal, fs = _resample_signal(signal, fs, 200) + + filtered = EEG_functions.butter_bandpass_filter(signal, lowcut=0.3, highcut=35, fs=fs, order=4) + signal_std = np.std(filtered) + if signal_std == 0 or not np.isfinite(signal_std): + return None + + normalized = (filtered - np.mean(filtered)) / signal_std + epochs = EEG_functions.create_epochs(normalized, fs, epoch_duration=30) + if epochs.size == 0: + return None + + band_powers, complexities = EEG_functions.extract_band_powers(epochs, fs, win_len=15) + if len(band_powers) > 60: + band_powers = band_powers.iloc[60:] + complexities = complexities.iloc[60:] + if band_powers.empty: + return None + + total_power = band_powers.sum(axis=1).replace(0, np.nan) + relative_powers = band_powers.div(total_power, axis=0).replace([np.inf, -np.inf], np.nan).fillna(0.0).mean() + alpha_power = float(relative_powers.get('Alpha', 0.0)) + theta_power = float(relative_powers.get('Theta', 0.0)) + theta_alpha_ratio = theta_power / alpha_power if alpha_power > 0 else 0.0 + + complexity_mean = float( + complexities['Hjorth_Complexity'].replace([np.inf, -np.inf], np.nan).fillna(0.0).mean() + ) if 'Hjorth_Complexity' in complexities else 0.0 + + return np.array([ + float(relative_powers.get('Delta', 0.0)), + theta_power, + alpha_power, + float(relative_powers.get('Sigma', 0.0)), + float(relative_powers.get('Beta', 0.0)), + float(theta_alpha_ratio), + complexity_mean, + ], dtype=np.float32) + + +def processEEG(physiological_data, physiological_fs, csv_path): + channels = pd.read_csv(csv_path) + eeg_aliases = _build_eeg_aliases(channels) + channel_metrics = [] + + for label, signal in physiological_data.items(): + if label not in physiological_fs: + continue + if _normalize_label(label) not in eeg_aliases: + continue + + metrics = _extract_channel_metrics(signal, physiological_fs[label]) + if metrics is not None: + channel_metrics.append(metrics) + + if not channel_metrics: + return np.zeros(EEG_FEATURE_LENGTH, dtype=np.float32) + + stacked = np.vstack(channel_metrics) + aggregated = np.mean(stacked, axis=0) + return np.hstack([np.array([len(channel_metrics)], dtype=np.float32), aggregated]).astype(np.float32) \ No newline at end of file diff --git a/src/lib/ECG_processing.py b/src/lib/ECG_processing.py new file mode 100644 index 0000000..6206e23 --- /dev/null +++ b/src/lib/ECG_processing.py @@ -0,0 +1,152 @@ +import numpy as np +import pandas as pd +from scipy.signal import butter, filtfilt, resample +from .lib.pan_tompkins import pan_tompkin +from .lib.compute_hrv_hrf import compute_HRV_HRF +from .lib.interpolate_NN import interpolate_NN_pchip +from .lib.remove_ectopic_beat import remove_ectopic_beats +def ECGprocessing(ecg_signal, fs, patient_id): + + all_results = pd.DataFrame() + + ecg_signal = ecg_signal - np.mean(ecg_signal) + + # =============================== + # RESAMPLE TO 200 Hz IF NEEDED + # =============================== + target_fs = 200 + + if fs != target_fs: + num_samples = int(len(ecg_signal) * target_fs / fs) + ecg_signal = resample(ecg_signal, num_samples) + fs = target_fs + + # =============================== + # SEGMENT INTO 5-MIN WINDOWS + # =============================== + win_sec = 300 + win_samples = int(win_sec * fs) + + N = len(ecg_signal) + n_windows = N // win_samples + + if n_windows == 0: + print("Signal too short.") + return None + + # =============================== + # FIND VALID WINDOWS + # =============================== + valid_windows = [] + + for w in range(n_windows): + + idx_start = w * win_samples + idx_end = (w + 1) * win_samples + + ecg_win = ecg_signal[idx_start:idx_end] + + # Quality check + if np.sum(np.isnan(ecg_win)) != 0 or np.sum(ecg_win == 0) > 0.2 * len(ecg_win): + continue + + valid_windows.append(w) + + # =============================== + # PROCESS WINDOWS + # =============================== + HRV_all = [] + + for w in valid_windows: + + idx_start = w * win_samples + idx_end = (w + 1) * win_samples + + ecg_win = ecg_signal[idx_start:idx_end] + + ecg_win = ecg_win - np.mean(ecg_win) + + # --- Filtering --- + # Notch + b, a = butter(3, [59.5/(fs/2), 60.5/(fs/2)], btype='bandstop') + ecg_win = filtfilt(b, a, ecg_win) + + # High-pass + b, a = butter(3, 0.5/(fs/2), btype='high') + ecg_win = filtfilt(b, a, ecg_win) + + # Low-pass + b, a = butter(3, 45/(fs/2), btype='low') + ecg_win = filtfilt(b, a, ecg_win) + + # =============================== + # QRS DETECTION + # =============================== + qrs_amp_raw, R_locs, delay = pan_tompkin(ecg_win, fs, 0) + + if len(R_locs) < 150: + continue + + # =============================== + # NN INTERVALS + # =============================== + NN = np.diff(R_locs) / fs + + # =============================== + # HRV PREPROCESSING + # =============================== + NN, ectopic_perc = remove_ectopic_beats(NN, 40, 0.10) + + NN = interpolate_NN_pchip(NN, 2) + + valid_ratio = np.sum(~np.isnan(NN)) / len(NN) + NN = NN[~np.isnan(NN)] + + if valid_ratio < 0.75: + continue + + # =============================== + # HRV + HRF METRICS + # =============================== + res = compute_HRV_HRF(NN, fs) + + meanNN = np.mean(NN) + + HRV_all.append([ + meanNN, + res["PIP"], res["PNNLS"], res["PNNSS"], + res["AVNN"], res["SDNN"], res["RMSSD"], res["HF"], + ectopic_perc + ]) + + # =============================== + # SUBJECT-LEVEL METRICS + # =============================== + if len(HRV_all) == 0: + print("No valid windows.") + return None + + HRV_all = np.array(HRV_all) + + median_vals = np.nanmedian(HRV_all, axis=0) + std_vals = np.nanstd(HRV_all, axis=0) + + # =============================== + # SAVE RESULTS (DataFrame row) + # =============================== + row = pd.DataFrame([{ + "ID": patient_id, + "mNNmed": median_vals[0], "mNNstd": std_vals[0], + "PIP_med": median_vals[1], "PIP_std": std_vals[1], + "PNNLS_med": median_vals[2], "PNNLS_std": std_vals[2], + "PNNSS_med": median_vals[3], "PNNSS_std": std_vals[3], + "AVNN_med": median_vals[4], "AVNN_std": std_vals[4], + "SDNN_med": median_vals[5], "SDNN_std": std_vals[5], + "RMSSD_med": median_vals[6], "RMSSD_std": std_vals[6], + "HF_med": median_vals[7], "HF_std": std_vals[7], + "ECTOPIC_med": median_vals[8], "ECTOPIC_std": std_vals[8], + }]) + + all_results = pd.concat([all_results, row], ignore_index=True) + + return all_results \ No newline at end of file diff --git a/src/lib/EEG_functions.py b/src/lib/EEG_functions.py new file mode 100644 index 0000000..8c7137c --- /dev/null +++ b/src/lib/EEG_functions.py @@ -0,0 +1,339 @@ +from scipy.signal import butter, filtfilt +import numpy as np +from scipy.signal import welch +import pandas as pd +from scipy import signal +from scipy.stats import kurtosis, entropy + +# try: +# import plotly.graph_objects as go +# from plotly.subplots import make_subplots +# except ModuleNotFoundError: +# go = None +# make_subplots = None + +# try: +# import matplotlib.pyplot as plt +# except ModuleNotFoundError: +# plt = None + +def _safe_sqrt_variance_ratio(numerator_signal, denominator_signal): + numerator_var = np.var(numerator_signal) + denominator_var = np.var(denominator_signal) + if denominator_var <= 0 or not np.isfinite(denominator_var): + return 0.0 + ratio = numerator_var / denominator_var + if ratio <= 0 or not np.isfinite(ratio): + return 0.0 + return float(np.sqrt(ratio)) + +def butter_bandpass_filter(data, lowcut, highcut, fs, order=4): + nyq = 0.5 * fs # Frecuencia de Nyquist + low = lowcut / nyq + high = highcut / nyq + b, a = butter(order, [low, high], btype='bandpass') + # Usamos filtfilt para que no haya desfase en la señal + # w, h = signal.freqz(b, a, worN=8000) + # frequencies = (w * fs) / (2 * np.pi) + # plt.figure(figsize=(10, 5)) + # plt.plot(frequencies, 20 * np.log10(abs(h))) + # plt.xlim(0, highcut + 20) + # plt.ylim(-40, 5) # Para ver bien la caída + # plt.title('Respuesta Frecuencial Digital (Bandpass)') + # plt.xlabel('Frecuencia [Hz]') + # plt.ylabel('Amplitud [dB]') + # plt.grid(which='both', axis='both') + # plt.axvline(lowcut, color='red', linestyle='--', label='Lowcut') + # plt.axvline(highcut, color='red', linestyle='--', label='Highcut') + # plt.legend() + # plt.show() + y = filtfilt(b, a, data) + return y + +# def plot_EEG(df, columns, fs = 200): +# if go is None or make_subplots is None: +# raise ModuleNotFoundError("plotly is required for plot_EEG") + +# fig = make_subplots(rows=len(columns), cols=1, +# shared_xaxes=True, +# vertical_spacing=0.02, +# subplot_titles=columns) +# limit = int(3000 * fs) +# x = np.arange(df[0].shape[0]) / fs # Asumiendo fs=100Hz, ajusta si es diferente +# downsample = 10 # Factor de downsampling para mejorar rendimiento (ajusta según necesidad) +# for i, col in enumerate(columns): +# fig.add_trace( +# go.Scattergl(x=x[:limit:downsample], y=df[i][:limit:downsample], name=col, mode='lines'), +# row=i+1, col=1 +# ) +# fig.update_layout( +# height=900, +# title_text="Polisomnografía - Canales EEG", +# showlegend=False, +# template="plotly_white" +# ) +# fig.update_xaxes(title_text="Tiempo (segundos)", row=len(columns), col=1) +# fig.show() + +# def plot_EEG_sel(sel, name = "EEG_plot_raw.html"): +# if go is None or make_subplots is None: +# raise ModuleNotFoundError("plotly is required for plot_EEG_sel") + +# fig = make_subplots(rows=len(sel), cols=1, +# shared_xaxes=True, +# vertical_spacing=0.02, +# subplot_titles=[ch[1].label for ch in sel]) + +# for i, (idx, sig) in enumerate(sel): +# # Crear eje de tiempo en segundos +# fs = sig.sampling_frequency +# time = np.linspace(0, len(sig.data) / fs, len(sig.data)) + +# # Añadir traza (solo mostramos los primeros 30s por defecto para no saturar el navegador) +# # Puedes quitar el slice [:int(30*fs)] para ver todo, pero cuidado con el rendimiento +# limit = int(3000 * fs) +# # limit = len(sig.data) if limit > len(sig.data) else limit +# # limit = len(sig.data) +# # downsample = 10 # Factor de downsampling para mejorar rendimiento (ajusta según necesidad) +# fig.add_trace( +# go.Scattergl(x=time[:limit], y=sig.data[:limit], name=sig.label, mode='lines'), +# row=i+1, col=1 +# ) + +# fig.update_layout( +# height=900, +# title_text="Polisomnografía - Canales EEG", +# showlegend=False, +# template="plotly_white" +# ) + +# fig.update_xaxes(title_text="Tiempo (segundos)", row=len(sel), col=1) +# fig.write_html(f"graphs/{name}.html") # Guardar como HTML para visualización interactiva +# # fig.show() + +def filtering_and_normalization(sig, sig_fs): + b, a = signal.butter(4, 0.3, btype='highpass', fs=sig_fs) + sig_filtered = signal.filtfilt(b, a, sig) + b, a = signal.butter(4, 35, btype='lowpass', fs=sig_fs) + sig_filtered = signal.filtfilt(b, a, sig_filtered) + sig_filtered = normalize(sig_filtered) + return sig_filtered + +def normalize(x): + return (x - np.mean(x)) / np.std(x) + +def remove_impulse_artifacts(sig): + # Square of second derivative + aux = np.diff(np.diff(sig)) ** 2 + aux = np.insert(aux, 0, aux[0]) + aux = np.append(aux, aux[-1]) + + # Median filter threshold + wind = 999 + if aux.size < wind: + wind = aux.size + if (wind % 2) != 1: + wind = wind - 1 + mf = signal.medfilt(aux, wind) + + # Find impulses + margin = 20 + impulses = np.asarray(np.where(aux > mf + 0.005)).ravel() + for impulse in impulses: + impulses = np.append(impulses, np.arange(impulse - margin, min(impulse + margin+1, sig.size))) + impulses = np.sort(impulses) + impulses = np.unique(impulses) + impulses = impulses[impulses >= 0] + + # Remove impulses + output = sig + output[impulses] = np.nan + return output + +def clean_movement_artifacts(data, fs, threshold_z=10, window_ms=500): + """ + Identifica y limpia artefactos de gran amplitud. + + Args: + data: Array de la señal. + fs: Frecuencia de muestreo. + threshold_z: Umbral de desviaciones estándar para marcar como artefacto. + window_ms: Tiempo alrededor del artefacto a limpiar para asegurar + que eliminamos la subida y bajada del pico. + """ + cleaned_data = data.copy() + + # 1. Calcular Z-Score de la amplitud + z_scores = np.abs((data - np.mean(data)) / np.std(data)) + + # 2. Encontrar índices que superan el umbral + mask = z_scores > threshold_z + + # 3. Expandir la máscara (el movimiento suele durar un poco más que el pico) + padding = int((window_ms / 1000) * fs) + expanded_mask = np.convolve(mask, np.ones(padding), mode='same') > 0 + + # 4. Reemplazar artefactos con el valor medio (0 si está centrada) + cleaned_data[expanded_mask] = 0 + + artifacts_percentage = (np.sum(expanded_mask) / len(data)) * 100 + print(f"Artefactos eliminados: {artifacts_percentage:.2f}% de la señal.") + + return cleaned_data + +def adaptive_variance_cleaner(data, fs, win_size_ms=500, alpha=0.1, threshold=3.5): + """ + Filtro adaptativo que detecta artefactos cuando la varianza local + excede significativamente la varianza histórica adaptativa. + + Args: + data: Array de la señal (1D). + fs: Frecuencia de muestreo. + win_size_ms: Tamaño de la ventana para calcular la varianza local. + alpha: Factor de adaptación (0 a 1). Cuanto más alto, más rápido olvida el pasado. + threshold: Multiplicador de la varianza adaptativa para marcar artefacto. + """ + win_samples = int((win_size_ms / 1000) * fs) + n_samples = len(data) + cleaned_data = np.copy(data) + + # Inicializamos la varianza adaptativa con la varianza de la primera ventana + first_win = data[:win_samples] + adaptive_var = np.var(first_win) + + # Para guardar dónde detectamos artefactos + artifact_mask = np.zeros(n_samples, dtype=bool) + + # Iteramos por ventanas + for i in range(0, n_samples - win_samples, win_samples): + current_win_idx = slice(i, i + win_samples) + current_var = np.var(data[current_win_idx]) + + # Si la varianza actual es mucho mayor que la adaptativa, es un artefacto + if current_var > threshold * adaptive_var: + artifact_mask[current_win_idx] = True + cleaned_data[current_win_idx] = 0 # O podrías interpolar + # No actualizamos la varianza adaptativa con un artefacto para no "contaminarla" + else: + # Actualización adaptativa (Exponential Moving Average) + adaptive_var = alpha * current_var + (1 - alpha) * adaptive_var + + return cleaned_data, artifact_mask + +def create_epochs(data, fs, epoch_duration=30): + samples_per_epoch = int(fs * epoch_duration) + num_epochs = len(data) // samples_per_epoch + + # Recortamos la señal para que sea divisible exactamente + data_trimmed = data[:num_epochs * samples_per_epoch] + + # Reshape: (Número de épocas, Puntos por época) + epochs = data_trimmed.reshape(num_epochs, samples_per_epoch) + return epochs + +def extract_band_powers(epochs, fs, win_len = 2): + features = [] + complexities = [] + # Definición de las bandas + bands = { + 'Delta': (0.5, 4), + 'Theta': (4, 8), + 'Alpha': (8, 12), + 'Sigma': (11, 16), + 'Beta': (12, 30) + } + + for epoch in epochs: + # Calcular PSD + freqs, psd = welch(epoch, fs, nperseg=fs*30) # Ventanas de 2 seg para buena resolución + # Plot de PSD para verificar que las bandas se ven bien (opcional) + # plt.semilogy(freqs, psd) + # plt.show() + epoch_features = {} + for band_name, (low, high) in bands.items(): + # Encontrar índices de frecuencia para la banda actual + idx_band = np.logical_and(freqs >= low, freqs <= high) + # Calcular la potencia media en esa banda + epoch_features[band_name] = np.mean(psd[idx_band]) + + features.append(epoch_features) + + diff = np.diff(epoch) + mobility = _safe_sqrt_variance_ratio(diff, epoch) + # 2. Complejidad de Hjorth: Qué tan similar es la señal a una onda senoidal + diff2 = np.diff(diff) + mobility_diff = _safe_sqrt_variance_ratio(diff2, diff) + complexity = mobility_diff / mobility if mobility > 0 else 0 + complexities.append({'Hjorth_Mobility': mobility, 'Hjorth_Complexity': complexity}) + + return pd.DataFrame(features), pd.DataFrame(complexities) + +def get_patient_profile(df_features): + # 1. Calcular Potencia Total por época + total_power = df_features.sum(axis=1) + avg_p = df_features.mean() + total_avg_p = avg_p.sum() + + # 2. Variabilidad (Refleja microdespertares y fragmentación) + # Coeficiente de Variación (CV = std/mean) para normalizar por amplitud + variability = df_features.std() / df_features.mean() + variability.index = ['CV_' + col for col in variability.index] + + # 3. Curtosis (Picos súbitos de actividad) + kurt = df_features.apply(kurtosis) + kurt.index = ['Kurt_' + col for col in kurt.index] + + # 4. Índices de potencia relativa específicos + rel_delta = avg_p['Delta'] / total_avg_p + + # 5. Ratios de enlentecimiento + tar = avg_p['Theta'] / avg_p['Alpha'] # Theta-Alpha Ratio + tbr = avg_p['Theta'] / avg_p['Beta'] # Theta-Beta Ratio + + # 6. Entropía Espectral (Complejidad del perfil de potencia promedio) + # Cuanto más baja, más "pobre" es la diversidad de frecuencias del cerebro + spec_entropy = entropy(df_features) + + # 2. Calcular Potencias Relativas (promedio de toda la noche) + rel_powers = df_features.div(total_power, axis=0).mean() + rel_powers.index = ['Rel_' + col for col in rel_powers.index] + + # Calculate main frecuencies of oscilation on each band (peak frequency) + # Esto puede ser un buen indicador de cambios en la arquitectura del sueño + # peak_freqs = {} + # for band in ['Delta', 'Theta', 'Alpha', 'Sigma', 'Beta']: + # freqs, psd = welch(df_features[band], fs=1/30, nperseg=25, noverlap = 25 // 2, nfft=1024) # fs=1/30 porque cada punto es un promedio de 30s + # idx_peak = np.argmax(psd) + # peak_freqs['PeakFreq_' + band] = freqs[idx_peak] + + # 3. Calcular Ratios Críticos + # Usamos la media de las potencias absolutas para el ratio global + avg_p = df_features.mean() + ratios = { + 'Ratio_Theta_Alpha': avg_p['Theta'] / avg_p['Alpha'], + 'Ratio_Slow_Fast': (avg_p['Delta'] + avg_p['Theta']) / (avg_p['Alpha'] + avg_p['Beta']), + 'Sigma_Stability': df_features['Sigma'].std() / df_features['Sigma'].mean(), + 'Spectral_Entropy_delta': spec_entropy[0], + 'Spectral_Entropy_theta': spec_entropy[1], + 'Spectral_Entropy_alpha': spec_entropy[2], + 'Spectral_Entropy_sigma': spec_entropy[3], + 'Spectral_Entropy_beta': spec_entropy[4], + 'Theta_Alpha_Ratio': tar, + 'Theta_Beta_Ratio': tbr, + 'Relative_Delta_Power': rel_delta, + 'kurtosis_Delta': kurt['Kurt_Delta'], + 'kurtosis_Theta': kurt['Kurt_Theta'], + 'kurtosis_Alpha': kurt['Kurt_Alpha'], + 'kurtosis_Sigma': kurt['Kurt_Sigma'], + 'kurtosis_Beta': kurt['Kurt_Beta'], + 'variability_Delta': variability['CV_Delta'], + 'variability_Theta': variability['CV_Theta'], + 'variability_Alpha': variability['CV_Alpha'], + 'variability_Sigma': variability['CV_Sigma'], + 'variability_Beta': variability['CV_Beta'], + + } + + # Combinar todo en una sola fila + profile = pd.concat([rel_powers, pd.Series(ratios)]) + return profile \ No newline at end of file diff --git a/src/lib/Resp_features.py b/src/lib/Resp_features.py new file mode 100644 index 0000000..60ff3bf --- /dev/null +++ b/src/lib/Resp_features.py @@ -0,0 +1,318 @@ +import pandas as pd +import numpy as np +from .peakedness import peakednessCost +from scipy.interpolate import interp1d +from scipy.stats import kruskal +from scipy.signal import resample, detrend +import scipy.fft as fft +from scipy.signal import butter, filtfilt + +try: + import plotly.graph_objs as go +except ModuleNotFoundError: + go = None + +try: + import matplotlib.pyplot as plt +except ModuleNotFoundError: + plt = None + +def plot_resp(Data, subjet = 1, DownPrinting = 2): + """ + Plot resp data using Plotly. + """ + if go is None: + raise ModuleNotFoundError("plotly is required for plot_resp") + + if type(Data) == dict: + Data = pd.DataFrame(Data[str(subjet)]) + Data = Data.iloc[::DownPrinting, :] + end = -1 + elif type(Data) == type(pd.DataFrame()): + Data = Data[Data['Subjet'] == str(subjet)] + Data = Data.iloc[::DownPrinting, :] + end = -2 + + # Data.reset_index(drop=True, inplace=True) + print(len(Data.columns)) + fig = go.Figure() + for c in Data.columns[:end]: + fig.add_trace(go.Line(x=Data.Time, y=Data[c], name = c)) + fig.update_layout(title_text='EDA Data', title_x=0.5) + + fig.show() + +def peakedness_application(Data, stage, plotflag = False, subjet = 1): + # print("Compute BR") + fs = 25 + Setup = {} + Setup["K"] = 5 + Setup["DT"] = 5 + Setup["Ts"] = 60 #interval length of Welch periodograms (s) + Setup["Tm"] = 20 #interval length of subintervals for Welch periodograms (s) + # Setup["d"] = 0.1 #interval length of subintervals for Welch periodograms (s) + Setup["Omega_r"] = np.array([5, 25])/60 #respiratory rate range in Hz + Setup["plotflag"] = plotflag + Setup["Nfft"] = np.power(2,13) + tsBR = np.arange(0,Data.shape[0]/fs,1/fs) + + if tsBR.shape[0] != Data.shape[0]: + # print(f"tsBR.shape[0]: {tsBR.shape[0]}, Data.shape[0]: {Data.shape[0]}") + tsBR = np.arange(0,Data.shape[0]/fs,1/fs)[:Data.shape[0]] + + hat_Br, Sk_Br, t_aver, used = peakednessCost(Data, tsBR, fs, Setup, title = stage, storeGraph = False, subjet = subjet) + # print(f"hat_Br: {hat_Br}, Sk_Br: {Sk_Br}, bar_Br: {bar_Br}, t_aver_Br: {t_aver_Br}, f_Br: {f_Br}, used_Br: {used_Br}") + + # print(hat_Br) + return hat_Br, Sk_Br, t_aver, used + +def ODI_application(data, fs, plotflag=True, subjet=1): + """Detecta desaturaciones de más del 3 % en la señal de saturación de + oxígeno (SpO2) y devuelve estadísticas básicas de los eventos. + + El índice de desaturación de oxígeno (ODI) se define como el número de + episodios en los que la saturación cae al menos un 3 % respecto a una + línea de base móvil, normalizado por hora de grabación. Aquí se calcula + una línea base mediante la mediana móvil de 60 segundos y se agrupan + los índices consecutivos que cumplen el criterio en eventos únicos. + + Args: + data (array-like): valores de SpO2 (0‑100). + fs (float): frecuencia de muestreo en Hz. + plotflag (bool): si True, dibuja la señal y marca los eventos. + subjet (int): identificador de sujeto (utilizado en títulos de gráficas). + + Returns: + tuple: + * odi_mean (float): número de desaturaciones normalizado por hora. + * odi_std (float): desviación estándar de las magnitudes de caída + entre eventos (en porcentaje). + """ + # convertir a serie para comodidad + sp = pd.Series(data) + if len(sp) == 0 or fs <= 0: + return 0.0, 0.0 + + # base móvil de 60 segundos (median para ser robusto). ventana en muestras + window = int(fs * 60) + if window < 1: + window = 1 + baseline = sp.rolling(window, min_periods=1, center=True).median() + + # diferencia de base menos señal; buscamos caídas >=3 + diff = baseline - sp + mask = diff >= 3 + + # juntar índices contiguos en eventos + events = [] # lista de (start_idx, end_idx) + in_event = False + for idx, flag in mask.items(): + if flag and not in_event: + start = idx + in_event = True + elif not flag and in_event: + end = prev_idx + events.append((start, end)) + in_event = False + prev_idx = idx + if in_event: + events.append((start, prev_idx)) + + num_events = len(events) + duration_hours = len(sp) / fs / 3600.0 + odi_mean = num_events / duration_hours if duration_hours > 0 else 0.0 + + # calcular magnitudes de caída en cada evento (tomando el valor más bajo) + magnitudes = [] + for start, end in events: + mag = diff.loc[start:end].max() + magnitudes.append(mag) + odi_deepness = np.mean(magnitudes) if magnitudes else 0.0 + + if plotflag: + if plt is None: + raise ModuleNotFoundError("matplotlib is required when plotflag=True") + times = np.arange(len(sp)) / fs / 60.0 # minutos + plt.figure(figsize=(10, 4)) + plt.plot(times, sp.values, label='SpO2') + plt.plot(times, baseline.values, label='Baseline (60s med)') + for (start, end) in events: + t0 = start / fs / 60.0 + t1 = end / fs / 60.0 + plt.axvspan(t0, t1, color='red', alpha=0.3) + plt.xlabel('Tiempo (min)') + plt.ylabel('SpO2 (%)') + plt.title(f'Sujeto {subjet} - ODI detectado: {odi_mean:.2f} eventos/h') + plt.legend() + plt.tight_layout() + plt.show() + + return odi_mean, odi_deepness + +# Butterworth low-pass filter +def lowpass_filter(signal, fs, cutoff=2.0, order=4): + nyq = 0.5 * fs + normal_cutoff = cutoff / nyq + b, a = butter(order, normal_cutoff, btype='low', analog=False) + return filtfilt(b, a, signal) + +def Metrics_per_segment(Data): + """ + Compute peakedness per segment. + """ + + # Results = pd.DataFrame(columns=['Subject', 'Stage', 'Peakedness', 'Slope', 'Intercept', 'Relative Peak', 'Bocanada', 'Contraction', "TidalVolume", "Complexity", "Mobility", "Activity"]) + Results = pd.DataFrame() + + for subjet in Data['Subjet'].unique(): + sel_sujeto = Data[Data['Subjet'] == subjet] + sel_sujeto_ref = sel_sujeto.iloc[:,:-2] + Sol_subject = [] + Sol_interSubject = [] + for secc in sel_sujeto_ref.columns: + if secc == 'Time': + continue + else: + section = sel_sujeto[secc].values + section = section[~np.isnan(section)] + + hat_Br, Sk_Br, t_aver, _ = peakedness_application(section, stage=secc, plotflag = False, subjet= subjet) + # print(f"Subjet: {subjet}, section: {secc} hat_Br: {hat_Br}, Sk_Br: {Sk_Br}") + + # Ajuste lineal + coef = np.polyfit(t_aver, hat_Br, 1) # Grado 1 = línea recta + pendiente, interseccion = coef + # print(f"Pendiente: {pendiente:.6f}, Intersección: {interseccion:.6f}") + + #Picudez relativa + rel_peak_list = [] + for ti in range(len(t_aver)): + f_max = np.argmax(Sk_Br[:,ti]) + rel_peak_list.append(np.sum(Sk_Br[f_max-1:f_max+1,ti]) / np.sum(Sk_Br[:,ti])) + real_peak = np.mean(rel_peak_list) + + # Derivada + diff = np.diff(section) + bocanada = max(np.percentile(diff, 90), np.abs(np.percentile(diff, 10))) + + Contraction = np.percentile(np.abs(diff), 10) + + #Tidal Volume + TidalVolume = max(np.percentile(section, 99), np.abs(np.percentile(section, 1))) + + # Calculate derivatives + dx = np.diff(section) + ddx = np.diff(dx) + + # Calculate variance and its derivatives + x_var = np.var(section) # = activity + dx_var = np.var(dx) + ddx_var = np.var(ddx) + + # Mobility and complexity + mobility = np.sqrt(dx_var / x_var) + complexity = np.sqrt(ddx_var / dx_var) / mobility + + + filtered_signal = lowpass_filter(section, 100, cutoff=2.0, order=4) + segment4Hz = resample(filtered_signal, int(filtered_signal.size/100*4)) # Resample to 4Hz + + fft_signal = fft.fft(detrend(segment4Hz), n=2**12) + power = np.abs(fft_signal)**2 + freqs = fft.fftfreq(2**12, d = 1/4) + max_freq_index = np.argmax(power) + max_freq = freqs[max_freq_index] + power_at_max_freq = power[max_freq_index-51:max_freq_index+51].sum() + power_ratio = power_at_max_freq / np.sum(power[:len(power)//2]) + + Sol = [subjet, secc[:secc.find('_')], np.mean(hat_Br), pendiente, interseccion, real_peak, bocanada, Contraction, TidalVolume, complexity, mobility, x_var, max_freq, power_ratio] + + Sol_subject.append(Sol) + + + Sol_subject = pd.DataFrame(Sol_subject, columns=['Subject', 'Stage', 'Peakedness', 'Slope', + 'Intercept', 'Relative Peak', 'Bocanada', + "Contraction","TidalVolume", "Complexity", + "Mobility", "Activity", "Max_freq", "Power_ratio"]) + + peakmean = Sol_subject['Peakedness'].mean() + peakmin = Sol_subject['Peakedness'].min() + peakmax = Sol_subject['Peakedness'].max() + + slopemean = Sol_subject['Slope'].mean() + slopemin = Sol_subject['Slope'].min() + slopemax = Sol_subject['Slope'].max() + + Rel_peak_mean = Sol_subject['Relative Peak'].mean() + + Bocanada_max = Sol_subject['Bocanada'].max() + Contraction_max = Sol_subject['Contraction'].max() + TidalVolume_max = Sol_subject['TidalVolume'].max() + + Rel_metrics = ['Subject', 'Stage',"Peakmean", "Peakmin", "Peakmax", "Slopemean", "Slopemin", "Slopemax", "Rel_peak_mean", "Bocanada_max", "Contraction_max", "TidalVolume_max"] + # Rel_metrics = ["Peakmean", "Peakmin", "Peakmax", "Slopemean", "Slopemin", "Slopemax", "Rel_peak_mean", "Bocanada_max", "Contraction_max", "TidalVolume_max"] + + Sol_interSubject_DF = pd.DataFrame(Sol_interSubject, columns=Rel_metrics) + Sol_interSubject_DF = pd.DataFrame(Sol_interSubject, columns=Rel_metrics[2:]) + for i in Sol_subject.index: + # Sol_interSubject_DF.at[i,Rel_metrics[0]] = Sol_subject.iloc[i,0] + # Sol_interSubject_DF.at[i,Rel_metrics[1]] = Sol_subject.at[i,'Stage'] + Sol_interSubject_DF.at[i,Rel_metrics[2]] = Sol_subject.at[i,'Peakedness']/peakmean + Sol_interSubject_DF.at[i,Rel_metrics[3]] = Sol_subject.at[i,'Peakedness']/peakmin + Sol_interSubject_DF.at[i,Rel_metrics[4]] = Sol_subject.at[i,'Peakedness']/peakmax + Sol_interSubject_DF.at[i,Rel_metrics[5]] = Sol_subject.at[i,'Slope']/slopemean + Sol_interSubject_DF.at[i,Rel_metrics[6]] = Sol_subject.at[i,'Slope']/slopemin + Sol_interSubject_DF.at[i,Rel_metrics[7]] = Sol_subject.at[i,'Slope']/slopemax + Sol_interSubject_DF.at[i,Rel_metrics[8]] = Sol_subject.at[i,'Relative Peak']/Rel_peak_mean + Sol_interSubject_DF.at[i,Rel_metrics[9]] = Sol_subject.at[i,'Bocanada']/Bocanada_max + Sol_interSubject_DF.at[i,Rel_metrics[10]] = Sol_subject.at[i,"Contraction"]/Contraction_max + Sol_interSubject_DF.at[i,Rel_metrics[11]] = Sol_subject.at[i,"TidalVolume"]/TidalVolume_max + + Sol = pd.concat([Sol_subject, Sol_interSubject_DF], axis=1) + Results = pd.concat([Results, Sol], ignore_index=True) + + + + return Results + +def Significance_tests(RespData): + """ + Compute significance tests for the features. + """ + results = {} + for metrica in RespData.columns[2:]: + # print(f"Realizando prueba de Kruskal-Wallis para la métrica: {metrica}") + # Realizar la prueba de Kruskal-Wallis + estadistico, p_valor = kruskal( + np.array(RespData[RespData.Stage == "Baseline"][metrica].reset_index(drop=True)), + np.array(RespData[RespData.Stage == "LOW"][metrica].reset_index(drop=True)), + np.array(RespData[RespData.Stage == "HIGH"][metrica].reset_index(drop=True)), + np.array(RespData[RespData.Stage == "REST"][metrica].reset_index(drop=True)) + ) + + # Imprimir resultados + + # print(f"Estadístico de Kruskal-Wallis: {estadistico}") + print(f"Metrica: "+metrica+" tiene un valor p: {p_valor}") + results[metrica] = p_valor + # if p_valor < 0.05: + # print("Se rechaza la hipótesis nula: hay diferencias significativas entre los grupos.") + # else: + # print("No se rechaza la hipótesis nula: no hay diferencias significativas entre los grupos.") + + results = pd.DataFrame.from_dict(results, orient='index', columns=['p_value']) + results = results.reset_index() + results.to_excel('./Graphs/kruskal_results.xlsx', index=False) + + if plt is None: + raise ModuleNotFoundError("matplotlib is required for Significance_tests plotting") + + plt.plot(results['index'], results['p_value']) + plt.axhline(y=0.05, color='r', linestyle='--') + plt.xlabel('Métrica') + plt.ylabel('Valor p') + plt.title('Resultados de la prueba de Kruskal-Wallis') + plt.xticks(rotation=90) + plt.tight_layout() + plt.savefig('./Graphs/kruskal_results.png') + plt.show() \ No newline at end of file diff --git a/src/lib/__init__.py b/src/lib/__init__.py new file mode 100644 index 0000000..665f43f --- /dev/null +++ b/src/lib/__init__.py @@ -0,0 +1 @@ +"""Signal-processing helper package.""" \ No newline at end of file diff --git a/src/lib/compute_hrv_hrf.py b/src/lib/compute_hrv_hrf.py new file mode 100644 index 0000000..15b0422 --- /dev/null +++ b/src/lib/compute_hrv_hrf.py @@ -0,0 +1,155 @@ +import numpy as np +from scipy.signal import lombscargle + +def compute_HRV_HRF(NN, SF): + """ + NN: array of NN intervals in seconds + SF: sampling frequency (Hz) + """ + + NN = np.asarray(NN).flatten() + + # =============================== + # ΔNN + # =============================== + dNN = np.diff(NN) + + n = 1 + thr = n / SF + + # Classification + acc = dNN <= -thr + dec = dNN >= thr + noch = (dNN > -thr) & (dNN < thr) + + # Sign representation + sign_dNN = np.zeros_like(dNN) + sign_dNN[acc] = -1 + sign_dNN[dec] = 1 + + N = len(dNN) + + # =============================== + # PIP (Inflection Points) + # =============================== + inflection = 0 + + for i in range(N - 1): + if (dNN[i+1] * dNN[i] <= 0) and (dNN[i+1] != dNN[i]): + inflection += 1 + + PIP = (inflection / (N - 1)) * 100 if N > 1 else np.nan + + # =============================== + # Segment Detection + # =============================== + segments = [] + if N > 0: + current_seg = sign_dNN[0] + length_seg = 1 + + for i in range(1, N): + if sign_dNN[i] == current_seg and sign_dNN[i] != 0: + length_seg += 1 + else: + if current_seg != 0: + segments.append(length_seg) + current_seg = sign_dNN[i] + length_seg = 1 + + # Add last segment + if current_seg != 0: + segments.append(length_seg) + + segments = np.array(segments) + + # =============================== + # PNNLS & PNNSS + # =============================== + if len(segments) > 0: + long_segments = segments[segments >= 3] + short_segments = segments[segments < 3] + + PNNLS = np.sum(long_segments) / N * 100 + PNNSS = np.sum(short_segments) / np.sum(segments) * 100 + else: + PNNLS = np.nan + PNNSS = np.nan + + # =============================== + # Time-domain HRV + # =============================== + win_length = 300 # seconds + + time = np.cumsum(NN) + + AVNN_all = [] + SDNN_all = [] + RMSSD_all = [] + + i = 0 + while i < len(NN): + t_start = time[i] + t_end = t_start + win_length + + idx = np.where((time >= t_start) & (time < t_end))[0] + + if len(idx) >= 150: + NN_win = NN[idx] + + AVNN_all.append(np.nanmean(NN_win)) + SDNN_all.append(np.nanstd(NN_win, ddof=1)) + + diffNN = np.diff(NN_win) + RMSSD_all.append(np.sqrt(np.nanmean(diffNN**2))) + + next_i = np.where(time >= t_end)[0] + if len(next_i) == 0: + break + i = next_i[0] + + AVNN = np.nanmean(AVNN_all) if len(AVNN_all) > 0 else np.nan + SDNN = np.nanmean(SDNN_all) if len(SDNN_all) > 0 else np.nan + RMSSD = np.nanmean(RMSSD_all) if len(RMSSD_all) > 0 else np.nan + + # =============================== + # Frequency-domain (HF) + # =============================== + HF_all = [] + + for _ in range(len(AVNN_all)): + # NOTE: simplified like MATLAB version + NN_win = NN.copy() + t_win = np.cumsum(NN_win) + + # Convert to angular frequency + f = np.linspace(0.01, 0.5, 1000) + angular_f = 2 * np.pi * f + + # Remove mean (important for Lomb) + NN_detrended = NN_win - np.mean(NN_win) + + Pxx = lombscargle(t_win, NN_detrended, angular_f, normalize=True) + + HF_band = (f >= 0.15) & (f <= 0.4) + + HF_power = np.trapezoid(Pxx[HF_band], f[HF_band]) + + HF_all.append(HF_power) + + HF = np.nanmean(HF_all) if len(HF_all) > 0 else np.nan + + # =============================== + # OUTPUT + # =============================== + results = { + "PIP": PIP, + "PNNLS": PNNLS, + "PNNSS": PNNSS, + "AVNN": AVNN, + "SDNN": SDNN, + "RMSSD": RMSSD, + "HF": HF + } + + return results \ No newline at end of file diff --git a/src/lib/interpolate_NN.py b/src/lib/interpolate_NN.py new file mode 100644 index 0000000..987019f --- /dev/null +++ b/src/lib/interpolate_NN.py @@ -0,0 +1,40 @@ +import numpy as np +from scipy.interpolate import PchipInterpolator + +def interpolate_NN_pchip(NN, maxGap): + """ + NN: array of NN intervals (seconds) + maxGap: max number of consecutive NaNs allowed for interpolation + """ + + NN = np.asarray(NN).flatten() + NN_interp = NN.copy() + + nan_idx = np.isnan(NN) + + # Find NaN segments + d = np.diff(np.concatenate(([0], nan_idx.astype(int), [0]))) + start_idx = np.where(d == 1)[0] + end_idx = np.where(d == -1)[0] - 1 + + for k in range(len(start_idx)): + seg_len = end_idx[k] - start_idx[k] + 1 + + if seg_len <= maxGap: + left = start_idx[k] - 1 + right = end_idx[k] + 1 + + # Check bounds + if (left >= 0 and right < len(NN) and + not np.isnan(NN[left]) and not np.isnan(NN[right])): + + x = np.array([left, right]) + y = np.array([NN[left], NN[right]]) + + xi = np.arange(start_idx[k], end_idx[k] + 1) + + # PCHIP interpolation + interpolator = PchipInterpolator(x, y) + NN_interp[xi] = interpolator(xi) + + return NN_interp \ No newline at end of file diff --git a/src/lib/pan_tompkins.py b/src/lib/pan_tompkins.py new file mode 100644 index 0000000..3f37214 --- /dev/null +++ b/src/lib/pan_tompkins.py @@ -0,0 +1,184 @@ +import numpy as np +from scipy.signal import butter, filtfilt, find_peaks + + +def pan_tompkin(ecg, fs, gr=0): + + ecg = np.asarray(ecg).flatten() + delay = 0 + + skip = 0 + m_selected_RR = 0 + mean_RR = 0 + ser_back = 0 + + # ===================== FILTERING ===================== # + ecg = ecg - np.mean(ecg) + + if fs == 200: + # Low-pass + b, a = butter(3, 12*2/fs, btype='low') + ecg_l = filtfilt(b, a, ecg) + ecg_l = ecg_l / np.max(np.abs(ecg_l)) + + # High-pass + b, a = butter(3, 5*2/fs, btype='high') + ecg_h = filtfilt(b, a, ecg_l) + ecg_h = ecg_h / np.max(np.abs(ecg_h)) + else: + b, a = butter(3, [5*2/fs, 15*2/fs], btype='band') + ecg_h = filtfilt(b, a, ecg) + ecg_h = ecg_h / np.max(np.abs(ecg_h)) + + # ===================== DERIVATIVE ===================== # + if fs != 200: + int_c = int((5 - 1) / (fs * (1/40))) + base = np.array([1, 2, 0, -2, -1]) * (1/8) * fs + x_old = np.linspace(1, 5, 5) + x_new = np.linspace(1, 5, int_c) + b = np.interp(x_new, x_old, base) + else: + b = np.array([1, 2, 0, -2, -1]) * (1/8) * fs + + ecg_d = filtfilt(b, [1], ecg_h) + ecg_d = ecg_d / np.max(np.abs(ecg_d)) + + # ===================== SQUARING ===================== # + ecg_s = ecg_d ** 2 + + # ===================== MOVING WINDOW ===================== # + win = int(round(0.150 * fs)) + ecg_m = np.convolve(ecg_s, np.ones(win)/win, mode='same') + delay += win // 2 + + # ===================== PEAK DETECTION ===================== # + locs, _ = find_peaks(ecg_m, distance=int(0.2 * fs)) + pks = ecg_m[locs] + + LLp = len(pks) + + qrs_i = [] + qrs_c = [] + qrs_i_raw = [] + qrs_amp_raw = [] + + nois_i = [] + nois_c = [] + + # Threshold initialization + THR_SIG = np.max(ecg_m[:2*fs]) / 3 + THR_NOISE = np.mean(ecg_m[:2*fs]) / 2 + SIG_LEV = THR_SIG + NOISE_LEV = THR_NOISE + + THR_SIG1 = np.max(ecg_h[:2*fs]) / 3 + THR_NOISE1 = np.mean(ecg_h[:2*fs]) / 2 + SIG_LEV1 = THR_SIG1 + NOISE_LEV1 = THR_NOISE1 + + Beat_C = 0 + Beat_C1 = 0 + + for i in range(LLp): + + loc = locs[i] + + # Find peak in filtered signal + left = max(0, loc - int(0.150 * fs)) + right = loc + + if right < len(ecg_h): + segment = ecg_h[left:right+1] + if len(segment) > 0: + y_i = np.max(segment) + x_i = np.argmax(segment) + else: + continue + else: + continue + + # RR interval update + if len(qrs_i) >= 9: + diffRR = np.diff(qrs_i[-8:]) + mean_RR = np.mean(diffRR) + comp = qrs_i[-1] - qrs_i[-2] + + if comp <= 0.92 * mean_RR or comp >= 1.16 * mean_RR: + THR_SIG *= 0.5 + THR_SIG1 *= 0.5 + else: + m_selected_RR = mean_RR + + test_m = m_selected_RR if m_selected_RR else mean_RR + + # ===================== SEARCH BACK ===================== # + if test_m and len(qrs_i) > 0: + if (loc - qrs_i[-1]) >= int(1.66 * test_m): + + sb_left = qrs_i[-1] + int(0.2 * fs) + sb_right = loc - int(0.2 * fs) + + if sb_right > sb_left: + segment = ecg_m[sb_left:sb_right] + if len(segment) > 0: + pks_temp = np.max(segment) + locs_temp = sb_left + np.argmax(segment) + + if pks_temp > THR_NOISE: + qrs_c.append(pks_temp) + qrs_i.append(locs_temp) + + seg = ecg_h[max(0, locs_temp-int(0.150*fs)):locs_temp] + if len(seg) > 0: + y_i_t = np.max(seg) + x_i_t = np.argmax(seg) + + if y_i_t > THR_NOISE1: + qrs_i_raw.append(locs_temp - int(0.150*fs) + x_i_t) + qrs_amp_raw.append(y_i_t) + SIG_LEV1 = 0.25*y_i_t + 0.75*SIG_LEV1 + + SIG_LEV = 0.25*pks_temp + 0.75*SIG_LEV + + # ===================== CLASSIFICATION ===================== # + if pks[i] >= THR_SIG: + + # T-wave rejection + if len(qrs_i) >= 3: + if (loc - qrs_i[-1]) <= int(0.36 * fs): + + slope1 = np.mean(np.diff(ecg_m[max(0, loc-int(0.075*fs)):loc])) + slope2 = np.mean(np.diff(ecg_m[max(0, qrs_i[-1]-int(0.075*fs)):qrs_i[-1]])) + + if abs(slope1) <= 0.5 * abs(slope2): + NOISE_LEV1 = 0.125*y_i + 0.875*NOISE_LEV1 + NOISE_LEV = 0.125*pks[i] + 0.875*NOISE_LEV + continue + + # Accept QRS + qrs_c.append(pks[i]) + qrs_i.append(loc) + + if y_i >= THR_SIG1: + qrs_i_raw.append(loc - int(0.150*fs) + x_i) + qrs_amp_raw.append(y_i) + SIG_LEV1 = 0.125*y_i + 0.875*SIG_LEV1 + + SIG_LEV = 0.125*pks[i] + 0.875*SIG_LEV + + elif THR_NOISE <= pks[i] < THR_SIG: + NOISE_LEV1 = 0.125*y_i + 0.875*NOISE_LEV1 + NOISE_LEV = 0.125*pks[i] + 0.875*NOISE_LEV + + else: + NOISE_LEV1 = 0.125*y_i + 0.875*NOISE_LEV1 + NOISE_LEV = 0.125*pks[i] + 0.875*NOISE_LEV + + # Update thresholds + THR_SIG = NOISE_LEV + 0.25 * abs(SIG_LEV - NOISE_LEV) + THR_NOISE = 0.5 * THR_SIG + + THR_SIG1 = NOISE_LEV1 + 0.25 * abs(SIG_LEV1 - NOISE_LEV1) + THR_NOISE1 = 0.5 * THR_SIG1 + + return np.array(qrs_amp_raw), np.array(qrs_i_raw), delay \ No newline at end of file diff --git a/src/lib/peakedness.py b/src/lib/peakedness.py new file mode 100644 index 0000000..eaf224a --- /dev/null +++ b/src/lib/peakedness.py @@ -0,0 +1,652 @@ +import pandas as pd +import numpy as np +from numpy.fft import fftshift +from numpy.fft import fft +from scipy.signal import detrend, find_peaks +from time import time +import os + +def _safe_ratio(numerator, denominator, default=0.0): + if denominator is None or not np.isfinite(denominator) or denominator == 0: + return default + value = numerator / denominator + if np.isfinite(value): + return value + return default + +def setParamFr(Setup): + if 'DT' not in Setup.keys(): + Setup["DT"] = 5 + DT = 5 + else: + DT = Setup["DT"] + + if 'Ts' not in Setup.keys(): + Setup["Ts"] = 42 + Ts = 42 + else: + Ts = Setup["Ts"] + + if 'Tm' not in Setup.keys(): + Setup["Tm"] = 12 + Tm = 12 + else: + Tm = Setup["Tm"] + + if 'Nfft' not in Setup.keys(): + Setup["Nfft"] = np.power(2,12) + Nfft = np.power(2,12) + else: + Nfft = Setup["Nfft"] + + if 'K' not in Setup.keys(): + Setup["K"] = 5 + K = 5 + else: + K = Setup["K"] + + if 'Omega_r' not in Setup.keys(): + Setup["Omega_r"] = np.array([0.04, 1]) + Omega_r = np.array([0.04, 1]) + else: + Omega_r = Setup["Omega_r"] + + if 'ksi_p' not in Setup.keys(): + Setup["ksi_p"] = 45 + ksi_p = 45 + else: + ksi_p = Setup["ksi_p"] + + if 'N_k' not in Setup.keys(): + Setup["N_k"] = 4 + N_k = 4 + else: + N_k = Setup["N_k"] + + if 'ksi_a' not in Setup.keys(): + Setup["ksi_a"] = 85 + ksi_a = 85 + else: + ksi_a = Setup["ksi_a"] + + if 'd' not in Setup.keys(): + Setup["d"] = 0.125 + d = 0.125 + else: + d = Setup["d"] + + if 'b' not in Setup.keys(): + Setup["b"] = 0.8 + b = 0.8 + else: + b = Setup["b"] + + if 'a' not in Setup.keys(): + Setup["a"] = 0.5 + a = 0.5 + else: + a = Setup["a"] + + if 'plotflag' not in Setup.keys(): + Setup["plotflag"] = False + plotflag =False + else: + plotflag = Setup["plotflag"] + + + return [ DT, Ts, Tm, Nfft, K, Omega_r, ksi_p, ksi_a, d, b, a, N_k, plotflag, Setup] + +def extract_interval( x, t, int_ini, int_end ): + # EXTRACT_INTERVAL Very simple function to extract an interval from a signal + # + # Created by Jesús Lázaro in 2011 + # -------- + # Sintax: [ x_int, t_int, indexes ] = extract_interval( x, t, int_ini, int_end ) + # In: x = signal + # t = time vector + # int_ini = interval begin time (same units as 't') + # int_end = interval end time (same units as 't') + # + # Out: x_int = interval [int_ini, int_end] of 'x' + # t_int = interval [int_ini, int_end] of 't' + # indexes = indexes corresponding to returned time interval + + x_int = x[(t>=int_ini) & (t <=int_end)] + t_int = t[(t>=int_ini) & (t <=int_end)] + + return [ x_int, t_int ] + +def normalizar_PSD( PSD, f = None, rango = None): + # NORMALIZAR_PSD Normaliza una densidad espectral de potencia en el rango + # de frecuencias requerido. + # + # Created by Jesús Lázaro in 2011 + # ------- + # Sintax: [ PSD_norm, f_PSD_norm, factor_norm ] = normalizar_PSD( PSD, f, rango ) + # In: PSD = Densidad espectral de potencia + # f = Vector de frecuencias para PSD [Por defecto: frecuencias digitales] + # rango = Rango [f1, f2] en el que se aplicar� la normalizaci�n [Por defecto: Todo f] + # + # Out: PSD_norm = Densidad espectral de potencia notmalizada + # f_PSD_norm = Vector de frecuencias para PSD_norm + # factor_norm = Factor de normalizaci�n utilizado + + if f is None: + f = np.arange(0,PSD.shape[0]) / PSD.shape[0] - 1/2 + + if rango is None: + rango = [f[0], f[-1]] + + + # Seleccionar rango de inter�s: + f_PSD_norm = f[(f>=rango[0]) & (f<=rango[1])] + PSD = PSD[(f>=rango[0]) & (f<=rango[1])] + if not np.any(f_PSD_norm): # El vector de frecuencias no estaba ordenado + print('El vector de frecuencias debe estar ordenado de forma ascendente'); + + + # Calcular factor de normalizaci�n y normalizar: + ##ADD NAN removal from PSD + factor_norm = sum(PSD) + # print("factor_norm "+ str(factor_norm)) + if factor_norm == 0: + # print("stop") // IMPORTANT MODIFICATION TODO + PSD_norm = PSD + else: + PSD_norm = PSD/factor_norm + + return [ PSD_norm, f_PSD_norm, factor_norm ] + +def init_module(kk,vars,param, plotflag): + # function vars = init_module(kk,vars,param, plotflag) + # This function is used for initialization and reinitialization of bar_fr + Skl = vars["Skl"] + t_orig = vars["t_orig"] + t_aver = vars["t_aver"] + f = vars["f"] + L = vars["L"] + + DT = param["DT"] + K = param["K"] + ksi_p = param["ksi_p"] + d = param["d"] + + # Increment of number of spectra for averaging + if kk == 0: # INITIALIZATION + N = 4*np.floor(K/2) + else: # RE-INITIALIZATION + N = 2*np.floor(K/2) + + ###### Peakedness Analysis : + # Indexes of original spectra that take part in the average + O = np.bitwise_and(t_orig>=t_aver[kk]-N*DT, t_orig<=t_aver[kk]+N*DT) + W = np.arange(O.shape[0]) + O = W[O] + # W = np.ones([O.shape[0]]) + # O1 = W[O] + # Pre-allocate + Xkl = np.empty((O.shape[0], L)) + Xkl[:] = np.nan + for k in range(O.shape[0]): + for l in range(L): + S = Skl[:, O[k], l] + # print(S.shape) + # Use as reference for Pkl calculation the absolute maximum + i_m = S.argmax() + fr_max = f[i_m] + + # Define the Omega, Omega_p bands + Omega = np.bitwise_and(f>=fr_max-d, f<=fr_max+d) + + # Modified limits for initialization (reduces the risk for 0.1 Hz) + Omega_p = np.bitwise_and(f>=max(fr_max-0.4*d,0.15), f<=min(fr_max+0.4*d,0.8)) + + # Peakedness + # print(S[Omega]) + band_power = np.sum(S[Omega]) + peaky_power = np.sum(S[Omega_p]) + Pkl = 100*_safe_ratio(peaky_power, band_power) + + if Pkl >= ksi_p: + Xkl[k,l] = 1 + else: + Xkl[k,l] = 0 + + # Initialization for averaged spectrum (if cannot be defined) + if L>1: + averS = np.mean(np.squeeze(np.mean(Skl[:, O, :],1)),1) + else: + averS = np.mean(np.mean(Skl[:, O, :],1),1) + + + if kk == 0: #INITIALIZATION + if np.sum(Xkl[:]) > 0: # One or more spectra were peaked enough + + # Sum all peaky spectra + averS = np.zeros((f.shape[0], 1)) + for k in range(O.shape[0]): + for l in range(L): + if Xkl[k, l] == 1: + averS = averS + Skl[:, O[k], l] + + # Select the maximum in the spectrum + i_m = averS.argmax() + + # Save in vars + vars["bar_fr"][0] = f[i_m] + + else: # RE-INITIALIZATION + # One or more spectra were peaked enough + if np.sum(Xkl[:]) > 0: + # Sum all peaky spectra + averS = np.zeros(f.shape[0]) + for k in range(O.shape[0]): + for l in range(L): + if Xkl[k,l] == 1: + averS = averS + Skl[:, O[k], l] + + # Local maxima in the averaged spectrum + j_pk = find_peaks(averS) + j_pk = j_pk[0] + pk = averS[j_pk] + # Extra restriction : consider peaks with important power + # j_del = pk<0.5*np.max(averS) # IMPORTANTE TODO + j_del = pk<0.2*np.max(averS) + pk = pk[~j_del] + j_pk = j_pk[~j_del] + + # Cost function for deviation from previous fr and maximum power + max_s = np.max(S) + if np.isfinite(max_s) and max_s != 0: + C_a = 1 - (np.transpose(pk) / max_s) + else: + C_a = np.ones_like(pk, dtype=float) + fr_prev = vars["bar_fr"][np.max(kk,0)] + C_f = abs(f[j_pk[:]]-fr_prev)/(2*d) + # C_f = abs(f(i_pk(:))-fr_prev)/(Omega_r(2)-Omega(1)); + + C = C_a +C_f + if C.size > 0: + j_min = C.argmin() + fj = j_pk[j_min] + vars["bar_fr"][kk] = f[fj] + else: + vars["bar_fr"][kk] = 0 + # Save in vars + # vars["bar_fr"][kk] = f[fj] + + # if plotflag: + # if plt is None: + # raise ModuleNotFoundError("matplotlib is required when plotflag=True") + # plt.plot(f, averS) + # plt.plot(f[fj], averS[fj], '-') + # plt.title('Initialization - Averaged Spectrum') + # plt.show() + + return vars + # # No spectra fulfill the initialization + # if plotflag: + # keyboard + +def compute_Xkl( Skl, f, bar_fr, O, ksi_p, ksi_a, d): + # function [ Xkl ] = compute_Xkl( Skl, f, bar_fr, O, ksi_p, ksi_a, d) + # Created by Spyros Kontaxis in 2019 + # Computation of peakedness for a power spectrum + # Sintax: [ Xkl ] = compute_Xkl( Skl, f, bar_fr, O, ksi_p, ksi_a, d) + # Inputs: + # Skl : Welch TF maps in a 3D matrix (f x t x DR signals) + # f : frequency vector (Hz) + # bar_fr : smoothed estimate of the respiratory rate (Hz) + # O : Indexes of original spectra that take part in the average + # ksi_p : peakedness threshold based on power concentration (%) + # ksi_a : peakedness threshold based on absolute maximum (%) + # d : half bandwith of Omega centered around bar_fr (Hz) + # Outputs: + # Xkl : 1-> the o:th spectrum will be used in the average + # 0-> the o:th spectrum will not be used in the average + # + + # % Define two search window arround the estimated respiratory rate + Omega = np.bitwise_and(f>=bar_fr-d, f<=bar_fr+d) + Omega_p = np.bitwise_and(f>=bar_fr-0.4*d, f<=bar_fr+0.4*d) + + # % Get the ammount of signals + L = Skl.shape[2] + + # % Pre-allocate + Xkl = np.zeros((O.shape[0],L)) + + # % Loop over all segments + for k in range(O.shape[0]): + + # % Loop over all signals + for l in range(L): + # % Select the power spectrum of one segment + S = Skl[:, O[k], l] + + # % Define peakedness based on the power concentration + band_power = np.sum(S[Omega]) + peaky_power = np.sum(S[Omega_p]) + Pkl = 100*_safe_ratio(peaky_power, band_power) + + # % Define peakedness based on the absolute maximum + # print(max(S)) + max_s = np.max(S) + max_band = np.max(S[Omega]) if np.any(Omega) else 0.0 + Akl = 100*_safe_ratio(max_band, max_s) + # % If the spectrum is concidered peaky by both conditions, mark as + # % peaky + if np.bitwise_and(Pkl >= ksi_p, Akl >= ksi_a): + Xkl[k,l] = 1 + else: + Xkl[k,l] = 0 + + return Xkl + +def compute_fJmin( S, f, bar_fr, d): + # function [ fJmin ] = compute_fJmin( S, f, bar_fr, d) + # Created by Spyros Kontaxis in 2019 + # Spectral peak selection based on cost function + # Sintax: [ fJmin ] = compute_fJmin( S, f, bar_fr, d) + # Inputs: + # S : Averaged Spectrum + # f : frequency vector (Hz) + # bar_fr : smoothed estimate of the respiratory rate (Hz) + # d : half bandwith of Omega centered around bar_fr (Hz) + # Outputs: + # fJmin : respiratory rate estimate + # + + # Define the search window + Omega = np.bitwise_and(f >= bar_fr-d, f <= bar_fr+d) + + # Pre-allocate + fJmin = np.nan + + # Locate peaks in the search window + [peaks, properties] = find_peaks(S[Omega]) #,'SortStr','descend' + + # Put the location in the correct perspective + lm = peaks + (Omega[:] ==1).argmax() + + # Select the frequency that corresponds to the location + fJ = f[lm] + + # print(len(lm)) + if len(lm) > 0: + # Compute the cost function for deviation from previous fr and maximum power + C_f = abs(fJ-bar_fr)/(2*d) + C_a = 1-S[lm]/max(S[Omega]) + + # Select the minimum cost + C = C_f+C_a + Jmin = C.argmin() + + # Store the frequency with the minimum cost + fJmin = fJ[Jmin] + + return fJmin + +def peakednessCost(signals, ts, fs, Setup = {}, title = "", storeGraph = False, subjet =1): + + vars = {} + # Set parameters / Arrange inputs + [ DT, Ts, Tm, Nfft, K, Omega_r, ksi_p, ksi_a, d, b, a,N_k, plotflag , Setup] = setParamFr(Setup) + + # Start the time stamps at zero + ts1 = ts[0] + ts = ts-ts1 + if type(signals) == type(pd.DataFrame()): + signals = signals.to_numpy() + + # Get the number of signals + if len(signals.shape) == 1: + signals = np.reshape(signals, (signals.shape[0],1)) + if signals.shape[0]= Omega_r[0], f < Omega_r[1]) + vars["f"] = f[f_ind] + + # Time vector for original Welch periodograms + # vars["t_orig"] = np.arange(Ts/2, ts[-1]+DT,DT) - Ts/2+DT #Es posible que sea esto lo que quieren pero no es lo que sale de MATLAB + vars["t_orig"] = np.arange(Ts/2, ts[-1]- Ts/2+DT,DT) #Esto es lo que sale de MATLAB + + # Pre-allocate + vars["Skl"] = np.empty((vars["f"].shape[0], vars["t_orig"].shape[0], vars["L"])) + vars["Skl"][:] = np.nan + + t_for1 = time() + for ii in range(vars["L"]): + t_for_L = time() + + # Select signal + signal = signals[:, ii] + # signal = np.reshape(signal, (signal.shape[0],1)) + # Compute the Welch Periodgrams + for k, ki in zip(vars["t_orig"], range(vars["t_orig"].shape[0])): + # Begin of Ts seconds interval + # Ws_begin = vars["t_orig"][k] - Ts/2 + Ws_begin = k - Ts/2 + + # End of Ts seconds interval + Ws_end = Ws_begin + Ts + [int_Ts_sig, int_Ts_t] = extract_interval(signal, ts, Ws_begin, Ws_end); # Ts seconds interval + S = np.zeros((vars["f"].shape[0])) + if int_Ts_sig.shape[0] < (Tm*100)/2: + vars["Skl"][:, ki, ii] = np.zeros((vars["f"].shape[0])) + continue + # Number of Tm length subintervals + NWm = int(np.floor(2*Ts/Tm)) + I=0 + + for i_Tm in range(NWm): + S_i = [] + + # Begin of Tm seconds interval + Wm_begin = Ws_begin + (i_Tm)*Tm/2 + + # End of Tm seconds interval + Wm_end = min(Wm_begin + Tm, Ws_end) + + # Tm seconds interval + [int_Tm_sig, int_Tm_t] = extract_interval(int_Ts_sig, int_Ts_t, Wm_begin, Wm_end) + + # Estimate the spectrum only for intervals without NaNs + if ~np.isnan((int_Ts_sig.astype(float))).any(): + S_i = abs(fftshift(fft(detrend(int_Tm_sig[:-1]), Nfft)))**2 + # S_i = abs(fftshift(fft(int_Tm_sig[:-1], Nfft)))**2 + S_i = S_i[f_ind] + [ S_i, f_PSD_norm, factor_norm ] = normalizar_PSD(S_i) + if ~np.isnan(S_i).any(): + S = S + (1/NWm)*S_i #TODO hacer una median real, que si uno falla la media se coja con los otros 3 dividido entre 3 + I=I+1 + + if I < 0.5*NWm : + vars["Skl"][:, ki, ii] = np.zeros((vars["f"].shape[0])) + else: + # Define the spectrum when enough subintervals were used + vars["Skl"][:, ki, ii] = S + + + + + ##### Peak-conditioned spectral average: ###### + # Pre-allocate + N = int(np.floor(K/2)) + vars["t_aver"] = vars["t_orig"][N:-N] + if vars["t_aver"].shape[0] == 0: + print("No hay tiempo para promediar") + empty_spectra = np.empty((vars["f"].shape[0], 0)) + empty_used = np.empty((0, vars["L"])) + return np.array([]), empty_spectra, np.array([]), empty_used + vars["Sk"] = np.empty((vars["f"].shape[0], vars["t_aver"].shape[0])) + vars["Sk"][:] = np.nan + vars["bar_fr"] = np.empty(( vars["t_aver"].shape[0])) + vars["bar_fr"][:] = np.nan + vars["hat_fr"] = np.empty((vars["t_aver"].shape[0])) + vars["hat_fr"][:] = np.nan + vars["Naveraged"] = np.zeros((vars["t_aver"].shape[0])) + vars["used"] = np.zeros((vars["t_aver"].shape[0],vars["L"])) + vars["times_used"] = np.zeros((vars["t_orig"].shape[0],vars["L"])) + + # Call the initialization module + k_ini = 0 + plotFlag = False + # print(vars["t_aver"]) + vars = init_module(k_ini,vars,Setup,plotFlag); #bar_fr has been initialized + + for k in np.arange(k_ini, vars["t_aver"].shape[0]): + if k >= 1: + k_prev = k-1 + else: + k_prev = 0 + + # Re-initialization when hat_fr has not been defined for N_k time instants + N_k = 2#3+1#vars["N_k"] + N_prev = np.arange(k,max(k-N_k,-1),-1) + if np.isnan(vars["hat_fr"][N_prev]).all() and k > 2: + vars = init_module(k_prev,vars,Setup,plotFlag) # bar_fr has been re-initialized + + # Peakedness Analysis: + # Indexes of original spectra that take part in the average + O = np.bitwise_and(vars["t_orig"]>=vars["t_aver"][k]-N*DT, vars["t_orig"]<=vars["t_aver"][k]+N*DT) + W = np.arange(O.shape[0]) + O = W[O] + + # Compute the peakedness of the power spectrum (1 or 0) + Xkl = compute_Xkl(vars["Skl"], vars["f"], vars["bar_fr"][k_prev], O, ksi_p, ksi_a, d) + + + if np.sum(Xkl) == 0: # No spectrum was peaked + # Store the previous respiratory frequency + vars["bar_fr"][k] = vars["bar_fr"][k_prev] + + # Compute averaged spectrum just for visualization + if vars["L"]>1: + vars["Sk"][:, k] = np.mean(np.squeeze(np.mean(vars["Skl"][:, O, :],1)),1) + else: + try: + vars["Sk"][:,k] = np.mean(vars["Skl"][:, O, :],1)[:,0] + except: + print("Cogido en el except") + print("Cogido en el except") + print("Cogido en el except") + print("Cogido en el except") + vars["Sk"][:,k] = np.nan + + else: #One or more spectra were peaked enough + # Pre-allocate + averS = np.zeros((vars["f"].shape[0])) + + for i_Tm in range(O.shape[0]): + for ii in range(vars["L"]): + if Xkl[i_Tm,ii] == 1: # If this spectrum is considered peaky + # Sum all peaky spectra + averS = averS[:] + vars["Skl"][:, O[i_Tm], ii] + + # Store the nr of peaky spectra + vars["Naveraged"][k] = vars["Naveraged"][k] + 1 + vars["used"][k,ii] = 1 + + # Compute and store the averaged spectrum + vars["Sk"][:, k] = averS/vars["Naveraged"][k] + vars["times_used"][O,:] = vars["times_used"][O,:] + Xkl + + #Spectral peak selection + fJmin = compute_fJmin( vars["Sk"][:, k], vars["f"], vars["bar_fr"][k_prev], d) + + if ~np.isnan(fJmin).any(): # Local maxima inside Omega has been found + # Update bar_fr + + vars["bar_fr"][k] = b*vars["bar_fr"][k_prev] + (1-b)*fJmin + + # Update hat_fr + if ~np.isnan(vars["hat_fr"][k_prev]).any(): + vars["hat_fr"][k]= a*vars["hat_fr"][k_prev] + (1-a)*fJmin + else: + # Use bar_fr(k-1) that always is defined, instead of hat_fr(k-1) + vars["hat_fr"][k]= a*vars["bar_fr"][k_prev] + (1-a)*fJmin + + else: # No local maxima inside Omega + # Update bar_fr + vars["bar_fr"][k] = vars["bar_fr"][k_prev] + + # Don't Update hat_fr + + t_taver = time() + + # Extra : use bar_fr to update hat_fr when was not defined for small gaps (N_k) + # Beginning of the intervals + N_k = 0 + + int_b = np.argwhere(np.isnan(vars["hat_fr"])) + int_b1 = np.append(0,int_b) + int_b = int_b[np.diff(int_b1)>1] + + # End of the intervals + int_e = np.argwhere(np.isnan(vars["hat_fr"])) + int_e1 = np.append(int_e,np.inf) + int_e = int_e[np.diff(int_e1)>1] + + if np.isnan(vars["hat_fr"][0]) and int_e.shape[0]>1: + + int_e = int_e[1:] + if int_b.size > 0 and int_e.size > 0 and (int_e[0]-int_b[0])[0] < 0: + int_e = int_e[1:] + + int_small = (int_e-int_b)<=(N_k-1) + + int_b = int_b[int_small] + int_e = int_e[int_small] + for i in range(int_small.sum()): + vars["hat_fr"][int_b[i]:int_e[i]+1] = vars["hat_fr"][min(int_e[i]+1,vars["hat_fr"].shape[0])] + vars["bar_fr"][int_b[i]:int_e[i]+1] = vars["bar_fr"][min(int_e[i]+1,vars["hat_fr"].shape[0])] + + + + # # Total times a signal can be used + Ntotal = K*(vars["t_orig"].shape[0] - 2) + np.sum(np.arange(1,K)) + + # Times each signal is used + Nused = np.sum(vars["times_used"], 1) + vars["percentage_used"] = 100*Nused/Ntotal + + + vars["t_aver"] = vars["t_aver"] + ts1 + vars["t_orig"] = vars["t_orig"] + ts1 + t_fin = time() + + # if plotflag: + # if go is None or subplots is None: + # raise ModuleNotFoundError("plotly is required when plotflag=True") + + # fig = subplots.make_subplots(rows=2,shared_xaxes=True, subplot_titles=('Peak-condition averaged EDR Spectra in '+title,"EDR/RESP signals"), row_heights=[0.7, 0.3]) + + # fig.add_heatmap(x=vars["t_aver"], y=vars["f"], z=vars["Sk"]/np.max(vars["Sk"]),colorscale='jet',colorbar=dict(orientation='h')) + # fig.update_layout(coloraxis_showscale=False) + # fig.add_trace(go.Line(x=vars["t_aver"], y=vars["hat_fr"],name = 'f\u0302_r(k)'), row = 1, col=1) + # fig.add_trace(go.Line(x=vars["t_aver"],y=vars["bar_fr"],name= 'f\u0304_r(k)'), row = 1, col=1) + + # fig.add_trace(go.Line(x=vars["t_aver"],y=vars["used"]), row = 1, col=1) + # # fig.axis([vars.t_aver(1), vars.t_aver(end), vars.f(1), vars.f(end)]) + # for i in range(signals.shape[1]): + # fig.add_trace(go.Line(x=ts+ts1,y=signals[:,i],name = 'Signal '+str(i)), row = 2, col=1) + + # fig.update_layout(coloraxis_showscale=False) + # fig.update_yaxes(title_text="f (Hz)", row=1, col=1) + # fig.update_yaxes(title_text="(n.u.)", row=2, col=1) + # fig.update_xaxes(title_text="time (s)", row=2, col=1) + # if storeGraph: + # os.makedirs("Graphs/Peakedness/"+str(subjet), exist_ok=True) + # # fig.write_image(os.path.join("Graphs", "Peakedness",str(subjet),title+".png")) + # fig.write_html(os.path.join("Graphs", "Peakedness",str(subjet),title+".html")) + # # fig.write_image() + # else: + # fig.show() + + return vars["hat_fr"], vars["Sk"], vars["t_aver"], vars["used"] + # return vars["hat_fr"], vars["Sk"], vars["bar_fr"],vars["t_aver"], vars["f"], vars["used"] \ No newline at end of file diff --git a/src/lib/remove_ectopic_beat.py b/src/lib/remove_ectopic_beat.py new file mode 100644 index 0000000..de60caa --- /dev/null +++ b/src/lib/remove_ectopic_beat.py @@ -0,0 +1,41 @@ +import numpy as np + +def remove_ectopic_beats(NN, window_size, threshold): + NN = np.asarray(NN).flatten() + NN_corrected = NN.copy() + + half_win = window_size // 2 + ectopic_count = 0 + valid_count = 0 + + for i in range(len(NN)): + + if np.isnan(NN[i]): + continue + + valid_count += 1 + + # Define local window + left = max(0, i - half_win) + right = min(len(NN), i + half_win + 1) # Python slice is exclusive + + local_segment = NN[left:right] + local_segment = local_segment[~np.isnan(local_segment)] + + if local_segment.size == 0: + continue + + med_val = np.median(local_segment) + + # Detect ectopic + if abs(NN[i] - med_val) > threshold * med_val: + NN_corrected[i] = med_val + ectopic_count += 1 + + # Percentage over valid NN + if valid_count > 0: + ectopic_perc = (ectopic_count / valid_count) * 100 + else: + ectopic_perc = np.nan + + return NN_corrected, ectopic_perc \ No newline at end of file diff --git a/src/lolai_models.py b/src/lolai_models.py new file mode 100644 index 0000000..c93eb98 --- /dev/null +++ b/src/lolai_models.py @@ -0,0 +1,2214 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +import torch.nn as nn +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import matplotlib.pyplot as plt +from captum.attr import ( + IntegratedGradients, + LayerGradCam, + LayerAttribution, + Occlusion, + GradientShap +) + +from typing import Dict, List, Tuple, Union, Optional + + +class CNN_LSTM_Classifier(nn.Module): + def __init__(self, input_channels=3, hidden_dim=64, num_classes=3, dropout=0.3): + super(CNN_LSTM_Classifier, self).__init__() + + self.cnn = nn.Sequential( + nn.Conv1d(input_channels, 6, kernel_size=5, padding=2), + nn.BatchNorm1d(6), + nn.ReLU(), + nn.MaxPool1d(kernel_size=2), + nn.Dropout(dropout), + + nn.Conv1d(6, 9, kernel_size=3, padding=1), + nn.BatchNorm1d(9), + nn.ReLU(), + nn.MaxPool1d(kernel_size=2), + nn.Dropout(dropout), + + + nn.Conv1d(9, 18, kernel_size=3, padding=1), + nn.BatchNorm1d(18), + nn.ReLU(), + nn.MaxPool1d(kernel_size=2), + nn.Dropout(dropout) + ) + + self.lstm = nn.LSTM( + input_size=18, + hidden_size=hidden_dim, + batch_first=True, + bidirectional=True # Use bidirectional LSTM for better context + ) + + self.classifier = nn.Sequential( + nn.Linear(2 * hidden_dim, 64), # Adjust for bidirectional LSTM + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(64, num_classes) + ) + + def forward(self, x): + # x: (batch, channels, time) + x = self.cnn(x) # (batch, features, time) + x = x.permute(0, 2, 1) # (batch, time, features) + _, (h_n, _) = self.lstm(x) # h_n: (num_layers * num_directions, batch, hidden_dim) + h_n = torch.cat((h_n[-2], h_n[-1]), dim=1) # Concatenate forward and backward states + out = self.classifier(h_n) # (batch, num_classes) + return out + +import torch +import torch.nn as nn +import torch.nn.functional as F +import matplotlib.pyplot as plt +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class CNN_LSTM_Classifier_XAI(nn.Module): + def __init__(self, input_channels=3, hidden_dim=32, num_classes=3, dropout=0.4): + super(CNN_LSTM_Classifier_XAI, self).__init__() + + self.cnn_activations = [] + self.lstm_activations = None + self.attention_weights = None + self.gradients = None + self.last_cnn_output = None + self.input = None + + # CNN + self.conv1 = nn.Conv1d(input_channels, 16, kernel_size=5, padding=2) + self.bn1 = nn.BatchNorm1d(16) + self.conv2 = nn.Conv1d(16, 32, kernel_size=3, padding=1) + self.bn2 = nn.BatchNorm1d(32) + self.conv3 = nn.Conv1d(32, 64, kernel_size=3, padding=1) + self.bn3 = nn.BatchNorm1d(64) + self.pool = nn.MaxPool1d(kernel_size=2) + self.dropout = nn.Dropout(dropout) + + # LSTM + self.lstm = nn.LSTM(input_size=64, hidden_size=hidden_dim, + batch_first=True, bidirectional=True) + + # Atención + self.attention = nn.Linear(2 * hidden_dim, 1) + + # Clasificador + self.classifier = nn.Sequential( + nn.Linear(2 * hidden_dim, 32), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(32, num_classes) + ) + + def activations_hook(self, grad): + self.gradients = grad + + def forward(self, x, return_attention=False, track_gradients=False): + self.input = x + + x = self.conv1(x) + x = self.bn1(x) + x = F.relu(x) + self.cnn_activations.append(x.detach()) + x = self.pool(x) + x = self.dropout(x) + + x = self.conv2(x) + x = self.bn2(x) + x = F.relu(x) + self.cnn_activations.append(x.detach()) + x = self.pool(x) + x = self.dropout(x) + + x = self.conv3(x) + x = self.bn3(x) + x = F.relu(x) + cnn_output = x + + if track_gradients and cnn_output.requires_grad: + cnn_output.register_hook(self.activations_hook) + + self.last_cnn_output = cnn_output # necesario para Grad-CAM + self.cnn_activations.append(cnn_output.detach()) + + x = self.pool(cnn_output) + x = self.dropout(x) + + x = x.permute(0, 2, 1) # (batch, time, features) + lstm_out, _ = self.lstm(x) + self.lstm_activations = lstm_out.detach() + + attention_scores = self.attention(lstm_out).squeeze(-1) + attention_weights = F.softmax(attention_scores, dim=1) + self.attention_weights = attention_weights.detach() + + context_vector = torch.bmm(attention_weights.unsqueeze(1), lstm_out).squeeze(1) + out = self.classifier(context_vector) + + if return_attention: + return out, attention_weights + return out + + def reset_activation_storage(self): + self.cnn_activations = [] + self.lstm_activations = None + self.attention_weights = None + self.gradients = None + self.last_cnn_output = None + self.input = None + + def interpret(self, x, class_idx=None): + self.reset_activation_storage() + + was_training = self.training + lstm_was_training = self.lstm.training + + self.eval() + self.lstm.train() # necesario para CuDNN backward + + x.requires_grad_() + logits, attention = self.forward(x, return_attention=True, track_gradients=True) + pred = torch.softmax(logits, dim=1) + + if class_idx is None: + class_idx = pred.argmax(dim=1) + + for i in range(x.shape[0]): + pred[i, class_idx[i]].backward(retain_graph=True) + + self.train(was_training) + self.lstm.train(lstm_was_training) + + feature_importance = self.get_feature_importance() + temporal_channel_importance = self.get_temporal_channel_importance() + channel_imp=self.get_channel_importance() + + self.input = None # limpieza para evitar problemas de memoria + torch.cuda.empty_cache() + + return { + 'prediction': pred.detach(), + 'class_idx': class_idx, + 'attention_weights': self.attention_weights, + 'feature_importance': feature_importance, + 'cnn_activations': self.cnn_activations, + 'temporal_channel_importance': temporal_channel_importance, + 'channel_importance': channel_imp + } + + def get_feature_importance(self): + """ + Grad-CAM temporal sobre la salida del último bloque CNN. + Devuelve tensor (batch, time) + """ + if self.gradients is None or self.last_cnn_output is None: + return None + + pooled_gradients = torch.mean(self.gradients, dim=[0, 2]) # (channels,) + cam = self.last_cnn_output.clone() + + for i in range(cam.shape[1]): + cam[:, i, :] *= pooled_gradients[i] + + heatmap = torch.mean(cam, dim=1).detach() # (batch, time) + return heatmap + + def get_channel_importance(self): + """ + Importancia por canal: (batch, channels) + """ + if self.input.grad is None: + raise ValueError("Gradientes de la entrada no están disponibles. Llama primero a interpret().") + return self.input.grad.abs().mean(dim=2) + + def get_temporal_channel_importance(self): + """ + Importancia canal-temporal: (batch, channels, time) + """ + if self.input.grad is None: + raise ValueError("Gradients of the input are not available. Call interpret() first.") + return self.input.grad.abs().detach() + + + + + + +class ContrastiveVAE(nn.Module): + def __init__(self, in_channels=4, latent_dim=32, lstm_hidden=64, n_classes=3, use_classifier=False): + super().__init__() + self.use_classifier = use_classifier + + # Encoder + self.encoder = nn.Sequential( + nn.Conv1d(in_channels, 32, kernel_size=3, padding=1), + nn.ReLU(), + nn.Conv1d(32, 64, kernel_size=3, padding=1), + nn.ReLU() + ) + + self.global_pool = nn.AdaptiveAvgPool1d(1) # for VAE path + self.fc_mu = nn.Linear(64, latent_dim) + self.fc_logvar = nn.Linear(64, latent_dim) + + # Decoder (for reconstruction) + self.decoder_input = nn.Linear(latent_dim, 64) + self.decoder = nn.Sequential( + nn.ConvTranspose1d(64, 32, kernel_size=3, padding=1), + nn.ReLU(), + nn.ConvTranspose1d(32, in_channels, kernel_size=3, padding=1) + ) + + # LSTM + Classifier always initialized (but optionally used) + self.lstm = nn.LSTM(input_size=64, hidden_size=lstm_hidden, batch_first=True, bidirectional=True) + self.classifier = nn.Sequential( + nn.Linear(lstm_hidden * 2, 64), + nn.ReLU(), + nn.Linear(64, n_classes) + ) + + def encode(self, x): + h = self.encoder(x) # (B, 64, T) + pooled = self.global_pool(h).squeeze(-1) # (B, 64) + mu = self.fc_mu(pooled) + logvar = self.fc_logvar(pooled) + return mu, logvar, h # h is (B, 64, T) + + def reparameterize(self, mu, logvar): + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + return mu + eps * std + + def decode(self, z, length): + # Mejora la reconstrucción con una capa de proyección inicial + h = self.decoder_input(z).unsqueeze(-1) # (B, 64, 1) + # Usar interpolación para un escalado más suave en lugar de expand + h = F.interpolate(h, size=length, mode='linear', align_corners=False) + x_recon = self.decoder(h) + return x_recon + + def forward(self, x): + B, C, T = x.shape + mu, logvar, features = self.encode(x) # features: (B, 64, T) + z = self.reparameterize(mu, logvar) + x_recon = self.decode(z, T) + + logits = None + if self.use_classifier: + features_t = features.permute(0, 2, 1) # (B, T, 64) + lstm_out, _ = self.lstm(features_t) # (B, T, 2*hidden) + lstm_feat = lstm_out.mean(dim=1) # (B, 2*hidden) + logits = self.classifier(lstm_feat) # (B, n_classes) + + return x_recon, mu, logvar, z, logits + + def get_latents(self, x, use_mean=True): + mu, logvar, _ = self.encode(x) + return mu if use_mean else self.reparameterize(mu, logvar) + + def classify(self, x): + """Forward through the classifier only (requires use_classifier = True).""" + assert self.use_classifier, "Classifier is not enabled. Set model.use_classifier = True before calling classify." + _, _, features = self.encode(x) + features_t = features.permute(0, 2, 1) + lstm_out, _ = self.lstm(features_t) + lstm_feat = lstm_out.mean(dim=1) + return self.classifier(lstm_feat) + + +def vae_loss(recon_x, x, mu, logvar): + recon_loss = F.mse_loss(recon_x, x, reduction='mean') + kl_div = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp()) + return recon_loss + kl_div, recon_loss, kl_div + + +def contrastive_loss(z, ids, temperature=0.1): + z = F.normalize(z, dim=1) + sim = torch.mm(z, z.T) / temperature + labels = ids.view(-1, 1) + mask = torch.eq(labels, labels.T).float().to(z.device) + mask = mask - torch.eye(len(z), device=z.device) + exp_sim = torch.exp(sim) * (1 - torch.eye(len(z), device=z.device)) + log_prob = sim - torch.log(exp_sim.sum(1, keepdim=True) + 1e-8) + mean_log_prob_pos = (mask * log_prob).sum(1) / (mask.sum(1) + 1e-8) + return -mean_log_prob_pos.mean() + + +def intra_patient_loss(latents, patient_ids): + """Versión optimizada que evita bucles explícitos""" + # Convertir IDs a tensor si no lo son ya + if not isinstance(patient_ids, torch.Tensor): + patient_ids = torch.tensor(patient_ids, device=latents.device) + + # Crear matriz de similaridad de pacientes (1 donde son iguales) + patient_sim = (patient_ids.unsqueeze(1) == patient_ids.unsqueeze(0)).float() + # Quitar diagonal (mismo ejemplo) + mask = patient_sim - torch.eye(len(latents), device=latents.device) + # Calcular distancias entre latentes + latent_dists = torch.cdist(latents, latents, p=2) + # Aplicar máscara y promediar + valid_pairs = mask.sum() + if valid_pairs > 0: + return (mask * latent_dists).sum() / valid_pairs + return torch.tensor(0.0, device=latents.device) + + +def training_step(model, batch1, batch2, patient_ids, optimizer, alpha=0.1, beta=1.0): + model.train() + x1, x2 = batch1, batch2 + x1_recon, mu1, logvar1, z1, _ = model(x1) + x2_recon, mu2, logvar2, z2, _ = model(x2) + + recon1, r1, kl1 = vae_loss(x1_recon, x1, mu1, logvar1) + recon2, r2, kl2 = vae_loss(x2_recon, x2, mu2, logvar2) + vae_total = (recon1 + recon2) / 2 + + z_all = torch.cat([z1, z2], dim=0) + # Asegurar que IDs sean tensores + ids = torch.arange(len(z1), device=z1.device).repeat(2) + contrastive = contrastive_loss(z_all, ids) + + # Extender patient_ids correctamente + if isinstance(patient_ids, torch.Tensor): + p_ids = torch.cat([patient_ids, patient_ids], dim=0) + else: + p_ids = patient_ids + patient_ids # Si es una lista + + patient_reg = intra_patient_loss(z_all, p_ids) + + total = vae_total + alpha * contrastive + beta * patient_reg + + optimizer.zero_grad() + total.backward() + optimizer.step() + + return { + 'total_loss': total.item(), + 'recon_loss': vae_total.item(), + 'contrastive': contrastive.item(), + 'patient_reg': patient_reg.item(), + 'kl_loss': (kl1 + kl2).item() / 2 + } + + +def fine_tune_step(model, x, y, optimizer, criterion=nn.CrossEntropyLoss()): + model.train() + model.use_classifier = True + logits = model.classify(x) + loss = criterion(logits, y) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + preds = torch.argmax(logits, dim=1) + acc = (preds == y).float().mean().item() + + return { + 'classification_loss': loss.item(), + 'accuracy': acc + } + +class ImprovedPainClassifier(nn.Module): + def __init__(self, input_channels=3, hidden_dim=128, num_classes=3, dropout=0.4): + super(ImprovedPainClassifier, self).__init__() + + # Increased regularization and feature extraction for small datasets + self.cnn = nn.Sequential( + # Layer 1: More filters to capture diverse patterns + nn.Conv1d(input_channels, 64, kernel_size=3, padding=1), + nn.BatchNorm1d(64), + nn.LeakyReLU(0.1), # LeakyReLU helps with gradient flow + nn.MaxPool1d(kernel_size=2), + + # Layer 2: Increased complexity + nn.Conv1d(64, 128, kernel_size=3, padding=1), + nn.BatchNorm1d(128), + nn.LeakyReLU(0.1), + nn.MaxPool1d(kernel_size=2), + nn.Dropout(dropout), + + # Layer 3: Additional layer for better feature extraction + nn.Conv1d(128, 128, kernel_size=3, padding=1), + nn.BatchNorm1d(128), + nn.LeakyReLU(0.1), + nn.Dropout(dropout) + ) + + # Attention mechanism to focus on important temporal patterns + self.attention = nn.Sequential( + nn.Linear(256, 64), + nn.Tanh(), + nn.Linear(64, 1) + ) + + # Bidirectional LSTM with residual connections + self.lstm = nn.LSTM( + input_size=128, + hidden_size=hidden_dim, + num_layers=2, # Multiple layers for complex temporal patterns + batch_first=True, + bidirectional=True, + dropout=dropout # Apply dropout between LSTM layers + ) + + # Classifier with additional regularization + self.classifier = nn.Sequential( + nn.Linear(2 * hidden_dim, hidden_dim), + nn.BatchNorm1d(hidden_dim), # Normalize activations + nn.LeakyReLU(0.1), + nn.Dropout(dropout), + nn.Linear(hidden_dim, hidden_dim // 2), + nn.LeakyReLU(0.1), + nn.Dropout(dropout), + nn.Linear(hidden_dim // 2, num_classes) + ) + + def forward(self, x): + # x: (batch, channels, time) - BVP, EDA, and respiratory signals + + # Extract features with CNN + cnn_out = self.cnn(x) # (batch, 128, time') + + # Reshape for LSTM + cnn_out = cnn_out.permute(0, 2, 1) # (batch, time', 128) + + # Process with LSTM + lstm_out, (h_n, _) = self.lstm(cnn_out) # lstm_out: (batch, time', 2*hidden_dim) + + # Apply attention to focus on relevant parts of the signal + attn_weights = self.attention(lstm_out).softmax(dim=1) # (batch, time', 1) + context = torch.sum(attn_weights * lstm_out, dim=1) # (batch, 2*hidden_dim) + + # Alternative: Use concatenated hidden states from both directions + # h_n = torch.cat((h_n[-2], h_n[-1]), dim=1) # (batch, 2*hidden_dim) + + # Classify + out = self.classifier(context) # (batch, num_classes) + return out + + +import matplotlib.pyplot as plt +import numpy as np +import torch + +class ExplainabilityVisualizer: + def __init__(self, channel_names=None): + """ + channel_names: lista opcional con nombres de los canales de entrada + """ + self.channel_names = channel_names + + def plot_attention_weights(self, attention_weights, title="Atención temporal"): + attention = attention_weights.squeeze().cpu().numpy() + plt.figure(figsize=(10, 2)) + plt.plot(attention) + plt.title(title) + plt.xlabel("Timestep") + plt.ylabel("Weight") + plt.grid(True) + plt.tight_layout() + plt.show() + + def plot_gradcam_heatmap(self, heatmap, title="Grad-CAM temporal"): + heat = heatmap.squeeze().cpu().numpy() + plt.figure(figsize=(10, 2)) + plt.plot(heat) + plt.title(title) + plt.xlabel("Timestep") + plt.ylabel("Importance") + plt.grid(True) + plt.tight_layout() + plt.show() + + def plot_channel_importance(self, channel_importance, title="Importancia por canal"): + values = channel_importance.squeeze().cpu().numpy() + channels = self.channel_names if self.channel_names else [f"Channel {i}" for i in range(len(values))] + plt.figure(figsize=(6, 3)) + plt.bar(channels, values) + plt.title(title) + plt.ylabel("Importancia media") + plt.xticks(rotation=45) + plt.grid(axis='y') + plt.tight_layout() + plt.show() + + def plot_signals_with_attention_highlight(self, x, importance, threshold=0.85, title="Señal con zonas de atención"): + """ + Dibuja las señales multicanal y sombreado rojo donde la importancia temporal supera el umbral. + + Args: + x: Tensor (channels, time) + importance: Tensor (time,) + threshold: percentil (0-1) o valor absoluto + title: título del gráfico + """ + x = x.detach().cpu().numpy() + importance = importance.detach().cpu().numpy() + time = np.arange(x.shape[1]) + n_channels = x.shape[0] + + if threshold <= 1.0: + threshold_value = np.quantile(importance, threshold) + else: + threshold_value = threshold + + high_attention_mask = importance >= threshold_value + + fig, axs = plt.subplots(n_channels, 1, figsize=(12, 2.5 * n_channels), sharex=True) + if n_channels == 1: + axs = [axs] + + for i in range(n_channels): + axs[i].plot(time, x[i], label=self.channel_names[i] if self.channel_names else f"Canal {i}", color="black") + axs[i].set_ylabel("Valor") + axs[i].grid(True) + + in_high = False + start = 0 + for t in range(len(high_attention_mask)): + if high_attention_mask[t] and not in_high: + start = t + in_high = True + elif not high_attention_mask[t] and in_high: + axs[i].axvspan(start, t, color='red', alpha=0.25) + in_high = False + if in_high: + axs[i].axvspan(start, len(high_attention_mask), color='red', alpha=0.25) + + axs[i].legend(loc="upper right") + + axs[-1].set_xlabel("Tiempo (muestras)") + plt.suptitle(title) + plt.tight_layout() + plt.show() + + + +class FocalLoss(nn.Module): + """ + Focal Loss para clasificación binaria y multiclase. + + Parámetros: + - alpha: Factor de ponderación para manejar desequilibrio de clases. + Puede ser un escalar (mismo valor para todas las clases) o + un tensor (valores específicos por clase). + - gamma: Factor de modulación para enfocar en ejemplos difíciles (>= 0). + - reduction: 'none' | 'mean' | 'sum' + - eps: Pequeño valor para estabilidad numérica + + Referencias: + - Paper original: "Focal Loss for Dense Object Detection" por Lin et al. + """ + def __init__(self, alpha=0.25, gamma=2.0, reduction='mean', eps=1e-6): + super(FocalLoss, self).__init__() + self.alpha = alpha + self.gamma = gamma + self.reduction = reduction + self.eps = eps + + def forward(self, inputs, targets): + """ + Args: + inputs: Logits de forma [B, C] donde B es el tamaño del batch y C es el número de clases. + Para clasificación binaria, C puede ser 1. + targets: Etiquetas de objetivos de forma [B] para multiclase o [B, 1] para binaria. + Valores enteros para multiclase (clases indexadas desde 0 a C-1). + Valores continuos entre 0 y 1 para binaria. + """ + # Determinar si es clasificación binaria o multiclase + if inputs.shape[1] == 1 or inputs.shape[1] == 2: # Binaria + # Aplicar sigmoide para obtener probabilidades + probs = torch.sigmoid(inputs.view(-1)) + targets = targets.view(-1) + + # Calcular pt (probabilidad del objetivo correcto) + pt = probs * targets + (1 - probs) * (1 - targets) + + # Aplicar factores de ponderación + if isinstance(self.alpha, (float, int)): + alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets) + else: + # Si alpha es un tensor, usar indexación + alpha_t = self.alpha if self.alpha is not None else torch.ones_like(pt) + + # Calcular la focal loss + focal_weight = (1 - pt).pow(self.gamma) + loss = -alpha_t * focal_weight * torch.log(pt.clamp(min=self.eps)) + + else: # Multiclase + # Convertir logits a distribución de probabilidad + log_softmax = F.log_softmax(inputs, dim=1) + + # Obtener log probabilidad para las clases objetivo + targets = targets.view(-1, 1) + log_pt = log_softmax.gather(1, targets).view(-1) + pt = log_pt.exp() # Obtener probabilidades + + # Aplicar factores de ponderación + if isinstance(self.alpha, (list, tuple, torch.Tensor)): + # Si alpha es específico por clase + alpha = torch.tensor(self.alpha, device=inputs.device) + alpha_t = alpha.gather(0, targets.view(-1)) + else: + alpha_t = self.alpha if self.alpha is not None else 1.0 + + # Calcular focal loss + focal_weight = (1 - pt).pow(self.gamma) + loss = -alpha_t * focal_weight * log_pt + + # Aplicar reduction + if self.reduction == 'mean': + return loss.mean() + elif self.reduction == 'sum': + return loss.sum() + else: # 'none' + return loss + + + + +class CNN_LSTM_Classifier_XAI_2(nn.Module): + def __init__(self, input_channels=3, hidden_dim=32, num_classes=3, dropout=0.1): + super(CNN_LSTM_Classifier_XAI_2, self).__init__() + + self.cnn_activations = [] + self.lstm_activations = None + self.attention_weights = None + self.gradients = None + self.last_cnn_output = None + self.input = None + self.input_channels = input_channels + + # CNN + self.conv1 = nn.Conv1d(input_channels, 16, kernel_size=5, padding=2) + self.bn1 = nn.BatchNorm1d(16) + self.conv2 = nn.Conv1d(16, 32, kernel_size=3, padding=1) + self.bn2 = nn.BatchNorm1d(32) + self.conv3 = nn.Conv1d(32, 64, kernel_size=3, padding=1) + self.bn3 = nn.BatchNorm1d(64) + self.pool = nn.MaxPool1d(kernel_size=2) + self.dropout = nn.Dropout(dropout) + + # LSTM + self.lstm = nn.LSTM(input_size=64, hidden_size=hidden_dim, + batch_first=True, bidirectional=True) + + # Attention + self.attention = nn.Linear(2 * hidden_dim, 1) + + # Classifier + self.classifier = nn.Sequential( + nn.Linear(2 * hidden_dim, 32), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(32, num_classes) + ) + + def activations_hook(self, grad): + self.gradients = grad + + + + def forward(self, x, return_attention=False, track_gradients=False): + # Reset activation storage at the beginning of each forward pass + self.reset_activation_storage() + + self.input = x + + x = self.conv1(x) + x = self.bn1(x) + x = F.relu(x) + self.cnn_activations.append(x.detach()) + x = self.pool(x) + x = self.dropout(x) + + x = self.conv2(x) + x = self.bn2(x) + x = F.relu(x) + self.cnn_activations.append(x.detach()) + x = self.pool(x) + x = self.dropout(x) + + x = self.conv3(x) + x = self.bn3(x) + x = F.relu(x) + cnn_output = x + + if track_gradients and cnn_output.requires_grad: + cnn_output.register_hook(self.activations_hook) + + self.last_cnn_output = cnn_output # needed for Grad-CAM + self.cnn_activations.append(cnn_output.detach()) + + x = self.pool(cnn_output) + x = self.dropout(x) + + x = x.permute(0, 2, 1) # (batch, time, features) + lstm_out, (h_n, c_n) = self.lstm(x) + self.lstm_activations = lstm_out.detach() + + attention_scores = self.attention(lstm_out).squeeze(-1) + attention_weights = F.softmax(attention_scores, dim=1) + self.attention_weights = attention_weights.detach() + + context_vector = torch.bmm(attention_weights.unsqueeze(1), lstm_out).squeeze(1) + out = self.classifier(context_vector) + + if return_attention: + return out, attention_weights + return out + + def reset_activation_storage(self): + self.cnn_activations = [] + self.lstm_activations = None + self.attention_weights = None + self.gradients = None + self.last_cnn_output = None + + def interpret(self, x, class_idx=None, methods=None): + """ + Enhanced interpretation method with multiple explainability techniques + + Args: + x: Input data tensor + class_idx: Target class indices to explain (defaults to predicted class) + methods: List of methods to use, options: ['gradcam', 'integrated_gradients', + 'occlusion', 'shap', 'feature_ablation', 'all'] + + Returns: + Dictionary with various interpretability outputs + """ + if methods is None: + methods = ['gradcam', 'attention'] # Default methods + if 'all' in methods: + methods = ['gradcam', 'integrated_gradients', 'occlusion', 'shap', + 'feature_ablation', 'attention', 'layer_importance'] + + # Store original training state + was_training = self.training + lstm_was_training = self.lstm.training + + # Set model to evaluation mode for interpretability + self.eval() + self.lstm.train() # needed for CuDNN backward compatibility + + # Base prediction + x.requires_grad_() + self.input = x # Store input for interpretability methods + + logits, attention = self.forward(x, return_attention=True, track_gradients=True) + pred = torch.softmax(logits, dim=1) + + if class_idx is None: + class_idx = pred.argmax(dim=1) + + # Initialize results dictionary + results = { + 'prediction': pred.detach(), + 'class_idx': class_idx, + 'attention_weights': self.attention_weights, + } + + # Apply selected interpretability methods + if 'gradcam' in methods: + for i in range(x.shape[0]): + pred[i, class_idx[i]].backward(retain_graph=True if i < x.shape[0]-1 else False) + + results['feature_importance'] = self.get_feature_importance() + results['temporal_channel_importance'] = self.get_temporal_channel_importance() + results['channel_importance'] = self.get_channel_importance() + results['cnn_activations'] = self.cnn_activations + + # Integrated Gradients + if 'integrated_gradients' in methods: + ig = IntegratedGradients(self.forward_wrapper) + results['integrated_gradients'] = self._compute_integrated_gradients( + ig, x, class_idx) + + # Occlusion analysis + if 'occlusion' in methods: + occlusion = Occlusion(self.forward_wrapper) + results['occlusion'] = self._compute_occlusion(occlusion, x, class_idx) + + # SHAP (GradientSHAP implementation) + if 'shap' in methods: + gradient_shap = GradientShap(self.forward_wrapper) + results['gradient_shap'] = self._compute_gradient_shap(gradient_shap, x, class_idx) + + # Feature ablation (sensitivity analysis) + if 'feature_ablation' in methods: + results['feature_ablation'] = self._feature_ablation_analysis(x, class_idx) + + # Layer importance analysis + if 'layer_importance' in methods: + results['layer_importance'] = self._compute_layer_importance(x, class_idx) + + # Restore original training states + self.train(was_training) + self.lstm.train(lstm_was_training) + + # Clean up to avoid memory issues + self.input = None + torch.cuda.empty_cache() + + return results + + def forward_wrapper(self, x): + """Wrapper for Captum compatibility""" + return self.forward(x) + + def get_feature_importance(self): + """ + Grad-CAM temporal over the output of the last CNN block. + Returns tensor (batch, time) + """ + if self.gradients is None or self.last_cnn_output is None: + return None + + pooled_gradients = torch.mean(self.gradients, dim=[0, 2]) # (channels,) + cam = self.last_cnn_output.clone() + + for i in range(cam.shape[1]): + cam[:, i, :] *= pooled_gradients[i] + + heatmap = torch.mean(cam, dim=1).detach() # (batch, time) + + # Apply ReLU to highlight only positive influences + heatmap = F.relu(heatmap) + + # Normalize heatmap for better visualization + if heatmap.max() > 0: + heatmap = heatmap / heatmap.max() + + return heatmap + + def get_channel_importance(self): + """ + Channel importance: (batch, channels) + """ + if self.input is None or self.input.grad is None: + raise ValueError("Input gradients not available. Call interpret() first.") + return self.input.grad.abs().mean(dim=2).detach() + + def get_temporal_channel_importance(self): + """ + Temporal-channel importance: (batch, channels, time) + """ + if self.input is None or self.input.grad is None: + raise ValueError("Input gradients not available. Call interpret() first.") + return self.input.grad.abs().detach() + + def _compute_integrated_gradients(self, ig, x, class_idx): + """Compute integrated gradients attribution""" + batch_size = x.shape[0] + attributions = [] + + for i in range(batch_size): + baseline = torch.zeros_like(x[i:i+1]) + attr = ig.attribute( + x[i:i+1], baseline, target=class_idx[i].item(), n_steps=50 + ) + attributions.append(attr) + + return torch.cat(attributions).detach() + + def _compute_occlusion(self, occlusion_algo, x, class_idx): + """Compute occlusion-based feature attribution""" + batch_size = x.shape[0] + attributions = [] + + # Define sliding window parameters for temporal data + window_size = min(5, x.shape[2] // 4) # Adapt window size to input length + + for i in range(batch_size): + attr = occlusion_algo.attribute( + x[i:i+1], + sliding_window_shapes=(1, window_size), + target=class_idx[i].item(), + strides=(1, max(1, window_size // 2)) + ) + attributions.append(attr) + + return torch.cat(attributions).detach() + + def _compute_gradient_shap(self, shap_algo, x, class_idx): + """Compute GradientSHAP attributions""" + batch_size = x.shape[0] + attributions = [] + + for i in range(batch_size): + # Create random baselines (typically 10-50 for good estimates) + baselines = torch.randn(10, *x[i:i+1].shape[1:]) * 0.001 + + # Ensure baselines device matches input + baselines = baselines.to(x.device) + + attr = shap_algo.attribute( + x[i:i+1], baselines=baselines, target=class_idx[i].item() + ) + attributions.append(attr) + + return torch.cat(attributions).detach() + + def _feature_ablation_analysis(self, x, class_idx): + """Analyze model by systematically ablating input features""" + batch_size = x.shape[0] + results = [] + + for i in range(batch_size): + # Store original prediction + with torch.no_grad(): + orig_output = self.forward(x[i:i+1]) + orig_prob = torch.softmax(orig_output, dim=1)[0, class_idx[i]].item() + + # Test ablation of each channel + channel_importance = [] + for c in range(self.input_channels): + # Create ablated input (zero out one channel) + ablated_input = x[i:i+1].clone() + ablated_input[:, c, :] = 0 + + # Get prediction on ablated input + with torch.no_grad(): + ablated_output = self.forward(ablated_input) + ablated_prob = torch.softmax(ablated_output, dim=1)[0, class_idx[i]].item() + + # Impact is reduction in probability + channel_impact = orig_prob - ablated_prob + channel_importance.append(channel_impact) + + results.append(torch.tensor(channel_importance)) + + return torch.stack(results) + + def _compute_layer_importance(self, x, class_idx): + """Compute importance of each layer using Layer GradCAM""" + batch_size = x.shape[0] + layer_importance = {} + + # Define layers to analyze + layers = { + 'conv1': self.conv1, + 'conv2': self.conv2, + 'conv3': self.conv3 + } + + for layer_name, layer in layers.items(): + layer_gradcam = LayerGradCam(self.forward_wrapper, layer) + layer_attrs = [] + + for i in range(batch_size): + attr = layer_gradcam.attribute( + x[i:i+1], target=class_idx[i].item() + ) + # Process attribution to create a single importance score per sample + pooled_attr = torch.mean(attr, dim=1) + + layer_attrs.append(pooled_attr) + + layer_importance[layer_name] = torch.cat(layer_attrs).detach() + + return layer_importance + + def visualize_attributions(self, sample_idx, interpretations, time_axis=None, + channel_names=None, class_names=None): + """ + Visualize the various interpretation results + + Args: + sample_idx: Index of the sample to visualize + interpretations: Dictionary returned by interpret() method + time_axis: Optional array/list with time points for x-axis + channel_names: Optional list of channel names + class_names: Optional list of class names + """ + if not channel_names: + channel_names = [f'Channel {i}' for i in range(self.input_channels)] + + if not class_names: + class_idx = interpretations['class_idx'][sample_idx].item() + class_name = f'Class {class_idx}' + else: + class_idx = interpretations['class_idx'][sample_idx].item() + class_name = class_names[class_idx] + + # Set up figure + plt.figure(figsize=(15, 12)) + + # Original input visualization (top row, first column) + plt.subplot(3, 3, 1) + if self.input is not None: + input_data = self.input[sample_idx].cpu().detach().numpy() + if time_axis is not None: + for i in range(input_data.shape[0]): + plt.plot(time_axis, input_data[i], label=channel_names[i]) + else: + for i in range(input_data.shape[0]): + plt.plot(input_data[i], label=channel_names[i]) + plt.legend(loc='best') + plt.title('Input Signal') + plt.xlabel('Time') + plt.ylabel('Value') + + # GradCAM feature importance (top row, second column) + if 'feature_importance' in interpretations and interpretations['feature_importance'] is not None: + plt.subplot(3, 3, 2) + heatmap = interpretations['feature_importance'][sample_idx].cpu().numpy() + if time_axis is not None: + plt.plot(time_axis, heatmap) + else: + plt.plot(heatmap) + plt.title('GradCAM Feature Importance') + plt.xlabel('Time') + plt.ylabel('Importance') + + # Attention weights (top row, third column) + if 'attention_weights' in interpretations and interpretations['attention_weights'] is not None: + plt.subplot(3, 3, 3) + attention = interpretations['attention_weights'][sample_idx].cpu().numpy() + + if time_axis is not None: + # Need to match attention time axis to input time axis + # (account for pooling in the network) + x_points = np.linspace(time_axis[0], time_axis[-1], len(attention)) + plt.plot(x_points, attention) + else: + plt.plot(attention) + plt.title('Attention Weights') + plt.xlabel('Time') + plt.ylabel('Attention') + + # Channel importance (middle row, first column) + if 'channel_importance' in interpretations and interpretations['channel_importance'] is not None: + plt.subplot(3, 3, 4) + ch_importance = interpretations['channel_importance'][sample_idx].cpu().numpy() + plt.bar(channel_names, ch_importance) + plt.title('Channel Importance') + plt.ylabel('Importance') + plt.xticks(rotation=45) + + # Integrated Gradients (middle row, second column) + if 'integrated_gradients' in interpretations: + plt.subplot(3, 3, 5) + ig_attr = interpretations['integrated_gradients'][sample_idx].cpu().numpy() + ig_attr_mean = np.mean(ig_attr, axis=0) # Average across channels for visualization + + if time_axis is not None: + plt.plot(time_axis, ig_attr_mean) + else: + plt.plot(ig_attr_mean) + plt.title('Integrated Gradients') + plt.xlabel('Time') + plt.ylabel('Attribution') + + # Feature Ablation (middle row, third column) + if 'feature_ablation' in interpretations: + plt.subplot(3, 3, 6) + ablation_scores = interpretations['feature_ablation'][sample_idx].cpu().numpy() + plt.bar(channel_names, ablation_scores) + plt.title('Feature Ablation Impact') + plt.ylabel('Probability Change') + plt.xticks(rotation=45) + + # SHAP values (bottom row, first column) + if 'gradient_shap' in interpretations: + plt.subplot(3, 3, 7) + shap_attr = interpretations['gradient_shap'][sample_idx].cpu().numpy() + # Visualize average SHAP value over time + shap_avg = np.mean(shap_attr, axis=0) + + if time_axis is not None: + plt.plot(time_axis, shap_avg) + else: + plt.plot(shap_avg) + plt.title('GradientSHAP Values') + plt.xlabel('Time') + plt.ylabel('SHAP Value') + + # Occlusion analysis (bottom row, second column) + if 'occlusion' in interpretations: + plt.subplot(3, 3, 8) + occlusion_attr = interpretations['occlusion'][sample_idx].cpu().numpy() + occlusion_avg = np.mean(occlusion_attr, axis=0) + + if time_axis is not None: + plt.plot(time_axis, occlusion_avg) + else: + plt.plot(occlusion_avg) + plt.title('Occlusion Analysis') + plt.xlabel('Time') + plt.ylabel('Attribution') + + # Prediction summary (bottom row, third column) + plt.subplot(3, 3, 9) + pred_probs = interpretations['prediction'][sample_idx].cpu().numpy() + classes = list(range(len(pred_probs))) + if class_names: + classes = class_names + plt.bar(classes, pred_probs) + plt.title(f'Prediction: {class_name}') + plt.ylabel('Probability') + plt.ylim([0, 1]) + + plt.tight_layout() + return plt.gcf() + + def generate_interpretation_report(self, input_data, class_idx=None, + channel_names=None, class_names=None, + time_axis=None, methods='all'): + """ + Generate a comprehensive interpretation report for the given input + + Args: + input_data: Input tensor to analyze + class_idx: Target class indices (optional) + channel_names: Names of input channels (optional) + class_names: Names of output classes (optional) + time_axis: Time points for x-axis (optional) + methods: Explainability methods to use + + Returns: + Dictionary containing interpretations and visualization figure + """ + # Run all interpretation methods + interpretations = self.interpret(input_data, class_idx, methods=methods) + + # Generate visualizations for each sample + figures = [] + for i in range(input_data.shape[0]): + fig = self.visualize_attributions( + i, interpretations, + time_axis=time_axis, + channel_names=channel_names, + class_names=class_names + ) + figures.append(fig) + plt.close(fig) # Close to avoid display in notebooks + + return { + 'interpretations': interpretations, + 'figures': figures + } + + + +class CNN_LSTM_Classifier_XAI_2(nn.Module): + def __init__( + self, + input_channels=3, + num_classes=3, + cnn_channels=(16, 32, 64), + kernel_sizes=(5, 3, 3), + pool_type="max", # or 'avg' + dropout=0.1, + lstm_hidden_dim=32, + lstm_num_layers=1, + bidirectional=True, + classifier_hidden_dim=32, + attention_dim=None, # None = default: 2 * lstm_hidden_dim + ): + super(CNN_LSTM_Classifier_XAI_2, self).__init__() + + self.input_channels = input_channels + self.pool_type = pool_type + self.dropout_rate = dropout + self.bidirectional = bidirectional + self.num_directions = 2 if bidirectional else 1 + + self.cnn_activations = [] + self.lstm_activations = None + self.attention_weights = None + self.gradients = None + self.last_cnn_output = None + self.input = None + + # CNN Blocks + self.conv1 = nn.Conv1d(input_channels, cnn_channels[0], kernel_size=kernel_sizes[0], padding=kernel_sizes[0] // 2) + self.bn1 = nn.BatchNorm1d(cnn_channels[0]) + + self.conv2 = nn.Conv1d(cnn_channels[0], cnn_channels[1], kernel_size=kernel_sizes[1], padding=kernel_sizes[1] // 2) + self.bn2 = nn.BatchNorm1d(cnn_channels[1]) + + self.conv3 = nn.Conv1d(cnn_channels[1], cnn_channels[2], kernel_size=kernel_sizes[2], padding=kernel_sizes[2] // 2) + self.bn3 = nn.BatchNorm1d(cnn_channels[2]) + + self.pool = nn.MaxPool1d(kernel_size=2) if pool_type == "max" else nn.AvgPool1d(kernel_size=2) + self.dropout = nn.Dropout(dropout) + + # LSTM + self.lstm = nn.LSTM( + input_size=cnn_channels[2], + hidden_size=lstm_hidden_dim, + num_layers=lstm_num_layers, + batch_first=True, + bidirectional=bidirectional + ) + + # Attention + attention_dim = attention_dim or self.num_directions * lstm_hidden_dim + self.attention = nn.Linear(self.num_directions * lstm_hidden_dim, 1) + + # Classifier + self.classifier = nn.Sequential( + nn.Linear(self.num_directions * lstm_hidden_dim, classifier_hidden_dim), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(classifier_hidden_dim, num_classes) + ) + + def activations_hook(self, grad): + self.gradients = grad + + def reset_activation_storage(self): + self.cnn_activations = [] + self.lstm_activations = None + self.attention_weights = None + self.gradients = None + self.last_cnn_output = None + + def forward(self, x, return_attention=False, track_gradients=False): + self.reset_activation_storage() + self.input = x + + x = self.pool(F.relu(self.bn1(self.conv1(x)))) + self.cnn_activations.append(x.detach()) + x = self.dropout(x) + + x = self.pool(F.relu(self.bn2(self.conv2(x)))) + self.cnn_activations.append(x.detach()) + x = self.dropout(x) + + x = F.relu(self.bn3(self.conv3(x))) + cnn_output = x + + if track_gradients and cnn_output.requires_grad: + cnn_output.register_hook(self.activations_hook) + + self.last_cnn_output = cnn_output + self.cnn_activations.append(cnn_output.detach()) + + x = self.pool(cnn_output) + x = self.dropout(x) + + x = x.permute(0, 2, 1) # (batch, time, features) + lstm_out, _ = self.lstm(x) + self.lstm_activations = lstm_out.detach() + + attention_scores = self.attention(lstm_out).squeeze(-1) + attention_weights = F.softmax(attention_scores, dim=1) + self.attention_weights = attention_weights.detach() + + context_vector = torch.bmm(attention_weights.unsqueeze(1), lstm_out).squeeze(1) + out = self.classifier(context_vector) + + if return_attention: + return out, attention_weights + return out + + +class CNN_LSTM_Classifier_Tunable(nn.Module): + def __init__( + self, + config: Optional[Dict] = None, + input_channels: int = 3, + seq_length: int = None, + num_classes: int = 3, + cnn_channels: Tuple[int, ...] = (16, 32, 64), + kernel_sizes: Tuple[int, ...] = (5, 3, 3), + pool_type: str = "max", # or 'avg' + pool_sizes: Tuple[int, ...] = (2, 2, 2), + use_batch_norm: bool = True, + activation: str = "relu", # "relu", "leaky_relu", "elu", "gelu" + dropout: float = 0.1, + cnn_dropout: Optional[float] = None, # Separate dropout for CNN + lstm_hidden_dim: int = 32, + lstm_num_layers: int = 1, + bidirectional: bool = True, + lstm_dropout: Optional[float] = None, # Separate dropout for LSTM + classifier_hidden_dims: List[int] = [32], # Multiple hidden layers + attention_dim: Optional[int] = None, # None = default: 2 * lstm_hidden_dim + attention_type: str = "basic", # "basic", "scaled_dot", "multi_head" + multi_head_num: int = 4, # For multi-head attention + residual_connections: bool = False, + layer_normalization: bool = False, + weight_init: str = "default", # "default", "xavier", "kaiming" + ): + """ + Enhanced CNN-LSTM model with attention mechanism designed for tuning flexibility. + + Args: + config: Optional dictionary with all hyperparameters to override other arguments + input_channels: Number of input channels + seq_length: Length of input sequence (needed for some operations) + num_classes: Number of output classes + cnn_channels: Tuple of CNN output channels for each layer + kernel_sizes: Tuple of kernel sizes for each CNN layer + pool_type: Pooling type ("max" or "avg") + pool_sizes: Pooling sizes for each layer + use_batch_norm: Whether to use batch normalization + activation: Activation function type + dropout: Default dropout rate + cnn_dropout: CNN-specific dropout (if None, uses dropout) + lstm_hidden_dim: LSTM hidden dimension + lstm_num_layers: Number of LSTM layers + bidirectional: Whether LSTM is bidirectional + lstm_dropout: LSTM-specific dropout (if None, uses dropout) + classifier_hidden_dims: List of hidden dimensions for classifier + attention_dim: Attention dimension + attention_type: Type of attention mechanism + multi_head_num: Number of heads for multi-head attention + residual_connections: Whether to use residual connections + layer_normalization: Whether to use layer normalization + weight_init: Weight initialization strategy + """ + super(CNN_LSTM_Classifier_Tunable, self).__init__() + + # Override with config if provided + if config is not None: + # Set all attributes from config + for key, value in config.items(): + if hasattr(self, key): + setattr(self, key, value) + elif key in locals(): + locals()[key] = value + + # Store parameters + self.input_channels = input_channels + self.seq_length = seq_length + self.num_classes = num_classes + self.cnn_channels = cnn_channels + self.kernel_sizes = kernel_sizes + self.pool_type = pool_type + self.pool_sizes = pool_sizes + self.use_batch_norm = use_batch_norm + self.activation_type = activation + self.dropout_rate = dropout + self.cnn_dropout_rate = cnn_dropout if cnn_dropout is not None else dropout + self.lstm_hidden_dim = lstm_hidden_dim + self.lstm_num_layers = lstm_num_layers + self.bidirectional = bidirectional + self.lstm_dropout_rate = lstm_dropout if lstm_dropout is not None else dropout + self.classifier_hidden_dims = classifier_hidden_dims + self.residual_connections = residual_connections + self.layer_normalization = layer_normalization + self.weight_init = weight_init + self.attention_type = attention_type + self.multi_head_num = multi_head_num + + # Calculate directions + self.num_directions = 2 if bidirectional else 1 + + # Default attention dimension if not provided + self.attention_dim = attention_dim or self.num_directions * lstm_hidden_dim + + # Precalcular dimensiones de secuencia después de las capas CNN + self.input_seq_length = seq_length + self.output_seq_length = None + + if seq_length is not None: + # Calcular reducción de secuencia por pooling + seq_reduction = 1 + current_length = seq_length + + for pool_size in self.pool_sizes: + current_length = (current_length + pool_size - 1) // pool_size # Ceil division + seq_reduction *= pool_size + + self.output_seq_length = current_length + + # Verificar dimensiones válidas + if self.output_seq_length <= 0: + raise ValueError(f"La secuencia resultante después del pooling es demasiado corta. " + f"Secuencia entrada: {seq_length}, reducción: {seq_reduction}") + + # For visualization and explanation + self.cnn_activations = [] + self.lstm_activations = None + self.attention_weights = None + self.gradients = None + self.last_cnn_output = None + self.input = None + + # Create activation function + self.activation = self._get_activation() + + # Create CNN layers + self.cnn_blocks = nn.ModuleList() + in_channels = input_channels + + for i, (out_channels, kernel_size, pool_size) in enumerate(zip(cnn_channels, kernel_sizes, pool_sizes)): + block = nn.ModuleDict() + block["conv"] = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2) + + if use_batch_norm: + block["bn"] = nn.BatchNorm1d(out_channels) + + if layer_normalization: + # Usamos LayerNorm correctamente para normalizar sobre la dimensión de características + block["ln"] = nn.LayerNorm([out_channels]) + + if pool_type == "max": + block["pool"] = nn.MaxPool1d(kernel_size=pool_size) + else: + block["pool"] = nn.AvgPool1d(kernel_size=pool_size) + + block["dropout"] = nn.Dropout(self.cnn_dropout_rate) + + # Determinar si este bloque puede usar conexión residual + # Solo si tienen mismas dimensiones de entrada y salida + if residual_connections and in_channels == out_channels: + block["has_residual"] = True + else: + block["has_residual"] = False + + self.cnn_blocks.append(block) + in_channels = out_channels + + # LSTM layer + self.lstm = nn.LSTM( + input_size=cnn_channels[-1], + hidden_size=lstm_hidden_dim, + num_layers=lstm_num_layers, + batch_first=True, + bidirectional=bidirectional, + dropout=self.lstm_dropout_rate if lstm_num_layers > 1 else 0 + ) + + # Attention mechanism + lstm_output_dim = self.num_directions * lstm_hidden_dim + if attention_type == "basic": + self.attention = nn.Linear(lstm_output_dim, 1) + elif attention_type == "scaled_dot": + self.query = nn.Linear(lstm_output_dim, self.attention_dim) + self.key = nn.Linear(lstm_output_dim, self.attention_dim) + self.value = nn.Linear(lstm_output_dim, lstm_output_dim) + elif attention_type == "multi_head": + self.mha = nn.MultiheadAttention( + embed_dim=lstm_output_dim, + num_heads=multi_head_num, + batch_first=True + ) + self.attention_ln = nn.LayerNorm(lstm_output_dim) + else: + # Caso por defecto para evitar errores + self.attention = nn.Linear(lstm_output_dim, 1) + print(f"ADVERTENCIA: Tipo de atención '{attention_type}' no reconocido. Usando 'basic'.") + + # Classifier + classifier_layers = [] + in_dim = lstm_output_dim + + for hidden_dim in classifier_hidden_dims: + classifier_layers.append(nn.Linear(in_dim, hidden_dim)) + classifier_layers.append(self.activation) + classifier_layers.append(nn.Dropout(dropout)) + in_dim = hidden_dim + + classifier_layers.append(nn.Linear(in_dim, num_classes)) + self.classifier = nn.Sequential(*classifier_layers) + + # Initialize weights + self._initialize_weights() + + + def _get_activation(self): + if self.activation_type == "relu": + return nn.ReLU() + elif self.activation_type == "leaky_relu": + return nn.LeakyReLU(0.1) + elif self.activation_type == "elu": + return nn.ELU() + elif self.activation_type == "gelu": + return nn.GELU() + else: + return nn.ReLU() + + def _initialize_weights(self): + if self.weight_init == "xavier": + for m in self.modules(): + if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif self.weight_init == "kaiming": + for m in self.modules(): + if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + nn.init.zeros_(m.bias) + + def activations_hook(self, grad): + self.gradients = grad + + def reset_activation_storage(self): + self.cnn_activations = [] + self.lstm_activations = None + self.attention_weights = None + self.gradients = None + self.last_cnn_output = None + + def forward(self, x, return_attention=False, track_gradients=False): + """ + Forward pass del modelo + + Args: + x: Input tensor de forma (batch, channels, seq_length) + return_attention: Si True, devuelve también weights de atención + track_gradients: Si True, registra hook para gradientes (para explicabilidad) + + Returns: + Logits de clasificación y opcionalmente pesos de atención + """ + self.reset_activation_storage() + self.input = x + + # Guardar dimensiones originales para debugging + batch_size, channels, orig_seq_len = x.shape + + # Verificar entrada coherente con la configuración del modelo + if channels != self.input_channels: + print(f"ADVERTENCIA: Número de canales de entrada ({channels}) " + f"difiere del configurado ({self.input_channels})") + + # Almacenar dimensiones después de cada bloque para debugging + dims_after_each_block = [] + + # Process through CNN blocks + for i, block in enumerate(self.cnn_blocks): + # Guardar entrada para posible conexión residual + residual = x if block["has_residual"] else None + + # Forward pass por operaciones del bloque + x = block["conv"](x) + + if "bn" in block: + x = block["bn"](x) + + if "ln" in block: + # Transponemos correctamente para aplicar layer norm + x_transposed = x.transpose(1, 2) # (batch, seq, channels) + x_normalized = block["ln"](x_transposed) + x = x_normalized.transpose(1, 2) # Volver a (batch, channels, seq) + + x = self.activation(x) + + # Aplicar conexión residual si está disponible para este bloque + if residual is not None: + x = x + residual + + # Aplicar pooling (con la lógica corregida) + if not (i == len(self.cnn_blocks) - 1 and self.residual_connections): + x = block["pool"](x) + + x = block["dropout"](x) + self.cnn_activations.append(x.detach()) + + # Guardar dimensiones actuales + dims_after_each_block.append(tuple(x.shape)) + + cnn_output = x + + # Verificar dimensiones finales CNN + final_seq_len = x.shape[2] + if self.output_seq_length is not None and final_seq_len != self.output_seq_length: + print(f"ADVERTENCIA: Longitud secuencia después de CNN ({final_seq_len}) " + f"difiere de la esperada ({self.output_seq_length})") + + if track_gradients and cnn_output.requires_grad: + cnn_output.register_hook(self.activations_hook) + + self.last_cnn_output = cnn_output + + # Reshape for LSTM: (batch, channels, seq) -> (batch, seq, channels) + x = cnn_output.permute(0, 2, 1) + + # LSTM + lstm_out, _ = self.lstm(x) + self.lstm_activations = lstm_out.detach() + + # Declarar vectores que usaremos para todos los tipos de atención + attention_weights = None + context_vector = None + + # Apply attention mechanism según tipo configurado + if self.attention_type == "basic": + attention_scores = self.attention(lstm_out).squeeze(-1) + attention_weights = F.softmax(attention_scores, dim=1) + context_vector = torch.bmm(attention_weights.unsqueeze(1), lstm_out).squeeze(1) + + elif self.attention_type == "scaled_dot": + Q = self.query(lstm_out) + K = self.key(lstm_out) + V = self.value(lstm_out) + + scores = torch.bmm(Q, K.transpose(1, 2)) / np.sqrt(self.attention_dim) + attention_weights = F.softmax(scores, dim=-1) + context_vector = torch.bmm(attention_weights, V).mean(dim=1) + + elif self.attention_type == "multi_head": + # MultiheadAttention devuelve (attn_output, attn_output_weights) + attn_output, attn_output_weights = self.mha(lstm_out, lstm_out, lstm_out) + attention_weights = attn_output_weights + + if self.layer_normalization: + attn_output = self.attention_ln(attn_output + lstm_out) + + context_vector = attn_output.mean(dim=1) + else: + # Caso por defecto: atención uniforme + attention_weights = torch.ones(lstm_out.shape[0], lstm_out.shape[1]).to(lstm_out.device) + attention_weights = attention_weights / lstm_out.shape[1] # Normalizar + context_vector = lstm_out.mean(dim=1) + + # Verificar que attention_weights existe + if attention_weights is None: + attention_weights = torch.ones(lstm_out.shape[0], lstm_out.shape[1]).to(lstm_out.device) + attention_weights = attention_weights / lstm_out.shape[1] # Normalizar + + # Guardar pesos de atención para visualización/interpretación + self.attention_weights = attention_weights.detach() + + # Verificar dimensiones del vector de contexto + if context_vector is None: + context_vector = lstm_out.mean(dim=1) + + # Asegurar dimensionalidad correcta: (batch_size, features) + if len(context_vector.shape) > 2: + print(f"ADVERTENCIA: Vector contexto tiene forma inesperada {context_vector.shape}. " + f"Aplicando mean en dim 1.") + context_vector = context_vector.mean(dim=1) + elif len(context_vector.shape) == 1: + context_vector = context_vector.unsqueeze(0) + + # Verificación final + if len(context_vector.shape) != 2: + print(f"ERROR: Vector contexto debe ser 2D pero es {context_vector.shape}") + # Intentar corregir + if len(context_vector.shape) > 2: + context_vector = context_vector.reshape(batch_size, -1) + + # Classification + out = self.classifier(context_vector) + + if return_attention: + return out, attention_weights + return out + + def get_config(self): + """Returns the current configuration as a dictionary""" + return { + "input_channels": self.input_channels, + "seq_length": self.seq_length, + "num_classes": self.num_classes, + "cnn_channels": self.cnn_channels, + "kernel_sizes": self.kernel_sizes, + "pool_type": self.pool_type, + "pool_sizes": self.pool_sizes, + "use_batch_norm": self.use_batch_norm, + "activation": self.activation_type, + "dropout": self.dropout_rate, + "cnn_dropout": self.cnn_dropout_rate, + "lstm_hidden_dim": self.lstm_hidden_dim, + "lstm_num_layers": self.lstm_num_layers, + "bidirectional": self.bidirectional, + "lstm_dropout": self.lstm_dropout_rate, + "classifier_hidden_dims": self.classifier_hidden_dims, + "attention_dim": self.attention_dim, + "attention_type": self.attention_type, + "multi_head_num": self.multi_head_num, + "residual_connections": self.residual_connections, + "layer_normalization": self.layer_normalization, + "weight_init": self.weight_init + } + + def count_parameters(self): + """Count and return the number of trainable parameters""" + return sum(p.numel() for p in self.parameters() if p.requires_grad) + + def get_intermediate_outputs(self, x): + """Get all intermediate activations for a given input""" + _ = self.forward(x, track_gradients=True) + return { + "cnn_activations": self.cnn_activations, + "lstm_activations": self.lstm_activations, + "attention_weights": self.attention_weights + } + + def visualize_attention(self, x, return_fig=False): + """Visualize attention weights for a given input""" + try: + import matplotlib.pyplot as plt + + _, attention_weights = self.forward(x, return_attention=True) + + if attention_weights is None: + print("No attention weights available") + return None + + batch_size = attention_weights.size(0) + seq_len = attention_weights.size(1) + + fig, axes = plt.subplots(batch_size, 1, figsize=(10, 2*batch_size)) + if batch_size == 1: + axes = [axes] + + for i, ax in enumerate(axes): + weights = attention_weights[i].cpu().detach().numpy() + ax.bar(range(seq_len), weights) + ax.set_title(f"Sample {i+1}") + ax.set_xlabel("Sequence position") + ax.set_ylabel("Attention weight") + + plt.tight_layout() + + if return_fig: + return fig + plt.show() + return None + except ImportError: + print("matplotlib is required for visualization") + return None + + def interpret(self, x, class_idx=None, methods=None): + """ + Enhanced interpretation method with multiple explainability techniques + + Args: + x: Input data tensor + class_idx: Target class indices to explain (defaults to predicted class) + methods: List of methods to use, options: ['gradcam', 'integrated_gradients', + 'occlusion', 'shap', 'feature_ablation', 'all'] + + Returns: + Dictionary with various interpretability outputs + """ + try: + # Try to import Captum components + from captum.attr import IntegratedGradients, Occlusion, GradientShap, LayerGradCam + except ImportError: + raise ImportError("This method requires the 'captum' package. Install with: pip install captum") + + if methods is None: + methods = ['gradcam', 'attention'] # Default methods + if 'all' in methods: + methods = ['gradcam', 'integrated_gradients', 'occlusion', 'shap', + 'feature_ablation', 'attention', 'layer_importance'] + + # Store original training state + was_training = self.training + lstm_was_training = self.lstm.training + + # Set model to evaluation mode for interpretability + self.eval() + self.lstm.train() # needed for CuDNN backward compatibility + + # Base prediction + x.requires_grad_() + self.input = x # Store input for interpretability methods + + logits, attention = self.forward(x, return_attention=True, track_gradients=True) + pred = torch.softmax(logits, dim=1) + + if class_idx is None: + class_idx = pred.argmax(dim=1) + + # Initialize results dictionary + results = { + 'prediction': pred.detach(), + 'class_idx': class_idx, + 'attention_weights': self.attention_weights, + } + + # Apply selected interpretability methods + if 'gradcam' in methods: + for i in range(x.shape[0]): + pred[i, class_idx[i]].backward(retain_graph=True if i < x.shape[0]-1 else False) + + results['feature_importance'] = self.get_feature_importance() + results['temporal_channel_importance'] = self.get_temporal_channel_importance() + results['channel_importance'] = self.get_channel_importance() + results['cnn_activations'] = self.cnn_activations + + # Integrated Gradients + if 'integrated_gradients' in methods: + ig = IntegratedGradients(self.forward_wrapper) + results['integrated_gradients'] = self._compute_integrated_gradients( + ig, x, class_idx) + + # Occlusion analysis + if 'occlusion' in methods: + occlusion = Occlusion(self.forward_wrapper) + results['occlusion'] = self._compute_occlusion(occlusion, x, class_idx) + + # SHAP (GradientSHAP implementation) + if 'shap' in methods: + gradient_shap = GradientShap(self.forward_wrapper) + results['gradient_shap'] = self._compute_gradient_shap(gradient_shap, x, class_idx) + + # Feature ablation (sensitivity analysis) + if 'feature_ablation' in methods: + results['feature_ablation'] = self._feature_ablation_analysis(x, class_idx) + + # Layer importance analysis + if 'layer_importance' in methods: + results['layer_importance'] = self._compute_layer_importance(x, class_idx) + + # Restore original training states + self.train(was_training) + self.lstm.train(lstm_was_training) + + # Clean up to avoid memory issues + self.input = None + torch.cuda.empty_cache() + + return results + + def forward_wrapper(self, x): + """Wrapper for Captum compatibility""" + return self.forward(x) + + def get_feature_importance(self): + """ + Grad-CAM temporal over the output of the last CNN block. + Returns tensor (batch, time) + """ + if self.gradients is None or self.last_cnn_output is None: + return None + + pooled_gradients = torch.mean(self.gradients, dim=[0, 2]) # (channels,) + cam = self.last_cnn_output.clone() + + for i in range(cam.shape[1]): + cam[:, i, :] *= pooled_gradients[i] + + heatmap = torch.mean(cam, dim=1).detach() # (batch, time) + + # Apply ReLU to highlight only positive influences + heatmap = F.relu(heatmap) + + # Normalize heatmap for better visualization + if heatmap.max() > 0: + heatmap = heatmap / heatmap.max() + + return heatmap + + def get_channel_importance(self): + """ + Channel importance: (batch, channels) + """ + if self.input is None or self.input.grad is None: + raise ValueError("Input gradients not available. Call interpret() first.") + return self.input.grad.abs().mean(dim=2).detach() + + def get_temporal_channel_importance(self): + """ + Temporal-channel importance: (batch, channels, time) + """ + if self.input is None or self.input.grad is None: + raise ValueError("Input gradients not available. Call interpret() first.") + return self.input.grad.abs().detach() + + def _compute_integrated_gradients(self, ig, x, class_idx): + """Compute integrated gradients attribution""" + batch_size = x.shape[0] + attributions = [] + + for i in range(batch_size): + baseline = torch.zeros_like(x[i:i+1]) + attr = ig.attribute( + x[i:i+1], baseline, target=class_idx[i].item(), n_steps=50 + ) + attributions.append(attr) + + return torch.cat(attributions).detach() + + def _compute_occlusion(self, occlusion_algo, x, class_idx): + """Compute occlusion-based feature attribution""" + batch_size = x.shape[0] + attributions = [] + + # Define sliding window parameters for temporal data + window_size = min(5, x.shape[2] // 4) # Adapt window size to input length + + for i in range(batch_size): + attr = occlusion_algo.attribute( + x[i:i+1], + sliding_window_shapes=(1, window_size), + target=class_idx[i].item(), + strides=(1, max(1, window_size // 2)) + ) + attributions.append(attr) + + return torch.cat(attributions).detach() + + def _compute_gradient_shap(self, shap_algo, x, class_idx): + """Compute GradientSHAP attributions""" + batch_size = x.shape[0] + attributions = [] + + for i in range(batch_size): + # Create random baselines (typically 10-50 for good estimates) + baselines = torch.randn(10, *x[i:i+1].shape[1:]) * 0.001 + + # Ensure baselines device matches input + baselines = baselines.to(x.device) + + attr = shap_algo.attribute( + x[i:i+1], baselines=baselines, target=class_idx[i].item() + ) + attributions.append(attr) + + return torch.cat(attributions).detach() + + def _feature_ablation_analysis(self, x, class_idx): + """Analyze model by systematically ablating input features""" + batch_size = x.shape[0] + results = [] + + for i in range(batch_size): + # Store original prediction + with torch.no_grad(): + orig_output = self.forward(x[i:i+1]) + orig_prob = torch.softmax(orig_output, dim=1)[0, class_idx[i]].item() + + # Test ablation of each channel + channel_importance = [] + for c in range(self.input_channels): + # Create ablated input (zero out one channel) + ablated_input = x[i:i+1].clone() + ablated_input[:, c, :] = 0 + + # Get prediction on ablated input + with torch.no_grad(): + ablated_output = self.forward(ablated_input) + ablated_prob = torch.softmax(ablated_output, dim=1)[0, class_idx[i]].item() + + # Impact is reduction in probability + channel_impact = orig_prob - ablated_prob + channel_importance.append(channel_impact) + + results.append(torch.tensor(channel_importance)) + + return torch.stack(results) + + def _compute_layer_importance(self, x, class_idx): + """Compute importance of each layer using Layer GradCAM""" + try: + from captum.attr import LayerGradCam + except ImportError: + raise ImportError("This method requires the 'captum' package.") + + batch_size = x.shape[0] + layer_importance = {} + + # Define layers to analyze - adapted for our new ModuleList structure + layers = {} + for i, block in enumerate(self.cnn_blocks): + layers[f'conv{i+1}'] = block['conv'] + + for layer_name, layer in layers.items(): + layer_gradcam = LayerGradCam(self.forward_wrapper, layer) + layer_attrs = [] + + for i in range(batch_size): + attr = layer_gradcam.attribute( + x[i:i+1], target=class_idx[i].item() + ) + # Process attribution to create a single importance score per sample + pooled_attr = torch.mean(attr, dim=1) + + layer_attrs.append(pooled_attr) + + layer_importance[layer_name] = torch.cat(layer_attrs).detach() + + return layer_importance + + def visualize_attributions(self, sample_idx, interpretations, time_axis=None, + channel_names=None, class_names=None): + """ + Visualize the various interpretation results + + Args: + sample_idx: Index of the sample to visualize + interpretations: Dictionary returned by interpret() method + time_axis: Optional array/list with time points for x-axis + channel_names: Optional list of channel names + class_names: Optional list of class names + """ + try: + import matplotlib.pyplot as plt + import numpy as np + except ImportError: + raise ImportError("This method requires matplotlib and numpy for visualization") + + if not channel_names: + channel_names = [f'Channel {i}' for i in range(self.input_channels)] + + if not class_names: + class_idx = interpretations['class_idx'][sample_idx].item() + class_name = f'Class {class_idx}' + else: + class_idx = interpretations['class_idx'][sample_idx].item() + class_name = class_names[class_idx] + + # Set up figure + plt.figure(figsize=(15, 12)) + + # Original input visualization (top row, first column) + plt.subplot(3, 3, 1) + if self.input is not None: + input_data = self.input[sample_idx].cpu().detach().numpy() + if time_axis is not None: + for i in range(input_data.shape[0]): + plt.plot(time_axis, input_data[i], label=channel_names[i]) + else: + for i in range(input_data.shape[0]): + plt.plot(input_data[i], label=channel_names[i]) + plt.legend(loc='best') + plt.title('Input Signal') + plt.xlabel('Time') + plt.ylabel('Value') + + # GradCAM feature importance (top row, second column) + if 'feature_importance' in interpretations and interpretations['feature_importance'] is not None: + plt.subplot(3, 3, 2) + heatmap = interpretations['feature_importance'][sample_idx].cpu().numpy() + if time_axis is not None: + plt.plot(time_axis, heatmap) + else: + plt.plot(heatmap) + plt.title('GradCAM Feature Importance') + plt.xlabel('Time') + plt.ylabel('Importance') + + # Attention weights (top row, third column) + if 'attention_weights' in interpretations and interpretations['attention_weights'] is not None: + plt.subplot(3, 3, 3) + attention = interpretations['attention_weights'][sample_idx].cpu().numpy() + + if time_axis is not None: + # Need to match attention time axis to input time axis + # (account for pooling in the network) + x_points = np.linspace(time_axis[0], time_axis[-1], len(attention)) + plt.plot(x_points, attention) + else: + plt.plot(attention) + plt.title('Attention Weights') + plt.xlabel('Time') + plt.ylabel('Attention') + + # Channel importance (middle row, first column) + if 'channel_importance' in interpretations and interpretations['channel_importance'] is not None: + plt.subplot(3, 3, 4) + ch_importance = interpretations['channel_importance'][sample_idx].cpu().numpy() + plt.bar(channel_names, ch_importance) + plt.title('Channel Importance') + plt.ylabel('Importance') + plt.xticks(rotation=45) + + # Integrated Gradients (middle row, second column) + if 'integrated_gradients' in interpretations: + plt.subplot(3, 3, 5) + ig_attr = interpretations['integrated_gradients'][sample_idx].cpu().numpy() + ig_attr_mean = np.mean(ig_attr, axis=0) # Average across channels for visualization + + if time_axis is not None: + plt.plot(time_axis, ig_attr_mean) + else: + plt.plot(ig_attr_mean) + plt.title('Integrated Gradients') + plt.xlabel('Time') + plt.ylabel('Attribution') + + # Feature Ablation (middle row, third column) + if 'feature_ablation' in interpretations: + plt.subplot(3, 3, 6) + ablation_scores = interpretations['feature_ablation'][sample_idx].cpu().numpy() + plt.bar(channel_names, ablation_scores) + plt.title('Feature Ablation Impact') + plt.ylabel('Probability Change') + plt.xticks(rotation=45) + + # SHAP values (bottom row, first column) + if 'gradient_shap' in interpretations: + plt.subplot(3, 3, 7) + shap_attr = interpretations['gradient_shap'][sample_idx].cpu().numpy() + # Visualize average SHAP value over time + shap_avg = np.mean(shap_attr, axis=0) + + if time_axis is not None: + plt.plot(time_axis, shap_avg) + else: + plt.plot(shap_avg) + plt.title('GradientSHAP Values') + plt.xlabel('Time') + plt.ylabel('SHAP Value') + + # Occlusion analysis (bottom row, second column) + if 'occlusion' in interpretations: + plt.subplot(3, 3, 8) + occlusion_attr = interpretations['occlusion'][sample_idx].cpu().numpy() + occlusion_avg = np.mean(occlusion_attr, axis=0) + + if time_axis is not None: + plt.plot(time_axis, occlusion_avg) + else: + plt.plot(occlusion_avg) + plt.title('Occlusion Analysis') + plt.xlabel('Time') + plt.ylabel('Attribution') + + # Prediction summary (bottom row, third column) + plt.subplot(3, 3, 9) + pred_probs = interpretations['prediction'][sample_idx].cpu().numpy() + classes = list(range(len(pred_probs))) + if class_names: + classes = class_names + plt.bar(classes, pred_probs) + plt.title(f'Prediction: {class_name}') + plt.ylabel('Probability') + plt.ylim([0, 1]) + + plt.tight_layout() + return plt.gcf() + + +# Example of creating a model with custom hyperparameters +def create_model_with_config(**kwargs): + """Helper function to create a model with specified config""" + config = { + "input_channels": 3, + "num_classes": 3, + "cnn_channels": (16, 32, 64), + "kernel_sizes": (5, 3, 3), + "pool_type": "max", + "dropout": 0.1, + "lstm_hidden_dim": 32, + "lstm_num_layers": 1, + "bidirectional": True, + "classifier_hidden_dims": [32], + "attention_type": "basic", + "residual_connections": False, + "layer_normalization": False + } + + # Update config with provided kwargs + config.update(kwargs) + + return CNN_LSTM_Classifier_Tunable(config=config) + diff --git a/src/openECGfunction.py b/src/openECGfunction.py new file mode 100644 index 0000000..64c72e7 --- /dev/null +++ b/src/openECGfunction.py @@ -0,0 +1,39 @@ +import pyedflib +from .lib.ECG_processing import ECGprocessing +import pandas as pd +def openECG(physiological_data_file, patient_id): + + f = pyedflib.EdfReader(physiological_data_file) + + signal_labels = f.getSignalLabels() + print(signal_labels) + + ecg_keywords = ['ecg', 'ekg'] + + idx = None + for i, label in enumerate(signal_labels): + label_clean = label.lower().strip() + + # Check if any ECG keyword is inside the label + if any(keyword in label_clean for keyword in ecg_keywords): + idx = i + break #first ECG channel only + + if idx is None: + raise ValueError("No ECG channel found") + + print("ECG channel:", signal_labels[idx]) + + ecg_signal = f.readSignal(idx) + fs = f.getSampleFrequency(idx) + + f.close() + + all_results = ECGprocessing(ecg_signal, fs, patient_id) + + if all_results is not None: + all_patients_ECGresults = pd.concat( + [all_patients_ECGresults, all_results], + ignore_index=True + ) + return all_patients_ECGresults diff --git a/src/resp_processing.py b/src/resp_processing.py new file mode 100644 index 0000000..c1e7845 --- /dev/null +++ b/src/resp_processing.py @@ -0,0 +1,171 @@ +from .lib import Resp_features +import sys +import os +import pandas as pd +import numpy as np + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +RESP_CHANNEL_GROUPS = ("Abdomen", "Chest", "Nasal", "Flow") +RESP_FEATURE_NAMES = [ + f"{group}_Peakedness_{metric}" + for group in RESP_CHANNEL_GROUPS + for metric in ("Max", "Min", "Mean", "Median", "Std") +] + [ + "SpO2_Max", + "SpO2_Min", + "SpO2_Mean", + "SpO2_Std", + "CET90", + "ODI_Mean", + "ODI_deepness", +] +RESP_FEATURE_LENGTH = len(RESP_FEATURE_NAMES) + + +def _normalize_label(text): + normalized = ''.join(ch if ch.isalnum() else ' ' for ch in str(text).lower()) + return ' '.join(normalized.split()) + + +def _split_aliases(raw_aliases): + return {_normalize_label(alias) for alias in str(raw_aliases).split(';') if alias} + + +def _build_resp_alias_groups(channels): + resp_rows = channels[channels['Category'].eq('resp')].reset_index(drop=True) + if len(resp_rows) < 7: + return {} + return { + 'Abdomen': _split_aliases(resp_rows.iloc[0]['Channel_Names']), + 'Chest': _split_aliases(resp_rows.iloc[1]['Channel_Names']), + 'Nasal': _split_aliases(resp_rows.iloc[2]['Channel_Names']), + 'Flow': _split_aliases(resp_rows.iloc[3]['Channel_Names']), + 'SpO2': _split_aliases(resp_rows.iloc[6]['Channel_Names']), + } + + +def _find_resp_group(label, alias_groups): + normalized = _normalize_label(label) + for group_name, aliases in alias_groups.items(): + if normalized in aliases: + return group_name + return None + + +def _resample_signal(signal, fs, target_fs): + signal = np.asarray(signal, dtype=float) + if signal.size == 0: + return signal, target_fs + if fs == target_fs: + return signal, target_fs + + duration = signal.size / fs + target_samples = max(1, int(round(duration * target_fs))) + time_original = np.linspace(0, duration, signal.size) + time_target = np.linspace(0, duration, target_samples) + return np.interp(time_target, time_original, signal), target_fs + + +def _compute_resp_quality(used, hat_br): + used_array = np.asarray(used, dtype=float) + if used_array.size: + quality = float(np.nanmean(used_array)) + if np.isfinite(quality): + return quality + hat_br = np.asarray(hat_br, dtype=float) + if hat_br.size == 0: + return 0.0 + return float(np.mean(np.isfinite(hat_br))) + + +def _summarize_peakedness(hat_br): + finite_values = np.asarray(hat_br, dtype=float) + finite_values = finite_values[np.isfinite(finite_values)] + if finite_values.size == 0: + return None + return { + 'Max': float(np.max(finite_values)), + 'Min': float(np.min(finite_values)), + 'Mean': float(np.mean(finite_values)), + 'Median': float(np.median(finite_values)), + 'Std': float(np.std(finite_values)), + } + + +def _summarize_spo2(data, fs): + if data.size == 0: + return {} + working = np.asarray(data, dtype=float).copy() + if np.nanmax(working) < 2: + working = np.round((working / 1.055) * 100) + + desaturation_mask = working.copy() + threshold = 0.7 + for index, value in enumerate(working): + if value < threshold: + start = int(max(0, index - fs * 2)) + end = int(min(working.size, index + fs * 2)) + desaturation_mask[start:end] = np.nan + + cet90 = float(np.count_nonzero(desaturation_mask < 90) / max(working.size, 1)) + valid = desaturation_mask[np.isfinite(desaturation_mask)] + if valid.size == 0: + return {'CET90': cet90} + + odi_mean, odi_deepness = Resp_features.ODI_application(desaturation_mask, fs, plotflag=False, subjet=1) + return { + 'SpO2_Max': float(np.max(valid)), + 'SpO2_Min': float(np.min(valid)), + 'SpO2_Mean': float(np.mean(valid)), + 'SpO2_Std': float(np.std(valid)), + 'CET90': cet90, + 'ODI_Mean': float(odi_mean), + 'ODI_deepness': float(odi_deepness), + } + + +def processResp(physiological_data, physiological_fs, csv_path): + channels = pd.read_csv(csv_path) + alias_groups = _build_resp_alias_groups(channels) + results = {feature_name: 0.0 for feature_name in RESP_FEATURE_NAMES} + best_quality = {group_name: -np.inf for group_name in RESP_CHANNEL_GROUPS} + + for label, signal in physiological_data.items(): + if label not in physiological_fs: + continue + + group_name = _find_resp_group(label, alias_groups) + if group_name is None: + continue + + resampled, fs = _resample_signal(signal, physiological_fs[label], 25) + resampled = np.nan_to_num(resampled, nan=0.0, posinf=0.0, neginf=0.0) + + if group_name == 'SpO2': + results.update(_summarize_spo2(resampled, fs)) + continue + + try: + hat_br, _, _, used = Resp_features.peakedness_application( + resampled, + stage=label, + plotflag=False, + subjet=label, + ) + except Exception: + continue + + summary = _summarize_peakedness(hat_br) + if summary is None: + continue + + quality = _compute_resp_quality(used, hat_br) + if quality <= best_quality[group_name]: + continue + + best_quality[group_name] = quality + for metric_name, metric_value in summary.items(): + results[f'{group_name}_Peakedness_{metric_name}'] = metric_value + + return np.array([results[name] for name in RESP_FEATURE_NAMES], dtype=np.float32) diff --git a/src/results_analysis.py b/src/results_analysis.py new file mode 100644 index 0000000..b415a4b --- /dev/null +++ b/src/results_analysis.py @@ -0,0 +1,67 @@ +import numpy as np +import pandas as pd +import sys +import os +import plotly.express as px + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) +from src.lib import EEG_functions +import seaborn as sns +import matplotlib.pyplot as plt +import plotly.express as px + +hospital = ["I0006","I0002","I0004","I0007", "S0001"] +results = pd.DataFrame() +for h in hospital: + results = pd.concat([results, pd.read_csv(f"results_summaryEEG_{h}.csv")], ignore_index=True) + +demographics = pd.read_csv(os.path.join('C:/BSICoS/CincChallenge2026/CincChallenge_2026/data/training_set', "demographics.csv")) +demographics = pd.concat([demographics, pd.read_csv(os.path.join('C:/BSICoS/CincChallenge2026/CincChallenge_2026/data/supplementary_set', "demographics.csv"))], ignore_index=True) + +for index, row in results.iterrows(): + patient_id = row['Patient_ID'] + hospital_id = row['File'][4:9] # Asumiendo que los primeros 5 caracteres del nombre del archivo indican el hospital + demographics_row = demographics[(demographics['BDSPPatientID'] == patient_id) & (demographics['SiteID'] == hospital_id)] + if not demographics_row.empty: + cognitive_impairment = demographics_row['Cognitive_Impairment'].values[0] + time_to_event = demographics_row['Time_to_Event'].values[0] + results.at[index, 'Hospital'] = hospital_id + results.at[index, 'CognitiveImpairment'] = cognitive_impairment + results.at[index, 'Time_to_Event'] = time_to_event + else: + results.at[index, 'Hospital'] = hospital_id + results.at[index, 'CognitiveImpairment'] = np.nan # O cualquier valor que indique que no se encontró información + results.at[index, 'Time_to_Event'] = np.nan + +df = pd.DataFrame(results) +# Agrupar por electrodo +for elec in results['Channel'].unique(): + subset = results[results['Channel'] == elec] + + print(subset.Hospital.unique()) + # Hacer un boxplot de cada característica que separe entre pacientes con congnitive impairment y sin él + for col in subset.columns[3:-2]: + print(col) + + fig = px.box(subset, + x='CognitiveImpairment', + y=col, + color='CognitiveImpairment', + notched=True, + points="all", + hover_data=['Patient_ID', 'Channel'], + title=f"{elec} - Comparativa de {col} según Estado Cognitivo") + + fig.update_layout(template="plotly_white") + fig.write_html(f"graphs/ComparativaCognitiveImpairment/2segundos/PorHospital/html/{elec}_{col}.html") # Guardar como HTML para visualización interactiva + # fig.delete_traces([0]) # Eliminar la leyenda para que no se repita en cada gráfico + + # Generar el boxplot con Seaborn (Extremadamente rápido) + plt.figure(figsize=(10, 6)) + sns.boxplot(data=df, x='CognitiveImpairment', y=col, hue='CognitiveImpairment', notch=True) + sns.stripplot(data=df, x='CognitiveImpairment', y=col, color="black", alpha=0.3, size=3) # Equivalente a points="all" + + plt.title(f"Comparativa {elec} - {col}") + plt.savefig(f"graphs/ComparativaCognitiveImpairment/2segundos/PorHospital/{hospital_id}/png/{elec}_{col}.png", dpi=100) + # \graphs\ComparativaCognitiveImpairment\2segundos\PorHospital\I0006\png + plt.close() # ¡Importante! Para no saturar la memoria RAM \ No newline at end of file diff --git a/src/segmentation.py b/src/segmentation.py new file mode 100644 index 0000000..6ae7406 --- /dev/null +++ b/src/segmentation.py @@ -0,0 +1,193 @@ +from binascii import Error +import numpy as np +import pandas as pd +import sys +import os +import matplotlib.pyplot as plt +import plotly.express as px +import plotly.graph_objects as go +import scipy.signal +from plotly.subplots import make_subplots +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) +import lib.helper_code as helper_code +import lib.EEG_functions as EEG_functions + +for hospital in ['I0002','I0006', 'I0004','I0007','S0001']: + print(f"Procesando hospital: {hospital}") + + if hospital == 'I0002' or hospital == 'I0006' or hospital == "S0001": + datapath = 'data/training_set/Physiological_data/'+hospital + else: + datapath = 'data/supplementary_set/Physiological_data/'+hospital + + channels = pd.read_csv("notebooks/channel_table.csv") + selectEEG = channels[channels['Category'].isin(['eeg'])] + demographics = pd.read_csv(os.path.join('data/training_set', "demographics.csv")) + selectresp = channels[channels['Category'].isin(['resp'])] + selectECG = channels[channels['Category'].isin(['ecg'])] + + # Datos = pd.DataFrame(columns=['File', 'Channel', 'Sampling_Frequency', 'Duration_sec']) + lista_dir = os.listdir(datapath) + results = [] + + for file in lista_dir: + # Cargar el archivo (sustituye por tu ruta real) + edf = helper_code.edfio.read_edf(os.path.join(datapath, file)) + + id = file[9:-10] # Asumiendo que el ID es el nombre del archivo sin la extensión + selEEG = [] + selECG = [] + selResp = [] + labels = [] + data = [] + HayECG = False + for i, sig in enumerate(edf.signals): + for index in selectECG.index: + if sig.label.lower() in selectECG['Channel_Names'][index].lower(): + HayECG = True + print(f"Canal seleccionado: {sig.label}") + selECG.append([i,sig]) + labels.append(sig.label) + data.append(sig.data) # Guardar la señal ECG sin filtrar para su posterior procesamiento + break + + HayResp = False + for i, sig in enumerate(edf.signals): + for index in selectresp.index: + if sig.label.lower() in selectresp['Channel_Names'][index].lower(): + HayResp = True + fs = sig.sampling_frequency + if sig.label == "O2": + print(f"Warning: {sig.label} is detected as respiratory signal but has a sampling frequency higher than 100 Hz. Check the data.") + else: + print(f"Canal seleccionado: {sig.label}") + selResp.append([i,sig]) + labels.append(sig.label) + + data.append(sig.data) # Guardar la señal RESP sin filtrar para su posterior procesamiento + break + + # Listar canales para identificar los de interés (ej: C3-M2, O1-M2) + # print("Canales detectados:") + + HayEEG = False + for i, sig in enumerate(edf.signals): + # print(f"[{i}] {sig.label}") + # print length fs and duration + # print(f"Length: {len(sig.data)}, Sampling Frequency: {sig.sampling_frequency} Hz, Duration: {len(sig.data)/sig.sampling_frequency:.2f} seconds") + for index in selectEEG.index: + if sig.label.lower() in selectEEG['Channel_Names'][index].lower(): + print(f"Canal seleccionado: {sig.label}") + selEEG.append([i,sig]) + # labels.append(sig.label) + HayEEG = True + break + # for i in range(len(edf.signals)): + # print(f"Longitud: {edf.signals[i].data.shape}, Canal: {edf.signals[i].label}, Frecuencia de muestreo: {edf.signals[i].sampling_frequency} Hz, Duración: {len(edf.signals[i].data)/edf.signals[i].sampling_frequency:.2f} segundos") + + if HayEEG and HayECG and HayResp: + + Bipolar = pd.DataFrame() + if all(label in labels for label in ["F3", "F4", "M1", "M2"]): + Bipolar['F3-M2'] = edf.signals[edf.labels.index("F3")].data - edf.signals[edf.labels.index("M2")].data + Bipolar['F4-M1'] = edf.signals[edf.labels.index("F4")].data - edf.signals[edf.labels.index("M1")].data + if all(label in labels for label in ["C3", "C4", "M1", "M2"]): + Bipolar['C3-M2'] = edf.signals[edf.labels.index("C3")].data - edf.signals[edf.labels.index("M2")].data + Bipolar['C4-M1'] = edf.signals[edf.labels.index("C4")].data - edf.signals[edf.labels.index("M1")].data + if all(label in labels for label in ["O2", "O1", "M1", "M2"]): + Bipolar['O2-M2'] = edf.signals[edf.labels.index("O1")].data - edf.signals[edf.labels.index("M2")].data + Bipolar['O1-M1'] = edf.signals[edf.labels.index("O2")].data - edf.signals[edf.labels.index("M1")].data + + # print(f"Archivo {file} tiene ECG, RESP y EEG. Se procesará con canales bipolares.") + if not Bipolar.empty: + for col in Bipolar.columns: + # print(f"Archivo: {file}, Canal: {col}, Frecuencia de muestreo: {sig.sampling_frequency} Hz, Duración: {len(Bipolar[col])/sig.sampling_frequency:.2f} segundos") + fs = edf.signals[edf.labels.index("O2")].sampling_frequency # Asumimos que todos los canales tienen la misma frecuencia de muestreo + time = np.linspace(0, len(Bipolar[col]) / fs, len(Bipolar[col])) + fil = EEG_functions.butter_bandpass_filter(Bipolar[col], lowcut=0.3, highcut=35, fs=fs, order=4) + norm = (fil-np.mean(fil))/np.std(fil) + + data.append(norm) # Restar la media para centrar la señal + labels.append(col) + # columns = Bipolar.columns.tolist() + else: + for i, (idx, sig) in enumerate(selEEG): + # print(f"Archivo: {file}, Canal: {sig.label}, Frecuencia de muestreo: {sig.sampling_frequency} Hz, Duración: {len(sig.data)/sig.sampling_frequency:.2f} segundos") + fs = sig.sampling_frequency + time = np.linspace(0, len(sig.data) / fs, len(sig.data)) + fil = EEG_functions.butter_bandpass_filter(sig.data, lowcut=0.3, highcut=35, fs=fs, order=4) + norm = (fil-np.mean(fil))/np.std(fil) + labels.append(sig.label) + data.append(norm) # Restar la media para centrar la señal + + # columns = [selEEG[i][1].label for i in range(len(selEEG))] + + + # for i in range(len(selResp)): + # columns.append(selResp[i][1].label) + demographicsID = demographics[demographics['BDSPPatientID'] == int(id)] + print(demographicsID) + + columnashoras = [] + for elec in labels: + for h in np.floor(np.arange(0, len(sig.data) / fs / 3600, 1)): + columnashoras.append(elec + f"_h{int(h)}") + epochs5min = pd.DataFrame(columns= columnashoras) + + for i, elec in enumerate(labels): + # Check fs of the current channel + if elec == 'O2_resp': + fs = edf.signals[edf.labels.index('O2')].sampling_frequency + else: + fs = edf.signals[edf.labels.index(labels[i])].sampling_frequency + + if fs != 200: + # print(f"Warning: Sampling frequency for channel {elec} in file {file} is {fs} Hz, expected 200 Hz. Check the data.") + duration = len(data[i]) / fs + time_original = np.linspace(0, duration, len(data[i])) + + num_samples_target = int(duration * 200 ) + time_target = np.linspace(0, duration, num_samples_target) + data[i] = np.interp(time_target, time_original, data[i]) + fs = 200 # Update fs to the target sampling frequency after resampling + + # Plot comparison of original and resampled signals + # lim = 50000 + # factor = len(filtered_data[0]) / num_samples_target + # plt.figure(figsize=(12, 6)) + # plt.plot(time_target[:int(lim/factor)], resampled_data[:int(lim/factor)], label='Resampled Signal') + # plt.plot(time_original[:lim], filtered_data[i][:lim], label='Original Signal') + # plt.title(f'Original vs Resampled Signal - {elec} in {file}') + # plt.show() + + + epoch_length = 300 # Duración de cada época en segundos + # epochs = EEG_functions.create_epochs(df[elec].values, fs, epoch_duration=epoch_length) + epochs = EEG_functions.create_epochs(data[i], fs, epoch_duration=epoch_length) + + # Coger los primeros 5min de cada hora + for h in np.floor(np.arange(0, len(epochs)*epoch_length/3600, 1)): + start_epoch = int(h*3600/epoch_length) + end_epoch = int((h*3600 + 5*60)/epoch_length) + if end_epoch > len(epochs): + end_epoch = len(epochs) + c = elec+'_h'+str(int(h)) + epochs5min.loc[:, c] = epochs[start_epoch:end_epoch][0] + + + # del epochs, fs # Liberar memoria + + # Plotly sublot epochs5min.iloc[:,::8].plot() + # h = 6 + # fig = make_subplots(rows=int(epochs5min.shape[1]/8), cols=1, subplot_titles=epochs5min.columns[h::8]) + # for i in range(h, epochs5min.shape[1], 8): + # print(f"Plotting channel: {epochs5min.columns[i]}") + # fig.add_trace(go.Scatter(x=epochs5min.index, y=epochs5min.iloc[:,i], mode='lines', name=epochs5min.columns[i]), row=int(i/8)+1, col=1) + # fig.update_layout(height=3000, width=1200, title_text=f"Epochs de 5 minutos para el archivo {file}") + # fig.show() + if epochs5min.shape[1] > 0 and epochs5min.shape[0] == 60000: + epochs5min.to_parquet(os.path.join('X:/bsicos01/__comun/Physionet/Data5min', f"{id}.parquet")) + else: + print(f"Error: No channels in epochs5min for file {file}") + # stop program if no channels are processed + raise ValueError(f"No channels in epochs5min for file {file}") \ No newline at end of file diff --git a/team_code.py b/team_code.py index 910c4ed..d5194c9 100644 --- a/team_code.py +++ b/team_code.py @@ -12,12 +12,18 @@ import joblib import numpy as np import os +import atexit +import builtins +import pandas as pd +import re +from concurrent.futures import ThreadPoolExecutor from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor import sys from tqdm import tqdm from helper_code import * - +from src.resp_processing import RESP_FEATURE_LENGTH, processResp +from src.eeg_processing import EEG_FEATURE_LENGTH, processEEG ################################################################################ # Path & Constant Configuration (Added for Robustness) ################################################################################ @@ -28,6 +34,163 @@ # Build the absolute path to the CSV file relative to the script location DEFAULT_CSV_PATH = os.path.join(SCRIPT_DIR, 'channel_table.csv') +# Progress bar state for run_model (initialized lazily) +RUN_MODEL_PBAR = None +RUN_MODEL_PBAR_TOTAL = None +ORIGINAL_PRINT = builtins.print +PRINT_FILTER_ACTIVE = False +RUN_PROGRESS_LINE_RE = re.compile(r'^-\s+\d+/\d+:\s') +RENAME_RULES_CACHE = {} +MAX_TRAIN_WORKERS = max(1, min(4, os.cpu_count() or 1)) + + +def build_training_metadata_cache(patient_data_file): + metadata = pd.read_csv(patient_data_file) + demographics_cache = {} + diagnosis_cache = {} + + for row in metadata.to_dict('records'): + patient_id = row[HEADERS['bids_folder']] + session_id = row[HEADERS['session_id']] + demographics_cache[(patient_id, session_id)] = row + diagnosis_cache[patient_id] = load_label(row) + + return demographics_cache, diagnosis_cache + + +def get_rename_rules(csv_path): + normalized_csv_path = os.path.abspath(csv_path) + rename_rules = RENAME_RULES_CACHE.get(normalized_csv_path) + if rename_rules is None: + rename_rules = load_rename_rules(normalized_csv_path) + RENAME_RULES_CACHE[normalized_csv_path] = rename_rules + return rename_rules + + +def _coerce_feature_vector(features): + vector = np.asarray(features, dtype=np.float32).reshape(-1) + return np.nan_to_num(vector, nan=0.0, posinf=0.0, neginf=0.0) + + +def _extract_optional_features(extractor, expected_length, *args, **kwargs): + vector = _coerce_feature_vector(extractor(*args, **kwargs)) + if vector.size != expected_length: + raise ValueError( + f"{extractor.__name__} returned {vector.size} features; expected {expected_length}." + ) + return vector + + +def extract_extended_physiological_features(physiological_data, physiological_fs, csv_path=DEFAULT_CSV_PATH): + base_features = _coerce_feature_vector( + extract_physiological_features(physiological_data, physiological_fs, csv_path=csv_path) + ) + + try: + resp_features = _extract_optional_features( + processResp, + RESP_FEATURE_LENGTH, + physiological_data, + physiological_fs, + csv_path=csv_path, + ) + except Exception: + resp_features = np.zeros(RESP_FEATURE_LENGTH, dtype=np.float32) + + try: + eeg_features = _extract_optional_features( + processEEG, + EEG_FEATURE_LENGTH, + physiological_data, + physiological_fs, + csv_path=csv_path, + ) + except Exception: + eeg_features = np.zeros(EEG_FEATURE_LENGTH, dtype=np.float32) + + return np.hstack([base_features, resp_features, eeg_features]).astype(np.float32) + + +def process_training_record(record, data_folder, demographics_cache, diagnosis_cache, csv_path): + patient_id = record[HEADERS['bids_folder']] + site_id = record[HEADERS['site_id']] + session_id = record[HEADERS['session_id']] + + try: + patient_data = demographics_cache.get((patient_id, session_id), {}) + demographic_features = extract_demographic_features(patient_data) + + physiological_data_file = os.path.join( + data_folder, + PHYSIOLOGICAL_DATA_SUBFOLDER, + site_id, + f"{patient_id}_ses-{session_id}.edf" + ) + if not os.path.exists(physiological_data_file): + return patient_id, None, None, f"Missing physiological data for {patient_id}. Skipping..." + + physiological_data, physiological_fs = load_signal_data(physiological_data_file) + physiological_features = extract_extended_physiological_features( + physiological_data, + physiological_fs, + csv_path=csv_path + ) + algorithmic_annotations_file = os.path.join( + data_folder, + ALGORITHMIC_ANNOTATIONS_SUBFOLDER, + site_id, + f"{patient_id}_ses-{session_id}_caisr_annotations.edf" + ) + if os.path.exists(algorithmic_annotations_file): + algorithmic_annotations, _ = load_signal_data(algorithmic_annotations_file) + algorithmic_features = extract_algorithmic_annotations_features(algorithmic_annotations) + else: + algorithmic_features = np.zeros(12, dtype=np.float32) + + label = diagnosis_cache.get(patient_id) + + if label == 0 or label == 1: + feature_vector = np.hstack([demographic_features, physiological_features, algorithmic_features]) + return patient_id, feature_vector, label, None + + return patient_id, None, None, f"Invalid label for {patient_id}. Skipping..." + + except Exception as e: + return patient_id, None, None, f"Error processing {patient_id}: {e}" + + +def _close_run_model_pbar(): + global RUN_MODEL_PBAR + if RUN_MODEL_PBAR is not None: + RUN_MODEL_PBAR.close() + RUN_MODEL_PBAR = None + + +def _install_run_print_filter(): + global PRINT_FILTER_ACTIVE + if PRINT_FILTER_ACTIVE: + return + + def _filtered_print(*args, **kwargs): + message = kwargs.get('sep', ' ').join(str(a) for a in args) if args else '' + if RUN_PROGRESS_LINE_RE.match(message): + return + return ORIGINAL_PRINT(*args, **kwargs) + + builtins.print = _filtered_print + PRINT_FILTER_ACTIVE = True + + +def _restore_print(): + global PRINT_FILTER_ACTIVE + if PRINT_FILTER_ACTIVE: + builtins.print = ORIGINAL_PRINT + PRINT_FILTER_ACTIVE = False + + +atexit.register(_close_run_model_pbar) +atexit.register(_restore_print) + ################################################################################ # @@ -46,6 +209,7 @@ def train_model(data_folder, model_folder, verbose, csv_path=DEFAULT_CSV_PATH): patient_data_file = os.path.join(data_folder, DEMOGRAPHICS_FILE) patient_metadata_list = find_patients(patient_data_file) + demographics_cache, diagnosis_cache = build_training_metadata_cache(patient_data_file) num_records = len(patient_metadata_list) if num_records == 0: @@ -55,73 +219,41 @@ def train_model(data_folder, model_folder, verbose, csv_path=DEFAULT_CSV_PATH): if verbose: print('Extracting features and labels from the data...') - # Iterate over the records to extract the features and labels. features = list() labels = list() - - pbar = tqdm(range(num_records), desc="Extracting Features", unit="record", disable=not verbose) - for i in pbar: - try: - # Extract identifiers for this specific record - record = patient_metadata_list[i] - patient_id = record[HEADERS['bids_folder']] - site_id = record[HEADERS['site_id']] - session_id = record[HEADERS['session_id']] + with ThreadPoolExecutor(max_workers=MAX_TRAIN_WORKERS) as executor: + results = executor.map( + lambda record: process_training_record( + record, + data_folder, + demographics_cache, + diagnosis_cache, + csv_path + ), + patient_metadata_list + ) + + pbar = tqdm(results, total=num_records, desc="Extracting Features", unit="record", disable=not verbose) + for patient_id, feature_vector, label, message in pbar: if verbose: pbar.set_postfix({"patient": patient_id}) - # Load the patient data. - patient_data_file = os.path.join(data_folder, DEMOGRAPHICS_FILE) - patient_data = load_demographics(patient_data_file, patient_id, session_id) - demographic_features = extract_demographic_features(patient_data) - - # Load signal data. - - # Load the physiological signal. - physiological_data_file = os.path.join(data_folder, PHYSIOLOGICAL_DATA_SUBFOLDER, site_id, f"{patient_id}_ses-{session_id}.edf") - # --- Check if the file actually exists before proceeding --- - if not os.path.exists(physiological_data_file): - if verbose: - print(f" ! Missing physiological data for {patient_id}. Skipping...") - continue # skip record - physiological_data, physiological_fs = load_signal_data(physiological_data_file) - physiological_features = extract_physiological_features(physiological_data, physiological_fs, csv_path=csv_path) # This function can rename, re-reference, resample, etc. the signal data. - - # Load the algorithmic annotations. - algorithmic_annotations_file = os.path.join(data_folder, ALGORITHMIC_ANNOTATIONS_SUBFOLDER, site_id, f"{patient_id}_ses-{session_id}_caisr_annotations.edf") - algorithmic_annotations, algorithmic_fs = load_signal_data(algorithmic_annotations_file) - algorithmic_features = extract_algorithmic_annotations_features(algorithmic_annotations) - - # Load the human annotations; these data will not be available in the hidden validation and test sets. - human_annotations_file = os.path.join(data_folder, HUMAN_ANNOTATIONS_SUBFOLDER, site_id, f"{patient_id}_ses-{session_id}_expert_annotations.edf") - human_annotations, human_fs = load_signal_data(human_annotations_file) - human_features = extract_human_annotations_features(human_annotations) - - # Load the diagnoses; these data will not be available in the hidden validation and test sets. - diagnosis_file = os.path.join(data_folder, DEMOGRAPHICS_FILE) - label = load_diagnoses(diagnosis_file, patient_id) - - # Store the features and labels, but - # the human annotations are not available on the hidden validation and test sets, but you - # may want to consider how to use them for training. - if label == 0 or label == 1: - features.append(np.hstack([demographic_features, physiological_features, algorithmic_features])) - labels.append(label) + if message is not None: + tqdm.write(f" ! {message}") + continue - if 'physiological_data' in locals(): del physiological_data - if 'algorithmic_annotations' in locals(): del algorithmic_annotations + features.append(feature_vector) + labels.append(label) - except Exception as e: - # If an error occurs (e.g., a record is corrupted), log it and move to the next - tqdm.write(f" !!! Error processing record {i+1} ({patient_id}): {e}") - continue - - pbar.close() + pbar.close() features = np.asarray(features, dtype=np.float32) labels = np.asarray(labels, dtype=bool) + if features.size == 0 or features.ndim != 2 or features.shape[0] == 0: + raise ValueError('No valid training samples were extracted. Review feature extraction logs for the skipped records.') + # Train the models on the features. if verbose: print('Training the model on the data...') @@ -150,6 +282,9 @@ def train_model(data_folder, model_folder, verbose, csv_path=DEFAULT_CSV_PATH): # Load your trained models. This function is *required*. You should edit this function to add your code, but do *not* change the # arguments of this function. If you do not train one of the models, then you can return None for the model. def load_model(model_folder, verbose): + if verbose: + _install_run_print_filter() + model_filename = os.path.join(model_folder, 'model.sav') model = joblib.load(model_filename) return model @@ -157,6 +292,8 @@ def load_model(model_folder, verbose): # Run your trained model. This function is *required*. You should edit this function to add your code, but do *not* change the # arguments of this function. def run_model(model, record, data_folder, verbose): + global RUN_MODEL_PBAR, RUN_MODEL_PBAR_TOTAL + # Load the model. model = model['model'] @@ -165,6 +302,27 @@ def run_model(model, record, data_folder, verbose): site_id = record[HEADERS['site_id']] session_id = record[HEADERS['session_id']] + # Initialize tqdm progress bar lazily so it advances across run_model calls. + if verbose and RUN_MODEL_PBAR is None: + patient_data_file = os.path.join(data_folder, DEMOGRAPHICS_FILE) + try: + RUN_MODEL_PBAR_TOTAL = len(find_patients(patient_data_file)) + except Exception: + RUN_MODEL_PBAR_TOTAL = None + + RUN_MODEL_PBAR = tqdm( + total=RUN_MODEL_PBAR_TOTAL, + desc="Running Model", + unit="record", + leave=True, + file=sys.stdout, + delay=0.5, + disable=not verbose + ) + + if verbose and RUN_MODEL_PBAR is not None: + RUN_MODEL_PBAR.set_postfix({"patient": patient_id}) + # Load the patient data. patient_data_file = os.path.join(data_folder, DEMOGRAPHICS_FILE) patient_data = load_demographics(patient_data_file, patient_id, session_id) @@ -175,10 +333,9 @@ def run_model(model, record, data_folder, verbose): if os.path.exists(phys_file): phys_data, phys_fs = load_signal_data(phys_file) # Ensure csv_path is accessible or defined - physiological_features = extract_physiological_features(phys_data, phys_fs) + physiological_features = extract_extended_physiological_features(phys_data, phys_fs) else: - # Fallback to zeros if file is missing (length 49) - physiological_features = np.zeros(49) + physiological_features = np.zeros(49 + RESP_FEATURE_LENGTH + EEG_FEATURE_LENGTH, dtype=np.float32) # Load Algorithmic Annotations algo_file = os.path.join(data_folder, ALGORITHMIC_ANNOTATIONS_SUBFOLDER, site_id, f"{patient_id}_ses-{session_id}_caisr_annotations.edf") @@ -187,7 +344,7 @@ def run_model(model, record, data_folder, verbose): algorithmic_features = extract_algorithmic_annotations_features(algo_data) else: # Fallback to zeros (length 12) - algorithmic_features = np.zeros(12) + algorithmic_features = np.zeros(12, dtype=np.float32) features = np.hstack([demographic_features, physiological_features, algorithmic_features]).reshape(1, -1) @@ -195,6 +352,12 @@ def run_model(model, record, data_folder, verbose): binary_output = model.predict(features)[0] probability_output = model.predict_proba(features)[0][1] + if verbose and RUN_MODEL_PBAR is not None: + RUN_MODEL_PBAR.update(1) + if RUN_MODEL_PBAR_TOTAL is not None and RUN_MODEL_PBAR.n >= RUN_MODEL_PBAR_TOTAL: + RUN_MODEL_PBAR.close() + RUN_MODEL_PBAR = None + return binary_output, probability_output ################################################################################ @@ -257,7 +420,7 @@ def extract_physiological_features(physiological_data, physiological_fs, csv_pat # Step 1: Load rules and standardize names # Note: Use script-relative path or absolute path for robustness - rename_rules = load_rename_rules(os.path.abspath(csv_path)) + rename_rules = get_rename_rules(csv_path) rename_map, cols_to_drop = standardize_channel_names_rename_only(original_labels, rename_rules) # Step 2: Apply renaming to BOTH signals and their corresponding FS @@ -543,4 +706,4 @@ def count_discrete_events(key): def save_model(model_folder, model): d = {'model': model} filename = os.path.join(model_folder, 'model.sav') - joblib.dump(d, filename, protocol=0) \ No newline at end of file + joblib.dump(d, filename, protocol=0)