1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414"""Data containers for Graph Agreement Models."""
15- import abc
1615import collections
1716import logging
1817import os
1918import pickle
19+ import scipy
2020
2121import numpy as np
2222import tensorflow as tf
2323
24-
25- class Dataset (object ):
26- """Interface for different types of datasets."""
27- __metaclass__ = abc .ABCMeta
28-
29- @abc .abstractmethod
30- def num_train (self ):
31- pass
32-
33- @abc .abstractmethod
34- def num_val (self ):
35- pass
36-
37- @abc .abstractmethod
38- def num_test (self ):
39- pass
40-
41- @abc .abstractmethod
42- def num_unlabeled (self ):
43- pass
44-
45- @abc .abstractmethod
46- def copy_labels (self ):
47- pass
48-
49- def save_to_pickle (self , file_path ):
50- pickle .dump (self , open (file_path , 'w' ))
51-
52- @staticmethod
53- def load_from_pickle (file_path ):
54- dataset = pickle .load (open (file_path , 'r' ))
55- return dataset
24+ from gam .data .preprocessing import split_train_val
5625
5726
58- class FixedDataset ( Dataset ):
59- """A dataset containing features of fixed size .
27+ class Dataset ( object ):
28+ """A container for datasets .
6029
6130 In this dataset, each sample has the same number of features.
6231 This class manages different splits of the data for train, validation, test
6332 and unlabeled. These sets of samples are disjoint.
6433 """
6534
66- def __init__ (self ,
67- x_train ,
68- y_train ,
69- x_val ,
70- y_val ,
71- x_test ,
72- y_test ,
73- x_unlabeled ,
74- y_unlabeled = None ,
75- num_classes = None ,
35+ def __init__ (self , name , features , labels , indices_train , indices_test ,
36+ indices_val , indices_unlabeled , num_classes = None ,
7637 feature_preproc_fn = lambda x : x ):
77- n_train = x_train .shape [0 ]
78- n_val = x_val .shape [0 ]
79- n_test = x_test .shape [0 ]
80- n_unlabeled = x_unlabeled .shape [0 ]
38+ self .name = name
39+ self .features = features
40+ self .labels = labels
41+
42+ self .indices_train = indices_train
43+ self .indices_val = indices_val
44+ self .indices_test = indices_test
45+ self .indices_unlabeled = indices_unlabeled
46+ self .feature_preproc_fn = feature_preproc_fn
8147
82- if y_unlabeled is None :
83- y_unlabeled = np . zeros ( shape = ( n_unlabeled ,), dtype = y_train . dtype )
48+ self . num_val = self . indices_val . shape [ 0 ]
49+ self . num_test = self . indices_test . shape [ 0 ]
8450
85- # Concatenate samples.
86- self .features = np .concatenate ((x_train , x_val , x_unlabeled , x_test ))
87- self .labels = np .concatenate ((y_train , y_val , y_unlabeled , y_test ))
51+ self .num_samples = labels .shape [0 ]
52+ self .features_shape = features .shape [1 :]
53+ self .num_features = np .prod (features .shape [1 :])
54+ self .num_classes = 1 + max (labels ) if num_classes is None else num_classes
8855
89- self ._num_features = np .prod (self .features .shape [1 :])
90- self ._num_classes = 1 + max (self .labels ) if num_classes is None else \
91- num_classes
92- self ._num_samples = n_train + n_val + n_unlabeled + n_test
56+ @staticmethod
57+ def build_from_splits (name , inputs_train , labels_train , inputs_val ,
58+ labels_val , inputs_test , labels_test , inputs_unlabeled ,
59+ labels_unlabeled = None , num_classes = None ,
60+ feature_preproc_fn = lambda x : x ):
61+ num_train = inputs_train .shape [0 ]
62+ num_val = inputs_val .shape [0 ]
63+ num_unlabeled = inputs_unlabeled .shape [0 ]
64+ num_test = inputs_test .shape [0 ]
65+
66+ if labels_unlabeled is None :
67+ labels_unlabeled = np .zeros (shape = (num_unlabeled ,),
68+ dtype = labels_train [0 ].dtype )
69+ features = np .concatenate (
70+ (inputs_train , inputs_val , inputs_unlabeled , inputs_test ))
71+ labels = np .concatenate (
72+ (labels_train , labels_val , labels_unlabeled , labels_test ))
73+
74+ indices_train = np .arange (num_train )
75+ indices_val = np .arange (num_train , num_train + num_val )
76+ indices_unlabeled = np .arange (num_train + num_val ,
77+ num_train + num_val + num_unlabeled )
78+ indices_test = np .arange (num_train + num_val + num_unlabeled ,
79+ num_train + num_val + num_unlabeled + num_test )
80+
81+ return Dataset (name = name ,
82+ features = features ,
83+ labels = labels ,
84+ indices_train = indices_train ,
85+ indices_test = indices_test ,
86+ indices_val = indices_val ,
87+ indices_unlabeled = indices_unlabeled ,
88+ num_classes = num_classes ,
89+ feature_preproc_fn = feature_preproc_fn )
9390
94- self .indices_train = np .arange (n_train )
95- self .indices_val = np .arange (n_train , n_train + n_val )
96- self .indices_unlabeled = np .arange (n_train + n_val , n_train + n_val + n_unlabeled )
97- self .indices_test = np .arange (n_train + n_val + n_unlabeled , self ._num_samples )
91+ @staticmethod
92+ def build_from_features (name , features , labels , indices_train , indices_test ,
93+ indices_val = None , indices_unlabeled = None ,
94+ percent_val = 0.2 , seed = None , num_classes = None ,
95+ feature_preproc_fn = lambda x : x ):
96+ if indices_val is None :
97+ rng = np .random .RandomState (seed = seed )
98+ indices_train , indices_val = split_train_val (
99+ np .arange (indices_train .shape [0 ]), percent_val , rng )
100+
101+ return Dataset (name = name ,
102+ features = features ,
103+ labels = labels ,
104+ indices_train = indices_train ,
105+ indices_test = indices_test ,
106+ indices_val = indices_val ,
107+ indices_unlabeled = indices_unlabeled ,
108+ num_classes = num_classes ,
109+ feature_preproc_fn = feature_preproc_fn )
98110
99- self .feature_preproc_fn = feature_preproc_fn
100111
101112 def copy_labels (self ):
102113 return np .copy (self .labels )
103114
104- def update_labels (self , indices_samples , new_labels ):
105- """Updates the labels of the samples with the provided indices.
106-
107- Arguments:
108- indices_samples: A list of integers representing sample indices.
109- new_labels: A list of integers representing the new labels of th samples
110- in indices_samples.
111- """
112- indices_samples = np .asarray (indices_samples )
113- new_labels = np .asarray (new_labels )
114- self .labels [indices_samples ] = new_labels
115-
116115 def get_features (self , indices ):
117116 """Returns the features of the samples with the provided indices."""
118117 f = self .features [indices ]
@@ -122,23 +121,6 @@ def get_features(self, indices):
122121 def get_labels (self , indices ):
123122 return self .labels [indices ]
124123
125- @property
126- def num_features (self ):
127- return self ._num_features
128-
129- @property
130- def features_shape (self ):
131- """Returns the shape of the input features, not including batch size."""
132- return self .features .shape [1 :]
133-
134- @property
135- def num_classes (self ):
136- return self ._num_classes
137-
138- @property
139- def num_samples (self ):
140- return self ._num_samples
141-
142124 def num_train (self ):
143125 return self .indices_train .shape [0 ]
144126
@@ -181,13 +163,134 @@ def label_samples(self, indices_samples, new_labels):
181163 # indices, without checking if they already exist.
182164 self .indices_train = np .concatenate ((self .indices_train , indices_samples ),
183165 axis = 0 )
184- # Remove the recently labeled samples from the unlabeled set.
166+ # Remove the recently labeled samples from the unlabeled set.
185167 indices_samples = set (indices_samples )
186168 self .indices_unlabeled = np .asarray (
187169 [u for u in self .indices_unlabeled if u not in indices_samples ])
188170
171+ def update_labels (self , indices , new_labels ):
172+ """Updates the labels of the samples with the provided indices.
173+
174+ Arguments:
175+ indices: A list of integers representing sample indices.
176+ new_labels: A list of integers representing the new labels of th samples
177+ in indices_samples.
178+ """
179+ indices = np .asarray (indices )
180+ new_labels = np .asarray (new_labels )
181+ self .labels [indices ] = new_labels
182+
183+ def save_to_pickle (self , file_path ):
184+ pickle .dump (self , open (file_path , 'w' ))
189185
190- class CotrainDataset (Dataset ):
186+ @staticmethod
187+ def load_from_pickle (file_path ):
188+ dataset = pickle .load (open (file_path , 'r' ))
189+ return dataset
190+
191+
192+ class GraphDataset (Dataset ):
193+ """Data container for SSL datasets."""
194+ class Edge (object ):
195+ def __init__ (self , src , tgt , weight = None ):
196+ self .src = src
197+ self .tgt = tgt
198+ self .weight = weight
199+
200+ def __init__ (self , name , features , labels , edges , indices_train , indices_test ,
201+ indices_val = None , indices_unlabeled = None , percent_val = 0.2 ,
202+ seed = None , num_classes = None , feature_preproc_fn = lambda x : x ):
203+ self .edges = edges
204+
205+ if indices_val is None :
206+ rng = np .random .RandomState (seed = seed )
207+ indices_train , indices_val = split_train_val (
208+ np .arange (indices_train .shape [0 ]), percent_val , rng )
209+
210+ super ().__init__ (
211+ name = name ,
212+ features = features ,
213+ labels = labels ,
214+ indices_train = indices_train ,
215+ indices_test = indices_test ,
216+ indices_val = indices_val ,
217+ indices_unlabeled = indices_unlabeled ,
218+ num_classes = num_classes ,
219+ feature_preproc_fn = feature_preproc_fn )
220+
221+ def get_edges (self , src_labeled = None , tgt_labeled = None ,
222+ label_must_match = False ):
223+ labeled_mask = np .full ((self .num_samples ,), False )
224+ labeled_mask [self .get_indices_train ()] = True
225+
226+ def _labeled_cond (idx , is_labeled ):
227+ return (is_labeled is None ) or (is_labeled == labeled_mask [idx ])
228+
229+ def _agreement_cond (edge ):
230+ return self .get_labels (edge .src ) == self .get_labels (edge .tgt )
231+
232+ agreement_cond = _agreement_cond if label_must_match else lambda e : True
233+
234+ return [e for e in self .edges
235+ if _labeled_cond (e .src , src_labeled ) and \
236+ _labeled_cond (e .tgt , tgt_labeled ) and \
237+ agreement_cond (e )]
238+
239+
240+ class PlanetoidDataset (GraphDataset ):
241+ """Data container for Planetoid datasets."""
242+
243+ def __init__ (self , name , adj , features , train_mask , val_mask , test_mask ,
244+ labels , row_normalize = False ):
245+
246+ # Extract train, val, test, unlabeled indices.
247+ train_indices = np .where (train_mask )[0 ]
248+ test_indices = np .where (test_mask )[0 ]
249+ val_indices = np .where (val_mask )[0 ]
250+ unlabeled_mask = np .logical_not (train_mask | test_mask | val_mask )
251+ unlabeled_indices = np .where (unlabeled_mask )[0 ]
252+
253+ # Extract node features.
254+ if row_normalize :
255+ features = self .preprocess_features (features )
256+ else :
257+ features = features .todense ()
258+ features = np .float32 (features )
259+
260+ # Extract labels.
261+ labels = np .argmax (labels , axis = - 1 )
262+ num_classes = max (labels ) + 1
263+
264+ # Extract edges.
265+ adj = scipy .sparse .coo_matrix (adj )
266+ edges = [self .Edge (src , tgt , val )
267+ for src , tgt , val in zip (adj .row , adj .col , adj .data )]
268+
269+ # Convert to Dataset format.
270+ super ().__init__ (
271+ name = name ,
272+ features = features ,
273+ labels = labels ,
274+ edges = edges ,
275+ indices_train = train_indices ,
276+ indices_test = test_indices ,
277+ indices_val = val_indices ,
278+ indices_unlabeled = unlabeled_indices ,
279+ num_classes = num_classes ,
280+ feature_preproc_fn = lambda x : x )
281+
282+ @staticmethod
283+ def preprocess_features (features ):
284+ """Row-normalize feature matrix."""
285+ rowsum = np .array (features .sum (1 ))
286+ r_inv = np .power (rowsum , - 1 ).flatten ()
287+ r_inv [np .isinf (r_inv )] = 0.
288+ r_mat_inv = scipy .sparse .diags (r_inv )
289+ features = r_mat_inv .dot (features )
290+ return features .todense ()
291+
292+
293+ class CotrainDataset (object ):
191294 """A wrapper around a Dataset object, adding co-training functionality.
192295
193296 Attributes:
@@ -430,3 +533,10 @@ def restore_state_from_file(self, path):
430533
431534 def copy_labels (self ):
432535 return self .dataset .copy_labels ()
536+
537+ def get_edges (self , src_labeled = None , tgt_labeled = None ,
538+ label_must_match = False ):
539+ return self .dataset .get_edges (
540+ src_labeled = src_labeled ,
541+ tgt_labeled = tgt_labeled ,
542+ label_must_match = label_must_match )
0 commit comments