Skip to content

Commit ae38ec8

Browse files
Add test for create random permute function (#131)
* Add tests for create random permute function * [github-action] formatting fixes * Send permute function to device * Make sure device type matches * Pass device parameters * [github-action] formatting fixes --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent c823cc1 commit ae38ec8

20 files changed

+438
-167
lines changed

torchhd/functional.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -716,7 +716,10 @@ def permute(input: VSATensor, *, shifts=1) -> VSATensor:
716716
return input.permute(shifts)
717717

718718

719-
def create_random_permute(dim: int) -> Callable[[VSATensor, int], VSATensor]:
719+
class create_random_permute(torch.nn.Module):
720+
forward_indices: LongTensor
721+
backward_indices: LongTensor
722+
720723
r"""Creates random permutation functions.
721724
722725
Args:
@@ -741,21 +744,28 @@ def create_random_permute(dim: int) -> Callable[[VSATensor, int], VSATensor]:
741744
742745
"""
743746

744-
forward = torch.randperm(dim)
745-
backward = torch.empty_like(forward)
746-
backward[forward] = torch.arange(dim)
747+
def __init__(self, dim: int) -> None:
748+
super().__init__()
749+
750+
forward = torch.randperm(dim)
751+
backward = torch.empty_like(forward)
752+
backward[forward] = torch.arange(dim)
753+
754+
self.register_buffer("forward_indices", forward)
755+
self.register_buffer("backward_indices", backward)
747756

748-
def permute(input: VSATensor, shifts: int = 1) -> VSATensor:
757+
def __call__(self, input: VSATensor, shifts: int = 1) -> VSATensor:
749758
y = input
759+
750760
if shifts > 0:
751-
for _ in range(shifts):
752-
y = y[..., forward]
761+
for _ in range(abs(shifts)):
762+
y = y[..., self.forward_indices]
763+
753764
elif shifts < 0:
754-
for _ in range(shifts):
755-
y = y[..., backward]
756-
return y
765+
for _ in range(abs(shifts)):
766+
y = y[..., self.backward_indices]
757767

758-
return permute
768+
return y.clone()
759769

760770

761771
def inverse(input: VSATensor) -> VSATensor:

torchhd/structures.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -944,7 +944,7 @@ def node_neighbors(self, input: VSATensor, outgoing=True) -> VSATensor:
944944
return functional.bind(self.value, input.inverse())
945945

946946
def contains(self, input: VSATensor) -> Tensor:
947-
"""Returns the cosine similarity of the input vector against the graph.
947+
"""Returns the normalized dot similarity of the input vector against the graph.
948948
949949
Args:
950950
input (Tensor): Hypervector to compare against the multiset.
@@ -955,7 +955,7 @@ def contains(self, input: VSATensor) -> Tensor:
955955
>>> G.contains(e)
956956
tensor(1.)
957957
"""
958-
return functional.cosine_similarity(input, self.value)
958+
return functional.dot_similarity(input, self.value) / self.value.size(-1)
959959

960960
def clear(self) -> None:
961961
"""Empties the graph.

torchhd/tests/basis_hv/test_circular_hv.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def test_device(self, dtype):
135135

136136
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
137137
hv = functional.circular(3, 52, device=device, dtype=dtype)
138-
assert hv.device == device
138+
assert hv.device.type == device.type
139139

140140
@pytest.mark.parametrize("dtype", torch_dtypes)
141141
@pytest.mark.parametrize("vsa", vsa_tensors)
@@ -148,7 +148,7 @@ def test_device(self, dtype, vsa):
148148

149149
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
150150
hv = functional.circular(3, 52, vsa, device=device, dtype=dtype)
151-
assert hv.device == device
151+
assert hv.device.type == device.type
152152

153153
def test_uses_default_dtype(self):
154154
hv = functional.circular(3, 52, "BSC")

torchhd/tests/basis_hv/test_empty_hv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def test_device(self, dtype, vsa):
7777

7878
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7979
hv = functional.empty(3, 52, vsa, device=device, dtype=dtype)
80-
assert hv.device == device
80+
assert hv.device.type == device.type
8181

8282
def test_uses_default_dtype(self):
8383
hv = functional.empty(3, 52, "BSC")

torchhd/tests/basis_hv/test_identity_hv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def test_device(self, dtype, vsa):
8282

8383
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8484
hv = functional.identity(3, 52, vsa, device=device, dtype=dtype)
85-
assert hv.device == device
85+
assert hv.device.type == device.type
8686

8787
def test_uses_default_dtype(self):
8888
hv = functional.identity(3, 52, "BSC")

torchhd/tests/basis_hv/test_level_hv.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
from ..utils import torch_dtypes, supported_dtype, vsa_tensors
3131

32-
seed = 2147483644
32+
seed = 2147483643
3333

3434

3535
class Testlevel:
@@ -129,7 +129,7 @@ def test_device(self, dtype, vsa):
129129

130130
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
131131
hv = functional.level(3, 52, vsa, device=device, dtype=dtype)
132-
assert hv.device == device
132+
assert hv.device.type == device.type
133133

134134
def test_uses_default_dtype(self):
135135
hv = functional.level(3, 52, "BSC")

torchhd/tests/basis_hv/test_random_hv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def test_device(self, dtype, vsa):
137137

138138
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
139139
hv = functional.random(3, 52, vsa, device=device, dtype=dtype)
140-
assert hv.device == device
140+
assert hv.device.type == device.type
141141

142142
def test_uses_default_dtype(self):
143143
hv = functional.random(3, 52, "BSC")

torchhd/tests/structures/test_distinct_sequence.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,15 @@
3434
class TestBindSequence:
3535
def test_creation_dim(self):
3636
S = structures.BindSequence(10000)
37-
assert torch.equal(S.value, torch.ones(10000))
37+
assert torch.allclose(S.value, torch.ones(10000))
3838

3939
def test_creation_tensor(self):
4040
generator = torch.Generator()
4141
generator.manual_seed(seed)
4242
hv = functional.random(len(letters), 10000, generator=generator)
4343

4444
S = structures.BindSequence(hv[0])
45-
assert torch.equal(S.value, hv[0])
45+
assert torch.allclose(S.value, hv[0])
4646

4747
def test_generator(self):
4848
generator = torch.Generator()

torchhd/tests/structures/test_fsa.py

Lines changed: 68 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import string
2727

2828
from torchhd import structures, functional
29+
from torchhd import MAPTensor
2930

3031
seed = 2147483644
3132
seed1 = 2147483643
@@ -35,7 +36,7 @@
3536
class TestFSA:
3637
def test_creation_dim(self):
3738
F = structures.FiniteStateAutomata(10000)
38-
assert torch.equal(F.value, torch.zeros(10000))
39+
assert torch.allclose(F.value, torch.zeros(10000))
3940

4041
def test_generator(self):
4142
generator = torch.Generator()
@@ -49,37 +50,81 @@ def test_generator(self):
4950
assert (hv1 == hv2).min().item()
5051

5152
def test_add_transition(self):
52-
generator = torch.Generator()
53-
generator1 = torch.Generator()
54-
generator.manual_seed(seed)
55-
generator1.manual_seed(seed1)
56-
tokens = functional.random(10, 10, generator=generator)
57-
actions = functional.random(10, 10, generator=generator1)
53+
tokens = MAPTensor(
54+
[
55+
[1.0, 1.0, -1.0, 1.0, 1.0, 1.0, -1.0, 1.0, -1.0, 1.0],
56+
[1.0, -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, -1.0, 1.0],
57+
[-1.0, 1.0, -1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0],
58+
[-1.0, -1.0, 1.0, 1.0, -1.0, -1.0, 1.0, -1.0, 1.0, -1.0],
59+
[1.0, -1.0, 1.0, -1.0, 1.0, -1.0, -1.0, 1.0, -1.0, -1.0],
60+
[1.0, 1.0, 1.0, 1.0, -1.0, -1.0, -1.0, -1.0, 1.0, -1.0],
61+
[1.0, 1.0, 1.0, 1.0, -1.0, 1.0, 1.0, 1.0, 1.0, -1.0],
62+
[-1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, -1.0, -1.0, 1.0],
63+
[-1.0, -1.0, 1.0, -1.0, -1.0, -1.0, 1.0, -1.0, 1.0, 1.0],
64+
[-1.0, -1.0, 1.0, -1.0, 1.0, 1.0, 1.0, -1.0, -1.0, 1.0],
65+
]
66+
)
67+
states = MAPTensor(
68+
[
69+
[1.0, 1.0, -1.0, -1.0, 1.0, 1.0, -1.0, -1.0, 1.0, 1.0],
70+
[-1.0, -1.0, -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, -1.0],
71+
[1.0, 1.0, -1.0, -1.0, 1.0, -1.0, 1.0, 1.0, 1.0, 1.0],
72+
[1.0, -1.0, 1.0, -1.0, 1.0, -1.0, -1.0, -1.0, 1.0, 1.0],
73+
[1.0, 1.0, 1.0, 1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 1.0],
74+
[-1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, -1.0, 1.0, -1.0],
75+
[1.0, -1.0, -1.0, -1.0, 1.0, 1.0, 1.0, -1.0, -1.0, -1.0],
76+
[1.0, -1.0, 1.0, 1.0, 1.0, -1.0, -1.0, -1.0, -1.0, -1.0],
77+
[1.0, 1.0, 1.0, -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
78+
[1.0, 1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, 1.0],
79+
]
80+
)
5881

5982
F = structures.FiniteStateAutomata(10)
6083

61-
F.add_transition(tokens[0], actions[1], actions[2])
62-
assert torch.equal(
84+
F.add_transition(tokens[0], states[1], states[2])
85+
assert torch.allclose(
6386
F.value,
64-
torch.tensor([1.0, 1.0, -1.0, 1.0, 1.0, 1.0, -1.0, -1.0, -1.0, 1.0]),
87+
MAPTensor([-1.0, -1.0, 1.0, -1.0, -1.0, 1.0, 1.0, 1.0, -1.0, -1.0]),
6588
)
66-
F.add_transition(tokens[1], actions[1], actions[3])
67-
assert torch.equal(
68-
F.value, torch.tensor([0.0, 0.0, -2.0, 2.0, 0.0, 2.0, 0.0, -2.0, -2.0, 0.0])
89+
F.add_transition(tokens[1], states[1], states[3])
90+
assert torch.allclose(
91+
F.value, MAPTensor([-2.0, 0.0, 2.0, 0.0, -2.0, 2.0, 0.0, 0.0, 0.0, -2.0])
6992
)
70-
F.add_transition(tokens[2], actions[1], actions[3])
71-
assert torch.equal(
93+
F.add_transition(tokens[2], states[1], states[3])
94+
assert torch.allclose(
7295
F.value,
73-
torch.tensor([1.0, 1.0, -3.0, 1.0, 1.0, 3.0, -1.0, -1.0, -1.0, 1.0]),
96+
MAPTensor([-1.0, -1.0, 1.0, -1.0, -3.0, 1.0, -1.0, 1.0, -1.0, -1.0]),
7497
)
7598

7699
def test_transition(self):
77-
generator = torch.Generator()
78-
generator1 = torch.Generator()
79-
generator.manual_seed(seed)
80-
generator1.manual_seed(seed1)
81-
tokens = functional.random(10, 10, generator=generator)
82-
states = functional.random(10, 10, generator=generator1)
100+
tokens = MAPTensor(
101+
[
102+
[1.0, 1.0, -1.0, 1.0, 1.0, 1.0, -1.0, 1.0, -1.0, 1.0],
103+
[1.0, -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, -1.0, 1.0],
104+
[-1.0, 1.0, -1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0],
105+
[-1.0, -1.0, 1.0, 1.0, -1.0, -1.0, 1.0, -1.0, 1.0, -1.0],
106+
[1.0, -1.0, 1.0, -1.0, 1.0, -1.0, -1.0, 1.0, -1.0, -1.0],
107+
[1.0, 1.0, 1.0, 1.0, -1.0, -1.0, -1.0, -1.0, 1.0, -1.0],
108+
[1.0, 1.0, 1.0, 1.0, -1.0, 1.0, 1.0, 1.0, 1.0, -1.0],
109+
[-1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, -1.0, -1.0, 1.0],
110+
[-1.0, -1.0, 1.0, -1.0, -1.0, -1.0, 1.0, -1.0, 1.0, 1.0],
111+
[-1.0, -1.0, 1.0, -1.0, 1.0, 1.0, 1.0, -1.0, -1.0, 1.0],
112+
]
113+
)
114+
states = MAPTensor(
115+
[
116+
[1.0, 1.0, -1.0, -1.0, 1.0, 1.0, -1.0, -1.0, 1.0, 1.0],
117+
[-1.0, -1.0, -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, -1.0],
118+
[1.0, 1.0, -1.0, -1.0, 1.0, -1.0, 1.0, 1.0, 1.0, 1.0],
119+
[1.0, -1.0, 1.0, -1.0, 1.0, -1.0, -1.0, -1.0, 1.0, 1.0],
120+
[1.0, 1.0, 1.0, 1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 1.0],
121+
[-1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, -1.0, 1.0, -1.0],
122+
[1.0, -1.0, -1.0, -1.0, 1.0, 1.0, 1.0, -1.0, -1.0, -1.0],
123+
[1.0, -1.0, 1.0, 1.0, 1.0, -1.0, -1.0, -1.0, -1.0, -1.0],
124+
[1.0, 1.0, 1.0, -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
125+
[1.0, 1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, 1.0],
126+
]
127+
)
83128

84129
F = structures.FiniteStateAutomata(10)
85130

@@ -121,6 +166,6 @@ def test_clear(self):
121166
F.add_transition(tokens[2], states[1], states[5])
122167

123168
F.clear()
124-
assert torch.equal(
169+
assert torch.allclose(
125170
F.value, torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
126171
)

torchhd/tests/structures/test_graph.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -28,22 +28,22 @@
2828
from torchhd import structures, functional
2929
from torchhd.tensors.map import MAPTensor
3030

31-
seed = 2147483644
31+
seed = 2147483645
3232
letters = list(string.ascii_lowercase)
3333

3434

3535
class TestGraph:
3636
def test_creation_dim(self):
3737
G = structures.Graph(10000, directed=True)
38-
assert torch.equal(G.value, torch.zeros(10000))
38+
assert torch.allclose(G.value, torch.zeros(10000))
3939

4040
def test_creation_tensor(self):
4141
generator = torch.Generator()
4242
generator.manual_seed(seed)
4343
hv = functional.random(len(letters), 10000, generator=generator)
4444
g = functional.bind(hv[0], hv[1])
4545
G = structures.Graph(g)
46-
assert torch.equal(G.value, g)
46+
assert torch.allclose(G.value, g)
4747

4848
def test_generator(self):
4949
generator = torch.Generator()
@@ -68,22 +68,22 @@ def test_add_edge(self):
6868
).as_subclass(MAPTensor)
6969

7070
G.add_edge(hv[0], hv[1])
71-
assert torch.equal(
71+
assert torch.allclose(
7272
G.value, torch.tensor([-1.0, -1.0, -1.0, 1.0, -1.0, -1.0, 1.0, -1.0])
7373
)
7474
G.add_edge(hv[2], hv[3])
75-
assert torch.equal(
75+
assert torch.allclose(
7676
G.value, torch.tensor([-2.0, -2.0, 0.0, 2.0, -2.0, 0.0, 2.0, -2.0])
7777
)
7878

7979
GD = structures.Graph(8, directed=True)
8080

8181
GD.add_edge(hv[0], hv[1])
82-
assert torch.equal(
82+
assert torch.allclose(
8383
GD.value, torch.tensor([-1.0, 1.0, -1.0, -1.0, -1.0, -1.0, 1.0, -1.0])
8484
)
8585
GD.add_edge(hv[2], hv[3])
86-
assert torch.equal(
86+
assert torch.allclose(
8787
GD.value, torch.tensor([0.0, 0.0, 0.0, -2.0, 0.0, -2.0, 2.0, -2.0])
8888
)
8989

@@ -99,23 +99,23 @@ def test_encode_edge(self):
9999
).as_subclass(MAPTensor)
100100

101101
e1 = G.encode_edge(hv[0], hv[1])
102-
assert torch.equal(
102+
assert torch.allclose(
103103
e1, torch.tensor([-1.0, -1.0, -1.0, 1.0, -1.0, -1.0, 1.0, -1.0])
104104
)
105105
e2 = G.encode_edge(hv[2], hv[3])
106-
assert torch.equal(
106+
assert torch.allclose(
107107
e2, torch.tensor([-1.0, -1.0, 1.0, 1.0, -1.0, 1.0, 1.0, -1.0])
108108
)
109109

110110
GD = structures.Graph(8, directed=True)
111111

112112
e1 = GD.encode_edge(hv[0], hv[1])
113-
assert torch.equal(
113+
assert torch.allclose(
114114
e1, torch.tensor([-1.0, 1.0, -1.0, -1.0, -1.0, -1.0, 1.0, -1.0])
115115
)
116116
e2 = GD.encode_edge(hv[2], hv[3])
117117
print(e2)
118-
assert torch.equal(
118+
assert torch.allclose(
119119
e2, torch.tensor([1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0])
120120
)
121121

@@ -157,8 +157,8 @@ def test_node_neighbors(self):
157157
def test_contains(self):
158158
generator = torch.Generator()
159159
generator.manual_seed(seed)
160-
hv = functional.random(4, 8, generator=generator)
161-
G = structures.Graph(8)
160+
hv = functional.random(4, 1000, generator=generator)
161+
G = structures.Graph(1000)
162162

163163
e1 = G.encode_edge(hv[0], hv[1])
164164
e2 = G.encode_edge(hv[0], hv[2])
@@ -172,7 +172,7 @@ def test_contains(self):
172172
assert G.contains(e2) > torch.tensor([0.6])
173173
assert G.contains(e3) < torch.tensor(0.6)
174174

175-
GD = structures.Graph(8, directed=True)
175+
GD = structures.Graph(1000, directed=True)
176176

177177
ee1 = GD.encode_edge(hv[0], hv[1])
178178
ee2 = GD.encode_edge(hv[0], hv[2])
@@ -200,16 +200,16 @@ def test_clear(self):
200200

201201
G.clear()
202202

203-
assert torch.equal(
203+
assert torch.allclose(
204204
G.value, torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
205205
)
206206

207207
def test_from_edges(self):
208208
generator = torch.Generator()
209209
generator.manual_seed(seed)
210210

211-
hv = functional.random(4, 8, generator=generator)
212-
edges = torch.empty(2, 3, 8).as_subclass(MAPTensor)
211+
hv = functional.random(4, 1000, generator=generator)
212+
edges = torch.empty(2, 3, 1000).as_subclass(MAPTensor)
213213
edges[0, 0] = hv[0]
214214
edges[1, 0] = hv[1]
215215
edges[0, 1] = hv[0]

0 commit comments

Comments
 (0)