diff --git a/scripts/build_species_list.py b/scripts/build_species_list.py new file mode 100644 index 0000000..27ff142 --- /dev/null +++ b/scripts/build_species_list.py @@ -0,0 +1,195 @@ +#!/usr/bin/env python +""" +Bridge script: reads a DwC-A file to extract verbatimScientificName (or another +name column) and joins it onto an existing annotations CSV produced by the +ami-dataset clean-dataset step. + +This is needed because load_dwca_data() in src/dataset_tools/utils.py does not +include verbatimScientificName in its column selection. + +Outputs: + - Augmented annotations CSV with the label column added + - Category map JSON (species_name -> integer_id) +""" + +import argparse +import json +import sys + +import pandas as pd +from dwca.read import DwCAReader + + +def print_dwca_summary(occ_df: pd.DataFrame) -> None: + """Print summary statistics for the DwC-A occurrence data.""" + print("\n=== DwC-A Summary ===") + print(f" Total occurrences: {len(occ_df)}") + + for col in [ + "verbatimScientificName", + "scientificName", + "species", + "family", + "order", + ]: + if col in occ_df.columns: + n_unique = occ_df[col].dropna().nunique() + n_missing = occ_df[col].isna().sum() + print(f" Unique {col}: {n_unique} (missing: {n_missing})") + + if "eventDate" in occ_df.columns: + dates = pd.to_datetime(occ_df["eventDate"], errors="coerce").dropna() + if len(dates) > 0: + print(f" Date range: {dates.min().date()} to {dates.max().date()}") + + for coord_col in ["decimalLatitude", "decimalLongitude"]: + if coord_col in occ_df.columns: + vals = pd.to_numeric(occ_df[coord_col], errors="coerce").dropna() + if len(vals) > 0: + lo = "{:.2f}".format(vals.min()) + hi = "{:.2f}".format(vals.max()) + print(f" {coord_col}: {lo} to {hi}") + + print() + + +def report_missing_labels(occ_df: pd.DataFrame, label_column: str) -> None: + """Report occurrences where the label column is missing/empty.""" + missing_mask = occ_df[label_column].isna() | ( + occ_df[label_column].astype(str).str.strip() == "" + ) + n_missing = missing_mask.sum() + + if n_missing == 0: + print(f"All occurrences have a value for '{label_column}'.") + return + + msg = f"\nWARNING: {n_missing} occurrences missing '{label_column}'" + print(msg) + alt_cols = [ + c + for c in [ + "scientificName", + "species", + "acceptedScientificName", + "verbatimScientificName", + ] + if c in occ_df.columns and c != label_column + ] + + missing_rows = occ_df[missing_mask].head(20) + for _, row in missing_rows.iterrows(): + alt_info = ", ".join(f"{c}={row.get(c, 'N/A')}" for c in alt_cols) + print(f" coreid={row['id']}: {alt_info}") + + if n_missing > 20: + print(f" ... and {n_missing - 20} more") + print() + + +def main(): + parser = argparse.ArgumentParser( + description="Augment annotations CSV with species names from DwC-A" + ) + parser.add_argument( + "--dwca-file", required=True, help="Path to Darwin Core Archive zip file" + ) + parser.add_argument( + "--annotations-csv", + required=True, + help="CSV from clean-dataset step (must have 'coreid' column)", + ) + parser.add_argument( + "--output-csv", required=True, help="Path to save augmented annotations CSV" + ) + parser.add_argument( + "--category-map-json", required=True, help="Path to save category map JSON" + ) + parser.add_argument( + "--label-column", + default="verbatimScientificName", + help="DwC-A column to use as species label (default: verbatimScientificName)", + ) + args = parser.parse_args() + + # --- Read DwC-A --- + print(f"Reading DwC-A: {args.dwca_file}") + with DwCAReader(args.dwca_file) as dwca: + occ_df = dwca.pd_read( + "occurrence.txt", parse_dates=True, on_bad_lines="skip", low_memory=False + ) + media_df = dwca.pd_read( + "multimedia.txt", parse_dates=True, on_bad_lines="skip", low_memory=False + ) + + print(f" Occurrences: {len(occ_df)}, Multimedia records: {len(media_df)}") + print_dwca_summary(occ_df) + + # --- Check label column exists --- + if args.label_column not in occ_df.columns: + available = [ + c for c in occ_df.columns if "name" in c.lower() or "species" in c.lower() + ] + print(f"ERROR: Column '{args.label_column}' missing from occurrence data.") + print(f" Available name-related columns: {available}") + sys.exit(1) + + report_missing_labels(occ_df, args.label_column) + + # --- Build coreid -> label mapping --- + # The 'id' column in occurrence.txt corresponds to 'coreid' in multimedia/annotations + name_map = occ_df[["id", args.label_column]].drop_duplicates() + name_map = name_map.rename(columns={"id": "coreid"}) + + # --- Read annotations CSV --- + print(f"Reading annotations: {args.annotations_csv}") + annotations = pd.read_csv(args.annotations_csv) + print(f" Rows: {len(annotations)}") + + if "coreid" not in annotations.columns: + print("ERROR: annotations CSV does not have a 'coreid' column.") + print(f" Available columns: {list(annotations.columns)}") + sys.exit(1) + + # --- Join --- + # Ensure coreid types match for the merge + annotations["coreid"] = annotations["coreid"].astype(str) + name_map["coreid"] = name_map["coreid"].astype(str) + + merged = annotations.merge(name_map, on="coreid", how="left") + + # --- Drop rows with missing labels --- + missing_mask = merged[args.label_column].isna() | ( + merged[args.label_column].astype(str).str.strip() == "" + ) + n_dropped = missing_mask.sum() + if n_dropped > 0: + print(f"WARNING: Dropping {n_dropped} rows with missing '{args.label_column}'") + merged = merged[~missing_mask].copy() + + print(f" Rows after join and filter: {len(merged)}") + + # --- Save augmented CSV --- + merged.to_csv(args.output_csv, index=False) + print(f"Saved augmented annotations: {args.output_csv}") + + # --- Build and save category map --- + species_list = sorted(merged[args.label_column].unique()) + category_map = {name: idx for idx, name in enumerate(species_list)} + + with open(args.category_map_json, "w") as f: + json.dump(category_map, f, indent=2) + print(f"Saved category map ({len(category_map)} species): {args.category_map_json}") + + # --- Per-species image count summary --- + print(f"\n=== Per-species image counts ({args.label_column}) ===") + counts = merged[args.label_column].value_counts().sort_index() + for species, count in counts.items(): + print(f" {species}: {count}") + + print(f"\nTotal images: {len(merged)}") + print(f"Total species: {len(category_map)}") + + +if __name__ == "__main__": + main() diff --git a/scripts/train_species_classifier.sh b/scripts/train_species_classifier.sh new file mode 100755 index 0000000..ee295f5 --- /dev/null +++ b/scripts/train_species_classifier.sh @@ -0,0 +1,309 @@ +#!/bin/bash +# +# Single-script species classifier training pipeline. +# Orchestrates the full flow from a GBIF Darwin Core Archive (DwC-A) +# to a trained ConvNeXt-Tiny species classification model. +# +# Usage: +# bash scripts/train_species_classifier.sh +# +# Prerequisites: +# - uv installed and project dependencies synced (uv sync --extra dev) +# - DwC-A zip file at DWCA_FILE path below +# + +set -euo pipefail + +# ============================================================ +# Configuration — edit these paths as needed +# ============================================================ +DWCA_FILE="0007113-260208012135463.zip" +DATASET_PATH="dataset-out" +OUTPUT_DIR="output/species_classifier" +LABEL_COLUMN="verbatimScientificName" +MIN_INSTANCES=3 +# MIN_INSTANCES=0 # for tiny datasets (<200 images) where few species meet the threshold +NUM_WORKERS=8 + +# Training hyperparameters +MODEL_TYPE="convnext_tiny_in22k" +TOTAL_EPOCHS=35 +WARMUP_EPOCHS=3 +EARLY_STOPPING=5 +LOSS_FUNCTION="cross_entropy" +LABEL_SMOOTHING=0.1 +LR_SCHEDULER="cosine" +LEARNING_RATE=0.001 +BATCH_SIZE=64 +IMAGE_INPUT_SIZE=128 + +# Webdataset settings +RESIZE_MIN_SIZE=450 +MAX_SHARD_SIZE=$((100 * 1024 * 1024)) # 100 MB + +# Weights & Biases (optional — leave empty to disable) +WANDB_ENTITY="" +WANDB_PROJECT="" +WANDB_RUN_NAME="" + +# ============================================================ +# Derived paths (generally don't need to edit) +# ============================================================ +VERIFIED_CSV="${OUTPUT_DIR}/verified_images.csv" +CLEAN_CSV="${VERIFIED_CSV%.csv}_clean.csv" +AUGMENTED_CSV="${OUTPUT_DIR}/annotations_with_species.csv" +CATEGORY_MAP="${OUTPUT_DIR}/category_map.json" +SPLIT_PREFIX="${OUTPUT_DIR}/split" +TRAIN_CSV="${SPLIT_PREFIX}/train.csv" +VAL_CSV="${SPLIT_PREFIX}/val.csv" +TEST_CSV="${SPLIT_PREFIX}/test.csv" +TRAIN_WBDS_DIR="${OUTPUT_DIR}/webdataset_train" +VAL_WBDS_DIR="${OUTPUT_DIR}/webdataset_val" +TEST_WBDS_DIR="${OUTPUT_DIR}/webdataset_test" +MODEL_SAVE_DIR="${OUTPUT_DIR}/model" + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$(dirname "$SCRIPT_DIR")" + +# ============================================================ +# Helper functions +# ============================================================ +step_header() { + echo "" + echo "============================================================" + echo "STEP $1: $2" + echo "============================================================" + echo "" +} + +count_tar_files() { + # Count .tar files in a directory + local dir="$1" + find "$dir" -maxdepth 1 -name "*.tar" | wc -l +} + +build_shard_pattern() { + # Build a brace-expansion webdataset pattern from tar files in a directory + # e.g. "output/webdataset_train/shard-{000000..000003}.tar" + local dir="$1" + local prefix="$2" + local count + count=$(count_tar_files "$dir") + if [ "$count" -eq 0 ]; then + echo "ERROR: No tar files found in $dir" >&2 + return 1 + fi + local last + last=$(printf "%06d" $((count - 1))) + echo "${dir}/${prefix}-{000000..${last}}.tar" +} + +# ============================================================ +# Ensure output directories exist +# ============================================================ +mkdir -p "$DATASET_PATH" "$OUTPUT_DIR" "$SPLIT_PREFIX" "$TRAIN_WBDS_DIR" "$VAL_WBDS_DIR" "$TEST_WBDS_DIR" "$MODEL_SAVE_DIR" + +# Keep track of total time +TOTAL_SECONDS=$SECONDS + +# ============================================================ +# Step 1: Fetch images from DwC-A +# ============================================================ +if [ -d "$DATASET_PATH" ] && [ "$(find "$DATASET_PATH" -name '*.jpg' 2>/dev/null | head -1)" ]; then + echo "SKIP Step 1: Images already exist in $DATASET_PATH" +else + step_header 1 "Fetch images from DwC-A" + uv run ami-dataset fetch-images \ + --dataset-path "$DATASET_PATH" \ + --dwca-file "$DWCA_FILE" \ + --num-workers "$NUM_WORKERS" + + echo "Step 1 complete. Time elapsed: $((SECONDS - TOTAL_SECONDS))s" +fi + +# ============================================================ +# Step 2: Verify downloaded images +# ============================================================ +if [ -f "$VERIFIED_CSV" ]; then + echo "SKIP Step 2: Verified CSV already exists at $VERIFIED_CSV" +else + step_header 2 "Verify downloaded images" + uv run ami-dataset verify-images \ + --dataset-path "$DATASET_PATH" \ + --dwca-file "$DWCA_FILE" \ + --results-csv "$VERIFIED_CSV" \ + --num-workers "$NUM_WORKERS" + + echo "Step 2 complete. Time elapsed: $((SECONDS - TOTAL_SECONDS))s" +fi + +# ============================================================ +# Step 3: Clean dataset (filter thumbnails, duplicates, etc.) +# ============================================================ +if [ -f "$CLEAN_CSV" ]; then + echo "SKIP Step 3: Clean CSV already exists at $CLEAN_CSV" +else + step_header 3 "Clean dataset" + uv run ami-dataset clean-dataset \ + --dwca-file "$DWCA_FILE" \ + --verified-data-csv "$VERIFIED_CSV" \ + --remove-non-adults false + + echo "Step 3 complete. Time elapsed: $((SECONDS - TOTAL_SECONDS))s" +fi + +# ============================================================ +# Step 4: Augment CSV with species names from DwC-A +# ============================================================ +if [ -f "$AUGMENTED_CSV" ] && [ -f "$CATEGORY_MAP" ]; then + echo "SKIP Step 4: Augmented CSV and category map already exist" +else + step_header 4 "Build species list and augment annotations" + uv run python "$SCRIPT_DIR/build_species_list.py" \ + --dwca-file "$DWCA_FILE" \ + --annotations-csv "$CLEAN_CSV" \ + --output-csv "$AUGMENTED_CSV" \ + --category-map-json "$CATEGORY_MAP" \ + --label-column "$LABEL_COLUMN" + + echo "Step 4 complete. Time elapsed: $((SECONDS - TOTAL_SECONDS))s" +fi + +# ============================================================ +# Step 5: Split dataset (stratified train/val/test) +# ============================================================ +if [ -f "$TRAIN_CSV" ] && [ -f "$VAL_CSV" ] && [ -f "$TEST_CSV" ]; then + echo "SKIP Step 5: Split CSVs already exist" +else + step_header 5 "Split dataset into train/val/test" + uv run ami-dataset split-dataset \ + --dataset-csv "$AUGMENTED_CSV" \ + --split-prefix "$SPLIT_PREFIX" \ + --category-key "$LABEL_COLUMN" \ + --max-instances -1 \ + --min-instances "$MIN_INSTANCES" + # For tiny datasets (<200 images): add these flags to lower the per-species threshold + # --val-frac 0.3 --test-frac 0.2 + + echo "Step 5 complete. Time elapsed: $((SECONDS - TOTAL_SECONDS))s" +fi + +# ============================================================ +# Step 6: Create webdatasets (train, val, test) +# ============================================================ +WBDS_ARGS=( + --dataset-path "$DATASET_PATH" + --image-path-column "image_path" + --label-column "$LABEL_COLUMN" + --category-map-json "$CATEGORY_MAP" + --resize-min-size "$RESIZE_MIN_SIZE" + --max-shard-size "$MAX_SHARD_SIZE" +) + +# 6a: Training webdataset +if [ "$(count_tar_files "$TRAIN_WBDS_DIR")" -gt 0 ]; then + echo "SKIP Step 6a: Training webdataset shards already exist" +else + step_header "6a" "Create training webdataset" + uv run ami-dataset create-webdataset \ + --annotations-csv "$TRAIN_CSV" \ + --webdataset-pattern "${TRAIN_WBDS_DIR}/shard-%06d.tar" \ + "${WBDS_ARGS[@]}" + + echo " Created $(count_tar_files "$TRAIN_WBDS_DIR") training shards" +fi + +# 6b: Validation webdataset +if [ "$(count_tar_files "$VAL_WBDS_DIR")" -gt 0 ]; then + echo "SKIP Step 6b: Validation webdataset shards already exist" +else + step_header "6b" "Create validation webdataset" + uv run ami-dataset create-webdataset \ + --annotations-csv "$VAL_CSV" \ + --webdataset-pattern "${VAL_WBDS_DIR}/shard-%06d.tar" \ + "${WBDS_ARGS[@]}" + + echo " Created $(count_tar_files "$VAL_WBDS_DIR") validation shards" +fi + +# 6c: Test webdataset +if [ "$(count_tar_files "$TEST_WBDS_DIR")" -gt 0 ]; then + echo "SKIP Step 6c: Test webdataset shards already exist" +else + step_header "6c" "Create test webdataset" + uv run ami-dataset create-webdataset \ + --annotations-csv "$TEST_CSV" \ + --webdataset-pattern "${TEST_WBDS_DIR}/shard-%06d.tar" \ + "${WBDS_ARGS[@]}" + + echo " Created $(count_tar_files "$TEST_WBDS_DIR") test shards" +fi + +echo "Step 6 complete. Time elapsed: $((SECONDS - TOTAL_SECONDS))s" + +# ============================================================ +# Step 7: Train the species classifier +# ============================================================ +step_header 7 "Train species classifier" + +# Compute num_classes from category map +NUM_CLASSES=$(uv run python -c "import json; print(len(json.load(open('${CATEGORY_MAP}'))))") +echo "Number of classes: $NUM_CLASSES" + +# Build webdataset shard patterns +TRAIN_PATTERN=$(build_shard_pattern "$TRAIN_WBDS_DIR" "shard") +VAL_PATTERN=$(build_shard_pattern "$VAL_WBDS_DIR" "shard") +TEST_PATTERN=$(build_shard_pattern "$TEST_WBDS_DIR" "shard") + +echo "Train pattern: $TRAIN_PATTERN" +echo "Val pattern: $VAL_PATTERN" +echo "Test pattern: $TEST_PATTERN" + +# Build optional wandb args +WANDB_ARGS=() +if [ -n "$WANDB_ENTITY" ]; then + WANDB_ARGS+=(--wandb_entity "$WANDB_ENTITY") +fi +if [ -n "$WANDB_PROJECT" ]; then + WANDB_ARGS+=(--wandb_project "$WANDB_PROJECT") +fi +if [ -n "$WANDB_RUN_NAME" ]; then + WANDB_ARGS+=(--wandb_run_name "$WANDB_RUN_NAME") +fi + +uv run ami-classification train-model \ + --model_type "$MODEL_TYPE" \ + --num_classes "$NUM_CLASSES" \ + --total_epochs "$TOTAL_EPOCHS" \ + --warmup_epochs "$WARMUP_EPOCHS" \ + --early_stopping "$EARLY_STOPPING" \ + --train_webdataset "$TRAIN_PATTERN" \ + --val_webdataset "$VAL_PATTERN" \ + --test_webdataset "$TEST_PATTERN" \ + --image_input_size "$IMAGE_INPUT_SIZE" \ + --batch_size "$BATCH_SIZE" \ + --learning_rate "$LEARNING_RATE" \ + --learning_rate_scheduler "$LR_SCHEDULER" \ + --loss_function_type "$LOSS_FUNCTION" \ + --label_smoothing "$LABEL_SMOOTHING" \ + --model_save_directory "$MODEL_SAVE_DIR" \ + "${WANDB_ARGS[@]}" + +# ============================================================ +# Done +# ============================================================ +echo "" +echo "============================================================" +echo "Pipeline complete!" +echo "Total time: $(( (SECONDS - TOTAL_SECONDS) / 60 )) minutes" +echo "" +echo "Outputs:" +echo " Category map: $CATEGORY_MAP" +echo " Train CSV: $TRAIN_CSV" +echo " Val CSV: $VAL_CSV" +echo " Test CSV: $TEST_CSV" +echo " Train shards: $TRAIN_WBDS_DIR/ ($(count_tar_files "$TRAIN_WBDS_DIR") files)" +echo " Val shards: $VAL_WBDS_DIR/ ($(count_tar_files "$VAL_WBDS_DIR") files)" +echo " Test shards: $TEST_WBDS_DIR/ ($(count_tar_files "$TEST_WBDS_DIR") files)" +echo " Model: $MODEL_SAVE_DIR/" +echo "============================================================"