Skip to content

Commit 7ba65c2

Browse files
committed
Addressed @matthew-brett's about coroutine and ArraySequence.extend method.
1 parent 5faf62e commit 7ba65c2

File tree

6 files changed

+169
-108
lines changed

6 files changed

+169
-108
lines changed

nibabel/streamlines/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,6 @@ def save(tractogram, filename, **kwargs):
127127
if len(kwargs) > 0:
128128
msg = ("A 'TractogramFile' object was provided, no need for"
129129
" keyword arguments.")
130-
raise ValueError(msg)
130+
raise ValueError(msg)
131131

132132
tractogram_file.save(filename)

nibabel/streamlines/array_sequence.py

Lines changed: 35 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numbers
12
import numpy as np
23

34

@@ -37,7 +38,7 @@ def __init__(self, iterable=None):
3738
"""
3839
# Create new empty `ArraySequence` object.
3940
self._is_view = False
40-
self._data = np.array(0)
41+
self._data = np.array([])
4142
self._offsets = np.array([], dtype=np.intp)
4243
self._lengths = np.array([], dtype=np.intp)
4344

@@ -79,8 +80,7 @@ def __init__(self, iterable=None):
7980
self._lengths = np.asarray(lengths)
8081

8182
# Clear unused memory.
82-
if self._data.ndim != 0:
83-
self._data.resize((offset,) + self.common_shape)
83+
self._data.resize((offset,) + self.common_shape)
8484

8585
@property
8686
def is_array_sequence(self):
@@ -89,13 +89,10 @@ def is_array_sequence(self):
8989
@property
9090
def common_shape(self):
9191
""" Matching shape of the elements in this array sequence. """
92-
if self._data.ndim == 0:
93-
return ()
94-
9592
return self._data.shape[1:]
9693

9794
def append(self, element):
98-
""" Appends :obj:`element` to this array sequence.
95+
""" Appends `element` to this array sequence.
9996
10097
Parameters
10198
----------
@@ -108,28 +105,28 @@ def append(self, element):
108105
If you need to add multiple elements you should consider
109106
`ArraySequence.extend`.
110107
"""
111-
if self._data.ndim == 0:
112-
self._data = np.asarray(element).copy()
113-
self._offsets = np.array([0])
114-
self._lengths = np.array([len(element)])
115-
return
108+
element = np.asarray(element)
116109

117-
if element.shape[1:] != self.common_shape:
110+
if self.common_shape != () and element.shape[1:] != self.common_shape:
118111
msg = "All dimensions, except the first one, must match exactly"
119112
raise ValueError(msg)
120113

121-
self._offsets = np.r_[self._offsets, len(self._data)]
122-
self._lengths = np.r_[self._lengths, len(element)]
123-
self._data = np.append(self._data, element, axis=0)
114+
next_offset = self._data.shape[0]
115+
size = (self._data.shape[0] + element.shape[0],) + element.shape[1:]
116+
self._data.resize(size)
117+
self._data[next_offset:] = element
118+
self._offsets = np.r_[self._offsets, next_offset]
119+
self._lengths = np.r_[self._lengths, element.shape[0]]
124120

125121
def extend(self, elements):
126122
""" Appends all `elements` to this array sequence.
127123
128124
Parameters
129125
----------
130-
elements : list of ndarrays or :class:`ArraySequence` object
131-
If list of ndarrays, each ndarray will be concatenated along the
132-
first dimension then appended to the data of this ArraySequence.
126+
elements : iterable of ndarrays or :class:`ArraySequence` object
127+
If iterable of ndarrays, each ndarray will be concatenated along
128+
the first dimension then appended to the data of this
129+
ArraySequence.
133130
If :class:`ArraySequence` object, its data are simply appended to
134131
the data of this ArraySequence.
135132
@@ -138,35 +135,31 @@ def extend(self, elements):
138135
The shape of the elements to be added must match the one of the
139136
data of this :class:`ArraySequence` except for the first dimension.
140137
"""
138+
if not is_array_sequence(elements):
139+
self.extend(ArraySequence(elements))
140+
return
141+
141142
if len(elements) == 0:
142143
return
143144

