Compare commits

..

4 Commits

Author SHA1 Message Date
7a0d102850 delete personal scripts 2023-09-15 11:21:38 +02:00
81bd1ea4ee Merge branch 'add_rotation' 2023-09-15 11:20:03 +02:00
62ad8157b9 Merge branch 'add_rotation' 2023-09-14 14:15:35 +02:00
Gao
aa13f41dd9 regular backup 2023-09-08 14:20:25 +02:00
11 changed files with 379 additions and 8360 deletions

View File

@ -1,347 +0,0 @@
import numpy as np
from scipy.linalg import lu
import xarray as xr
class InvalidDimException(Exception):
"Raised when the program can not identify (index of images, x, y) axes."
def __init__(self, dims):
if len(dims)>3:
self.message = 'The input data must have two or three axes: (index of images(alternative), x, y)'
else:
self.message = 'Can not identify (index of images(alternative), x, y) from ' + str(dims)
super().__init__(self.message)
class DataSizeException(Exception):
"Raised when the shape of the data is not correct."
def __init__(self):
self.message = 'The input data size does not match.'
super().__init__(self.message)
class FringeRemoval():
"""
FRINGEREMOVAL - Fringe removal and noise reduction from absorption images.
Creates an optimal reference image for each absorption image in a set as
a linear combination of reference images, with coefficients chosen to
minimize the least-squares residuals between each absorption image and
the optimal reference image. The coefficients are obtained by solving a
linear set of equations using matrix inverse by LU decomposition.
Application of the algorithm is described in C. F. Ockeloen et al, Improved
detection of small atom numbers through image processing, arXiv:1007.2136 (2010).
Adapted from a MATLAB script copy provided by Guoxian Su.
Original Authors: Shannon Whitlock, Caspar Ockeloen
Reference: C. F. Ockeloen, A. F. Tauschinsky, R. J. C. Spreeuw, and
S. Whitlock, Improved detection of small atom numbers through
image processing, arXiv:1007.2136
May 2009;
"""
def __init__(self) -> None:
"""Initialize the class
"""
self.nimgsR = 0 # The number of the reference images
self.xdim = 0 # The shape of x axis
self.ydim = 0 # The shape of y axis
self._mask = None # The mask array to choose the region of interest for fringes removal
self._center = None # Set the mask array by center and span
self._span = None
self.reshape=True # If it is necessary to reshape the data from (index of images(alternative), x, y) to (y, x, index of images(alternative))
self.P = None
self.L = None
self.U = None
@property
def center(self):
"""The getter of the center of region of insterest (ROI)
:return: The center of region of insterest (ROI)
:rtype: tuple
"""
return self._center
@center.setter
def center(self, value):
"""The setter of the center of region of insterest (ROI)
:param value: The center of region of insterest (ROI)
:type value: tuple
"""
self._mask = None
self._center = value
@property
def span(self):
"""The getter of the span of region of insterest (ROI)
:return: The span of region of insterest (ROI)
:rtype: tuple
"""
return self._span
@span.setter
def span(self, value):
"""The setter of the span of region of insterest (ROI)
:param value: The span of region of insterest (ROI)
:type value: tuple
"""
self._mask = None
self._span = value
def reshape_data(self, data):
"""The function is to reshape the data to the correct shape.
In order to minimize the calculation time, the data has to have a shape of (y, x, index of images(alternative)).
However, usually the input data has a shape of (index of images(alternative), x, y).
It can also convert the xarray DataArray and Dataset to numpy array.
:param data: The input data.
:type data: xarray, numpy array or list
:raises InvalidDimException: Raised when the program can not identify (index of images, x, y) axes.
:raises InvalidDimException: Raised when the shape of the data is not correct.
:return: The data with correct shape
:rtype: xarray, numpy array or list
"""
if data is None:
return data
if isinstance(data, type(xr.DataArray())):
dims = data.dims
if len(dims)>3:
raise InvalidDimException(dims)
xAxis = None
yAxis = None
if len(dims) == 2:
imageAxis = ''
else:
imageAxis = None
for dim in dims:
if (dim == 'x') or ('_x' in dim):
xAxis = dim
elif (dim == 'y') or ('_y' in dim):
yAxis = dim
else:
imageAxis = dim
if (xAxis is None) or (yAxis is None) or (imageAxis is None):
raise InvalidDimException(dims)
if len(dims) == 2:
data = data.transpose(yAxis, xAxis)
else:
data = data.transpose(yAxis, xAxis, imageAxis)
data = data.to_numpy()
else:
data = np.array(data)
if len(data.shape) == 3:
data = np.swapaxes(data, 0, 2)
# data = np.swapaxes(data, 0, 1)
elif len(data.shape) == 2:
data = np.swapaxes(data, 0, 1)
return data
def _reshape_absorption_images(self, data):
if data is None:
return data
if isinstance(data, type(xr.DataArray())):
dims = data.dims
if len(dims)>3:
raise InvalidDimException(dims)
xAxis = None
yAxis = None
if len(dims) == 2:
imageAxis = ''
else:
imageAxis = None
for dim in dims:
if (dim == 'x') or ('_x' in dim):
xAxis = dim
elif (dim == 'y') or ('_y' in dim):
yAxis = dim
else:
imageAxis = dim
if (xAxis is None) or (yAxis is None) or (imageAxis is None):
raise InvalidDimException(dims)
if len(dims) == 2:
data = data.transpose(yAxis, xAxis)
else:
data = data.transpose(yAxis, xAxis, imageAxis)
self.nimgs = len(data[imageAxis])
data = data.stack(axis=[yAxis, xAxis])
else:
data = np.array(data)
if len(data.shape) == 3:
data = np.swapaxes(data, 0, 2)
# data = np.swapaxes(data, 0, 1)
elif len(data.shape) == 2:
data = np.swapaxes(data, 0, 1)
return data
@property
def referenceImages(self):
res = self._referenceImages.reshape(self.ydim, self.xdim, self.nimgsR)
res = np.swapaxes(res, 0, 2)
return res
@referenceImages.setter
def referenceImages(self, value):
if value is None:
self._referenceImages = None
return
if self.reshape:
value = self.reshape_data(value)
elif isinstance(value, type(xr.DataArray())):
value = value.to_numpy()
self.nimgsR = value.shape[2]
self.xdim = value.shape[1]
self.ydim = value.shape[0]
self._referenceImages = (value.reshape(self.xdim * self.ydim, self.nimgsR).astype(np.float32))
def add_reference_images(self, data):
"""Add a new reference images
:param data: The new reference image.
:type data: xarray, numpy array or list
:raises DataSizeException: Raised when the shape of the data is not correct.
"""
if self.reshape:
data = self.reshape_data(data)
elif isinstance(data, type(xr.DataArray())):
data = data.to_numpy()
if not ((data.shape[0]==self.ydim) and (data.shape[1]==self.xdim)):
raise DataSizeException
data = data.reshape(self.xdim * self.ydim)
self._referenceImages = np.append(self._referenceImages, data, axis=1)
def _remove_first_reference_images(self):
"""Remove the first reference images
"""
self._referenceImages = np.delete(self._referenceImages, 0, axis=1)
def update_reference_images(self, data):
"""Update the reference images set by removing the first one and adding a new one at the end.
:param data: The new reference image.
:type data: xarray, numpy array or list
"""
self._remove_first_reference_images()
self.add_reference_images(data)
self.decompose_referenceImages()
@property
def mask(self):
return self._mask
@mask.setter
def mask(self, value):
if self.reshape:
value = self.reshape_data(value)
elif isinstance(value, type(xr.DataArray())):
value = value.to_numpy()
if not ((value.shape[0]==self.ydim) and (value.shape[1]==self.xdim)):
raise DataSizeException
self._mask = value
self._center = None
self._span = None
self.k = np.where(self._mask.flatten() == 1)[0]
def _auto_mask(self):
mask = np.ones((self.ydim, self.xdim), dtype=np.uint8)
if not self._center is None:
x_start = int(self._center[0] - self._span[0] / 2)
x_end = int(self._center[0] + self._span[0] / 2)
y_end = int(self._center[1] + self._span[1] / 2)
y_start = int(self._center[1] - self._span[1] / 2)
mask[y_start:y_end, x_start:x_end] = 0
return mask
def decompose_referenceImages(self):
if self._mask is None:
self.mask = self._auto_mask()
self.P, self.L, self.U = lu(self._referenceImages[self.k, :].T @ self._referenceImages[self.k, :], permute_l = False, p_indices = True)
def _fringe_removal(self, absorptionImages):
b = self.temp @ absorptionImages[self.k]
c = np.linalg.solve(self.U, np.linalg.solve(self.L[self.P], b))
optrefimages = (self._referenceImages @ c)
return optrefimages
def fringe_removal(self, absorptionImages, referenceImages=None, mask=None, reshape=None, dask='parallelized'):
"""
This function will generate a 'fake' background images, which can help to remove the fringes.
Important: Please substract the drak images from the both of images with atoms and without atoms before using this function!!!
:param absorptionImages: A set of images with atoms in absorption imaging
:type absorptionImages: xarray, numpy array or list
:param referenceImages: A set of images without atoms in absorption imaging, defaults to None
:type referenceImages: xarray, numpy array or list, optional
:param mask: An array to choose the region of interest for fringes removal, defaults to None, defaults to None
:type mask: numpy array, optional
:param reshape: If it needs to reshape the data, defaults to None
:type reshape: bool, optional
:param dask: Please refer to xarray.apply_ufunc()
:type dask: {"forbidden", "allowed", "parallelized"}, optional
:return: The 'fake' background to help removing the fringes
:rtype: xarray array
"""
if not reshape is None:
self.reshape = reshape
if not referenceImages is None:
self.referenceImages = referenceImages
if not mask is None:
self.mask = mask
if self.P is None:
self.decompose_referenceImages()
if self.reshape:
absorptionImages = self._reshape_absorption_images(absorptionImages)
self.temp = self._referenceImages[self.k, :].T
optrefimages = xr.apply_ufunc(self._fringe_removal, absorptionImages, input_core_dims=[['axis']], output_core_dims=[['axis']], dask=dask, vectorize=True, output_dtypes=float)
return optrefimages.unstack()

