8787# License: MIT License
8888
8989import numpy as np
90+ import os
9091import scipy
9192import scipy .linalg
92- import scipy .special as special
9393from scipy .sparse import issparse , coo_matrix , csr_matrix
94- import warnings
94+ import scipy . special as special
9595import time
96+ import warnings
97+
98+
99+ DISABLE_TORCH_KEY = 'POT_BACKEND_DISABLE_PYTORCH'
100+ DISABLE_JAX_KEY = 'POT_BACKEND_DISABLE_JAX'
101+ DISABLE_CUPY_KEY = 'POT_BACKEND_DISABLE_CUPY'
102+ DISABLE_TF_KEY = 'POT_BACKEND_DISABLE_TENSORFLOW'
103+
96104
97- try :
98- import torch
99- torch_type = torch .Tensor
100- except ImportError :
105+ if not os .environ .get (DISABLE_TORCH_KEY , False ):
106+ try :
107+ import torch
108+ torch_type = torch .Tensor
109+ except ImportError :
110+ torch = False
111+ torch_type = float
112+ else :
101113 torch = False
102114 torch_type = float
103115
104- try :
105- import jax
106- import jax .numpy as jnp
107- import jax .scipy .special as jspecial
108- from jax .lib import xla_bridge
109- jax_type = jax .numpy .ndarray
110- except ImportError :
116+ if not os .environ .get (DISABLE_JAX_KEY , False ):
117+ try :
118+ import jax
119+ import jax .numpy as jnp
120+ import jax .scipy .special as jspecial
121+ from jax .lib import xla_bridge
122+ jax_type = jax .numpy .ndarray
123+ except ImportError :
124+ jax = False
125+ jax_type = float
126+ else :
111127 jax = False
112128 jax_type = float
113129
114- try :
115- import cupy as cp
116- import cupyx
117- cp_type = cp .ndarray
118- except ImportError :
130+ if not os .environ .get (DISABLE_CUPY_KEY , False ):
131+ try :
132+ import cupy as cp
133+ import cupyx
134+ cp_type = cp .ndarray
135+ except ImportError :
136+ cp = False
137+ cp_type = float
138+ else :
119139 cp = False
120140 cp_type = float
121141
122- try :
123- import tensorflow as tf
124- import tensorflow .experimental .numpy as tnp
125- tf_type = tf .Tensor
126- except ImportError :
142+ if not os .environ .get (DISABLE_TF_KEY , False ):
143+ try :
144+ import tensorflow as tf
145+ import tensorflow .experimental .numpy as tnp
146+ tf_type = tf .Tensor
147+ except ImportError :
148+ tf = False
149+ tf_type = float
150+ else :
127151 tf = False
128152 tf_type = float
129153
132156
133157
134158# Mapping between argument types and the existing backend
135- _BACKENDS = []
159+ _BACKEND_IMPLEMENTATIONS = []
160+ _BACKENDS = {}
136161
137162
138- def register_backend ( backend ):
139- _BACKENDS .append (backend )
163+ def _register_backend_implementation ( backend_impl ):
164+ _BACKEND_IMPLEMENTATIONS .append (backend_impl )
140165
141166
142- def get_backend_list ():
143- """Returns the list of available backends"""
144- return _BACKENDS
167+ def _get_backend_instance (backend_impl ):
168+ if backend_impl .__name__ not in _BACKENDS :
169+ _BACKENDS [backend_impl .__name__ ] = backend_impl ()
170+ return _BACKENDS [backend_impl .__name__ ]
145171
146172
147- def _check_args_backend (backend , args ):
148- is_instance = set (isinstance (a , backend .__type__ ) for a in args )
173+ def _check_args_backend (backend_impl , args ):
174+ is_instance = set (isinstance (arg , backend_impl .__type__ ) for arg in args )
149175 # check that all arguments matched or not the type
150176 if len (is_instance ) == 1 :
151177 return is_instance .pop ()
152178
153- # Oterwise return an error
154- raise ValueError (str_type_error .format ([type (a ) for a in args ]))
179+ # Otherwise return an error
180+ raise ValueError (str_type_error .format ([type (arg ) for arg in args ]))
181+
182+
183+ def get_backend_list ():
184+ """Returns instances of all available backends.
185+
186+ Note that the function forces all detected implementations
187+ to be instantiated even if specific backend was not use before.
188+ Be careful as instantiation of the backend might lead to side effects,
189+ like GPU memory pre-allocation. See the documentation for more details.
190+ If you only need to know which implementations are available,
191+ use `:py:func:`ot.backend.get_available_backend_implementations`,
192+ which does not force instance of the backend object to be created.
193+ """
194+ return [
195+ _get_backend_instance (backend_impl )
196+ for backend_impl
197+ in get_available_backend_implementations ()
198+ ]
199+
200+
201+ def get_available_backend_implementations ():
202+ """Returns the list of available backend implementations."""
203+ return _BACKEND_IMPLEMENTATIONS
155204
156205
157206def get_backend (* args ):
@@ -167,9 +216,9 @@ def get_backend(*args):
167216 if not len (args ) > 0 :
168217 raise ValueError (" The function takes at least one (non-None) parameter" )
169218
170- for backend in _BACKENDS :
171- if _check_args_backend (backend , args ):
172- return backend
219+ for backend_impl in _BACKEND_IMPLEMENTATIONS :
220+ if _check_args_backend (backend_impl , args ):
221+ return _get_backend_instance ( backend_impl )
173222
174223 raise ValueError ("Unknown type of non implemented backend." )
175224
@@ -1341,7 +1390,7 @@ def matmul(self, a, b):
13411390 return np .matmul (a , b )
13421391
13431392
1344- register_backend (NumpyBackend () )
1393+ _register_backend_implementation (NumpyBackend )
13451394
13461395
13471396class JaxBackend (Backend ):
@@ -1710,7 +1759,7 @@ def matmul(self, a, b):
17101759
17111760if jax :
17121761 # Only register jax backend if it is installed
1713- register_backend (JaxBackend () )
1762+ _register_backend_implementation (JaxBackend )
17141763
17151764
17161765class TorchBackend (Backend ):
@@ -2193,7 +2242,7 @@ def matmul(self, a, b):
21932242
21942243if torch :
21952244 # Only register torch backend if it is installed
2196- register_backend (TorchBackend () )
2245+ _register_backend_implementation (TorchBackend )
21972246
21982247
21992248class CupyBackend (Backend ): # pragma: no cover
@@ -2586,7 +2635,7 @@ def matmul(self, a, b):
25862635
25872636if cp :
25882637 # Only register cp backend if it is installed
2589- register_backend (CupyBackend () )
2638+ _register_backend_implementation (CupyBackend )
25902639
25912640
25922641class TensorflowBackend (Backend ):
@@ -3006,4 +3055,4 @@ def matmul(self, a, b):
30063055
30073056if tf :
30083057 # Only register tensorflow backend if it is installed
3009- register_backend (TensorflowBackend () )
3058+ _register_backend_implementation (TensorflowBackend )
0 commit comments