|
28 | 28 | from torch import Tensor |
29 | 29 | from torch.nn.parameter import Parameter |
30 | 30 | import torch.nn.init as init |
31 | | -import torch.utils.data as data |
32 | | -from tqdm import tqdm |
33 | | - |
34 | 31 |
|
35 | 32 | import torchhd.functional as functional |
36 | | -import torchhd.datasets as datasets |
37 | 33 | import torchhd.embeddings as embeddings |
38 | 34 |
|
39 | 35 |
|
@@ -71,6 +67,7 @@ class Centroid(nn.Module): |
71 | 67 | >>> output.size() |
72 | 68 | torch.Size([128, 30]) |
73 | 69 | """ |
| 70 | + |
74 | 71 | __constants__ = ["in_features", "out_features"] |
75 | 72 | in_features: int |
76 | 73 | out_features: int |
@@ -108,6 +105,30 @@ def add(self, input: Tensor, target: Tensor, lr: float = 1.0) -> None: |
108 | 105 | """Adds the input vectors scaled by the lr to the target prototype vectors.""" |
109 | 106 | self.weight.index_add_(0, target, input, alpha=lr) |
110 | 107 |
|
| 108 | + @torch.no_grad() |
| 109 | + def add_adapt(self, input: Tensor, target: Tensor, lr: float = 1.0) -> None: |
| 110 | + r"""Only updates the prototype vectors on wrongly predicted inputs. |
| 111 | +
|
| 112 | + Implements the iterative training method as described in `AdaptHD: Adaptive Efficient Training for Brain-Inspired Hyperdimensional Computing <https://ieeexplore.ieee.org/document/8918974>`_. |
| 113 | +
|
| 114 | + Subtracts the input from the mispredicted class prototype scaled by the learning rate |
| 115 | + and adds the input to the target prototype scaled by the learning rate. |
| 116 | + """ |
| 117 | + logit = self(input) |
| 118 | + pred = logit.argmax(1) |
| 119 | + is_wrong = target != pred |
| 120 | + |
| 121 | + # cancel update if all predictions were correct |
| 122 | + if is_wrong.sum().item() == 0: |
| 123 | + return |
| 124 | + |
| 125 | + input = input[is_wrong] |
| 126 | + target = target[is_wrong] |
| 127 | + pred = pred[is_wrong] |
| 128 | + |
| 129 | + self.weight.index_add_(0, target, input, alpha=lr) |
| 130 | + self.weight.index_add_(0, pred, input, alpha=-lr) |
| 131 | + |
111 | 132 | @torch.no_grad() |
112 | 133 | def add_online(self, input: Tensor, target: Tensor, lr: float = 1.0) -> None: |
113 | 134 | r"""Only updates the prototype vectors on wrongly predicted inputs. |
@@ -137,23 +158,30 @@ def add_online(self, input: Tensor, target: Tensor, lr: float = 1.0) -> None: |
137 | 158 | alpha1 = 1.0 - logit.gather(1, target.unsqueeze(1)) |
138 | 159 | alpha2 = logit.gather(1, pred.unsqueeze(1)) - 1.0 |
139 | 160 |
|
140 | | - self.weight.index_add_(0, target, lr * alpha1 * input) |
141 | | - self.weight.index_add_(0, pred, lr * alpha2 * input) |
| 161 | + self.weight.index_add_(0, target, alpha1 * input, alpha=lr) |
| 162 | + self.weight.index_add_(0, pred, alpha2 * input, alpha=lr) |
142 | 163 |
|
143 | | - @torch.no_grad() |
144 | 164 | def normalize(self, eps=1e-12) -> None: |
145 | 165 | """Transforms all the class prototype vectors into unit vectors. |
146 | 166 |
|
147 | 167 | After calling this, inferences can be made more efficiently by specifying ``dot=True`` in the forward pass. |
148 | 168 | Training further after calling this method is not advised. |
149 | 169 | """ |
150 | 170 | norms = self.weight.norm(dim=1, keepdim=True) |
| 171 | + |
| 172 | + if torch.isclose(norms, torch.zeros_like(norms), equal_nan=True).any(): |
| 173 | + import warnings |
| 174 | + |
| 175 | + warnings.warn( |
| 176 | + "The norm of a prototype vector is nearly zero upon normalizing, this could indicate a bug." |
| 177 | + ) |
| 178 | + |
151 | 179 | norms.clamp_(min=eps) |
152 | 180 | self.weight.div_(norms) |
153 | 181 |
|
154 | 182 | def extra_repr(self) -> str: |
155 | 183 | return "in_features={}, out_features={}".format( |
156 | | - self.in_features, self.out_features is not None |
| 184 | + self.in_features, self.out_features |
157 | 185 | ) |
158 | 186 |
|
159 | 187 |
|
|
0 commit comments