Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/dev/13526.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix bug preventing reading of :class:`mne.time_frequency.Spectrum` and :class:`mne.time_frequency.BaseTFR` objects created in MNE<1.8 using the deprecated subject info birthday tuple format, by `Thomas Binns`_.
8 changes: 6 additions & 2 deletions mne/time_frequency/spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,11 @@
check_fname,
)
from ..utils.misc import _pl
from ..utils.spectrum import _get_instance_type_string, _split_psd_kwargs
from ..utils.spectrum import (
_convert_old_birthday_format,
_get_instance_type_string,
_split_psd_kwargs,
)
from ..viz.topo import _plot_timeseries, _plot_timeseries_unified, _plot_topo
from ..viz.topomap import _make_head_outlines, _prepare_topomap_plot, plot_psds_topomap
from ..viz.utils import (
Expand Down Expand Up @@ -391,7 +395,7 @@ def __setstate__(self, state):
self._freqs = state["freqs"]
self._dims = state["dims"]
self._sfreq = state["sfreq"]
self.info = Info(**state["info"])
self.info = Info(**_convert_old_birthday_format(state["info"]))
self._data_type = state["data_type"]
self._nave = state.get("nave") # objs saved before #11282 won't have `nave`
self._weights = state.get("weights") # objs saved before #12747 won't have
Expand Down
27 changes: 21 additions & 6 deletions mne/time_frequency/tests/test_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.

import datetime
import re
from functools import partial

Expand All @@ -26,7 +27,7 @@
SpectrumArray,
combine_spectrum,
)
from mne.utils import _record_warnings
from mne.utils import _import_h5io_funcs, _record_warnings


def test_compute_psd_errors(raw):
Expand Down Expand Up @@ -178,6 +179,7 @@ def _get_inst(inst, request, *, evoked=None, average_tfr=None):
def test_spectrum_io(inst, tmp_path, request, evoked):
"""Test save/load of spectrum objects."""
pytest.importorskip("h5io")
h5py = pytest.importorskip("h5py")
fname = tmp_path / f"{inst}-spectrum.h5"
inst = _get_inst(inst, request, evoked=evoked)
if isinstance(inst, BaseEpochs):
Expand All @@ -190,12 +192,25 @@ def test_spectrum_io(inst, tmp_path, request, evoked):
orig.save(fname)
loaded = read_spectrum(fname)
assert orig == loaded
# Only check following for one type
if not isinstance(inst, BaseEpochs):
return
# Test loading with old-style birthday format
fname_subject_info = tmp_path / "subject-info.h5"
_, write_hdf5 = _import_h5io_funcs()
write_hdf5(fname_subject_info, dict(birthday=(2000, 1, 1)), title="subject_info")
with h5py.File(fname, "r+") as f:
del f["mnepython/key_info/key_subject_info"]
f["mnepython/key_info/key_subject_info"] = h5py.ExternalLink(
fname_subject_info, "subject_info"
)
loaded = read_spectrum(fname)
assert isinstance(loaded.info["subject_info"]["birthday"], datetime.date)
# Test Spectrum from EpochsSpectrum.average() can be read (gh-13521)
if isinstance(inst, BaseEpochs):
origavg = orig.average()
origavg.save(fname, overwrite=True)
loadedavg = read_spectrum(fname)
assert origavg == loadedavg
origavg = orig.average()
origavg.save(fname, overwrite=True)
loadedavg = read_spectrum(fname)
assert origavg == loadedavg


