-
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathWarmRestart_example.py
More file actions
44 lines (33 loc) · 1.14 KB
/
WarmRestart_example.py
File metadata and controls
44 lines (33 loc) · 1.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
-------------------------------------------------
Description : utils.lrs_scheduler WarmRestart and warm_restart example
Email : autuanliu@163.com
Date:2018/04/01
"""
from models.utils.utils_imports import *
from models.utils.lrs_scheduler import WarmRestart, warm_restart
from models.vislib.line_plot import line
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 1, 1)
self.conv2 = nn.Conv2d(1, 1, 1)
def forward(self, x):
return self.conv2(F.relu(self.conv1(x)))
net = Net()
opt = optim.SGD([{'params': net.conv1.parameters()}, {'params': net.conv2.parameters(), 'lr': 0.5}], lr=0.05)
# CosineAnnealingLR with warm_restart
# scheduler = lr_scheduler.CosineAnnealingLR(opt, T_max=20, eta_min=0)
scheduler = WarmRestart(opt, T_max=20, T_mult=2, eta_min=1e-10)
vis_data = []
for epoch in range(200):
scheduler.step()
# for warm_restart
# scheduler = warm_restart(scheduler, T_mult=2)
print(scheduler.get_lr())
vis_data.append(scheduler.get_lr()[0])
opt.step()
line(vis_data)
plt.show()