1111import numpy as np
1212
1313
14- from ..utils import unif
14+ from ..utils import unif , check_random_state
1515from ..backend import get_backend
1616from ._gw import gromov_wasserstein , fused_gromov_wasserstein
1717
1818
1919def gromov_wasserstein_dictionary_learning (Cs , D , nt , reg = 0. , ps = None , q = None , epochs = 20 , batch_size = 32 , learning_rate = 1. , Cdict_init = None , projection = 'nonnegative_symmetric' , use_log = True ,
20- tol_outer = 10 ** (- 5 ), tol_inner = 10 ** (- 5 ), max_iter_outer = 20 , max_iter_inner = 200 , use_adam_optimizer = True , verbose = False , ** kwargs ):
20+ tol_outer = 10 ** (- 5 ), tol_inner = 10 ** (- 5 ), max_iter_outer = 20 , max_iter_inner = 200 , use_adam_optimizer = True , verbose = False , random_state = None , ** kwargs ):
2121 r"""
2222 Infer Gromov-Wasserstein linear dictionary :math:`\{ (\mathbf{C_{dict}[d]}, q) \}_{d \in [D]}` from the list of structures :math:`\{ (\mathbf{C_s},\mathbf{p_s}) \}_s`
2323
@@ -81,6 +81,9 @@ def gromov_wasserstein_dictionary_learning(Cs, D, nt, reg=0., ps=None, q=None, e
8181 Maximum number of iterations for the Conjugate Gradient. Default is 200.
8282 verbose : bool, optional
8383 Print the reconstruction loss every epoch. Default is False.
84+ random_state : int, RandomState instance or None, default=None
85+ Determines random number generation. Pass an int for reproducible
86+ output across multiple function calls.
8487
8588 Returns
8689 -------
@@ -90,6 +93,7 @@ def gromov_wasserstein_dictionary_learning(Cs, D, nt, reg=0., ps=None, q=None, e
9093 The dictionary leading to the best loss over an epoch is saved and returned.
9194 log: dict
9295 If use_log is True, contains loss evolutions by batches and epochs.
96+
9397 References
9498 -------
9599 .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online
@@ -110,10 +114,11 @@ def gromov_wasserstein_dictionary_learning(Cs, D, nt, reg=0., ps=None, q=None, e
110114 q = unif (nt )
111115 else :
112116 q = nx .to_numpy (q )
117+ rng = check_random_state (random_state )
113118 if Cdict_init is None :
114119 # Initialize randomly structures of dictionary atoms based on samples
115120 dataset_means = [C .mean () for C in Cs ]
116- Cdict = np . random .normal (loc = np .mean (dataset_means ), scale = np .std (dataset_means ), size = (D , nt , nt ))
121+ Cdict = rng .normal (loc = np .mean (dataset_means ), scale = np .std (dataset_means ), size = (D , nt , nt ))
117122 else :
118123 Cdict = nx .to_numpy (Cdict_init ).copy ()
119124 assert Cdict .shape == (D , nt , nt )
@@ -141,7 +146,7 @@ def gromov_wasserstein_dictionary_learning(Cs, D, nt, reg=0., ps=None, q=None, e
141146
142147 for _ in range (iter_by_epoch ):
143148 # batch sampling
144- batch = np . random .choice (range (dataset_size ), size = batch_size , replace = False )
149+ batch = rng .choice (range (dataset_size ), size = batch_size , replace = False )
145150 cumulated_loss_over_batch = 0.
146151 unmixings = np .zeros ((batch_size , D ))
147152 Cs_embedded = np .zeros ((batch_size , nt , nt ))
@@ -469,7 +474,8 @@ def _linesearch_gromov_wasserstein_unmixing(w, grad_w, x, Cdict, Cembedded, cons
469474
470475def fused_gromov_wasserstein_dictionary_learning (Cs , Ys , D , nt , alpha , reg = 0. , ps = None , q = None , epochs = 20 , batch_size = 32 , learning_rate_C = 1. , learning_rate_Y = 1. ,
471476 Cdict_init = None , Ydict_init = None , projection = 'nonnegative_symmetric' , use_log = False ,
472- tol_outer = 10 ** (- 5 ), tol_inner = 10 ** (- 5 ), max_iter_outer = 20 , max_iter_inner = 200 , use_adam_optimizer = True , verbose = False , ** kwargs ):
477+ tol_outer = 10 ** (- 5 ), tol_inner = 10 ** (- 5 ), max_iter_outer = 20 , max_iter_inner = 200 , use_adam_optimizer = True , verbose = False ,
478+ random_state = None , ** kwargs ):
473479 r"""
474480 Infer Fused Gromov-Wasserstein linear dictionary :math:`\{ (\mathbf{C_{dict}[d]}, \mathbf{Y_{dict}[d]}, \mathbf{q}) \}_{d \in [D]}` from the list of S attributed structures :math:`\{ (\mathbf{C_s}, \mathbf{Y_s},\mathbf{p_s}) \}_s`
475481
@@ -548,6 +554,9 @@ def fused_gromov_wasserstein_dictionary_learning(Cs, Ys, D, nt, alpha, reg=0., p
548554 Maximum number of iterations for the Conjugate Gradient. Default is 200.
549555 verbose : bool, optional
550556 Print the reconstruction loss every epoch. Default is False.
557+ random_state : int, RandomState instance or None, default=None
558+ Determines random number generation. Pass an int for reproducible
559+ output across multiple function calls.
551560
552561 Returns
553562 -------
@@ -560,6 +569,7 @@ def fused_gromov_wasserstein_dictionary_learning(Cs, Ys, D, nt, alpha, reg=0., p
560569 The dictionary leading to the best loss over an epoch is saved and returned.
561570 log: dict
562571 If use_log is True, contains loss evolutions by batches and epochs.
572+
563573 References
564574 -------
565575 .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online
@@ -583,17 +593,18 @@ def fused_gromov_wasserstein_dictionary_learning(Cs, Ys, D, nt, alpha, reg=0., p
583593 else :
584594 q = nx .to_numpy (q )
585595
596+ rng = check_random_state (random_state )
586597 if Cdict_init is None :
587598 # Initialize randomly structures of dictionary atoms based on samples
588599 dataset_means = [C .mean () for C in Cs ]
589- Cdict = np . random .normal (loc = np .mean (dataset_means ), scale = np .std (dataset_means ), size = (D , nt , nt ))
600+ Cdict = rng .normal (loc = np .mean (dataset_means ), scale = np .std (dataset_means ), size = (D , nt , nt ))
590601 else :
591602 Cdict = nx .to_numpy (Cdict_init ).copy ()
592603 assert Cdict .shape == (D , nt , nt )
593604 if Ydict_init is None :
594605 # Initialize randomly features of dictionary atoms based on samples distribution by feature component
595606 dataset_feature_means = np .stack ([F .mean (axis = 0 ) for F in Ys ])
596- Ydict = np . random .normal (loc = dataset_feature_means .mean (axis = 0 ), scale = dataset_feature_means .std (axis = 0 ), size = (D , nt , d ))
607+ Ydict = rng .normal (loc = dataset_feature_means .mean (axis = 0 ), scale = dataset_feature_means .std (axis = 0 ), size = (D , nt , d ))
597608 else :
598609 Ydict = nx .to_numpy (Ydict_init ).copy ()
599610 assert Ydict .shape == (D , nt , d )
@@ -626,7 +637,7 @@ def fused_gromov_wasserstein_dictionary_learning(Cs, Ys, D, nt, alpha, reg=0., p
626637 for _ in range (iter_by_epoch ):
627638
628639 # Batch iterations
629- batch = np . random .choice (range (dataset_size ), size = batch_size , replace = False )
640+ batch = rng .choice (range (dataset_size ), size = batch_size , replace = False )
630641 cumulated_loss_over_batch = 0.
631642 unmixings = np .zeros ((batch_size , D ))
632643 Cs_embedded = np .zeros ((batch_size , nt , nt ))
0 commit comments