def test_spectrum_copy(raw_spectrum):
Expand Down
19 changes: 18 additions & 1 deletion mne/time_frequency/tests/test_tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
tfr_multitaper,
write_tfrs,
)
from mne.utils import catch_logging, grand_average
from mne.utils import _import_h5io_funcs, catch_logging, grand_average
from mne.utils._testing import _get_suptitle
from mne.viz.utils import (
_channel_type_prettyprint,
Expand Down Expand Up @@ -620,6 +620,7 @@ def test_tfr_io(inst, average_tfr, request, tmp_path):
"""Test TFR I/O."""
pytest.importorskip("h5io")
pd = pytest.importorskip("pandas")
h5py = pytest.importorskip("h5py")

tfr = _get_inst(inst, request, average_tfr=average_tfr)
fname = tmp_path / "temp_tfr.hdf5"
Expand Down Expand Up @@ -679,6 +680,22 @@ def test_tfr_io(inst, average_tfr, request, tmp_path):
tfravg.save(fname, overwrite=True)
tfravg_loaded = read_tfrs(fname)
assert tfravg == tfravg_loaded
# test loading with old-style birthday format
fname_multi = tmp_path / "temp_multi_tfr.hdf5"
write_tfrs(fname_multi, tfr) # also check for multiple files from write_tfrs
fname_subject_info = tmp_path / "subject-info.hdf5"
_, write_hdf5 = _import_h5io_funcs()
write_hdf5(fname_subject_info, dict(birthday=(2000, 1, 1)), title="subject_info")
for this_fname in (fname, fname_multi):
with h5py.File(this_fname, "r+") as f:
if f.get("mnepython/key_info/key_subject_info"):
path = "mnepython/key_info/key_subject_info"
else: # multi-files on linux have different path to attrs
path = "mnepython/idx_0/idx_1/key_info/key_subject_info"
del f[path]
f[path] = h5py.ExternalLink(fname_subject_info, "subject_info")
tfr_loaded = read_tfrs(this_fname)
assert isinstance(tfr_loaded.info["subject_info"]["birthday"], datetime.date)
# test with taper dimension and weights
n_tapers = 3 # anything >= 1 should do
weights = np.ones((n_tapers, tfr.shape[2])) # tapers x freqs
Expand Down
6 changes: 3 additions & 3 deletions mne/time_frequency/tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
verbose,
warn,
)
from ..utils.spectrum import _get_instance_type_string
from ..utils.spectrum import _convert_old_birthday_format, _get_instance_type_string
from ..viz.topo import _imshow_tfr, _imshow_tfr_unified, _plot_topo
from ..viz.topomap import (
_add_colorbar,
Expand Down Expand Up @@ -1433,7 +1433,7 @@ def __setstate__(self, state):
self._dims = defaults["dims"]
self._raw_times = np.asarray(defaults["times"], dtype=np.float64)
self._baseline = defaults["baseline"]
self.info = Info(**defaults["info"])
self.info = Info(**_convert_old_birthday_format(defaults["info"]))
self._data_type = defaults["data_type"]
self._decim = defaults["decim"]
self.preload = True
Expand Down Expand Up @@ -4141,7 +4141,7 @@ def _read_multiple_tfrs(tfr_data, condition=None, *, verbose=None):
if key != condition:
continue
tfr = dict(tfr)
tfr["info"] = Info(tfr["info"])
tfr["info"] = Info(_convert_old_birthday_format(tfr["info"]))
tfr["info"]._check_consistency()
if "metadata" in tfr:
tfr["metadata"] = _prepare_read_metadata(tfr["metadata"])
Expand Down
11 changes: 11 additions & 0 deletions mne/utils/spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.

from datetime import datetime
from inspect import currentframe, getargvalues, signature

from ..utils import warn
Expand Down Expand Up @@ -102,3 +103,13 @@ def _split_psd_kwargs(*, plot_fun=None, kwargs=None):
for k in plot_kwargs:
del kwargs[k]
return kwargs, plot_kwargs


def _convert_old_birthday_format(info):
"""Convert deprecated birthday tuple to datetime."""
subject_info = info.get("subject_info")
if subject_info is not None:
birthday = subject_info.get("birthday")
if isinstance(birthday, tuple):
info["subject_info"]["birthday"] = datetime(*birthday)
return info