@@ -435,18 +435,23 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=None,
435435##############################################################################
436436
437437
438- def batch_grad_dual_alpha (M , reg , alpha , beta , batch_size , batch_alpha ,
439- batch_beta ):
438+ def batch_grad_dual (M , reg , a , b , alpha , beta , batch_size , batch_alpha ,
439+ batch_beta ):
440440 '''
441441 Computes the partial gradient of F_\W_varepsilon
442442
443443 Compute the partial gradient of the dual problem:
444444
445445 ..math:
446446 \f orall i in batch_alpha,
447- grad_alpha_i = 1 * batch_size -
448- sum_{j in batch_beta} exp((alpha_i + beta_j - M_{i,j})/reg)
449-
447+ grad_alpha_i = alpha_i * batch_size/len(beta) -
448+ sum_{j in batch_beta} exp((alpha_i + beta_j - M_{i,j})/reg)
449+ * a_i * b_j
450+
451+ \f orall j in batch_alpha,
452+ grad_beta_j = beta_j * batch_size/len(alpha) -
453+ sum_{j in batch_alpha} exp((alpha_i + beta_j - M_{i,j})/reg)
454+ * a_i * b_j
450455 where :
451456 - M is the (ns,nt) metric cost matrix
452457 - alpha, beta are dual variables in R^ixR^J
@@ -478,7 +483,7 @@ def batch_grad_dual_alpha(M, reg, alpha, beta, batch_size, batch_alpha,
478483 -------
479484
480485 grad : np.ndarray(ns,)
481- partial grad F in alpha
486+ partial grad F
482487
483488 Examples
484489 --------
@@ -510,100 +515,20 @@ def batch_grad_dual_alpha(M, reg, alpha, beta, batch_size, batch_alpha,
510515 arXiv preprint arxiv:1711.02283.
511516 '''
512517
513- grad_alpha = np .zeros (batch_size )
514- grad_alpha [:] = batch_size
515- for j in batch_beta :
516- grad_alpha -= np .exp ((alpha [batch_alpha ] + beta [j ] -
517- M [batch_alpha , j ]) / reg )
518- return grad_alpha
519-
520-
521- def batch_grad_dual_beta (M , reg , alpha , beta , batch_size , batch_alpha ,
522- batch_beta ):
523- '''
524- Computes the partial gradient of F_\W_varepsilon
525-
526- Compute the partial gradient of the dual problem:
527-
528- ..math:
529- \f orall j in batch_beta,
530- grad_beta_j = 1 * batch_size -
531- sum_{i in batch_alpha} exp((alpha_i + beta_j - M_{i,j})/reg)
532-
533- where :
534- - M is the (ns,nt) metric cost matrix
535- - alpha, beta are dual variables in R^ixR^J
536- - reg is the regularization term
537- - batch_alpha and batch_beta are list of index
538-
539- The algorithm used for solving the dual problem is the SGD algorithm
540- as proposed in [19]_ [alg.1]
541-
542- Parameters
543- ----------
544-
545- M : np.ndarray(ns, nt),
546- cost matrix
547- reg : float number,
548- Regularization term > 0
549- alpha : np.ndarray(ns,)
550- dual variable
551- beta : np.ndarray(nt,)
552- dual variable
553- batch_size : int number
554- size of the batch
555- batch_alpha : np.ndarray(bs,)
556- batch of index of alpha
557- batch_beta : np.ndarray(bs,)
558- batch of index of beta
559-
560- Returns
561- -------
562-
563- grad : np.ndarray(ns,)
564- partial grad F in beta
565-
566- Examples
567- --------
568-
569- >>> n_source = 7
570- >>> n_target = 4
571- >>> reg = 1
572- >>> numItermax = 20000
573- >>> lr = 0.1
574- >>> batch_size = 3
575- >>> log = True
576- >>> a = ot.utils.unif(n_source)
577- >>> b = ot.utils.unif(n_target)
578- >>> rng = np.random.RandomState(0)
579- >>> X_source = rng.randn(n_source, 2)
580- >>> Y_target = rng.randn(n_target, 2)
581- >>> M = ot.dist(X_source, Y_target)
582- >>> sgd_dual_pi, log = stochastic.solve_dual_entropic(a, b, M, reg,
583- batch_size,
584- numItermax, lr, log)
585- >>> print(log['alpha'], log['beta'])
586- >>> print(sgd_dual_pi)
587-
588- References
589- ----------
590-
591- [Seguy et al., 2018] :
592- International Conference on Learning Representation (2018),
593- arXiv preprint arxiv:1711.02283.
518+ G = - (np .exp ((alpha [batch_alpha , None ] + beta [None , batch_beta ] -
519+ M [batch_alpha , :][:, batch_beta ]) / reg ) * a [batch_alpha , None ] *
520+ b [None , batch_beta ])
521+ grad_beta = np .zeros (np .shape (M )[1 ])
522+ grad_alpha = np .zeros (np .shape (M )[0 ])
523+ grad_beta [batch_beta ] = (b [batch_beta ] * len (batch_alpha ) / np .shape (M )[0 ] +
524+ G .sum (0 ))
525+ grad_alpha [batch_alpha ] = (a [batch_alpha ] * len (batch_beta ) /
526+ np .shape (M )[1 ] + G .sum (1 ))
594527
595- '''
596-
597- grad_beta = np .zeros (batch_size )
598- grad_beta [:] = batch_size
599- for i in batch_alpha :
600- grad_beta -= np .exp ((alpha [i ] +
601- beta [batch_beta ] - M [i , batch_beta ]) / reg )
602- return grad_beta
528+ return grad_alpha , grad_beta
603529
604530
605- def sgd_entropic_regularization (M , reg , batch_size , numItermax , lr ,
606- alternate = True ):
531+ def sgd_entropic_regularization (M , reg , a , b , batch_size , numItermax , lr ):
607532 '''
608533 Compute the sgd algorithm to solve the regularized discrete measures
609534 optimal transport dual problem
@@ -628,6 +553,10 @@ def sgd_entropic_regularization(M, reg, batch_size, numItermax, lr,
628553 cost matrix
629554 reg : float number,
630555 Regularization term > 0
556+ alpha : np.ndarray(ns,)
557+ dual variable
558+ beta : np.ndarray(nt,)
559+ dual variable
631560 batch_size : int number
632561 size of the batch
633562 numItermax : int number
@@ -677,35 +606,17 @@ def sgd_entropic_regularization(M, reg, batch_size, numItermax, lr,
677606
678607 n_source = np .shape (M )[0 ]
679608 n_target = np .shape (M )[1 ]
680- cur_alpha = np .random .randn (n_source )
681- cur_beta = np .random .randn (n_target )
682- if alternate :
683- for cur_iter in range (numItermax ):
684- k = np .sqrt (cur_iter + 1 )
685- batch_alpha = np .random .choice (n_source , batch_size , replace = False )
686- batch_beta = np .random .choice (n_target , batch_size , replace = False )
687- grad_F_alpha = batch_grad_dual_alpha (M , reg , cur_alpha , cur_beta ,
688- batch_size , batch_alpha ,
689- batch_beta )
690- cur_alpha [batch_alpha ] += (lr / k ) * grad_F_alpha
691- grad_F_beta = batch_grad_dual_beta (M , reg , cur_alpha , cur_beta ,
692- batch_size , batch_alpha ,
693- batch_beta )
694- cur_beta [batch_beta ] += (lr / k ) * grad_F_beta
695-
696- else :
697- for cur_iter in range (numItermax ):
698- k = np .sqrt (cur_iter + 1 )
699- batch_alpha = np .random .choice (n_source , batch_size , replace = False )
700- batch_beta = np .random .choice (n_target , batch_size , replace = False )
701- grad_F_alpha = batch_grad_dual_alpha (M , reg , cur_alpha , cur_beta ,
702- batch_size , batch_alpha ,
703- batch_beta )
704- grad_F_beta = batch_grad_dual_beta (M , reg , cur_alpha , cur_beta ,
705- batch_size , batch_alpha ,
706- batch_beta )
707- cur_alpha [batch_alpha ] += (lr / k ) * grad_F_alpha
708- cur_beta [batch_beta ] += (lr / k ) * grad_F_beta
609+ cur_alpha = np .zeros (n_source )
610+ cur_beta = np .zeros (n_target )
611+ for cur_iter in range (numItermax ):
612+ k = np .sqrt (cur_iter / 100 + 1 )
613+ batch_alpha = np .random .choice (n_source , batch_size , replace = False )
614+ batch_beta = np .random .choice (n_target , batch_size , replace = False )
615+ update_alpha , update_beta = batch_grad_dual (M , reg , a , b , cur_alpha ,
616+ cur_beta , batch_size ,
617+ batch_alpha , batch_beta )
618+ cur_alpha += (lr / k ) * update_alpha
619+ cur_beta += (lr / k ) * update_beta
709620
710621 return cur_alpha , cur_beta
711622
@@ -787,7 +698,7 @@ def solve_dual_entropic(a, b, M, reg, batch_size, numItermax=10000, lr=1,
787698 arXiv preprint arxiv:1711.02283.
788699 '''
789700
790- opt_alpha , opt_beta = sgd_entropic_regularization (M , reg , batch_size ,
701+ opt_alpha , opt_beta = sgd_entropic_regularization (M , reg , a , b , batch_size ,
791702 numItermax , lr )
792703 pi = (np .exp ((opt_alpha [:, None ] + opt_beta [None , :] - M [:, :]) / reg ) *
793704 a [:, None ] * b [None , :])
0 commit comments