-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathBoxEModel.py
More file actions
1174 lines (1025 loc) · 69.4 KB
/
BoxEModel.py
File metadata and controls
1174 lines (1025 loc) · 69.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import tensorflow as tf
import KBUtils
import Cnst
import time
import msgpack
import msgpack_numpy as m
import numpy as np
import os
import copy
from TestFunctions import run_categorical_tests
from ModelOptions import ModelOptions
from math import ceil
from RuleParser import RuleParser, enforce_rule, recursive_tensor
import json
m.patch()
tf.logging.set_verbosity(tf.logging.ERROR)
zero = tf.constant(0.0, name="Zero")
half = tf.constant(0.5, name="Half")
one = tf.constant(1.0, name="One")
neg_zero = tf.constant(0.0, shape=[1, 1], name="Neg_Zero")
SANITY_EPS = 10 ** -8
NORM_LOG_BOUND = 1
FACTORIAL_TRICK = {1: 1, 2: 2, 3: 6, 4: 12, 5: 60, 6: 60, 7: 420, 8: 840, 9: 2520, 10: 2520}
BOTTOM_VALUE = np.nan
def sanitize_scatter(input_tensor):
return tf.where(tf.equal(input_tensor, 0), input_tensor + SANITY_EPS, input_tensor)
def delta_time_string(delta):
seconds = int(delta) % 60
minutes = int(delta / 60) % 60
hours = int(delta / 3600)
return str(hours) + ":" + str(minutes).zfill(2) + ":" + str(seconds).zfill(2)
def sg(x):
return tf.stop_gradient(x)
def print_or_log(input_string, log, log_file_path="log.txt"):
if not log:
print(input_string)
else:
log_file = open(log_file_path, "a+")
log_file.write(input_string + "\r\n")
log_file.close()
def transform(input_list, transformation_function):
output_tuple = tuple([transformation_function(x) for x in input_list])
return output_tuple
def q2b_loss(name: str, points, lower_corner, upper_corner, scale_mults): # Query2Box Loss Function
with tf.name_scope(name):
with tf.name_scope("Center"):
centres = 1 / 2 * (lower_corner + upper_corner)
dist_outside = tf.maximum(points - upper_corner, 0.0) + tf.maximum(lower_corner - points, 0.0)
dist_inside = centres - tf.minimum(upper_corner, tf.maximum(lower_corner, points))
return dist_outside, dist_inside
def polynomial_loss(name: str, points, lower_corner, upper_corner, scale_mults): # Standard Loss Function
with tf.name_scope(name):
with tf.name_scope("Width"):
widths = upper_corner - lower_corner
widths_p1 = widths + one
with tf.name_scope("Center"):
centres = 1 / 2 * (lower_corner + upper_corner)
with tf.name_scope("Width_Cond"):
width_cond = tf.where(tf.logical_and(lower_corner <= points, points <= upper_corner),
tf.abs(points - centres) / widths_p1,
widths_p1 * tf.abs(points - centres) - (widths / 2) *
(widths_p1 - 1 / widths_p1))
return width_cond
def total_box_size_reg(rel_deltas, reg_lambda, log_box_size): # Regularization based on total box size
rel_mean_width = tf.reduce_mean(tf.log(tf.abs(rel_deltas) + SANITY_EPS), axis=2)
min_width = sg(tf.reduce_min(rel_mean_width))
rel_width_ratios = tf.exp(rel_mean_width - min_width)
total_multiplier = tf.log(tf.reduce_sum(rel_width_ratios) + SANITY_EPS)
total_width = total_multiplier + min_width
size_constraint_loss = reg_lambda * (total_width - log_box_size) ** 2
return size_constraint_loss
def drpt(tensor, rate): # Dropout
return tf.cond(tf.greater(rate, 0), lambda: tf.nn.dropout(tensor, rate=rate), lambda: tensor)
def loss_function_q2b(batch_points, batch_mask, rel_bx_low, rel_bx_high, batch_rel_mults,
dim_dropout_prob=zero, order=2, alpha=0.2):
batch_box_inside, batch_box_outside = q2b_loss("Q2B_Box_Loss", batch_points, rel_bx_low, rel_bx_high,
batch_rel_mults)
bbi = tf.norm(drpt(batch_box_inside, rate=dim_dropout_prob), axis=2, ord=order)
bbi_masked = tf.reduce_sum(tf.multiply(bbi, batch_mask), axis=1)
bbo = tf.norm(drpt(batch_box_outside, rate=dim_dropout_prob), axis=2, ord=order)
bbo_masked = tf.reduce_sum(tf.multiply(bbo, batch_mask), axis=1)
total_loss = alpha * bbi_masked + bbo_masked
return total_loss
def loss_function_poly(batch_points, batch_mask, rel_bx_low, rel_bx_high, batch_rel_mults, dim_dropout_prob=zero,
order=1):
poly_loss = polynomial_loss("Poly_Box_Loss", batch_points, rel_bx_low, rel_bx_high, batch_rel_mults)
poly_loss = tf.norm(drpt(poly_loss, rate=dim_dropout_prob), axis=2, ord=order)
total_loss = tf.reduce_sum(tf.multiply(poly_loss, batch_mask), axis=1)
return total_loss
def compute_box(box_base, box_delta, name: str):
box_second = box_base + half * box_delta
box_first = box_base - half * box_delta
box_low = tf.minimum(box_first, box_second, name=name + "_low")
box_high = tf.maximum(box_first, box_second, name=name + "_high")
return box_low, box_high
def compute_box_np(box_base, box_delta):
box_second = box_base + 0.5 * box_delta
box_first = box_base - 0.5 * box_delta
box_low = np.minimum(box_first, box_second)
box_high = np.maximum(box_first, box_second)
return box_low, box_high
# Stop gradients enabled
def loss_computation_with_sg(batch_points, batch_mask, batch_rel_bases, batch_rel_deltas, batch_rel_mults,
bounded_pt_space: bool, bounded_box_space: bool, bound_scale: float, sgrad=Cnst.NO_STOPS,
original_batch_size=None, dim_drpt_prob=zero, loss_order=2, loss_fct=Cnst.POLY_LOSS):
obs = original_batch_size
loss_function = loss_function_poly if loss_fct == Cnst.POLY_LOSS else loss_function_q2b
if sgrad == Cnst.NO_STOPS or obs is None: # No Stop gradients, standard computation
with tf.name_scope("Standard_Loss_Computation"):
rel_bx_low, rel_bx_high = compute_box(batch_rel_bases, batch_rel_deltas, name="batch_rel_box")
batch_points = bound_scale * tf.tanh(batch_points) if bounded_pt_space else batch_points
if bounded_box_space:
rel_bx_low, rel_bx_high = transform([rel_bx_low, rel_bx_high], lambda x: bound_scale * tf.tanh(x))
loss_pos_neg = loss_function(batch_points=batch_points, batch_mask=batch_mask, rel_bx_low=rel_bx_low,
rel_bx_high=rel_bx_high, batch_rel_mults=batch_rel_mults, order=loss_order,
dim_dropout_prob=dim_drpt_prob)
with tf.name_scope("Values_Split"):
pos_loss = tf.cond(obs < tf.shape(loss_pos_neg)[0], lambda: loss_pos_neg[:obs], lambda: loss_pos_neg)
neg_loss = tf.cond(obs < tf.shape(loss_pos_neg)[0], lambda: loss_pos_neg[obs:], lambda: neg_zero)
else: # Stop Gradients introduced, so loss split into two computation streams to enable their use.
relation_all = [batch_rel_bases, batch_rel_deltas, batch_rel_mults]
with tf.name_scope("Loss_With_Stop_Grads"):
rel_bases_pos, rel_deltas_pos, rel_mults_pos = transform(relation_all, lambda x: x[:obs])
rel_bases_neg = sg(batch_rel_bases[obs:]) if sgrad[0] == Cnst.STOP else batch_rel_bases[obs:]
rel_deltas_neg = sg(batch_rel_deltas[obs:]) if sgrad[1] == Cnst.STOP else batch_rel_deltas[obs:]
batch_points_pos = batch_points[:obs]
batch_mask_pos = batch_mask[:obs]
batch_points_neg = batch_points[obs:]
batch_mask_neg = batch_mask[obs:]
rel_mults_neg = batch_rel_mults[obs:]
with tf.name_scope("Positive_Loss"):
rel_bx_low_pos, rel_bx_high_pos = compute_box(rel_bases_pos, rel_deltas_pos, name="batch_rel_box_pos")
batch_points_pos = bound_scale * tf.tanh(batch_points_pos) if bounded_pt_space else batch_points_pos
if bounded_box_space:
rel_bx_low_pos, rel_bx_high_pos = transform([rel_bx_low_pos, rel_bx_high_pos],
lambda x: bound_scale * tf.tanh(x))
pos_loss = loss_function(batch_points=batch_points_pos, dim_dropout_prob=dim_drpt_prob,
rel_bx_low=rel_bx_low_pos, rel_bx_high=rel_bx_high_pos,
batch_rel_mults=rel_mults_pos, order=loss_order, batch_mask=batch_mask_pos)
with tf.name_scope("Negative_Loss"):
rel_bx_low_neg, rel_bx_high_neg = compute_box(rel_bases_neg, rel_deltas_neg, name="batch_rel_box_neg")
batch_points_neg = bound_scale * tf.tanh(batch_points_neg) if bounded_pt_space else batch_points_neg
if bounded_box_space:
rel_bx_low_neg, rel_bx_high_neg = transform([rel_bx_low_neg, rel_bx_high_neg],
lambda x: bound_scale * tf.tanh(x))
neg_loss = loss_function(batch_points=batch_points_neg, dim_dropout_prob=dim_drpt_prob,
rel_bx_low=rel_bx_low_neg, rel_bx_high=rel_bx_high_neg,
batch_rel_mults=rel_mults_neg, order=loss_order, batch_mask=batch_mask_neg)
return pos_loss, neg_loss
# Negative Sampling function
def uniform_neg_sampling(nb_neg_examples_per_pos, batch_components, nb_entities, max_arity, return_replacements=False):
if nb_neg_examples_per_pos == 0:
return batch_components
batch_size = tf.shape(batch_components)[0]
batch_components_ent = batch_components[:, 1:-1]
pre_arities = tf.where(tf.equal(batch_components_ent, nb_entities), tf.zeros_like(batch_components_ent),
tf.ones_like(batch_components_ent))
arities = tf.reduce_sum(pre_arities, axis=1, keepdims=True)
random_count = batch_size * nb_neg_examples_per_pos
arities_tiled = tf.tile(arities, [nb_neg_examples_per_pos, 1])
with tf.name_scope("Position_Replacement_Choices"):
replacement_choice = tf.transpose(tf.random.categorical(tf.zeros([1, FACTORIAL_TRICK[max_arity]]), random_count,
dtype=tf.int32))
replacement_choice = tf.floormod(replacement_choice, arities_tiled)
with tf.name_scope("Replacement_Entity_Choice"):
new_entities_pre = tf.transpose(tf.random.categorical(tf.zeros([1, nb_entities - 1]),
random_count, dtype=tf.int32))[:, 0]
with tf.name_scope("Negative_Samples"):
nd_indices = tf.concat([tf.expand_dims(tf.range(random_count), axis=-1), replacement_choice + 1], axis=1)
negative_samples_pre = tf.tile(batch_components, [nb_neg_examples_per_pos, 1])[:, :-1]
negative_samples_with_val = tf.concat([negative_samples_pre, tf.fill([random_count, 1], -1)], axis=1)
replaced_ents = tf.gather_nd(negative_samples_with_val, nd_indices)
increment_entities = tf.cast(tf.greater_equal(new_entities_pre, replaced_ents), tf.int32)
new_entities = new_entities_pre + increment_entities
negative_sample_update = tf.scatter_nd(nd_indices, new_entities + 1, tf.shape(negative_samples_with_val))
negative_samples = tf.where(tf.greater(negative_sample_update, 0), negative_sample_update - 1,
negative_samples_with_val)
batch_components_neg = tf.concat([batch_components, negative_samples], axis=0)
nd_indices_post = tf.concat([batch_size + tf.expand_dims(tf.range(random_count), axis=-1),
replacement_choice + 1], axis=1)
if return_replacements:
return batch_components_neg, nd_indices_post
else:
return batch_components_neg
def create_uniform_var(name, shape, min_val, max_val):
with tf.name_scope(name):
with tf.variable_scope(name):
var = tf.get_variable(name="init_unif", initializer=tf.random_uniform(shape, min_val, max_val))
return var
def instantiate_box_embeddings(name: str, scale_mult_shape, rel_tbl_shape, base_norm_shapes, sqrt_dim,
hard_size: bool, total_size: float, relation_stats, fixed_width: bool):
with tf.name_scope(name):
if relation_stats is not None:
scale_multiples = relation_stats
else:
if fixed_width:
scale_multiples = tf.zeros(scale_mult_shape)
else:
scale_multiples = create_uniform_var("scale_multiples_" + name, scale_mult_shape, -1.0, 1.0)
if hard_size:
scale_multiples = total_size * tf.nn.softmax(scale_multiples, axis=0)
else:
scale_multiples = tf.nn.elu(scale_multiples) + one
embedding_base_points = create_uniform_var(name + "_base_point", rel_tbl_shape, -0.5 / sqrt_dim,
0.5 / sqrt_dim)
embedding_deltas = tf.multiply(scale_multiples, base_norm_shapes, name=name + "delta")
return embedding_base_points, embedding_deltas, scale_multiples
def apply_flag_sg(input_tensor, flag):
return tf.multiply(flag, input_tensor) + tf.multiply(one - flag, sg(input_tensor))
def product_normalise(input_tensor, bounded_norm=True):
step1_tensor = tf.abs(input_tensor)
step2_tensor = step1_tensor + SANITY_EPS
log_norm_tensor = tf.log(step2_tensor)
step3_tensor = tf.reduce_mean(log_norm_tensor, axis=2, keepdims=True)
norm_volume = tf.exp(step3_tensor)
pre_norm_out = input_tensor / norm_volume
if not bounded_norm:
return pre_norm_out
else:
minsize_tensor = tf.minimum(tf.reduce_min(log_norm_tensor, axis=2, keepdims=True), -NORM_LOG_BOUND)
maxsize_tensor = tf.maximum(tf.reduce_max(log_norm_tensor, axis=2, keepdims=True), NORM_LOG_BOUND)
minsize_ratio = -NORM_LOG_BOUND / minsize_tensor
maxsize_ratio = NORM_LOG_BOUND / maxsize_tensor
size_norm_ratio = tf.minimum(minsize_ratio, maxsize_ratio)
normed_tensor = log_norm_tensor * size_norm_ratio
return tf.exp(normed_tensor)
def add_padding(input_tensor):
return tf.concat([input_tensor, tf.zeros([1, tf.shape(input_tensor)[1]])], axis=0)
def apply_sg_neg(input_tensor, replacement_indices):
input_tensor_2 = sanitize_scatter(input_tensor)
values = sg(tf.gather_nd(input_tensor_2, replacement_indices))
replacement_mask = tf.scatter_nd(replacement_indices, values, tf.shape(input_tensor))
return tf.where(tf.greater(replacement_mask, 0), replacement_mask, input_tensor)
def corrupt_batch(batch, idx, nb_entities, hash_table, filtered=True): # Generate Corrupted Data for validation
nb_batch_facts = tf.shape(batch)[0]
replacement_ents = tf.fill([nb_batch_facts], tf.cast(idx % nb_entities, tf.int32)) # the replacement values
rep_ar_pos = tf.fill([nb_batch_facts, 1], tf.cast(idx // nb_entities, tf.int32)) # output_type not available
replacement_idx = tf.concat([tf.expand_dims(tf.range(nb_batch_facts), axis=-1), rep_ar_pos + 1], axis=1)
replacement_mask = tf.scatter_nd(replacement_idx, replacement_ents + 1, tf.shape(batch))
new_batch = tf.where(tf.greater(replacement_mask, 0), replacement_mask - 1, batch) # Replace everywhere
if filtered:
# Filtering Mechanism
input_keys = tf.strings.reduce_join(tf.strings.as_string(new_batch[:, :-1]), axis=1,
separator=Cnst.FACT_DELIMITER)
fact_exists = hash_table.lookup(input_keys) # Get which values are in there
fact_exists_bool = tf.greater(fact_exists, 0) # 0 implies not in KB, so keep replaced, else restore
original_ents = batch[:, idx // nb_entities + 1] # Get the original batch values
original_ents_filt = tf.boolean_mask(original_ents, fact_exists_bool)
replacement_idx_filt = tf.boolean_mask(replacement_idx, fact_exists_bool) # Which of the replaced exists?
# Replacement of entities with values where applicable: + 1 to avoid 0 index
replacement_mask = tf.scatter_nd(replacement_idx_filt, original_ents_filt + 1, tf.shape(new_batch))
new_batch = tf.where(tf.greater(replacement_mask, 0), replacement_mask - 1, new_batch)
return new_batch
class BoxEMulti:
def __init__(self, kb_name, options: ModelOptions, suffix: str = ""):
self.options = options
self.embedding_dim = options.embedding_dim
self.neg_sampling_opt = options.neg_sampling_opt # Negative Sampling mode
self.adv_temp = options.adversarial_temp # Adversarial Temperature (if applicable)
self.nb_neg_examples_per_pos = options.nb_neg_examples_per_pos # Number of negative samples per fact
self.nb_neg = tf.Variable(self.nb_neg_examples_per_pos, dtype=tf.int32)
self.learning_rate = options.learning_rate
self.stop_gradient = options.stop_gradient # Use stop gradient
self.sg_neg = options.stop_gradient_negated # Stop gradient for negative examples
self.margin = options.loss_margin # Margin for loss function
self.reg_lambda = options.regularisation_lambda # Regularization options
self.reg_points = options.regularisation_points
self.total_log_box_size = options.total_log_box_size
self.batch_size = options.batch_size
self.use_bumps = options.use_bumps # Enable entity bumps
self.hard_total_size = options.hard_total_size # Fix a total both size
self.shared_shape = options.shared_shape # Fix a common box shape
self.learnable_shape = options.learnable_shape # Shape a fixed value or
self.fixed_width = options.fixed_width # Fix the overall box volume
self.param_directory = "weights_" + kb_name + "/values.ckpt"
self.saver = None
self.sess = None
self.use_tensorboard = options.use_tensorboard
self.bounded_pt_space = options.bounded_pt_space # Map points with tanh activation
self.bounded_box_space = options.bounded_box_space # Map boxes with tanh
self.bound_scale = options.space_bound # Multiplier to apply on ]-1,1[ range.
self.obj_fct = options.obj_fct # Objective Function
self.loss_fct = options.loss_fct # Loss Function
self.loss_ord = options.loss_norm_ord # Loss Norm Order (1,2,etc..)
self.dim_dropout_prob = tf.Variable(initial_value=options.dim_dropout_prob, shape=(), name='dim_drpt_prob')
self.dim_dropout_flt = options.dim_dropout_prob
# KB setting
self.kb_name = kb_name
kb_metadata = KBUtils.load_kb_metadata_multi(kb_name)
self.nb_entities = kb_metadata[0]
self.nb_relations = kb_metadata[1]
self.hard_code_size = options.hard_code_size # Set all boxes to fixed sizes based on fact statistics
self.sqrt_dim = tf.sqrt(self.embedding_dim + 0.0)
self.gradient_clip = options.gradient_clip
self.bounded_norm = options.bounded_norm
self.max_arity = kb_metadata[2]
self.augment_inv = options.augment_inv
self.original_nb_rel = self.nb_relations
if self.augment_inv: # Data Augmentation (define inverse relation and train on inverse fact)
if self.max_arity > 2:
print("Unable to use data augmentation, dataset is not a knowledge graph. Setting Aug to False")
self.augment_inv = False
else:
self.nb_relations = 2 * self.nb_relations
if options.hard_code_size:
relation_stats = KBUtils.compute_statistics(kb_name)
relation_stats = relation_stats ** (1 / self.embedding_dim)
else:
relation_stats = None
self.lr_decay = options.learning_rate_decay
self.lr_decay_period = options.decay_period
self.rule_dir = options.rule_dir
self.normed_bumps = options.normed_bumps
with tf.name_scope('Entity_Embeddings' + suffix): # Instantiate Entity Embeddings
entity_table_shape = [self.nb_entities, self.embedding_dim]
self.entity_points = create_uniform_var("entity_embeddings" + suffix, entity_table_shape,
-0.5 / self.sqrt_dim, 0.5 / self.sqrt_dim)
self.entities_with_pad = add_padding(self.entity_points)
if self.use_bumps: # Translational Bumps
self.entity_bumps = create_uniform_var("entity_bump_embeddings" + suffix, entity_table_shape,
-0.5 / self.sqrt_dim, 0.5 / self.sqrt_dim)
if self.normed_bumps: # Normalization of bumps option
self.entity_bumps = tf.math.l2_normalize(self.entity_bumps, axis=1)
self.bumps_with_pad = add_padding(self.entity_bumps)
rel_tbl_shape = [self.nb_relations, self.max_arity, self.embedding_dim]
scale_multiples_shape = [self.nb_relations, self.max_arity, 1]
tile_shape = [self.nb_relations, 1, 1]
with tf.name_scope('Relation_Embeddings' + suffix): # Relation Embedding Instantiation
if self.shared_shape: # Shared box shape
base_shape = [1, self.max_arity, self.embedding_dim]
tile_var = True
else: # Variable box shape
base_shape = rel_tbl_shape
tile_var = False
if self.learnable_shape: # If shape is learnable, define variables accordingly
self.rel_shapes = create_uniform_var("rel_shape" + suffix,
base_shape, -0.5 / self.sqrt_dim, 0.5 / self.sqrt_dim)
self.norm_rel_shapes = product_normalise(self.rel_shapes, self.bounded_norm)
else: # Otherwise set all boxes as one-hypercubes
self.norm_rel_shapes = tf.ones(base_shape, name="rel_shape" + suffix)
if tile_var:
self.norm_rel_shapes = tf.tile(self.norm_rel_shapes, tile_shape)
self.total_size = np.exp(options.total_log_box_size) if self.hard_total_size else -1
self.rel_bases, self.rel_deltas, self.rel_multiples = \
instantiate_box_embeddings("rel" + suffix, scale_multiples_shape, rel_tbl_shape,
self.norm_rel_shapes, self.sqrt_dim, self.hard_total_size,
self.total_size, relation_stats, self.fixed_width)
if self.rule_dir: # Rule Injection logic
self.rel_bx_lows, self.rel_bx_highs = compute_box(self.rel_bases, self.rel_deltas, name="rel_bx")
if self.bounded_box_space:
self.rel_bx_lows = self.bound_scale * tf.tanh(self.rel_bx_lows)
self.rel_bx_highs = self.bound_scale * tf.tanh(self.rel_bx_highs)
if self.sg_neg or self.stop_gradient[0] == Cnst.STOP or self.stop_gradient[1] == Cnst.STOP:
self.stop_gradient = Cnst.NO_STOPS
self.sg_neg = False
print("Stop Gradients Not Implemented Yet in Rule Injection Mode. No Stop Gradients Applied...")
rule_parser = RuleParser(self.rule_dir) # Parse the rules
parsed_rules = rule_parser.get_parsed_rules()
self.rule_boxes = [self.rel_bx_lows]
for rule_i in parsed_rules: # Iterate over them linearly (hence rules must be ordered appropriately)
recursive_tensor(rule_i[0], self.rel_bx_lows, self.rel_bx_highs)
recursive_tensor(rule_i[1], self.rel_bx_lows, self.rel_bx_highs)
self.rel_bx_lows, self.rel_bx_highs = enforce_rule(rule_i, self.rel_bx_lows, self.rel_bx_highs)
self.rule_boxes.append(self.rel_bx_lows) # Keep a store of box configuration over rule injection
with tf.name_scope("Training_Data_Pipeline"): # Data setup
tr_np_arr = KBUtils.load_kb_file(Cnst.DEFAULT_KB_MULTI_DIR + str(kb_name) + "/train" + Cnst.KB_FORMAT)
if not options.restricted_training:
self.nb_training_facts = tr_np_arr.shape[0]
else:
tr_np_arr = tr_np_arr[:options.restriction, :]
self.nb_training_facts = options.restriction
if self.augment_inv:
tr_np_arr_augmentation = np.zeros_like(tr_np_arr)
tr_np_arr_augmentation[:, 0] = tr_np_arr[:, 0] + self.original_nb_rel
tr_np_arr_augmentation[:, 1] = tr_np_arr[:, 2]
tr_np_arr_augmentation[:, 2] = tr_np_arr[:, 1]
tr_np_arr_augmentation[:, 3] = tr_np_arr[:, 3]
tr_np_arr = np.concatenate([tr_np_arr, tr_np_arr_augmentation], axis=0)
self.nb_training_facts = 2 * self.nb_training_facts
self.nb_tr_batches = ceil(self.nb_training_facts / self.batch_size)
self.tr_dataset = tf.data.Dataset.from_tensor_slices(tr_np_arr)
self.tr_dataset = self.tr_dataset.shuffle(self.nb_training_facts, reshuffle_each_iteration=True)
self.tr_dataset = self.tr_dataset.batch(self.batch_size)
# Negative Sampling
if self.neg_sampling_opt == Cnst.UNIFORM or self.neg_sampling_opt == Cnst.SELFADV:
self.tr_dataset = self.tr_dataset.map(lambda facts: uniform_neg_sampling(self.nb_neg_examples_per_pos,
batch_components=facts,
nb_entities=self.nb_entities,
max_arity=self.max_arity,
return_replacements=self.sg_neg
), num_parallel_calls=8)
self.tr_dataset.prefetch(1)
hash_tbl = KBUtils.create_kb_filter_tf(self.kb_name) # Creating filter for evaluation
with tf.name_scope("Val_Data_Pipeline"):
vl_np_arr = KBUtils.load_kb_file(Cnst.DEFAULT_KB_MULTI_DIR + str(kb_name) + "/valid" + Cnst.KB_FORMAT)
self.nb_vl_facts = vl_np_arr.shape[0]
self.vl_dataset = tf.data.Dataset.from_tensor_slices(vl_np_arr)
self.vl_dataset = self.vl_dataset.batch(self.nb_vl_facts)
self.vl_dataset.prefetch(16)
self.vl_dataset_corr = tf.data.Dataset.from_tensor_slices(vl_np_arr).batch(self.nb_vl_facts) \
.repeat(self.max_arity * self.nb_entities)
rep_idx = tf.data.Dataset.range(self.max_arity * self.nb_entities) # int64, output_type only in TF2, cast
self.vl_dataset_corr = tf.data.Dataset.zip((self.vl_dataset_corr, rep_idx))
self.vl_dataset_corr = self.vl_dataset_corr.map(lambda batch, idx: corrupt_batch(batch, idx,
self.nb_entities, hash_tbl),
num_parallel_calls=8)
self.vl_dataset_corr.prefetch(16)
with tf.name_scope("Tr_Tst_Data_Pipeline"):
tr_ts_np_arr = tr_np_arr[:3 * self.batch_size, :]
self.nb_tr_ts_facts = tr_ts_np_arr.shape[0]
self.tr_ts_dataset = tf.data.Dataset.from_tensor_slices(tr_ts_np_arr)
self.tr_ts_dataset = self.tr_ts_dataset.batch(self.nb_tr_ts_facts)
self.tr_ts_dataset.prefetch(16)
# Trying something different here...
self.tr_ts_dataset_corr = tf.data.Dataset.from_tensor_slices(tr_ts_np_arr).batch(self.nb_tr_ts_facts)\
.repeat(self.max_arity * self.nb_entities)
rep_idx = tf.data.Dataset.range(self.max_arity * self.nb_entities) # int64, output_type only in TF2, cast
self.tr_ts_dataset_corr = tf.data.Dataset.zip((self.tr_ts_dataset_corr, rep_idx))
self.tr_ts_dataset_corr = self.tr_ts_dataset_corr.map(lambda batch,
idx: corrupt_batch(batch, idx, self.nb_entities,
hash_tbl),
num_parallel_calls=8)
self.tr_ts_dataset_corr.prefetch(16)
with tf.name_scope("Tst_Data_Pipeline"):
ts_np_arr = KBUtils.load_kb_file(Cnst.DEFAULT_KB_MULTI_DIR + str(kb_name) + "/test" + Cnst.KB_FORMAT)
self.nb_ts_facts = ts_np_arr.shape[0]
self.ts_dataset = tf.data.Dataset.from_tensor_slices(ts_np_arr)
self.ts_dataset = self.ts_dataset.batch(self.nb_ts_facts)
self.ts_dataset.prefetch(16)
self.ts_dataset_corr = tf.data.Dataset.from_tensor_slices(ts_np_arr).batch(self.nb_ts_facts) \
.repeat(self.max_arity * self.nb_entities)
rep_idx = tf.data.Dataset.range(self.max_arity * self.nb_entities) # int64, output_type only in TF2, cast
self.ts_dataset_corr = tf.data.Dataset.zip((self.ts_dataset_corr, rep_idx))
self.ts_dataset_corr = self.ts_dataset_corr.map(lambda batch, idx: corrupt_batch(batch, idx,
self.nb_entities,
hash_tbl),
num_parallel_calls=8)
self.ts_dataset_corr.prefetch(16)
with tf.name_scope("Iterator"): # Data Iterator
self.iterator = tf.data.Iterator.from_structure(self.tr_dataset.output_types, self.tr_dataset.output_shapes)
self.next_batch = self.iterator.get_next()
if self.sg_neg and self.nb_neg_examples_per_pos > 0:
self.batch_components, self.replaced_indices = self.next_batch
else:
self.batch_components = self.next_batch
self.original_batch_size = tf.div(tf.shape(self.batch_components)[0], 1 + self.nb_neg)
self.training_init_op = self.iterator.make_initializer(self.tr_dataset)
self.valid_init_op = self.iterator.make_initializer(self.vl_dataset)
self.valid_corr_init_op = self.iterator.make_initializer(self.vl_dataset_corr)
self.test_init_op = self.iterator.make_initializer(self.ts_dataset)
self.test_corr_init_op = self.iterator.make_initializer(self.ts_dataset_corr)
self.tr_test_init_op = self.iterator.make_initializer(self.tr_ts_dataset)
self.tr_test_corr_init_op = self.iterator.make_initializer(self.tr_ts_dataset_corr)
with tf.name_scope("Batch_Points"): # Batch Lookups
self.batch_points = tf.nn.embedding_lookup(self.entities_with_pad,
self.batch_components[:, 1: self.max_arity + 1],
name="batch_pts")
if self.use_bumps:
self.batch_bumps = tf.nn.embedding_lookup(self.bumps_with_pad,
self.batch_components[:, 1: self.max_arity + 1],
name="bump_pts")
with tf.name_scope("Bumps"): # Application of bumps
self.batch_bump_sum = tf.reduce_sum(self.batch_bumps, axis=1, keepdims=True)
self.batch_point_representations = self.batch_points
if self.use_bumps:
self.batch_point_representations += self.batch_bump_sum - self.batch_bumps
self.batch_components_ent = self.batch_components[:, 1:-1]
self.batch_mask = tf.where(tf.equal(self.batch_components_ent, self.nb_entities),
tf.zeros_like(self.batch_components_ent, dtype=tf.float32),
tf.ones_like(self.batch_components_ent, dtype=tf.float32))
with tf.name_scope("Batch_Rel_Params"):
self.batch_rel_bases = tf.nn.embedding_lookup(self.rel_bases, self.batch_components[:, 0],
name='batch_rel_bases')
self.batch_rel_deltas = tf.nn.embedding_lookup(self.rel_deltas, self.batch_components[:, 0],
name='batch_rel_deltas')
self.batch_rel_mults = tf.nn.embedding_lookup(self.rel_multiples, self.batch_components[:, 0],
name='batch_rel_multiples')
if self.rule_dir:
self.batch_rel_bx_lows = tf.nn.embedding_lookup(self.rel_bx_lows, self.batch_components[:, 0],
name='batch_rel_bx_lows')
self.batch_rel_bx_highs = tf.nn.embedding_lookup(self.rel_bx_highs, self.batch_components[:, 0],
name='batch_rel_bx_highs')
loss_function = loss_function_poly if self.loss_fct == Cnst.POLY_LOSS else loss_function_q2b
with tf.name_scope("Standard_Loss_Computation"):
obs = self.original_batch_size
if self.bounded_pt_space:
self.batch_point_representations = self.bound_scale * tf.tanh(self.batch_point_representations)
loss_pos_neg = loss_function(batch_points=self.batch_point_representations, batch_mask=self.batch_mask,
rel_bx_low=self.batch_rel_bx_lows, rel_bx_high=self.batch_rel_bx_highs,
batch_rel_mults=self.batch_rel_mults, order=self.loss_ord,
dim_dropout_prob=self.dim_dropout_prob)
with tf.name_scope("Values_Split"):
self.positive_loss = tf.cond(obs < tf.shape(loss_pos_neg)[0], lambda: loss_pos_neg[:obs],
lambda: loss_pos_neg)
self.negative_loss = tf.cond(obs < tf.shape(loss_pos_neg)[0], lambda: loss_pos_neg[obs:],
lambda: neg_zero)
else:
if self.sg_neg and self.nb_neg_examples_per_pos > 0:
with tf.name_scope("Negated_Replacement_Stop_Grad"):
self.batch_rel_bases = apply_sg_neg(self.batch_rel_bases, self.replaced_indices)
self.batch_rel_deltas = apply_sg_neg(self.batch_rel_deltas, self.replaced_indices)
self.batch_rel_mults = apply_sg_neg(self.batch_rel_mults, self.replaced_indices)
self.positive_loss, self.negative_loss = \
loss_computation_with_sg(
batch_points=self.batch_point_representations, batch_mask=self.batch_mask, loss_fct=self.loss_fct,
batch_rel_deltas=self.batch_rel_deltas, batch_rel_mults=self.batch_rel_mults,
bounded_box_space=self.bounded_box_space, bound_scale=self.bound_scale, sgrad=self.stop_gradient,
bounded_pt_space=self.bounded_pt_space, original_batch_size=self.original_batch_size,
loss_order=self.loss_ord, dim_drpt_prob=self.dim_dropout_prob, batch_rel_bases=self.batch_rel_bases)
if self.obj_fct == Cnst.NEG_SAMP:
self.loss_pos = tf.log(tf.nn.sigmoid(self.margin - self.positive_loss) + SANITY_EPS)
elif self.obj_fct == Cnst.MARGIN_BASED:
self.loss_pos = self.positive_loss
if self.nb_neg_examples_per_pos > 0:
if self.neg_sampling_opt == Cnst.UNIFORM:
if self.obj_fct == Cnst.NEG_SAMP: # Standard Objective
self.loss_neg = tf.log(tf.nn.sigmoid(self.negative_loss - self.margin) + SANITY_EPS)
self.loss_n_term = tf.reduce_sum(self.loss_neg) / self.nb_neg_examples_per_pos
elif self.obj_fct == Cnst.MARGIN_BASED: # Objective used in TransE
self.reshaped_neg_dists = tf.reshape(self.negative_loss, [self.nb_neg_examples_per_pos,
self.original_batch_size])
self.reshaped_neg_dists = tf.transpose(self.reshaped_neg_dists, perm=[1, 0],
name='transposed_neg', conjugate=False)
self.loss_neg = tf.reduce_mean(self.reshaped_neg_dists, axis=1)
self.loss_n_term = tf.reduce_sum(
self.loss_neg)
elif self.neg_sampling_opt == Cnst.SELFADV:
self.reshaped_neg_dists = tf.reshape(self.negative_loss, [self.nb_neg_examples_per_pos,
self.original_batch_size])
self.reshaped_neg_dists = tf.transpose(self.reshaped_neg_dists, perm=[1, 0],
name='transposed_neg', conjugate=False)
self.softmax_pre_scores = tf.negative(self.reshaped_neg_dists, name="Negated_Dists") * self.adv_temp
self.neg_softmax = sg(tf.nn.softmax(self.softmax_pre_scores, axis=1, name="softmax_weights"))
if self.obj_fct == Cnst.NEG_SAMP:
self.loss_neg_batch = tf.log(tf.nn.sigmoid(self.reshaped_neg_dists - self.margin) + SANITY_EPS)
self.loss_neg = tf.multiply(self.neg_softmax, self.loss_neg_batch, name="Self-Adversarial_Loss")
elif self.obj_fct == Cnst.MARGIN_BASED:
self.loss_neg = tf.multiply(self.neg_softmax, self.reshaped_neg_dists, name="Self-Adversarial_Loss")
self.loss_n_term = tf.reduce_sum(self.loss_neg)
else:
self.loss_n_term = tf.constant(0.0)
self.loss_p_term = tf.reduce_sum(self.loss_pos)
if self.reg_lambda > 0 and not self.hard_total_size:
if self.fixed_width:
print("Box size regularization with fixed widths is redundant, so regularization has been disabled")
self.reg_lambda = -1
self.reg_loss = 0.0
else:
self.reg_loss = total_box_size_reg(rel_deltas=self.rel_deltas, reg_lambda=self.reg_lambda,
log_box_size=self.total_log_box_size)
else:
self.reg_loss = 0.0
if self.obj_fct == Cnst.NEG_SAMP:
self.loss = - self.loss_n_term - self.loss_p_term + self.reg_loss
elif self.obj_fct == Cnst.MARGIN_BASED:
self.loss = tf.reduce_sum(tf.maximum(0.0, self.margin + self.loss_pos - self.loss_neg))
if self.reg_points > 0:
self.loss += self.reg_points * (tf.nn.l2_loss(self.batch_point_representations) +
tf.nn.l2_loss(self.batch_rel_bases))
if self.use_tensorboard:
with tf.name_scope('Loss_Terms'):
if self.obj_fct == Cnst.NEG_SAMP:
self.pos_loss_summary = tf.summary.scalar('pos_loss', - self.loss_p_term)
self.neg_loss_summary = tf.summary.scalar('neg_loss', - self.loss_n_term)
elif self.obj_fct == Cnst.MARGIN_BASED:
self.pos_loss_summary = tf.summary.scalar('pos_loss', self.loss_p_term)
self.neg_loss_summary = tf.summary.scalar('neg_loss', self.loss_n_term)
self.reg_loss_summary = tf.summary.scalar('reg_loss', self.reg_loss)
self.total_loss_summary = tf.summary.scalar('loss', self.loss)
self.loss_summaries = tf.summary.merge([self.pos_loss_summary, self.neg_loss_summary,
self.reg_loss_summary, self.total_loss_summary])
self.global_step = tf.Variable(0, trainable=False)
if self.lr_decay > 0:
decay_step = self.lr_decay_period * self.nb_tr_batches
self.lr_with_decay = tf.train.inverse_time_decay(self.learning_rate, global_step=self.global_step,
decay_rate=self.lr_decay, decay_steps=decay_step)
self.optimiser = tf.train.AdamOptimizer(learning_rate=self.lr_with_decay)
else:
self.optimiser = tf.train.AdamOptimizer(learning_rate=self.learning_rate)
if self.gradient_clip > 0:
gradients, variables = zip(*self.optimiser.compute_gradients(self.loss))
gradients, _ = tf.clip_by_global_norm(gradients, self.gradient_clip)
self.train_op = self.optimiser.apply_gradients(zip(gradients, variables), name='minimize',
global_step=self.global_step)
else:
self.train_op = self.optimiser.minimize(self.loss, name='minimize', global_step=self.global_step)
self.scores = tf.expand_dims(self.positive_loss, axis=-1)
if self.use_tensorboard: # TensorBoard configuration
self.sess = tf.Session()
self.summary_writer = None
self.average_epoch_loss = tf.placeholder(tf.float32, shape=None, name='per_epoch_loss')
self.epoch_loss_summary = tf.summary.scalar('Average Epoch Loss', self.average_epoch_loss)
with tf.name_scope('Training_Acc'):
self.train_cat_acc = tf.placeholder(tf.float32, shape=None, name='training_cat_acc')
self.train_cat_acc_summary = tf.summary.scalar('Training Cat Accuracy', self.train_cat_acc)
self.train_mr = tf.placeholder(tf.float32, shape=None, name='train_mean_rank')
self.train_mr_summary = tf.summary.scalar('Training Mean Rank', self.train_mr)
self.train_mrr = tf.placeholder(tf.float32, shape=None, name='train_mean_reciprocal_rank')
self.train_mrr_summary = tf.summary.scalar('Training Mean Reciprocal Rank', self.train_mrr)
self.train_h_at_1 = tf.placeholder(tf.float32, shape=None, name='train_hits_at_1')
self.train_h_at_1_summary = tf.summary.scalar('Train Hits@1', self.train_h_at_1)
self.train_h_at_3 = tf.placeholder(tf.float32, shape=None, name='train_hits_at_3')
self.train_h_at_3_summary = tf.summary.scalar('Train Hits@3', self.train_h_at_3)
self.train_h_at_5 = tf.placeholder(tf.float32, shape=None, name='train_hits_at_5')
self.train_h_at_5_summary = tf.summary.scalar('Train Hits@5', self.train_h_at_5)
self.train_h_at_10 = tf.placeholder(tf.float32, shape=None, name='train_hits_at_10')
self.train_h_at_10_summary = tf.summary.scalar('Train Hits@10', self.train_h_at_10)
self.train_summaries = tf.summary.merge([self.train_cat_acc_summary, self.train_mr_summary,
self.train_mrr_summary, self.train_h_at_1_summary,
self.train_h_at_3_summary, self.train_h_at_5_summary,
self.train_h_at_10_summary])
with tf.name_scope("Validation_Acc"):
self.valid_cat_acc = tf.placeholder(tf.float32, shape=None, name='valid_cat_acc')
self.valid_cat_acc_summary = tf.summary.scalar('Validation Cat Accuracy', self.valid_cat_acc)
self.valid_mr = tf.placeholder(tf.float32, shape=None, name='valid_mean_rank')
self.valid_mr_summary = tf.summary.scalar('Validation Mean Rank', self.valid_mr)
self.valid_mrr = tf.placeholder(tf.float32, shape=None, name='valid_mean_reciprocal_rank')
self.valid_mrr_summary = tf.summary.scalar('Valid Mean Reciprocal Rank', self.valid_mrr)
self.valid_h_at_1 = tf.placeholder(tf.float32, shape=None, name='valid_hits_at_1')
self.valid_h_at_1_summary = tf.summary.scalar('Valid Hits@1', self.valid_h_at_1)
self.valid_h_at_3 = tf.placeholder(tf.float32, shape=None, name='valid_hits_at_3')
self.valid_h_at_3_summary = tf.summary.scalar('Valid Hits@3', self.valid_h_at_3)
self.valid_h_at_5 = tf.placeholder(tf.float32, shape=None, name='valid_hits_at_5')
self.valid_h_at_5_summary = tf.summary.scalar('Valid Hits@5', self.valid_h_at_5)
self.valid_h_at_10 = tf.placeholder(tf.float32, shape=None, name='valid_hits_at_10')
self.valid_h_at_10_summary = tf.summary.scalar('Valid Hits@10', self.valid_h_at_10)
self.valid_summaries = tf.summary.merge([self.valid_cat_acc_summary, self.valid_mr_summary,
self.valid_mrr_summary, self.valid_h_at_1_summary,
self.valid_h_at_3_summary, self.valid_h_at_5_summary,
self.valid_h_at_10_summary])
def get_rule_boxes(self):
if not self.rule_dir:
print("Not in Rule Mode!")
return None
else:
rel_bx_low, rel_bx_high = self.sess.run([self.rel_bx_lows, self.rel_bx_highs])
return rel_bx_low, rel_bx_high
def create_feed_dict(self, batch_components=None):
feed_dict = {}
if batch_components is not None:
feed_dict[self.batch_components] = batch_components
return feed_dict
def check_scale_mults(self, reload_params=True, param_loc=None):
if reload_params:
self.load_params(param_loc)
scores = self.sess.run(self.rel_multiples)
return scores
def check_shapes(self, reload_params=True, param_loc=None):
if reload_params:
self.load_params(param_loc)
scores = self.sess.run(self.norm_rel_shapes)
return scores
def check_box_pos(self, reload_params=True, param_loc=None):
if reload_params:
self.load_params(param_loc)
scores = self.sess.run(self.rel_bases)
return scores
def check_reg_loss(self, reload_params=True, param_loc=None):
if reload_params:
self.load_params(param_loc)
scores = self.sess.run(self.reg_loss)
return scores
def score_forward_pass(self, batch_components, reload_params=True, param_loc=None):
feed_dict = self.create_feed_dict(batch_components)
if self.nb_neg_examples_per_pos > 0:
feed_dict[self.original_batch_size] = batch_components.shape[0]
if self.sg_neg:
feed_dict[self.replaced_indices] = np.array([[0, 0]])
feed_dict[self.dim_dropout_prob] = 0.0
if reload_params:
self.load_params(param_loc)
scores = self.sess.run(self.scores, feed_dict=feed_dict)
return scores
def get_mask(self, batch_components, reload_params=True, param_loc=None):
feed_dict = self.create_feed_dict(batch_components)
if self.nb_neg_examples_per_pos > 0:
feed_dict[self.original_batch_size] = batch_components.shape[0]
if self.sg_neg:
feed_dict[self.replaced_indices] = np.array([[0, 0]])
feed_dict[self.dim_dropout_prob] = 0.0
if reload_params:
self.load_params(param_loc)
mask = self.sess.run(self.batch_mask, feed_dict=feed_dict)
return mask
def compute_box_volume(self, reload_params=True, param_loc=None):
if reload_params:
self.load_params(param_loc)
r_bases, r_deltas = self.sess.run([self.rel_bases, self.rel_deltas])
r_low, r_high = compute_box_np(r_bases, r_deltas)
if self.bounded_box_space:
r_low = self.bound_scale * np.tanh(r_low)
r_high = self.bound_scale * np.tanh(r_high)
r_log_widths = np.mean(np.log(r_high - r_low + SANITY_EPS), axis=-1, keepdims=False)
r_geom_width = np.exp(r_log_widths)
return r_geom_width
def categorical_forward_pass(self, batch_components, reload_params=True, param_loc=None):
feed_dict = self.create_feed_dict(batch_components)
if self.nb_neg_examples_per_pos > 0:
feed_dict[self.original_batch_size] = batch_components.shape[0]
if self.sg_neg:
feed_dict[self.replaced_indices] = np.array([[0, 0]])
if reload_params:
self.load_params(param_loc)
points, mask, r_bases, r_deltas = self.sess.run([self.batch_point_representations, self.batch_mask,
self.batch_rel_bases, self.batch_rel_deltas],
feed_dict=feed_dict)
mask_bool = (mask <= 0.0)
r_low, r_high = compute_box_np(r_bases, r_deltas)
points_inside = np.logical_and(points >= r_low, points <= r_high)
points_inside_masked = np.logical_or(points_inside, np.expand_dims(mask_bool, axis=-1))
points_inside_boxes = np.all(points_inside_masked, axis=2)
scores = np.all(points_inside_boxes, axis=1) * 1
return scores
def load_params(self, param_loc=None):
if self.saver is None:
self.saver = tf.train.Saver(name="Saver")
self.sess = tf.Session()
if param_loc is None:
param_loc = self.param_directory
try:
self.saver.restore(self.sess, param_loc)
self.sess.run(tf.tables_initializer())
except Exception as e:
print(e)
self.sess.run(tf.global_variables_initializer())
self.sess.run(tf.tables_initializer())
def get_pair_points(self, batch_components, reload_params=True, param_loc=None):
feed_dict = self.create_feed_dict(batch_components)
if reload_params:
self.load_params(param_loc)
batch_points = self.sess.run([self.batch_points], feed_dict=feed_dict)
return batch_points
def get_relation_average_norm_width(self):
rel_batch = np.arange(self.nb_relations)
batch_components = np.zeros([self.nb_relations, self.max_arity + 2])
batch_components[:, 0] = rel_batch
feed_dict = self.create_feed_dict(batch_components=batch_components)
self.load_params(None)
rel_deltas = self.sess.run([self.rel_deltas], feed_dict=feed_dict)
width_arith_mean = np.mean(np.mean(rel_deltas, axis=2), axis=1)
return width_arith_mean
def get_entity_embeddings(self, reload_params=True, param_loc=None):
if reload_params:
self.load_params(param_loc)
feed_dict = {}
ent_emb = self.sess.run([self.entity_points], feed_dict=feed_dict)
return ent_emb
def forward_pass(self, batch_components, reload_params=True, param_loc=None):
feed_dict = self.create_feed_dict(batch_components)
if reload_params:
self.load_params(param_loc)
batch_points, batch_mask, rel_bases, rel_deltas = self.sess.run([self.batch_points, self.batch_mask,
self.batch_rel_bases, self.batch_rel_deltas],
feed_dict=feed_dict)
return batch_points, rel_bases, rel_deltas
def validate(self, hits_at=None, verbose=True, dataset=Cnst.VALID): # Validation Function
init_op = self.valid_init_op if dataset == Cnst.VALID else self.test_init_op if dataset == Cnst.TEST \
else self.tr_test_init_op
corr_init_op = self.valid_corr_init_op if dataset == Cnst.VALID else self.test_corr_init_op \
if dataset == Cnst.TEST else self.tr_test_corr_init_op
nb_facts = self.nb_vl_facts if dataset == Cnst.VALID else self.nb_ts_facts if dataset == Cnst.TEST \
else self.nb_tr_ts_facts
if hits_at is None:
hits_at = [1, 3, 5, 10]
self.sess.run(tf.assign(self.nb_neg, 0))
self.sess.run(tf.assign(self.dim_dropout_prob, 0.0)) # Disable dropout during testing
self.sess.run(init_op)
reference_scores, batch_mask = self.sess.run([self.scores, self.batch_mask])
ranks = np.full((nb_facts, self.max_arity), 1)
self.sess.run(corr_init_op)
entities_seen = 0
while entities_seen < self.nb_entities * self.max_arity:
try:
current_ar = entities_seen // self.nb_entities
scores = self.sess.run(self.scores)
nb_ent = scores.shape[0] // nb_facts
reshaped_scores = np.reshape(scores, (nb_ent, nb_facts, 1))
if entities_seen // self.nb_entities < (entities_seen + nb_ent) // self.nb_entities:
ents_to_go = self.nb_entities - (entities_seen % self.nb_entities)
rank_ind_low = np.sum((reshaped_scores < reference_scores)[:ents_to_go, :, :] * 1,
axis=0, keepdims=False)
ranks[:, current_ar] += rank_ind_low[:, 0]
rank_ind_high = np.sum((reshaped_scores < reference_scores)[ents_to_go:, :, :] * 1,
axis=0, keepdims=False)
if current_ar + 1 < self.max_arity:
ranks[:, current_ar + 1] += rank_ind_high[:, 0]
else:
rank_indicator = np.sum((reshaped_scores < reference_scores) * 1, axis=0, keepdims=False)
ranks[:, current_ar] += rank_indicator[:, 0]
entities_seen += nb_ent
if verbose:
print(entities_seen)
except tf.errors.OutOfRangeError:
break
all_ranks = ranks[batch_mask > 0]
mean_rank = np.mean(all_ranks)
mean_reciprocal_rank = np.mean(1 / all_ranks)
hits_at_values = []
for x in hits_at:
hits_at_values.append(np.mean((all_ranks <= x) * 1))
if verbose:
print("MR:" + str(mean_rank))
print("MRR:" + str(mean_reciprocal_rank))
for i in range(len(hits_at)):
print("Hits@" + str(hits_at[i]) + ":" + str(hits_at_values[i]))
self.sess.run(tf.assign(self.nb_neg, self.nb_neg_examples_per_pos))
self.sess.run(tf.assign(self.dim_dropout_prob, self.dim_dropout_flt)) # Restore dropout after eval complete
self.sess.run(self.training_init_op)
return mean_rank, mean_reciprocal_rank, hits_at_values
def set_up_valid_net(self):
options_no_neg = copy.deepcopy(self.options)
options_no_neg.nb_neg_examples_per_pos = 0
with tf.name_scope("Valid_Net"):
valid_net = BoxEMulti(self.kb_name, options_no_neg, "_val")
return valid_net
def train_with_valid(self, separate_valid_model=True, print_period=1, epoch_ckpt=50, save_period=1000,
num_epochs=1000, reset_weights=True, loss_file_name="losses", log_to_file=True,
log_file_name="training_log.txt", viz_mode=False):
if separate_valid_model:
valid_model = self.set_up_valid_net()
else:
valid_model = self
if self.use_tensorboard:
if not os.path.exists('summaries'):
os.mkdir('summaries')
summary_descriptor = str(self.kb_name) + "_" + str(self.stop_gradient) + "_" + str(self.learning_rate) + \
"_" + "nb_neg-" + str(self.nb_neg_examples_per_pos) + "_" + "loss_margin-" \
+ str(self.margin) + "_" + "emb_dim-" + str(self.embedding_dim) + "_ " + \
"neg_opt-" + str(self.neg_sampling_opt) + "_" \
+ time.asctime()
if not os.path.exists(os.path.join('summaries', summary_descriptor)):
os.mkdir(os.path.join('summaries', summary_descriptor))
self.summary_writer = tf.summary.FileWriter(os.path.join('summaries', summary_descriptor), self.sess.graph)
if log_to_file:
open(log_file_name, "w")
losses = []
self.sess = tf.Session()
if self.saver is None:
self.saver = tf.train.Saver(name="Saver")
self.sess.run(tf.tables_initializer())
if reset_weights:
self.sess.run(tf.global_variables_initializer())
else:
try:
self.saver.restore(self.sess, self.param_directory)
except ValueError:
self.sess.run(tf.global_variables_initializer())
self.sess.run(tf.tables_initializer())
batch_total_count = 0
try:
if not os.path.exists('training_ckpts'):
os.mkdir('training_ckpts')
tim = time.time()
print_or_log("BoxEMulti: ", log_to_file, log_file_name)
print_or_log("Training for " + str(self.kb_name) + ":", log_to_file, log_file_name)
print_or_log("Embedding Dimension: " + str(self.embedding_dim), log_to_file, log_file_name)
print_or_log("Checkpoint Frequency: " + str(epoch_ckpt), log_to_file, log_file_name)
print_or_log("Number of Epochs: " + str(num_epochs), log_to_file, log_file_name)
print_or_log("Learning Rate: " + str(self.learning_rate), log_to_file, log_file_name)
print_or_log("LR Decay: " + str(self.lr_decay) + "/" + "Period: " + str(self.lr_decay_period),