1- import sys
2-
3- sys .path .append ("." ) # Adds the module to path
1+ # Use this only when running the test locally.
2+ # import sys
3+ # sys.path.append(".") # Adds the module to path
44
55import unittest
66
7- from deeptrack import scatterers
8-
97import numpy as np
8+
109from deeptrack .optics import Fluorescence , Brightfield
11- from deeptrack . image import Image
10+ from deeptrack import scatterers
1211
12+ from deeptrack .backend import TORCH_AVAILABLE , xp
13+ from deeptrack .tests import BackendTestBase
14+
15+ if TORCH_AVAILABLE :
16+ import torch
17+
18+
19+ class TestScatterers_NumPy (BackendTestBase ):
20+ BACKEND = "numpy"
21+
22+ @property
23+ def array_type (self ):
24+ if self .BACKEND == "numpy" :
25+ return np .ndarray
26+ elif self .BACKEND == "torch" :
27+ return torch .Tensor
28+ else :
29+ raise ValueError (f"Unsupported backend: { self .BACKEND } " )
1330
14- class TestScatterers (unittest .TestCase ):
1531 def test_PointParticle (self ):
1632 optics = Fluorescence (
1733 NA = 0.7 ,
@@ -27,7 +43,7 @@ def test_PointParticle(self):
2743 )
2844 imaged_scatterer = optics (scatterer )
2945 output_image = imaged_scatterer .resolve ()
30- self .assertIsInstance (output_image , np . ndarray )
46+ self .assertIsInstance (output_image , self . array_type )
3147 self .assertEqual (output_image .shape , (64 , 64 , 1 ))
3248
3349 def test_Ellipse (self ):
@@ -48,7 +64,7 @@ def test_Ellipse(self):
4864 )
4965 imaged_scatterer = optics (scatterer )
5066 output_image = imaged_scatterer .resolve ()
51- self .assertIsInstance (output_image , np . ndarray )
67+ self .assertIsInstance (output_image , self . array_type )
5268 self .assertEqual (output_image .shape , (64 , 64 , 1 ))
5369
5470 def test_EllipseUpscale (self ):
@@ -146,7 +162,7 @@ def test_Sphere(self):
146162 )
147163 imaged_scatterer = optics (scatterer )
148164 output_image = imaged_scatterer .resolve ()
149- self .assertIsInstance (output_image , np . ndarray )
165+ self .assertIsInstance (output_image , self . array_type )
150166 self .assertEqual (output_image .shape , (64 , 64 , 1 ))
151167
152168 def test_SphereUpscale (self ):
@@ -188,7 +204,7 @@ def test_Ellipsoid(self):
188204 )
189205 imaged_scatterer = optics (scatterer )
190206 output_image = imaged_scatterer .resolve ()
191- self .assertIsInstance (output_image , np . ndarray )
207+ self .assertIsInstance (output_image , self . array_type )
192208 self .assertEqual (output_image .shape , (64 , 64 , 1 ))
193209
194210 def test_EllipsoidUpscale (self ):
@@ -343,6 +359,12 @@ def test_MieStratifiedSphere(self):
343359 imaged_scatterer_1 = optics_1 (scatterer )
344360 imaged_scatterer_1 .update ().resolve ()
345361
362+ # TODO: Extending the test and setting the backend to torch
363+ # @unittest.skipUnless(TORCH_AVAILABLE, "PyTorch is not installed.")
364+ # class TestScatterers_PyTorch(TestScatterers_NumPy):
365+ # BACKEND = "torch"
366+ # pass
367+
346368
347369if __name__ == "__main__" :
348370 unittest .main ()
0 commit comments