Source code for mcutils.scripts.MDE.spherical_mean

#!/usr/bin/env python
"""Computes spherical mean signal"""
import nibabel as nib
import numpy as np
from scipy.special import sph_harm
from numpy import linalg
import os.path as op
from mcutils.utils.sidecar import AcquisitionParams
from mcutils.scripts.sidecar import index as sidecar_index
from fsl.data.image import removeExt
from numpy.random import shuffle
import logging
from mcutils.utils.spherical import cart2spherical
logger = logging.getLogger(__name__)


[docs]def read_bvecs(filename): if filename is None: return None bvecs = np.genfromtxt(filename).T if bvecs.shape[0] == 3 and bvecs.shape[1] != 3: return bvecs.T return bvecs
[docs]class SPHModel(object): """ Fits the diffusion data with spherical harmonics """
[docs] def __init__(self, data, bvecs, order=4): self.data = data n = np.concatenate([np.zeros(n * 2 + 1, dtype='i4') + n for n in range(0, order + 1, 2)]) m = np.concatenate([np.arange(-n, n + 1, 1) for n in range(0, order + 1, 2)]) _, phi, theta = cart2spherical(*bvecs.T) self.diff_mat = np.real(sph_harm(m[None, :], n[None, :], phi[:, None], theta[:, None])) * (2 * np.sqrt(np.pi))
@property def ncomponents(self, ): return self.diff_mat.shape[1] @property def ngrad(self, ): return self.diff_mat.shape[0]
[docs] def fit(self, select=None): if select is not None: data = self.data[..., select] diff_mat = self.diff_mat[select] else: data, diff_mat = self.data, self.diff_mat ngrad = diff_mat.shape[0] solution, _, _, _ = linalg.lstsq(diff_mat, data.reshape(-1, ngrad).T, rcond=None) return solution.T.reshape(data.shape[:-1] + (self.ncomponents, ))
[docs] def leavekout(self, k): """ Computes accuracy by leaving k out :param k: number of parameters to leave out :return: mean squared residual """ select = np.arange(self.ngrad) shuffle(select) solution = self.fit(select[k:]) predicted = np.tensordot(solution, self.diff_mat[select[:k], :], (-1, -1)) res = np.mean((predicted - self.data[..., select[:k]]) ** 2, -1) return res
[docs]def fit_sph_harm(diffusion_data, bvecs, order=None): """ Fits spherical harmonics to the diffusion data :param diffusion_data: (Nx, Ny, Nz, Ng) array of diffusion data :param bvecs: (Ng, 3) array of diffusion-encoding gradients :param order: maximum spherical harmonic order :return: (Nx, Ny, Nz, Nc) array with the components """ if order is not None: logger.info('Computing spherical mean using spherical harmonics up to order {}'.format(order)) return SPHModel(diffusion_data, bvecs, order).fit() ngrad = bvecs.shape[0] k_out = int(np.ceil(ngrad / 10)) accuracy = np.inf for test_order in range(0, 12, 2): logger.debug('testing prediction accuracy of unseen data for spherical harmonics up to order %d', test_order) model = SPHModel(diffusion_data, bvecs, test_order) new_accuracy = np.mean([model.leavekout(k_out) for _ in range(4)]) logger.debug('New accuracy: %f (compared with previous of %f', new_accuracy, accuracy) if new_accuracy + 1e-10 >= accuracy: logger.info('Maximum accuracy reached with order {}'.format(test_order - 2)) return solution solution = model.fit() accuracy = new_accuracy if new_accuracy + 1e-10 >= accuracy: logger.warning('Maximum accuracy still not reached at order {}'.format(test_order - 2))
[docs]def run_single(diffusion_img: nib.Nifti1Image, method='sample', bvecs=None, order=None): """ Computes the mean across a diffusion image :param diffusion_img: 4D image with diffusion data (shape=(Nx, Ny, Nz, Ng)) :param method: method used to compute the spherical mean :param bvecs: (Ng, 3) array of gradient orientations :param order: maximum spherical harmonic order :return: image with the spherical mean """ diffusion_data = np.asanyarray(diffusion_img.dataobj) if (diffusion_data.ndim == 4 and diffusion_data.shape[-1] == 1) or diffusion_data.ndim == 3: logger.info("Only single volume provided, this is returned as its own mean") mean_data = diffusion_data if mean_data.ndim == 4: mean_data = mean_data[..., 0] else: logger.info("Computing {} mean over {} volumes".format(method, diffusion_data.shape[-1])) mask = np.isfinite(diffusion_data).all(-1) & (diffusion_data != 0).any(-1) mean_data = np.zeros(diffusion_data.shape[:-1]) if method == 'sample': mean_data[mask] = np.mean(diffusion_data[mask], -1) else: if bvecs is None: raise ValueError("Need to provide b-factor for method %s" % method) if method == 'shm': mean_data[mask] = fit_sph_harm(diffusion_data[mask], bvecs, order=order)[..., 0] else: raise ValueError(f"Unrecognized method: {method}") return nib.Nifti1Image(mean_data, affine=None, header=diffusion_img.header)
[docs]def run_multi(diffusion_img: nib.Nifti1Image, indices, method='sample', bvecs=None, order=None): """ Computes the spherical mean for each b-value :param diffusion_img: image with the diffusion data :param indices: integer array with shell ids (zero-based) :param method: method used to compute the spherical mean :param bvecs: (Ng, 3) array of gradient orientation :param order: maximum spherical harmonic order :return: image with spherical mean for each shell """ res = np.zeros(diffusion_img.shape[:3] + (max(indices) + 1, )) for idx_shell in range(max(indices) + 1): res[..., idx_shell] = run_single( nib.Nifti1Image(diffusion_img.get_data()[..., indices == idx_shell], diffusion_img.affine), method=method, bvecs=None if bvecs is None else bvecs[indices == idx_shell, :], order=order ).get_data() return nib.Nifti1Image(res, affine=None, header=diffusion_img.header)
[docs]def run_from_args(args): if args.xps is not None: if args.bvals is not None or args.bvecs is not None: logger.warning("Sidecar file has been provided, so bvals and bvecs will be ignored") xps = AcquisitionParams.read(args.xps) elif args.bvals is not None: xps = AcquisitionParams(b=np.loadtxt(args.bvals).flatten()) bvecs = read_bvecs(args.bvecs) if args.xps is None and args.bvals is None: run_single( diffusion_img=nib.load(args.data), method=args.method, bvecs=bvecs, order=args.shm_order, ).to_filename(args.out) else: indices = sidecar_index.get_indices(xps, args) mean_img = run_multi( diffusion_img=nib.load(args.data), method=args.method, bvecs=bvecs, order=args.shm_order, indices=indices, ) mean_img.to_filename(args.out) directory, filename = op.split(args.out) if args.xps is None: shell_filename = op.join(directory, removeExt(filename) + '.bvals') np.savetxt(shell_filename, xps.groupby(indices)['b'], fmt='%d') else: xps_filename = op.join(directory, removeExt(filename) + '.json') xps.groupby(indices, drop=('theta', 'phi', 'psi')).write(xps_filename)
[docs]def add_to_parser(parser): parser.add_argument("data", help='aligned 4D diffusion data file (post-eddy)') parser.add_argument("out", help='output filename') parser.add_argument("method", choices=('sample', 'shm'), help='method used to calculate the spherical mean:\n' '1. sample: computes the sample mean\n' '2. shm: returns the first component of a fit of spherical harmonics') parser.add_argument("-r", "--bvecs", help='text file with the gradient orientations') parser.add_argument("-b", "--bvals", help='text file with the diffusion weighting. ' 'If provided computes the spherical mean for each b-shell.') parser.add_argument("-x", "--xps", help='Acquisition parameter file (.json or .mat) containing the b-matrices') parser.add_argument('--shm_order', type=int, default=None, help='maximum spherical harmonic order (should be multiple of 2). ' + 'Defaults to testing which order has the highest accuracy.') sidecar_index.add_index_params(parser)