Source code for pyepr.relaxation_analysis

from matplotlib.figure import Figure
import matplotlib.cm as cm
import numpy as np
from deerlab import noiselevel
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
from pyepr.sequences import Sequence
import deerlab as dl
from scipy.linalg import svd
from pyepr.colors import primary_colors
# ===========================================================================


class CarrPurcellAnalysis():

    def __init__(self, dataset, sequence: Sequence = None) -> None:
        """Analysis and calculation of Carr Purcell decay. 

        Parameters
        ----------
        dataset : 
            _description_

        Attributes
        ----------
        axis : xr.DataArray
            The time axis representing the interpulse delay.
        """
        # self.axis = dataset.axes[0]
        # self.data = dataset.data
        if 'tau1' in dataset.coords:
            self.axis = dataset['tau1']
        elif 'tau' in dataset.coords:
            self.axis = dataset['tau']
        elif 't' in dataset.coords:
            self.axis = dataset['t']
        elif 'step' in dataset.coords:
            self.axis = dataset['step'] 
        else:
            self.axis = dataset['X']
        
        dataset = dataset.epr.correctphasefull
        self.data = dataset.data.real
        self.dataset = dataset

        data = self.data / np.max(self.data)
        self.noise = noiselevel(data)
        # if sequence is None and hasattr(dataset,'sequence'):
        #     self.seq = dataset.sequence
        # else:
        #     self.seq = sequence
        pass
    
    def fit(self, type: str = "mono",**kwargs):
        """Fit the experimental CP decay

        Parameters
        ----------
        type : str, optional
            Either a mono or double exponential decay model, by default "mono"

        """

        data = self.data
        data /= np.max(data)

        # if type == "mono":
        #     self.func = lambda x, a, b, e: a*np.exp(-b*x**e)
        #     p0 = [1, 1, 2]
        #     bounds = ([0, 0, 0],[2, 1000, 10])
        # elif type == "double":
        #     self.func = lambda x, a, b, e, c, d, f: a*np.exp(-b*x**e) + c*np.exp(-d*x**f)
        #     p0 = [1, 1, 2, 1, 1, 2]
        #     bounds = ([0, 0, 0, 0, 1, 0],[2, 1000, 10, 2, 1000, 10])
        # else:
        #     raise ValueError("Type must be one of: mono")s
        
        # self.fit_type = type
        # self.fit_result = curve_fit(self.func, self.axis, data, p0=p0, bounds=bounds)
        monoModel = dl.bg_strexp
        monoModel.name = 'Stretched exponential'
        doubleModel = dl.bg_sumstrexp
        doubleModel.weight1.ub = 500
        doubleModel.name = "Sum of two stretched exponentials"

        testModels = []
        if type == "mono":
            testModels.append(monoModel)
        elif type == "double":
            testModels.append(doubleModel)

        else: # type == "auto"
            testModels = [monoModel, doubleModel]

        results = []
        for model in testModels:
            tmp_result = dl.fit(model,data,self.axis,reg=False,**kwargs)
            results.append(tmp_result)
        
        if len(results) == 1:
            self.fit_result = results[0]
            self.fit_model = testModels[0]
        else:
            # Select based of R2 and weight of first component
            R2 = [result.stats['R2'] for result in results]
            self.fit_result = results[np.argmax(R2)]
            self.fit_model = testModels[np.argmax(R2)]
            print(f"Selected model: {self.fit_model.description}")
        
        return self.fit_result

    def plot(self, norm: bool = True, ci=50, axs=None, fig=None) -> Figure:
        """Plot the carr purcell decay with fit, if avaliable.

        Parameters
        ----------
        norm : bool, optional
            Normalise the fit to a maximum of 1, by default True
        ci : int, optional
            The percentage confidence interval to plot, by default 50
        

        Returns
        -------
        Figure
            The figure.
        """

        if norm is True:
            data = self.data
            data /= np.max(data)

        if axs is None and fig is None:
            fig, axs = plt.subplots()

        if hasattr(self, "fit_result"):
            x = self.axis
            V = self.fit_result.evaluate(self.fit_model, x)*self.fit_result.scale
            fitUncert = self.fit_result.propagate(self.fit_model, x)
            VCi = fitUncert.ci(ci)*self.fit_result.scale
            ub = VCi[:,1]
            lb = VCi[:,0]
            # ub = self.fit_model(x,*self.fit_result.paramUncert.ci(ci)[:-1,1])*self.fit_result.paramUncert.ci(ci)[-1,1]
            # lb = self.fit_model(x,*self.fit_result.paramUncert.ci(ci)[:-1,0])*self.fit_result.paramUncert.ci(ci)[-1,0]
            axs.plot(self.axis, data, '.', label='data', color='0.6', ms=6)
            axs.plot(x, V, label='fit', color=primary_colors[0], lw=2)
            if ci is not None:
                axs.fill_between(x, lb, ub, color=primary_colors[0], alpha=0.3, label=f"{ci}% CI")

            axs.legend()
        else:
            axs.plot(self.axis, data, label='data')

        axs.set_xlabel('Time / us')
        axs.set_ylabel('Normalised Amplitude')
        return fig
    
    def check_decay(self,level=0.1):
        """
        Checks that the data has decayed by over 90% in the first half, and less than 90% in the first quarter.

        Parameters
        ----------
        level : float, optional
            The level to check the decay, by default 0.05

        Returns
        -------
        int
            0 if both conditions are met, 1 if a longer decay is needed, and -1 if the decay is too long.
        
        """
        n_points = len(self.axis)
        if hasattr(self,"fit_result"):
            # decay = self.func(self.axis, *self.fit_result[0]).data
            x = self.axis
            decay = self.fit_result.evaluate(self.fit_model, x)*self.fit_result.scale
            if (decay[:int(n_points*0.50)].min() < level) and (decay[:int(n_points*0.25)].min() > level):
                return 0
            elif decay[:int(n_points*0.25)].min() < level:
                return 1
            elif decay[:int(n_points*0.50)].min() > level:
                return -1
        else:
            raise ValueError("No fit result found")

    def find_optimal(
            self, SNR_target, target_time: float, target_step, averages=None) -> float:
        """Calculate the optimal inter pulse delay for a given total measurment
        time. 

        Parameters
        ----------
        SNR_target: float,
            The Signal to Noise ratio target.
        target_time : float
            The target time in hours
        target_shrt : float
            The shot repettition time of target in seconds
        target_step: float
            The target step size in ns.
        averages : int, optional
            The total number of shots taken, by default None. If None, the
            number of shots will be calculated from the dataset.
        

        Returns
        -------
        float
            The calculated optimal time in us
        """
        # time_per_point = shrt * averages
        dataset = self.dataset
        if averages is None:
            averages = dataset.nAvgs * dataset.shots * dataset.nPcyc
        target_shrt = dataset.reptime * 1e-6

        data = np.abs(self.data)
        data /= np.max(data)

        if hasattr(self,"fit_result"):
            calc_data = self.func(self.axis.data,*self.fit_result[0])
        else:
            calc_data = data

        # averages = self.seq.shots.value * self.seq.averages.value
        self.noise = noiselevel(data)
        data_snr = calc_data / self.noise
        data_snr_avgs = data_snr / np.sqrt(averages)

        # Target time
        target_time = target_time * 3600
        target_step_us = target_step * 1e-3
        g = (target_time * target_step / target_shrt) * 1/(self.axis.data)
        f = (SNR_target/data_snr_avgs)**2

        self.optimal = self.axis.data[np.argmin(np.abs(g-f))]
        return self.optimal
    
    def __call__(self, x, norm=True, SNR=False, source=None):
        """
        Evaluate the fit or data at a given x value.

        Parameters
        ----------
        x : float
            The x value to evaluate the data at.
        norm : bool, optional
            Normalise the data to the maximum, by default True
        SNR : bool, optional
            Return the SNR_per_sqrt(shot) for this data point, by default False
            If True, the data is normalised to the maximum of the data.
        source : str, optional
            The source of the data, either 'fit' or 'data', by default None
            If None, the source is determined by the presence of a fit result.
        
        """
        
        if source is 'fit' or (source is None and hasattr(self,'fit_result')):
            V = self.fit_result.evaluate(self.fit_model, x)*self.fit_result.scale
            if not norm and SNR is not True: # Fit is normalised to 1 by default
                V *= np.max(self.data)
        elif source is 'data' or (source is None and not hasattr(self,'fit_result')):
            x_idx = np.abs(self.axis - x).argmin()
            V = self.data[x_idx]

            if norm is True or SNR is True:
                V /= np.max(self.data)

        if SNR is True:
            V /= self.noise
            V /= np.sqrt(self.dataset.nAvgs * self.dataset.shots * self.dataset.nPcyc)

        # return single value if x is a single value
        if np.isscalar(x):
            return V[0]
        else:
            return V

