|
12 | 12 | # License: MIT License |
13 | 13 |
|
14 | 14 | import numpy as np |
| 15 | +import warnings |
15 | 16 |
|
16 | 17 |
|
17 | 18 | from ..utils import dist, UndefinedParameter, list_to_array |
@@ -53,6 +54,10 @@ def gromov_wasserstein(C1, C2, p=None, q=None, loss_fun='square_loss', symmetric |
53 | 54 | which can lead to copy overhead on GPU arrays. |
54 | 55 | .. note:: All computations in the conjugate gradient solver are done with |
55 | 56 | numpy to limit memory overhead. |
| 57 | + .. note:: This function will cast the computed transport plan to the data |
| 58 | + type of the provided input :math:`\mathbf{C}_1`. Casting to an integer |
| 59 | + tensor might result in a loss of precision. If this behaviour is |
| 60 | + unwanted, please make sure to provide a floating point input. |
56 | 61 |
|
57 | 62 | Parameters |
58 | 63 | ---------- |
@@ -122,7 +127,7 @@ def gromov_wasserstein(C1, C2, p=None, q=None, loss_fun='square_loss', symmetric |
122 | 127 | if q is not None: |
123 | 128 | arr.append(list_to_array(q)) |
124 | 129 | else: |
125 | | - q = unif(C2.shape[0], type_as=C2) |
| 130 | + q = unif(C2.shape[0], type_as=C1) |
126 | 131 | if G0 is not None: |
127 | 132 | G0_ = G0 |
128 | 133 | arr.append(G0) |
@@ -171,6 +176,16 @@ def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): |
171 | 176 | else: |
172 | 177 | def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): |
173 | 178 | return solve_gromov_linesearch(G, deltaG, cost_G, hC1, hC2, M=0., reg=1., nx=np_, **kwargs) |
| 179 | + |
| 180 | + if not nx.is_floating_point(C10): |
| 181 | + warnings.warn( |
| 182 | + "Input structure matrix consists of integer. The transport plan will be " |
| 183 | + "casted accordingly, possibly resulting in a loss of precision. " |
| 184 | + "If this behaviour is unwanted, please make sure your input " |
| 185 | + "structure matrix consists of floating point elements.", |
| 186 | + stacklevel=2 |
| 187 | + ) |
| 188 | + |
174 | 189 | if log: |
175 | 190 | res, log = cg(p, q, 0., 1., f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs) |
176 | 191 | log['gw_dist'] = nx.from_numpy(log['loss'][-1], type_as=C10) |
@@ -216,6 +231,10 @@ def gromov_wasserstein2(C1, C2, p=None, q=None, loss_fun='square_loss', symmetri |
216 | 231 | which can lead to copy overhead on GPU arrays. |
217 | 232 | .. note:: All computations in the conjugate gradient solver are done with |
218 | 233 | numpy to limit memory overhead. |
| 234 | + .. note:: This function will cast the computed transport plan to the data |
| 235 | + type of the provided input :math:`\mathbf{C}_1`. Casting to an integer |
| 236 | + tensor might result in a loss of precision. If this behaviour is |
| 237 | + unwanted, please make sure to provide a floating point input. |
219 | 238 |
|
220 | 239 | Parameters |
221 | 240 | ---------- |
@@ -286,7 +305,7 @@ def gromov_wasserstein2(C1, C2, p=None, q=None, loss_fun='square_loss', symmetri |
286 | 305 | if p is None: |
287 | 306 | p = unif(C1.shape[0], type_as=C1) |
288 | 307 | if q is None: |
289 | | - q = unif(C2.shape[0], type_as=C2) |
| 308 | + q = unif(C2.shape[0], type_as=C1) |
290 | 309 |
|
291 | 310 | T, log_gw = gromov_wasserstein( |
292 | 311 | C1, C2, p, q, loss_fun, symmetric, log=True, armijo=armijo, G0=G0, |
@@ -344,6 +363,10 @@ def fused_gromov_wasserstein(M, C1, C2, p=None, q=None, loss_fun='square_loss', |
344 | 363 | which can lead to copy overhead on GPU arrays. |
345 | 364 | .. note:: All computations in the conjugate gradient solver are done with |
346 | 365 | numpy to limit memory overhead. |
| 366 | + .. note:: This function will cast the computed transport plan to the data |
| 367 | + type of the provided input :math:`\mathbf{M}`. Casting to an integer |
| 368 | + tensor might result in a loss of precision. If this behaviour is |
| 369 | + unwanted, please make sure to provide a floating point input. |
347 | 370 |
|
348 | 371 |
|
349 | 372 | Parameters |
@@ -409,11 +432,11 @@ def fused_gromov_wasserstein(M, C1, C2, p=None, q=None, loss_fun='square_loss', |
409 | 432 | if p is not None: |
410 | 433 | arr.append(list_to_array(p)) |
411 | 434 | else: |
412 | | - p = unif(C1.shape[0], type_as=C1) |
| 435 | + p = unif(C1.shape[0], type_as=M) |
413 | 436 | if q is not None: |
414 | 437 | arr.append(list_to_array(q)) |
415 | 438 | else: |
416 | | - q = unif(C2.shape[0], type_as=C2) |
| 439 | + q = unif(C2.shape[0], type_as=M) |
417 | 440 | if G0 is not None: |
418 | 441 | G0_ = G0 |
419 | 442 | arr.append(G0) |
@@ -465,14 +488,22 @@ def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): |
465 | 488 | else: |
466 | 489 | def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): |
467 | 490 | return solve_gromov_linesearch(G, deltaG, cost_G, hC1, hC2, M=(1 - alpha) * M, reg=alpha, nx=np_, **kwargs) |
| 491 | + if not nx.is_floating_point(M0): |
| 492 | + warnings.warn( |
| 493 | + "Input feature matrix consists of integer. The transport plan will be " |
| 494 | + "casted accordingly, possibly resulting in a loss of precision. " |
| 495 | + "If this behaviour is unwanted, please make sure your input " |
| 496 | + "feature matrix consists of floating point elements.", |
| 497 | + stacklevel=2 |
| 498 | + ) |
468 | 499 | if log: |
469 | 500 | res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs) |
470 | | - log['fgw_dist'] = nx.from_numpy(log['loss'][-1], type_as=C10) |
471 | | - log['u'] = nx.from_numpy(log['u'], type_as=C10) |
472 | | - log['v'] = nx.from_numpy(log['v'], type_as=C10) |
473 | | - return nx.from_numpy(res, type_as=C10), log |
| 501 | + log['fgw_dist'] = nx.from_numpy(log['loss'][-1], type_as=M0) |
| 502 | + log['u'] = nx.from_numpy(log['u'], type_as=M0) |
| 503 | + log['v'] = nx.from_numpy(log['v'], type_as=M0) |
| 504 | + return nx.from_numpy(res, type_as=M0), log |
474 | 505 | else: |
475 | | - return nx.from_numpy(cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=False, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs), type_as=C10) |
| 506 | + return nx.from_numpy(cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=False, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs), type_as=M0) |
476 | 507 |
|
477 | 508 |
|
478 | 509 | def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss', symmetric=None, alpha=0.5, |
@@ -510,6 +541,10 @@ def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss', |
510 | 541 | which can lead to copy overhead on GPU arrays. |
511 | 542 | .. note:: All computations in the conjugate gradient solver are done with |
512 | 543 | numpy to limit memory overhead. |
| 544 | + .. note:: This function will cast the computed transport plan to the data |
| 545 | + type of the provided input :math:`\mathbf{M}`. Casting to an integer |
| 546 | + tensor might result in a loss of precision. If this behaviour is |
| 547 | + unwanted, please make sure to provide a floating point input. |
513 | 548 |
|
514 | 549 | Parameters |
515 | 550 | ---------- |
@@ -578,9 +613,9 @@ def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss', |
578 | 613 |
|
579 | 614 | # init marginals if set as None |
580 | 615 | if p is None: |
581 | | - p = unif(C1.shape[0], type_as=C1) |
| 616 | + p = unif(C1.shape[0], type_as=M) |
582 | 617 | if q is None: |
583 | | - q = unif(C2.shape[0], type_as=C2) |
| 618 | + q = unif(C2.shape[0], type_as=M) |
584 | 619 |
|
585 | 620 | T, log_fgw = fused_gromov_wasserstein( |
586 | 621 | M, C1, C2, p, q, loss_fun, symmetric, alpha, armijo, G0, log=True, |
|
0 commit comments