Skip to content

Commit 67c32ea

Browse files
authored
Dev (#13)
* Update README.md * Dev (#20) * fix: fix bugs of torch-version ckp loader * refactor: add sync_timestep for hp loader * fix: minor changes for version compatibility * Dev (#21) * fix: fix bugs of torch-version ckp loader * refactor: add sync_timestep for hp loader * fix: minor changes for version compatibility * fix: a bug of sorting in torch-version checkpoint loading * Dev (#22) * fix: fix bugs of torch-version ckp loader * refactor: add sync_timestep for hp loader * fix: minor changes for version compatibility * fix: a bug of sorting in torch-version checkpoint loading * refactor: robust multi-key plot implementation * feat: supoort pretty plotter * refactor(log plotter): record scores of the log plotter * fix(exp_loader): add parameter ckp_index * update readme * Dev (#23) * fix: fix bugs of torch-version ckp loader * refactor: add sync_timestep for hp loader * fix: minor changes for version compatibility * fix: a bug of sorting in torch-version checkpoint loading * refactor: robust multi-key plot implementation * feat: supoort pretty plotter * refactor(log plotter): record scores of the log plotter * fix(exp_loader): add parameter ckp_index * update readme * rm unsolved merge * Dev (#24) * fix: fix bugs of torch-version ckp loader * refactor: add sync_timestep for hp loader * fix: minor changes for version compatibility * fix: a bug of sorting in torch-version checkpoint loading * refactor: robust multi-key plot implementation * feat: supoort pretty plotter * refactor(log plotter): record scores of the log plotter * fix(exp_loader): add parameter ckp_index * update readme * rm unsolved merge * feat: tf-v2 compatible * refactor: add timestep recorder. refactor on exp_loader * test: add test data * feat(plot): track the hyper-parameter from the exp_manager instead of the experiment name. refactor the plot_func for better readability * Dev (#25) * fix: fix bugs of torch-version ckp loader * refactor: add sync_timestep for hp loader * fix: minor changes for version compatibility * fix: a bug of sorting in torch-version checkpoint loading * refactor: robust multi-key plot implementation * feat: supoort pretty plotter * refactor(log plotter): record scores of the log plotter * fix(exp_loader): add parameter ckp_index * update readme * rm unsolved merge * feat: tf-v2 compatible * refactor: add timestep recorder. refactor on exp_loader * test: add test data * feat(plot): track the hyper-parameter from the exp_manager instead of the experiment name. refactor the plot_func for better readability * test(plot): add user cases and documents * test(plot): add user cases
1 parent 3b8bb7c commit 67c32ea

5 files changed

Lines changed: 148 additions & 32 deletions

File tree

RLA/easy_log/exp_loader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,10 @@ def load_from_record_date(self, var_prefix: Optional[str] = None, variable_list:
8888
load_res = {}
8989
if var_prefix is not None:
9090
loaded_tester.new_saver(var_prefix=var_prefix, max_to_keep=1)
91-
_, load_res = loaded_tester.load_checkpoint()
91+
_, load_res = loaded_tester.load_checkpoint(ckp_index)
9292
else:
9393
loaded_tester.new_saver(max_to_keep=1)
94-
_, load_res = loaded_tester.load_checkpoint()
94+
_, load_res = loaded_tester.load_checkpoint(ckp_index)
9595
hist_variables = {}
9696
if variable_list is not None:
9797
for v in variable_list:

RLA/easy_log/tester.py

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -543,11 +543,31 @@ def update_fph(self, cum_epochs):
543543
# self.last_record_fph_time = cur_time
544544
logger.dump_tabular()
545545

546-
def time_record(self, name):
546+
def time_record(self, name:str):
547+
"""
548+
[deprecated] see RLA.easy_log.time_used_recorder
549+
record the consumed time of your code snippet. call this function to start a recorder.
550+
"name" is identifier to distinguish different recorder and record different snippets at the same time.
551+
call time_record_end to end a recorder.
552+
:param name: identifier of your code snippet.
553+
:type name: str
554+
:return:
555+
:rtype:
556+
"""
547557
assert name not in self._rc_start_time
548558
self._rc_start_time[name] = time.time()
549559

550-
def time_record_end(self, name):
560+
def time_record_end(self, name:str):
561+
"""
562+
[deprecated] see RLA.easy_log.time_used_recorder
563+
record the consumed time of your code snippet. call this function to start a recorder.
564+
"name" is identifier to distinguish different recorder and record different snippets at the same time.
565+
call time_record_end to end a recorder.
566+
:param name: identifier of your code snippet.
567+
:type name: str
568+
:return:
569+
:rtype:
570+
"""
551571
end_time = time.time()
552572
start_time = self._rc_start_time[name]
553573
logger.record_tabular("time_used/{}".format(name), end_time - start_time)
@@ -566,23 +586,46 @@ def new_saver(self, max_to_keep, var_prefix=None):
566586
import tensorflow as tf
567587
if var_prefix is None:
568588
var_prefix = ''
569-
var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, var_prefix)
570-
logger.info("save variable :")
571-
for v in var_list:
572-
logger.info(v)
573-
self.saver = tf.train.Saver(var_list=var_list, max_to_keep=max_to_keep, filename=self.checkpoint_dir, save_relative_paths=True)
589+
try:
590+
var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, var_prefix)
591+
logger.info("save variable :")
592+
for v in var_list:
593+
logger.info(v)
594+
self.saver = tf.train.Saver(var_list=var_list, max_to_keep=max_to_keep, filename=self.checkpoint_dir,
595+
save_relative_paths=True)
596+
597+
except AttributeError as e:
598+
self.max_to_keep = max_to_keep
599+
# tf.compat.v1.disable_eager_execution()
600+
# tf = tf.compat.v1
601+
# var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, var_prefix)
574602
elif self.dl_framework == FRAMEWORK.torch:
575603
self.max_to_keep = max_to_keep
576604
else:
577605
raise NotImplementedError
578606

579-
def save_checkpoint(self, model_dict: Optional[dict]=None, related_variable: Optional[dict]=None):
607+
def save_checkpoint(self, model_dict: Optional[dict] = None, related_variable: Optional[dict] = None):
580608
if self.dl_framework == FRAMEWORK.tensorflow:
581609
import tensorflow as tf
582610
iter = self.time_step_holder.get_time()
583611
cpt_name = osp.join(self.checkpoint_dir, 'checkpoint')
584612
logger.info("save checkpoint to ", cpt_name, iter)
585-
self.saver.save(tf.get_default_session(), cpt_name, global_step=iter)
613+
try:
614+
self.saver.save(tf.get_default_session(), cpt_name, global_step=iter)
615+
except AttributeError as e:
616+
if model_dict is None:
617+
logger.warn("call save_checkpoints without passing a model_dict")
618+
return
619+
if self.checkpoint_keep_list is None:
620+
self.checkpoint_keep_list = []
621+
iter = self.time_step_holder.get_time()
622+
# tf.compat.v1.disable_eager_execution()
623+
# tf = tf.compat.v1
624+
# self.saver.save(tf.get_default_session(), cpt_name, global_step=iter)
625+
626+
tf.train.Checkpoint(**model_dict).save(tester.checkpoint_dir + "checkpoint-{}".format(iter))
627+
self.checkpoint_keep_list.append(iter)
628+
self.checkpoint_keep_list = self.checkpoint_keep_list[-1 * self.max_to_keep:]
586629
elif self.dl_framework == FRAMEWORK.torch:
587630
import torch
588631
if self.checkpoint_keep_list is None:
@@ -602,6 +645,7 @@ def save_checkpoint(self, model_dict: Optional[dict]=None, related_variable: Opt
602645
for k, v in related_variable.items():
603646
self.add_custom_data(k, v, type(v), mode='replace')
604647
self.add_custom_data(DEFAULT_X_NAME, time_step_holder.get_time(), int, mode='replace')
648+
self.serialize_object_and_save()
605649

606650
def load_checkpoint(self, ckp_index=None):
607651
if self.dl_framework == FRAMEWORK.tensorflow:
@@ -613,6 +657,7 @@ def load_checkpoint(self, ckp_index=None):
613657
ckpt_path = tf.train.latest_checkpoint(cpt_name)
614658
else:
615659
ckpt_path = tf.train.latest_checkpoint(cpt_name, ckp_index)
660+
logger.info("load ckpt_path {}".format(ckpt_path))
616661
self.saver.restore(tf.get_default_session(), ckpt_path)
617662
max_iter = ckpt_path.split('-')[-1]
618663
return int(max_iter), None

RLA/easy_plot/plot_func_v2.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,30 +11,31 @@
1111

1212
from RLA import logger
1313
from RLA.const import DEFAULT_X_NAME
14-
from RLA.query_tool import experiment_data_query
14+
from RLA.query_tool import experiment_data_query, extract_valid_index
15+
1516
from RLA.easy_plot import plot_util
1617
from RLA.easy_log.const import LOG, ARCHIVE_TESTER, OTHER_RESULTS
1718

1819

1920

20-
def default_key_to_legend(parse_list, y_name):
21-
task_split_key = '.'.join(parse_list)
21+
def default_key_to_legend(parse_dict, split_keys, y_name):
22+
task_split_key = '.'.join(f'{k}={parse_dict[k]}' for k in split_keys)
2223
return task_split_key + ' eval:' + y_name
2324

2425

2526
def plot_func(data_root:str, task_table_name:str, regs:list, split_keys:list, metrics:list,
2627
use_buf=False, verbose=True,
2728
xlim: Optional[tuple] = None,
2829
xlabel: Optional[str] = DEFAULT_X_NAME, ylabel: Optional[str] = None,
29-
scale_dict: Optional[dict] = None, replace_legend_keys: Optional[list] = None,
30+
scale_dict: Optional[dict] = None, regs2legends: Optional[list] = None,
3031
key_to_legend_fn: Optional[Callable] = default_key_to_legend,
3132
save_name: Optional[str] = None, *args, **kwargs):
3233
"""
3334
A high-level matplotlib plotter.
3435
The function is to load your experiments and plot curves.
3536
You can group several experiments into a single figure through this function.
3637
It is completed by loading experiments satisfying [data_root, task_table_name, regs] pattern,
37-
grouping by "split_keys" or by the "regs" terms (see replace_legend_keys), and plotting the customized "metrics".
38+
grouping by "split_keys" or by the "regs" terms (see regs2legends), and plotting the customized "metrics".
3839
3940
The function support several configure to customize the figure, including xlim, xlabel, ylabel, key_to_legend_fn, etc.
4041
The function also supports several configure to post-process your log data, including resample, smooth_step, scale_dict, key_to_legend_fn, etc.
@@ -61,7 +62,13 @@ def plot_func(data_root:str, task_table_name:str, regs:list, split_keys:list, me
6162
:param scale_dict: a function dict, to map the value of the metrics through customize functions.
6263
e.g.,set metrics = ['return'], scale_dict = {'return': lambda x: np.log(x)}, then we will plot a log-scale return.
6364
:type scale_dict: Optional[dict]
64-
:param args: set the label of the y axes.
65+
:param regs2legends: use regex-to-legend mode to plot the figure. Each iterm in regs will be gouped into a curve.
66+
In this reg2legend_map mode, you should define the lgend name for each curve. See test/test_plot/test_reg_map_mode for details.
67+
:type regs2legends: Optional[list] = None
68+
:param key_to_legend_fn: we give a default function to stringify the k-v pairs. you can customize your own function in key_to_legend_fn.
69+
See default_key_to_legend for the detault way and test/test_plot/test_customize_legend_name_mode for details.
70+
:type key_to_legend_fn: Optional[Callable] = default_key_to_legend
71+
:param args/kwargs: send other parameters to plot_util.plot_results
6572
6673
:return:
6774
:rtype:
@@ -98,17 +105,17 @@ def plot_func(data_root:str, task_table_name:str, regs:list, split_keys:list, me
98105
if ylabel is None:
99106
ylabel = metrics
100107

101-
if replace_legend_keys is not None:
102-
assert len(replace_legend_keys) == len(regs) and len(metrics) == 1, \
108+
if regs2legends is not None:
109+
assert len(regs2legends) == len(regs) and len(metrics) == 1, \
103110
"In manual legend-key mode, the number of keys should be one-to-one matched with regs"
104-
# if len(replace_legend_keys) == len(regs):
111+
# if len(regs2legends) == len(regs):
105112
group_fn = lambda r: split_by_reg(taskpath=r, reg_group=reg_group, y_names=y_names)
106113
else:
107114
group_fn = lambda r: picture_split(taskpath=r, split_keys=split_keys, y_names=y_names,
108115
key_to_legend_fn=key_to_legend_fn)
109116
_, _, lgd, texts, g2lf, score_results = \
110117
plot_util.plot_results(results, xy_fn= lambda r, y_names: csv_to_xy(r, DEFAULT_X_NAME, y_names, final_scale_dict),
111-
group_fn=group_fn, average_group=True, ylabel=ylabel, xlabel=xlabel, replace_legend_keys=replace_legend_keys, *args, **kwargs)
118+
group_fn=group_fn, average_group=True, ylabel=ylabel, xlabel=xlabel, regs2legends=regs2legends, *args, **kwargs)
112119
print("--- complete process ---")
113120
if save_name is not None:
114121
import os
@@ -127,25 +134,28 @@ def plot_func(data_root:str, task_table_name:str, regs:list, split_keys:list, me
127134
def split_by_reg(taskpath, reg_group, y_names):
128135
task_split_key = "None"
129136
for i , reg_k in enumerate(reg_group.keys()):
130-
if taskpath.dirname in reg_group[reg_k]:
131-
assert task_split_key == "None", "one experiment should belong to only one reg_group"
132-
task_split_key = str(i)
137+
for result in reg_group[reg_k]:
138+
if taskpath.dirname == result.dirname:
139+
assert task_split_key == "None", "one experiment should belong to only one reg_group"
140+
task_split_key = str(i)
133141
assert len(y_names) == 1
134142
return task_split_key, y_names
135143

136144

137145
def split_by_task(taskpath, split_keys, y_names, key_to_legend_fn):
138146
pair_delimiter = '&'
139147
kv_delimiter = '='
140-
parse_list = []
148+
parse_dict = {}
141149
for split_key in split_keys:
142150
if split_key in taskpath.hyper_param:
143-
parse_list.append(split_key + '=' + str(taskpath.hyper_param[split_key]))
151+
parse_dict[split_key] = str(taskpath.hyper_param[split_key])
152+
# parse_list.append(split_key + '=' + str(taskpath.hyper_param[split_key]))
144153
else:
145-
parse_list.append(split_key + '=NF')
154+
parse_dict[split_key] = 'NF'
155+
# parse_list.append(split_key + '=NF')
146156
param_keys = []
147157
for y_name in y_names:
148-
param_keys.append(key_to_legend_fn(parse_list, y_name))
158+
param_keys.append(key_to_legend_fn(parse_dict, split_keys, y_name))
149159
return param_keys, y_names
150160

151161

RLA/easy_plot/plot_util.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ def plot_results(
305305
ylabel=None,
306306
title=None,
307307
replace_legend_keys=None,
308-
replace_legend_sort=None,
308+
regs2legends=None,
309309
pretty=False,
310310
bound_line=None,
311311
colors=None,
@@ -505,6 +505,8 @@ def allequal(qs):
505505
legend_lines = legend_lines[sorted_index]
506506
if replace_legend_keys is not None:
507507
legend_keys = np.array(replace_legend_keys)
508+
if regs2legends is not None:
509+
legend_keys = np.array(regs2legends)
508510
# if replace_legend_sort is not None:
509511
# sorted_index = replace_legend_sort
510512
# else:

test/test_plot.py

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,85 @@
11
# Created by xionghuichen at 2022/8/10
22
# Email: chenxh@lamda.nju.edu.cn
33
from test._base import BaseTest
4+
import numpy as np
45
from RLA.easy_log.log_tools import DeleteLogTool, Filter
56
from RLA.easy_log.log_tools import ArchiveLogTool, ViewLogTool
67
from RLA.easy_log.tester import exp_manager
7-
8+
from RLA import plot_func
89
import os
910

1011
class ScriptTest(BaseTest):
11-
def test_plot(self):
12-
from RLA import plot_func
12+
13+
def get_basic_info(self):
1314
data_root = 'test_data_root'
1415
task = 'demo_task'
16+
return data_root, task
17+
18+
def test_plot_basic(self):
19+
data_root, task = self.get_basic_info()
20+
1521
regs = [
1622
'2022/03/01/21-[12]*'
1723
]
1824
_ = plot_func(data_root=data_root, task_table_name=task, regs=regs, split_keys=['learning_rate'],
1925
metrics=['perf/mse'])
26+
# customize the figure
2027
_ = plot_func(data_root=data_root, task_table_name=task, regs=regs, split_keys=['learning_rate'],
2128
metrics=['perf/mse'], ylim=(0, 0.1))
2229
_ = plot_func(data_root=data_root, task_table_name=task, regs=regs, split_keys=['learning_rate'],
2330
metrics=['perf/mse'], ylim=(0, 0.1), xlabel='epochs', ylabel='reward ratio', )
31+
32+
33+
def test_pretty_plot(self):
34+
data_root, task = self.get_basic_info()
35+
36+
regs = [
37+
'2022/03/01/21-[12]*'
38+
]
39+
# save image
40+
_ = plot_func(data_root=data_root, task_table_name=task, regs=regs, split_keys=['learning_rate'],
41+
metrics=['perf/mse'], ylim=(0, 0.1), xlabel='epochs', ylabel='reward ratio',
42+
shaded_range=False, show_number=False, pretty=True)
2443
_ = plot_func(data_root=data_root, task_table_name=task, regs=regs, split_keys=['learning_rate'],
2544
metrics=['perf/mse'], ylim=(0, 0.1), xlabel='epochs', ylabel='reward ratio',
45+
shaded_range=False, pretty=True, save_name='saved_image.png')
46+
47+
def test_reg_map_mode(self):
48+
# reg-map mode.
49+
data_root, task = self.get_basic_info()
50+
regs = [
51+
'2022/03/01/21-[12]*learning_rate=0.01*',
52+
'2022/03/01/21-[12]*learning_rate=0.00*',
53+
]
54+
_ = plot_func(data_root=data_root, task_table_name=task, regs=regs, split_keys=['learning_rate'],
55+
metrics=['perf/mse'], regs2legends=['lr=0.01', 'lr<=0.001'],
2656
shaded_range=False, pretty=True)
57+
58+
def test_customize_legend_name_mode(self):
59+
data_root, task = self.get_basic_info()
60+
regs = [
61+
'2022/03/01/21-[12]*'
62+
]
63+
64+
def my_key_to_legend(parse_dict, split_keys, y_name):
65+
66+
task_split_key = '.'.join(f'{k}={parse_dict[k]}' for k in split_keys)
67+
task_split_key = task_split_key.replace('learning_rate', 'α')
68+
return task_split_key
69+
70+
_ = plot_func(data_root=data_root, task_table_name=task, regs=regs, split_keys=['learning_rate'],
71+
metrics=['perf/mse'],
72+
key_to_legend_fn=my_key_to_legend,
73+
shaded_range=False, pretty=True, show_number=False)
74+
75+
def test_post_process(self):
76+
data_root, task = self.get_basic_info()
77+
regs = [
78+
'2022/03/01/21-[12]*'
79+
]
80+
81+
_ = plot_func(data_root=data_root, task_table_name=task, regs=regs, split_keys=['learning_rate'],
82+
metrics=['perf/mse'],
83+
scale_dict={'perf/mse': lambda x: np.log(x)},
84+
ylabel='RMSE',
85+
shaded_range=False, pretty=True, show_number=False)

0 commit comments

Comments
 (0)