|
7 | 7 |
|
8 | 8 | @pytest.mark.parametrize('tensor', tensors) |
9 | 9 | def test_grid_cluster_cpu(tensor): |
10 | | - position = Tensor(tensor, [0, 9, 2, 8, 3]) |
| 10 | + position = Tensor(tensor, [2, 6]) |
11 | 11 | size = torch.LongTensor([5]) |
12 | | - expected = torch.LongTensor([0, 1, 0, 1, 0]) |
13 | | - output = grid_cluster(position, size) |
| 12 | + expected = torch.LongTensor([0, 0]) |
| 13 | + output, _ = grid_cluster(position, size) |
| 14 | + assert output.tolist() == expected.tolist() |
| 15 | + |
| 16 | + expected = torch.LongTensor([0, 1]) |
| 17 | + output, _ = grid_cluster(position, size, offset=0) |
| 18 | + assert output.tolist() == expected.tolist() |
| 19 | + |
| 20 | + position = Tensor(tensor, [0, 17, 2, 8, 3]) |
| 21 | + expected = torch.LongTensor([0, 2, 0, 1, 0]) |
| 22 | + output, _ = grid_cluster(position, size) |
| 23 | + assert output.tolist() == expected.tolist() |
| 24 | + |
| 25 | + output, _ = grid_cluster(position, size, fake_nodes=True) |
| 26 | + expected = torch.LongTensor([0, 3, 0, 1, 0]) |
14 | 27 | assert output.tolist() == expected.tolist() |
15 | 28 |
|
16 | 29 | position = Tensor(tensor, [[0, 0], [9, 9], [2, 8], [2, 2], [8, 3]]) |
17 | 30 | size = torch.LongTensor([5, 5]) |
18 | 31 | expected = torch.LongTensor([0, 3, 1, 0, 2]) |
19 | | - output = grid_cluster(position, size) |
| 32 | + output, _ = grid_cluster(position, size) |
20 | 33 | assert output.tolist() == expected.tolist() |
21 | 34 |
|
22 | | - position = Tensor(tensor, [[0, 9, 2, 2, 8], [0, 9, 8, 2, 3]]).t() |
23 | | - output = grid_cluster(position, size) |
| 35 | + position = Tensor(tensor, [[0, 11, 2, 2, 8], [0, 9, 8, 2, 3]]).t() |
| 36 | + output, _ = grid_cluster(position, size) |
24 | 37 | assert output.tolist() == expected.tolist() |
25 | 38 |
|
26 | | - output = grid_cluster(position.expand(2, 5, 2), size) |
| 39 | + output, _ = grid_cluster(position.expand(2, 5, 2), size) |
27 | 40 | assert output.tolist() == expected.expand(2, 5).tolist() |
28 | 41 |
|
29 | 42 | position = position.repeat(2, 1) |
30 | 43 | batch = torch.LongTensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) |
31 | 44 | expected = torch.LongTensor([0, 3, 1, 0, 2, 4, 7, 5, 4, 6]) |
32 | | - expected_batch = torch.LongTensor([0, 0, 0, 0, 1, 1, 1, 1]) |
33 | | - output, reduced_batch = grid_cluster(position, size, batch) |
| 45 | + expected_batch2 = torch.LongTensor([0, 0, 0, 0, 1, 1, 1, 1]) |
| 46 | + output, batch2 = grid_cluster(position, size, batch) |
| 47 | + assert output.tolist() == expected.tolist() |
| 48 | + assert batch2.tolist() == expected_batch2.tolist() |
| 49 | + |
| 50 | + output, C = grid_cluster(position, size, batch, fake_nodes=True) |
| 51 | + expected = torch.LongTensor([0, 5, 1, 0, 2, 6, 11, 7, 6, 8]) |
34 | 52 | assert output.tolist() == expected.tolist() |
35 | | - assert reduced_batch.tolist() == expected_batch.tolist() |
| 53 | + assert C == 6 |
36 | 54 |
|
37 | 55 |
|
38 | 56 | @pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA') |
39 | 57 | @pytest.mark.parametrize('tensor', tensors) |
40 | 58 | def test_grid_cluster_gpu(tensor): # pragma: no cover |
41 | | - position = Tensor(tensor, [0, 9, 2, 8, 3]).cuda() |
| 59 | + position = Tensor(tensor, [2, 6]).cuda() |
42 | 60 | size = torch.cuda.LongTensor([5]) |
43 | | - expected = torch.cuda.LongTensor([0, 1, 0, 1, 0]) |
44 | | - output = grid_cluster(position, size) |
| 61 | + expected = torch.LongTensor([0, 0]) |
| 62 | + output, _ = grid_cluster(position, size) |
| 63 | + assert output.cpu().tolist() == expected.tolist() |
| 64 | + |
| 65 | + expected = torch.LongTensor([0, 1]) |
| 66 | + output, _ = grid_cluster(position, size, offset=0) |
| 67 | + assert output.cpu().tolist() == expected.tolist() |
| 68 | + |
| 69 | + position = Tensor(tensor, [0, 17, 2, 8, 3]).cuda() |
| 70 | + expected = torch.LongTensor([0, 2, 0, 1, 0]) |
| 71 | + output, _ = grid_cluster(position, size) |
| 72 | + assert output.cpu().tolist() == expected.tolist() |
| 73 | + |
| 74 | + output, _ = grid_cluster(position, size, fake_nodes=True) |
| 75 | + expected = torch.LongTensor([0, 3, 0, 1, 0]) |
45 | 76 | assert output.cpu().tolist() == expected.tolist() |
46 | 77 |
|
47 | 78 | position = Tensor(tensor, [[0, 0], [9, 9], [2, 8], [2, 2], [8, 3]]) |
48 | 79 | position = position.cuda() |
49 | 80 | size = torch.cuda.LongTensor([5, 5]) |
50 | | - expected = torch.cuda.LongTensor([0, 3, 1, 0, 2]) |
51 | | - output = grid_cluster(position, size) |
| 81 | + expected = torch.LongTensor([0, 3, 1, 0, 2]) |
| 82 | + output, _ = grid_cluster(position, size) |
52 | 83 | assert output.cpu().tolist() == expected.tolist() |
53 | 84 |
|
54 | | - position = Tensor(tensor, [[0, 9, 2, 2, 8], [0, 9, 8, 2, 3]]) |
| 85 | + position = Tensor(tensor, [[0, 11, 2, 2, 8], [0, 9, 8, 2, 3]]) |
55 | 86 | position = position.cuda().t() |
56 | | - output = grid_cluster(position, size) |
| 87 | + output, _ = grid_cluster(position, size) |
57 | 88 | assert output.cpu().tolist() == expected.tolist() |
58 | 89 |
|
59 | | - output = grid_cluster(position.expand(2, 5, 2), size) |
| 90 | + output, _ = grid_cluster(position.expand(2, 5, 2), size) |
60 | 91 | assert output.tolist() == expected.expand(2, 5).tolist() |
61 | 92 |
|
62 | 93 | position = position.repeat(2, 1) |
63 | 94 | batch = torch.cuda.LongTensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) |
64 | 95 | expected = torch.LongTensor([0, 3, 1, 0, 2, 4, 7, 5, 4, 6]) |
65 | | - expected_batch = torch.LongTensor([0, 0, 0, 0, 1, 1, 1, 1]) |
66 | | - output, reduced_batch = grid_cluster(position, size, batch) |
| 96 | + expected_batch2 = torch.LongTensor([0, 0, 0, 0, 1, 1, 1, 1]) |
| 97 | + output, batch2 = grid_cluster(position, size, batch) |
| 98 | + assert output.cpu().tolist() == expected.tolist() |
| 99 | + assert batch2.cpu().tolist() == expected_batch2.tolist() |
| 100 | + |
| 101 | + output, C = grid_cluster(position, size, batch, fake_nodes=True) |
| 102 | + expected = torch.LongTensor([0, 5, 1, 0, 2, 6, 11, 7, 6, 8]) |
67 | 103 | assert output.cpu().tolist() == expected.tolist() |
68 | | - assert reduced_batch.cpu().tolist() == expected_batch.tolist() |
| 104 | + assert C == 6 |
0 commit comments