Skip to content

Commit ddab3ee

Browse files
committed
Fixing issue with no graph edges available between two labeled nodes.
1 parent 1c46ee1 commit ddab3ee

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

neural_structured_learning/research/gam/trainer/trainer_classification.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,14 @@ def edge_iterator(self, data, batch_size, labeling):
615615
else:
616616
raise ValueError('Unsupported value for parameter `labeling`.')
617617

618+
if len(edges) == 0:
619+
indices = np.zeros(shape=(0,), dtype=np.int32)
620+
features = np.zeros(shape=[0,] + list(data.features_shape),
621+
dtype=np.float32)
622+
labels = np.zeros(shape=(0,), dtype=np.int64)
623+
while True:
624+
yield (indices, indices, features, features, labels, labels)
625+
618626
edges = np.stack([(e.src, e.tgt) for e in edges])
619627
iterator = batch_iterator(
620628
inputs=edges,

0 commit comments

Comments
 (0)