forked from YXNan0110/RANA
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdataset.py
More file actions
127 lines (100 loc) · 4.78 KB
/
dataset.py
File metadata and controls
127 lines (100 loc) · 4.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import json
import os
import argparse
from scipy.io import loadmat
import numpy as np
import networkx as nx
from networkx.readwrite import json_graph
from data_preprocess import DataPreprocess
import random
import graph_utils as graph_utils
class Dataset:
"""
this class receives input from graphsage format with predefined folder structure, the data folder must contains these files:
G.json, id2idx.json, features.npy (optional)
Arguments:
- data_dir: Data directory which contains files mentioned above.
"""
def __init__(self, data_dir, noise_rate):
self.data_dir = data_dir
self.noise_rate = noise_rate
self._load_id2idx()
self._load_G()
self._load_features()
graph_utils.construct_adjacency(self.G, self.id2idx, sparse=False, file_path=self.data_dir + "/edges.edgelist")
# self.load_edge_features()
print("Dataset info:")
print("- Nodes: ", len(self.G.nodes()))
print("- Edges: ", len(self.G.edges()))
def _load_G(self):
G_data = json.load(open(os.path.join(self.data_dir, "G.json")))
G_data['links'] = [{'source': self.idx2id[G_data['links'][i]['source']], 'target': self.idx2id[G_data['links'][i]['target']]} for i in range(len(G_data['links']))]
if self.noise_rate != 0:
n_edges = len(G_data['links'])
num_noise_edges = int(n_edges * self.noise_rate)
all_nodes = {node['id'] for node in G_data['nodes']}
for _ in range(num_noise_edges):
# 随机选择两个不同的节点
u, v = random.sample(all_nodes, 2)
# 确保噪声边不重复,且原图中没有此边
while any(link['source'] == u and link['target'] == v or link['source'] == v and link['target'] == u for link in G_data['links']):
u, v = random.sample(all_nodes, 2)
# 添加噪声边
G_data['links'].append({'source': u, 'target': v})
self.G = json_graph.node_link_graph(G_data)
def _load_id2idx(self):
id2idx_file = os.path.join(self.data_dir, 'id2idx.json')
self.id2idx = json.load(open(id2idx_file))
self.idx2id = {v:k for k,v in self.id2idx.items()}
def _load_features(self):
self.features = None
feats_path = os.path.join(self.data_dir, 'feats.npy')
if os.path.isfile(feats_path):
self.features = np.load(feats_path)
else:
self.features = None
return self.features
def load_edge_features(self):
self.edge_features= None
feats_path = os.path.join(self.data_dir, 'edge_feats.mat')
if os.path.isfile(feats_path):
edge_feats = loadmat(feats_path)['edge_feats']
self.edge_features = np.zeros((len(edge_feats[0]),
len(self.G.nodes()),
len(self.G.nodes())))
for idx, matrix in enumerate(edge_feats[0]):
self.edge_features[idx] = matrix.toarray()
else:
self.edge_features = None
return self.edge_features
def get_adjacency_matrix(self, sparse=False):
return graph_utils.construct_adjacency(self.G, self.id2idx, sparse=False, file_path=self.data_dir + "/edges.edgelist")
def get_nodes_degrees(self):
return graph_utils.build_degrees(self.G, self.id2idx)
def get_nodes_clustering(self):
return graph_utils.build_clustering(self.G, self.id2idx)
def get_edges(self):
return graph_utils.get_edges(self.G, self.id2idx)
def check_id2idx(self):
# print("Checking format of dataset")
for i, node in enumerate(self.G.nodes()):
if (self.id2idx[node] != i):
print("Failed at node %s" % str(node))
return False
# print("Pass")
return True
def parse_args():
parser = argparse.ArgumentParser(description="Test loading dataset")
parser.add_argument('--source_dataset', default="/home/trunght/dataspace/graph/douban/online/graphsage/")
parser.add_argument('--target_dataset', default="/home/trunght/dataspace/graph/douban/offline/graphsage/")
parser.add_argument('--groundtruth', default="/home/trunght/dataspace/graph/douban/dictionaries/groundtruth")
parser.add_argument('--output_dir', default="/home/trunght/dataspace/graph/douban/statistics/")
return parser.parse_args()
def main(args):
source_dataset = Dataset(args.source_dataset)
target_dataset = Dataset(args.target_dataset)
groundtruth = graph_utils.load_gt(args.groundtruth, source_dataset.id2idx, target_dataset.id2idx, "dict")
DataPreprocess.evaluateDataset(source_dataset, target_dataset, groundtruth, args.output_dir)
if __name__ == "__main__":
args = parse_args()
main(args)