Skip to content

Commit 279e2ef

Browse files
Add Centroid classification model with iterative learning methods (#113)
* Add centroid classification model with iterative learning * [github-action] formatting fixes * Add models import to library root Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent f508a68 commit 279e2ef

File tree

14 files changed

+265
-253
lines changed

14 files changed

+265
-253
lines changed

docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ Torchhd is a Python library dedicated to *Hyperdimensional Computing* (also know
1818
torchhd
1919
embeddings
2020
structures
21+
models
2122
datasets
2223
utils
2324

docs/models.rst

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
.. _models:
2+
3+
torchhd.models
4+
==================
5+
6+
.. currentmodule:: torchhd.models
7+
8+
.. autosummary::
9+
:toctree: generated/
10+
:template: class.rst
11+
12+
Centroid
13+

examples/emg_hand_gestures.py

Lines changed: 23 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import torchhd
1111
from torchhd import embeddings
12+
from torchhd.models import Centroid
1213
from torchhd.datasets import EMGHandGestures
1314

1415
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -27,31 +28,23 @@ def transform(x):
2728
return x[SUBSAMPLES]
2829

2930

30-
class Model(nn.Module):
31-
def __init__(self, num_classes, timestamps, channels):
32-
super(Model, self).__init__()
31+
class Encoder(nn.Module):
32+
def __init__(self, out_features, timestamps, channels):
33+
super(Encoder, self).__init__()
3334

34-
self.channels = embeddings.Random(channels, DIMENSIONS)
35-
self.timestamps = embeddings.Random(timestamps, DIMENSIONS)
36-
self.signals = embeddings.Level(NUM_LEVELS, DIMENSIONS, high=20)
35+
self.channels = embeddings.Random(channels, out_features)
36+
self.timestamps = embeddings.Random(timestamps, out_features)
37+
self.signals = embeddings.Level(NUM_LEVELS, out_features, high=20)
3738

38-
self.classify = nn.Linear(DIMENSIONS, num_classes, bias=False)
39-
self.classify.weight.data.fill_(0.0)
40-
41-
def encode(self, x: torch.Tensor) -> torch.Tensor:
42-
signal = self.signals(x)
39+
def forward(self, input: torch.Tensor) -> torch.Tensor:
40+
signal = self.signals(input)
4341
samples = torchhd.bind(signal, self.channels.weight.unsqueeze(0))
4442
samples = torchhd.bind(signal, self.timestamps.weight.unsqueeze(1))
4543

4644
samples = torchhd.multiset(samples)
4745
sample_hv = torchhd.ngrams(samples, n=N_GRAM_SIZE)
4846
return torchhd.hard_quantize(sample_hv)
4947

50-
def forward(self, x: torch.Tensor) -> torch.Tensor:
51-
enc = self.encode(x)
52-
logit = self.classify(enc)
53-
return logit
54-
5548

5649
def experiment(subjects=[0]):
5750
print("List of subjects " + str(subjects))
@@ -66,29 +59,32 @@ def experiment(subjects=[0]):
6659
train_ld = data.DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
6760
test_ld = data.DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)
6861

62+
encode = Encoder(DIMENSIONS, ds[0][0].size(-2), ds[0][0].size(-1))
63+
encode = encode.to(device)
64+
6965
num_classes = len(ds.classes)
70-
model = Model(num_classes, ds[0][0].size(-2), ds[0][0].size(-1))
66+
model = Centroid(DIMENSIONS, num_classes)
7167
model = model.to(device)
7268

7369
with torch.no_grad():
74-
for samples, labels in tqdm(train_ld, desc="Training"):
70+
for samples, targets in tqdm(train_ld, desc="Training"):
7571
samples = samples.to(device)
76-
labels = labels.to(device)
72+
targets = targets.to(device)
7773

78-
samples_hv = model.encode(samples)
79-
model.classify.weight[labels] += samples_hv
80-
81-
model.classify.weight[:] = F.normalize(model.classify.weight)
74+
sample_hv = encode(samples)
75+
model.add(sample_hv, targets)
8276

8377
accuracy = torchmetrics.Accuracy("multiclass", num_classes=num_classes)
8478

8579
with torch.no_grad():
86-
for samples, labels in tqdm(test_ld, desc="Testing"):
80+
model.normalize()
81+
82+
for samples, targets in tqdm(test_ld, desc="Testing"):
8783
samples = samples.to(device)
8884

89-
outputs = model(samples)
90-
predictions = torch.argmax(outputs, dim=-1)
91-
accuracy.update(predictions.cpu(), labels)
85+
sample_hv = encode(samples)
86+
output = model(sample_hv, dot=True)
87+
accuracy.update(output.cpu(), targets)
9288

9389
print(f"Testing accuracy of {(accuracy.compute().item() * 100):.3f}%")
9490

examples/graphhd.py

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import torchhd
1313
from torchhd import embeddings
14+
from torchhd.models import Centroid
1415

1516
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1617
print("Using {} device".format(device))
@@ -80,55 +81,50 @@ def min_max_graph_size(graph_dataset):
8081
return min_num_nodes, max_num_nodes
8182

8283

83-
class Model(nn.Module):
84-
def __init__(self, num_classes, size):
85-
super(Model, self).__init__()
84+
class Encoder(nn.Module):
85+
def __init__(self, out_features, size):
86+
super(Encoder, self).__init__()
87+
self.out_features = out_features
88+
self.node_ids = embeddings.Random(size, out_features)
8689

87-
self.node_ids = embeddings.Random(size, DIMENSIONS)
88-
89-
self.classify = nn.Linear(DIMENSIONS, num_classes, bias=False)
90-
self.classify.weight.data.fill_(0.0)
91-
92-
def encode(self, x):
90+
def forward(self, x):
9391
pr = pagerank(x)
9492
pr_sort, pr_argsort = pr.sort()
9593

96-
node_id_hvs = torch.zeros((x.num_nodes, DIMENSIONS), device=device)
94+
node_id_hvs = torch.zeros((x.num_nodes, self.out_features), device=device)
9795
node_id_hvs[pr_argsort] = self.node_ids.weight[: x.num_nodes]
9896

9997
row, col = to_undirected(x.edge_index)
10098

10199
hvs = torchhd.bind(node_id_hvs[row], node_id_hvs[col])
102100
return torchhd.multiset(hvs)
103101

104-
def forward(self, x):
105-
enc = self.encode(x)
106-
logit = self.classify(enc)
107-
return logit
108-
109102

110103
min_graph_size, max_graph_size = min_max_graph_size(graphs)
111-
model = Model(graphs.num_classes, max_graph_size)
104+
encode = Encoder(DIMENSIONS, max_graph_size)
105+
encode = encode.to(device)
106+
107+
model = Centroid(DIMENSIONS, graphs.num_classes)
112108
model = model.to(device)
113109

114110
with torch.no_grad():
115111
for samples in tqdm(train_ld, desc="Training"):
116112
samples.edge_index = samples.edge_index.to(device)
117113
samples.y = samples.y.to(device)
118114

119-
samples_hv = model.encode(samples)
120-
model.classify.weight[samples.y] += samples_hv
121-
122-
model.classify.weight[:] = F.normalize(model.classify.weight)
115+
samples_hv = encode(samples).unsqueeze(0)
116+
model.add(samples_hv, samples.y)
123117

124118
accuracy = torchmetrics.Accuracy("multiclass", num_classes=graphs.num_classes)
125119

126120
with torch.no_grad():
121+
model.normalize()
122+
127123
for samples in tqdm(test_ld, desc="Testing"):
128124
samples.edge_index = samples.edge_index.to(device)
129125

130-
outputs = model(samples)
131-
predictions = torch.argmax(outputs, dim=-1).unsqueeze(0)
132-
accuracy.update(predictions.cpu(), samples.y)
126+
samples_hv = encode(samples).unsqueeze(0)
127+
outputs = model(samples_hv, dot=True)
128+
accuracy.update(outputs.cpu(), samples.y)
133129

134130
print(f"Testing accuracy of {(accuracy.compute().item() * 100):.3f}%")

examples/language_recognition.py

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import torchhd
1111
from torchhd import embeddings
12+
from torchhd.models import Centroid
1213
from torchhd.datasets import EuropeanLanguages as Languages
1314

1415
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -54,49 +55,43 @@ def transform(x: str) -> torch.Tensor:
5455
test_ld = data.DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)
5556

