Source code for autodeer.dataset


import numpy as np
import matplotlib.pyplot as plt
import re
import os
import uuid
import json
import base64
import numbers
from scipy.signal import hilbert
from scipy.io import savemat
import scipy.fft as fft
from autodeer.utils import autoEPRDecoder
from autodeer import __version__
from autodeer.pulses import Pulse,Detection
from autodeer.classes import Parameter, Interface
from autodeer.sequences import Sequence, DEERSequence
import autodeer.pulses as ad_pulses
import autodeer.sequences as ad_seqs
import xarray as xr
from deerlab import correctphase
from deerlab import deerload
import copy

# =============================================================================

[docs] def get_all_axes(sequence): # loop over all parameters and get the axis axes = {} for param_name in sequence.__dict__: param = sequence.__dict__[param_name] if not isinstance(param, Parameter): continue if param.axis == []: continue for axis in param.axis: if axis['uuid'] in sequence.axes_uuid and not axis['uuid'] in sequence.reduce_uuid: ax_id = sequence.axes_uuid.index(axis['uuid']) if param.unit == 'ns': convert = 1e-3 elif param.unit == 'us': convert = 1 elif param.unit == 'ms': convert = 1e3 else: convert = 1 axes[param_name] = {'axis': axis['axis']*convert + param.value*convert, 'ax_id':ax_id} for i,pulse in enumerate(sequence.pulses): for param_name in pulse.__dict__: param = pulse.__dict__[param_name] if not isinstance(param, Parameter): continue if param.axis == []: continue for axis in param.axis: if axis['uuid'] in sequence.axes_uuid and not axis['uuid'] in sequence.reduce_uuid: ax_id = sequence.axes_uuid.index(axis['uuid']) axes[f"pulse{i}_{param_name}"] = {'axis': axis['axis'] + param.value, 'ax_id':ax_id} return axes
[docs] def get_all_fixed_param(sequence): fixed_param = {} for param_name in sequence.__dict__: param = sequence.__dict__[param_name] if param_name == 'name': if param is not None: fixed_param[f"seq_{param_name}"] = param continue elif not isinstance(param, Parameter): continue else: if (param.axis == []) and (param.value is not None): fixed_param[param_name] = param.value elif param.axis[0]['uuid'] in sequence.reduce_uuid: fixed_param[param_name] = param.value for i,pulse in enumerate(sequence.pulses): if isinstance(pulse, Detection): type="det" else: type="pulse" for param_name in pulse.__dict__: param = pulse.__dict__[param_name] if param_name == 'name': if param is not None: fixed_param[f"{type}{i}_{param_name}"] = param continue elif not isinstance(param, Parameter): continue else: if (param.axis == []) and (param.value is not None): fixed_param[f"{type}{i}_{param_name}"] = param.value fixed_param['nPcyc'] = sequence.pcyc_dets.shape[0] if isinstance(sequence,DEERSequence): fixed_param['pcyc_name'] = sequence.pcyc_name return fixed_param
[docs] def create_dataset_from_sequence(data, sequence: Sequence,extra_params={}): ndims = data.ndim default_labels = ['X','Y','Z','T'] dims = default_labels[:ndims] axes = get_all_axes(sequence) coords = {a:(default_labels[b['ax_id']],b['axis']) for a,b in axes.items()} attr = get_all_fixed_param(sequence) attr.update(extra_params) return xr.DataArray(data, dims=dims, coords=coords,attrs=attr)
[docs] def create_dataset_from_axes(data, axes, params: dict = None,axes_labels=None): ndims = data.ndim if axes_labels is None: default_labels = ['X','Y','Z','T'] elif len(axes_labels) < ndims: default_labels = axes_labels + ['X','Y','Z','T'] else: default_labels = axes_labels dims = default_labels[:ndims] if not isinstance(axes, list): axes = [axes] coords = {default_labels.pop(0):a for a in axes} return xr.DataArray(data, dims=dims, coords=coords, attrs=params)
[docs] def create_dataset_from_bruker(filepath): axes, data, params = deerload(filepath, plot=False, full_output=True) if not isinstance(axes, list): axes = [axes] ndims = data.ndim default_labels = ['X','Y','Z','T'] dims = default_labels[:ndims] labels = [] for i in range(ndims): ax_label = default_labels[i] axis_string = params['DESC'][f'{ax_label}UNI'] if "'" in axis_string: axis_string = axis_string.replace("'", "") if axis_string == 'G': labels.append('B') axes[i] = axes[i] * 1e3 elif axis_string == 'ns': labels.append('t') else: labels.append(None) # Count occurens of each label label_count = {i:labels.count(i) for i in labels} # Add a number to each label if count > 1 for i in range(ndims): if label_count[labels[i]] > 1: labels[i] = default_labels[i] + labels[i] coords = {labels[i]:(default_labels[i],a) for i,a in enumerate(axes)} attr = {} attr['LO'] = float(params['SPL']['MWFQ']) / 1e9 attr['B'] = float(params['SPL']['B0VL']) * 1e4 attr['reptime'] = float(params['DSL']['ftEpr']['ShotRepTime'].replace('us','')) attr['nAvgs'] = int(params['DSL']['recorder']['NbScansAcc']) attr['shots'] = int(params['DSL']['ftEpr']['ShotsPLoop']) return xr.DataArray(data, dims=dims, coords=coords, attrs=attr)
@xr.register_dataarray_accessor("epr")
[docs] class EPRAccessor: def __init__(self, xarray_obj):
[docs] self._obj = xarray_obj
[docs] def save(self, filename,type='netCDF'): if type == 'netCDF': self._obj.to_netcdf(f"{filename}.h5",engine='h5netcdf',invalid_netcdf=True)
@property
[docs] def correctphase(self): new_obj = copy.deepcopy(self._obj) if np.iscomplexobj(self._obj.data): corr_data = correctphase(self._obj.data) new_obj.data = corr_data else: UserWarning("Data is not complex, phase correction not applied") return new_obj
@property
[docs] def normalise(self): self._obj.data = self._obj.data / np.abs(self._obj.data).max() return self._obj
@property
[docs] def correctphasefull(self): new_obj = copy.deepcopy(self._obj) if np.iscomplexobj(self._obj.data): Re,Im,_ = correctphase(self._obj.data,full_output=True) new_obj.data = Re + 1j*Im else: UserWarning("Data is not complex, phase correction not applied") return new_obj
@property
[docs] def SNR(self): from deerlab import der_snr norm_data = self._obj.data / np.abs(self._obj.data).max() return 1/der_snr(norm_data)
@property
[docs] def fft(self): new_obj = copy.deepcopy(self._obj) new_obj.data = fft.fftshift(fft.fft(self._obj.data)) new_coords = {} for key, coord in new_obj.coords.items(): new_coords[key] = (coord.dims[0],fft.fftshift(fft.fftfreq(coord.size, coord[1].data-coord[0].data))) new_obj = new_obj.assign_coords(**new_coords) return new_obj
@property
[docs] def sequence(self): dataset_attrs = self._obj.attrs dataset_coords = self._obj.coords pulses = len([key for key in dataset_attrs.keys() if re.match(r"pulse\d+_name$", key)]) seq_param_types = ['seq_name','B','LO','reptime','shots','averages','det_window'] seq_params = {} for param_type in seq_param_types: if param_type in dataset_attrs: seq_params[param_type] = dataset_attrs[param_type] elif param_type in dataset_coords: coord = dataset_coords[param_type] min = coord.min() dim = coord.shape[0] step = coord[1] - coord[0] seq_params[param_type] = Parameter(name = param_type, value = min, dim=dim, step=step) seq_params['name'] = seq_params.pop('seq_name') pulses_obj = [] for i in range(pulses): pulse_type = dataset_attrs[f"pulse{i}_name"] param_types = ['t','tp','freq','flipangle','scale','order1','order2','init_freq','BW','final_freq','beta'] params = {} for param_type in param_types: if f"pulse{i}_{param_type}" in dataset_attrs: params[param_type] = dataset_attrs[f"pulse{i}_{param_type}"] elif f"pulse{i}_{param_type}" in dataset_coords: coord = dataset_coords[f"pulse{i}_{param_type}"] min = coord.min() dim = coord.shape[0] step = coord[1] - coord[0] params[param_type] = Parameter(name = param_type, value = min, dim=dim, step=step) try: pulse_build = getattr(ad_pulses,pulse_type) pulses_obj.append(pulse_build(**params)) except: continue sequence = Sequence(**seq_params) sequence.pulses = pulses_obj return sequence