@@ -122,26 +122,39 @@ def load_data_realistic_ssl(dataset_name, data_path, label_map_path):
122122 return data
123123
124124
125- def load_from_planetoid_files (dataset_str , path ):
126- """Loads input data from gcn/data directory.
127-
128- This function is copied and adapted from https://github.com/tkipf/gcn/blob/master/gcn/utils.py.
129-
130- ind.dataset_str.x => the feature vectors of the training instances as scipy.sparse.csr.csr_matrix object;
131- ind.dataset_str.tx => the feature vectors of the test instances as scipy.sparse.csr.csr_matrix object;
132- ind.dataset_str.allx => the feature vectors of both labeled and unlabeled training instances
133- (a superset of ind.dataset_str.x) as scipy.sparse.csr.csr_matrix object;
134- ind.dataset_str.y => the one-hot labels of the labeled training instances as numpy.ndarray object;
135- ind.dataset_str.ty => the one-hot labels of the test instances as numpy.ndarray object;
136- ind.dataset_str.ally => the labels for instances in ind.dataset_str.allx as numpy.ndarray object;
137- ind.dataset_str.graph => a dict in the format {index: [index_of_neighbor_nodes]} as collections.defaultdict
138- object;
139- ind.dataset_str.test.index => the indices of test instances in graph, for the inductive setting as list object.
140-
141- All objects above must be saved using python pickle module.
142-
143- :param dataset_str: Dataset name
144- :return: All data input files loaded (as well the training/test data).
125+ def load_from_planetoid_files (dataset_name , path ):
126+ """Loads Planetoid data in GCN format, as released with the GCN code.
127+
128+ This function is adapted from https://github.com/tkipf/gcn.
129+
130+ This function assumes that the following files can be found at the location
131+ specified by `path`:
132+ ind.dataset_str.x => the feature vectors of the training instances
133+ as scipy.sparse.csr.csr_matrix object.
134+ ind.dataset_str.tx => the feature vectors of the test instances as
135+ scipy.sparse.csr.csr_matrix object.
136+ ind.dataset_str.allx => the feature vectors of both labeled and
137+ unlabeled training instances (a superset of
138+ ind.dataset_str.x) as
139+ scipy.sparse.csr.csr_matrix object.
140+ ind.dataset_str.y => the one-hot labels of the labeled training
141+ instances as numpy.ndarray object.
142+ ind.dataset_str.ty => the one-hot labels of the test instances as
143+ numpy.ndarray object.
144+ ind.dataset_str.ally => the labels for instances in
145+ ind.dataset_str.allx as numpy.ndarray object.
146+ ind.dataset_str.graph => a dict in the format
147+ {index: [index_of_neighbor_nodes]} as
148+ collections.defaultdict object.
149+ ind.dataset_str.test.index => the indices of test instances in graph, for
150+ the inductive setting as list object.
151+
152+ Arguments:
153+ dataset_name: A string representing the dataset name (e.g., `cora`).
154+ path: Path to the directory containing the files.
155+
156+ Returns:
157+ All data input files loaded (as well the training/test data).
145158 """
146159 def _sample_mask (idx , l ):
147160 """Create mask."""
@@ -159,7 +172,7 @@ def _parse_index_file(filename):
159172 names = ['x' , 'y' , 'tx' , 'ty' , 'allx' , 'ally' , 'graph' ]
160173 objects = []
161174 for i in range (len (names )):
162- filename = "ind.{}.{}" .format (dataset_str , names [i ])
175+ filename = "ind.{}.{}" .format (dataset_name , names [i ])
163176 filename = os .path .join (path , filename )
164177 with open (filename , 'rb' ) as f :
165178 if sys .version_info > (3 , 0 ):
@@ -168,15 +181,16 @@ def _parse_index_file(filename):
168181 objects .append (pkl .load (f ))
169182
170183 x , y , tx , ty , allx , ally , graph = tuple (objects )
171- filename = "ind.{}.test.index" .format (dataset_str )
184+ filename = "ind.{}.test.index" .format (dataset_name )
172185 filename = os .path .join (path , filename )
173186 test_idx_reorder = _parse_index_file (filename )
174187 test_idx_range = np .sort (test_idx_reorder )
175188
176- if dataset_str == 'citeseer' :
189+ if dataset_name == 'citeseer' :
177190 # Fix citeseer dataset (there are some isolated nodes in the graph).
178191 # Find isolated nodes, add them as zero-vecs into the right position.
179- test_idx_range_full = range (min (test_idx_reorder ), max (test_idx_reorder )+ 1 )
192+ test_idx_range_full = range (min (test_idx_reorder ),
193+ max (test_idx_reorder )+ 1 )
180194 tx_extended = sp .lil_matrix ((len (test_idx_range_full ), x .shape [1 ]))
181195 tx_extended [test_idx_range - min (test_idx_range ), :] = tx
182196 tx = tx_extended
@@ -187,7 +201,6 @@ def _parse_index_file(filename):
187201 features = sp .vstack ((allx , tx )).tolil ()
188202 features [test_idx_reorder , :] = features [test_idx_range , :]
189203 adj = nx .adjacency_matrix (nx .from_dict_of_lists (graph ))
190- #adj = nx.adjacency_matrix(nx.from_dict_of_lists(graph, create_using=nx.DiGraph))
191204
192205 labels = np .vstack ((ally , ty ))
193206 labels [test_idx_reorder , :] = labels [test_idx_range , :]
@@ -207,19 +220,20 @@ def _parse_index_file(filename):
207220 y_val [val_mask , :] = labels [val_mask , :]
208221 y_test [test_mask , :] = labels [test_mask , :]
209222
210- return adj , features , y_train , y_val , y_test , train_mask , val_mask , test_mask , labels
223+ return (adj , features , y_train , y_val , y_test , train_mask , val_mask ,
224+ test_mask , labels )
211225
212226
213227def load_data_planetoid (name , path , splits_path = None , row_normalize = False ):
214- # Load from file.
215228 if splits_path is None :
216- adj , features , y_train , y_val , y_test , train_mask , val_mask , test_mask , \
217- labels = load_from_planetoid_files (name , path )
229+ # Load from file in Planetoid format.
230+ adj , features , y_train , y_val , y_test , train_mask , val_mask , test_mask ,\
231+ labels = load_from_planetoid_files (name , path )
218232 else :
219- # Otherwise load from splits path.
233+ # Otherwise load from a path where we saved a pickle with random splits .
220234 logging .info ('Loading from splits path: %s' , splits_path )
221- adj , features , y_train , y_val , y_test , train_mask , val_mask , test_mask , \
222- labels = pickle .load (open (splits_path , "rb" ))
235+ adj , features , y_train , y_val , y_test , train_mask , val_mask , test_mask ,\
236+ labels = pickle .load (open (splits_path , "rb" ))
223237
224238 return PlanetoidDataset (name , adj , features , train_mask , val_mask , test_mask ,
225239 labels , row_normalize = row_normalize )
0 commit comments