class HahnEchoRelaxationAnalysis():

    def __init__(self, dataset) -> None:
        """Analysis, fitting and plotting for the HahnEchoRelaxation Sequence. 

        Parameters
        ----------
        dataset : xarray.DataArray
            The dataset to be analysed, with the time axis contained.

        Attributes
        ----------
        axis : xr.DataArray
            The time axis representing the interpulse delay.
        """
        # self.axis = dataset.axes[0]
        # self.data = dataset.data
        if 'tau1' in dataset.coords:
            self.axis = dataset['tau1']
        elif 'tau' in dataset.coords:
            self.axis = dataset['tau']
        elif 't' in dataset.coords:
            self.axis = dataset['t']
        elif 'step' in dataset.coords:
            self.axis = dataset['step'] 
        else:
            self.axis = dataset['X']
        
        dataset = dataset.epr.correctphasefull
        self.data = dataset.data.real
        self.dataset = dataset

        data = self.data / np.max(self.data)
        self.noise = noiselevel(data)

        
        pass
    
    def fit(self, type: str = "mono",**kwargs):
        """Fit the experimental CP decay

        Parameters
        ----------
        type : str, optional
            Either a mono or double exponential decay model, by default "mono"

        """

        data = self.data
        data /= np.max(data)
        monoModel = dl.bg_strexp
        monoModel.name = 'Stretched exponential'
        doubleModel = dl.bg_sumstrexp
        doubleModel.weight1.ub = 200
        doubleModel.name = "Sum of two stretched exponentials"

        testModels = []
        if type == "mono":
            testModels.append(monoModel)
        elif type == "double":
            testModels.append(doubleModel)

        else: # type == "auto"
            testModels = [monoModel, doubleModel]

        results = []
        for model in testModels:
            results.append(dl.fit(model,data,self.axis,reg=False,**kwargs))
        
        if len(results) == 1:
            self.fit_result = results[0]
            self.fit_model = testModels[0]
        else:
            # Select based of R2
            R2 = [result.stats['R2'] for result in results]
            self.fit_result = results[np.argmax(R2)]
            self.fit_model = testModels[np.argmax(R2)]
            print(f"Selected model: {self.fit_model.description}")
        
        return self.fit_result

    def plot(self, norm: bool = True, ci=50, axs=None, fig=None) -> Figure:
        """Plot the carr purcell decay with fit, if avaliable.

        Parameters
        ----------
        norm : bool, optional
            Normalise the fit to a maximum of 1, by default True
        ci : int, optional
            The percentage confidence interval to plot, by default 50
        

        Returns
        -------
        Figure
            The figure.
        """

        if norm is True:
            data = self.data
            data /= np.max(data)

        if axs is None and fig is None:
            fig, axs = plt.subplots()

        if hasattr(self, "fit_result"):
            x = self.axis
            V = self.fit_result.evaluate(self.fit_model, x)*self.fit_result.scale
            fitUncert = self.fit_result.propagate(self.fit_model, x)
            VCi = fitUncert.ci(ci)*self.fit_result.scale
            ub = VCi[:,1]
            lb = VCi[:,0]
            # ub = self.fit_model(x,*self.fit_result.paramUncert.ci(ci)[:-1,1])*self.fit_result.paramUncert.ci(ci)[-1,1]
            # lb = self.fit_model(x,*self.fit_result.paramUncert.ci(ci)[:-1,0])*self.fit_result.paramUncert.ci(ci)[-1,0]
            axs.plot(self.axis, data, '.', label='data', color='0.6', ms=6)
            axs.plot(x, V, label='fit', color=primary_colors[0], lw=2)
            if ci is not None:
                axs.fill_between(x, lb, ub, color=primary_colors[0], alpha=0.3, label=f"{ci}% CI")

            axs.legend()
        else:
            axs.plot(self.axis, data, label='data')

        axs.set_xlabel('Time / us')
        axs.set_ylabel('Normalised Amplitude')
        return fig
    
    def check_decay(self,level=0.1):
        """
        Checks that the data has decayed by over 90% in the first half, and less than 90% in the first quarter.

        Parameters
        ----------
        level : float, optional
            The level to check the decay, by default 0.05

        Returns
        -------
        int
            0 if both conditions are met, 1 if a longer decay is needed, and -1 if the decay is too long.
        
        """
        n_points = len(self.axis)
        if hasattr(self,"fit_result"):
            # decay = self.func(self.axis, *self.fit_result[0]).data
            x = self.axis
            decay = self.fit_result.evaluate(self.fit_model, x)*self.fit_result.scale
            if (decay[:int(n_points*0.50)].min() < level) and (decay[:int(n_points*0.25)].min() > level):
                return 0
            elif decay[:int(n_points*0.25)].min() < level:
                return 1
            elif decay[:int(n_points*0.50)].min() > level:
                return -1
        else:
            raise ValueError("No fit result found")

    def __call__(self, x, norm=True, SNR=False, source=None):
        """
        Evaluate the fit or data at a given x value.

        Parameters
        ----------
        x : float
            The x value to evaluate the data at.
        norm : bool, optional
            Normalise the data to the maximum, by default True
        SNR : bool, optional
            Return the SNR_per_sqrt(shot) for this data point, by default False
        source : str, optional
            The source of the data, either 'fit' or 'data', by default None
            If None, the source is determined by the presence of a fit result.
        
        """
        
        if source is 'fit' or (source is None and hasattr(self,'fit_result')):
            V = self.fit_result.evaluate(self.fit_model, x)*self.fit_result.scale
            if not norm: # Fit is normalised to 1 by default
                V *= np.max(self.data)
        elif source is 'data' or (source is None and not hasattr(self,'fit_result')):
            x_idx = np.abs(self.axis - x).argmin()
            V = self.data[x_idx]

            if norm is True:
                V /= np.max(self.data)

        if SNR is True:
            V /= self.noise
            V /= np.sqrt(self.dataset.nAvgs * self.dataset.shots * self.dataset.nPcyc)

        # return single value if x is a single value
        if np.isscalar(x):
            return V[0]
        else:
            return V
        
