import pymongo import xarray_mongodb import bson import builtins import xarray as xr from ToolFunction.ToolFunction import get_date npTypeDict = {v: getattr(builtins, k) for k, v in np.sctypeDict.items() if k in vars(builtins)} npArrayType = type(np.array([0])) class MongoDB: def __init__(self, mongoClient, mongoDB, date=None) -> None: self.mongoClient = mongoClient self.mongoDB = mongoDB self.xdb = xarray_mongodb.XarrayMongoDB(mongoDB) if date is None: date= get_date() self.set_date(date) def _convert_numpy_type(self, data): for key in data: typeKey = type(data[key]) if typeKey in npTypeDict: data[key] = data[key].item() elif typeKey == npArrayType: data[key] = data[key].tolist() else: try: data[key] = data[key].item() except: pass return data def _break_dataSet(self, dataSet, scanAxis=None): if scanAxis is None: scanAxis = dataSet.attrs['scanAxis'] dataArray = dataSet.shotNum stackedDataArray = dataArray.stack(_scanAxis=tuple(scanAxis)) return stackedDataArray def set_date(self, date): date = date.split("/") self.year = int(date[0]) self.month = int(date[1]) self.day = int(date[2]) def create_global(self, shotNum, dataSet=None, date=None): if not date is None: self.set_date(date) data = { 'year': self.year, 'month': self.month, 'day': self.day, 'shotNum': shotNum, 'runNum': 0, 'global_parameters' : {}, } global_parameters = self._convert_numpy_type(dataSet.attrs) if not dataSet is None: data['global_parameters'].update(global_parameters) data = self._convert_numpy_type(data) if 'scanAxis' in dataSet.attrs: del data['global_parameters']['scanAxis'] del data['global_parameters']['scanAxisLength'] scanAxis = dataSet.attrs['scanAxis'] data['global_parameters'].update( { key:0 for key in scanAxis } ) stackedDataArray = self._break_dataSet(dataSet) stackedDataArray = stackedDataArray.groupby('_scanAxis') for i in stackedDataArray: stackedDataArray_single = i[1] data.update( { 'runNum': int(stackedDataArray_single.item()) } ) data['global_parameters'].update( { key: stackedDataArray_single[key].item() for key in scanAxis } ) if '_id' in data: del data['_id'] self.mongoDB['global'].insert_one(data) else: self.mongoDB['global'].insert_one(data) def _add_data_normal(self, shotNum, runNum, data): if runNum is None: runNum = 0 filter = { 'year': self.year, 'month': self.month, 'day': self.day, 'shotNum': shotNum, 'runNum': runNum, } self.mongoDB['global'].update_one(filter, {"$set": data}, upsert=False) def _add_data_xarray_dataArray(self, shotNum, dataArray, scanAxis=None): if scanAxis is None: scanAxis = list(dataArray.coords) dataArray.attrs = self._convert_numpy_type(dataArray.attrs) stackedDataArray = dataArray.stack(_scanAxis=tuple(scanAxis)) stackedDataArray = stackedDataArray.groupby('_scanAxis') filter = { 'year': self.year, 'month': self.month, 'day': self.day, 'shotNum': shotNum, } for i in stackedDataArray: stackedDataArray_single = i[1].drop('_scanAxis') global_parameters = { 'global_parameters.' + key: stackedDataArray_single[key].item() for key in scanAxis } filter.update(global_parameters) mongoID, _ = self.xdb.put(stackedDataArray_single) data_label = { dataArray.name: { 'name': dataArray.name, 'mongoID': mongoID, 'engine': 'xarray', 'dtype': 'dataArray', } } self.mongoDB['global'].update_one(filter, {"$set": data_label}, upsert=False) def _add_data_xarray_dataSet(self, shotNum, dataSet, name, scanAxis=None): if scanAxis is None: scanAxis = list(dataSet.coords) dataSet.attrs = self._convert_numpy_type(dataSet.attrs) for key in list(dataSet.data_vars): dataSet[key].attrs = self._convert_numpy_type(dataSet[key].attrs) stackedDataSet = dataSet.stack(_scanAxis=tuple(scanAxis)) stackedDataSet = stackedDataSet.groupby('_scanAxis') filter = { 'year': self.year, 'month': self.month, 'day': self.day, 'shotNum': shotNum, } for i in stackedDataSet: stackedDataSet_single = i[1].drop('_scanAxis') global_parameters = { 'global_parameters.' + key: stackedDataSet_single[key].item() for key in scanAxis } filter.update(global_parameters) mongoID, _ = self.xdb.put(dataSet) data_label = { name: { 'name': name, 'mongoID': mongoID, 'engine': 'xarray', 'dtype': 'dataSet', } } self.mongoDB['global'].update_one(filter, {"$set": data_label}, upsert=False) def _add_data_additional(self, shotNum, runNum, data, name): if runNum is None: runNum = 0 filter = { 'year': self.year, 'month': self.month, 'day': self.day, 'shotNum': shotNum, 'runNum': runNum, } mongoID = self.mongoDB.additional.insert_one(data).inserted_id data_label = { name: { 'name': name, 'mongoID': mongoID, 'engine': 'additional', 'dtype': 'dict', } } self.mongoDB['global'].update_one(filter, {"$set": data_label}, upsert=False) def add_data(self, shotNum, data, runNum=None, date=None, name=None, engine='normal'): if not date is None: self.set_date(date) if engine == 'normal': self._add_data_normal(shotNum=shotNum, runNum=runNum, data=data) elif engine == 'xarray': if isinstance(data, type(xr.Dataset())): self._add_data_xarray_dataSet(shotNum=shotNum, data=data, name=name) else: self._add_data_xarray_dataArray(shotNum=shotNum, data=data, name=name) elif engine == 'additional': self._add_data_additional(shotNum=shotNum, runNum=runNum, data=data, name=name) def read_global_single(self, shotNum, runNum, date=None): if not date is None: self.set_date(date) filter = { 'year': self.year, 'month': self.month, 'day': self.day, 'shotNum': shotNum, 'runNum': runNum, } return self.mongoDB['global'].find_one(filter) def _load_data_single(self, mongoID, engine): if engine == 'xarray': return self.xdb.get(mongoID) if engine == 'additional': return self.mongoDB.additional.find_one({'_id': mongoID}) def load_data_single(self, shotNum=None, runNum=None, globalDict=None, date=None, field=None): if not date is None: self.set_date(date) if globalDict is None: globalDict = self.read_global_single(shotNum=shotNum, runNum=runNum) if field is None: field = globalDict res = field for key in field: if isinstance(globalDict[key], dict) and ('mongoID' in globalDict[key]): mongoID = globalDict[key]['mongoID'] engine = globalDict[key]['engine'] res.update( { key: self._load_data_single(mongoID=mongoID, engine=engine) } ) return res def load_data(self, shotNum=None, globalDict=None, date=None, field=None): pass