144-
if self._data.ndim == 0:
145-
elem = np.asarray(elements[0])
146-
self._data = np.zeros((0, elem.shape[1]), dtype=elem.dtype)
145+
if (self.common_shape != () and
146+
elements.common_shape != self.common_shape):
147+
msg = "All dimensions, except the first one, must match exactly"
148+
raise ValueError(msg)
147149

148150
next_offset = self._data.shape[0]
151+
self._data.resize((self._data.shape[0] + sum(elements._lengths),
152+
elements._data.shape[1]))
149153

150-
if is_array_sequence(elements):
151-
self._data.resize((self._data.shape[0] + sum(elements._lengths),
152-
self._data.shape[1]))
153-
154-
offsets = []
155-
for offset, length in zip(elements._offsets, elements._lengths):
156-
offsets.append(next_offset)
157-
chunk = elements._data[offset:offset + length]
158-
self._data[next_offset:next_offset + length] = chunk
159-
next_offset += length
160-
161-
self._lengths = np.r_[self._lengths, elements._lengths]
162-
self._offsets = np.r_[self._offsets, offsets]
154+
offsets = []
155+
for offset, length in zip(elements._offsets, elements._lengths):
156+
offsets.append(next_offset)
157+
chunk = elements._data[offset:offset + length]
158+
self._data[next_offset:next_offset + length] = chunk
159+
next_offset += length
163160

164-
else:
165-
self._data = np.concatenate([self._data] + list(elements), axis=0)
166-
lengths = list(map(len, elements))
167-
self._lengths = np.r_[self._lengths, lengths]
168-
self._offsets = np.r_[self._offsets,
169-
np.cumsum([next_offset] + lengths)[:-1]]
161+
self._lengths = np.r_[self._lengths, elements._lengths]
162+
self._offsets = np.r_[self._offsets, offsets]
170163

