-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtrain.py
More file actions
39 lines (31 loc) · 1.1 KB
/
train.py
File metadata and controls
39 lines (31 loc) · 1.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import argparse
import os
import logging
import torch
from workflow import json
from vae import train
logging.getLogger('ignite').setLevel(logging.WARNING)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', default=20, type=int)
parser.add_argument('--eval_batch_size', default=20, type=int)
parser.add_argument('--learning_rate', default=1e-3, type=float)
parser.add_argument('--max_epochs', default=200, type=int)
parser.add_argument('--n_batches_per_epoch', default=200, type=int)
parser.add_argument('--n_batches_per_step', default=1, type=int)
parser.add_argument('--patience', default=40, type=float)
parser.add_argument('--n_workers', default=0, type=int)
try:
__IPYTHON__
args = parser.parse_known_args()[0]
except NameError:
args = parser.parse_args()
config = vars(args)
config.update(
seed=1,
use_cuda=torch.cuda.is_available(),
run_id=os.getenv('RUN_ID'),
)
json.write(config, 'config.json')
train(config)