import numpy as np from scipy.linalg import lu import xarray as xr class InvalidDimException(Exception): "Raised when the program can not identify (index of images, x, y) axes." def __init__(self, dims): if len(dims)>3: self.message = 'The input data must have two or three axes: (index of images(alternative), x, y)' else: self.message = 'Can not identify (index of images(alternative), x, y) from ' + str(dims) super().__init__(self.message) class DataSizeException(Exception): "Raised when the shape of the data is not correct." def __init__(self): self.message = 'The input data size does not match.' super().__init__(self.message) class FringeRemoval(): """ FRINGEREMOVAL - Fringe removal and noise reduction from absorption images. Creates an optimal reference image for each absorption image in a set as a linear combination of reference images, with coefficients chosen to minimize the least-squares residuals between each absorption image and the optimal reference image. The coefficients are obtained by solving a linear set of equations using matrix inverse by LU decomposition. Application of the algorithm is described in C. F. Ockeloen et al, Improved detection of small atom numbers through image processing, arXiv:1007.2136 (2010). Adapted from a MATLAB script copy provided by Guoxian Su. Original Authors: Shannon Whitlock, Caspar Ockeloen Reference: C. F. Ockeloen, A. F. Tauschinsky, R. J. C. Spreeuw, and S. Whitlock, Improved detection of small atom numbers through image processing, arXiv:1007.2136 May 2009; """ def __init__(self) -> None: """Initialize the class """ self.nimgsR = 0 # The number of the reference images self.xdim = 0 # The shape of x axis self.ydim = 0 # The shape of y axis self._mask = None # The mask array to choose the region of interest for fringes removal self._center = None # Set the mask array by center and span self._span = None self.reshape=True # If it is necessary to reshape the data from (index of images(alternative), x, y) to (y, x, index of images(alternative)) self.P = None self.L = None self.U = None @property def center(self): """The getter of the center of region of insterest (ROI) :return: The center of region of insterest (ROI) :rtype: tuple """ return self._center @center.setter def center(self, value): """The setter of the center of region of insterest (ROI) :param value: The center of region of insterest (ROI) :type value: tuple """ self._mask = None self._center = value @property def span(self): """The getter of the span of region of insterest (ROI) :return: The span of region of insterest (ROI) :rtype: tuple """ return self._span @span.setter def span(self, value): """The setter of the span of region of insterest (ROI) :param value: The span of region of insterest (ROI) :type value: tuple """ self._mask = None self._span = value def reshape_data(self, data): """The function is to reshape the data to the correct shape. In order to minimize the calculation time, the data has to have a shape of (y, x, index of images(alternative)). However, usually the input data has a shape of (index of images(alternative), x, y). It can also convert the xarray DataArray and Dataset to numpy array. :param data: The input data. :type data: xarray, numpy array or list :raises InvalidDimException: Raised when the program can not identify (index of images, x, y) axes. :raises InvalidDimException: Raised when the shape of the data is not correct. :return: The data with correct shape :rtype: xarray, numpy array or list """ if data is None: return data if isinstance(data, type(xr.DataArray())): dims = data.dims if len(dims)>3: raise InvalidDimException(dims) xAxis = None yAxis = None if len(dims) == 2: imageAxis = '' else: imageAxis = None for dim in dims: if (dim == 'x') or ('_x' in dim): xAxis = dim elif (dim == 'y') or ('_y' in dim): yAxis = dim else: imageAxis = dim if (xAxis is None) or (yAxis is None) or (imageAxis is None): raise InvalidDimException(dims) if len(dims) == 2: data = data.transpose(yAxis, xAxis) else: data = data.transpose(yAxis, xAxis, imageAxis) data = data.to_numpy() else: data = np.array(data) if len(data.shape) == 3: data = np.swapaxes(data, 0, 2) # data = np.swapaxes(data, 0, 1) elif len(data.shape) == 2: data = np.swapaxes(data, 0, 1) return data def _reshape_absorption_images(self, data): if data is None: return data if isinstance(data, type(xr.DataArray())): dims = data.dims if len(dims)>3: raise InvalidDimException(dims) xAxis = None yAxis = None if len(dims) == 2: imageAxis = '' else: imageAxis = None for dim in dims: if (dim == 'x') or ('_x' in dim): xAxis = dim elif (dim == 'y') or ('_y' in dim): yAxis = dim else: imageAxis = dim if (xAxis is None) or (yAxis is None) or (imageAxis is None): raise InvalidDimException(dims) if len(dims) == 2: data = data.transpose(yAxis, xAxis) else: data = data.transpose(yAxis, xAxis, imageAxis) self.nimgs = len(data[imageAxis]) data = data.stack(axis=[yAxis, xAxis]) else: data = np.array(data) if len(data.shape) == 3: data = np.swapaxes(data, 0, 2) # data = np.swapaxes(data, 0, 1) elif len(data.shape) == 2: data = np.swapaxes(data, 0, 1) return data @property def referenceImages(self): res = self._referenceImages.reshape(self.ydim, self.xdim, self.nimgsR) res = np.swapaxes(res, 0, 2) return res @referenceImages.setter def referenceImages(self, value): if value is None: self._referenceImages = None return if self.reshape: value = self.reshape_data(value) elif isinstance(value, type(xr.DataArray())): value = value.to_numpy() self.nimgsR = value.shape[2] self.xdim = value.shape[1] self.ydim = value.shape[0] self._referenceImages = (value.reshape(self.xdim * self.ydim, self.nimgsR).astype(np.float32)) def add_reference_images(self, data): """Add a new reference images :param data: The new reference image. :type data: xarray, numpy array or list :raises DataSizeException: Raised when the shape of the data is not correct. """ if self.reshape: data = self.reshape_data(data) elif isinstance(data, type(xr.DataArray())): data = data.to_numpy() if not ((data.shape[0]==self.ydim) and (data.shape[1]==self.xdim)): raise DataSizeException data = data.reshape(self.xdim * self.ydim) self._referenceImages = np.append(self._referenceImages, data, axis=1) def _remove_first_reference_images(self): """Remove the first reference images """ self._referenceImages = np.delete(self._referenceImages, 0, axis=1) def update_reference_images(self, data): """Update the reference images set by removing the first one and adding a new one at the end. :param data: The new reference image. :type data: xarray, numpy array or list """ self._remove_first_reference_images() self.add_reference_images(data) self.decompose_referenceImages() @property def mask(self): return self._mask @mask.setter def mask(self, value): if self.reshape: value = self.reshape_data(value) elif isinstance(value, type(xr.DataArray())): value = value.to_numpy() if not ((value.shape[0]==self.ydim) and (value.shape[1]==self.xdim)): raise DataSizeException self._mask = value self._center = None self._span = None self.k = np.where(self._mask.flatten() == 1)[0] def _auto_mask(self): mask = np.ones((self.ydim, self.xdim), dtype=np.uint8) if not self._center is None: x_start = int(self._center[0] - self._span[0] / 2) x_end = int(self._center[0] + self._span[0] / 2) y_end = int(self._center[1] + self._span[1] / 2) y_start = int(self._center[1] - self._span[1] / 2) mask[y_start:y_end, x_start:x_end] = 0 return mask def decompose_referenceImages(self): if self._mask is None: self.mask = self._auto_mask() self.P, self.L, self.U = lu(self._referenceImages[self.k, :].T @ self._referenceImages[self.k, :], permute_l = False, p_indices = True) def _fringe_removal(self, absorptionImages): b = self.temp @ absorptionImages[self.k] c = np.linalg.solve(self.U, np.linalg.solve(self.L[self.P], b)) optrefimages = (self._referenceImages @ c) return optrefimages def fringe_removal(self, absorptionImages, referenceImages=None, mask=None, reshape=None, dask='parallelized'): """ This function will generate a 'fake' background images, which can help to remove the fringes. Important: Please substract the drak images from the both of images with atoms and without atoms before using this function!!! :param absorptionImages: A set of images with atoms in absorption imaging :type absorptionImages: xarray, numpy array or list :param referenceImages: A set of images without atoms in absorption imaging, defaults to None :type referenceImages: xarray, numpy array or list, optional :param mask: An array to choose the region of interest for fringes removal, defaults to None, defaults to None :type mask: numpy array, optional :param reshape: If it needs to reshape the data, defaults to None :type reshape: bool, optional :param dask: Please refer to xarray.apply_ufunc() :type dask: {"forbidden", "allowed", "parallelized"}, optional :return: The 'fake' background to help removing the fringes :rtype: xarray array """ if not reshape is None: self.reshape = reshape if not referenceImages is None: self.referenceImages = referenceImages if not mask is None: self.mask = mask if self.P is None: self.decompose_referenceImages() if self.reshape: absorptionImages = self._reshape_absorption_images(absorptionImages) self.temp = self._referenceImages[self.k, :].T optrefimages = xr.apply_ufunc(self._fringe_removal, absorptionImages, input_core_dims=[['axis']], output_core_dims=[['axis']], dask=dask, vectorize=True, output_dtypes=float) return optrefimages.unstack()