class ReptimeAnalysis():

    def __init__(self, dataset, sequence: Sequence = None) -> None:
        """Analysis and calculation of Reptime based saturation recovery. 

        Parameters
        ----------
        dataset :
            The dataset to be analyzed.
        sequence : Sequence, optional
            The sequence object describing the experiment. (not currently used)
        """
        # self.axis = dataset.axes[0]
        self.axis = dataset['reptime']
        # if self.axis.max() > 1e4:
        #     self.axis /= 1e3 # ns -> us
        # self.data = dataset.data/np.max(dataset.data)
        
        if np.iscomplexobj(dataset.data):
            self.data = dataset.epr.correctphase
        else:
            self.data = dataset

        self.data.data /= np.max(self.data.data)
        self.seq = sequence
        pass

    def fit(self,type='SE', **kwargs):

        if type == 'SE': # stetch exponential recovery
            def func(t,A,T1,xi):
                return A*(1-np.exp(-(t/T1)**xi))
            p0 = [1,1.8e3,1]
        elif type.lower() == 'exp': # exponential recovery
            def func(t,A,T1):
                return A*(1-np.exp(-t/T1))
            p0 = [1,1.8e3]
        self.func = func

        if 'p0' in kwargs:
            p0 = kwargs.pop('p0')
        # mymodel = dl.Model(func,constants='t')
        # mymodel.T1.set(lb=0,ub=np.inf,par0=1.8e3)
        # mymodel.T1.unit = 'us'
        # mymodel.T1.description = 'T1 time'

        # results = dl.fit(mymodel,self.data.real,self.axis,reg=False,**kwargs)
        # self.fit_result = results

        self.fit_result = curve_fit(func, self.axis, self.data, p0=p0,**kwargs)

        return self.fit_result

    def plot(self, axs=None, fig=None):

        if axs is None and fig is None:
            fig, axs = plt.subplots()

        # if hasattr(self,'fit_result'):
        #     # renormalise data to fit amplitude
        #     data = self.data/self.fit_result[0][0]
        # else:
        data = self.data

        axs.plot(self.axis/1e3, data, '.', label='data', color='0.6', ms=6)
        
        if hasattr(self,'fit_result'):
            axs.plot(self.axis/1e3, self.func(self.axis,*self.fit_result[0]), label='Fit', color=primary_colors[0], lw=2)
            axs.set_xlim(*axs.get_xlim())
            axs.set_ylim(*axs.get_ylim())
            ylim = axs.get_ylim()
            axs.vlines(self.fit_result[0][1]/1e3,*ylim,linestyles='dashed',label='T1 = {:.3g} ms'.format(self.fit_result[0][1]/1e3),colors=primary_colors[1])

            if hasattr(self,'optimal'):
                axs.vlines(self.optimal/1e3,*ylim,linestyles='dashed',label='Optimal = {:.3g} ms'.format(self.optimal/1e3),colors=primary_colors[2])

        axs.set_xlabel('Reptime / ms')
        axs.set_ylabel('Normalised signal')
        axs.legend()
        return fig

    def calc_optimal_reptime(self, recovery=0.9):
        # Calculates the x% recovery time
        if recovery is not None:
            self.optimal = self.fit_result[0][1]*np.log(1/(1-recovery))
        else:
            t = self.axis
            optimal_vals = self.func(t,*self.fit_result[0])* 1/np.sqrt(t)
            self.optimal = t[np.nanargmax(optimal_vals)]
        return self.optimal

