You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

346 lines
12 KiB

1 year ago
1 year ago
1 year ago
1 year ago
1 year ago
1 year ago
1 year ago
1 year ago
1 year ago
1 year ago
1 year ago
1 year ago
1 year ago
1 year ago
1 year ago
1 year ago
1 year ago
1 year ago
1 year ago
1 year ago
  1. import numpy as np
  2. from scipy.linalg import lu
  3. import xarray as xr
  4. class InvalidDimException(Exception):
  5. "Raised when the program can not identify (index of images, x, y) axes."
  6. def __init__(self, dims):
  7. if len(dims)>3:
  8. self.message = 'The input data must have two or three axes: (index of images(alternative), x, y)'
  9. else:
  10. self.message = 'Can not identify (index of images(alternative), x, y) from ' + str(dims)
  11. super().__init__(self.message)
  12. class DataSizeException(Exception):
  13. "Raised when the shape of the data is not correct."
  14. def __init__(self):
  15. self.message = 'The input data size does not match.'
  16. super().__init__(self.message)
  17. class FringeRemoval():
  18. """
  19. FRINGEREMOVAL - Fringe removal and noise reduction from absorption images.
  20. Creates an optimal reference image for each absorption image in a set as
  21. a linear combination of reference images, with coefficients chosen to
  22. minimize the least-squares residuals between each absorption image and
  23. the optimal reference image. The coefficients are obtained by solving a
  24. linear set of equations using matrix inverse by LU decomposition.
  25. Application of the algorithm is described in C. F. Ockeloen et al, Improved
  26. detection of small atom numbers through image processing, arXiv:1007.2136 (2010).
  27. Adapted from a MATLAB script copy provided by Guoxian Su.
  28. Original Authors: Shannon Whitlock, Caspar Ockeloen
  29. Reference: C. F. Ockeloen, A. F. Tauschinsky, R. J. C. Spreeuw, and
  30. S. Whitlock, Improved detection of small atom numbers through
  31. image processing, arXiv:1007.2136
  32. May 2009;
  33. """
  34. def __init__(self) -> None:
  35. """Initialize the class
  36. """
  37. self.nimgsR = 0 # The number of the reference images
  38. self.xdim = 0 # The shape of x axis
  39. self.ydim = 0 # The shape of y axis
  40. self._mask = None # The mask array to choose the region of interest for fringes removal
  41. self._center = None # Set the mask array by center and span
  42. self._span = None
  43. self.reshape=True # If it is necessary to reshape the data from (index of images(alternative), x, y) to (y, x, index of images(alternative))
  44. self.P = None
  45. self.L = None
  46. self.U = None
  47. @property
  48. def center(self):
  49. """The getter of the center of region of insterest (ROI)
  50. :return: The center of region of insterest (ROI)
  51. :rtype: tuple
  52. """
  53. return self._center
  54. @center.setter
  55. def center(self, value):
  56. """The setter of the center of region of insterest (ROI)
  57. :param value: The center of region of insterest (ROI)
  58. :type value: tuple
  59. """
  60. self._mask = None
  61. self._center = value
  62. @property
  63. def span(self):
  64. """The getter of the span of region of insterest (ROI)
  65. :return: The span of region of insterest (ROI)
  66. :rtype: tuple
  67. """
  68. return self._span
  69. @span.setter
  70. def span(self, value):
  71. """The setter of the span of region of insterest (ROI)
  72. :param value: The span of region of insterest (ROI)
  73. :type value: tuple
  74. """
  75. self._mask = None
  76. self._span = value
  77. def reshape_data(self, data):
  78. """The function is to reshape the data to the correct shape.
  79. In order to minimize the calculation time, the data has to have a shape of (y, x, index of images(alternative)).
  80. However, usually the input data has a shape of (index of images(alternative), x, y).
  81. It can also convert the xarray DataArray and Dataset to numpy array.
  82. :param data: The input data.
  83. :type data: xarray, numpy array or list
  84. :raises InvalidDimException: Raised when the program can not identify (index of images, x, y) axes.
  85. :raises InvalidDimException: Raised when the shape of the data is not correct.
  86. :return: The data with correct shape
  87. :rtype: xarray, numpy array or list
  88. """
  89. if data is None:
  90. return data
  91. if isinstance(data, type(xr.DataArray())):
  92. dims = data.dims
  93. if len(dims)>3:
  94. raise InvalidDimException(dims)
  95. xAxis = None
  96. yAxis = None
  97. if len(dims) == 2:
  98. imageAxis = ''
  99. else:
  100. imageAxis = None
  101. for dim in dims:
  102. if (dim == 'x') or ('_x' in dim):
  103. xAxis = dim
  104. elif (dim == 'y') or ('_y' in dim):
  105. yAxis = dim
  106. else:
  107. imageAxis = dim
  108. if (xAxis is None) or (yAxis is None) or (imageAxis is None):
  109. raise InvalidDimException(dims)
  110. if len(dims) == 2:
  111. data = data.transpose(yAxis, xAxis)
  112. else:
  113. data = data.transpose(yAxis, xAxis, imageAxis)
  114. data = data.to_numpy()
  115. else:
  116. data = np.array(data)
  117. if len(data.shape) == 3:
  118. data = np.swapaxes(data, 0, 2)
  119. # data = np.swapaxes(data, 0, 1)
  120. elif len(data.shape) == 2:
  121. data = np.swapaxes(data, 0, 1)
  122. return data
  123. def _reshape_absorption_images(self, data):
  124. if data is None:
  125. return data
  126. if isinstance(data, type(xr.DataArray())):
  127. dims = data.dims
  128. if len(dims)>3:
  129. raise InvalidDimException(dims)
  130. xAxis = None
  131. yAxis = None
  132. if len(dims) == 2:
  133. imageAxis = ''
  134. else:
  135. imageAxis = None
  136. for dim in dims:
  137. if (dim == 'x') or ('_x' in dim):
  138. xAxis = dim
  139. elif (dim == 'y') or ('_y' in dim):
  140. yAxis = dim
  141. else:
  142. imageAxis = dim
  143. if (xAxis is None) or (yAxis is None) or (imageAxis is None):
  144. raise InvalidDimException(dims)
  145. if len(dims) == 2:
  146. data = data.transpose(yAxis, xAxis)
  147. else:
  148. data = data.transpose(yAxis, xAxis, imageAxis)
  149. self.nimgs = len(data[imageAxis])
  150. data = data.stack(axis=[yAxis, xAxis])
  151. else:
  152. data = np.array(data)
  153. if len(data.shape) == 3:
  154. data = np.swapaxes(data, 0, 2)
  155. # data = np.swapaxes(data, 0, 1)
  156. elif len(data.shape) == 2:
  157. data = np.swapaxes(data, 0, 1)
  158. return data
  159. @property
  160. def referenceImages(self):
  161. res = self._referenceImages.reshape(self.ydim, self.xdim, self.nimgsR)
  162. res = np.swapaxes(res, 0, 2)
  163. return res
  164. @referenceImages.setter
  165. def referenceImages(self, value):
  166. if value is None:
  167. self._referenceImages = None
  168. return
  169. if self.reshape:
  170. value = self.reshape_data(value)
  171. elif isinstance(value, type(xr.DataArray())):
  172. value = value.to_numpy()
  173. self.nimgsR = value.shape[2]
  174. self.xdim = value.shape[1]
  175. self.ydim = value.shape[0]
  176. self._referenceImages = (value.reshape(self.xdim * self.ydim, self.nimgsR).astype(np.float32))
  177. def add_reference_images(self, data):
  178. """Add a new reference images
  179. :param data: The new reference image.
  180. :type data: xarray, numpy array or list
  181. :raises DataSizeException: Raised when the shape of the data is not correct.
  182. """
  183. if self.reshape:
  184. data = self.reshape_data(data)
  185. elif isinstance(data, type(xr.DataArray())):
  186. data = data.to_numpy()
  187. if not ((data.shape[0]==self.ydim) and (data.shape[1]==self.xdim)):
  188. raise DataSizeException
  189. data = data.reshape(self.xdim * self.ydim)
  190. self._referenceImages = np.append(self._referenceImages, data, axis=1)
  191. def _remove_first_reference_images(self):
  192. """Remove the first reference images
  193. """
  194. self._referenceImages = np.delete(self._referenceImages, 0, axis=1)
  195. def update_reference_images(self, data):
  196. """Update the reference images set by removing the first one and adding a new one at the end.
  197. :param data: The new reference image.
  198. :type data: xarray, numpy array or list
  199. """
  200. self._remove_first_reference_images()
  201. self.add_reference_images(data)
  202. self.decompose_referenceImages()
  203. @property
  204. def mask(self):
  205. return self._mask
  206. @mask.setter
  207. def mask(self, value):
  208. if self.reshape:
  209. value = self.reshape_data(value)
  210. elif isinstance(value, type(xr.DataArray())):
  211. value = value.to_numpy()
  212. if not ((value.shape[0]==self.ydim) and (value.shape[1]==self.xdim)):
  213. raise DataSizeException
  214. self._mask = value
  215. self._center = None
  216. self._span = None
  217. self.k = np.where(self._mask.flatten() == 1)[0]
  218. def _auto_mask(self):
  219. mask = np.ones((self.ydim, self.xdim), dtype=np.uint8)
  220. if not self._center is None:
  221. x_start = int(self._center[0] - self._span[0] / 2)
  222. x_end = int(self._center[0] + self._span[0] / 2)
  223. y_end = int(self._center[1] + self._span[1] / 2)
  224. y_start = int(self._center[1] - self._span[1] / 2)
  225. mask[y_start:y_end, x_start:x_end] = 0
  226. return mask
  227. def decompose_referenceImages(self):
  228. if self._mask is None:
  229. self.mask = self._auto_mask()
  230. self.P, self.L, self.U = lu(self._referenceImages[self.k, :].T @ self._referenceImages[self.k, :], permute_l = False, p_indices = True)
  231. def _fringe_removal(self, absorptionImages):
  232. b = self.temp @ absorptionImages[self.k]
  233. c = np.linalg.solve(self.U, np.linalg.solve(self.L[self.P], b))
  234. optrefimages = (self._referenceImages @ c)
  235. return optrefimages
  236. def fringe_removal(self, absorptionImages, referenceImages=None, mask=None, reshape=None, dask='parallelized'):
  237. """
  238. This function will generate a 'fake' background images, which can help to remove the fringes.
  239. Important: Please substract the drak images from the both of images with atoms and without atoms before using this function!!!
  240. :param absorptionImages: A set of images with atoms in absorption imaging
  241. :type absorptionImages: xarray, numpy array or list
  242. :param referenceImages: A set of images without atoms in absorption imaging, defaults to None
  243. :type referenceImages: xarray, numpy array or list, optional
  244. :param mask: An array to choose the region of interest for fringes removal, defaults to None, defaults to None
  245. :type mask: numpy array, optional
  246. :param reshape: If it needs to reshape the data, defaults to None
  247. :type reshape: bool, optional
  248. :param dask: Please refer to xarray.apply_ufunc()
  249. :type dask: {"forbidden", "allowed", "parallelized"}, optional
  250. :return: The 'fake' background to help removing the fringes
  251. :rtype: xarray array
  252. """
  253. if not reshape is None:
  254. self.reshape = reshape
  255. if not referenceImages is None:
  256. self.referenceImages = referenceImages
  257. if not mask is None:
  258. self.mask = mask
  259. if self.P is None:
  260. self.decompose_referenceImages()
  261. if self.reshape:
  262. absorptionImages = self._reshape_absorption_images(absorptionImages)
  263. self.temp = self._referenceImages[self.k, :].T
  264. optrefimages = xr.apply_ufunc(self._fringe_removal, absorptionImages, input_core_dims=[['axis']], output_core_dims=[['axis']], dask=dask, vectorize=True, output_dtypes=float)
  265. return optrefimages.unstack()