2626import string
2727
2828from torchhd import structures , functional
29+ from torchhd import MAPTensor
2930
3031seed = 2147483644
3132seed1 = 2147483643
3536class TestFSA :
3637 def test_creation_dim (self ):
3738 F = structures .FiniteStateAutomata (10000 )
38- assert torch .equal (F .value , torch .zeros (10000 ))
39+ assert torch .allclose (F .value , torch .zeros (10000 ))
3940
4041 def test_generator (self ):
4142 generator = torch .Generator ()
@@ -49,37 +50,81 @@ def test_generator(self):
4950 assert (hv1 == hv2 ).min ().item ()
5051
5152 def test_add_transition (self ):
52- generator = torch .Generator ()
53- generator1 = torch .Generator ()
54- generator .manual_seed (seed )
55- generator1 .manual_seed (seed1 )
56- tokens = functional .random (10 , 10 , generator = generator )
57- actions = functional .random (10 , 10 , generator = generator1 )
53+ tokens = MAPTensor (
54+ [
55+ [1.0 , 1.0 , - 1.0 , 1.0 , 1.0 , 1.0 , - 1.0 , 1.0 , - 1.0 , 1.0 ],
56+ [1.0 , - 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , - 1.0 , 1.0 ],
57+ [- 1.0 , 1.0 , - 1.0 , - 1.0 , 1.0 , - 1.0 , 1.0 , - 1.0 , 1.0 , - 1.0 ],
58+ [- 1.0 , - 1.0 , 1.0 , 1.0 , - 1.0 , - 1.0 , 1.0 , - 1.0 , 1.0 , - 1.0 ],
59+ [1.0 , - 1.0 , 1.0 , - 1.0 , 1.0 , - 1.0 , - 1.0 , 1.0 , - 1.0 , - 1.0 ],
60+ [1.0 , 1.0 , 1.0 , 1.0 , - 1.0 , - 1.0 , - 1.0 , - 1.0 , 1.0 , - 1.0 ],
61+ [1.0 , 1.0 , 1.0 , 1.0 , - 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , - 1.0 ],
62+ [- 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , - 1.0 , - 1.0 , 1.0 ],
63+ [- 1.0 , - 1.0 , 1.0 , - 1.0 , - 1.0 , - 1.0 , 1.0 , - 1.0 , 1.0 , 1.0 ],
64+ [- 1.0 , - 1.0 , 1.0 , - 1.0 , 1.0 , 1.0 , 1.0 , - 1.0 , - 1.0 , 1.0 ],
65+ ]
66+ )
67+ states = MAPTensor (
68+ [
69+ [1.0 , 1.0 , - 1.0 , - 1.0 , 1.0 , 1.0 , - 1.0 , - 1.0 , 1.0 , 1.0 ],
70+ [- 1.0 , - 1.0 , - 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , - 1.0 ],
71+ [1.0 , 1.0 , - 1.0 , - 1.0 , 1.0 , - 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ],
72+ [1.0 , - 1.0 , 1.0 , - 1.0 , 1.0 , - 1.0 , - 1.0 , - 1.0 , 1.0 , 1.0 ],
73+ [1.0 , 1.0 , 1.0 , 1.0 , - 1.0 , - 1.0 , - 1.0 , - 1.0 , - 1.0 , 1.0 ],
74+ [- 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , - 1.0 , 1.0 , - 1.0 ],
75+ [1.0 , - 1.0 , - 1.0 , - 1.0 , 1.0 , 1.0 , 1.0 , - 1.0 , - 1.0 , - 1.0 ],
76+ [1.0 , - 1.0 , 1.0 , 1.0 , 1.0 , - 1.0 , - 1.0 , - 1.0 , - 1.0 , - 1.0 ],
77+ [1.0 , 1.0 , 1.0 , - 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ],
78+ [1.0 , 1.0 , 1.0 , - 1.0 , 1.0 , - 1.0 , 1.0 , - 1.0 , 1.0 , 1.0 ],
79+ ]
80+ )
5881
5982 F = structures .FiniteStateAutomata (10 )
6083
61- F .add_transition (tokens [0 ], actions [1 ], actions [2 ])
62- assert torch .equal (
84+ F .add_transition (tokens [0 ], states [1 ], states [2 ])
85+ assert torch .allclose (
6386 F .value ,
64- torch . tensor ([ 1.0 , 1.0 , - 1.0 , 1.0 , 1.0 , 1.0 , - 1.0 , - 1.0 , - 1.0 , 1.0 ]),
87+ MAPTensor ([ - 1.0 , - 1.0 , 1.0 , - 1.0 , - 1.0 , 1.0 , 1.0 , 1.0 , - 1.0 , - 1.0 ]),
6588 )
66- F .add_transition (tokens [1 ], actions [1 ], actions [3 ])
67- assert torch .equal (
68- F .value , torch . tensor ([ 0 .0 , 0.0 , - 2.0 , 2 .0 , 0 .0 , 2.0 , 0.0 , - 2 .0 , - 2 .0 , 0 .0 ])
89+ F .add_transition (tokens [1 ], states [1 ], states [3 ])
90+ assert torch .allclose (
91+ F .value , MAPTensor ([ - 2 .0 , 0.0 , 2.0 , 0 .0 , - 2 .0 , 2.0 , 0.0 , 0 .0 , 0 .0 , - 2 .0 ])
6992 )
70- F .add_transition (tokens [2 ], actions [1 ], actions [3 ])
71- assert torch .equal (
93+ F .add_transition (tokens [2 ], states [1 ], states [3 ])
94+ assert torch .allclose (
7295 F .value ,
73- torch . tensor ([ 1.0 , 1.0 , - 3 .0 , 1.0 , 1 .0 , 3 .0 , - 1.0 , - 1.0 , - 1.0 , 1.0 ]),
96+ MAPTensor ([ - 1.0 , - 1.0 , 1 .0 , - 1.0 , - 3 .0 , 1 .0 , - 1.0 , 1.0 , - 1.0 , - 1.0 ]),
7497 )
7598
7699 def test_transition (self ):
77- generator = torch .Generator ()
78- generator1 = torch .Generator ()
79- generator .manual_seed (seed )
80- generator1 .manual_seed (seed1 )
81- tokens = functional .random (10 , 10 , generator = generator )
82- states = functional .random (10 , 10 , generator = generator1 )
100+ tokens = MAPTensor (
101+ [
102+ [1.0 , 1.0 , - 1.0 , 1.0 , 1.0 , 1.0 , - 1.0 , 1.0 , - 1.0 , 1.0 ],
103+ [1.0 , - 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , - 1.0 , 1.0 ],
104+ [- 1.0 , 1.0 , - 1.0 , - 1.0 , 1.0 , - 1.0 , 1.0 , - 1.0 , 1.0 , - 1.0 ],
105+ [- 1.0 , - 1.0 , 1.0 , 1.0 , - 1.0 , - 1.0 , 1.0 , - 1.0 , 1.0 , - 1.0 ],
106+ [1.0 , - 1.0 , 1.0 , - 1.0 , 1.0 , - 1.0 , - 1.0 , 1.0 , - 1.0 , - 1.0 ],
107+ [1.0 , 1.0 , 1.0 , 1.0 , - 1.0 , - 1.0 , - 1.0 , - 1.0 , 1.0 , - 1.0 ],
108+ [1.0 , 1.0 , 1.0 , 1.0 , - 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , - 1.0 ],
109+ [- 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , - 1.0 , - 1.0 , 1.0 ],
110+ [- 1.0 , - 1.0 , 1.0 , - 1.0 , - 1.0 , - 1.0 , 1.0 , - 1.0 , 1.0 , 1.0 ],
111+ [- 1.0 , - 1.0 , 1.0 , - 1.0 , 1.0 , 1.0 , 1.0 , - 1.0 , - 1.0 , 1.0 ],
112+ ]
113+ )
114+ states = MAPTensor (
115+ [
116+ [1.0 , 1.0 , - 1.0 , - 1.0 , 1.0 , 1.0 , - 1.0 , - 1.0 , 1.0 , 1.0 ],
117+ [- 1.0 , - 1.0 , - 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , - 1.0 ],
118+ [1.0 , 1.0 , - 1.0 , - 1.0 , 1.0 , - 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ],
119+ [1.0 , - 1.0 , 1.0 , - 1.0 , 1.0 , - 1.0 , - 1.0 , - 1.0 , 1.0 , 1.0 ],
120+ [1.0 , 1.0 , 1.0 , 1.0 , - 1.0 , - 1.0 , - 1.0 , - 1.0 , - 1.0 , 1.0 ],
121+ [- 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , - 1.0 , 1.0 , - 1.0 ],
122+ [1.0 , - 1.0 , - 1.0 , - 1.0 , 1.0 , 1.0 , 1.0 , - 1.0 , - 1.0 , - 1.0 ],
123+ [1.0 , - 1.0 , 1.0 , 1.0 , 1.0 , - 1.0 , - 1.0 , - 1.0 , - 1.0 , - 1.0 ],
124+ [1.0 , 1.0 , 1.0 , - 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ],
125+ [1.0 , 1.0 , 1.0 , - 1.0 , 1.0 , - 1.0 , 1.0 , - 1.0 , 1.0 , 1.0 ],
126+ ]
127+ )
83128
84129 F = structures .FiniteStateAutomata (10 )
85130
@@ -121,6 +166,6 @@ def test_clear(self):
121166 F .add_transition (tokens [2 ], states [1 ], states [5 ])
122167
123168 F .clear ()
124- assert torch .equal (
169+ assert torch .allclose (
125170 F .value , torch .tensor ([0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 ])
126171 )
0 commit comments