analyseScript/Analyser/FringeRemoval.py

322 lines
11 KiB
Python

import numpy as np
from scipy.linalg import lu
import xarray as xr
class InvalidDimException(Exception):
"Raised when the program can not identify (index of images, x, y) axes."
def __init__(self, dims):
if len(dims)>3:
self.message = 'The input data must have two or three axes: (index of images(alternative), x, y)'
else:
self.message = 'Can not identify (index of images(alternative), x, y) from ' + str(dims)
super().__init__(self.message)
class DataSizeException(Exception):
"Raised when the shape of the data is not correct."
def __init__(self):
self.message = 'The input data size does not match.'
super().__init__(self.message)
class FringeRemoval():
"""A class for fringes removal
"""
def __init__(self) -> None:
"""Initialize the class
"""
self.nimgsR = 0 # The number of the reference images
self.xdim = 0 # The shape of x axis
self.ydim = 0 # The shape of y axis
self._mask = None # The mask array to choose the region of interest for fringes removal
self._center = None # Set the mask array by center and span
self._span = None
self.reshape=True # If it is necessary to reshape the data from (index of images(alternative), x, y) to (y, x, index of images(alternative))
self.P = None
self.L = None
self.U = None
@property
def center(self):
"""The getter of the center of region of insterest (ROI)
:return: The center of region of insterest (ROI)
:rtype: tuple
"""
return self._center
@center.setter
def center(self, value):
"""The setter of the center of region of insterest (ROI)
:param value: The center of region of insterest (ROI)
:type value: tuple
"""
self._mask = None
self._center = value
@property
def span(self):
"""The getter of the span of region of insterest (ROI)
:return: The span of region of insterest (ROI)
:rtype: tuple
"""
return self._span
@span.setter
def span(self, value):
"""The setter of the span of region of insterest (ROI)
:param value: The span of region of insterest (ROI)
:type value: tuple
"""
self._mask = None
self._span = value
def reshape_data(self, data):
"""The function is to reshape the data to the correct shape.
In order to minimize the calculation time, the data has to have a shape of (y, x, index of images(alternative)).
However, usually the input data has a shape of (index of images(alternative), x, y).
It can also convert the xarray DataArray and Dataset to numpy array.
:param data: The input data.
:type data: xarray, numpy array or list
:raises InvalidDimException: Raised when the program can not identify (index of images, x, y) axes.
:raises InvalidDimException: Raised when the shape of the data is not correct.
:return: The data with correct shape
:rtype: xarray, numpy array or list
"""
if data is None:
return data
if isinstance(data, type(xr.DataArray())):
dims = data.dims
if len(dims)>3:
raise InvalidDimException(dims)
xAxis = None
yAxis = None
if len(dims) == 2:
imageAxis = ''
else:
imageAxis = None
for dim in dims:
if (dim == 'x') or ('_x' in dim):
xAxis = dim
elif (dim == 'y') or ('_y' in dim):
yAxis = dim
else:
imageAxis = dim
if (xAxis is None) or (yAxis is None) or (imageAxis is None):
raise InvalidDimException(dims)
if len(dims) == 2:
data = data.transpose(yAxis, xAxis)
else:
data = data.transpose(yAxis, xAxis, imageAxis)
data = data.to_numpy()
else:
data = np.array(data)
if len(data.shape) == 3:
data = np.swapaxes(data, 0, 2)
# data = np.swapaxes(data, 0, 1)
elif len(data.shape) == 2:
data = np.swapaxes(data, 0, 1)
return data
def _reshape_absorption_images(self, data):
if data is None:
return data
if isinstance(data, type(xr.DataArray())):
dims = data.dims
if len(dims)>3:
raise InvalidDimException(dims)
xAxis = None
yAxis = None
if len(dims) == 2:
imageAxis = ''
else:
imageAxis = None
for dim in dims:
if (dim == 'x') or ('_x' in dim):
xAxis = dim
elif (dim == 'y') or ('_y' in dim):
yAxis = dim
else:
imageAxis = dim
if (xAxis is None) or (yAxis is None) or (imageAxis is None):
raise InvalidDimException(dims)
if len(dims) == 2:
data = data.transpose(yAxis, xAxis)
else:
data = data.transpose(yAxis, xAxis, imageAxis)
self.nimgs = len(data[imageAxis])
data = data.stack(axis=[yAxis, xAxis])
else:
data = np.array(data)
if len(data.shape) == 3:
data = np.swapaxes(data, 0, 2)
# data = np.swapaxes(data, 0, 1)
elif len(data.shape) == 2:
data = np.swapaxes(data, 0, 1)
return data
@property
def referenceImages(self):
res = self._referenceImages.reshape(self.ydim, self.xdim, self.nimgsR)
res = np.swapaxes(res, 0, 2)
return res
@referenceImages.setter
def referenceImages(self, value):
if self.reshape:
value = self.reshape_data(value)
elif isinstance(value, type(xr.DataArray())):
value = value.to_numpy()
self.nimgsR = value.shape[2]
self.xdim = value.shape[1]
self.ydim = value.shape[0]
self._referenceImages = (value.reshape(self.xdim * self.ydim, self.nimgsR).astype(np.float32))
def add_reference_images(self, data):
"""Add a new reference images
:param data: The new reference image.
:type data: xarray, numpy array or list
:raises DataSizeException: Raised when the shape of the data is not correct.
"""
if self.reshape:
data = self.reshape_data(data)
elif isinstance(data, type(xr.DataArray())):
data = data.to_numpy()
if not ((data.shape[0]==self.ydim) and (data.shape[1]==self.xdim)):
raise DataSizeException
data = data.reshape(self.xdim * self.ydim)
self._referenceImages = np.append(self._referenceImages, data, axis=1)
def _remove_first_reference_images(self):
"""Remove the first reference images
"""
self._referenceImages = np.delete(self._referenceImages, 0, axis=1)
def update_reference_images(self, data):
"""Update the reference images set by removing the first one and adding a new one at the end.
:param data: The new reference image.
:type data: xarray, numpy array or list
"""
self._remove_first_reference_images()
self.add_reference_images(data)
self.decompose_referenceImages()
@property
def mask(self):
return self._mask
@mask.setter
def mask(self, value):
if self.reshape:
value = self.reshape_data(value)
elif isinstance(value, type(xr.DataArray())):
value = value.to_numpy()
if not ((value.shape[0]==self.ydim) and (value.shape[1]==self.xdim)):
raise DataSizeException
self._mask = value
self.k = np.where(self._mask.flatten() == 1)[0]
def _auto_mask(self):
mask = np.ones((self.ydim, self.xdim), dtype=np.uint8)
if not self._center is None:
x_start = int(self._center[0] - self._span[0] / 2)
x_end = int(self._center[0] + self._span[0] / 2)
y_end = int(self._center[1] + self._span[1] / 2)
y_start = int(self._center[1] - self._span[1] / 2)
mask[y_start:y_end, x_start:x_end] = 0
return mask
def decompose_referenceImages(self):
if self._mask is None:
self.mask = self._auto_mask()
self.P, self.L, self.U = lu(self._referenceImages[self.k, :].T @ self._referenceImages[self.k, :], permute_l = False, p_indices = True)
def _fringe_removal(self, absorptionImages):
b = self.temp @ absorptionImages[self.k]
c = np.linalg.solve(self.U, np.linalg.solve(self.L[self.P], b))
optrefimages = (self._referenceImages @ c)
return optrefimages
def fringe_removal(self, absorptionImages, referenceImages=None, mask=None, reshape=None, dask='parallelized'):
"""
This function will generate a 'fake' background images, which can help to remove the fringes.
Important: Please substract the drak images from the both of images with atoms and without atoms before using this function!!!
:param absorptionImages: A set of images with atoms in absorption imaging
:type absorptionImages: xarray, numpy array or list
:param referenceImages: A set of images without atoms in absorption imaging, defaults to None
:type referenceImages: xarray, numpy array or list, optional
:param mask: An array to choose the region of interest for fringes removal, defaults to None, defaults to None
:type mask: numpy array, optional
:param reshape: If it needs to reshape the data, defaults to None
:type reshape: bool, optional
:param dask: Please refer to xarray.apply_ufunc()
:type dask: {"forbidden", "allowed", "parallelized"}, optional
:return: The 'fake' background to help removing the fringes
:rtype: xarray array
"""
if not reshape is None:
self.reshape = reshape
if not referenceImages is None:
self.referenceImages = referenceImages
if not mask is None:
self.mask = mask
if self.P is None:
self.decompose_referenceImages()
if self.reshape:
absorptionImages = self._reshape_absorption_images(absorptionImages)
self.temp = self._referenceImages[self.k, :].T
optrefimages = xr.apply_ufunc(self._fringe_removal, absorptionImages, input_core_dims=[['axis']], output_core_dims=[['axis']], dask=dask, vectorize=True, output_dtypes=float)
return optrefimages.unstack()