33Multi-lib backend for POT
44
55The goal is to write backend-agnostic code. Whether you're using Numpy, PyTorch,
6- or Jax , POT code should work nonetheless.
6+ Jax, or Cupy , POT code should work nonetheless.
77To achieve that, POT provides backend classes which implements functions in their respective backend
88imitating Numpy API. As a convention, we use nx instead of np to refer to the backend.
99
4444 jax = False
4545 jax_type = float
4646
47+ try :
48+ import cupy as cp
49+ import cupyx
50+ cp_type = cp .ndarray
51+ except ImportError :
52+ cp = False
53+ cp_type = float
54+
4755str_type_error = "All array should be from the same type/backend. Current types are : {}"
4856
4957
@@ -57,6 +65,9 @@ def get_backend_list():
5765 if jax :
5866 lst .append (JaxBackend ())
5967
68+ if cp :
69+ lst .append (CupyBackend ())
70+
6071 return lst
6172
6273
@@ -78,6 +89,8 @@ def get_backend(*args):
7889 return TorchBackend ()
7990 elif isinstance (args [0 ], jax_type ):
8091 return JaxBackend ()
92+ elif isinstance (args [0 ], cp_type ):
93+ return CupyBackend ()
8194 else :
8295 raise ValueError ("Unknown type of non implemented backend." )
8396
@@ -94,7 +107,8 @@ def to_numpy(*args):
94107class Backend ():
95108 """
96109 Backend abstract class.
97- Implementations: :py:class:`JaxBackend`, :py:class:`NumpyBackend`, :py:class:`TorchBackend`
110+ Implementations: :py:class:`JaxBackend`, :py:class:`NumpyBackend`, :py:class:`TorchBackend`,
111+ :py:class:`CupyBackend`
98112
99113 - The `__name__` class attribute refers to the name of the backend.
100114 - The `__type__` class attribute refers to the data structure used by the backend.
@@ -1500,3 +1514,287 @@ def assert_same_dtype_device(self, a, b):
15001514
15011515 assert a_dtype == b_dtype , "Dtype discrepancy"
15021516 assert a_device == b_device , f"Device discrepancy. First input is on { str (a_device )} , whereas second input is on { str (b_device )} "
1517+
1518+
1519+ class CupyBackend (Backend ): # pragma: no cover
1520+ """
1521+ CuPy implementation of the backend
1522+
1523+ - `__name__` is "cupy"
1524+ - `__type__` is cp.ndarray
1525+ """
1526+
1527+ __name__ = 'cupy'
1528+ __type__ = cp_type
1529+ __type_list__ = None
1530+
1531+ rng_ = None
1532+
1533+ def __init__ (self ):
1534+ self .rng_ = cp .random .RandomState ()
1535+
1536+ self .__type_list__ = [
1537+ cp .array (1 , dtype = cp .float32 ),
1538+ cp .array (1 , dtype = cp .float64 )
1539+ ]
1540+
1541+ def to_numpy (self , a ):
1542+ return cp .asnumpy (a )
1543+
1544+ def from_numpy (self , a , type_as = None ):
1545+ if type_as is None :
1546+ return cp .asarray (a )
1547+ else :
1548+ with cp .cuda .Device (type_as .device ):
1549+ return cp .asarray (a , dtype = type_as .dtype )
1550+
1551+ def set_gradients (self , val , inputs , grads ):
1552+ # No gradients for cupy
1553+ return val
1554+
1555+ def zeros (self , shape , type_as = None ):
1556+ if isinstance (shape , (list , tuple )):
1557+ shape = tuple (int (i ) for i in shape )
1558+ if type_as is None :
1559+ return cp .zeros (shape )
1560+ else :
1561+ with cp .cuda .Device (type_as .device ):
1562+ return cp .zeros (shape , dtype = type_as .dtype )
1563+
1564+ def ones (self , shape , type_as = None ):
1565+ if isinstance (shape , (list , tuple )):
1566+ shape = tuple (int (i ) for i in shape )
1567+ if type_as is None :
1568+ return cp .ones (shape )
1569+ else :
1570+ with cp .cuda .Device (type_as .device ):
1571+ return cp .ones (shape , dtype = type_as .dtype )
1572+
1573+ def arange (self , stop , start = 0 , step = 1 , type_as = None ):
1574+ return cp .arange (start , stop , step )
1575+
1576+ def full (self , shape , fill_value , type_as = None ):
1577+ if isinstance (shape , (list , tuple )):
1578+ shape = tuple (int (i ) for i in shape )
1579+ if type_as is None :
1580+ return cp .full (shape , fill_value )
1581+ else :
1582+ with cp .cuda .Device (type_as .device ):
1583+ return cp .full (shape , fill_value , dtype = type_as .dtype )
1584+
1585+ def eye (self , N , M = None , type_as = None ):
1586+ if type_as is None :
1587+ return cp .eye (N , M )
1588+ else :
1589+ with cp .cuda .Device (type_as .device ):
1590+ return cp .eye (N , M , dtype = type_as .dtype )
1591+
1592+ def sum (self , a , axis = None , keepdims = False ):
1593+ return cp .sum (a , axis , keepdims = keepdims )
1594+
1595+ def cumsum (self , a , axis = None ):
1596+ return cp .cumsum (a , axis )
1597+
1598+ def max (self , a , axis = None , keepdims = False ):
1599+ return cp .max (a , axis , keepdims = keepdims )
1600+
1601+ def min (self , a , axis = None , keepdims = False ):
1602+ return cp .min (a , axis , keepdims = keepdims )
1603+
1604+ def maximum (self , a , b ):
1605+ return cp .maximum (a , b )
1606+
1607+ def minimum (self , a , b ):
1608+ return cp .minimum (a , b )
1609+
1610+ def abs (self , a ):
1611+ return cp .abs (a )
1612+
1613+ def exp (self , a ):
1614+ return cp .exp (a )
1615+
1616+ def log (self , a ):
1617+ return cp .log (a )
1618+
1619+ def sqrt (self , a ):
1620+ return cp .sqrt (a )
1621+
1622+ def power (self , a , exponents ):
1623+ return cp .power (a , exponents )
1624+
1625+ def dot (self , a , b ):
1626+ return cp .dot (a , b )
1627+
1628+ def norm (self , a ):
1629+ return cp .sqrt (cp .sum (cp .square (a )))
1630+
1631+ def any (self , a ):
1632+ return cp .any (a )
1633+
1634+ def isnan (self , a ):
1635+ return cp .isnan (a )
1636+
1637+ def isinf (self , a ):
1638+ return cp .isinf (a )
1639+
1640+ def einsum (self , subscripts , * operands ):
1641+ return cp .einsum (subscripts , * operands )
1642+
1643+ def sort (self , a , axis = - 1 ):
1644+ return cp .sort (a , axis )
1645+
1646+ def argsort (self , a , axis = - 1 ):
1647+ return cp .argsort (a , axis )
1648+
1649+ def searchsorted (self , a , v , side = 'left' ):
1650+ if a .ndim == 1 :
1651+ return cp .searchsorted (a , v , side )
1652+ else :
1653+ # this is a not very efficient way to make numpy
1654+ # searchsorted work on 2d arrays
1655+ ret = cp .empty (v .shape , dtype = int )
1656+ for i in range (a .shape [0 ]):
1657+ ret [i , :] = cp .searchsorted (a [i , :], v [i , :], side )
1658+ return ret
1659+
1660+ def flip (self , a , axis = None ):
1661+ return cp .flip (a , axis )
1662+
1663+ def outer (self , a , b ):
1664+ return cp .outer (a , b )
1665+
1666+ def clip (self , a , a_min , a_max ):
1667+ return cp .clip (a , a_min , a_max )
1668+
1669+ def repeat (self , a , repeats , axis = None ):
1670+ return cp .repeat (a , repeats , axis )
1671+
1672+ def take_along_axis (self , arr , indices , axis ):
1673+ return cp .take_along_axis (arr , indices , axis )
1674+
1675+ def concatenate (self , arrays , axis = 0 ):
1676+ return cp .concatenate (arrays , axis )
1677+
1678+ def zero_pad (self , a , pad_width ):
1679+ return cp .pad (a , pad_width )
1680+
1681+ def argmax (self , a , axis = None ):
1682+ return cp .argmax (a , axis = axis )
1683+
1684+ def mean (self , a , axis = None ):
1685+ return cp .mean (a , axis = axis )
1686+
1687+ def std (self , a , axis = None ):
1688+ return cp .std (a , axis = axis )
1689+
1690+ def linspace (self , start , stop , num ):
1691+ return cp .linspace (start , stop , num )
1692+
1693+ def meshgrid (self , a , b ):
1694+ return cp .meshgrid (a , b )
1695+
1696+ def diag (self , a , k = 0 ):
1697+ return cp .diag (a , k )
1698+
1699+ def unique (self , a ):
1700+ return cp .unique (a )
1701+
1702+ def logsumexp (self , a , axis = None ):
1703+ # Taken from
1704+ # https://github.com/scipy/scipy/blob/v1.7.1/scipy/special/_logsumexp.py#L7-L127
1705+ a_max = cp .amax (a , axis = axis , keepdims = True )
1706+
1707+ if a_max .ndim > 0 :
1708+ a_max [~ cp .isfinite (a_max )] = 0
1709+ elif not cp .isfinite (a_max ):
1710+ a_max = 0
1711+
1712+ tmp = cp .exp (a - a_max )
1713+ s = cp .sum (tmp , axis = axis )
1714+ out = cp .log (s )
1715+ a_max = cp .squeeze (a_max , axis = axis )
1716+ out += a_max
1717+ return out
1718+
1719+ def stack (self , arrays , axis = 0 ):
1720+ return cp .stack (arrays , axis )
1721+
1722+ def reshape (self , a , shape ):
1723+ return cp .reshape (a , shape )
1724+
1725+ def seed (self , seed = None ):
1726+ if seed is not None :
1727+ self .rng_ .seed (seed )
1728+
1729+ def rand (self , * size , type_as = None ):
1730+ if type_as is None :
1731+ return self .rng_ .rand (* size )
1732+ else :
1733+ with cp .cuda .Device (type_as .device ):
1734+ return self .rng_ .rand (* size , dtype = type_as .dtype )
1735+
1736+ def randn (self , * size , type_as = None ):
1737+ if type_as is None :
1738+ return self .rng_ .randn (* size )
1739+ else :
1740+ with cp .cuda .Device (type_as .device ):
1741+ return self .rng_ .randn (* size , dtype = type_as .dtype )
1742+
1743+ def coo_matrix (self , data , rows , cols , shape = None , type_as = None ):
1744+ data = self .from_numpy (data )
1745+ rows = self .from_numpy (rows )
1746+ cols = self .from_numpy (cols )
1747+ if type_as is None :
1748+ return cupyx .scipy .sparse .coo_matrix (
1749+ (data , (rows , cols )), shape = shape
1750+ )
1751+ else :
1752+ with cp .cuda .Device (type_as .device ):
1753+ return cupyx .scipy .sparse .coo_matrix (
1754+ (data , (rows , cols )), shape = shape , dtype = type_as .dtype
1755+ )
1756+
1757+ def issparse (self , a ):
1758+ return cupyx .scipy .sparse .issparse (a )
1759+
1760+ def tocsr (self , a ):
1761+ if self .issparse (a ):
1762+ return a .tocsr ()
1763+ else :
1764+ return cupyx .scipy .sparse .csr_matrix (a )
1765+
1766+ def eliminate_zeros (self , a , threshold = 0. ):
1767+ if threshold > 0 :
1768+ if self .issparse (a ):
1769+ a .data [self .abs (a .data ) <= threshold ] = 0
1770+ else :
1771+ a [self .abs (a ) <= threshold ] = 0
1772+ if self .issparse (a ):
1773+ a .eliminate_zeros ()
1774+ return a
1775+
1776+ def todense (self , a ):
1777+ if self .issparse (a ):
1778+ return a .toarray ()
1779+ else :
1780+ return a
1781+
1782+ def where (self , condition , x , y ):
1783+ return cp .where (condition , x , y )
1784+
1785+ def copy (self , a ):
1786+ return a .copy ()
1787+
1788+ def allclose (self , a , b , rtol = 1e-05 , atol = 1e-08 , equal_nan = False ):
1789+ return cp .allclose (a , b , rtol = rtol , atol = atol , equal_nan = equal_nan )
1790+
1791+ def dtype_device (self , a ):
1792+ return a .dtype , a .device
1793+
1794+ def assert_same_dtype_device (self , a , b ):
1795+ a_dtype , a_device = self .dtype_device (a )
1796+ b_dtype , b_device = self .dtype_device (b )
1797+
1798+ # cupy has implicit type conversion so
1799+ # we automatically validate the test for type
1800+ assert a_device == b_device , f"Device discrepancy. First input is on { str (a_device )} , whereas second input is on { str (b_device )} "
0 commit comments