import numpy as np from scipy.linalg import lu import xarray as xr def fringeremoval(absimages, refimages, bgmask=None): # Process inputs nimgs = absimages.shape[2] nimgsR = refimages.shape[2] xdim = absimages.shape[1] ydim = absimages.shape[0] A = (absimages.reshape(xdim * ydim, nimgs).astype(np.float32)) R = (refimages.reshape(xdim * ydim, nimgsR).astype(np.float32)) optrefimages = np.zeros_like(absimages, dtype=np.float32) if bgmask is None: bgmask = np.ones((ydim, xdim), dtype=np.uint8) k = np.where(bgmask.flatten() == 1)[0] # Index k specifying the background region # Ensure there are no duplicate reference images # R = np.unique(R, axis=1) # Comment this line if memory issues arise # Decompose B = R * R' using LU decomposition P, L, U = lu(R[k, :].T @ R[k, :], permute_l = False, p_indices = True) for j in range(nimgs): b = R[k, :].T @ A[k, j] # Obtain coefficients c which minimize least-square residuals c = np.linalg.solve(U, np.linalg.solve(L[P], b)) # Compute optimized reference image optrefimages[:, :, j] = (R @ c).reshape((ydim, xdim)) return optrefimages class InvalidDimException(Exception): "Raised when the program can not identify (index of images, x, y) axes." def __init__(self, dims): if len(dims)>3: self.message = 'The input data must have two or three axes: (index of images(alternative), x, y)' else: self.message = 'Can not identify (index of images(alternative), x, y) from ' + str(dims) super().__init__(self.message) class DataSizeException(Exception): "Raised when the shape of the data is not correct." def __init__(self): self.message = 'The input data size does not match.' super().__init__(self.message) class FringeRemoval(): def __init__(self) -> None: self.nimgsR = 0 self.xdim = 0 self.ydim = 0 self._mask = None self.reshape=True self.P = None self.L = None self.U = None def reshape_data(self, data): if data is None: return data if isinstance(data, type(xr.DataArray())): dims = data.dims if len(dims)>3: raise InvalidDimException(dims) xAxis = None yAxis = None if len(dims) == 2: imageAxis = '' else: imageAxis = None for dim in dims: if (dim == 'x') or ('_x' in dim): xAxis = dim elif (dim == 'y') or ('_y' in dim): yAxis = dim else: imageAxis = dim if (xAxis is None) or (yAxis is None) or (imageAxis is None): raise InvalidDimException(dims) if len(dims) == 2: data = data.transpose(yAxis, xAxis) else: data = data.transpose(yAxis, xAxis, imageAxis) data = data.to_numpy() else: data = np.array(data) if len(data.shape) == 3: data = np.swapaxes(data, 0, 2) # data = np.swapaxes(data, 0, 1) elif len(data.shape) == 2: data = np.swapaxes(data, 0, 1) return data @property def referenceImages(self): res = self._referenceImages.reshape(self.ydim, self.xdim, self.nimgsR) res = np.swapaxes(res, 0, 2) return res @referenceImages.setter def referenceImages(self, value): if self.reshape: value = self.reshape_data(value) elif isinstance(value, type(xr.DataArray())): value = value.to_numpy() self.nimgsR = value.shape[2] self.xdim = value.shape[1] self.ydim = value.shape[0] self._referenceImages = (value.reshape(self.xdim * self.ydim, self.nimgsR).astype(np.float32)) def add_reference_images(self, data): if self.reshape: data = self.reshape_data(data) elif isinstance(data, type(xr.DataArray())): data = data.to_numpy() if not ((data.shape[0]==self.ydim) and (data.shape[1]==self.xdim)): raise DataSizeException data = data.reshape(self.xdim * self.ydim) self._referenceImages = np.append(self._referenceImages, data, axis=1) def _remove_first_reference_images(self): self._referenceImages = np.delete(self._referenceImages, 0, axis=1) def update_reference_images(self, data): self._remove_first_reference_images() self.add_reference_images(data) if self._mask is None: self.mask = np.ones((self.ydim, self.xdim), dtype=np.uint8) self.decompose_referenceImages() @property def mask(self): return self._mask @mask.setter def mask(self, value): if self.reshape: value = self.reshape_data(value) elif isinstance(value, type(xr.DataArray())): value = value.to_numpy() if not ((value.shape[0]==self.ydim) and (value.shape[1]==self.xdim)): raise DataSizeException self._mask = value self.k = np.where(self._mask.flatten() == 1)[0] def decompose_referenceImages(self): self.P, self.L, self.U = lu(self._referenceImages[self.k, :].T @ self._referenceImages[self.k, :], permute_l = False, p_indices = True) def _fringe_removal(self, absorptionImages, referenceImages=None, mask=None, reshape=None): if not reshape is None: self.reshape = reshape if not referenceImages is None: self.referenceImages = referenceImages if not mask is None: self.mask = mask if self.P is None: self.decompose_referenceImages() absorptionImages = np.atleast_3d(absorptionImages) if self.reshape: absorptionImages = self.reshape_data(absorptionImages) self.nimgs = absorptionImages.shape[2] absorptionImages = (absorptionImages.reshape(self.xdim * self.ydim, self.nimgs).astype(np.float32)) optrefimages = np.zeros_like(absorptionImages, dtype=np.float32) for j in range(self.nimgs): b = self._referenceImages[self.k, :].T @ absorptionImages[self.k, j] # Obtain coefficients c which minimize least-square residuals c = np.linalg.solve(self.U, np.linalg.solve(self.L[self.P], b)) # Compute optimized reference image optrefimages[:, j] = (self._referenceImages @ c) return optrefimages def fringe_removal(self, absorptionImages, referenceImages=None, mask=None, reshape=None): res = self._fringe_removal(absorptionImages, referenceImages, mask, reshape) res = res.reshape(self.ydim, self.xdim, self.nimgs) return np.swapaxes(res, 0, 2)