Skip to content

Commit 63bd0e6

Browse files
authored
Add functional creation of graph (#85)
* Add functional creation of graph * Do not test half dtype for similarities
1 parent 380b097 commit 63bd0e6

File tree

6 files changed

+143
-20
lines changed

6 files changed

+143
-20
lines changed

docs/functional.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ Encodings
6363
hash_table
6464
cross_product
6565
ngrams
66+
graph
6667

6768

6869
Utilities

torchhd/functional.py

Lines changed: 51 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
"bind_sequence",
2828
"ngrams",
2929
"hash_table",
30+
"graph",
3031
"map_range",
3132
"value_to_index",
3233
"index_to_value",
@@ -41,7 +42,7 @@ def identity_hv(
4142
device=None,
4243
requires_grad=False,
4344
) -> Tensor:
44-
"""Creates a set of identity hypervector.
45+
"""Creates a set of identity hypervectors.
4546
4647
When bound with a random-hypervector :math:`x`, the result is :math:`x`.
4748
@@ -174,24 +175,20 @@ def random_hv(
174175
if dtype == torch.uint8:
175176
raise ValueError("Unsigned integer hypervectors are not supported.")
176177

178+
size = (num_embeddings, embedding_dim)
177179
if dtype in {torch.complex64, torch.complex128}:
178180
dtype = torch.float if dtype == torch.complex64 else torch.double
179181

180-
angle = torch.empty(num_embeddings, embedding_dim, dtype=dtype, device=device)
182+
angle = torch.empty(size, dtype=dtype, device=device)
181183
angle.uniform_(-math.pi, math.pi)
182-
magnitude = torch.ones(
183-
num_embeddings, embedding_dim, dtype=dtype, device=device
184-
)
184+
magnitude = torch.ones(size, dtype=dtype, device=device)
185185

186186
result = torch.polar(magnitude, angle)
187187
result.requires_grad = requires_grad
188188
return result
189189

190190
select = torch.empty(
191-
(
192-
num_embeddings,
193-
embedding_dim,
194-
),
191+
size,
195192
dtype=torch.bool,
196193
).bernoulli_(1.0 - sparsity, generator=generator)
197194

@@ -1031,7 +1028,7 @@ def hash_table(keys: Tensor, values: Tensor) -> Tensor:
10311028
10321029
.. math::
10331030
1034-
\bigoplus_{i = 0}^{m - 1} K_i \otimes V_i
1031+
\bigoplus_{i = 0}^{n - 1} K_i \otimes V_i
10351032
10361033
Args:
10371034
keys (Tensor): The keys hypervectors, must be the same shape as values.
@@ -1066,7 +1063,7 @@ def bundle_sequence(input: Tensor) -> Tensor:
10661063
10671064
.. math::
10681065
1069-
\bigoplus_{i=0}^{m-1} \Pi^{m - i - 1}(V_i)
1066+
\bigoplus_{i=0}^{n-1} \Pi^{n - i - 1}(V_i)
10701067
10711068
Args:
10721069
input (Tensor): The hypervector values.
@@ -1105,7 +1102,7 @@ def bind_sequence(input: Tensor) -> Tensor:
11051102
11061103
.. math::
11071104
1108-
\bigotimes_{i=0}^{m-1} \Pi^{m - i - 1}(V_i)
1105+
\bigotimes_{i=0}^{n-1} \Pi^{n - i - 1}(V_i)
11091106
11101107
Args:
11111108
input (Tensor): The hypervector values.
@@ -1141,6 +1138,48 @@ def bind_sequence(input: Tensor) -> Tensor:
11411138
return multibind(permuted)
11421139

11431140

1141+
def graph(input: Tensor, *, directed=False) -> Tensor:
1142+
r"""Graph from node hypervector pairs.
1143+
1144+
If ``directed=False`` this computes:
1145+
1146+
.. math::
1147+
1148+
\bigoplus_{i = 0}^{n - 1} V_{0,i} \otimes V_{1,i}
1149+
1150+
If ``directed=True`` this computes:
1151+
1152+
.. math::
1153+
1154+
\bigoplus_{i = 0}^{n - 1} V_{0,i} \otimes \Pi(V_{1,i})
1155+
1156+
Args:
1157+
input (Tensor): tensor containing pairs of node hypervectors that share an edge.
1158+
directed (bool, optional): specify if the graph is directed or not. Default: ``False``.
1159+
1160+
Shapes:
1161+
- Input: :math:`(*, 2, n, d)`
1162+
- Output: :math:`(*, d)`
1163+
1164+
Examples::
1165+
>>> edges = torch.tensor([[0, 0, 1, 2], [1, 2, 2, 3]])
1166+
>>> node_embedding = embeddings.Random(4, 10000)
1167+
>>> edges_hv = node_embedding(edges)
1168+
>>> graph = functional.graph(edges_hv)
1169+
>>> neighbors = unbind(graph, node_embedding.weight[0])
1170+
>>> cosine_similarity(neighbors, node_embedding.weight)
1171+
tensor([0.0006, 0.5017, 0.4997, 0.0048])
1172+
1173+
"""
1174+
to_nodes = input[..., 0, :, :]
1175+
from_nodes = input[..., 1, :, :]
1176+
1177+
if directed:
1178+
from_nodes = permute(from_nodes)
1179+
1180+
return multiset(bind(to_nodes, from_nodes))
1181+
1182+
11441183
def map_range(
11451184
input: Tensor,
11461185
in_min: float,

torchhd/structures.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -911,6 +911,26 @@ def clear(self) -> None:
911911
"""
912912
self.value.fill_(0.0)
913913

914+
@classmethod
915+
def from_edges(cls, input: Tensor, directed=False):
916+
"""Creates a graph from a tensor
917+
918+
See: :func:`~torchhd.functional.graph`.
919+
920+
Args:
921+
input (Tensor): tensor containing pairs of node hypervectors that share an edge.
922+
directed (bool, optional): specify if the graph is directed or not. Default: ``False``.
923+
924+
Examples::
925+
>>> edges = torch.tensor([[0, 0, 1, 2], [1, 2, 2, 3]])
926+
>>> node_embedding = embeddings.Random(4, 10000)
927+
>>> edges_hv = node_embedding(edges)
928+
>>> graph = structures.Graph.from_edges(edges_hv)
929+
930+
"""
931+
value = functional.graph(input, directed=directed)
932+
return cls(value, directed=directed)
933+
914934

915935
class Tree:
916936
"""Hypervector-based tree data structure.

torchhd/tests/structures/test_graph.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,3 +179,25 @@ def test_clear(self):
179179
assert torch.equal(
180180
G.value, torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
181181
)
182+
183+
def test_from_edges(self):
184+
generator = torch.Generator()
185+
generator.manual_seed(seed)
186+
187+
hv = functional.random_hv(4, 8, generator=generator)
188+
edges = torch.empty(2, 3, 8)
189+
edges[0, 0] = hv[0]
190+
edges[1, 0] = hv[1]
191+
edges[0, 1] = hv[0]
192+
edges[1, 1] = hv[2]
193+
edges[0, 2] = hv[1]
194+
edges[1, 2] = hv[2]
195+
196+
G = structures.Graph.from_edges(edges)
197+
neighbors = G.node_neighbors(hv[0])
198+
neighbor_similarity = functional.cosine_similarity(neighbors, hv)
199+
200+
assert neighbor_similarity[0] < torch.tensor(0.5)
201+
assert neighbor_similarity[1] > torch.tensor(0.5)
202+
assert neighbor_similarity[2] > torch.tensor(0.5)
203+
assert neighbor_similarity[3] < torch.tensor(0.5)

torchhd/tests/test_encodings.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,3 +320,44 @@ def test_device(self):
320320
hv = functional.random_hv(11, 10000, device=device)
321321
res = functional.bind_sequence(hv)
322322
assert res.device == device
323+
324+
325+
class TestGraph:
326+
def test_value(self):
327+
hv = torch.zeros(2, 4, 1000)
328+
res = functional.graph(hv)
329+
assert torch.all(res == 0).item()
330+
331+
g = torch.tensor(
332+
[
333+
[[1, -1, -1, 1], [-1, -1, 1, 1], [-1, 1, 1, 1]],
334+
[[-1, -1, 1, 1], [-1, 1, 1, 1], [1, -1, -1, 1]],
335+
]
336+
)
337+
res = functional.graph(g)
338+
assert torch.all(res == torch.tensor([-1, -1, -1, 3])).item()
339+
assert res.dtype == g.dtype
340+
341+
res = functional.graph(g, directed=True)
342+
assert torch.all(res == torch.tensor([-1, 3, 1, 1])).item()
343+
assert res.dtype == g.dtype
344+
345+
@pytest.mark.parametrize("dtype", torch_dtypes)
346+
def test_dtype(self, dtype):
347+
hv = torch.zeros(5, 2, 23, 1000, dtype=dtype)
348+
349+
if dtype == torch.uint8:
350+
with pytest.raises(ValueError):
351+
functional.graph(hv)
352+
353+
return
354+
355+
res = functional.graph(hv)
356+
assert res.dtype == dtype
357+
358+
def test_device(self):
359+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
360+
361+
hv = torch.zeros(5, 2, 23, 1000, device=device)
362+
res = functional.graph(hv)
363+
assert res.device == device

torchhd/tests/test_similarities.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
class TestDotSimilarity:
1616
@pytest.mark.parametrize("dtype", torch_dtypes)
1717
def test_shape(self, dtype):
18-
if not supported_dtype(dtype):
18+
if not supported_dtype(dtype) or dtype == torch.half:
1919
return
2020

2121
generator = torch.Generator()
@@ -43,7 +43,7 @@ def test_shape(self, dtype):
4343

4444
@pytest.mark.parametrize("dtype", torch_dtypes)
4545
def test_value(self, dtype):
46-
if not supported_dtype(dtype):
46+
if not supported_dtype(dtype) or dtype == torch.half:
4747
return
4848

4949
generator = torch.Generator()
@@ -113,7 +113,7 @@ def test_value(self, dtype):
113113

114114
@pytest.mark.parametrize("dtype", torch_dtypes)
115115
def test_dtype(self, dtype):
116-
if not supported_dtype(dtype):
116+
if not supported_dtype(dtype) or dtype == torch.half:
117117
return
118118

119119
generator = torch.Generator()
@@ -134,7 +134,7 @@ def test_dtype(self, dtype):
134134

135135
@pytest.mark.parametrize("dtype", torch_dtypes)
136136
def test_device(self, dtype):
137-
if not supported_dtype(dtype):
137+
if not supported_dtype(dtype) or dtype == torch.half:
138138
return
139139

140140
generator = torch.Generator()
@@ -153,7 +153,7 @@ def test_device(self, dtype):
153153
class TestCosSimilarity:
154154
@pytest.mark.parametrize("dtype", torch_dtypes)
155155
def test_shape(self, dtype):
156-
if not supported_dtype(dtype):
156+
if not supported_dtype(dtype) or dtype == torch.half:
157157
return
158158

159159
generator = torch.Generator()
@@ -181,7 +181,7 @@ def test_shape(self, dtype):
181181

182182
@pytest.mark.parametrize("dtype", torch_dtypes)
183183
def test_value(self, dtype):
184-
if not supported_dtype(dtype):
184+
if not supported_dtype(dtype) or dtype == torch.half:
185185
return
186186

187187
generator = torch.Generator()
@@ -250,7 +250,7 @@ def test_value(self, dtype):
250250

251251
@pytest.mark.parametrize("dtype", torch_dtypes)
252252
def test_dtype(self, dtype):
253-
if not supported_dtype(dtype):
253+
if not supported_dtype(dtype) or dtype == torch.half:
254254
return
255255

256256
generator = torch.Generator()
@@ -264,7 +264,7 @@ def test_dtype(self, dtype):
264264

265265
@pytest.mark.parametrize("dtype", torch_dtypes)
266266
def test_device(self, dtype):
267-
if not supported_dtype(dtype):
267+
if not supported_dtype(dtype) or dtype == torch.half:
268268
return
269269

270270
generator = torch.Generator()

0 commit comments

Comments
 (0)