implenment parallel computing

This commit is contained in:
Jianshun Gao 2023-09-28 15:04:56 +02:00
parent b17c05bf30
commit 6d841ef992

View File

@ -32,6 +32,8 @@ class FringeRemoval():
self.ydim = 0 # The shape of y axis self.ydim = 0 # The shape of y axis
self._mask = None # The mask array to choose the region of interest for fringes removal self._mask = None # The mask array to choose the region of interest for fringes removal
self._center = None # Set the mask array by center and span
self._span = None
self.reshape=True # If it is necessary to reshape the data from (index of images(alternative), x, y) to (y, x, index of images(alternative)) self.reshape=True # If it is necessary to reshape the data from (index of images(alternative), x, y) to (y, x, index of images(alternative))
@ -39,6 +41,44 @@ class FringeRemoval():
self.L = None self.L = None
self.U = None self.U = None
@property
def center(self):
"""The getter of the center of region of insterest (ROI)
:return: The center of region of insterest (ROI)
:rtype: tuple
"""
return self._center
@center.setter
def center(self, value):
"""The setter of the center of region of insterest (ROI)
:param value: The center of region of insterest (ROI)
:type value: tuple
"""
self._mask = None
self._center = value
@property
def span(self):
"""The getter of the span of region of insterest (ROI)
:return: The span of region of insterest (ROI)
:rtype: tuple
"""
return self._span
@span.setter
def span(self, value):
"""The setter of the span of region of insterest (ROI)
:param value: The span of region of insterest (ROI)
:type value: tuple
"""
self._mask = None
self._span = value
def reshape_data(self, data): def reshape_data(self, data):
"""The function is to reshape the data to the correct shape. """The function is to reshape the data to the correct shape.
In order to minimize the calculation time, the data has to have a shape of (y, x, index of images(alternative)). In order to minimize the calculation time, the data has to have a shape of (y, x, index of images(alternative)).
@ -98,6 +138,55 @@ class FringeRemoval():
return data return data
def _reshape_absorption_images(self, data):
if data is None:
return data
if isinstance(data, type(xr.DataArray())):
dims = data.dims
if len(dims)>3:
raise InvalidDimException(dims)
xAxis = None
yAxis = None
if len(dims) == 2:
imageAxis = ''
else:
imageAxis = None
for dim in dims:
if (dim == 'x') or ('_x' in dim):
xAxis = dim
elif (dim == 'y') or ('_y' in dim):
yAxis = dim
else:
imageAxis = dim
if (xAxis is None) or (yAxis is None) or (imageAxis is None):
raise InvalidDimException(dims)
if len(dims) == 2:
data = data.transpose(yAxis, xAxis)
else:
data = data.transpose(yAxis, xAxis, imageAxis)
self.nimgs = len(data[imageAxis])
data = data.stack(axis=[yAxis, xAxis])
else:
data = np.array(data)
if len(data.shape) == 3:
data = np.swapaxes(data, 0, 2)
# data = np.swapaxes(data, 0, 1)
elif len(data.shape) == 2:
data = np.swapaxes(data, 0, 1)
return data
@property @property
def referenceImages(self): def referenceImages(self):
res = self._referenceImages.reshape(self.ydim, self.xdim, self.nimgsR) res = self._referenceImages.reshape(self.ydim, self.xdim, self.nimgsR)
@ -151,9 +240,6 @@ class FringeRemoval():
self._remove_first_reference_images() self._remove_first_reference_images()
self.add_reference_images(data) self.add_reference_images(data)
if self._mask is None:
self.mask = np.ones((self.ydim, self.xdim), dtype=np.uint8)
self.decompose_referenceImages() self.decompose_referenceImages()
@property @property
@ -173,46 +259,32 @@ class FringeRemoval():
self._mask = value self._mask = value
self.k = np.where(self._mask.flatten() == 1)[0] self.k = np.where(self._mask.flatten() == 1)[0]
def _auto_mask(self):
mask = np.ones((self.ydim, self.xdim), dtype=np.uint8)
if not self._center is None:
x_start = int(self._center[0] - self._span[0] / 2)
x_end = int(self._center[0] + self._span[0] / 2)
y_end = int(self._center[1] + self._span[1] / 2)
y_start = int(self._center[1] - self._span[1] / 2)
mask[y_start:y_end, x_start:x_end] = 0
return mask
def decompose_referenceImages(self): def decompose_referenceImages(self):
if self._mask is None:
self.mask = self._auto_mask()
self.P, self.L, self.U = lu(self._referenceImages[self.k, :].T @ self._referenceImages[self.k, :], permute_l = False, p_indices = True) self.P, self.L, self.U = lu(self._referenceImages[self.k, :].T @ self._referenceImages[self.k, :], permute_l = False, p_indices = True)
def solve_coefficient(self): def _fringe_removal(self, absorptionImages):
pass
def _fringe_removal(self, absorptionImages, referenceImages=None, mask=None, reshape=None, dask='forbidden'): b = self.temp @ absorptionImages[self.k]
if not reshape is None: c = np.linalg.solve(self.U, np.linalg.solve(self.L[self.P], b))
self.reshape = reshape optrefimages = (self._referenceImages @ c)
if not referenceImages is None:
self.referenceImages = referenceImages
if not mask is None:
self.mask = mask
if self.P is None:
self.decompose_referenceImages()
absorptionImages = np.atleast_3d(absorptionImages)
if self.reshape:
absorptionImages = self.reshape_data(absorptionImages)
self.nimgs = absorptionImages.shape[2]
absorptionImages = (absorptionImages.reshape(self.xdim * self.ydim, self.nimgs).astype(np.float32))
optrefimages = np.zeros_like(absorptionImages, dtype=np.float32)
if dask=='forbidden':
for j in range(self.nimgs):
b = self._referenceImages[self.k, :].T @ absorptionImages[self.k, j]
# Obtain coefficients c which minimize least-square residuals
c = np.linalg.solve(self.U, np.linalg.solve(self.L[self.P], b))
# Compute optimized reference image
optrefimages[:, j] = (self._referenceImages @ c)
else:
pass
return optrefimages return optrefimages
def fringe_removal(self, absorptionImages, referenceImages=None, mask=None, reshape=None, dask='forbidden'): def fringe_removal(self, absorptionImages, referenceImages=None, mask=None, reshape=None, dask='parallelized'):
""" """
This function will generate a 'fake' background images, which can help to remove the fringes. This function will generate a 'fake' background images, which can help to remove the fringes.
@ -229,13 +301,22 @@ class FringeRemoval():
:param dask: Please refer to xarray.apply_ufunc() :param dask: Please refer to xarray.apply_ufunc()
:type dask: {"forbidden", "allowed", "parallelized"}, optional :type dask: {"forbidden", "allowed", "parallelized"}, optional
:return: The 'fake' background to help removing the fringes :return: The 'fake' background to help removing the fringes
:rtype: numpy array :rtype: xarray array
""" """
if not reshape is None:
self.reshape = reshape
if not referenceImages is None:
self.referenceImages = referenceImages
if not mask is None:
self.mask = mask
if self.P is None:
self.decompose_referenceImages()
res = self._fringe_removal(absorptionImages, referenceImages, mask, reshape, dask)
res = res.reshape(self.ydim, self.xdim, self.nimgs)
if self.reshape: if self.reshape:
return np.swapaxes(res, 0, 2) absorptionImages = self._reshape_absorption_images(absorptionImages)
else:
return res
self.temp = self._referenceImages[self.k, :].T
optrefimages = xr.apply_ufunc(self._fringe_removal, absorptionImages, input_core_dims=[['axis']], output_core_dims=[['axis']], dask=dask, vectorize=True, output_dtypes=float)
return optrefimages.unstack()