|
20 | 20 | import ot |
21 | 21 |
|
22 | 22 |
|
23 | | -############################################################################## |
| 23 | +# |
24 | 24 | # Sample two Gaussian distributions (2D and 3D) |
25 | 25 | # --------------------------------------------- |
26 | 26 | # |
|
43 | 43 | xt = np.random.randn(n_samples, 3).dot(P) + mu_t |
44 | 44 |
|
45 | 45 |
|
46 | | -############################################################################## |
| 46 | +# |
47 | 47 | # Plotting the distributions |
48 | 48 | # -------------------------- |
49 | 49 |
|
|
56 | 56 | pl.show() |
57 | 57 |
|
58 | 58 |
|
59 | | -############################################################################## |
| 59 | +# |
60 | 60 | # Compute distance kernels, normalize them and then display |
61 | 61 | # --------------------------------------------------------- |
62 | 62 |
|
|
74 | 74 | pl.imshow(C2) |
75 | 75 | pl.show() |
76 | 76 |
|
77 | | -############################################################################## |
| 77 | +# |
78 | 78 | # Compute Gromov-Wasserstein plans and distance |
79 | 79 | # --------------------------------------------- |
80 | 80 |
|
81 | | -#%% |
82 | 81 | p = ot.unif(n_samples) |
83 | 82 | q = ot.unif(n_samples) |
84 | 83 |
|
85 | | -gw0,log0 = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', verbose=True,log=True) |
| 84 | +gw0, log0 = ot.gromov.gromov_wasserstein( |
| 85 | + C1, C2, p, q, 'square_loss', verbose=True, log=True) |
86 | 86 |
|
87 | | -gw,log= ot.gromov.entropic_gromov_wasserstein(C1, C2, p, q, 'square_loss', epsilon=5e-4,log=True,verbose=True) |
| 87 | +gw, log = ot.gromov.entropic_gromov_wasserstein( |
| 88 | + C1, C2, p, q, 'square_loss', epsilon=5e-4, log=True, verbose=True) |
88 | 89 |
|
89 | 90 |
|
90 | 91 | print('Gromov-Wasserstein distances: ' + str(log0['gw_dist'])) |
91 | 92 | print('Entropic Gromov-Wasserstein distances: ' + str(log['gw_dist'])) |
92 | 93 |
|
93 | 94 |
|
94 | | -pl.figure(1,(10,5)) |
| 95 | +pl.figure(1, (10, 5)) |
95 | 96 |
|
96 | | -pl.subplot(1,2,1) |
| 97 | +pl.subplot(1, 2, 1) |
97 | 98 | pl.imshow(gw0, cmap='jet') |
98 | | -pl.colorbar() |
99 | 99 | pl.title('Gromov Wasserstein') |
100 | 100 |
|
101 | | -pl.subplot(1,2,2) |
102 | | -pl.imshow(gw0, cmap='jet') |
103 | | -pl.colorbar() |
| 101 | +pl.subplot(1, 2, 2) |
| 102 | +pl.imshow(gw, cmap='jet') |
104 | 103 | pl.title('Entropic Gromov Wasserstein') |
105 | 104 |
|
106 | 105 | pl.show() |
0 commit comments