|
1 | 1 | import math |
2 | | -from typing import Type, Union |
| 2 | +from typing import Type, Union, Callable |
3 | 3 | import torch |
4 | 4 | from torch import LongTensor, FloatTensor, Tensor |
5 | 5 | from collections import deque |
|
25 | 25 | "inverse", |
26 | 26 | "negative", |
27 | 27 | "cleanup", |
| 28 | + "create_random_permute", |
28 | 29 | "hard_quantize", |
29 | 30 | "soft_quantize", |
30 | 31 | "hamming_similarity", |
@@ -661,6 +662,50 @@ def permute(input: VSA_Model, *, shifts=1) -> VSA_Model: |
661 | 662 | return input.permute(shifts) |
662 | 663 |
|
663 | 664 |
|
| 665 | + |
| 666 | +def create_random_permute(dim: int) -> Callable[[VSA_Model, int], VSA_Model]: |
| 667 | + r"""Creates random permutation functions. |
| 668 | +
|
| 669 | + Args: |
| 670 | + dim (int): dimension of the hypervectors |
| 671 | +
|
| 672 | + Examples:: |
| 673 | +
|
| 674 | + >>> a = torchhd.random_hv(3, 10) |
| 675 | + >>> a |
| 676 | + tensor([[-1., 1., 1., 1., -1., -1., -1., -1., 1., -1.], |
| 677 | + [-1., -1., -1., 1., -1., 1., -1., -1., 1., -1.], |
| 678 | + [ 1., 1., 1., -1., -1., 1., -1., 1., 1., 1.]]) |
| 679 | + >>> p = torchhd.create_random_permute(10) |
| 680 | + >>> p(a, 2) |
| 681 | + tensor([[ 1., 1., -1., -1., -1., 1., -1., -1., 1., -1.], |
| 682 | + [ 1., -1., -1., -1., 1., 1., -1., -1., -1., -1.], |
| 683 | + [ 1., 1., 1., -1., 1., -1., -1., 1., 1., 1.]]) |
| 684 | + >>> p(a, -2) |
| 685 | + tensor([[-1., 1., 1., 1., -1., -1., -1., -1., 1., -1.], |
| 686 | + [-1., -1., -1., 1., -1., 1., -1., -1., 1., -1.], |
| 687 | + [ 1., 1., 1., -1., -1., 1., -1., 1., 1., 1.]]) |
| 688 | +
|
| 689 | + """ |
| 690 | + |
| 691 | + forward = torch.randperm(dim) |
| 692 | + backward = torch.empty_like(forward) |
| 693 | + backward[forward] = torch.arange(dim) |
| 694 | + |
| 695 | + def permute(input: VSA_Model, shifts: int = 1) -> VSA_Model: |
| 696 | + y = input |
| 697 | + if shifts > 0: |
| 698 | + for _ in range(shifts): |
| 699 | + y = y[..., forward] |
| 700 | + elif shifts < 0: |
| 701 | + for _ in range(shifts): |
| 702 | + y = y[..., backward] |
| 703 | + return y |
| 704 | + |
| 705 | + return permute |
| 706 | + |
| 707 | + |
| 708 | + |
664 | 709 | def inverse(input: VSA_Model) -> VSA_Model: |
665 | 710 | r"""Inverse for the binding operation. |
666 | 711 |
|
|
0 commit comments