diff --git a/doc/changes/dev/13518.bugfix.rst b/doc/changes/dev/13518.bugfix.rst new file mode 100644 index 00000000000..f7d73a2cb04 --- /dev/null +++ b/doc/changes/dev/13518.bugfix.rst @@ -0,0 +1 @@ +Raise warning implementation in interpolate_bads() method when interpolation of channels with invalid position is done, by `Himanshu Mahor`_. diff --git a/mne/channels/channels.py b/mne/channels/channels.py index 08c35dd310e..4825fccd882 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -827,6 +827,7 @@ def interpolate_bads( origin="auto", method=None, exclude=(), + on_bad_position="raise", verbose=None, ): """Interpolate bad MEG and EEG channels. @@ -872,6 +873,10 @@ def interpolate_bads( exclude : list | tuple The channels to exclude from interpolation. If excluded a bad channel will stay in bads. + on_bad_position : "raise" | "warn" | "ignore" + What to do when one or more sensor positions are invalid (zero or NaN). + If ``"warn"`` or ``"ignore"``, channels with invalid positions will be + filled with :data:`~numpy.nan`. %(verbose)s Returns @@ -896,6 +901,23 @@ def interpolate_bads( _check_preload(self, "interpolation") _validate_type(method, (dict, str, None), "method") + + invalid_chs = [] + for ch in self.info["bads"]: + loc = self.info["chs"][self.ch_names.index(ch)]["loc"][:3] + if np.allclose(loc, 0.0, atol=1e-16) or np.isnan(loc).any(): + invalid_chs.append(ch) + + if invalid_chs: + msg = ( + f"Channel(s) {invalid_chs} have invalid sensor position(s). " + "Interpolation cannot proceed correctly. If you want to continue " + "despite missing positions, set on_bad_position='warn' or 'ignore', " + "which outputs all NaN values (np.nan) for the interpolated " + "channel(s)." + ) + _on_missing(on_bad_position, msg) + method = _handle_default("interpolation_method", method) ch_types = self.get_channel_types(unique=True) # figure out if we have "mag" for "meg", "hbo" for "fnirs", ... to filter the diff --git a/mne/channels/tests/test_interpolation.py b/mne/channels/tests/test_interpolation.py index de09a97c306..3f7699622ce 100644 --- a/mne/channels/tests/test_interpolation.py +++ b/mne/channels/tests/test_interpolation.py @@ -420,6 +420,27 @@ def test_nan_interpolation(raw): raw.interpolate_bads(method="nan", reset_bads=False) assert raw.info["bads"] == ch_to_interp + store = raw.info["chs"][1]["loc"] + # for on_bad_position="raise" + raw.info["bads"] = ch_to_interp + raw.info["chs"][1]["loc"] = np.full(12, np.nan) + with pytest.raises(ValueError, match="have invalid sensor position"): + # DOES NOT interpolates at all. So raw.info["bads"] remains as is + raw.interpolate_bads(on_bad_position="raise") + + # for on_bad_position="warn" + with pytest.warns(RuntimeWarning, match="have invalid sensor position"): + # this DOES the interpolation BUT with a warning + # so raw.info["bad"] will be empty again, + # and interpolated channel with be all np.nan + raw.interpolate_bads(on_bad_position="warn") + + # for on_bad_position="ignore" + raw.info["bads"] = ch_to_interp + assert raw.interpolate_bads(on_bad_position="ignore") + assert np.isnan(bad_chs).all, "Interpolated channel should be all NaN" + raw.info["chs"][1]["loc"] = store + # make sure other channels are untouched raw.drop_channels(ch_to_interp) good_chs = raw.get_data()