diff --git a/Analyser/FringeRemoval.py b/Analyser/FringeRemoval.py index fe52934..a2da3c3 100644 --- a/Analyser/FringeRemoval.py +++ b/Analyser/FringeRemoval.py @@ -3,38 +3,6 @@ from scipy.linalg import lu import xarray as xr -def fringeremoval(absimages, refimages, bgmask=None): - # Process inputs - nimgs = absimages.shape[2] - nimgsR = refimages.shape[2] - xdim = absimages.shape[1] - ydim = absimages.shape[0] - - A = (absimages.reshape(xdim * ydim, nimgs).astype(np.float32)) - R = (refimages.reshape(xdim * ydim, nimgsR).astype(np.float32)) - optrefimages = np.zeros_like(absimages, dtype=np.float32) - - if bgmask is None: - bgmask = np.ones((ydim, xdim), dtype=np.uint8) - k = np.where(bgmask.flatten() == 1)[0] # Index k specifying the background region - - # Ensure there are no duplicate reference images - # R = np.unique(R, axis=1) # Comment this line if memory issues arise - - # Decompose B = R * R' using LU decomposition - P, L, U = lu(R[k, :].T @ R[k, :], permute_l = False, p_indices = True) - - for j in range(nimgs): - b = R[k, :].T @ A[k, j] - - # Obtain coefficients c which minimize least-square residuals - c = np.linalg.solve(U, np.linalg.solve(L[P], b)) - # Compute optimized reference image - optrefimages[:, :, j] = (R @ c).reshape((ydim, xdim)) - - return optrefimages - - class InvalidDimException(Exception): "Raised when the program can not identify (index of images, x, y) axes." def __init__(self, dims): @@ -53,66 +21,82 @@ class DataSizeException(Exception): class FringeRemoval(): + """A class for fringes removal + """ def __init__(self) -> None: - self.nimgsR = 0 - self.xdim = 0 - self.ydim = 0 + """Initialize the class + """ + self.nimgsR = 0 # The number of the reference images + self.xdim = 0 # The shape of x axis + self.ydim = 0 # The shape of y axis - self._mask = None + self._mask = None # The mask array to choose the region of interest for fringes removal - self.reshape=True + self.reshape=True # If it is necessary to reshape the data from (index of images(alternative), x, y) to (y, x, index of images(alternative)) self.P = None self.L = None self.U = None - - def reshape_data(self, data): - if data is None: - return data + def reshape_data(self, data): + """The function is to reshape the data to the correct shape. + In order to minimize the calculation time, the data has to have a shape of (y, x, index of images(alternative)). + However, usually the input data has a shape of (index of images(alternative), x, y). + It can also convert the xarray DataArray and Dataset to numpy array. - if isinstance(data, type(xr.DataArray())): - - dims = data.dims - - if len(dims)>3: - raise InvalidDimException(dims) - - xAxis = None - yAxis = None - if len(dims) == 2: - imageAxis = '' - else: - imageAxis = None - - for dim in dims: - if (dim == 'x') or ('_x' in dim): - xAxis = dim - elif (dim == 'y') or ('_y' in dim): - yAxis = dim - else: - imageAxis = dim - - if (xAxis is None) or (yAxis is None) or (imageAxis is None): - raise InvalidDimException(dims) - - if len(dims) == 2: - data = data.transpose(yAxis, xAxis) - else: - data = data.transpose(yAxis, xAxis, imageAxis) - - data = data.to_numpy() - - else: - data = np.array(data) - if len(data.shape) == 3: - data = np.swapaxes(data, 0, 2) - # data = np.swapaxes(data, 0, 1) - elif len(data.shape) == 2: - data = np.swapaxes(data, 0, 1) - + :param data: The input data. + :type data: xarray, numpy array or list + :raises InvalidDimException: Raised when the program can not identify (index of images, x, y) axes. + :raises InvalidDimException: Raised when the shape of the data is not correct. + :return: The data with correct shape + :rtype: xarray, numpy array or list + """ + + if data is None: return data + + if isinstance(data, type(xr.DataArray())): + + dims = data.dims + + if len(dims)>3: + raise InvalidDimException(dims) + + xAxis = None + yAxis = None + if len(dims) == 2: + imageAxis = '' + else: + imageAxis = None + + for dim in dims: + if (dim == 'x') or ('_x' in dim): + xAxis = dim + elif (dim == 'y') or ('_y' in dim): + yAxis = dim + else: + imageAxis = dim + + if (xAxis is None) or (yAxis is None) or (imageAxis is None): + raise InvalidDimException(dims) + + if len(dims) == 2: + data = data.transpose(yAxis, xAxis) + else: + data = data.transpose(yAxis, xAxis, imageAxis) + + data = data.to_numpy() + + else: + data = np.array(data) + if len(data.shape) == 3: + data = np.swapaxes(data, 0, 2) + # data = np.swapaxes(data, 0, 1) + elif len(data.shape) == 2: + data = np.swapaxes(data, 0, 1) + + return data @property def referenceImages(self): @@ -132,8 +116,15 @@ class FringeRemoval(): self.ydim = value.shape[0] self._referenceImages = (value.reshape(self.xdim * self.ydim, self.nimgsR).astype(np.float32)) - + def add_reference_images(self, data): + """Add a new reference images + + :param data: The new reference image. + :type data: xarray, numpy array or list + :raises DataSizeException: Raised when the shape of the data is not correct. + """ + if self.reshape: data = self.reshape_data(data) elif isinstance(data, type(xr.DataArray())): @@ -147,9 +138,16 @@ class FringeRemoval(): self._referenceImages = np.append(self._referenceImages, data, axis=1) def _remove_first_reference_images(self): + """Remove the first reference images + """ self._referenceImages = np.delete(self._referenceImages, 0, axis=1) def update_reference_images(self, data): + """Update the reference images set by removing the first one and adding a new one at the end. + + :param data: The new reference image. + :type data: xarray, numpy array or list + """ self._remove_first_reference_images() self.add_reference_images(data) @@ -178,7 +176,10 @@ class FringeRemoval(): def decompose_referenceImages(self): self.P, self.L, self.U = lu(self._referenceImages[self.k, :].T @ self._referenceImages[self.k, :], permute_l = False, p_indices = True) - def _fringe_removal(self, absorptionImages, referenceImages=None, mask=None, reshape=None): + def solve_coefficient(self): + pass + + def _fringe_removal(self, absorptionImages, referenceImages=None, mask=None, reshape=None, dask='forbidden'): if not reshape is None: self.reshape = reshape if not referenceImages is None: @@ -198,19 +199,43 @@ class FringeRemoval(): optrefimages = np.zeros_like(absorptionImages, dtype=np.float32) - for j in range(self.nimgs): - b = self._referenceImages[self.k, :].T @ absorptionImages[self.k, j] + if dask=='forbidden': + for j in range(self.nimgs): + b = self._referenceImages[self.k, :].T @ absorptionImages[self.k, j] - # Obtain coefficients c which minimize least-square residuals - c = np.linalg.solve(self.U, np.linalg.solve(self.L[self.P], b)) - # Compute optimized reference image - optrefimages[:, j] = (self._referenceImages @ c) + # Obtain coefficients c which minimize least-square residuals + c = np.linalg.solve(self.U, np.linalg.solve(self.L[self.P], b)) + # Compute optimized reference image + optrefimages[:, j] = (self._referenceImages @ c) + else: + pass return optrefimages - def fringe_removal(self, absorptionImages, referenceImages=None, mask=None, reshape=None): + def fringe_removal(self, absorptionImages, referenceImages=None, mask=None, reshape=None, dask='forbidden'): + """ + This function will generate a 'fake' background images, which can help to remove the fringes. - res = self._fringe_removal(absorptionImages, referenceImages, mask, reshape) + Important: Please substract the drak images from the both of images with atoms and without atoms before using this function!!! + + :param absorptionImages: A set of images with atoms in absorption imaging + :type absorptionImages: xarray, numpy array or list + :param referenceImages: A set of images without atoms in absorption imaging, defaults to None + :type referenceImages: xarray, numpy array or list, optional + :param mask: An array to choose the region of interest for fringes removal, defaults to None, defaults to None + :type mask: numpy array, optional + :param reshape: If it needs to reshape the data, defaults to None + :type reshape: bool, optional + :param dask: Please refer to xarray.apply_ufunc() + :type dask: {"forbidden", "allowed", "parallelized"}, optional + :return: The 'fake' background to help removing the fringes + :rtype: numpy array + """ + + res = self._fringe_removal(absorptionImages, referenceImages, mask, reshape, dask) res = res.reshape(self.ydim, self.xdim, self.nimgs) - return np.swapaxes(res, 0, 2) + if self.reshape: + return np.swapaxes(res, 0, 2) + else: + return res \ No newline at end of file