-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_model.py
More file actions
114 lines (90 loc) · 4.02 KB
/
train_model.py
File metadata and controls
114 lines (90 loc) · 4.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import argparse
import numpy as np
from pickle import dump
from keras import backend as K
from keras.utils import plot_model
from tools.data_tools import GetData
from keras.optimizers import SGD, RMSprop
from sklearn.model_selection import train_test_split
from tools.language_model_tools import LanguageModel
from keras.callbacks import ReduceLROnPlateau, ModelCheckpoint, EarlyStopping
def argument_parser():
ap = argparse.ArgumentParser()
ap.add_argument("-e", "--epoch", required=True,
help="Choose the number of epoch.")
return vars(ap.parse_args())
def get_embedding_matrix(embedding_dim, word_index, vocabulary_size):
glove_dir = 'glove.6B'
embeddings_index = {}
f = open(os.path.join(glove_dir, 'glove.6B.{}d.txt'.format(embedding_dim)))
for line in f:
values = line.split()
word = values[0]
coefs = np.asarray(values[1:], dtype='float32')
embeddings_index[word] = coefs
f.close()
embedding_matrix = np.zeros((vocabulary_size, embedding_dim))
for word, i in word_index.items():
embedding_vector = embeddings_index.get(word)
if i < vocabulary_size:
if embedding_vector is not None:
# Words not found in embedding index will be all-zeros.
embedding_matrix[i] = embedding_vector
return embedding_matrix
def main():
SEQUENCE_MAX_LEN = 4
args = argument_parser()
try:
NUM_EPOCHS = int(args["epoch"])
except ValueError:
print("\nError: Epoch should be an integer.")
print("Exiting!\n")
sys.exit(1)
input_file_path = os.path.join("input_files", "complete_seinfeld_scripts.csv")
#Get data
data = GetData(SEQUENCE_MAX_LEN)
X, y, tokenizer, word_index, vocabulary_size = data.get_data(input_file_path)
# Save word_index
for_serve = [tokenizer, word_index, SEQUENCE_MAX_LEN]
dump(for_serve, open('saved_models/for_server.pkl', 'wb'))
# Options are 50, 100, 200, 300
embedding_dim = 100
embedding_matrix = get_embedding_matrix(embedding_dim, word_index, vocabulary_size)
# Train the model
language_model = LanguageModel(vocabulary_size, SEQUENCE_MAX_LEN, embedding_dim, embedding_matrix)
model = language_model.build_model()
model_and_weights = os.path.join("saved_models", "model_and_weights.hdf5")
# If weights exist, load them before training
if(os.path.isfile(model_and_weights)):
print("Old weights found!")
try:
model.load_weights(model_and_weights)
print("Old weights loaded successfully!")
except:
print("Old weights couldn't be loaded successfully, will continue!")
learning_rate = 1.0e-4;
decaly_rate = learning_rate/NUM_EPOCHS
model.compile(optimizer=SGD(lr=learning_rate, decay=decaly_rate, momentum=0.9), loss='categorical_crossentropy')
# Print model summary
model.summary()
# Plot the model architecture
model_path = os.path.join("plots", "model.pdf")
plot_model(model, to_file=model_path, show_shapes=True)
# Stop training when a monitored quantity has stopped improving after certain epochs
early_stop = EarlyStopping(patience=10, verbose=1)
# Reduce learning rate when a metric has stopped improving
reduce_lr = ReduceLROnPlateau(factor=0.2, patience=3, cooldown=3, verbose=1)
# Save the best model after every epoch
check_point = ModelCheckpoint(filepath=model_and_weights, verbose=1, save_best_only=True)
# Split data into train and validation set (85/15)
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.15)
history = model.fit(X_train, y_train, batch_size=500, epochs=NUM_EPOCHS, verbose=1,
validation_data=(X_val, y_val),
callbacks=[check_point, early_stop, reduce_lr])
loss_path = os.path.join("plots", "loss_vs_epoch.pdf")
language_model.plot_loss_history(history, loss_path)
# Needed for 'object has no attribute 'TF_DeleteStatus'' error
K.clear_session()
if __name__ == "__main__":
main()