forked from Sunnan191/EviSEC
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlink_pred_tasker.py
More file actions
147 lines (113 loc) · 4.92 KB
/
link_pred_tasker.py
File metadata and controls
147 lines (113 loc) · 4.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import torch
import taskers_utils as tu
import utils as u
class Link_Pred_Tasker():
'''
Creates a tasker object which computes the required inputs for training on a link prediction
task. It receives a dataset object which should have two attributes: nodes_feats and edges, this
makes the tasker independent of the dataset being used (as long as mentioned attributes have the same
structure).
Based on the dataset it implements the get_sample function required by edge_cls_trainer.
This is a dictionary with:
- time_step: the time_step of the prediction
- hist_adj_list: the input adjacency matrices until t, each element of the list
is a sparse tensor with the current edges. For link_pred they're
unweighted
- nodes_feats_list: the input nodes for the GCN models, each element of the list is a tensor
two dimmensions: node_idx and node_feats
- label_adj: a sparse representation of the target edges. A dict with two keys: idx: M by 2
matrix with the indices of the nodes conforming each edge, vals: 1 if the node exists
, 0 if it doesn't
There's a test difference in the behavior, on test (or development), the number of sampled non existing
edges should be higher.
'''
def __init__(self,args,dataset,ood_mode=None):
self.data = dataset
#max_time for link pred should be one before
self.max_time = dataset.max_time - 1
self.args = args
self.num_classes = 2
if not (args.use_2_hot_node_feats or args.use_1_hot_node_feats):
self.feats_per_node = dataset.feats_per_node # sbm 3
self.get_node_feats = self.build_get_node_feats(args,dataset,ood_mode)
self.prepare_node_feats = self.build_prepare_node_feats(args,dataset,ood_mode)
self.is_static = False
def build_prepare_node_feats(self,args,dataset,ood_mode):
if args.use_2_hot_node_feats or args.use_1_hot_node_feats:
def prepare_node_feats(node_feats):
return u.sparse_prepare_tensor(node_feats,
torch_size= [dataset.num_nodes,
self.feats_per_node])
else:
prepare_node_feats = self.data.prepare_node_feats
return prepare_node_feats
def build_get_node_feats(self,args,dataset,ood_mode):
if args.use_2_hot_node_feats:
max_deg_out, max_deg_in = tu.get_max_degs(args,dataset)
self.feats_per_node = max_deg_out + max_deg_in
def get_node_feats(adj):
return tu.get_2_hot_deg_feats(adj,
max_deg_out,
max_deg_in,
dataset.num_nodes)
elif args.use_1_hot_node_feats:
max_deg,_ = tu.get_max_degs(args,dataset)
self.feats_per_node = max_deg
def get_node_feats(adj):
feats = tu.get_1_hot_deg_feats(adj, max_deg, dataset.num_nodes)
if ood_mode == "FI":
feats["idx"][:, 1] = torch.randint(1, max_deg, feats["idx"][:, 1].size(), dtype=torch.int64)
return feats
else:
def get_node_feats(adj):
return dataset.nodes_feats
return get_node_feats
def get_sample(self,idx,test, **kwargs):
hist_adj_list = []
hist_adj_list_unnormalized = []
hist_ndFeats_list = []
hist_mask_list = []
existing_nodes = []
for i in range(idx - self.args.num_hist_steps, idx+1):
cur_adj = tu.get_sp_adj(edges = self.data.edges,
time = i,
weighted = True,
time_window = self.args.adj_mat_time_window)
if self.args.smart_neg_sampling:
existing_nodes.append(cur_adj['idx'].unique())
else:
existing_nodes = None
node_mask = tu.get_node_mask(cur_adj, self.data.num_nodes)
node_feats = self.get_node_feats(cur_adj)
cur_adj_unnormalized = cur_adj
cur_adj = tu.normalize_adj(adj = cur_adj, num_nodes = self.data.num_nodes)
hist_adj_list.append(cur_adj)
hist_adj_list_unnormalized.append(cur_adj_unnormalized)
hist_ndFeats_list.append(node_feats)
hist_mask_list.append(node_mask)
label_adj = tu.get_sp_adj(edges = self.data.edges,
time = idx+1,
weighted = False,
time_window = self.args.adj_mat_time_window)
if test:
neg_mult = self.args.negative_mult_test
else:
neg_mult = self.args.negative_mult_training
if self.args.smart_neg_sampling:
existing_nodes = torch.cat(existing_nodes)
if 'all_edges' in kwargs.keys() and kwargs['all_edges'] == True:
non_exisiting_adj = tu.get_all_non_existing_edges(adj = label_adj, tot_nodes = self.data.num_nodes)
else:
non_exisiting_adj = tu.get_non_existing_edges(adj = label_adj,
number = label_adj['vals'].size(0) * neg_mult,
tot_nodes = self.data.num_nodes,
smart_sampling = self.args.smart_neg_sampling,
existing_nodes = existing_nodes)
label_adj['idx'] = torch.cat([label_adj['idx'],non_exisiting_adj['idx']])
label_adj['vals'] = torch.cat([label_adj['vals'],non_exisiting_adj['vals']])
return {'idx': idx,
'hist_adj_list': hist_adj_list,
'hist_ndFeats_list': hist_ndFeats_list,
'label_sp': label_adj,
'node_mask_list': hist_mask_list,
'hist_adj_list_u': hist_adj_list_unnormalized}