[docs] def detect_ESEEM(dataset,type='deuteron', threshold=1.5): """Detect if the dataset is an ESEEM experiment. Parameters ---------- dataset : xr.DataArray The dataset to be analyzed. type : str, optional The type of ESEEM experiment, either deuteron or proton, by default 'deuteron' threshold : float, optional The SNR threshold for detection, by default 1.5 Returns ------- bool True if ESEEM is detected, False if not. """ D_freq = 4.10663 * dataset.B *1e-4 *np.pi /2 P_freq = 26.75221 * dataset.B *1e-4 *np.pi /2 def find_pnl(freq): fft_data = np.abs(dataset.epr.fft) index = np.abs(fft_data.X - freq).argmin().data peak = 2 /fft_data.size * fft_data[index] noiselevel = 2/fft_data.size * fft_data[index-8:index+8].mean() return peak/noiselevel if type == 'deuteron': peak = find_pnl(D_freq) elif type == 'proton': peak = find_pnl(P_freq) else: raise ValueError('type must be deuteron or proton') if peak > threshold: return True else: return False
[docs] cmap = ['#D95B6F','#42A399']
[docs] def plot_1Drelax(*args,fig=None, axs=None,cmap=cmap): """ Create a superimposed plot of relaxation data and fits. Parameters ---------- args : ad.Analysis The 1D relaxation data to be plotted. fig : Figure, optional The figure to plot to, by default None axs : Axes, optional The axes to plot to, by default None cmap : list, optional The color map to use, by default ad.cmap """ if fig is None and axs is None: fig, axs = plt.subplots(1,1, figsize=(5,5)) elif axs is None: axs = fig.subplots(1,1) for i,arg in enumerate(args): if arg.dataset.seq_name == 'T2RelaxationSequence': xscale = 2 label='Hahn Echo' elif arg.dataset.seq_name == 'CarrPurcellSequence': xscale = 4 label='CP-2' elif (arg.dataset.seq_name == 'DEERSequence') or (arg.dataset.seq_name == '5pDEER'): xscale = 4 label='CP-2' else: xscale = 4 label='CP-2' axs.plot(arg.axis*xscale, arg.data/arg.data.max(), '.', label=label,alpha=0.5,color=cmap[i],mec='none') if hasattr(arg, 'func'): print('The scipy fitting elements are being deprecated, please use DeerLab fitting') V = arg.func(arg.axis,*arg.fit_result[0]) else: V = arg.fit_model(arg.axis,*arg.fit_result.param[:-1])*arg.fit_result.scale axs.plot(arg.axis*xscale, V, '-',alpha=1,color=cmap[i], lw=2) axs.legend() axs.set_xlabel('Total Sequence Length / $\mu s$') axs.set_ylabel('Signal / $ A.U. $') return fig