import numpy as np from scipy.linalg import lu import xarray as xr class InvalidDimException(Exception): "Raised when the program can not identify (index of images, x, y) axes." def __init__(self, dims): if len(dims)>3: self.message = 'The input data must have two or three axes: (index of images(alternative), x, y)' else: self.message = 'Can not identify (index of images(alternative), x, y) from ' + str(dims) super().__init__(self.message) class DataSizeException(Exception): "Raised when the shape of the data is not correct." def __init__(self): self.message = 'The input data size does not match.' super().__init__(self.message) class FringeRemoval(): """A class for fringes removal """ def __init__(self) -> None: """Initialize the class """ self.nimgsR = 0 # The number of the reference images self.xdim = 0 # The shape of x axis self.ydim = 0 # The shape of y axis self._mask = None # The mask array to choose the region of interest for fringes removal self._center = None # Set the mask array by center and span self._span = None self.reshape=True # If it is necessary to reshape the data from (index of images(alternative), x, y) to (y, x, index of images(alternative)) self.P = None self.L = None self.U = None @property def center(self): """The getter of the center of region of insterest (ROI) :return: The center of region of insterest (ROI) :rtype: tuple """ return self._center @center.setter def center(self, value): """The setter of the center of region of insterest (ROI) :param value: The center of region of insterest (ROI) :type value: tuple """ self._mask = None self._center = value @property def span(self): """The getter of the span of region of insterest (ROI) :return: The span of region of insterest (ROI) :rtype: tuple """ return self._span @span.setter def span(self, value): """The setter of the span of region of insterest (ROI) :param value: The span of region of insterest (ROI) :type value: tuple """ self._mask = None self._span = value def reshape_data(self, data): """The function is to reshape the data to the correct shape. In order to minimize the calculation time, the data has to have a shape of (y, x, index of images(alternative)). However, usually the input data has a shape of (index of images(alternative), x, y). It can also convert the xarray DataArray and Dataset to numpy array. :param data: The input data. :type data: xarray, numpy array or list :raises InvalidDimException: Raised when the program can not identify (index of images, x, y) axes. :raises InvalidDimException: Raised when the shape of the data is not correct. :return: The data with correct shape :rtype: xarray, numpy array or list """ if data is None: return data if isinstance(data, type(xr.DataArray())): dims = data.dims if len(dims)>3: raise InvalidDimException(dims) xAxis = None yAxis = None if len(dims) == 2: imageAxis = '' else: imageAxis = None for dim in dims: if (dim == 'x') or ('_x' in dim): xAxis = dim elif (dim == 'y') or ('_y' in dim): yAxis = dim else: imageAxis = dim if (xAxis is None) or (yAxis is None) or (imageAxis is None): raise InvalidDimException(dims) if len(dims) == 2: data = data.transpose(yAxis, xAxis) else: data = data.transpose(yAxis, xAxis, imageAxis) data = data.to_numpy() else: data = np.array(data) if len(data.shape) == 3: data = np.swapaxes(data, 0, 2) # data = np.swapaxes(data, 0, 1) elif len(data.shape) == 2: data = np.swapaxes(data, 0, 1) return data def _reshape_absorption_images(self, data): if data is None: return data if isinstance(data, type(xr.DataArray())): dims = data.dims if len(dims)>3: raise InvalidDimException(dims) xAxis = None yAxis = None if len(dims) == 2: imageAxis = '' else: imageAxis = None for dim in dims: if (dim == 'x') or ('_x' in dim): xAxis = dim elif (dim == 'y') or ('_y' in dim): yAxis = dim else: imageAxis = dim if (xAxis is None) or (yAxis is None) or (imageAxis is None): raise InvalidDimException(dims) if len(dims) == 2: data = data.transpose(yAxis, xAxis) else: data = data.transpose(yAxis, xAxis, imageAxis) self.nimgs = len(data[imageAxis]) data = data.stack(axis=[yAxis, xAxis]) else: data = np.array(data) if len(data.shape) == 3: data = np.swapaxes(data, 0, 2) # data = np.swapaxes(data, 0, 1) elif len(data.shape) == 2: data = np.swapaxes(data, 0, 1) return data @property def referenceImages(self): res = self._referenceImages.reshape(self.ydim, self.xdim, self.nimgsR) res = np.swapaxes(res, 0, 2) return res @referenceImages.setter def referenceImages(self, value): if self.reshape: value = self.reshape_data(value) elif isinstance(value, type(xr.DataArray())): value = value.to_numpy() self.nimgsR = value.shape[2] self.xdim = value.shape[1] self.ydim = value.shape[0] self._referenceImages = (value.reshape(self.xdim * self.ydim, self.nimgsR).astype(np.float32)) def add_reference_images(self, data): """Add a new reference images :param data: The new reference image. :type data: xarray, numpy array or list :raises DataSizeException: Raised when the shape of the data is not correct. """ if self.reshape: data = self.reshape_data(data) elif isinstance(data, type(xr.DataArray())): data = data.to_numpy() if not ((data.shape[0]==self.ydim) and (data.shape[1]==self.xdim)): raise DataSizeException data = data.reshape(self.xdim * self.ydim) self._referenceImages = np.append(self._referenceImages, data, axis=1) def _remove_first_reference_images(self): """Remove the first reference images """ self._referenceImages = np.delete(self._referenceImages, 0, axis=1) def update_reference_images(self, data): """Update the reference images set by removing the first one and adding a new one at the end. :param data: The new reference image. :type data: xarray, numpy array or list """ self._remove_first_reference_images() self.add_reference_images(data) self.decompose_referenceImages() @property def mask(self): return self._mask @mask.setter def mask(self, value): if self.reshape: value = self.reshape_data(value) elif isinstance(value, type(xr.DataArray())): value = value.to_numpy() if not ((value.shape[0]==self.ydim) and (value.shape[1]==self.xdim)): raise DataSizeException self._mask = value self.k = np.where(self._mask.flatten() == 1)[0] def _auto_mask(self): mask = np.ones((self.ydim, self.xdim), dtype=np.uint8) if not self._center is None: x_start = int(self._center[0] - self._span[0] / 2) x_end = int(self._center[0] + self._span[0] / 2) y_end = int(self._center[1] + self._span[1] / 2) y_start = int(self._center[1] - self._span[1] / 2) mask[y_start:y_end, x_start:x_end] = 0 return mask def decompose_referenceImages(self): if self._mask is None: self.mask = self._auto_mask() self.P, self.L, self.U = lu(self._referenceImages[self.k, :].T @ self._referenceImages[self.k, :], permute_l = False, p_indices = True) def _fringe_removal(self, absorptionImages): b = self.temp @ absorptionImages[self.k] c = np.linalg.solve(self.U, np.linalg.solve(self.L[self.P], b)) optrefimages = (self._referenceImages @ c) return optrefimages def fringe_removal(self, absorptionImages, referenceImages=None, mask=None, reshape=None, dask='parallelized'): """ This function will generate a 'fake' background images, which can help to remove the fringes. Important: Please substract the drak images from the both of images with atoms and without atoms before using this function!!! :param absorptionImages: A set of images with atoms in absorption imaging :type absorptionImages: xarray, numpy array or list :param referenceImages: A set of images without atoms in absorption imaging, defaults to None :type referenceImages: xarray, numpy array or list, optional :param mask: An array to choose the region of interest for fringes removal, defaults to None, defaults to None :type mask: numpy array, optional :param reshape: If it needs to reshape the data, defaults to None :type reshape: bool, optional :param dask: Please refer to xarray.apply_ufunc() :type dask: {"forbidden", "allowed", "parallelized"}, optional :return: The 'fake' background to help removing the fringes :rtype: xarray array """ if not reshape is None: self.reshape = reshape if not referenceImages is None: self.referenceImages = referenceImages if not mask is None: self.mask = mask if self.P is None: self.decompose_referenceImages() if self.reshape: absorptionImages = self._reshape_absorption_images(absorptionImages) self.temp = self._referenceImages[self.k, :].T optrefimages = xr.apply_ufunc(self._fringe_removal, absorptionImages, input_core_dims=[['axis']], output_core_dims=[['axis']], dask=dask, vectorize=True, output_dtypes=float) return optrefimages.unstack()