diff --git a/Analyser/FringeRemoval.py b/Analyser/FringeRemoval.py index a2da3c3..b675569 100644 --- a/Analyser/FringeRemoval.py +++ b/Analyser/FringeRemoval.py @@ -32,12 +32,52 @@ class FringeRemoval(): self.ydim = 0 # The shape of y axis self._mask = None # The mask array to choose the region of interest for fringes removal + self._center = None # Set the mask array by center and span + self._span = None self.reshape=True # If it is necessary to reshape the data from (index of images(alternative), x, y) to (y, x, index of images(alternative)) self.P = None self.L = None self.U = None + + @property + def center(self): + """The getter of the center of region of insterest (ROI) + + :return: The center of region of insterest (ROI) + :rtype: tuple + """ + return self._center + + @center.setter + def center(self, value): + """The setter of the center of region of insterest (ROI) + + :param value: The center of region of insterest (ROI) + :type value: tuple + """ + self._mask = None + self._center = value + + @property + def span(self): + """The getter of the span of region of insterest (ROI) + + :return: The span of region of insterest (ROI) + :rtype: tuple + """ + return self._span + + @span.setter + def span(self, value): + """The setter of the span of region of insterest (ROI) + + :param value: The span of region of insterest (ROI) + :type value: tuple + """ + self._mask = None + self._span = value def reshape_data(self, data): """The function is to reshape the data to the correct shape. @@ -98,6 +138,55 @@ class FringeRemoval(): return data + def _reshape_absorption_images(self, data): + + if data is None: + return data + + if isinstance(data, type(xr.DataArray())): + + dims = data.dims + + if len(dims)>3: + raise InvalidDimException(dims) + + xAxis = None + yAxis = None + if len(dims) == 2: + imageAxis = '' + else: + imageAxis = None + + for dim in dims: + if (dim == 'x') or ('_x' in dim): + xAxis = dim + elif (dim == 'y') or ('_y' in dim): + yAxis = dim + else: + imageAxis = dim + + if (xAxis is None) or (yAxis is None) or (imageAxis is None): + raise InvalidDimException(dims) + + if len(dims) == 2: + data = data.transpose(yAxis, xAxis) + else: + data = data.transpose(yAxis, xAxis, imageAxis) + + self.nimgs = len(data[imageAxis]) + + data = data.stack(axis=[yAxis, xAxis]) + + else: + data = np.array(data) + if len(data.shape) == 3: + data = np.swapaxes(data, 0, 2) + # data = np.swapaxes(data, 0, 1) + elif len(data.shape) == 2: + data = np.swapaxes(data, 0, 1) + + return data + @property def referenceImages(self): res = self._referenceImages.reshape(self.ydim, self.xdim, self.nimgsR) @@ -150,9 +239,6 @@ class FringeRemoval(): """ self._remove_first_reference_images() self.add_reference_images(data) - - if self._mask is None: - self.mask = np.ones((self.ydim, self.xdim), dtype=np.uint8) self.decompose_referenceImages() @@ -173,46 +259,32 @@ class FringeRemoval(): self._mask = value self.k = np.where(self._mask.flatten() == 1)[0] + def _auto_mask(self): + mask = np.ones((self.ydim, self.xdim), dtype=np.uint8) + if not self._center is None: + x_start = int(self._center[0] - self._span[0] / 2) + x_end = int(self._center[0] + self._span[0] / 2) + y_end = int(self._center[1] + self._span[1] / 2) + y_start = int(self._center[1] - self._span[1] / 2) + + mask[y_start:y_end, x_start:x_end] = 0 + + return mask + def decompose_referenceImages(self): + if self._mask is None: + self.mask = self._auto_mask() self.P, self.L, self.U = lu(self._referenceImages[self.k, :].T @ self._referenceImages[self.k, :], permute_l = False, p_indices = True) - def solve_coefficient(self): - pass - - def _fringe_removal(self, absorptionImages, referenceImages=None, mask=None, reshape=None, dask='forbidden'): - if not reshape is None: - self.reshape = reshape - if not referenceImages is None: - self.referenceImages = referenceImages - if not mask is None: - self.mask = mask - if self.P is None: - self.decompose_referenceImages() + def _fringe_removal(self, absorptionImages): - absorptionImages = np.atleast_3d(absorptionImages) - - if self.reshape: - absorptionImages = self.reshape_data(absorptionImages) + b = self.temp @ absorptionImages[self.k] + c = np.linalg.solve(self.U, np.linalg.solve(self.L[self.P], b)) + optrefimages = (self._referenceImages @ c) - self.nimgs = absorptionImages.shape[2] - absorptionImages = (absorptionImages.reshape(self.xdim * self.ydim, self.nimgs).astype(np.float32)) - - optrefimages = np.zeros_like(absorptionImages, dtype=np.float32) - - if dask=='forbidden': - for j in range(self.nimgs): - b = self._referenceImages[self.k, :].T @ absorptionImages[self.k, j] - - # Obtain coefficients c which minimize least-square residuals - c = np.linalg.solve(self.U, np.linalg.solve(self.L[self.P], b)) - # Compute optimized reference image - optrefimages[:, j] = (self._referenceImages @ c) - else: - pass - return optrefimages - def fringe_removal(self, absorptionImages, referenceImages=None, mask=None, reshape=None, dask='forbidden'): + def fringe_removal(self, absorptionImages, referenceImages=None, mask=None, reshape=None, dask='parallelized'): """ This function will generate a 'fake' background images, which can help to remove the fringes. @@ -229,13 +301,22 @@ class FringeRemoval(): :param dask: Please refer to xarray.apply_ufunc() :type dask: {"forbidden", "allowed", "parallelized"}, optional :return: The 'fake' background to help removing the fringes - :rtype: numpy array + :rtype: xarray array """ + if not reshape is None: + self.reshape = reshape + if not referenceImages is None: + self.referenceImages = referenceImages + if not mask is None: + self.mask = mask + if self.P is None: + self.decompose_referenceImages() - res = self._fringe_removal(absorptionImages, referenceImages, mask, reshape, dask) - res = res.reshape(self.ydim, self.xdim, self.nimgs) if self.reshape: - return np.swapaxes(res, 0, 2) - else: - return res - \ No newline at end of file + absorptionImages = self._reshape_absorption_images(absorptionImages) + + self.temp = self._referenceImages[self.k, :].T + + optrefimages = xr.apply_ufunc(self._fringe_removal, absorptionImages, input_core_dims=[['axis']], output_core_dims=[['axis']], dask=dask, vectorize=True, output_dtypes=float) + + return optrefimages.unstack() \ No newline at end of file