diff --git a/Analyser/FringeRemoval.py b/Analyser/FringeRemoval.py new file mode 100644 index 0000000..fe52934 --- /dev/null +++ b/Analyser/FringeRemoval.py @@ -0,0 +1,216 @@ +import numpy as np +from scipy.linalg import lu + +import xarray as xr + +def fringeremoval(absimages, refimages, bgmask=None): + # Process inputs + nimgs = absimages.shape[2] + nimgsR = refimages.shape[2] + xdim = absimages.shape[1] + ydim = absimages.shape[0] + + A = (absimages.reshape(xdim * ydim, nimgs).astype(np.float32)) + R = (refimages.reshape(xdim * ydim, nimgsR).astype(np.float32)) + optrefimages = np.zeros_like(absimages, dtype=np.float32) + + if bgmask is None: + bgmask = np.ones((ydim, xdim), dtype=np.uint8) + k = np.where(bgmask.flatten() == 1)[0] # Index k specifying the background region + + # Ensure there are no duplicate reference images + # R = np.unique(R, axis=1) # Comment this line if memory issues arise + + # Decompose B = R * R' using LU decomposition + P, L, U = lu(R[k, :].T @ R[k, :], permute_l = False, p_indices = True) + + for j in range(nimgs): + b = R[k, :].T @ A[k, j] + + # Obtain coefficients c which minimize least-square residuals + c = np.linalg.solve(U, np.linalg.solve(L[P], b)) + # Compute optimized reference image + optrefimages[:, :, j] = (R @ c).reshape((ydim, xdim)) + + return optrefimages + + +class InvalidDimException(Exception): + "Raised when the program can not identify (index of images, x, y) axes." + def __init__(self, dims): + if len(dims)>3: + self.message = 'The input data must have two or three axes: (index of images(alternative), x, y)' + else: + self.message = 'Can not identify (index of images(alternative), x, y) from ' + str(dims) + super().__init__(self.message) + + +class DataSizeException(Exception): + "Raised when the shape of the data is not correct." + def __init__(self): + self.message = 'The input data size does not match.' + super().__init__(self.message) + + +class FringeRemoval(): + + def __init__(self) -> None: + self.nimgsR = 0 + self.xdim = 0 + self.ydim = 0 + + self._mask = None + + self.reshape=True + + self.P = None + self.L = None + self.U = None + + def reshape_data(self, data): + + if data is None: + return data + + if isinstance(data, type(xr.DataArray())): + + dims = data.dims + + if len(dims)>3: + raise InvalidDimException(dims) + + xAxis = None + yAxis = None + if len(dims) == 2: + imageAxis = '' + else: + imageAxis = None + + for dim in dims: + if (dim == 'x') or ('_x' in dim): + xAxis = dim + elif (dim == 'y') or ('_y' in dim): + yAxis = dim + else: + imageAxis = dim + + if (xAxis is None) or (yAxis is None) or (imageAxis is None): + raise InvalidDimException(dims) + + if len(dims) == 2: + data = data.transpose(yAxis, xAxis) + else: + data = data.transpose(yAxis, xAxis, imageAxis) + + data = data.to_numpy() + + else: + data = np.array(data) + if len(data.shape) == 3: + data = np.swapaxes(data, 0, 2) + # data = np.swapaxes(data, 0, 1) + elif len(data.shape) == 2: + data = np.swapaxes(data, 0, 1) + + return data + + @property + def referenceImages(self): + res = self._referenceImages.reshape(self.ydim, self.xdim, self.nimgsR) + res = np.swapaxes(res, 0, 2) + return res + + @referenceImages.setter + def referenceImages(self, value): + if self.reshape: + value = self.reshape_data(value) + elif isinstance(value, type(xr.DataArray())): + value = value.to_numpy() + + self.nimgsR = value.shape[2] + self.xdim = value.shape[1] + self.ydim = value.shape[0] + + self._referenceImages = (value.reshape(self.xdim * self.ydim, self.nimgsR).astype(np.float32)) + + def add_reference_images(self, data): + if self.reshape: + data = self.reshape_data(data) + elif isinstance(data, type(xr.DataArray())): + data = data.to_numpy() + + if not ((data.shape[0]==self.ydim) and (data.shape[1]==self.xdim)): + raise DataSizeException + + data = data.reshape(self.xdim * self.ydim) + + self._referenceImages = np.append(self._referenceImages, data, axis=1) + + def _remove_first_reference_images(self): + self._referenceImages = np.delete(self._referenceImages, 0, axis=1) + + def update_reference_images(self, data): + self._remove_first_reference_images() + self.add_reference_images(data) + + if self._mask is None: + self.mask = np.ones((self.ydim, self.xdim), dtype=np.uint8) + + self.decompose_referenceImages() + + @property + def mask(self): + return self._mask + + @mask.setter + def mask(self, value): + if self.reshape: + value = self.reshape_data(value) + elif isinstance(value, type(xr.DataArray())): + value = value.to_numpy() + + if not ((value.shape[0]==self.ydim) and (value.shape[1]==self.xdim)): + raise DataSizeException + + self._mask = value + self.k = np.where(self._mask.flatten() == 1)[0] + + def decompose_referenceImages(self): + self.P, self.L, self.U = lu(self._referenceImages[self.k, :].T @ self._referenceImages[self.k, :], permute_l = False, p_indices = True) + + def _fringe_removal(self, absorptionImages, referenceImages=None, mask=None, reshape=None): + if not reshape is None: + self.reshape = reshape + if not referenceImages is None: + self.referenceImages = referenceImages + if not mask is None: + self.mask = mask + if self.P is None: + self.decompose_referenceImages() + + absorptionImages = np.atleast_3d(absorptionImages) + + if self.reshape: + absorptionImages = self.reshape_data(absorptionImages) + + self.nimgs = absorptionImages.shape[2] + absorptionImages = (absorptionImages.reshape(self.xdim * self.ydim, self.nimgs).astype(np.float32)) + + optrefimages = np.zeros_like(absorptionImages, dtype=np.float32) + + for j in range(self.nimgs): + b = self._referenceImages[self.k, :].T @ absorptionImages[self.k, j] + + # Obtain coefficients c which minimize least-square residuals + c = np.linalg.solve(self.U, np.linalg.solve(self.L[self.P], b)) + # Compute optimized reference image + optrefimages[:, j] = (self._referenceImages @ c) + + return optrefimages + + def fringe_removal(self, absorptionImages, referenceImages=None, mask=None, reshape=None): + + res = self._fringe_removal(absorptionImages, referenceImages, mask, reshape) + res = res.reshape(self.ydim, self.xdim, self.nimgs) + return np.swapaxes(res, 0, 2) + \ No newline at end of file