Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Outsource leave-one-out splitter so it can be used across data types #98

Merged
merged 34 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
5f13501
Starting to refactor logo_split
teresamg Nov 30, 2022
7fa9b05
Fixed flake8 errors
teresamg Nov 30, 2022
5956168
Moved h5 file read
teresamg Nov 30, 2022
bc14820
Pull dwframe, bframe, bvec into estimator.py
teresamg Dec 1, 2022
200b9f9
Fixed flake8 errors
teresamg Dec 1, 2022
38b4276
Starting to refactor logo_split
teresamg Nov 30, 2022
56f9fd0
Moved h5 file read
teresamg Nov 30, 2022
6aeb4d7
Pull dwframe, bframe, bvec into estimator.py
teresamg Dec 1, 2022
1d5c6ea
Fixed flake8 errors
teresamg Dec 1, 2022
26a0dc4
Preallocate em_affines as array, not Affine object
teresamg Dec 8, 2022
c0693a1
Removed orig copy of dwdata
teresamg Dec 8, 2022
2140673
enh: move logo_split to new submodule
oesteban Dec 15, 2022
374864a
Update src/eddymotion/data/splitting.py
teresamg Dec 15, 2022
bc845d9
Update src/eddymotion/data/splitting.py
teresamg Dec 15, 2022
6191fa1
Update src/eddymotion/data/splitting.py
teresamg Dec 15, 2022
175d1dd
Update src/eddymotion/data/splitting.py
teresamg Dec 15, 2022
cb19fb1
Update src/eddymotion/data/splitting.py
teresamg Dec 15, 2022
64ccc58
Update src/eddymotion/data/splitting.py
teresamg Dec 15, 2022
88bf1dc
Update src/eddymotion/data/splitting.py
teresamg Dec 15, 2022
a3a3f1b
Fixed logo_split() call and dwdata->data
teresamg Dec 15, 2022
f18029e
Updated set_transform(), removed pbar grad_str text
teresamg Dec 15, 2022
a8538bf
fix: revise merge conflicts and get ready for final revision
oesteban Mar 27, 2024
542caa8
fix: revert accidental removal of two lines
oesteban Mar 27, 2024
82a557e
Apply suggestions from code review
esavary Mar 27, 2024
d45ec01
fix: Restore access to HDF5 file
esavary Mar 27, 2024
e776155
fix: Restore b0 argument
esavary Mar 27, 2024
b4a18ad
fix: typos
esavary Mar 27, 2024
3a96d51
Add: test for lovo_split
esavary Mar 28, 2024
1d6762c
Fix: remove unused import
esavary Mar 28, 2024
bcff16d
Apply suggestions from code review
esavary Mar 28, 2024
316eb57
Fix: return test data and gradient
esavary Mar 28, 2024
4217635
Fix: typo
esavary Mar 28, 2024
14cd4d0
Fix: masking
esavary Mar 28, 2024
9591b81
Update src/eddymotion/estimator.py
esavary Mar 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 4 additions & 54 deletions src/eddymotion/data/dmri.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,64 +70,14 @@ class DWI:
)
"""A path to an HDF5 file to store the whole dataset."""

def get_filename(self):
"""Get the filepath of the HDF5 file."""
return self._filepath

def __len__(self):
"""Obtain the number of high-*b* orientations."""
return self.dataobj.shape[-1]

def logo_split(self, index, with_b0=False):
"""
Produce one fold of LOGO (leave-one-gradient-out).

Parameters
----------
index : :obj:`int`
Index of the DWI orientation to be left out in this fold.
with_b0 : :obj:`bool`
Insert the *b=0* reference at the beginning of the training dataset.

Returns
-------
(train_data, train_gradients) : :obj:`tuple`
Training DWI and corresponding gradients.
Training data/gradients come **from the updated dataset**.
(test_data, test_gradients) :obj:`tuple`
Test 3D map (one DWI orientation) and corresponding b-vector/value.
The test data/gradient come **from the original dataset**.

"""
if not Path(self._filepath).exists():
self.to_filename(self._filepath)

# read original DWI data & b-vector
with h5py.File(self._filepath, "r") as in_file:
root = in_file["/0"]
dwframe = np.asanyarray(root["dataobj"][..., index])
bframe = np.asanyarray(root["gradients"][..., index])
oesteban marked this conversation as resolved.
Show resolved Hide resolved

# if the size of the mask does not match data, cache is stale
mask = np.zeros(len(self), dtype=bool)
mask[index] = True

train_data = self.dataobj[..., ~mask]
train_gradients = self.gradients[..., ~mask]

if with_b0:
oesteban marked this conversation as resolved.
Show resolved Hide resolved
train_data = np.concatenate(
(np.asanyarray(self.bzero)[..., np.newaxis], train_data),
axis=-1,
)
b0vec = np.zeros((4, 1))
b0vec[0, 0] = 1
train_gradients = np.concatenate(
(b0vec, train_gradients),
axis=-1,
)

