implenment fringe removal class
This commit is contained in:
parent
afe7a5907e
commit
b70bc5faf5
216
Analyser/FringeRemoval.py
Normal file
216
Analyser/FringeRemoval.py
Normal file
@ -0,0 +1,216 @@
|
|||||||
|
import numpy as np
|
||||||
|
from scipy.linalg import lu
|
||||||
|
|
||||||
|
import xarray as xr
|
||||||
|
|
||||||
|
def fringeremoval(absimages, refimages, bgmask=None):
|
||||||
|
# Process inputs
|
||||||
|
nimgs = absimages.shape[2]
|
||||||
|
nimgsR = refimages.shape[2]
|
||||||
|
xdim = absimages.shape[1]
|
||||||
|
ydim = absimages.shape[0]
|
||||||
|
|
||||||
|
A = (absimages.reshape(xdim * ydim, nimgs).astype(np.float32))
|
||||||
|
R = (refimages.reshape(xdim * ydim, nimgsR).astype(np.float32))
|
||||||
|
optrefimages = np.zeros_like(absimages, dtype=np.float32)
|
||||||
|
|
||||||
|
if bgmask is None:
|
||||||
|
bgmask = np.ones((ydim, xdim), dtype=np.uint8)
|
||||||
|
k = np.where(bgmask.flatten() == 1)[0] # Index k specifying the background region
|
||||||
|
|
||||||
|
# Ensure there are no duplicate reference images
|
||||||
|
# R = np.unique(R, axis=1) # Comment this line if memory issues arise
|
||||||
|
|
||||||
|
# Decompose B = R * R' using LU decomposition
|
||||||
|
P, L, U = lu(R[k, :].T @ R[k, :], permute_l = False, p_indices = True)
|
||||||
|
|
||||||
|
for j in range(nimgs):
|
||||||
|
b = R[k, :].T @ A[k, j]
|
||||||
|
|
||||||
|
# Obtain coefficients c which minimize least-square residuals
|
||||||
|
c = np.linalg.solve(U, np.linalg.solve(L[P], b))
|
||||||
|
# Compute optimized reference image
|
||||||
|
optrefimages[:, :, j] = (R @ c).reshape((ydim, xdim))
|
||||||
|
|
||||||
|
return optrefimages
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidDimException(Exception):
|
||||||
|
"Raised when the program can not identify (index of images, x, y) axes."
|
||||||
|
def __init__(self, dims):
|
||||||
|
if len(dims)>3:
|
||||||
|
self.message = 'The input data must have two or three axes: (index of images(alternative), x, y)'
|
||||||
|
else:
|
||||||
|
self.message = 'Can not identify (index of images(alternative), x, y) from ' + str(dims)
|
||||||
|
super().__init__(self.message)
|
||||||
|
|
||||||
|
|
||||||
|
class DataSizeException(Exception):
|
||||||
|
"Raised when the shape of the data is not correct."
|
||||||
|
def __init__(self):
|
||||||
|
self.message = 'The input data size does not match.'
|
||||||
|
super().__init__(self.message)
|
||||||
|
|
||||||
|
|
||||||
|
class FringeRemoval():
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.nimgsR = 0
|
||||||
|
self.xdim = 0
|
||||||
|
self.ydim = 0
|
||||||
|
|
||||||
|
self._mask = None
|
||||||
|
|
||||||
|
self.reshape=True
|
||||||
|
|
||||||
|
self.P = None
|
||||||
|
self.L = None
|
||||||
|
self.U = None
|
||||||
|
|
||||||
|
def reshape_data(self, data):
|
||||||
|
|
||||||
|
if data is None:
|
||||||
|
return data
|
||||||
|
|
||||||
|
if isinstance(data, type(xr.DataArray())):
|
||||||
|
|
||||||
|
dims = data.dims
|
||||||
|
|
||||||
|
if len(dims)>3:
|
||||||
|
raise InvalidDimException(dims)
|
||||||
|
|
||||||
|
xAxis = None
|
||||||
|
yAxis = None
|
||||||
|
if len(dims) == 2:
|
||||||
|
imageAxis = ''
|
||||||
|
else:
|
||||||
|
imageAxis = None
|
||||||
|
|
||||||
|
for dim in dims:
|
||||||
|
if (dim == 'x') or ('_x' in dim):
|
||||||
|
xAxis = dim
|
||||||
|
elif (dim == 'y') or ('_y' in dim):
|
||||||
|
yAxis = dim
|
||||||
|
else:
|
||||||
|
imageAxis = dim
|
||||||
|
|
||||||
|
if (xAxis is None) or (yAxis is None) or (imageAxis is None):
|
||||||
|
raise InvalidDimException(dims)
|
||||||
|
|
||||||
|
if len(dims) == 2:
|
||||||
|
data = data.transpose(yAxis, xAxis)
|
||||||
|
else:
|
||||||
|
data = data.transpose(yAxis, xAxis, imageAxis)
|
||||||
|
|
||||||
|
data = data.to_numpy()
|
||||||
|
|
||||||
|
else:
|
||||||
|
data = np.array(data)
|
||||||
|
if len(data.shape) == 3:
|
||||||
|
data = np.swapaxes(data, 0, 2)
|
||||||
|
# data = np.swapaxes(data, 0, 1)
|
||||||
|
elif len(data.shape) == 2:
|
||||||
|
data = np.swapaxes(data, 0, 1)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
@property
|
||||||
|
def referenceImages(self):
|
||||||
|
res = self._referenceImages.reshape(self.ydim, self.xdim, self.nimgsR)
|
||||||
|
res = np.swapaxes(res, 0, 2)
|
||||||
|
return res
|
||||||
|
|
||||||
|
@referenceImages.setter
|
||||||
|
def referenceImages(self, value):
|
||||||
|
if self.reshape:
|
||||||
|
value = self.reshape_data(value)
|
||||||
|
elif isinstance(value, type(xr.DataArray())):
|
||||||
|
value = value.to_numpy()
|
||||||
|
|
||||||
|
self.nimgsR = value.shape[2]
|
||||||
|
self.xdim = value.shape[1]
|
||||||
|
self.ydim = value.shape[0]
|
||||||
|
|
||||||
|
self._referenceImages = (value.reshape(self.xdim * self.ydim, self.nimgsR).astype(np.float32))
|
||||||
|
|
||||||
|
def add_reference_images(self, data):
|
||||||
|
if self.reshape:
|
||||||
|
data = self.reshape_data(data)
|
||||||
|
elif isinstance(data, type(xr.DataArray())):
|
||||||
|
data = data.to_numpy()
|
||||||
|
|
||||||
|
if not ((data.shape[0]==self.ydim) and (data.shape[1]==self.xdim)):
|
||||||
|
raise DataSizeException
|
||||||
|
|
||||||
|
data = data.reshape(self.xdim * self.ydim)
|
||||||
|
|
||||||
|
self._referenceImages = np.append(self._referenceImages, data, axis=1)
|
||||||
|
|
||||||
|
def _remove_first_reference_images(self):
|
||||||
|
self._referenceImages = np.delete(self._referenceImages, 0, axis=1)
|
||||||
|
|
||||||
|
def update_reference_images(self, data):
|
||||||
|
self._remove_first_reference_images()
|
||||||
|
self.add_reference_images(data)
|
||||||
|
|
||||||
|
if self._mask is None:
|
||||||
|
self.mask = np.ones((self.ydim, self.xdim), dtype=np.uint8)
|
||||||
|
|
||||||
|
self.decompose_referenceImages()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mask(self):
|
||||||
|
return self._mask
|
||||||
|
|
||||||
|
@mask.setter
|
||||||
|
def mask(self, value):
|
||||||
|
if self.reshape:
|
||||||
|
value = self.reshape_data(value)
|
||||||
|
elif isinstance(value, type(xr.DataArray())):
|
||||||
|
value = value.to_numpy()
|
||||||
|
|
||||||
|
if not ((value.shape[0]==self.ydim) and (value.shape[1]==self.xdim)):
|
||||||
|
raise DataSizeException
|
||||||
|
|
||||||
|
self._mask = value
|
||||||
|
self.k = np.where(self._mask.flatten() == 1)[0]
|
||||||
|
|
||||||
|
def decompose_referenceImages(self):
|
||||||
|
self.P, self.L, self.U = lu(self._referenceImages[self.k, :].T @ self._referenceImages[self.k, :], permute_l = False, p_indices = True)
|
||||||
|
|
||||||
|
def _fringe_removal(self, absorptionImages, referenceImages=None, mask=None, reshape=None):
|
||||||
|
if not reshape is None:
|
||||||
|
self.reshape = reshape
|
||||||
|
if not referenceImages is None:
|
||||||
|
self.referenceImages = referenceImages
|
||||||
|
if not mask is None:
|
||||||
|
self.mask = mask
|
||||||
|
if self.P is None:
|
||||||
|
self.decompose_referenceImages()
|
||||||
|
|
||||||
|
absorptionImages = np.atleast_3d(absorptionImages)
|
||||||
|
|
||||||
|
if self.reshape:
|
||||||
|
absorptionImages = self.reshape_data(absorptionImages)
|
||||||
|
|
||||||
|
self.nimgs = absorptionImages.shape[2]
|
||||||
|
absorptionImages = (absorptionImages.reshape(self.xdim * self.ydim, self.nimgs).astype(np.float32))
|
||||||
|
|
||||||
|
optrefimages = np.zeros_like(absorptionImages, dtype=np.float32)
|
||||||
|
|
||||||
|
for j in range(self.nimgs):
|
||||||
|
b = self._referenceImages[self.k, :].T @ absorptionImages[self.k, j]
|
||||||
|
|
||||||
|
# Obtain coefficients c which minimize least-square residuals
|
||||||
|
c = np.linalg.solve(self.U, np.linalg.solve(self.L[self.P], b))
|
||||||
|
# Compute optimized reference image
|
||||||
|
optrefimages[:, j] = (self._referenceImages @ c)
|
||||||
|
|
||||||
|
return optrefimages
|
||||||
|
|
||||||
|
def fringe_removal(self, absorptionImages, referenceImages=None, mask=None, reshape=None):
|
||||||
|
|
||||||
|
res = self._fringe_removal(absorptionImages, referenceImages, mask, reshape)
|
||||||
|
res = res.reshape(self.ydim, self.xdim, self.nimgs)
|
||||||
|
return np.swapaxes(res, 0, 2)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user