@@ -535,18 +535,18 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
535535
536536 Parameters
537537 ----------
538- measures_locations : list of N (k_i,d) numpy.ndarray
538+ measures_locations : list of N (k_i,d) array-like
539539 The discrete support of a measure supported on :math:`k_i` locations of a `d`-dimensional space
540540 (:math:`k_i` can be different for each element of the list)
541- measures_weights : list of N (k_i,) numpy.ndarray
541+ measures_weights : list of N (k_i,) array-like
542542 Numpy arrays where each numpy array has :math:`k_i` non-negatives values summing to one
543543 representing the weights of each discrete input measure
544544
545- X_init : (k,d) np.ndarray
545+ X_init : (k,d) array-like
546546 Initialization of the support locations (on `k` atoms) of the barycenter
547- b : (k,) np.ndarray
547+ b : (k,) array-like
548548 Initialization of the weights of the barycenter (non-negatives, sum to 1)
549- weights : (N,) np.ndarray
549+ weights : (N,) array-like
550550 Initialization of the coefficients of the barycenter (non-negatives, sum to 1)
551551
552552 numItermax : int, optional
@@ -564,7 +564,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
564564
565565 Returns
566566 -------
567- X : (k,d) np.ndarray
567+ X : (k,d) array-like
568568 Support locations (on k atoms) of the barycenter
569569
570570
@@ -577,15 +577,17 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
577577
578578 """
579579
580+ nx = get_backend (* measures_locations ,* measures_weights ,X_init )
581+
580582 iter_count = 0
581583
582584 N = len (measures_locations )
583585 k = X_init .shape [0 ]
584586 d = X_init .shape [1 ]
585587 if b is None :
586- b = np .ones ((k ,)) / k
588+ b = nx .ones ((k ,), type_as = X_init ) / k
587589 if weights is None :
588- weights = np .ones ((N ,)) / N
590+ weights = nx .ones ((N ,), type_as = X_init ) / N
589591
590592 X = X_init
591593
@@ -596,15 +598,15 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
596598
597599 while (displacement_square_norm > stopThr and iter_count < numItermax ):
598600
599- T_sum = np .zeros ((k , d ))
601+ T_sum = nx .zeros ((k , d ),type_as = X_init )
602+
600603
601- for (measure_locations_i , measure_weights_i , weight_i ) in zip (measures_locations , measures_weights ,
602- weights .tolist ()):
604+ for (measure_locations_i , measure_weights_i , weight_i ) in zip (measures_locations , measures_weights , weights ):
603605 M_i = dist (X , measure_locations_i )
604606 T_i = emd (b , measure_weights_i , M_i , numThreads = numThreads )
605- T_sum = T_sum + weight_i * np . reshape ( 1. / b , ( - 1 , 1 )) * np . matmul (T_i , measure_locations_i )
607+ T_sum = T_sum + weight_i * 1. / b [:, None ] * nx . dot (T_i , measure_locations_i )
606608
607- displacement_square_norm = np .sum (np . square (T_sum - X ))
609+ displacement_square_norm = nx .sum ((T_sum - X )** 2 )
608610 if log :
609611 displacement_square_norms .append (displacement_square_norm )
610612
0 commit comments