Skip to content

[BUG] Enormous Long Training Time #41

@rwang5203

Description

@rwang5203

Hi,
I was trying to train MolScribe from scratch on 8 NVIDIA GeForce RTX 3090 and found out every epoch took much longer than expected (Epoch 1 - Time: 55482s) which is abnormally long, what could be the possible reason for this?
I simply added some test files csv to the training bash script train_uspto_joint_chartok_1m680k.sh:

#!/bin/bash

NUM_NODES=1
NUM_GPUS_PER_NODE=8
NODE_RANK=0

BATCH_SIZE=320
ACCUM_STEP=1

MASTER_PORT=$(shuf -n 1 -i 10000-65535)

DATESTR=$(date -u -d '+8 hours' +"%Y-%m-%d_%H-%M-%S")
SAVE_PATH=output/${DATESTR}/swin_base_char_aux_1m680k
mkdir -p ${SAVE_PATH}

set -x

torchrun \
    --nproc_per_node=$NUM_GPUS_PER_NODE --nnodes=$NUM_NODES --node_rank $NODE_RANK --master_addr localhost --master_port $MASTER_PORT \
    train.py \
    --data_path data \
    --train_file pubchem/train_1m.csv \
    --aux_file uspto_mol/train_680k.csv --coords_file aux_file \
    --valid_file real/acs.csv \
    --test_file synthetic/chemdraw.csv,synthetic/indigo.csv,real/acs.csv,real/CLEF.csv,real/JPO.csv,real/staker.csv,real/UOB.csv,real/USPTO.csv,perturb_by_arrows/data_mod/acs_a.csv,perturb_by_arrows/data_mod/CLEF_a.csv,perturb_by_arrows/data_mod/JPO_a.csv,perturb_by_arrows/data_mod/staker_a.csv,perturb_by_arrows/data_mod/UOB_a.csv,perturb_by_arrows/data_mod/USPTO_a.csv,perturb_by_imgtransform/perturb/acs_p.csv,perturb_by_imgtransform/perturb/CLEF_p.csv,perturb_by_imgtransform/perturb/JPO_p.csv,perturb_by_imgtransform/perturb/staker_p.csv,perturb_by_imgtransform/perturb/UOB_p.csv,perturb_by_imgtransform/perturb/USPTO_p.csv \
    --vocab_file molscribe/vocab/vocab_chars.json \
    --formats chartok_coords,edges \
    --dynamic_indigo --augment --mol_augment \
    --include_condensed \
    --coord_bins 64 --sep_xy \
    --input_size 384 \
    --encoder swin_base \
    --decoder transformer \
    --encoder_lr 4e-4 \
    --decoder_lr 4e-4 \
    --save_path $SAVE_PATH --save_mode all \
    --label_smoothing 0.1 \
    --epochs 30 \
    --batch_size $((BATCH_SIZE / NUM_GPUS_PER_NODE / ACCUM_STEP)) \
    --gradient_accumulation_steps $ACCUM_STEP \
    --use_checkpoint \
    --warmup 0.02 \
    --print_freq 200 \
    --do_train --do_valid --do_test \
    --fp16 --backend gloo 2>&1

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions