Source code for pyepr.dataset

import numpy as np
import matplotlib.pyplot as plt
import re
import scipy.fft as fft
from pyepr import __version__
from pyepr.pulses import Detection
from pyepr.classes import Parameter
from pyepr.sequences import Sequence
import pyepr.pulses as ad_pulses
import xarray as xr
from deerlab import correctphase
from deerlab import deerload
import copy

# =============================================================================

[docs] def get_all_axes(sequence): # loop over all parameters and get the axis axes = {} for param_name in sequence.__dict__: param = sequence.__dict__[param_name] if not isinstance(param, Parameter): continue if param.axis == []: continue for axis in param.axis: if axis['uuid'] in sequence.axes_uuid and not axis['uuid'] in sequence.reduce_uuid: ax_id = sequence.axes_uuid.index(axis['uuid']) if param.unit == 'ns': convert = 1e-3 elif param.unit == 'us': convert = 1 elif param.unit == 'ms': convert = 1e3 else: convert = 1 axes[param_name] = {'axis': axis['axis']*convert + param.value*convert, 'ax_id':ax_id} for i,pulse in enumerate(sequence.pulses): for param_name in pulse.__dict__: param = pulse.__dict__[param_name] if not isinstance(param, Parameter): continue if param.axis == []: continue for axis in param.axis: if axis['uuid'] in sequence.axes_uuid and not axis['uuid'] in sequence.reduce_uuid: ax_id = sequence.axes_uuid.index(axis['uuid']) axes[f"pulse{i}_{param_name}"] = {'axis': axis['axis'] + param.value, 'ax_id':ax_id} return axes
[docs] def get_all_fixed_param(sequence): fixed_param = {} for param_name in sequence.__dict__: param = sequence.__dict__[param_name] if param_name == 'name': if param is not None: fixed_param[f"seq_{param_name}"] = param continue elif not isinstance(param, Parameter): continue else: if (param.axis == []) and (param.value is not None): fixed_param[param_name] = param.value elif param.axis[0]['uuid'] in sequence.reduce_uuid: fixed_param[param_name] = param.value for i,pulse in enumerate(sequence.pulses): if isinstance(pulse, Detection): type="det" else: type="pulse" for param_name in pulse.__dict__: param = pulse.__dict__[param_name] if param_name == 'name': if param is not None: fixed_param[f"{type}{i}_{param_name}"] = param continue elif not isinstance(param, Parameter): continue else: if (param.axis == []) and (param.value is not None): fixed_param[f"{type}{i}_{param_name}"] = param.value fixed_param['nPcyc'] = sequence.pcyc_dets.shape[0] if hasattr(sequence,'pcyc_name'): fixed_param['pcyc_name'] = sequence.pcyc_name return fixed_param
[docs] def create_dataset_from_sequence(data, sequence: Sequence,extra_params={}): ndims = data.ndim default_labels = ['X','Y','Z','T'] dims = default_labels[:ndims] axes = get_all_axes(sequence) coords = {a:(default_labels[b['ax_id']],b['axis']) for a,b in axes.items()} attr = get_all_fixed_param(sequence) attr.update(extra_params) attr.update({'autoDEER_Version':__version__}) return xr.DataArray(data, dims=dims, coords=coords,attrs=attr)
[docs] def create_dataset_from_axes(data, axes, params: dict = None,axes_labels=None): ndims = data.ndim if axes_labels is None: default_labels = ['X','Y','Z','T'] elif len(axes_labels) < ndims: default_labels = axes_labels + ['X','Y','Z','T'] else: default_labels = axes_labels dims = default_labels[:ndims] if not isinstance(axes, list): axes = [axes] coords = {default_labels.pop(0):a for a in axes} params.update({'autoDEER_Version':__version__}) return xr.DataArray(data, dims=dims, coords=coords, attrs=params)
[docs] def create_dataset_from_bruker(filepath): axes, data, params = deerload(filepath, plot=False, full_output=True) if not isinstance(axes, list): axes = [axes] ndims = data.ndim default_labels = ['X','Y','Z','T'] dims = default_labels[:ndims] labels = [] for i in range(ndims): ax_label = default_labels[i] axis_string = params['DESC'][f'{ax_label}UNI'] if "'" in axis_string: axis_string = axis_string.replace("'", "") if axis_string == 'G': labels.append('B') axes[i] = axes[i] * 1e3 elif axis_string == 'ns': labels.append('t') else: labels.append(None) # Count occurens of each label label_count = {i:labels.count(i) for i in labels} # Add a number to each label if count > 1 for i in range(ndims): if label_count[labels[i]] > 1: labels[i] = default_labels[i] + labels[i] coords = {labels[i]:(default_labels[i],a) for i,a in enumerate(axes)} attr = {} attr['LO'] = float(params['SPL']['MWFQ']) / 1e9 attr['B'] = float(params['SPL']['B0VL']) * 1e4 attr['reptime'] = float(params['DSL']['ftEpr']['ShotRepTime'].replace('us','')) attr['nAvgs'] = int(params['DSL']['recorder']['NbScansAcc']) attr['shots'] = int(params['DSL']['ftEpr']['ShotsPLoop']) attr.update({'autoDEER_Version':__version__}) return xr.DataArray(data, dims=dims, coords=coords, attrs=attr)
@xr.register_dataarray_accessor("epr") class EPRAccessor: def __init__(self, xarray_obj): self._obj = xarray_obj def save(self, filename,type='netCDF'): if type == 'netCDF': self._obj.to_netcdf(f"{filename}.h5",engine='h5netcdf',invalid_netcdf=True) @property def correctphase(self): new_obj = copy.deepcopy(self._obj) if np.iscomplexobj(self._obj.data): corr_data = correctphase(self._obj.data) new_obj.data = corr_data else: UserWarning("Data is not complex, phase correction not applied") return new_obj @property def normalise(self): self._obj.data = self._obj.data / np.abs(self._obj.data).max() return self._obj @property def correctphasefull(self): new_obj = copy.deepcopy(self._obj) if np.iscomplexobj(self._obj.data): Re,Im,_ = correctphase(self._obj.data,full_output=True) new_obj.data = Re + 1j*Im else: UserWarning("Data is not complex, phase correction not applied") return new_obj @property def SNR(self): from deerlab import der_snr norm_data = self._obj.data / np.abs(self._obj.data).max() return 1/der_snr(norm_data) @property def fft(self): new_obj = copy.deepcopy(self._obj) new_obj.data = fft.fftshift(fft.fft(self._obj.data)) new_coords = {} for key, coord in new_obj.coords.items(): new_coords[key] = (coord.dims[0],fft.fftshift(fft.fftfreq(coord.size, coord[1].data-coord[0].data))) new_obj = new_obj.assign_coords(**new_coords) return new_obj @property def MeasurementTime(self): """Calculate the total measurement time in seconds""" return self._obj.reptime *1e-6 * self._obj.shots * self._obj.nAvgs * self._obj.nPcyc @property def sequence(self): dataset_attrs = self._obj.attrs dataset_coords = self._obj.coords pulses = len([key for key in dataset_attrs.keys() if re.match(r"pulse\d+_name$", key)]) det_events = len([key for key in dataset_attrs.keys() if re.match(r"det\d+_t$", key)]) n_events = pulses + det_events seq_param_types = ['seq_name','B','LO','reptime','shots','averages','det_window'] seq_params = {} for param_type in seq_param_types: if param_type in dataset_attrs: seq_params[param_type] = dataset_attrs[param_type] elif param_type in dataset_coords: coord = dataset_coords[param_type] min = coord.min() dim = coord.shape[0] step = coord[1] - coord[0] seq_params[param_type] = Parameter(name = param_type, value = min, dim=dim, step=step) seq_params['name'] = seq_params.pop('seq_name') pulses_obj = [] for i in range(n_events): if f"pulse{i}_t" in dataset_attrs: pulse_type = dataset_attrs[f"pulse{i}_name"] key="pulse" elif f"det{i}_t" in dataset_attrs: pulse_type = 'Detection' key="det" param_types = ['t','tp','freq','flipangle','scale','order1','order2','init_freq','BW','final_freq','beta'] params = {} for param_type in param_types: if f"{key}{i}_{param_type}" in dataset_attrs: params[param_type] = dataset_attrs[f"{key}{i}_{param_type}"] elif f"{key}{i}_{param_type}" in dataset_coords: coord = dataset_coords[f"{key}{i}_{param_type}"] min_value = coord.min() dim = coord.shape[0] step = coord[1] - coord[0] params[param_type] = Parameter(name = param_type, value = min_value, dim=dim, step=step) try: pulse_build = getattr(ad_pulses,pulse_type) pulses_obj.append(pulse_build(**params)) except: continue sequence = Sequence(**seq_params) sequence.pulses = pulses_obj return sequence def merge(self,other): """ Merge two datasets into one dataset. Handles the following cases: 1. Both datasets have the same parameters but different axes and are 1D """ # Check if the datasets have the same parameters dataarray1: xr.DataArray = self._obj dataarray2: xr.DataArray = other # check both axes are 1D if len(dataarray1.dims) != 1 or len(dataarray2.dims) != 1: raise ValueError("Both datasets must be 1D") keys_check = [ 'B','LO','reptime','shots','nAvgs','nPcyc','pcyc_name', ] for key in keys_check: if key in dataarray1.attrs and key in dataarray2.attrs: if dataarray1.attrs[key] != dataarray2.attrs[key]: raise ValueError(f"Datasets have different values for {key}, cannot merge") elif key in dataarray1.attrs: print(f"Parameter {key} not found in dataset 2") elif key in dataarray2.attrs: print(f"Parameter {key} not found in dataset 1") new_data = np.concatenate((dataarray1.data,dataarray2.data),axis=0) new_coords = {} for key, coord in dataarray1.coords.items(): new_coords[key] = (coord.dims,np.concatenate((coord,dataarray2.coords[key]),axis=0)) # Sort based on the first coord first_coord = new_coords[list(new_coords.keys())[0]][1] sort_dir = first_coord[-1] - first_coord[0] # check if ascending or descending sort_idx = np.argsort(first_coord) if sort_dir < 0: sort_idx = np.flip(sort_idx) new_data = new_data[sort_idx] for key in new_coords.keys(): new_coords[key] = (new_coords[key][0],new_coords[key][1][sort_idx]) new_dataset = xr.DataArray(new_data, dims=dataarray1.dims, coords=new_coords, attrs=dataarray1.attrs) return new_dataset