Source code for pyepr.pulses

from pyepr.classes import Parameter
from pyepr.utils import build_table, sop, autoEPRDecoder
from memoization import cached
import numpy as np
from scipy.integrate import cumulative_trapezoid
import numbers
import uuid
import json
import base64
import matplotlib.pyplot as plt
import scipy.fft as fft
from pyepr import __version__
import copy
from functools import reduce
from itertools import accumulate

class Pulse:
    Represents a general experimental pulse.

    def __init__(self, *, tp, t=None, scale=None, flipangle=None, pcyc=[0],
                 name=None, **kwargs) -> None:
        """The class for a general pulse.

        tp : float
            The pulse length in ns.
        scale : float
            The arbitary experimental pulse amplitude, 0-1.
        t : float, optional
            The pulse start time in ns.

        if t is not None:
            if isinstance(t,numbers.Number):
                self.t = Parameter("t", t, "ns", "Start time of pulse")
            elif isinstance(t,Parameter):
                self.t = t
            self.t = None

        if isinstance(tp,numbers.Number):
   = Parameter("tp", tp, "ns", "Length of the pulse")
        elif isinstance(tp,Parameter):
   = tp

        if isinstance(scale,numbers.Number):
            self.scale = Parameter("scale", scale, None, "Amplitude of pulse")
        elif isinstance(scale,Parameter):
            self.scale = scale
        elif scale is None:
            self.scale = Parameter("scale", None, None, "Amplitude of pulse")
 = name
        self.Progression = False

        if flipangle is not None:
            self.flipangle = Parameter(
                "flipangle", flipangle, None,
                "The target flip angle of the spins")
        if pcyc is None:
            self.pcyc = None
        elif type(pcyc) is dict:
            self._addPhaseCycle(pcyc["phases"], detections=pcyc["dets"])
            self._addPhaseCycle(pcyc, detections=None)

    def bandwidth(self):
        BW = Parameter("Bandwidth", 0, "GHz", "Bandwidth of pulse")
        if hasattr(self, "freq"):
            BW.value =0
        elif hasattr(self, "init_freq") & hasattr(self, "final_freq"):
            BW.value = np.abs(self.final_freq.value - self.init_freq.value)
            raise ValueError()
        return BW
    def _addPhaseCycle(self, phases, detections=None):
            Adds a phase cycle to the pulse sequence.

                phases (list): List of phases to add to the phase cycle.
                detections (list, optional): List of detection signs. Defaults to None. If None then all cycles are summed.

            if detections is None:
                detections = np.ones(len(phases))
            self.pcyc = {"Phases": list(phases), "DetSigns": list(detections)}

    def _buildFMAM(self, func, ax=None):
        Builds the amplitude modulation (AM) and frequency modulation (FM) of a given function.

            func: A function that takes in an array of values and returns two arrays, representing the AM and FM of the function.

            Two arrays representing the AM and FM of the function.
        """ = np.linspace(,, 1001)
        dt =[1][0]
        self.AM, self.FM = func(
        FM_arg = 2*np.pi*cumulative_trapezoid(self.FM, initial=0) * dt
        self.complex = self.AM * (np.cos(FM_arg) +1j* np.sin(FM_arg))
        return self.AM, self.FM
    def build_shape(self,ax=None):
        if ax is None:
            ax =
        dt = ax[1]-ax[0]
        pulse_points = int(np.around(
        total_points = ax.shape[0]
        remaining_points = total_points-pulse_points
        pulse_axis = np.zeros(pulse_points)
        AM, FM = self.func(pulse_axis)
        AM = np.pad(AM,(int(np.floor(remaining_points/2)),int(np.ceil(remaining_points/2))),'constant',constant_values=0)
        FM = np.pad(FM,(int(np.floor(remaining_points/2)),int(np.ceil(remaining_points/2))),'constant',constant_values=0)
        FM_arg = 2*np.pi*cumulative_trapezoid(FM, initial=0) * dt
        shape = AM * (np.cos(FM_arg) +1j* np.sin(FM_arg))
        return shape

    def build_table(self):
            Builds a table of variables, axes, and UUIDs for all non-static Parameters in the object.

                dict: A dictionary containing the following keys: "Variable", "axis", and "uuid".
                The values for each key are lists of the corresponding values for each non-static Parameter.
            progTable = {"Variable": [], "axis": [],
                    "uuid": []}
            for arg in vars(self):
                attr = getattr(self,arg)
                if type(attr) is Parameter and not attr.is_static():
                    for i, axes in enumerate(attr.axis):

            return progTable

    def is_static(self):
            Check if all parameters in the pulse object are static.

                bool: True if all parameters are static, False otherwise.
            for arg in vars(self):
                attr = getattr(self,arg)
                if type(attr) is Parameter and not attr.is_static():
                    return False
            return True

    def isDelayFocused(self):
        Does the pulse contain a specified time, `t`?

        If so then it is not delay focused.
        if self.t is None:
            return True
            return False
    def isPulseFocused(self):
        Does the pulse contain a specified time, `t`?

        If so then it is delay focused.
        if self.t is not None:
            return True
            return False

    def plot(self, pad=1000):
        """Plots the time domain representation of this pulse.

        pad : int, optional
            The number of zeros to pad the data with, by default 1000
        dt =[1] -[0]
        if self.isPulseFocused():
            tax = self.t.value +
            tax =
        fwd_ex = np.linspace(tax[0] - dt * pad, tax[0], pad)
        rev_ex = np.linspace(
            tax[-1] + dt, tax[-1] + dt * pad, pad, endpoint=True)
        tax = np.concatenate((fwd_ex, tax, rev_ex))
        data = np.pad(self.complex, pad)

        plt.plot(tax, np.real(data), label="Re")
        plt.plot(tax, np.imag(data), label="Im")
        plt.xlabel("Time (ns)")
        plt.ylabel("Amplitude (A.U.)")
    def _calc_fft(self, pad=10000):
        data = np.pad(self.complex, int(pad))
        pulse_fft = fft.fftshift(fft.fft(data))
        n = data.shape[0]
        dt =[1][0]
        axis_fft = fft.fftshift(fft.fftfreq(n, dt))

        return axis_fft, pulse_fft
    def amp_factor(self):
        """ The B1 amplitude factor (nutation frequency) for the pulse in GHz"""
        amp_factor_value=  self.flipangle.value / (2 * np.pi * np.trapz(self.AM,
        return Parameter("amp_factor", amp_factor_value, "GHz", "Amplitude factor for the pulse")

    # @cached(thread_safe=False)
    def exciteprofile_old(self, freqs=None, resonator = None):
        """Excitation profile

        Generates the exciatation profiles for this pulse.

        This function is ported from EasySpin 
        ( [1-2],
        and based upon the method from Gunnar Jeschke, Stefan Pribitzer and
        Andrin Doll[3].

        [1] Stefan Stoll, Arthur Schweiger
        EasySpin, a comprehensive software package for spectral simulation and analysis in EPR
        J. Magn. Reson. 178(1), 42-55 (2006)
        [2] Stefan Stoll, R. David Britt
        General and efficient simulation of pulse EPR spectra
        Phys. Chem. Chem. Phys. 11, 6614-6625 (2009)

        [3] Jeschke, G., Pribitzer, S. & DollA. 
        Coherence Transfer by Passage Pulses in Electron Paramagnetic 
        Resonance Spectroscopy. 
        J. Phys. Chem. B 119, 13570-13582 (2015)

        freqs: np.ndarray, optional
            The frequency axis. Caution: A larger number of points will linearly
            increase computation time.
        resonator: ad.ResonatorProfile, optional


        Mx: np.ndarray
            The magentisation in the X direction.
        My: np.ndarray
            The magentisation in the Y direction.
        Mz: np.ndarray
            The magentisation in the Z direction.
        eye = np.identity(2)
        # offsets = np.arange(-2*BW,2*BW,0.002)


        if freqs is None:
            fxs, fs = self._calc_fft()
            fs = np.abs(fs)
            points = np.argwhere(fs>(fs.max()*0.5))[:,0]
            BW = fxs[points[-1]] - fxs[points[0]]
            c_freq = np.mean([fxs[points[-1]], fxs[points[0]]])
            offsets = np.linspace(-2*BW, 2*BW, 100) + c_freq
            offsets = freqs

        t = 
        nOffsets = offsets.shape[0]

        ISignal = np.real(self.complex) * self.amp_factor.value
        QSignal = np.imag(self.complex) * self.amp_factor.value

        if resonator is not None:
            FM = self.FM
            amp_factor = np.interp(FM, resonator.freqs-resonator.LO_c, resonator.profile)
            amp_factor = np.min([amp_factor,np.ones_like(amp_factor)*self.amp_factor.value],axis=0)
            ISignal = np.real(self.complex) * amp_factor
            QSignal = np.imag(self.complex) * amp_factor

        npoints = ISignal.shape[0]
        Sx = sop([1/2],"x")
        Sy = sop([1/2],"y")
        Sz = sop([1/2],"z")
        Density0 = -Sz
        Mag = np.zeros((3,nOffsets), dtype=np.complex128)
        for iOffset in range(0,nOffsets):
            UPulse = eye
            Ham0 = offsets[iOffset]*Sz

            for it in range(0,npoints):
                Ham = ISignal[it] * Sx + QSignal[it] * Sy + Ham0
                M = -2j * np.pi * (t[1]-t[0]) * Ham
                q = np.sqrt(M[0,0] ** 2 - np.abs(M[0,1]) ** 2)
                if np.abs(q) < 1e-10:
                    dU = eye + M
                    dU = np.cosh(q)*eye + (np.sinh(q)/q)*M
                UPulse = dU @ UPulse

            Density = UPulse @ Density0 @ UPulse.conjugate().T
            Mag[0, iOffset] = -2 * (Sx @ Density.T).sum().real
            Mag[1, iOffset] = -2 * (Sy @ Density.T).sum().real
            Mag[2, iOffset] = -2 * (Sz * Density.T).sum().real
        return Mag[0,:], Mag[1,:], Mag[2,:]
    def exciteprofile(self, freqs=None, resonator=None, trajectory=False):
        """Excitation profile

        Generates the exciatation profiles for this pulse, using the two-level system model developed by Jeschke et al. [1].
        And then optimised in the PulseShape software package by Maxx Tessmer [2], reproduced under the GNU GPL v3 license.

        [1] Jeschke, G., Pribitzer, S. & Doll, A. 
            Coherence Transfer by Passage Pulses in Electron Paramagnetic 
            Resonance Spectroscopy. 
            J. Phys. Chem. B 119, 13570-13582 (2015)

        freqs: np.ndarray, optional
            The frequency axis. Caution: A larger number of points will linearly
            increase computation time.
        resonator: ad.ResonatorProfile, optional
            The resonator profile to apply resonator compensation.
        trajectory: bool, optional
            If True, the function will return the magnetisation at each time point. Default is False.
        Mag: np.ndarray
            The magnetisation in the X, Y, and Z directions.

        if freqs is None:
            fxs, fs = self._calc_fft()
            fs = np.abs(fs)
            points = np.argwhere(fs>(fs.max()*0.5))[:,0]
            BW = fxs[points[-1]] - fxs[points[0]]
            c_freq = np.mean([fxs[points[-1]], fxs[points[0]]])
            offsets = np.linspace(-2*BW, 2*BW, 100) + c_freq
            offsets = freqs
        ISignal = np.real(self.complex) * self.amp_factor.value
        QSignal = np.imag(self.complex) * self.amp_factor.value
        if resonator is not None:
            FM = self.FM
            amp_factor = np.interp(FM, resonator.freqs-resonator.LO_c, resonator.profile)
            amp_factor = np.min([amp_factor,np.ones_like(amp_factor)*self.amp_factor.value],axis=0)
            ISignal = np.real(self.complex) * amp_factor
            QSignal = np.imag(self.complex) * amp_factor

        M0=[0, 0, 1]
        M0 = np.asarray(M0, dtype=float)
        if len(M0.shape) == 1:
            Mmag = np.linalg.norm(M0)
            M0 /= Mmag
            M0 = np.tile(M0, (len(offsets), 1))
            Mmag = np.array([Mmag for i in range(len(offsets))])
            Mmag = np.linalg.norm(M0, axis=1)
            M0 /= Mmag[:, None]

        Sx = sop([1/2],"x")
        Sy = sop([1/2],"y")
        Sz = sop([1/2],"z")

        density0 = 0.5 * np.array(([[1 + M0[:, 2], M0[:, 0] - 1j * M0[:, 1]],
                                [M0[:, 0] + 1j * M0[:, 1], 1 - M0[:, 2]]]))
        density0 = np.moveaxis(density0, 2, 0)

        dt = t[1] - t[0]

        H = offsets[:, None, None] * Sz
        H = H[:, None, :, :] + ISignal[:, None, None] * Sx + QSignal[:, None, None] * Sy
        M = -2j * np.pi * dt * H
        q = np.sqrt(M[:, :, 0, 0]**2 - np.abs(M[:, :, 0, 1])**2)

        dUs = np.cosh(q)[:, :, None, None] * np.eye(2, dtype=complex) + (np.sinh(q) / q)[:, :, None, None] * M
        mask = np.abs(q) < 1e-10
        dUs[mask] = np.eye(2, dtype=complex) + M[mask]

        if not trajectory:
            Upulses = np.empty((len(dUs), 2, 2), dtype=complex)
            for i in range(len(dUs)):
                Upulses[i] = reduce(lambda x, y: y@x, dUs[i, :-1])
            density = np.einsum('ijk,ikl,ilm->ijm', Upulses, density0, Upulses.conj().transpose((0, 2, 1)))
            density = density.transpose((0, 2, 1))

            Mag = np.zeros((len(offsets), 3))
            Mag[..., 0] =  2 * density[..., 0, 1].real             # 2 * (Sx[None, :, :] * density).sum(axis=(1, 2)).real
            Mag[..., 1] = -2 * density[..., 1, 0].imag             # 2 * (Sy[None, :, :] * density).sum(axis=(1, 2)).real
            Mag[..., 2] =  density[..., 0, 0] - density[..., 1, 1] # 2 * (Sz[None, :, :] * density).sum(axis=(1, 2)).real
            return np.squeeze(Mag * Mmag[:, None])
            Upulses = np.empty((len(dUs), len(t), 2, 2), dtype=complex)
            for i in range(len(dUs)):
                Upulses[i] = [np.eye(2)] +  list((accumulate(dUs[i, :-1], lambda x, y: y @ x)))

            density = np.einsum('hijk,hkl,hilm->hijm', Upulses, density0, Upulses.conj().transpose((0, 1, 3, 2)))
            density = density.transpose((0, 1, 3, 2))

            Mag = np.zeros((len(offsets), len(t), 3))
            Mag[..., 0] = 2 * density[..., 0, 1].real # 2 * (Sx[None, None, :, :] * density).sum(axis=(2, 3)).real
            Mag[..., 1] = -2 * density[..., 1, 0].imag # 2 * (Sy[None, None, :, :] * density).sum(axis=(2, 3)).real
            Mag[..., 2] = density[..., 0, 0] - density[..., 1, 1] # 2 * (Sz[None, None, :, :] * density).sum(axis=(2, 3)).real

            return Mag

    def plot_fft(self):
        ax, ft = self._calc_fft(1e4)
        plt.plot(ax, np.real(ft), label="Re")
        plt.plot(ax, np.imag(ft), label="Im")
        plt.plot(ax, np.abs(ft), label="Im")
        plt.xlabel("Frequency (GHz)")
        plt.ylabel("Amplitude (A.U.)")

    def _pcyc_str(self):
        if self.pcyc is not None:
            pcyc_table = "["
            dets = self.pcyc["DetSigns"]
            phases = self.pcyc["Phases"]
            for i in range(0, len(phases)):
                if dets[i] == 1:
                    pcyc_table += "+"
                elif dets[i] == -1:
                    pcyc_table += "-"
                elif dets[i] == 1j:
                    pcyc_table += "+i"
                elif dets[i] == -1j:
                    pcyc_table += "-i"

                if phases[i] == 0:
                    pcyc_table += "(+x)"
                elif phases[i] == np.pi:
                    pcyc_table += "(-x)"
                elif phases[i] == np.pi/2:
                    pcyc_table += "(+y)"
                elif phases[i] == -np.pi/2:
                    pcyc_table += "(-y)"
                elif phases[i] == 3*np.pi/2:
                    pcyc_table += "(-y)"
                pcyc_table += " "
            pcyc_table = pcyc_table[:-1]
            pcyc_table += "]"
            pcyc_table = ""

        return pcyc_table

    def __str__(self):
        # Build header line
        header = "#" * 79 + "\n" + "autoDEER Pulse Definition" + \
                 "\n" + "#" * 79 + "\n"
        # Build Overviews
        if self.isPulseFocused():
            overview_params = "Time Pos (ns) \tLength (ns) \tScale (0-1)" +\
                "\t Static \n"
            if type(self) is Detection:
                overview_params += f"{self.t.value}" + "\t\t" + \
                    f"{}" + "\t\t" + "\t\t" +\
                    f"{self.is_static()}" + "\n" + "#" * 79 + "\n"        
                overview_params += f"{self.t.value}" + "\t\t" + \
                    f"{}" + "\t\t" + f"{self.scale.value}" + \
                    "\t\t" + f"{self.is_static()}" + "\n" + "#" * 79 + "\n"
        elif self.isDelayFocused():
            overview_params = "Length (ns) \tScale (0-1)" +\
                "\t Static \n"
            if type(self) is Detection:
                overview_params += f"{}" + "\t\t" + "\t\t" +\
                    f"{self.is_static()}" + "\n" + "#" * 79 + "\n"        
                overview_params += f"{}" + "\t\t" +\
                    f"{self.scale.value}" + "\t\t" + f"{self.is_static()}" +\
                    "\n" + "#" * 79 + "\n"

        # Build Pulse Parameter Table
        param_table =  f"{'Name':^12} \t {'Value':^12} \t {'Unit':^7} \t Description\n"
        param_table += f"{'----':^12} \t {'-----':^12} \t {'----':^7} \t -----------\n"
        for attr_name in dir(self):
            attr = getattr(self, attr_name)
            if type(attr) is Parameter:
                if == "flipangle":
                    if attr.value == np.pi:
                        value = f"{'Ï€':>12}"
                    elif attr.value == np.pi/2:
                        value = f"{'Ï€/2':>12}"
                    elif attr.value == 'Hard':
                        value = f"{'Hard':>12}"
                        value = f"{attr.value:>5.5g}"
                elif attr.value is None:
                    value = f"{'N/A':>12}"
                    value = f"{attr.value:>5.5g}"
                if attr.unit is None:
                    attr.unit = "None"
                if attr.description is None:
                    attr.description = ""

                param_table += f"{>12} \t {value:>12} \t" +\
                    f"{attr.unit:>7} \t {attr.description} \n"

        # Build Pulse Progression Table
        if self.Progression is True:
            prog_table = "#" * 79 + "\n" + "Progressive Variables:" + "\n"
            prog_table += "Variable \t Axis \t Length \n"
            prog_table += "-------- \t ---- \t ------ \n"
            prog_table += f"{self.Prog_var} \t\t {self.Prog_axis} \t" +\
                f"{self.Prog_n} \n"
            prog_table = ""

        # Build Pulse Phase Cycle Table
        if self.pcyc is not None:
            pcyc_table = "#" * 79 + "\n" + "Phase Cycle:" + "\t"
            pcyc_table += self._pcyc_str()
            pcyc_table += "\n"
            pcyc_table = ""

        # Build Footer
        footer = "#" * 79 + "\n" +\
            f"Built by autoDEER Version: {__version__}" + "\n" + "#" * 79

        # Combine All
        string = header + overview_params + param_table + prog_table +\
            pcyc_table + footer

        return string

    def copy(self, clear=False, **kwargs):
        """Creates a deep-copy of the pulse. I.e. Every parameter object is
        re-created at another memory space.

        Parameter can be chaged at this stage by adding them as keyword-
        arguments (kwargs).

            A deep copy of the pulse

        new_pulse = copy.deepcopy(self)

        if clear:
            for arg in vars(new_pulse):
                attr = getattr(new_pulse,arg)
                if type(attr) is Parameter and not getattr(new_pulse,arg).is_static():

        for arg in kwargs:
            if (hasattr(new_pulse,arg)) and (getattr(new_pulse,arg) is not None):
                attr = getattr(new_pulse,arg)
                if type(attr) is Parameter:
                    if isinstance(kwargs[arg], numbers.Number):
                        attr.value = kwargs[arg]     
                    elif isinstance(kwargs[arg], Parameter):
                        attr.value = kwargs[arg].value
                        attr.axis = kwargs[arg].axis

                elif arg == "pcyc":
                    new_pcyc = kwargs[arg]
                    if attr is None:
                        new_pulse.pcyc = None
                    elif type(new_pcyc) is dict:
                        new_pulse._addPhaseCycle(new_pcyc["phases"], detections=new_pcyc["dets"])
                        new_pulse._addPhaseCycle(new_pcyc, detections=None)               
                    attr = kwargs[arg]
            elif arg == "t":
                if type(kwargs[arg]) is Parameter:
                    kwargs[arg].name = "t"
                    setattr(new_pulse, arg, kwargs[arg])
                elif isinstance(kwargs[arg], numbers.Number):
                    setattr(new_pulse, arg,
                            Parameter("t", kwargs[arg], "ns",
                                    "Start time of pulse"))
                    raise ValueError("new parameters must be either numeric or Parameters")
            elif arg == "scale":
                setattr(new_pulse, arg,
                        Parameter("scale", kwargs[arg], None,
                                  "Amplitude of pulse"))
            elif arg == "flipangle":
                        Parameter("flipangle", kwargs[arg], None,
                                  "The target flip angle of the spins"))
        return new_pulse

    def _to_dict(self):
        to_return = {"version": __version__, "type": "Pulse", "subclass": str(type(self))}

        for key, var in vars(self).items():
            if isinstance(var, Parameter):
                to_return[key] = var._to_dict()
                to_return[key] = var

        return to_return
    def _to_json(self):
        class autoEPREncoder(json.JSONEncoder):
            def default(self, obj):
                if isinstance(obj, np.ndarray):
                    if (len(obj) > 0 ) and isinstance(obj[0], str):
                        return list(obj)
                    data = np.ascontiguousarray(
                    data_b64 = base64.b64encode(data)
                    return dict(__ndarray__=str(data_b64),
                if isinstance(obj, complex):
                    return str(obj)
                if isinstance(obj, numbers.Number):
                    return str(obj)
                if isinstance(obj, uuid.UUID):
                    return_dict = {"__uuid__": str(obj)}
                    return return_dict
                if isinstance(obj, Parameter):
                    return obj._to_dict()
                if isinstance(obj, Pulse):
                    return obj._to_dict()
                    return json.JSONEncoder.default(self, obj)
        return json.dumps(self._to_dict(), cls=autoEPREncoder, indent=4)

    def save(self, filename):
        """Save the Pulse to a JSON file.

        filename : str
            Path to the JSON file.


            If the object cannot be serialized to JSON.

        >>> obj = Pulse()
        with open(filename, "w") as f:

    def _from_dict(cls, dct):
        new_pulse = cls(tp=Parameter._from_dict(dct["tp"]), name=dct["name"])
        for key, var in dct.items(): 
            if isinstance(var, dict) and ("type" in var):
                setattr(new_pulse, key, Parameter._from_dict(var))
            elif key == "type":
            elif key == "version":
                setattr(new_pulse, key, var)
        return new_pulse

    def _from_json(cls, JSONstring):
        dct = json.loads(JSONstring, object_hook=autoEPRDecoder)
        return cls._from_dict(dct)
    def load(cls, filename):
        """Load a Pulse object from a JSON file.

        filename : str
            Path to the JSON file.

        obj : Pulse
            The Pulse loaded from the JSON file.

            If the file does not exist.

        >>> obj = Pulse.load("my_pulse.json")
        with open(filename, "r") as f:
           file_buffer =
        return cls._from_json(file_buffer)

# =============================================================================

#                                Subclasses

# =============================================================================

class Detection(Pulse):
    Represents a detection pulse.

    def __init__(self, *, tp, t=None, freq=0,**kwargs) -> None:
        """A general detection pulse.

        tp : float
            The **total** time of the detection event. The detection event will
            be symetrical about the centre time. 
        t : float, optional
            The **centre** time of the detection event
        freq: float, optional
            The detection frequency, not all spectrometer support this 
            functionality, by default 0 MHz

        Pulse.__init__(self, t=t, tp=tp, **kwargs)
        self.scale = None
        self.freq = Parameter("freq", freq, "GHz", "The detection frequency, if supported")
        self.pcyc = None

# =============================================================================

class Delay(Pulse):
    Represents a inter-pulse delay pulse.

    def __init__(self, *, tp, t=None) -> None:
        if t is not None:
            self.t = Parameter("Time", t, "ns", "Start time of pulse")
            self.t = None = Parameter("Length", tp, "ns", "Length of the pulse")
        self.pcyc = None
        self.scale = None

# =============================================================================

class RectPulse(Pulse):
    Represents a rectangular monochromatic pulse.

    def __init__(
            self, tp=16, freq=0, t=None, flipangle=None, **kwargs) -> None:

        tp : flaot, optional
            Pulse Length in ns, by default 16
        freq : float,optional
            Frequency in MHz, by default 0
        t : float, optional
            Time position in ns, by default None
        flipangle : _type_, optional
            The flip angle in radians, by default None
            self, tp=tp, t=t, flipangle=flipangle,name='RectPulse', **kwargs)
        self.freq = Parameter("freq", freq, "GHz", "Frequency of the Pulse")

    def func(self, ax):
        nx = ax.shape[0]
        AM = np.ones(nx)
        FM = np.zeros(nx) + self.freq.value
        return AM, FM

class GaussianPulse(Pulse):
    Represents a Gaussian monochromatic pulse.

    def __init__(self, *, tp=32,FWHM=16, freq=0, **kwargs) -> None:
        """    Represents a Gaussian monochromatic pulse.

        tp : float
            Pulse length in ns, by default 128
        FWHM : float,
            The full width at half maximum of the pulse
        freq : float, optional
            The frequency of the pulse, by default 0
        Pulse.__init__(self, tp=tp,name='GaussianPulse', **kwargs)
        self.freq = Parameter("Freq", freq, "GHz", "Frequency of the Pulse")
        self.FWHM = Parameter("FWHM", FWHM, "ns", "Full Width at Half Maximum of the Pulse")

    def func(self, ax):
        nx = ax.shape[0]
        sigma = self.FWHM.value / (2 * np.sqrt(2 * np.log(2)))
        AM = np.exp(-ax**2 / (2 * sigma**2))
        FM = np.zeros(nx) + self.freq.value
        return AM, FM
# =============================================================================
class FrequencySweptPulse(Pulse):
    A general parent class for Frequency Swept Pulses.
    def __init__(self, *, tp, t=None, scale=None, flipangle=None, pcyc=[0], name=None, **kwargs) -> None:
        super().__init__(tp=tp, t=t, scale=scale, flipangle=flipangle, pcyc=pcyc, name=name, **kwargs)

        # Extract frequency infomation, can be either specified with an initial and final frequency or a bandwidth combined with one of the two or the central frequency

        if "BW" in kwargs:
            BW = kwargs["BW"]
            if "init_freq" in kwargs:
                self.init_freq = Parameter("init_freq", kwargs["init_freq"], "GHz", "Initial frequency of pulse")
                self.final_freq = Parameter("final_freq", self.init_freq.value + BW, "GHz", "Final frequency of pulse")
            elif "final_freq" in kwargs:
                self.final_freq = Parameter("final_freq", kwargs["final_freq"], "GHz", "Final frequency of pulse")
                self.init_freq = Parameter("init_freq", self.final_freq.value - BW, "GHz", "Initial frequency of pulse")
                raise ValueError("Bandwidth must be combined with either an initial or final frequency")
        elif ("init_freq" in kwargs) & ("final_freq" in kwargs):
            self.init_freq = Parameter("init_freq", kwargs["init_freq"], "GHz", "Initial frequency of pulse")
            self.final_freq = Parameter("final_freq", kwargs["final_freq"], "GHz", "Final frequency of pulse")
            raise ValueError("Frequency information must be provided")

    def Qcrit(self):
        """ The critical Q factor for the pulse."""
        Qcrit = (2/np.pi)*np.log(2/(1+np.cos(self.flipangle.value)))
        Qcrit = np.min([Qcrit,5])
        return Parameter("Qcrit", Qcrit, None, "Critical Q factor for the pulse")
    def sweeprate(self):
        """ The sweep rate of the pulse in GHz/ns"""
        raise NotImplementedError("This property must be implemented in the subclass")

    def amp_factor(self):
        """ The B1 amplitude factor (nutation frequency) for the pulse in GHz"""
        amp_factor_value=  np.sqrt(2*np.pi*self.Qcrit.value*self.sweeprate.value)/(2*np.pi)
        return Parameter("amp_factor", amp_factor_value, "GHz", "Amplitude factor for the pulse")

class HSPulse(FrequencySweptPulse):
    Represents a hyperboilc secant frequency-swept pulse.
    def __init__(self, *, tp=128, order1=1, order2=6, beta=20, **kwargs) -> None:
        FrequencySweptPulse.__init__(self, tp=tp, name='HSPulse', **kwargs)
        self.order1 = Parameter(
            "order1", order1, None, "Order 1 of the HS Pulse")
        self.order2 = Parameter(
            "order2", order2, None, "Order 2 of the HS Pulse")
        self.beta = Parameter("beta", beta, None, "Beta of the HS Pulse")


    def func(self, ax):
        beta = self.beta.value
        order1 = self.order1.value
        order2 = self.order2.value
        tp = ax.max() - ax.min()
        tcent = tp / 2
        ti = ax
        nx = ax.shape[0]
        # beta_exp1 = np.log(beta*0.5**(1-order1)) / np.log(beta)
        # beta_exp2 = np.log(beta*0.5**(1-order2)) / np.log(beta)
        # cut = round(nx/2)
        # AM = np.zeros(nx)
        # AM[0:cut] = 1/np.cosh(
        #     beta**beta_exp1 * (ax[0:cut]/tp)**order1)
        # AM[cut:-1] = 1/np.cosh(
        #     beta**beta_exp2 * (ax[cut:-1]/tp)**order2)

        # FM = BW * cumulative_trapezoid(AM**2,ax,initial=0) /\
        #      np.trapz(AM**2,ax) + self.init_freq.value
        sech = lambda x: 1/np.cosh(x) 
        cut = round(nx/2)
        AM = np.zeros_like(ti)
        AM[:cut] = sech(beta*0.5*(2*ti[:cut]/tp)**order1)
        AM[cut:] = sech(beta*0.5*(2*ti[cut:]/tp)**order2)

        BWinf = (self.bandwidth.value) / np.tanh(beta/2)

        freq = (BWinf/2) * np.tanh((beta/tp*ti)) + np.mean([self.init_freq.value,self.final_freq.value])
        # phase = 2*np.pi*(BWinf/2) * (tp/beta) * np.log(np.cosh((beta/tp)*ti))

        # total_phase = phase * 2* np.pi * np.mean([self.init_freq.value,self.final_freq.value])
        return AM, freq
    def sweeprate(self):
        """ The sweep rate of the pulse in GHz/ns"""
        sweeprate_value = self.beta.value * self.bandwidth.value / (2*
        return Parameter("sweeprate", sweeprate_value, "GHz/ns", "Sweep rate of the pulse")

# =============================================================================

class ChirpPulse(FrequencySweptPulse):
    Represents a linear frequency-swept pulse.

    def __init__(self, *, tp=128, **kwargs) -> None:
        FrequencySweptPulse.__init__(self, tp=tp,name='ChirpPulse', **kwargs)


    def func(self, ax):
        nx = ax.shape[0]
        AM = np.ones(nx)

        FM = np.linspace(
            self.init_freq.value, self.final_freq.value, nx)

        return AM, FM
    def sweeprate(self):
        """ The sweep rate of the pulse in GHz/ns"""
        sweeprate_value = self.bandwidth.value /
        return Parameter("sweeprate", sweeprate_value, "GHz/ns", "Sweep rate of the pulse")
# =============================================================================

class SincPulse(Pulse):

    def __init__(self, *, tp=128, freq=0, order=6, window=None, **kwargs) -> None:
        """    Represents a sinc shaped monochromatic pulse.

        tp : int
            Pulse length in ns, by default 128
        freq : int, optional
            The frequency of the pulse, by default 0
        order : int, optional
            The order of this sinc function, by default 6
        window : _type_, optional
            The window function, by default None

        Pulse.__init__(self, tp=tp,name='SincPulse', **kwargs)
        self.freq = Parameter("Freq", freq, "GHz", "Frequency of the Pulse")

        self.order = Parameter("Order", order, None, "The sinc pulse order")
        self.window = Parameter(
            "Window", window, None, "The type of window function")


    def func(self, ax):
        nx = ax.shape[0]
        FM = np.zeros(nx) + self.freq.value
        AM = np.sinc(self.order.value * ax)

        return AM, FM
# class GaussianPulse(Pulse):
#     """
#     Represents a Gaussian monochromatic pulse.
#     """

#     def __init__(self, *, tp=128, freq=0, **kwargs) -> None:
#         """    Represents a Gaussian monochromatic pulse.

#         Parameters
#         ----------
#         tp : int
#             Pulse length in ns, by default 128
#         freq : int, optional
#             The frequency of the pulse, by default 0
#         """

#         Pulse.__init__(self, tp=tp,name='GaussianPulse', **kwargs)
#         self.freq = Parameter("Freq", freq, "GHz", "Frequency of the Pulse")

#         self._buildFMAM(self.func)
#         pass

#     def func(self, ax):
#         nx = ax.shape[0]
#         FM = np.zeros(nx) + self.freq.value
#         AM = np.exp(-ax**2/(((**2)/(-1*np.log(0.001))))

#         return AM, FM
# =============================================================================

[docs] def build_default_pulses(AWG=True,SPFU=False,tp=12): if SPFU: tp_pi = tp*2 else: tp_pi = tp exc_pulse = RectPulse( tp=tp, freq=0, flipangle=np.pi/2, scale=0 ) ref_pulse = RectPulse( tp=tp_pi, freq=0, flipangle=np.pi, scale=0 ) if AWG: pump_pulse = HSPulse( tp=120, init_freq=-0.25, final_freq=-0.03, flipangle=np.pi, scale=0, order1=6, order2=1, beta=10 ) det_event = Detection(tp=512, freq=0) else: pump_pulse = RectPulse( tp=tp_pi, freq=-0.07, flipangle=np.pi, scale=0) # det_event = Detection( * 2, freq=0) det_event = Detection(tp=128, freq=0) pulses = {'exc_pulse':exc_pulse, 'ref_pulse':ref_pulse, 'pump_pulse':pump_pulse, 'det_event':det_event} return pulses