-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_improved.py
More file actions
36 lines (27 loc) · 886 Bytes
/
train_improved.py
File metadata and controls
36 lines (27 loc) · 886 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from __future__ import annotations
from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path
from types import ModuleType
def _load_train_module() -> ModuleType:
train_path = Path(__file__).with_name("train.py")
spec = spec_from_file_location("train_notebook_compat", train_path)
if spec is None or spec.loader is None:
raise ImportError(f"Unable to load training module from {train_path}.")
module = module_from_spec(spec)
spec.loader.exec_module(module)
return module
_TRAIN = _load_train_module()
CONFIG = _TRAIN.CONFIG
apply_cli_overrides = _TRAIN.apply_cli_overrides
main = _TRAIN.main
parse_args = _TRAIN.parse_args
validate_config = _TRAIN.validate_config
__all__ = [
"CONFIG",
"apply_cli_overrides",
"main",
"parse_args",
"validate_config",
]
if __name__ == "__main__":
main()