-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplot.py
More file actions
27 lines (22 loc) · 850 Bytes
/
plot.py
File metadata and controls
27 lines (22 loc) · 850 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
import matplotlib.pyplot as plt
import numpy as np
fig, ax = plt.subplots(2, 1, figsize=(15, 15))
for net in ('ResNet20', 'ResNet32', 'ResNet44', 'ResNet56'):
checkpoint = torch.load('pretrained/' + net + '.pth')
history = checkpoint['history']
c = ax[0].plot(history['acc'], linewidth=0.5, label=net)
ax[0].plot(history['val_acc'], linewidth=0.5, color=c[0].get_color())
c = ax[1].plot(history['loss'], linewidth=0.5, label=net)
ax[1].plot(history['val_loss'], linewidth=0.5, color=c[0].get_color())
ax[0].set_title('model accuracy')
ax[0].set_ylabel('accuracy')
ax[0].set_xlabel('epoch')
ax[0].legend()
ax[0].grid(b=True, linestyle='--')
ax[1].set_title('model loss')
ax[1].set_ylabel('loss')
ax[1].set_xlabel('epoch')
ax[1].legend()
ax[1].grid(b=True, linestyle='--')
plt.savefig('plot.png', dpi=100)