Skip to content

Commit 7204ee5

Browse files
committed
fix(exp_manager): add log_name_formatter
1 parent c04cb15 commit 7204ee5

3 files changed

Lines changed: 29 additions & 22 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,6 @@ RLA.egg-info**
1111
**/arc/**
1212
**/.ipynb_checkpoints/*
1313
**/.DS_Store
14+
test/test_data_root/*
1415
test/target_data_root/*
1516
**/private_config.py

RLA/easy_log/tester.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,24 @@ def _copy_run_file(run_file, code_dir):
524524
else:
525525
raise NotImplementedError
526526

527+
def log_name_formatter(self, prefix, record_date):
528+
"""
529+
return a unified and unique name for the experiment log.
530+
:param prefix: prefix location to store the log data.
531+
:param record_date: the timestamp of the experiment log.
532+
:return: a unify and unique name
533+
"""
534+
version_num = self.get_version_num()
535+
if version_num is None:
536+
name_format = '{prefix}/{date}/{timestep} {ip} {info}'
537+
elif version_num == LOG_NAME_FORMAT_VERSION.V1:
538+
name_format = '{prefix}/{date}/{timestep}_{ip}_{info}'
539+
else:
540+
raise RuntimeError("unknown version name", version_num)
541+
date = record_date.strftime("%Y/%m/%d")
542+
return name_format.format(prefix=prefix, date=date, timestep=self.record_date_to_str(record_date),
543+
ip=str(self.ipaddr), info=self.info)
544+
527545
def record_date_to_str(self, record_date):
528546
return str(record_date.strftime("%H-%M-%S-%f"))
529547

@@ -534,29 +552,14 @@ def get_version_num(self):
534552
def __create_file_directory(self, prefix, ext='', is_file=True, record_date=None):
535553
if record_date is None:
536554
record_date = self.record_date
555+
name = self.log_name_formatter(prefix, record_date)
537556
directory = str(record_date.strftime("%Y/%m/%d"))
538557
directory = osp.join(prefix, directory)
539-
version_num = self.get_version_num()
540-
541-
if version_num is None:
542-
name_format = '{dir}/{timestep} {ip} {info}{ext}'
543-
elif version_num == LOG_NAME_FORMAT_VERSION.V1:
544-
name_format = '{dir}/{timestep}_{ip}_{info}{ext}'
545-
else:
546-
raise RuntimeError("unknown version name", version_num)
547-
548558
if is_file:
549559
os.makedirs(directory, exist_ok=True)
550-
file_name = name_format.format(dir=directory, timestep=self.record_date_to_str(record_date),
551-
ip=str(self.ipaddr),
552-
info=self.info,
553-
ext=ext)
560+
file_name = name + ext
554561
else:
555-
directory = (name_format + '/').format(dir=directory,
556-
timestep=self.record_date_to_str(record_date),
557-
ip=str(self.ipaddr),
558-
info=self.info,
559-
ext=ext)
562+
directory = name + '/'
560563
os.makedirs(directory, exist_ok=True)
561564
file_name = ''
562565
return directory, file_name

test/test_proj/proj/test_manager.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def target_func(x):
1515
DATABASE_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
1616
CODE_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
1717

18+
1819
class ManagerTest(BaseTest):
1920

2021
def _load_rla_config(self):
@@ -54,7 +55,6 @@ def test_log_tf(self):
5455

5556
exp_manager.new_saver(var_prefix='', max_to_keep=1)
5657
# synthetic target function.
57-
5858
for i in range(0, 100):
5959
exp_manager.time_step_holder.set_time(i)
6060
x_input = np.random.normal(0, 3, [64, kwargs["input_size"]])
@@ -66,8 +66,6 @@ def test_log_tf(self):
6666
if i % 20 == 0:
6767
exp_manager.save_checkpoint()
6868
if i % 10 == 0:
69-
logger.ma_record_tabular("perf/mse-long", np.mean(mse_loss.detach().cpu().numpy()), 10, freq=25)
70-
logger.record_tabular("y_out-long", np.mean(y), freq=25)
7169
def plot_func():
7270
import matplotlib.pyplot as plt
7371
testX = np.repeat(np.expand_dims(np.arange(-10, 10, 0.1), axis=-1), repeats=kwargs["input_size"], axis=-1)
@@ -139,7 +137,12 @@ def test_sent_to_master(self):
139137
exp_manager.set_hyper_param(**kwargs)
140138
exp_manager.add_record_param(['input_size'])
141139
yaml = self._load_rla_config()
142-
from test.test_proj.proj import private_config
140+
try:
141+
from test.test_proj.proj import private_config
142+
except ImportError as e:
143+
print("[WARN] for this test, you should config your username, password, and the remote root firstly.")
144+
return
145+
# raise RuntimeError
143146
yaml['DL_FRAMEWORK'] = 'torch'
144147
yaml['SEND_LOG_FILE'] = True
145148
yaml['REMOTE_SETTING']['ftp_server'] = '127.0.0.1'

0 commit comments

Comments
 (0)