5657

57-
class Model(nn.Module):
58-
def __init__(self, num_classes, size):
59-
super(Model, self).__init__()
58+
class Encoder(nn.Module):
59+
def __init__(self, out_features, size):
60+
super(Encoder, self).__init__()
61+
self.symbol = embeddings.Random(size, out_features, padding_idx=PADDING_IDX)
6062

61-
self.symbol = embeddings.Random(size, DIMENSIONS, padding_idx=PADDING_IDX)
62-
63-
self.classify = nn.Linear(DIMENSIONS, num_classes, bias=False)
64-
self.classify.weight.data.fill_(0.0)
65-
66-
def encode(self, x):
63+
def forward(self, x):
6764
symbols = self.symbol(x)
6865
sample_hv = torchhd.ngrams(symbols, n=3)
6966
return torchhd.hard_quantize(sample_hv)
7067

71-
def forward(self, x):
72-
enc = self.encode(x)
73-
logit = self.classify(enc)
74-
return logit
7568

69+
encode = Encoder(DIMENSIONS, NUM_TOKENS)
70+
encode = encode.to(device)
7671

7772
num_classes = len(train_ds.classes)
78-
model = Model(num_classes, NUM_TOKENS)
73+
model = Centroid(DIMENSIONS, num_classes)
7974
model = model.to(device)
8075

