@@ -435,7 +435,7 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=None,
435435##############################################################################
436436
437437
438- def batch_grad_dual (M , reg , a , b , alpha , beta , batch_size , batch_alpha ,
438+ def batch_grad_dual (a , b , M , reg , alpha , beta , batch_size , batch_alpha ,
439439 batch_beta ):
440440 '''
441441 Computes the partial gradient of F_\W_varepsilon
@@ -528,7 +528,7 @@ def batch_grad_dual(M, reg, a, b, alpha, beta, batch_size, batch_alpha,
528528 return grad_alpha , grad_beta
529529
530530
531- def sgd_entropic_regularization (M , reg , a , b , batch_size , numItermax , lr ):
531+ def sgd_entropic_regularization (a , b , M , reg , batch_size , numItermax , lr ):
532532 '''
533533 Compute the sgd algorithm to solve the regularized discrete measures
534534 optimal transport dual problem
@@ -612,7 +612,7 @@ def sgd_entropic_regularization(M, reg, a, b, batch_size, numItermax, lr):
612612 k = np .sqrt (cur_iter / 100 + 1 )
613613 batch_alpha = np .random .choice (n_source , batch_size , replace = False )
614614 batch_beta = np .random .choice (n_target , batch_size , replace = False )
615- update_alpha , update_beta = batch_grad_dual (M , reg , a , b , cur_alpha ,
615+ update_alpha , update_beta = batch_grad_dual (a , b , M , reg , cur_alpha ,
616616 cur_beta , batch_size ,
617617 batch_alpha , batch_beta )
618618 cur_alpha += (lr / k ) * update_alpha
@@ -698,7 +698,7 @@ def solve_dual_entropic(a, b, M, reg, batch_size, numItermax=10000, lr=1,
698698 arXiv preprint arxiv:1711.02283.
699699 '''
700700
701- opt_alpha , opt_beta = sgd_entropic_regularization (M , reg , a , b , batch_size ,
701+ opt_alpha , opt_beta = sgd_entropic_regularization (a , b , M , reg , batch_size ,
702702 numItermax , lr )
703703 pi = (np .exp ((opt_alpha [:, None ] + opt_beta [None , :] - M [:, :]) / reg ) *
704704 a [:, None ] * b [None , :])
0 commit comments