-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathscript.py
More file actions
120 lines (108 loc) · 4.89 KB
/
script.py
File metadata and controls
120 lines (108 loc) · 4.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import click
import logging
import argparse
import torch
import toml
from rgnn.train.trainer import Trainer, test_model
from torch_geometric.loader import DataLoader
from rgnn.models.reaction_models.utils import get_scaler
from rgnn.common.registry import registry
from sklearn.model_selection import train_test_split
def train_reaction_model(settings):
with open(settings, "r") as f:
config = toml.load(f)
task = config["task"]
logger_config = config["logger"]
train_config = config["train"]
model_config = config["model"]
if task not in os.listdir():
os.makedirs(task, exist_ok=True)
if "model" not in os.listdir(task):
os.mkdir(task + "/model")
batch_size = train_config.get("batch_size", 8)
device = train_config.get("device", "cuda")
num_epoch = train_config.get("num_epoch", 20)
start_epoch = train_config.get("start_epoch", 0)
save_model_name = train_config.get("save_model_name", "reaction_model.pth.tar")
save_model_path = f"{task}/model"
# Setup logger
log_filename = f"{task}/{logger_config['filename']}.log"
logger = setup_logger(logger_config["name"], log_filename)
logger.info(f"Training reaction model in: {os.path.realpath(task)}")
# Load dataset
total_dataset = torch.load(train_config["trainset_path"])
val_size = train_config.get("val_size", 0.2)
trainset, valset = train_test_split(total_dataset, test_size=val_size, random_state=42)
train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(valset, batch_size=batch_size)
if train_config.get("testset_path", None) is not None:
testset = torch.load(train_config["testset_path"])
else:
testset = None
# Setup model
train_labels = train_config["train_labels"]
means, stddevs = get_scaler(train_labels, trainset)
means_str = "\t".join([f"{key}: {value.item():.4f}" for key, value in means.items()])
stddevs_str = "\t".join([f"{key}: {value.item():.4f}" for key, value in stddevs.items()])
logger.info(f"Means: {means_str}")
logger.info(f"Stddevs: {stddevs_str}")
model_name = model_config.pop("@name")
model_config.update({"means": means, "stddevs": stddevs})
reaction_model = registry.get_reaction_model_class(model_name)(**model_config)
trainable_params = filter(lambda p: p.requires_grad, reaction_model.parameters())
# Setup loss function, optimizer and scheduler
loss_params = train_config["loss_params"]
loss_fn = registry.get_loss_class(train_config["loss_fn"])(**loss_params)
optimizer_params = train_config.get("optimizer_params", {"lr": 3e-4, "weight_decay": 0.0})
optimizer = registry.get_optimizer_class(train_config.get("optimizer", "adam"))(trainable_params, **optimizer_params)
scheduler_params = train_config.get("scheduler_params", {"patience": 10, "factor": 0.1})
scheduler = registry.get_scheduler_class(train_config.get("scheduler", "reduce_lr_on_plateau"))(optimizer=optimizer, **scheduler_params)
trainer = Trainer(
save_dir=save_model_path,
model=reaction_model,
loss_fn=loss_fn,
optimizer=optimizer,
scheduler=scheduler,
train_loader=train_loader,
validation_loader=val_loader,
)
trainer.train(
logger=logger,
device=device,
n_epochs=num_epoch,
start_epoch=start_epoch,
best_model_name=save_model_name,
)
if testset is not None:
test_results = test_model(
dataset=testset,
model=reaction_model,
device=device,
batch_size=batch_size,
)
for key, value in test_results.items():
pred = value["pred"]
label = value["label"]
mae = torch.abs(pred - label).mean()
mse = torch.square(pred - label).mean()
rmse = torch.sqrt(mse)
r2 = 1 - torch.square(pred - label).mean() / torch.square(label - label.mean()).mean()
logger.info(f"{key}: MAE: {mae:.4f}, MSE: {mse:.4f}, RMSE: {rmse:.4f}, R2: {r2:.4f}")
with open(f"{task}/test_results.txt", "a") as f:
f.write(f"{key}: MAE: {mae:.4f}, MSE: {mse:.4f}, RMSE: {rmse:.4f}, R2: {r2:.4f}\n")
with open(f"{task}/test_results.txt", "a") as f:
f.write("--------------------------------\n")
def setup_logger(name, log_file, level=logging.INFO):
formatter = logging.Formatter("%(asctime)s - %(name)s | %(message)s", "%Y-%m-%d %H:%M:%S")
file_handler = logging.FileHandler(log_file, mode="w")
file_handler.setFormatter(formatter)
logger = logging.getLogger(name)
logger.setLevel(level)
logger.addHandler(file_handler)
return logger
if __name__ == "__main__":
args = argparse.ArgumentParser()
args.add_argument("-c", "--config", required=True, help="config file path")
args = args.parse_args()
train_reaction_model(args.config)