View File

@ -2,10 +2,6 @@ import numpy as np
import xarray as xr import xarray as xr
import copy import copy
from DataContainer.ReadData import read_hdf5_file
from Analyser.FringeRemoval import FringeRemoval
from ToolFunction.ToolFunction import get_scanAxis
class ImageAnalyser(): class ImageAnalyser():
"""A class for operate with and analyse images """A class for operate with and analyse images
@ -19,15 +15,11 @@ class ImageAnalyser():
'background': 'background', 'background': 'background',
'dark': 'dark', 'dark': 'dark',
'OD':'OD', 'OD':'OD',
'optimumBackground':'optimumBackground'
} }
self._center = None self._center = None
self._span = None self._span = None
self._fraction = None self._fraction = None
self._fringeRemoval = FringeRemoval()
self.fringeRemovalReferenceImages = None
@property @property
def image_name(self): def image_name(self):
"""The getter of the names of three standard images for absorption images """The getter of the names of three standard images for absorption images
@ -158,7 +150,7 @@ class ImageAnalyser():
res.attrs = copy.copy(dataArray.attrs) res.attrs = copy.copy(dataArray.attrs)
return res return res
def crop_image(self, dataSet, center=None, span=None, fringeRemoval=False): def crop_image(self, dataSet, center=None, span=None):
"""Crop the image according to the region of interest (ROI). """Crop the image according to the region of interest (ROI).
:param dataSet: The images :param dataSet: The images
@ -167,8 +159,6 @@ class ImageAnalyser():
:type center: tuple, optional :type center: tuple, optional
:param span: the span of region of insterest (ROI), defaults to None :param span: the span of region of insterest (ROI), defaults to None
:type span: tuple, optional :type span: tuple, optional
:param fringeRemoval: If also crop the reference background images for finges removal function, defaults to False
:type fringeRemoval: bool, optional
:return: The croped images :return: The croped images
:rtype: xarray DataArray or DataSet :rtype: xarray DataArray or DataSet
""" """
@ -177,6 +167,8 @@ class ImageAnalyser():
center = self._center center = self._center
if span is None: if span is None:
span = self._span span = self._span
if not x in
x_start = int(center[0] - span[0] / 2) x_start = int(center[0] - span[0] / 2)
x_end = int(center[0] + span[0] / 2) x_end = int(center[0] + span[0] / 2)
@ -202,15 +194,15 @@ class ImageAnalyser():
dataSet[key].attrs['y_center'] = center[1] dataSet[key].attrs['y_center'] = center[1]
dataSet[key].attrs['x_span'] = span[0] dataSet[key].attrs['x_span'] = span[0]
dataSet[key].attrs['y_span'] = span[1] dataSet[key].attrs['y_span'] = span[1]
if fringeRemoval:
scanAxis = list(get_scanAxis(self.fringeRemovalReferenceImages))
if not scanAxis[1] is None:
self._fringeRemoval.referenceImages = self.fringeRemovalReferenceImages.isel(x=slice(x_start, x_end), y=slice(y_start, y_end)).stack(_imgIdx=scanAxis)
else:
self._fringeRemoval.referenceImages = self.fringeRemovalReferenceImages.isel(x=slice(x_start, x_end), y=slice(y_start, y_end))
return dataSet.isel(x=slice(x_start, x_end), y=slice(y_start, y_end)) res = dataSet.isel(x=slice(x_start, x_end), y=slice(y_start, y_end))
res = res.assign_coords(
{
'x': np.linspace(x_start, x_end - 1, span[0]),
'y': np.linspace(y_start, y_end - 1, span[1]),
}
)
return res
def get_OD(self, imageAtom, imageBackground, imageDrak): def get_OD(self, imageAtom, imageBackground, imageDrak):
"""Calculate the OD image for absorption imaging. """Calculate the OD image for absorption imaging.
@ -237,32 +229,6 @@ class ImageAnalyser():
return imageOD[0] return imageOD[0]
else: else:
return imageOD return imageOD
def get_OD_no_dark(self, imageAtom, imageBackground):
"""Calculate the OD image for absorption imaging without dark images.
:param imageAtom: The image with atoms
:type imageAtom: numpy array
:param imageBackground: The image without atoms
:type imageBackground: numpy array
:param imageDrak: The image without light
:type imageDrak: numpy array
:return: The OD images
:rtype: numpy array
"""
numerator = np.atleast_1d(imageBackground)
denominator = np.atleast_1d(imageAtom)
numerator[numerator == 0] = 1
denominator[denominator == 0] = 1
imageOD = np.abs(np.divide(denominator, numerator))
imageOD= -np.log(imageOD)
if len(imageOD) == 1:
return imageOD[0]
else:
return imageOD
def get_Ncount(self, dataSet, dim=['x', 'y'], **kwargs): def get_Ncount(self, dataSet, dim=['x', 'y'], **kwargs):
"""Sum all the value in the image to give the Ncount. """Sum all the value in the image to give the Ncount.
@ -276,13 +242,11 @@ class ImageAnalyser():
""" """
return dataSet.sum(dim=['x', 'y'], **kwargs) return dataSet.sum(dim=['x', 'y'], **kwargs)
def get_absorption_images(self, dataSet, fringeRemoval=False, dask='allowed', keep_attrs=True, **kwargs): def get_absorption_images(self, dataSet, dask='allowed', keep_attrs=True, **kwargs):
"""Calculate the OD images for absorption imaging. """Calculate the OD images for absorption imaging.
:param dataSet: The data from absorption imaging. :param dataSet: The data from absorption imaging.
:type dataSet: xarray DataSet :type dataSet: xarray DataSet
:param fringeRemoval: If use fringe removal function, defaults to False
:type fringeRemoval: bool, optional
:param dask: over write of the same argument in xarray.apply_ufunc, defaults to 'allowed' :param dask: over write of the same argument in xarray.apply_ufunc, defaults to 'allowed'
:type dask: str, optional :type dask: str, optional
:param keep_attrs: over write of the same argument in xarray.apply_ufunc, defaults to True :param keep_attrs: over write of the same argument in xarray.apply_ufunc, defaults to True
@ -298,33 +262,11 @@ class ImageAnalyser():
} }
) )
if fringeRemoval: dataSet = dataSet.assign(
{
dataSetAtoms = dataSet[self._image_name['atoms']] - dataSet[self._image_name['dark']] self._image_name['OD']: xr.apply_ufunc(self.get_OD, dataSet[self._image_name['atoms']], dataSet[self._image_name['background']], dataSet[self._image_name['dark']], **kwargs)
}
scanAxis = list(get_scanAxis(dataSet)) )
if not scanAxis[1] is None:
OptimumRef = self._fringeRemoval.fringe_removal(dataSetAtoms.stack(_imgIdx=scanAxis))
else:
OptimumRef = self._fringeRemoval.fringe_removal(dataSetAtoms)
dataSet = dataSet.assign(
{
self._image_name['optimumBackground']: OptimumRef
}
)
dataSet = dataSet.assign(
{
self._image_name['OD']: xr.apply_ufunc(self.get_OD_no_dark, dataSetAtoms, dataSet[self._image_name['optimumBackground']], **kwargs)
}
)
else:
dataSet = dataSet.assign(
{
self._image_name['OD']: xr.apply_ufunc(self.get_OD, dataSet[self._image_name['atoms']], dataSet[self._image_name['background']], dataSet[self._image_name['dark']], **kwargs)
}
)
# dataSet[self._image_name['OD']].attrs.update(dataSet.attrs) # dataSet[self._image_name['OD']].attrs.update(dataSet.attrs)
@ -349,77 +291,6 @@ class ImageAnalyser():
) )
xr.apply_ufunc(self.get_OD, dataSet[self._image_name['atoms']], dataSet[self._image_name['background']], dataSet[self._image_name['dark']], **kwargs) xr.apply_ufunc(self.get_OD, dataSet[self._image_name['atoms']], dataSet[self._image_name['background']], dataSet[self._image_name['dark']], **kwargs)
@property
def fringeRemovalCenter(self):
"""The getter of the center of region of insterest (ROI)
:return: The center of region of insterest (ROI)
:rtype: tuple
"""
return self._fringeRemoval.center
@fringeRemovalCenter.setter
def fringeRemovalCenter(self, value):
"""The setter of the center of region of insterest (ROI)
:param value: The center of region of insterest (ROI)
:type value: tuple
"""
self._fringeRemoval.center = value
@property
def fringeRemovalSpan(self):
"""The getter of the span of region of insterest (ROI)
:return: The span of region of insterest (ROI)
:rtype: tuple
"""
return self._fringeRemoval.span
@fringeRemovalSpan.setter
def fringeRemovalSpan(self, value):
"""The setter of the span of region of insterest (ROI)
:param value: The span of region of insterest (ROI)
:type value: tuple
"""
self._fringeRemoval.span = value
def load_fringe_removal_background_from_hdf5(self, img_dir, SequenceName, date, shotNum, group, crop=False, load=False, **kwargs):
"""Load the reference background images from hdf5 files of one single shot.
:param img_dir: The path of the folder storing data.
:type img_dir: str
:param SequenceName: The name of the sequence
:type SequenceName: str
:param date: The date when the shot was taken in 'YYYY/MM/DD'.
:type date: str
:param shotNum: The number of the shot
:type shotNum: str
:param group: The name of the group storing the imgaes
:type group: str
:param crop: If crop the data, defaults to False
:type crop: bool, optional
:param load: If load the data into RAM, defaults to False
:type load: bool, optional
"""
folderPath = img_dir + SequenceName + "/" + date
filePath = folderPath + "/" + shotNum + "/*.h5"
dataSet = read_hdf5_file(filePath, group, **kwargs)
scanAxis = dataSet.scanAxis
dataSet = dataSet[self._image_name['background']] - dataSet[self._image_name['dark']]
dataSet.attrs['scanAxis'] = scanAxis
if crop:
dataSet = self.crop_image(dataSet)
if load:
self.fringeRemovalReferenceImages = dataSet.load()
else:
self.fringeRemovalReferenceImages = dataSet
def load_fringe_removal_background_from_database():
pass

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -2,17 +2,12 @@ from collections import OrderedDict
import numpy as np import numpy as np
import pymongo import pymongo
from pymongo import MongoClient
import xarray_mongodb import xarray_mongodb
import bson import bson
import builtins import builtins
import xarray as xr import xarray as xr
# import sys
# #sys.path.insert(0, 'C:/Users/Fabrizio Klassen/PycharmProjects/DyLabDataViewer/src/bin/Analyser/AnalyserScript')
# import sys
# sys.path.append('../')
# from bin.Analyser.AnalyserScript.ToolFunction.ToolFunction import get_date
from ToolFunction.ToolFunction import get_date from ToolFunction.ToolFunction import get_date
@ -23,7 +18,7 @@ npArrayType = type(np.array([0]))
class MongoDB: class MongoDB:
"""A class for communicate with our MongoDB. """A class for communicate with our MongoDB.
""" """
def __init__(self, mongoClient, mongoDB, date=None) -> None: def __init__(self, mongoClient, mongoDB, date=None) -> None:
"""Initialize the class with given handle to our MongoDB client and database. """Initialize the class with given handle to our MongoDB client and database.
@ -37,11 +32,11 @@ class MongoDB:
self.mongoClient = mongoClient self.mongoClient = mongoClient
self.mongoDB = mongoDB self.mongoDB = mongoDB
self.xdb = xarray_mongodb.XarrayMongoDB(mongoDB) self.xdb = xarray_mongodb.XarrayMongoDB(mongoDB)
if date is None: if date is None:
date= get_date() date= get_date()
self.set_date(date) self.set_date(date)
def _convert_numpy_type(self, data): def _convert_numpy_type(self, data):
"""Convert from numpy type to normal python type. """Convert from numpy type to normal python type.
@ -50,10 +45,6 @@ class MongoDB:
:return: The converted data :return: The converted data
:rtype: normal python data type :rtype: normal python data type
""" """
if data is None:
return None
for key in data: for key in data:
typeKey = type(data[key]) typeKey = type(data[key])
if typeKey in npTypeDict: if typeKey in npTypeDict:
@ -66,7 +57,7 @@ class MongoDB:
except: except:
pass pass
return data return data
def _break_dataSet(self, dataSet, scanAxis=None): def _break_dataSet(self, dataSet, scanAxis=None):
"""Stack the scan axes of data """Stack the scan axes of data
@ -77,16 +68,16 @@ class MongoDB:
:return: The stacked xarray DataSet or DataArray stored the data :return: The stacked xarray DataSet or DataArray stored the data
:rtype: xarray DataSet or DataArray :rtype: xarray DataSet or DataArray
""" """
if scanAxis is None: if scanAxis is None:
scanAxis = dataSet.attrs['scanAxis'] scanAxis = dataSet.attrs['scanAxis']
dataArray = dataSet.shotNum dataArray = dataSet.shotNum
stackedDataArray = dataArray.stack(_scanAxis=tuple(scanAxis)) stackedDataArray = dataArray.stack(_scanAxis=tuple(scanAxis))
return stackedDataArray return stackedDataArray
def set_date(self, date): def set_date(self, date):
"""Set the date of data """Set the date of data
@ -97,8 +88,8 @@ class MongoDB:
self.year = int(date[0]) self.year = int(date[0])
self.month = int(date[1]) self.month = int(date[1])
self.day = int(date[2]) self.day = int(date[2])
def create_global(self, shotNum, dataSet=None, date=None, overwrite=True, runNum=None): def create_global(self, shotNum, dataSet=None, date=None):
"""Creat a the global document in MongoDB """Creat a the global document in MongoDB
:param shotNum: The shot number :param shotNum: The shot number
@ -107,66 +98,40 @@ class MongoDB:
:type dataSet: xarray DataSet, optional :type dataSet: xarray DataSet, optional
:param date: the date of the data, defaults to None :param date: the date of the data, defaults to None
:type date: str, optional :type date: str, optional
:param overwrite: If overwrite the exist global document, defaults to True
:type overwrite: bool, optional
""" """
if not date is None: if not date is None:
self.set_date(date) self.set_date(date)
if runNum is None:
data = {
'year': self.year,
'month': self.month,
'day': self.day,
'shotNum': shotNum,
}
runNum = 0
else:
data = {
'year': self.year,
'month': self.month,
'day': self.day,
'shotNum': shotNum,
'runNum': runNum,
}
if overwrite:
self.mongoDB['global'].delete_many(data)
else:
res = self.mongoDB['global'].find_one(data)
if not res is None:
if not len(res)==0:
return
data = { data = {
'year': self.year, 'year': self.year,
'month': self.month, 'month': self.month,
'day': self.day, 'day': self.day,
'shotNum': shotNum, 'shotNum': shotNum,
'runNum': runNum, }
self.mongoDB['global'].delete_many(data)
data = {
'year': self.year,
'month': self.month,
'day': self.day,
'shotNum': shotNum,
'runNum': 0,
'global_parameters' : {}, 'global_parameters' : {},
} }
if dataSet is None:
self.mongoDB['global'].insert_one(data)
return
if ('scanAxis' in dataSet.attrs) and len(dataSet.attrs['scanAxis'])==0:
del dataSet.attrs['scanAxis']
del dataSet.attrs['scanAxisLength']
global_parameters = self._convert_numpy_type(dataSet.attrs) global_parameters = self._convert_numpy_type(dataSet.attrs)
data['global_parameters'].update(global_parameters)
if not dataSet is None:
data['global_parameters'].update(global_parameters)
data = self._convert_numpy_type(data) data = self._convert_numpy_type(data)
if 'scanAxis' in dataSet.attrs: if 'scanAxis' in dataSet.attrs:
del data['global_parameters']['scanAxis'] del data['global_parameters']['scanAxis']
del data['global_parameters']['scanAxisLength'] del data['global_parameters']['scanAxisLength']
scanAxis = dataSet.attrs['scanAxis'] scanAxis = dataSet.attrs['scanAxis']
data['global_parameters'].update( data['global_parameters'].update(
{ {
@ -174,40 +139,40 @@ class MongoDB:
for key in scanAxis for key in scanAxis
} }
) )
stackedDataArray = self._break_dataSet(dataSet) stackedDataArray = self._break_dataSet(dataSet)
try: try:
stackedDataArray.load() stackedDataArray.load()
except: except:
pass pass
stackedDataArray = stackedDataArray.groupby('_scanAxis') stackedDataArray = stackedDataArray.groupby('_scanAxis')
for i in stackedDataArray: for i in stackedDataArray:
stackedDataArray_single = i[1] stackedDataArray_single = i[1]
data.update( data.update(
{ {
'runNum': int(stackedDataArray_single.item()) 'runNum': int(stackedDataArray_single.item())
} }
) )
data['global_parameters'].update( data['global_parameters'].update(
{ {
key: stackedDataArray_single[key].item() key: stackedDataArray_single[key].item()
for key in scanAxis for key in scanAxis
} }
) )
if '_id' in data: if '_id' in data:
del data['_id'] del data['_id']
self.mongoDB['global'].insert_one(data) self.mongoDB['global'].insert_one(data)
else: else:
self.mongoDB['global'].insert_one(data) self.mongoDB['global'].insert_one(data)
def _add_data_normal(self, shotNum, runNum, data): def _add_data_normal(self, shotNum, runNum, data):
"""Write the data directly to the global document """Write the data directly to the global document
@ -218,21 +183,21 @@ class MongoDB:
:param data: The data to be written :param data: The data to be written
:type data: normal python data type :type data: normal python data type
""" """
if runNum is None: if runNum is None:
runNum = 0 runNum = 0
filter = { filter = {
'year': self.year, 'year': self.year,
'month': self.month, 'month': self.month,
'day': self.day, 'day': self.day,
'shotNum': shotNum, 'shotNum': shotNum,
'runNum': runNum, 'runNum': runNum,
} }
self.mongoDB['global'].update_one(filter, {"$set": data}, upsert=False) self.mongoDB['global'].update_one(filter, {"$set": data}, upsert=False)
def _add_data_xarray_dataArray(self, shotNum, dataArray, name=None, scanAxis=None, runNum=None): def _add_data_xarray_dataArray(self, shotNum, dataArray, name=None, scanAxis=None):
"""Write the data in a type of xarray DataArray to the MongoDb. """Write the data in a type of xarray DataArray to the MongoDb.
:param shotNum: The shot number :param shotNum: The shot number
@ -247,45 +212,13 @@ class MongoDB:
if scanAxis is None: if scanAxis is None:
scanAxis = list(dataArray.coords) scanAxis = list(dataArray.coords)
if name is None:
name = dataArray.name
dataArray.attrs = self._convert_numpy_type(dataArray.attrs) dataArray.attrs = self._convert_numpy_type(dataArray.attrs)
if scanAxis is None or len(scanAxis) == 0:
if runNum is None:
return
filter = {
'year': self.year,
'month': self.month,
'day': self.day,
'shotNum': shotNum,
'runNum': runNum,
}
mongoID, _ = self.xdb.put(dataArray)
data_label = {
name:
{
'name': name,
'mongoID': mongoID,
'engine': 'xarray',
'dtype': 'dataArray',
}
}
self.mongoDB['global'].update_one(filter, {"$set": data_label}, upsert=False)
return
stackedDataArray = dataArray.stack(_scanAxis=tuple(scanAxis)) stackedDataArray = dataArray.stack(_scanAxis=tuple(scanAxis))
stackedDataArray = stackedDataArray.groupby('_scanAxis') stackedDataArray = stackedDataArray.groupby('_scanAxis')
filter = { filter = {
'year': self.year, 'year': self.year,
'month': self.month, 'month': self.month,
'day': self.day, 'day': self.day,
@ -293,31 +226,31 @@ class MongoDB:
} }
for i in stackedDataArray: for i in stackedDataArray:
stackedDataArray_single = i[1].drop('_scanAxis') stackedDataArray_single = i[1].drop('_scanAxis')
global_parameters = { global_parameters = {
'global_parameters.' + key: stackedDataArray_single[key].item() 'global_parameters.' + key: stackedDataArray_single[key].item()
for key in scanAxis for key in scanAxis
} }
filter.update(global_parameters) filter.update(global_parameters)
mongoID, _ = self.xdb.put(stackedDataArray_single) mongoID, _ = self.xdb.put(stackedDataArray_single)
data_label = { data_label = {
name: dataArray.name:
{ {
'name': name, 'name': dataArray.name,
'mongoID': mongoID, 'mongoID': mongoID,
'engine': 'xarray', 'engine': 'xarray',
'dtype': 'dataArray', 'dtype': 'dataArray',
} }
} }
self.mongoDB['global'].update_one(filter, {"$set": data_label}, upsert=False) self.mongoDB['global'].update_one(filter, {"$set": data_label}, upsert=False)
def _add_data_xarray_dataSet(self, shotNum, dataSet, name, scanAxis=None, runNum=None): def _add_data_xarray_dataSet(self, shotNum, dataSet, name, scanAxis=None):
"""Write the data in a type of xarray DataSet to the MongoDb. """Write the data in a type of xarray DataSet to the MongoDb.
:param shotNum: The shot number :param shotNum: The shot number
@ -329,69 +262,40 @@ class MongoDB:
:param scanAxis: The scan axes of the data, defaults to None :param scanAxis: The scan axes of the data, defaults to None
:type scanAxis: array like, optional :type scanAxis: array like, optional
""" """
if scanAxis is None: if scanAxis is None:
scanAxis = list(dataSet.coords) scanAxis = list(dataSet.coords)
dataSet.attrs = self._convert_numpy_type(dataSet.attrs) dataSet.attrs = self._convert_numpy_type(dataSet.attrs)
for key in list(dataSet.data_vars): for key in list(dataSet.data_vars):
dataSet[key].attrs = self._convert_numpy_type(dataSet[key].attrs) dataSet[key].attrs = self._convert_numpy_type(dataSet[key].attrs)
if scanAxis is None or len(scanAxis) == 0:
if runNum is None:
return
filter = {
'year': self.year,
'month': self.month,
'day': self.day,
'shotNum': shotNum,
'runNum': runNum,
}
mongoID, _ = self.xdb.put(dataSet)
data_label = {
name:
{
'name': name,
'mongoID': mongoID,
'engine': 'xarray',
'dtype': 'dataSet',
}
}
self.mongoDB['global'].update_one(filter, {"$set": data_label}, upsert=False)
return
stackedDataSet = dataSet.stack(_scanAxis=tuple(scanAxis)) stackedDataSet = dataSet.stack(_scanAxis=tuple(scanAxis))
stackedDataSet = stackedDataSet.groupby('_scanAxis') stackedDataSet = stackedDataSet.groupby('_scanAxis')
filter = { filter = {
'year': self.year, 'year': self.year,
'month': self.month, 'month': self.month,
'day': self.day, 'day': self.day,
'shotNum': shotNum, 'shotNum': shotNum,
} }
for i in stackedDataSet: for i in stackedDataSet:
stackedDataSet_single = i[1].drop('_scanAxis') stackedDataSet_single = i[1].drop('_scanAxis')
global_parameters = { global_parameters = {
'global_parameters.' + key: stackedDataSet_single[key].item() 'global_parameters.' + key: stackedDataSet_single[key].item()
for key in scanAxis for key in scanAxis
} }
filter.update(global_parameters) filter.update(global_parameters)
mongoID, _ = self.xdb.put(stackedDataSet_single) mongoID, _ = self.xdb.put(dataSet)
data_label = { data_label = {
name: name:
{ {
'name': name, 'name': name,
'mongoID': mongoID, 'mongoID': mongoID,
@ -399,9 +303,9 @@ class MongoDB:
'dtype': 'dataSet', 'dtype': 'dataSet',
} }
} }
self.mongoDB['global'].update_one(filter, {"$set": data_label}, upsert=False) self.mongoDB['global'].update_one(filter, {"$set": data_label}, upsert=False)
def _add_data_additional(self, shotNum, runNum, data, name): def _add_data_additional(self, shotNum, runNum, data, name):
"""Write the data in an additional document """Write the data in an additional document
@ -414,22 +318,22 @@ class MongoDB:
:param name: The name of the data :param name: The name of the data
:type name: str :type name: str
""" """
if runNum is None: if runNum is None:
runNum = 0 runNum = 0
filter = { filter = {
'year': self.year, 'year': self.year,
'month': self.month, 'month': self.month,
'day': self.day, 'day': self.day,
'shotNum': shotNum, 'shotNum': shotNum,
'runNum': runNum, 'runNum': runNum,
} }
mongoID = self.mongoDB.additional.insert_one(data).inserted_id mongoID = self.mongoDB.additional.insert_one(data).inserted_id
data_label = { data_label = {
name: name:
{ {
'name': name, 'name': name,
'mongoID': mongoID, 'mongoID': mongoID,
@ -437,9 +341,9 @@ class MongoDB:
'dtype': 'dict', 'dtype': 'dict',
} }
} }
self.mongoDB['global'].update_one(filter, {"$set": data_label}, upsert=False) self.mongoDB['global'].update_one(filter, {"$set": data_label}, upsert=False)
def add_data(self, shotNum, data, runNum=None, date=None, name=None, engine='normal'): def add_data(self, shotNum, data, runNum=None, date=None, name=None, engine='normal'):
"""Write a new data to MongoDB """Write a new data to MongoDB
@ -458,17 +362,17 @@ class MongoDB:
""" """
if not date is None: if not date is None:
self.set_date(date) self.set_date(date)
if engine == 'normal': if engine == 'normal':
self._add_data_normal(shotNum=shotNum, runNum=runNum, data=data) self._add_data_normal(shotNum=shotNum, runNum=runNum, data=data)
elif engine == 'xarray': elif engine == 'xarray':
if isinstance(data, type(xr.Dataset())): if isinstance(data, type(xr.Dataset())):
self._add_data_xarray_dataSet(shotNum=shotNum, dataSet=data, runNum=runNum, name=name) self._add_data_xarray_dataSet(shotNum=shotNum, dataSet=data, name=name)
else: else:
self._add_data_xarray_dataArray(shotNum=shotNum, dataArray=data, runNum=runNum, name=name) self._add_data_xarray_dataArray(shotNum=shotNum, dataArray=data, name=name)
elif engine == 'additional': elif engine == 'additional':
self._add_data_additional(shotNum=shotNum, runNum=runNum, data=data, name=name) self._add_data_additional(shotNum=shotNum, runNum=runNum, data=data, name=name)
def read_global_single(self, shotNum, runNum, date=None): def read_global_single(self, shotNum, runNum, date=None):
"""Read the global document of specified shot and run from MongoDB. """Read the global document of specified shot and run from MongoDB.
@ -481,20 +385,20 @@ class MongoDB:
:return: The global document :return: The global document
:rtype: dict :rtype: dict
""" """
if not date is None: if not date is None:
self.set_date(date) self.set_date(date)
filter = { filter = {
'year': self.year, 'year': self.year,
'month': self.month, 'month': self.month,
'day': self.day, 'day': self.day,
'shotNum': shotNum, 'shotNum': shotNum,
'runNum': runNum, 'runNum': runNum,
} }
return self.mongoDB['global'].find_one(filter) return self.mongoDB['global'].find_one(filter)
def read_global_all(self, shotNum, date=None): def read_global_all(self, shotNum, date=None):
"""Read the global document of all runs in the specified shot from MongoDB, and extract the scan axes. """Read the global document of all runs in the specified shot from MongoDB, and extract the scan axes.
@ -505,28 +409,29 @@ class MongoDB:
:return: The global document :return: The global document
:rtype: dict :rtype: dict
""" """
from xarray.core.utils import equivalent from xarray.core.utils import equivalent
if not date is None: if not date is None:
self.set_date(date) self.set_date(date)
filter = { filter = {
'year': self.year, 'year': self.year,
'month': self.month, 'month': self.month,
'day': self.day, 'day': self.day,
'shotNum': shotNum, 'shotNum': shotNum,
} }
result = {} result = {}
dropped_attrs = OrderedDict() dropped_attrs = OrderedDict()
docs = self.mongoDB['global'].find(filter).sort('runNum') docs = self.mongoDB['global'].find(filter).sort('runNum')
docs = [doc['global_parameters'] for doc in docs] docs = [doc['global_parameters'] for doc in docs]
for doc in docs: for doc in docs:
global_parameters = doc global_parameters = doc
result.update( result.update(
{ {
key: value key: value
@ -534,34 +439,34 @@ class MongoDB:
if key not in result and key not in dropped_attrs.keys() if key not in result and key not in dropped_attrs.keys()
} }
) )
result = { result = {
key: value key: value
for key, value in result.items() for key, value in result.items()
if key not in global_parameters or equivalent(global_parameters[key], value) if key not in global_parameters or equivalent(global_parameters[key], value)
} }
dropped_attrs.update( dropped_attrs.update(
{ {
key: [] key: []
for key in global_parameters if key not in result for key in global_parameters if key not in result
} }
) )
for doc in docs: for doc in docs:
global_parameters = doc global_parameters = doc
dropped_attrs.update( dropped_attrs.update(
{ {
key: np.append(dropped_attrs[key], global_parameters[key]) key: np.append(dropped_attrs[key], global_parameters[key])
for key in dropped_attrs.keys() for key in dropped_attrs.keys()
} }
) )
scan_attrs = OrderedDict() scan_attrs = OrderedDict()
scan_length = [] scan_length = []
for attrs_key in dropped_attrs.keys(): for attrs_key in dropped_attrs.keys():
flag = True flag = True
for key in scan_attrs.keys(): for key in scan_attrs.keys():
@ -590,7 +495,7 @@ class MongoDB:
) )
return result return result
def _load_data_single(self, mongoID, engine): def _load_data_single(self, mongoID, engine):
"""load the document according to given _ID """load the document according to given _ID
@ -601,12 +506,12 @@ class MongoDB:
:return: The engine for different types of the data :return: The engine for different types of the data
:rtype: str :rtype: str
""" """
if engine == 'xarray': if engine == 'xarray':
return self.xdb.get(mongoID) return self.xdb.get(mongoID)
if engine == 'additional': if engine == 'additional':
return self.mongoDB.additional.find_one({'_id': mongoID}) return self.mongoDB.additional.find_one({'_id': mongoID})
def load_data_single(self, shotNum=None, runNum=None, globalDict=None, date=None, field=None): def load_data_single(self, shotNum=None, runNum=None, globalDict=None, date=None, field=None):
"""go through the given global document and find all the MongoDB object, then replace them with the document they linked to. """go through the given global document and find all the MongoDB object, then replace them with the document they linked to.
@ -623,143 +528,31 @@ class MongoDB:
:return: The document with loaded data :return: The document with loaded data
:rtype: dict :rtype: dict
""" """
if not date is None: if not date is None:
self.set_date(date) self.set_date(date)
if globalDict is None: if globalDict is None:
globalDict = self.read_global_single(shotNum=shotNum, runNum=runNum) globalDict = self.read_global_single(shotNum=shotNum, runNum=runNum)
if field is None: if field is None:
field = globalDict field = globalDict
res = {} res = {}
for key in field: for key in field:
if isinstance(globalDict[key], dict) and ('mongoID' in globalDict[key]): if isinstance(globalDict[key], dict) and ('mongoID' in globalDict[key]):
mongoID = globalDict[key]['mongoID'] mongoID = globalDict[key]['mongoID']
engine = globalDict[key]['engine'] engine = globalDict[key]['engine']
res.update( res.update(
{ {
key: self._load_data_single(mongoID=mongoID, engine=engine) key: self._load_data_single(mongoID=mongoID, engine=engine)
} }
) )
return res return res
def load_data(self, shotNum, data_key=None, globalDict=None, date=None): def load_data(self, shotNum=None, globalDict=None, date=None, field=None):
"""load observables of given shot""" # load all the data of specified shot
"""the documents of a given shot can carry a variety of data types, i.e. optical density, pass
N_count, centerx etc. In order to not load all the data and take too much RAM, the user
is presented with a drop down featuring all possible observables. Only after selection the
actual data is being loaded
:param shotNum: The shot number, defaults to None
:type shotNum: str, optional
:param date: The date of the data ('YYYY/MM/DD'), defaults to None
:type date: str, optional
:return: All data types in the given shot
:rtype: list
"""
# set date
if not date is None:
self.set_date(date)
# collect global parameters and scan axes
if globalDict is None:
globalDict = self.read_global_all(shotNum=shotNum, date=date)
# initialize output dictionary
res = {'year': self.year, 'month': self.month, 'day': self.day, 'global_parameters': {}}
# add all global parameters except scan axes
res['global_parameters'].update(
{
key: value
for key, value in globalDict.items()
if key not in ['scanAxis', 'scanAxisLength']
}
)
# find data
filter = {
'year': self.year,
'month': self.month,
'day': self.day,
'shotNum': shotNum,
}
docs = self.mongoDB['global'].find(filter).sort('runNum')
if data_key is None:
data_key = [key for key in docs[0] if not key in ['year', 'month', 'day', 'shotNum', 'runNum', 'global_parameters', '_id']]
for key in data_key:
res[key] = self._load_data(shotNum=shotNum, data_key=key, globalDict=globalDict)
res['global_parameters'].update(
{
'scanAxis': globalDict['scanAxis'],
'scanAxisLength': globalDict['scanAxisLength'],
}
)
return res
def _load_data(self, shotNum, data_key, globalDict):
"""load all the data of specified shot"""
"""go through the given global document and find all the MongoDB object, then replace them with the document they linked to.
:param shotNum: The shot number, defaults to None
:type shotNum: str, optional
:param globalDict: All global parameters plus scan axes and scan axes length, defaults to None
:type globalDict: dict, optional
:param date: The date of the data ('YYYY/MM/DD'), defaults to None
:type date: str, optional
:return: Data from all runs of given shot including global parameters and date
:rtype: dict
"""
# collect data from all docs of given shot
filter = {
'year': self.year,
'month': self.month,
'day': self.day,
'shotNum': shotNum,
}
# find matching docs
docs = self.mongoDB['global'].find(filter).sort('runNum')
data = []
i = 0
for doc in docs:
key=data_key
if isinstance(doc[key], dict) and ('mongoID' in doc[key]):
mongoID = doc[key]['mongoID']
engine = doc[key]['engine']
single_data = self._load_data_single(mongoID=mongoID, engine=engine)
for axis in globalDict['scanAxis']:
if not axis in single_data.coords:
single_data = single_data.assign_coords({axis:globalDict[axis][i]})
if not axis in single_data.dims:
single_data = single_data.expand_dims(axis)
else:
engine = None
single_data = doc[key]
data.append(single_data)
i = i + 1
# combine data along coordinate axes
try:
if engine =='xarray':
data = xr.combine_by_coords(data)
except:
pass
return data

