Skip to content

Commit 909e12d

Browse files
authored
Fix the nonlinear bug and add mnist example (#110)
* fix the nonlinear bug and add mnist example * remove unused variable and add one comment
1 parent 8e21849 commit 909e12d

File tree

2 files changed

+78
-0
lines changed

2 files changed

+78
-0
lines changed

examples/mnist_nonlinear.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# This is an example of using nonlinear encoding on the MNIST dataset
2+
import torch
3+
import torch.nn as nn
4+
import torch.nn.functional as F
5+
import torchvision
6+
from torchvision.datasets import MNIST
7+
8+
# Note: this example requires the torchmetrics library: https://torchmetrics.readthedocs.io
9+
import torchmetrics
10+
from tqdm import tqdm
11+
12+
import torchhd
13+
from torchhd import embeddings
14+
15+
16+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17+
print("Using {} device".format(device))
18+
19+
DIMENSIONS = 10000
20+
IMG_SIZE = 28
21+
BATCH_SIZE = 1 # for GPUs with enough memory we can process multiple images at ones
22+
23+
transform = torchvision.transforms.ToTensor()
24+
25+
train_ds = MNIST("../data", train=True, transform=transform, download=True)
26+
train_ld = torch.utils.data.DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
27+
28+
test_ds = MNIST("../data", train=False, transform=transform, download=True)
29+
test_ld = torch.utils.data.DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)
30+
31+
class Model(nn.Module):
32+
def __init__(self, num_classes, size):
33+
super(Model, self).__init__()
34+
35+
self.flatten = torch.nn.Flatten()
36+
37+
self.nonlinear_projection = embeddings.Sinusoid(size * size, DIMENSIONS)
38+
39+
self.classify = nn.Linear(DIMENSIONS, num_classes, bias=False)
40+
self.classify.weight.data.fill_(0.0)
41+
42+
def encode(self, x):
43+
x = self.flatten(x)
44+
sample_hv = self.nonlinear_projection(x)
45+
return torchhd.hard_quantize(sample_hv)
46+
47+
def forward(self, x):
48+
enc = self.encode(x)
49+
logit = self.classify(enc)
50+
return logit
51+
52+
53+
model = Model(len(train_ds.classes), IMG_SIZE)
54+
model = model.to(device)
55+
56+
with torch.no_grad():
57+
for samples, labels in tqdm(train_ld, desc="Training"):
58+
samples = samples.to(device)
59+
labels = labels.to(device)
60+
61+
samples_hv = model.encode(samples)
62+
model.classify.weight[labels] += samples_hv
63+
64+
model.classify.weight[:] = F.normalize(model.classify.weight)
65+
66+
accuracy = torchmetrics.Accuracy()
67+
68+
69+
with torch.no_grad():
70+
for samples, labels in tqdm(test_ld, desc="Testing"):
71+
samples = samples.to(device)
72+
73+
outputs = model(samples)
74+
predictions = torch.argmax(outputs, dim=-1)
75+
accuracy.update(predictions.cpu(), labels)
76+
77+
print(f"Testing accuracy of {(accuracy.compute().item() * 100):.3f}%")

torchhd/embeddings.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,7 @@ def __init__(
409409

410410
def reset_parameters(self) -> None:
411411
nn.init.normal_(self.weight, 0, 1)
412+
self.weight.data.copy_(F.normalize(self.weight.data))
412413
nn.init.uniform_(self.bias, 0, 2 * math.pi)
413414

414415
def forward(self, input: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)