Skip to content

Commit 410ce9c

Browse files
mg/Update test_scatterers (#426)
* Update test_scatterers * update scatterers test
1 parent 2ca2b1f commit 410ce9c

File tree

1 file changed

+33
-11
lines changed

1 file changed

+33
-11
lines changed

deeptrack/tests/test_scatterers.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,33 @@
1-
import sys
2-
3-
sys.path.append(".") # Adds the module to path
1+
# Use this only when running the test locally.
2+
# import sys
3+
# sys.path.append(".") # Adds the module to path
44

55
import unittest
66

7-
from deeptrack import scatterers
8-
97
import numpy as np
8+
109
from deeptrack.optics import Fluorescence, Brightfield
11-
from deeptrack.image import Image
10+
from deeptrack import scatterers
1211

12+
from deeptrack.backend import TORCH_AVAILABLE, xp
13+
from deeptrack.tests import BackendTestBase
14+
15+
if TORCH_AVAILABLE:
16+
import torch
17+
18+
19+
class TestScatterers_NumPy(BackendTestBase):
20+
BACKEND = "numpy"
21+
22+
@property
23+
def array_type(self):
24+
if self.BACKEND == "numpy":
25+
return np.ndarray
26+
elif self.BACKEND == "torch":
27+
return torch.Tensor
28+
else:
29+
raise ValueError(f"Unsupported backend: {self.BACKEND}")
1330

14-
class TestScatterers(unittest.TestCase):
1531
def test_PointParticle(self):
1632
optics = Fluorescence(
1733
NA=0.7,
@@ -27,7 +43,7 @@ def test_PointParticle(self):
2743
)
2844
imaged_scatterer = optics(scatterer)
2945
output_image = imaged_scatterer.resolve()
30-
self.assertIsInstance(output_image, np.ndarray)
46+
self.assertIsInstance(output_image, self.array_type)
3147
self.assertEqual(output_image.shape, (64, 64, 1))
3248

3349
def test_Ellipse(self):
@@ -48,7 +64,7 @@ def test_Ellipse(self):
4864
)
4965
imaged_scatterer = optics(scatterer)
5066
output_image = imaged_scatterer.resolve()
51-
self.assertIsInstance(output_image, np.ndarray)
67+
self.assertIsInstance(output_image, self.array_type)
5268
self.assertEqual(output_image.shape, (64, 64, 1))
5369

5470
def test_EllipseUpscale(self):
@@ -146,7 +162,7 @@ def test_Sphere(self):
146162
)
147163
imaged_scatterer = optics(scatterer)
148164
output_image = imaged_scatterer.resolve()
149-
self.assertIsInstance(output_image, np.ndarray)
165+
self.assertIsInstance(output_image, self.array_type)
150166
self.assertEqual(output_image.shape, (64, 64, 1))
151167

152168
def test_SphereUpscale(self):
@@ -188,7 +204,7 @@ def test_Ellipsoid(self):
188204
)
189205
imaged_scatterer = optics(scatterer)
190206
output_image = imaged_scatterer.resolve()
191-
self.assertIsInstance(output_image, np.ndarray)
207+
self.assertIsInstance(output_image, self.array_type)
192208
self.assertEqual(output_image.shape, (64, 64, 1))
193209

194210
def test_EllipsoidUpscale(self):
@@ -343,6 +359,12 @@ def test_MieStratifiedSphere(self):
343359
imaged_scatterer_1 = optics_1(scatterer)
344360
imaged_scatterer_1.update().resolve()
345361

362+
# TODO: Extending the test and setting the backend to torch
363+
# @unittest.skipUnless(TORCH_AVAILABLE, "PyTorch is not installed.")
364+
# class TestScatterers_PyTorch(TestScatterers_NumPy):
365+
# BACKEND = "torch"
366+
# pass
367+
346368

347369
if __name__ == "__main__":
348370
unittest.main()

0 commit comments

Comments
 (0)