@@ -944,16 +944,17 @@ def eigh(self, a):
944944 """
945945 raise NotImplementedError ()
946946
947- def kl_div (self , p , q , eps = 1e-16 ):
947+ def kl_div (self , p , q , mass = False , eps = 1e-16 ):
948948 r"""
949- Computes the Kullback-Leibler divergence.
949+ Computes the (Generalized) Kullback-Leibler divergence.
950950
951951 This function follows the api from :any:`scipy.stats.entropy`.
952952
953953 Parameter eps is used to avoid numerical errors and is added in the log.
954954
955955 .. math::
956- KL(p,q) = \sum_i p(i) \log (\frac{p(i)}{q(i)}+\epsilon)
956+ KL(p,q) = \langle \mathbf{p}, log(\mathbf{p} / \mathbf{q} + eps \rangle
957+ + \mathbb{1}_{mass=True} \langle \mathbf{q} - \mathbf{p}, \mathbf{1} \rangle
957958
958959 See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.entropy.html
959960 """
@@ -1352,8 +1353,11 @@ def sqrtm(self, a):
13521353 def eigh (self , a ):
13531354 return np .linalg .eigh (a )
13541355
1355- def kl_div (self , p , q , eps = 1e-16 ):
1356- return np .sum (p * np .log (p / q + eps ))
1356+ def kl_div (self , p , q , mass = False , eps = 1e-16 ):
1357+ value = np .sum (p * np .log (p / q + eps ))
1358+ if mass :
1359+ value = value + np .sum (q - p )
1360+ return value
13571361
13581362 def isfinite (self , a ):
13591363 return np .isfinite (a )
@@ -1751,8 +1755,11 @@ def sqrtm(self, a):
17511755 def eigh (self , a ):
17521756 return jnp .linalg .eigh (a )
17531757
1754- def kl_div (self , p , q , eps = 1e-16 ):
1755- return jnp .sum (p * jnp .log (p / q + eps ))
1758+ def kl_div (self , p , q , mass = False , eps = 1e-16 ):
1759+ value = jnp .sum (p * jnp .log (p / q + eps ))
1760+ if mass :
1761+ value = value + jnp .sum (q - p )
1762+ return value
17561763
17571764 def isfinite (self , a ):
17581765 return jnp .isfinite (a )
@@ -2238,8 +2245,11 @@ def sqrtm(self, a):
22382245 def eigh (self , a ):
22392246 return torch .linalg .eigh (a )
22402247
2241- def kl_div (self , p , q , eps = 1e-16 ):
2242- return torch .sum (p * torch .log (p / q + eps ))
2248+ def kl_div (self , p , q , mass = False , eps = 1e-16 ):
2249+ value = torch .sum (p * torch .log (p / q + eps ))
2250+ if mass :
2251+ value = value + torch .sum (q - p )
2252+ return value
22432253
22442254 def isfinite (self , a ):
22452255 return torch .isfinite (a )
@@ -2639,8 +2649,11 @@ def sqrtm(self, a):
26392649 def eigh (self , a ):
26402650 return cp .linalg .eigh (a )
26412651
2642- def kl_div (self , p , q , eps = 1e-16 ):
2643- return cp .sum (p * cp .log (p / q + eps ))
2652+ def kl_div (self , p , q , mass = False , eps = 1e-16 ):
2653+ value = cp .sum (p * cp .log (p / q + eps ))
2654+ if mass :
2655+ value = value + cp .sum (q - p )
2656+ return value
26442657
26452658 def isfinite (self , a ):
26462659 return cp .isfinite (a )
@@ -3063,8 +3076,11 @@ def sqrtm(self, a):
30633076 def eigh (self , a ):
30643077 return tf .linalg .eigh (a )
30653078
3066- def kl_div (self , p , q , eps = 1e-16 ):
3067- return tnp .sum (p * tnp .log (p / q + eps ))
3079+ def kl_div (self , p , q , mass = False , eps = 1e-16 ):
3080+ value = tnp .sum (p * tnp .log (p / q + eps ))
3081+ if mass :
3082+ value = value + tnp .sum (q - p )
3083+ return value
30683084
30693085 def isfinite (self , a ):
30703086 return tnp .isfinite (a )
0 commit comments