return (
(train_data, train_gradients),
(dwframe, bframe),
)

def set_transform(self, index, affine, order=3):
"""Set an affine, and update data object and gradients."""
reference = namedtuple("ImageGrid", ("shape", "affine"))(
Expand Down
84 changes: 84 additions & 0 deletions src/eddymotion/data/splitting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
#
# Copyright 2022 The NiPreps Developers <nipreps@gmail.com>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# We support and encourage derived works from this project, please read
# about our expectations at
#
# https://www.nipreps.org/community/licensing/
#
"""Data splitting helpers."""
from pathlib import Path
import numpy as np
import h5py


def lovo_split(dataset, index, with_b0=False):
"""
Produce one fold of LOVO (leave-one-volume-out).

Parameters
----------
dataset : :obj:`eddymotion.data.dmri.DWI`
DWI object
index : :obj:`int`
Index of the DWI orientation to be left out in this fold.

Returns
-------
(train_data, train_gradients) : :obj:`tuple`
Training DWI and corresponding gradients.
Training data/gradients come **from the updated dataset**.
(test_data, test_gradients) :obj:`tuple`
Test 3D map (one DWI orientation) and corresponding b-vector/value.
The test data/gradient come **from the original dataset**.

"""

if not Path(dataset.get_filename()).exists():
dataset.to_filename(dataset.get_filename())

# read original DWI data & b-vector
with h5py.File(dataset.get_filename(), "r") as in_file:
root = in_file["/0"]
data = np.asanyarray(root["dataobj"])
gradients = np.asanyarray(root["gradients"])

# if the size of the mask does not match data, cache is stale
mask = np.zeros(data.shape[-1], dtype=bool)
mask[index] = True

train_data = data[..., ~mask]
train_gradients = gradients[..., ~mask]
test_data = data[..., mask]
test_gradients = gradients[..., mask]

if with_b0:
train_data = np.concatenate(
(np.asanyarray(dataset.bzero)[..., np.newaxis], train_data),
axis=-1,
)
b0vec = np.zeros((4, 1))
b0vec[0, 0] = 1
train_gradients = np.concatenate(
(b0vec, train_gradients),
axis=-1,
)

return (
(train_data, train_gradients),
(test_data, test_gradients),
)
3 changes: 2 additions & 1 deletion src/eddymotion/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from pkg_resources import resource_filename as pkg_fn
from tqdm import tqdm

from eddymotion.data.splitting import lovo_split
from eddymotion.model import ModelFactory


Expand Down Expand Up @@ -150,7 +151,7 @@ def fit(
pbar.set_description_str(
f"Pass {i_iter + 1}/{n_iter} | Fit and predict b-index <{i}>"
)
data_train, data_test = dwdata.logo_split(i, with_b0=True)
data_train, data_test = lovo_split(dwdata, i, with_b0=True)
grad_str = f"{i}, {data_test[1][:3]}, b={int(data_test[1][3])}"
pbar.set_description_str(f"[{grad_str}], {n_jobs} jobs")

Expand Down
3 changes: 2 additions & 1 deletion test/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import pytest

from eddymotion import model
from eddymotion.data.splitting import lovo_split
from eddymotion.data.dmri import DWI


Expand Down Expand Up @@ -94,7 +95,7 @@ def test_two_initialisations(datadir):
dmri_dataset = DWI.from_filename(datadir / "dwi.h5")

# Split data into test and train set
data_train, data_test = dmri_dataset.logo_split(10)
data_train, data_test = lovo_split(dmri_dataset, 10)

# Direct initialisation
model1 = model.AverageDWModel(
Expand Down
62 changes: 62 additions & 0 deletions test/test_splitting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
#
# Copyright 2021 The NiPreps Developers <nipreps@gmail.com>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# We support and encourage derived works from this project, please read
# about our expectations at
#
# https://www.nipreps.org/community/licensing/
#
"""Unit test testing the lovo_split function."""
import numpy as np
from eddymotion.data.dmri import DWI
from eddymotion.data.splitting import lovo_split


def test_lovo_split(datadir):
"""
Test the lovo_split function.

Parameters:
- datadir: The directory containing the test data.

Returns:
None
"""
data = DWI.from_filename(datadir / "dwi.h5")

# Set zeros in dataobj and gradients of the dwi object
data.dataobj[:] = 0
data.gradients[:] = 0

# Select a random index
index = np.random.randint(len(data))

# Set 1 in dataobj and gradients of the dwi object at this specific index
data.dataobj[..., index] = 1
data.gradients[..., index] = 1

# Apply the lovo_split function at the specified index
(train_data, train_gradients), \
(test_data, test_gradients) = lovo_split(data, index)

# Check if the test data contains only 1s
# and the train data contains only 0s after the split
assert np.all(test_data == 1)
assert np.all(train_data == 0)
assert np.all(test_gradients == 1)
assert np.all(train_gradients == 0)