-
Notifications
You must be signed in to change notification settings - Fork 21
Expand file tree
/
Copy pathtrain.py
More file actions
26 lines (21 loc) · 691 Bytes
/
train.py
File metadata and controls
26 lines (21 loc) · 691 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import tensorflow as tf
from models import DIRNet
from config import get_config
from data import MNISTDataHandler
from ops import mkdir
def main():
sess = tf.Session()
config = get_config(is_train=True)
mkdir(config.tmp_dir)
mkdir(config.ckpt_dir)
reg = DIRNet(sess, config, "DIRNet", is_train=True)
dh = MNISTDataHandler("MNIST_data", is_train=True)
for i in range(config.iteration):
batch_x, batch_y = dh.sample_pair(config.batch_size)
loss = reg.fit(batch_x, batch_y)
print("iter {:>6d} : {}".format(i+1, loss))
if (i+1) % 1000 == 0:
reg.deploy(config.tmp_dir, batch_x, batch_y)
reg.save(config.ckpt_dir)
if __name__ == "__main__":
main()