Binary file not shown.

Binary file not shown.

File diff suppressed because one or more lines are too long

View File

@ -1,77 +0,0 @@
asteval==0.9.31
asttokens==2.4.0
backcall==0.2.0
bokeh==3.2.2
click==8.1.7
cloudpickle==2.2.1
colorama==0.4.6
comm==0.1.4
contourpy==1.1.1
cycler==0.11.0
dask==2023.9.2
debugpy==1.8.0
decorator==5.1.1
distributed==2023.9.2
dnspython==2.4.2
exceptiongroup==1.1.3
executing==1.2.0
finufft==2.1.0
fonttools==4.42.1
fsspec==2023.9.1
future==0.18.3
h5netcdf==1.2.0
h5py==3.9.0
importlib-metadata==6.8.0
importlib-resources==6.1.0
ipykernel==6.25.2
ipython==8.15.0
jedi==0.19.0
Jinja2==3.1.2
jupyter_client==8.3.1
jupyter_core==5.3.1
kiwisolver==1.4.5
lmfit==1.2.2
locket==1.0.0
MarkupSafe==2.1.3
matplotlib==3.8.0
matplotlib-inline==0.1.6
msgpack==1.0.6
nest-asyncio==1.5.8
numpy==1.26.0
packaging==23.1
pandas==2.1.1
parso==0.8.3
partd==1.4.0
pickleshare==0.7.5
Pillow==10.0.1
platformdirs==3.10.0
prompt-toolkit==3.0.39
psutil==5.9.5
pure-eval==0.2.2
Pygments==2.16.1
pymongo==4.5.0
pyparsing==3.1.1
python-dateutil==2.8.2
pytz==2023.3.post1
pywin32==306
PyYAML==6.0.1
pyzmq==25.1.1
scipy==1.11.2
six==1.16.0
sortedcontainers==2.4.0
stack-data==0.6.2
tblib==2.0.0
toolz==0.12.0
tornado==6.3.3
traitlets==5.10.0
typing_extensions==4.8.0
tzdata==2023.3
uncertainties==3.1.7
urllib3==2.0.5
wcwidth==0.2.6
xarray==2023.8.0
xarray-mongodb==0.2.1
xrft==1.0.1
xyzservices==2023.7.0
zict==3.0.0
zipp==3.17.0

File diff suppressed because it is too large Load Diff