forked from Advanced-Optimization/Practical1
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_model.py
More file actions
69 lines (54 loc) · 1.84 KB
/
train_model.py
File metadata and controls
69 lines (54 loc) · 1.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
import sys
from pathlib import Path
from modules.calibration import calibrate_young
from modules.lab_utils import load_dataset, LAB_PATH, fix_path
from modules.pytorch_mlp import PytorchMLPReg
DEFAULT = "pytorch"
def train_pytorch_model(dataset_path, from_real=False):
x_train, y_train, x_test, y_test = load_dataset(dataset_path, from_real)
dataset_fname = dataset_path.parts[-1].strip(".csv")
fname = f"{LAB_PATH}/data/results/{dataset_fname}.pth"
mlp = PytorchMLPReg()
mlp.train(x_train, y_train, x_test, y_test, n_epochs=2_000)
mlp.save(fname)
print(f"Trained model saved at {os.path.abspath(fname)}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Train model using dataset")
parser.add_argument(
"--model-type",
type=str,
choices=["pytorch", "calibrated"],
default=DEFAULT,
help="Model type: pytorch or calibrated",
)
parser.add_argument(
"--dataset-path",
type=Path,
default=Path(
"data/results/blueleg_beam_sphere515.csv"
),
help="Path to dataset CSV",
)
parser.add_argument(
"-r",
"--from-real",
action="store_true",
help="Use real-world dataset instead of synthetic",
)
args = parser.parse_args()
dataset_path = args.dataset_path
learn_from_real = args.from_real
model_type = args.model_type
dataset_path = fix_path(dataset_path)
if dataset_path is None:
print(f"Dataset file not found: {dataset_path}")
sys.exit(1)
if model_type == "calibrated":
calibrate_young(dataset_path, learn_from_real)
elif model_type == "pytorch":
train_pytorch_model(dataset_path, learn_from_real)
else:
print(f"Unknown model type: {model_type}")
sys.exit(1)