8176
with torch.no_grad():
8277
for samples, labels in tqdm(train_ld, desc="Training"):
8378
samples = samples.to(device)
8479
labels = labels.to(device)
8580

86-
samples_hv = model.encode(samples)
87-
model.classify.weight[labels] += samples_hv
88-
89-
model.classify.weight[:] = F.normalize(model.classify.weight)
81+
samples_hv = encode(samples)
82+
model.add(samples_hv, labels)
9083

9184
accuracy = torchmetrics.Accuracy("multiclass", num_classes=num_classes)
9285

9386
with torch.no_grad():
87+
model.normalize()
88+
9489
for samples, labels in tqdm(test_ld, desc="Testing"):
9590
samples = samples.to(device)
9691
labels = labels.to(device)
9792

98-
outputs = model(samples)
99-
predictions = torch.argmax(outputs, dim=-1)
100-
accuracy.update(predictions.cpu(), labels)
93+
samples_hv = encode(samples)
94+
outputs = model(samples_hv, dot=True)
95+
accuracy.update(outputs.cpu(), labels)
10196

10297
print(f"Testing accuracy of {(accuracy.compute().item() * 100):.3f}%")

examples/mnist.py

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from tqdm import tqdm
1010

1111
import torchhd
12+
from torchhd.models import Centroid
1213
from torchhd import embeddings
1314

1415

@@ -29,52 +30,45 @@
2930
test_ld = torch.utils.data.DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)
3031

3132

32-
class Model(nn.Module):
33-
def __init__(self, num_classes, size):
34-
super(Model, self).__init__()
35-
33+
class Encoder(nn.Module):
34+
def __init__(self, out_features, size, levels):
35+
super(Encoder, self).__init__()
3636
self.flatten = torch.nn.Flatten()
37+
self.position = embeddings.Random(size * size, out_features)
38+
self.value = embeddings.Level(levels, out_features)
3739

38-
self.position = embeddings.Random(size * size, DIMENSIONS)
39-
self.value = embeddings.Level(NUM_LEVELS, DIMENSIONS)
40-
41-
self.classify = nn.Linear(DIMENSIONS, num_classes, bias=False)
42-
self.classify.weight.data.fill_(0.0)
43-
44-
def encode(self, x):
40+
def forward(self, x):
4541
x = self.flatten(x)
4642
sample_hv = torchhd.bind(self.position.weight, self.value(x))
4743
sample_hv = torchhd.multiset(sample_hv)
4844
return torchhd.hard_quantize(sample_hv)
4945

50-
def forward(self, x):
51-
enc = self.encode(x)
52-
logit = self.classify(enc)
53-
return logit
5446

47+
encode = Encoder(DIMENSIONS, IMG_SIZE, NUM_LEVELS)
48+
encode = encode.to(device)
5549

5650
num_classes = len(train_ds.classes)
57-
model = Model(num_classes, IMG_SIZE)
51+
model = Centroid(DIMENSIONS, num_classes)
5852
model = model.to(device)
5953

6054
with torch.no_grad():
6155
for samples, labels in tqdm(train_ld, desc="Training"):
6256
samples = samples.to(device)
6357
labels = labels.to(device)
6458

65-
samples_hv = model.encode(samples)
66-
model.classify.weight[labels] += samples_hv
67-
68-
model.classify.weight[:] = F.normalize(model.classify.weight)
59+
samples_hv = encode(samples)
60+
model.add(samples_hv, labels)
6961

7062
accuracy = torchmetrics.Accuracy("multiclass", num_classes=num_classes)
7163

7264
with torch.no_grad():
65+
model.normalize()
66+
7367
for samples, labels in tqdm(test_ld, desc="Testing"):
7468
samples = samples.to(device)
7569

76-
outputs = model(samples)
77-
predictions = torch.argmax(outputs, dim=-1)
78-
accuracy.update(predictions.cpu(), labels)
70+
samples_hv = encode(samples)
71+
outputs = model(samples_hv, dot=True)
72+
accuracy.update(outputs.cpu(), labels)
7973

8074
print(f"Testing accuracy of {(accuracy.compute().item() * 100):.3f}%")

0 commit comments

Comments
 (0)