analyseScript/Analyser/FringeRemoval.py
2023-09-28 16:59:50 +02:00

347 lines
12 KiB
Python

import numpy as np
from scipy.linalg import lu
import xarray as xr
class InvalidDimException(Exception):
"Raised when the program can not identify (index of images, x, y) axes."
def __init__(self, dims):
if len(dims)>3:
self.message = 'The input data must have two or three axes: (index of images(alternative), x, y)'
else:
self.message = 'Can not identify (index of images(alternative), x, y) from ' + str(dims)
super().__init__(self.message)
class DataSizeException(Exception):
"Raised when the shape of the data is not correct."
def __init__(self):
self.message = 'The input data size does not match.'
super().__init__(self.message)
class FringeRemoval():
"""
FRINGEREMOVAL - Fringe removal and noise reduction from absorption images.
Creates an optimal reference image for each absorption image in a set as
a linear combination of reference images, with coefficients chosen to
minimize the least-squares residuals between each absorption image and
the optimal reference image. The coefficients are obtained by solving a
linear set of equations using matrix inverse by LU decomposition.
Application of the algorithm is described in C. F. Ockeloen et al, Improved
detection of small atom numbers through image processing, arXiv:1007.2136 (2010).
Adapted from a MATLAB script copy provided by Guoxian Su.
Original Authors: Shannon Whitlock, Caspar Ockeloen
Reference: C. F. Ockeloen, A. F. Tauschinsky, R. J. C. Spreeuw, and
S. Whitlock, Improved detection of small atom numbers through
image processing, arXiv:1007.2136
May 2009;
"""
def __init__(self) -> None:
"""Initialize the class
"""
self.nimgsR = 0 # The number of the reference images
self.xdim = 0 # The shape of x axis
self.ydim = 0 # The shape of y axis
self._mask = None # The mask array to choose the region of interest for fringes removal
self._center = None # Set the mask array by center and span
self._span = None
self.reshape=True # If it is necessary to reshape the data from (index of images(alternative), x, y) to (y, x, index of images(alternative))
self.P = None
self.L = None
self.U = None
@property
def center(self):
"""The getter of the center of region of insterest (ROI)
:return: The center of region of insterest (ROI)
:rtype: tuple
"""
return self._center
@center.setter
def center(self, value):
"""The setter of the center of region of insterest (ROI)
:param value: The center of region of insterest (ROI)
:type value: tuple
"""
self._mask = None
self._center = value
@property
def span(self):
"""The getter of the span of region of insterest (ROI)
:return: The span of region of insterest (ROI)
:rtype: tuple
"""
return self._span
@span.setter
def span(self, value):
"""The setter of the span of region of insterest (ROI)
:param value: The span of region of insterest (ROI)
:type value: tuple
"""
self._mask = None
self._span = value
def reshape_data(self, data):
"""The function is to reshape the data to the correct shape.
In order to minimize the calculation time, the data has to have a shape of (y, x, index of images(alternative)).
However, usually the input data has a shape of (index of images(alternative), x, y).
It can also convert the xarray DataArray and Dataset to numpy array.
:param data: The input data.
:type data: xarray, numpy array or list
:raises InvalidDimException: Raised when the program can not identify (index of images, x, y) axes.
:raises InvalidDimException: Raised when the shape of the data is not correct.
:return: The data with correct shape
:rtype: xarray, numpy array or list
"""
if data is None:
return data
if isinstance(data, type(xr.DataArray())):
dims = data.dims
if len(dims)>3:
raise InvalidDimException(dims)
xAxis = None
yAxis = None
if len(dims) == 2:
imageAxis = ''
else:
imageAxis = None
for dim in dims:
if (dim == 'x') or ('_x' in dim):
xAxis = dim
elif (dim == 'y') or ('_y' in dim):
yAxis = dim
else:
imageAxis = dim
if (xAxis is None) or (yAxis is None) or (imageAxis is None):
raise InvalidDimException(dims)
if len(dims) == 2:
data = data.transpose(yAxis, xAxis)
else:
data = data.transpose(yAxis, xAxis, imageAxis)
data = data.to_numpy()
else:
data = np.array(data)
if len(data.shape) == 3:
data = np.swapaxes(data, 0, 2)
# data = np.swapaxes(data, 0, 1)
elif len(data.shape) == 2:
data = np.swapaxes(data, 0, 1)
return data
def _reshape_absorption_images(self, data):
if data is None:
return data
if isinstance(data, type(xr.DataArray())):
dims = data.dims
if len(dims)>3:
raise InvalidDimException(dims)
xAxis = None
yAxis = None
if len(dims) == 2:
imageAxis = ''
else:
imageAxis = None
for dim in dims:
if (dim == 'x') or ('_x' in dim):
xAxis = dim
elif (dim == 'y') or ('_y' in dim):
yAxis = dim
else:
imageAxis = dim
if (xAxis is None) or (yAxis is None) or (imageAxis is None):
raise InvalidDimException(dims)
if len(dims) == 2:
data = data.transpose(yAxis, xAxis)
else:
data = data.transpose(yAxis, xAxis, imageAxis)
self.nimgs = len(data[imageAxis])
data = data.stack(axis=[yAxis, xAxis])
else:
data = np.array(data)
if len(data.shape) == 3:
data = np.swapaxes(data, 0, 2)
# data = np.swapaxes(data, 0, 1)
elif len(data.shape) == 2:
data = np.swapaxes(data, 0, 1)
return data
@property
def referenceImages(self):
res = self._referenceImages.reshape(self.ydim, self.xdim, self.nimgsR)
res = np.swapaxes(res, 0, 2)
return res
@referenceImages.setter
def referenceImages(self, value):
if value is None:
self._referenceImages = None
return
if self.reshape:
value = self.reshape_data(value)
elif isinstance(value, type(xr.DataArray())):
value = value.to_numpy()
self.nimgsR = value.shape[2]
self.xdim = value.shape[1]
self.ydim = value.shape[0]
self._referenceImages = (value.reshape(self.xdim * self.ydim, self.nimgsR).astype(np.float32))
def add_reference_images(self, data):
"""Add a new reference images
:param data: The new reference image.
:type data: xarray, numpy array or list
:raises DataSizeException: Raised when the shape of the data is not correct.
"""
if self.reshape:
data = self.reshape_data(data)
elif isinstance(data, type(xr.DataArray())):
data = data.to_numpy()
if not ((data.shape[0]==self.ydim) and (data.shape[1]==self.xdim)):
raise DataSizeException
data = data.reshape(self.xdim * self.ydim)
self._referenceImages = np.append(self._referenceImages, data, axis=1)
def _remove_first_reference_images(self):
"""Remove the first reference images
"""
self._referenceImages = np.delete(self._referenceImages, 0, axis=1)
def update_reference_images(self, data):
"""Update the reference images set by removing the first one and adding a new one at the end.
:param data: The new reference image.
:type data: xarray, numpy array or list
"""
self._remove_first_reference_images()
self.add_reference_images(data)
self.decompose_referenceImages()
@property
def mask(self):
return self._mask
@mask.setter
def mask(self, value):
if self.reshape:
value = self.reshape_data(value)
elif isinstance(value, type(xr.DataArray())):
value = value.to_numpy()
if not ((value.shape[0]==self.ydim) and (value.shape[1]==self.xdim)):
raise DataSizeException
self._mask = value
self._center = None
self._span = None
self.k = np.where(self._mask.flatten() == 1)[0]
def _auto_mask(self):
mask = np.ones((self.ydim, self.xdim), dtype=np.uint8)
if not self._center is None:
x_start = int(self._center[0] - self._span[0] / 2)
x_end = int(self._center[0] + self._span[0] / 2)
y_end = int(self._center[1] + self._span[1] / 2)
y_start = int(self._center[1] - self._span[1] / 2)
mask[y_start:y_end, x_start:x_end] = 0
return mask
def decompose_referenceImages(self):
if self._mask is None:
self.mask = self._auto_mask()
self.P, self.L, self.U = lu(self._referenceImages[self.k, :].T @ self._referenceImages[self.k, :], permute_l = False, p_indices = True)
def _fringe_removal(self, absorptionImages):
b = self.temp @ absorptionImages[self.k]
c = np.linalg.solve(self.U, np.linalg.solve(self.L[self.P], b))
optrefimages = (self._referenceImages @ c)
return optrefimages
def fringe_removal(self, absorptionImages, referenceImages=None, mask=None, reshape=None, dask='parallelized'):
"""
This function will generate a 'fake' background images, which can help to remove the fringes.
Important: Please substract the drak images from the both of images with atoms and without atoms before using this function!!!
:param absorptionImages: A set of images with atoms in absorption imaging
:type absorptionImages: xarray, numpy array or list
:param referenceImages: A set of images without atoms in absorption imaging, defaults to None
:type referenceImages: xarray, numpy array or list, optional
:param mask: An array to choose the region of interest for fringes removal, defaults to None, defaults to None
:type mask: numpy array, optional
:param reshape: If it needs to reshape the data, defaults to None
:type reshape: bool, optional
:param dask: Please refer to xarray.apply_ufunc()
:type dask: {"forbidden", "allowed", "parallelized"}, optional
:return: The 'fake' background to help removing the fringes
:rtype: xarray array
"""
if not reshape is None:
self.reshape = reshape
if not referenceImages is None:
self.referenceImages = referenceImages
if not mask is None:
self.mask = mask
if self.P is None:
self.decompose_referenceImages()
if self.reshape:
absorptionImages = self._reshape_absorption_images(absorptionImages)
self.temp = self._referenceImages[self.k, :].T
optrefimages = xr.apply_ufunc(self._fringe_removal, absorptionImages, input_core_dims=[['axis']], output_core_dims=[['axis']], dask=dask, vectorize=True, output_dtypes=float)
return optrefimages.unstack()