import numpy as np from scipy.linalg import lu import xarray as xr class InvalidDimException(Exception): "Raised when the program can not identify (index of images, x, y) axes." def __init__(self, dims): if len(dims)>3: self.message = 'The input data must have two or three axes: (index of images(alternative), x, y)' else: self.message = 'Can not identify (index of images(alternative), x, y) from ' + str(dims) super().__init__(self.message) class DataSizeException(Exception): "Raised when the shape of the data is not correct." def __init__(self): self.message = 'The input data size does not match.' super().__init__(self.message) class FringeRemoval(): """A class for fringes removal """ def __init__(self) -> None: """Initialize the class """ self.nimgsR = 0 # The number of the reference images self.xdim = 0 # The shape of x axis self.ydim = 0 # The shape of y axis self._mask = None # The mask array to choose the region of interest for fringes removal self.reshape=True # If it is necessary to reshape the data from (index of images(alternative), x, y) to (y, x, index of images(alternative)) self.P = None self.L = None self.U = None def reshape_data(self, data): """The function is to reshape the data to the correct shape. In order to minimize the calculation time, the data has to have a shape of (y, x, index of images(alternative)). However, usually the input data has a shape of (index of images(alternative), x, y). It can also convert the xarray DataArray and Dataset to numpy array. :param data: The input data. :type data: xarray, numpy array or list :raises InvalidDimException: Raised when the program can not identify (index of images, x, y) axes. :raises InvalidDimException: Raised when the shape of the data is not correct. :return: The data with correct shape :rtype: xarray, numpy array or list """ if data is None: return data if isinstance(data, type(xr.DataArray())): dims = data.dims if len(dims)>3: raise InvalidDimException(dims) xAxis = None yAxis = None if len(dims) == 2: imageAxis = '' else: imageAxis = None for dim in dims: if (dim == 'x') or ('_x' in dim): xAxis = dim elif (dim == 'y') or ('_y' in dim): yAxis = dim else: imageAxis = dim if (xAxis is None) or (yAxis is None) or (imageAxis is None): raise InvalidDimException(dims) if len(dims) == 2: data = data.transpose(yAxis, xAxis) else: data = data.transpose(yAxis, xAxis, imageAxis) data = data.to_numpy() else: data = np.array(data) if len(data.shape) == 3: data = np.swapaxes(data, 0, 2) # data = np.swapaxes(data, 0, 1) elif len(data.shape) == 2: data = np.swapaxes(data, 0, 1) return data @property def referenceImages(self): res = self._referenceImages.reshape(self.ydim, self.xdim, self.nimgsR) res = np.swapaxes(res, 0, 2) return res @referenceImages.setter def referenceImages(self, value): if self.reshape: value = self.reshape_data(value) elif isinstance(value, type(xr.DataArray())): value = value.to_numpy() self.nimgsR = value.shape[2] self.xdim = value.shape[1] self.ydim = value.shape[0] self._referenceImages = (value.reshape(self.xdim * self.ydim, self.nimgsR).astype(np.float32)) def add_reference_images(self, data): """Add a new reference images :param data: The new reference image. :type data: xarray, numpy array or list :raises DataSizeException: Raised when the shape of the data is not correct. """ if self.reshape: data = self.reshape_data(data) elif isinstance(data, type(xr.DataArray())): data = data.to_numpy() if not ((data.shape[0]==self.ydim) and (data.shape[1]==self.xdim)): raise DataSizeException data = data.reshape(self.xdim * self.ydim) self._referenceImages = np.append(self._referenceImages, data, axis=1) def _remove_first_reference_images(self): """Remove the first reference images """ self._referenceImages = np.delete(self._referenceImages, 0, axis=1) def update_reference_images(self, data): """Update the reference images set by removing the first one and adding a new one at the end. :param data: The new reference image. :type data: xarray, numpy array or list """ self._remove_first_reference_images() self.add_reference_images(data) if self._mask is None: self.mask = np.ones((self.ydim, self.xdim), dtype=np.uint8) self.decompose_referenceImages() @property def mask(self): return self._mask @mask.setter def mask(self, value): if self.reshape: value = self.reshape_data(value) elif isinstance(value, type(xr.DataArray())): value = value.to_numpy() if not ((value.shape[0]==self.ydim) and (value.shape[1]==self.xdim)): raise DataSizeException self._mask = value self.k = np.where(self._mask.flatten() == 1)[0] def decompose_referenceImages(self): self.P, self.L, self.U = lu(self._referenceImages[self.k, :].T @ self._referenceImages[self.k, :], permute_l = False, p_indices = True) def solve_coefficient(self): pass def _fringe_removal(self, absorptionImages, referenceImages=None, mask=None, reshape=None, dask='forbidden'): if not reshape is None: self.reshape = reshape if not referenceImages is None: self.referenceImages = referenceImages if not mask is None: self.mask = mask if self.P is None: self.decompose_referenceImages() absorptionImages = np.atleast_3d(absorptionImages) if self.reshape: absorptionImages = self.reshape_data(absorptionImages) self.nimgs = absorptionImages.shape[2] absorptionImages = (absorptionImages.reshape(self.xdim * self.ydim, self.nimgs).astype(np.float32)) optrefimages = np.zeros_like(absorptionImages, dtype=np.float32) if dask=='forbidden': for j in range(self.nimgs): b = self._referenceImages[self.k, :].T @ absorptionImages[self.k, j] # Obtain coefficients c which minimize least-square residuals c = np.linalg.solve(self.U, np.linalg.solve(self.L[self.P], b)) # Compute optimized reference image optrefimages[:, j] = (self._referenceImages @ c) else: pass return optrefimages def fringe_removal(self, absorptionImages, referenceImages=None, mask=None, reshape=None, dask='forbidden'): """ This function will generate a 'fake' background images, which can help to remove the fringes. Important: Please substract the drak images from the both of images with atoms and without atoms before using this function!!! :param absorptionImages: A set of images with atoms in absorption imaging :type absorptionImages: xarray, numpy array or list :param referenceImages: A set of images without atoms in absorption imaging, defaults to None :type referenceImages: xarray, numpy array or list, optional :param mask: An array to choose the region of interest for fringes removal, defaults to None, defaults to None :type mask: numpy array, optional :param reshape: If it needs to reshape the data, defaults to None :type reshape: bool, optional :param dask: Please refer to xarray.apply_ufunc() :type dask: {"forbidden", "allowed", "parallelized"}, optional :return: The 'fake' background to help removing the fringes :rtype: numpy array """ res = self._fringe_removal(absorptionImages, referenceImages, mask, reshape, dask) res = res.reshape(self.ydim, self.xdim, self.nimgs) if self.reshape: return np.swapaxes(res, 0, 2) else: return res