forked from tensorlayer/seq2seq-chatbot
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
247 lines (208 loc) · 10.1 KB
/
main.py
File metadata and controls
247 lines (208 loc) · 10.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
#! /usr/bin/python
# -*- coding: utf8 -*-
"""Sequence to Sequence Learning for Twitter/Cornell Chatbot.
References
----------
http://suriyadeepan.github.io/2016-12-31-practical-seq2seq/
"""
import time
import click
import numpy as np
import tensorflow as tf
import tensorlayer as tl
from tqdm import tqdm
from sklearn.utils import shuffle
from tensorlayer.layers import DenseLayer, EmbeddingInputlayer, Seq2Seq, retrieve_seq_length_op2
from data.twitter import data
sess_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
"""
Training model [optional args]
"""
@click.command()
@click.option('-dc', '--data-corpus', default='twitter', help='Data corpus to use for training and inference',)
@click.option('-bs', '--batch-size', default=32, help='Batch size for training on minibatches',)
@click.option('-n', '--num-epochs', default=50, help='Number of epochs for training',)
@click.option('-lr', '--learning-rate', default=0.001, help='Learning rate to use when training model',)
@click.option('-inf', '--inference-mode', is_flag=True, help='Flag for INFERENCE mode',)
def train(data_corpus, batch_size, num_epochs, learning_rate, inference_mode):
metadata, trainX, trainY, testX, testY, validX, validY = initial_setup(data_corpus)
# Parameters
src_len = len(trainX)
tgt_len = len(trainY)
assert src_len == tgt_len
n_step = src_len // batch_size
src_vocab_size = len(metadata['idx2w']) # 8002 (0~8001)
emb_dim = 1024
word2idx = metadata['w2idx'] # dict word 2 index
idx2word = metadata['idx2w'] # list index 2 word
unk_id = word2idx['unk'] # 1
pad_id = word2idx['_'] # 0
start_id = src_vocab_size # 8002
end_id = src_vocab_size + 1 # 8003
word2idx.update({'start_id': start_id})
word2idx.update({'end_id': end_id})
idx2word = idx2word + ['start_id', 'end_id']
src_vocab_size = tgt_vocab_size = src_vocab_size + 2
""" A data for Seq2Seq should look like this:
input_seqs : ['how', 'are', 'you', '<PAD_ID'>]
decode_seqs : ['<START_ID>', 'I', 'am', 'fine', '<PAD_ID'>]
target_seqs : ['I', 'am', 'fine', '<END_ID>', '<PAD_ID'>]
target_mask : [1, 1, 1, 1, 0]
"""
# Preprocessing
target_seqs = tl.prepro.sequences_add_end_id([trainY[10]], end_id=end_id)[0]
decode_seqs = tl.prepro.sequences_add_start_id([trainY[10]], start_id=start_id, remove_last=False)[0]
target_mask = tl.prepro.sequences_get_mask([target_seqs])[0]
if not inference_mode:
print("encode_seqs", [idx2word[id] for id in trainX[10]])
print("target_seqs", [idx2word[id] for id in target_seqs])
print("decode_seqs", [idx2word[id] for id in decode_seqs])
print("target_mask", target_mask)
print(len(target_seqs), len(decode_seqs), len(target_mask))
# Init Session
tf.reset_default_graph()
sess = tf.Session(config=sess_config)
# Training Data Placeholders
encode_seqs = tf.placeholder(dtype=tf.int64, shape=[batch_size, None], name="encode_seqs")
decode_seqs = tf.placeholder(dtype=tf.int64, shape=[batch_size, None], name="decode_seqs")
target_seqs = tf.placeholder(dtype=tf.int64, shape=[batch_size, None], name="target_seqs")
target_mask = tf.placeholder(dtype=tf.int64, shape=[batch_size, None], name="target_mask")
net_out, _ = create_model(encode_seqs, decode_seqs, src_vocab_size, emb_dim, is_train=True, reuse=False)
net_out.print_params(False)
# Inference Data Placeholders
encode_seqs2 = tf.placeholder(dtype=tf.int64, shape=[1, None], name="encode_seqs")
decode_seqs2 = tf.placeholder(dtype=tf.int64, shape=[1, None], name="decode_seqs")
net, net_rnn = create_model(encode_seqs2, decode_seqs2, src_vocab_size, emb_dim, is_train=False, reuse=True)
y = tf.nn.softmax(net.outputs)
# Loss Function
loss = tl.cost.cross_entropy_seq_with_mask(logits=net_out.outputs, target_seqs=target_seqs,
input_mask=target_mask, return_details=False, name='cost')
# Optimizer
train_op = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss)
# Init Vars
sess.run(tf.global_variables_initializer())
# Load Model
tl.files.load_and_assign_npz(sess=sess, name='model.npz', network=net)
"""
Inference using pre-trained model
"""
def inference(seed):
seed_id = [word2idx.get(w, unk_id) for w in seed.split(" ")]
# Encode and get state
state = sess.run(net_rnn.final_state_encode,
{encode_seqs2: [seed_id]})
# Decode, feed start_id and get first word [https://github.com/zsdonghao/tensorlayer/blob/master/example/tutorial_ptb_lstm_state_is_tuple.py]
o, state = sess.run([y, net_rnn.final_state_decode],
{net_rnn.initial_state_decode: state,
decode_seqs2: [[start_id]]})
w_id = tl.nlp.sample_top(o[0], top_k=3)
w = idx2word[w_id]
# Decode and feed state iteratively
sentence = [w]
for _ in range(30): # max sentence length
o, state = sess.run([y, net_rnn.final_state_decode],
{net_rnn.initial_state_decode: state,
decode_seqs2: [[w_id]]})
w_id = tl.nlp.sample_top(o[0], top_k=2)
w = idx2word[w_id]
if w_id == end_id:
break
sentence = sentence + [w]
return sentence
if inference_mode:
print('Inference Mode')
print('--------------')
while True:
input_seq = input('Enter Query: ')
sentence = inference(input_seq)
print(" >", ' '.join(sentence))
else:
seeds = ["happy birthday have a nice day",
"donald trump won last nights presidential debate according to snap online polls"]
for epoch in range(num_epochs):
trainX, trainY = shuffle(trainX, trainY, random_state=0)
total_loss, n_iter = 0, 0
for X, Y in tqdm(tl.iterate.minibatches(inputs=trainX, targets=trainY, batch_size=batch_size, shuffle=False),
total=n_step, desc='Epoch[{}/{}]'.format(epoch + 1, num_epochs), leave=False):
X = tl.prepro.pad_sequences(X)
_target_seqs = tl.prepro.sequences_add_end_id(Y, end_id=end_id)
_target_seqs = tl.prepro.pad_sequences(_target_seqs)
_decode_seqs = tl.prepro.sequences_add_start_id(Y, start_id=start_id, remove_last=False)
_decode_seqs = tl.prepro.pad_sequences(_decode_seqs)
_target_mask = tl.prepro.sequences_get_mask(_target_seqs)
## Uncomment to view the data here
# for i in range(len(X)):
# print(i, [idx2word[id] for id in X[i]])
# print(i, [idx2word[id] for id in Y[i]])
# print(i, [idx2word[id] for id in _target_seqs[i]])
# print(i, [idx2word[id] for id in _decode_seqs[i]])
# print(i, _target_mask[i])
# print(len(_target_seqs[i]), len(_decode_seqs[i]), len(_target_mask[i]))
_, loss_iter = sess.run([train_op, loss], {encode_seqs: X, decode_seqs: _decode_seqs,
target_seqs: _target_seqs, target_mask: _target_mask})
total_loss += loss_iter
n_iter += 1
# printing average loss after every epoch
print('Epoch [{}/{}]: loss {:.4f}'.format(epoch + 1, num_epochs, total_loss / n_iter))
# inference after every epoch
for seed in seeds:
print("Query >", seed)
for _ in range(5):
sentence = inference(seed)
print(" >", ' '.join(sentence))
# saving the model
tl.files.save_npz(net.all_params, name='model.npz', sess=sess)
# session cleanup
sess.close()
"""
Creates the LSTM Model
"""
def create_model(encode_seqs, decode_seqs, src_vocab_size, emb_dim, is_train=True, reuse=False):
with tf.variable_scope("model", reuse=reuse):
# for chatbot, you can use the same embedding layer,
# for translation, you may want to use 2 seperated embedding layers
with tf.variable_scope("embedding") as vs:
net_encode = EmbeddingInputlayer(
inputs = encode_seqs,
vocabulary_size = src_vocab_size,
embedding_size = emb_dim,
name = 'seq_embedding')
vs.reuse_variables()
net_decode = EmbeddingInputlayer(
inputs = decode_seqs,
vocabulary_size = src_vocab_size,
embedding_size = emb_dim,
name = 'seq_embedding')
net_rnn = Seq2Seq(net_encode, net_decode,
cell_fn = tf.nn.rnn_cell.LSTMCell,
n_hidden = emb_dim,
initializer = tf.random_uniform_initializer(-0.1, 0.1),
encode_sequence_length = retrieve_seq_length_op2(encode_seqs),
decode_sequence_length = retrieve_seq_length_op2(decode_seqs),
initial_state_encode = None,
dropout = (0.5 if is_train else None),
n_layer = 3,
return_seq_2d = True,
name = 'seq2seq')
net_out = DenseLayer(net_rnn, n_units=src_vocab_size, act=tf.identity, name='output')
return net_out, net_rnn
"""
Initial Setup
"""
def initial_setup(data_corpus):
metadata, idx_q, idx_a = data.load_data(PATH='data/{}/'.format(data_corpus))
(trainX, trainY), (testX, testY), (validX, validY) = data.split_dataset(idx_q, idx_a)
trainX = tl.prepro.remove_pad_sequences(trainX.tolist())
trainY = tl.prepro.remove_pad_sequences(trainY.tolist())
testX = tl.prepro.remove_pad_sequences(testX.tolist())
testY = tl.prepro.remove_pad_sequences(testY.tolist())
validX = tl.prepro.remove_pad_sequences(validX.tolist())
validY = tl.prepro.remove_pad_sequences(validY.tolist())
return metadata, trainX, trainY, testX, testY, validX, validY
def main():
try:
train()
except KeyboardInterrupt:
print('Aborted!')
if __name__ == '__main__':
main()