171164
def copy(self):
172165
""" Creates a copy of this :class:`ArraySequence` object. """
@@ -210,7 +203,7 @@ def __getitem__(self, idx):
210203
Otherwise, returns a :class:`ArraySequence` object which is view
211204
of the selected sequences.
212205
"""
213-
if isinstance(idx, (int, np.integer)):
206+
if isinstance(idx, (numbers.Integral, np.integer)):
214207
start = self._offsets[idx]
215208
return self._data[start:start + self._lengths[idx]]
216209

nibabel/streamlines/tests/test_array_sequence.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import os
2+
import sys
23
import unittest
34
import tempfile
45
import numpy as np
56

67
from nose.tools import assert_equal, assert_raises, assert_true
78
from nibabel.testing import assert_arrays_equal
89
from numpy.testing import assert_array_equal
9-
from nibabel.externals.six.moves import zip, zip_longest
1010

1111
from ..array_sequence import ArraySequence, is_array_sequence
1212

@@ -32,7 +32,8 @@ def check_empty_arr_seq(seq):
3232
assert_equal(len(seq), 0)
3333
assert_equal(len(seq._offsets), 0)
3434
assert_equal(len(seq._lengths), 0)
35-
assert_equal(seq._data.ndim, 0)
35+
# assert_equal(seq._data.ndim, 0)
36+
assert_equal(seq._data.ndim, 1)
3637
assert_true(seq.common_shape == ())
3738

3839

@@ -138,6 +139,11 @@ def test_arraysequence_append(self):
138139
seq.append(element)
139140
check_arr_seq(seq, SEQ_DATA['data'] + [element])
140141

142+
# Append a list of list.
143+
seq = SEQ_DATA['seq'].copy() # Copy because of in-place modification.
144+
seq.append(element.tolist())
145+
check_arr_seq(seq, SEQ_DATA['data'] + [element])
146+
141147
# Append to an empty ArraySequence.
142148
seq = ArraySequence()
143149
seq.append(element)
@@ -164,6 +170,11 @@ def test_arraysequence_extend(self):
164170
seq.extend(new_data)
165171
check_arr_seq(seq, SEQ_DATA['data'] + new_data)
166172

173+
# Extend with a generator.
174+
seq = SEQ_DATA['seq'].copy() # Copy because of in-place modification.
175+
seq.extend((d for d in new_data))
176+
check_arr_seq(seq, SEQ_DATA['data'] + new_data)
177+
167178
# Extend with another `ArraySequence` object.
168179
seq = SEQ_DATA['seq'].copy() # Copy because of in-place modification.
169180
seq.extend(ArraySequence(new_data))
@@ -195,6 +206,9 @@ def test_arraysequence_getitem(self):
195206
for i, e in enumerate(SEQ_DATA['seq']):
196207
assert_array_equal(SEQ_DATA['seq'][i], e)
197208

209+
if sys.version_info < (3,):
210+
assert_array_equal(SEQ_DATA['seq'][long(i)], e)
211+
198212
# Get all items using indexing (creates a view).
199213
indices = list(range(len(SEQ_DATA['seq'])))
200214
seq_view = SEQ_DATA['seq'][indices]

nibabel/streamlines/tests/test_streamlines.py

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -226,23 +226,28 @@ def test_save_tractogram_file(self):
226226
assert_true(issubclass(w[0].category, ExtensionWarning))
227227
assert_true("extension" in str(w[0].message))
228228

229+
with InTemporaryDirectory():
230+
nib.streamlines.save(trk_file, "dummy.trk")
231+
tfile = nib.streamlines.load("dummy.trk", lazy_load=False)
232+
assert_tractogram_equal(tfile.tractogram, tractogram)
233+
229234
def test_save_empty_file(self):
230235
tractogram = Tractogram()
231236
for ext, cls in nib.streamlines.FORMATS.items():
232237
with InTemporaryDirectory():
233-
with open('streamlines' + ext, 'w+b') as f:
234-
nib.streamlines.save(tractogram, f.name)
235-
tfile = nib.streamlines.load(f, lazy_load=False)
236-
assert_tractogram_equal(tfile.tractogram, tractogram)
238+
filename = 'streamlines' + ext
239+
nib.streamlines.save(tractogram, filename)
240+
tfile = nib.streamlines.load(filename, lazy_load=False)
241+
assert_tractogram_equal(tfile.tractogram, tractogram)
237242

238243
def test_save_simple_file(self):
239244
tractogram = Tractogram(DATA['streamlines'])
240245
for ext, cls in nib.streamlines.FORMATS.items():
241246
with InTemporaryDirectory():
242-
with open('streamlines' + ext, 'w+b') as f:
243-
nib.streamlines.save(tractogram, f.name)
244-
tfile = nib.streamlines.load(f, lazy_load=False)
245-
assert_tractogram_equal(tfile.tractogram, tractogram)
247+
filename = 'streamlines' + ext
248+
nib.streamlines.save(tractogram, filename)
249+
tfile = nib.streamlines.load(filename, lazy_load=False)
250+
assert_tractogram_equal(tfile.tractogram, tractogram)
246251

247252
def test_save_complex_file(self):
248253
complex_tractogram = Tractogram(DATA['streamlines'],
@@ -251,18 +256,19 @@ def test_save_complex_file(self):
251256

252257
for ext, cls in nib.streamlines.FORMATS.items():
253258
with InTemporaryDirectory():
254-
with open('streamlines' + ext, 'w+b') as f:
255-
with clear_and_catch_warnings(record=True,
256-
modules=[trk]) as w:
257-
nib.streamlines.save(complex_tractogram, f.name)
258-
259-
# If streamlines format does not support saving data
260-
# per point or data per streamline, a warning message
261-
# should be issued.
262-
if not (cls.support_data_per_point() and
263-
cls.support_data_per_streamline()):
264-
assert_equal(len(w), 1)
265-
assert_true(issubclass(w[0].category, Warning))
259+
filename = 'streamlines' + ext
260+
261+
with clear_and_catch_warnings(record=True,
262+
modules=[trk]) as w:
263+
nib.streamlines.save(complex_tractogram, filename)
264+
265+
# If streamlines format does not support saving data
266+
# per point or data per streamline, a warning message
267+
# should be issued.
268+
if not (cls.support_data_per_point() and
269+
cls.support_data_per_streamline()):
270+
assert_equal(len(w), 1)
271+
assert_true(issubclass(w[0].category, Warning))
266272

267273
tractogram = Tractogram(DATA['streamlines'])
268274

@@ -272,7 +278,7 @@ def test_save_complex_file(self):
272278
if cls.support_data_per_streamline():
273279
tractogram.data_per_streamline = DATA['data_per_streamline']
274280

275-
tfile = nib.streamlines.load(f, lazy_load=False)
281+
tfile = nib.streamlines.load(filename, lazy_load=False)
276282
assert_tractogram_equal(tfile.tractogram, tractogram)
277283

278284
def test_load_unknown_format(self):

nibabel/streamlines/tests/test_tractogram.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import sys
12
import unittest
23
import numpy as np
34
import warnings
@@ -241,6 +242,22 @@ def test_tractogram_creation(self):
241242
DATA['data_per_streamline'],
242243
DATA['data_per_point'])
243244

245+
# Create a tractogram from another tractogram attributes.
246+
tractogram2 = Tractogram(tractogram.streamlines,
247+
tractogram.data_per_streamline,
248+
tractogram.data_per_point)
249+
250+
assert_tractogram_equal(tractogram2, tractogram)
251+
252+
# Create a tractogram from a LazyTractogram object.
253+
tractogram = LazyTractogram(DATA['streamlines_func'],
254+
DATA['data_per_streamline_func'],
255+
DATA['data_per_point_func'])
256+
257+
tractogram2 = Tractogram(tractogram.streamlines,
258+
tractogram.data_per_streamline,
259+
tractogram.data_per_point)
260+
244261
# Inconsistent number of scalars between streamlines
245262
wrong_data = [[(1, 0, 0)]*1,
246263
[(0, 1, 0), (0, 1)],
@@ -264,6 +281,9 @@ def test_tractogram_getitem(self):
264281
for i, t in enumerate(DATA['tractogram']):
265282
assert_tractogram_item_equal(DATA['tractogram'][i], t)
266283

284+
if sys.version_info < (3,):
285+
assert_tractogram_item_equal(DATA['tractogram'][long(i)], t)
286+
267287
# Get one TractogramItem out of two.
268288
tractogram_view = DATA['simple_tractogram'][::2]
269289
check_tractogram(tractogram_view, DATA['streamlines'][::2])
@@ -411,7 +431,7 @@ def test_lazy_tractogram_creation(self):
411431
'mean_colors': (x for x in DATA['mean_colors'])}
412432

413433
# Creating LazyTractogram with generators is not allowed as
414-
# generators get exhausted and are not reusable unlike coroutines.
434+
# generators get exhausted and are not reusable unlike generator function.
415435
assert_raises(TypeError, LazyTractogram, streamlines)
416436
assert_raises(TypeError, LazyTractogram,
417437
data_per_streamline=data_per_streamline)
@@ -430,7 +450,7 @@ def test_lazy_tractogram_creation(self):
430450
assert_true(check_iteration(tractogram))
431451
assert_equal(len(tractogram), len(DATA['streamlines']))
432452

433-
# Coroutines get re-called and creates new iterators.
453+
# Generator functions get re-called and creates new iterators.
434454
for i in range(2):
435455
assert_tractogram_equal(tractogram, DATA['tractogram'])
436456

@@ -441,7 +461,7 @@ def test_lazy_tractogram_create_from(self):
441461
tractogram = LazyTractogram.create_from(_empty_data_gen)
442462
check_tractogram(tractogram)
443463

444-
# Create `LazyTractogram` from a coroutine yielding TractogramItem
464+
# Create `LazyTractogram` from a generator function yielding TractogramItem.
445465
data = [DATA['streamlines'], DATA['fa'], DATA['colors'],
446466
DATA['mean_curvature'], DATA['mean_torsion'],
447467
DATA['mean_colors']]
@@ -530,13 +550,13 @@ def test_lazy_tractogram_copy(self):
530550
# Check we copied the data and not simply created new references.
531551
assert_true(tractogram is not DATA['lazy_tractogram'])
532552

533-
# When copying LazyTractogram, coroutines generating streamlines should
534-
# be the same.
553+
# When copying LazyTractogram, the generator function yielding streamlines
554+
# should stay the same.
535555
assert_true(tractogram._streamlines
536556
is DATA['lazy_tractogram']._streamlines)
537557

538558
# Copying LazyTractogram, creates new internal LazyDict objects,
539-
# but coroutines contained in it should be the same.
559+
# but generator functions contained in it should stay the same.
540560
assert_true(tractogram._data_per_streamline
541561
is not DATA['lazy_tractogram']._data_per_streamline)
542562
assert_true(tractogram._data_per_point

0 commit comments

Comments
 (0)