You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

323 lines
9.3 KiB

1 year ago
  1. import xarray as xr
  2. import numpy as np
  3. from collections import OrderedDict
  4. from functools import partial
  5. import copy
  6. import glob
  7. import os
  8. from datetime import datetime
  9. def _read_globals_attrs(variable_attrs, context=None):
  10. """Combine attributes from different variables according to combine_attrs"""
  11. if not variable_attrs:
  12. # no attributes to merge
  13. return None
  14. from xarray.core.utils import equivalent
  15. result = {}
  16. dropped_attrs = OrderedDict()
  17. for attrs in variable_attrs:
  18. result.update(
  19. {
  20. key: value
  21. for key, value in attrs.items()
  22. if key not in result and key not in dropped_attrs.keys()
  23. }
  24. )
  25. result = {
  26. key: value
  27. for key, value in result.items()
  28. if key not in attrs or equivalent(attrs[key], value)
  29. }
  30. dropped_attrs.update(
  31. {
  32. key: []
  33. for key in attrs if key not in result
  34. }
  35. )
  36. for attrs in variable_attrs:
  37. dropped_attrs.update(
  38. {
  39. key: np.append(dropped_attrs[key], attrs[key])
  40. for key in dropped_attrs.keys()
  41. }
  42. )
  43. scan_attrs = OrderedDict()
  44. scan_length = []
  45. for attrs_key in dropped_attrs.keys():
  46. flag = True
  47. for key in scan_attrs.keys():
  48. if equivalent(scan_attrs[key], dropped_attrs[attrs_key]):
  49. flag = False
  50. result.update({attrs_key: key})
  51. break
  52. if flag:
  53. scan_attrs.update({
  54. attrs_key: dropped_attrs[attrs_key]
  55. })
  56. scan_length = np.append(scan_length, len(dropped_attrs[attrs_key]))
  57. result.update(
  58. {
  59. key: value
  60. for key, value in scan_attrs.items()
  61. }
  62. )
  63. result.update(
  64. {
  65. "scanAxis": list(scan_attrs.keys()),
  66. "scanAxisLength": scan_length,
  67. }
  68. )
  69. # if result['scanAxis'] == []:
  70. # result['scanAxis'] = ['runs',]
  71. return result
  72. def _read_shot_number_from_hdf5(x):
  73. filePath = x.encoding["source"]
  74. shotNum = filePath.split("_")[-1].split("_")[-1].split(".")[0]
  75. return x.assign(shotNum=shotNum)
  76. def _assign_scan_axis_partial(x, datesetOfGlobal, fullFilePath):
  77. scanAxis = datesetOfGlobal.scanAxis
  78. filePath = x.encoding["source"].replace("\\", "/")
  79. shotNum = np.where(fullFilePath==filePath)
  80. shotNum = np.squeeze(shotNum)
  81. # shotNum = filePath.split("_")[-1].split("_")[-1].split(".")[0]
  82. x = x.assign(shotNum=shotNum)
  83. x = x.expand_dims(list(scanAxis))
  84. return x.assign_coords(
  85. {
  86. key: np.atleast_1d(np.atleast_1d(datesetOfGlobal.attrs[key])[int(shotNum)])
  87. for key in scanAxis
  88. }
  89. )
  90. def _update_globals_attrs(variable_attrs, context=None):
  91. pass
  92. def update_hdf5_file():
  93. pass
  94. def read_hdf5_file(filePath, group=None, datesetOfGlobal=None, preprocess=None, join="outer", parallel=True, engine="h5netcdf", phony_dims="access", excludeAxis=[], maxFileNum=None, **kwargs):
  95. filePath = np.sort(np.atleast_1d(filePath))
  96. filePathAbs = []
  97. for i in range(len(filePath)):
  98. filePathAbs.append(os.path.abspath(filePath[i]).replace("\\", "/"))
  99. fullFilePath = []
  100. for i in range(len(filePathAbs)):
  101. fullFilePath.append(list(np.sort(glob.glob(filePathAbs[i]))))
  102. fullFilePath = np.array(fullFilePath).flatten()
  103. for i in range(len(fullFilePath)):
  104. fullFilePath[i] = fullFilePath[i].replace("\\", "/")
  105. if not maxFileNum is None:
  106. fullFilePath = fullFilePath[0:int(maxFileNum)]
  107. kwargs.update(
  108. {
  109. 'join': join,
  110. 'parallel': parallel,
  111. 'engine': engine,
  112. 'phony_dims': phony_dims,
  113. 'group': group
  114. }
  115. )
  116. if datesetOfGlobal is None:
  117. datesetOfGlobal = xr.open_mfdataset(
  118. fullFilePath,
  119. group="globals",
  120. concat_dim="fileNum",
  121. combine="nested",
  122. preprocess=_read_shot_number_from_hdf5,
  123. engine="h5netcdf",
  124. phony_dims="access",
  125. combine_attrs=_read_globals_attrs,
  126. parallel=True, )
  127. datesetOfGlobal.attrs['scanAxis'] = np.setdiff1d(datesetOfGlobal.attrs['scanAxis'], excludeAxis)
  128. _assgin_scan_axis = partial(_assign_scan_axis_partial, datesetOfGlobal=datesetOfGlobal, fullFilePath=fullFilePath)
  129. if preprocess is None:
  130. kwargs.update({'preprocess':_assgin_scan_axis})
  131. else:
  132. kwargs.update({'preprocess':preprocess})
  133. ds = xr.open_mfdataset(fullFilePath, **kwargs)
  134. newDimKey = np.append(['x', 'y', 'z'], [ chr(i) for i in range(97, 97+23)])
  135. oldDimKey = np.sort(
  136. [
  137. key
  138. for key in ds.dims
  139. if not key in datesetOfGlobal.scanAxis
  140. ]
  141. )
  142. renameDict = {
  143. oldDimKey[j]: newDimKey[j]
  144. for j in range(len(oldDimKey))
  145. }
  146. ds = ds.rename_dims(renameDict)
  147. ds.attrs = copy.deepcopy(datesetOfGlobal.attrs)
  148. return ds
  149. def _assign_scan_axis_partial_and_remove_everything(x, datesetOfGlobal, fullFilePath):
  150. scanAxis = datesetOfGlobal.scanAxis
  151. filePath = x.encoding["source"].replace("\\", "/")
  152. shotNum = np.where(fullFilePath==filePath)
  153. shotNum = np.squeeze(shotNum)
  154. runTime = _read_run_time_from_hdf5(x)
  155. x = xr.Dataset(data_vars={'runTine':runTime})
  156. x = x.expand_dims(list(scanAxis))
  157. return x.assign_coords(
  158. {
  159. key: np.atleast_1d(np.atleast_1d(datesetOfGlobal.attrs[key])[int(shotNum)])
  160. for key in scanAxis
  161. }
  162. )
  163. def _read_run_time_from_hdf5(x):
  164. runTime = datetime.strptime(x.attrs['run time'], '%Y%m%dT%H%M%S')
  165. return runTime
  166. def read_hdf5_run_time(filePath, group=None, datesetOfGlobal=None, preprocess=None, join="outer", parallel=True, engine="h5netcdf", phony_dims="access", excludeAxis=[], maxFileNum=None, **kwargs):
  167. filePath = np.sort(np.atleast_1d(filePath))
  168. filePathAbs = []
  169. for i in range(len(filePath)):
  170. filePathAbs.append(os.path.abspath(filePath[i]).replace("\\", "/"))
  171. fullFilePath = []
  172. for i in range(len(filePathAbs)):
  173. fullFilePath.append(list(np.sort(glob.glob(filePathAbs[i]))))
  174. fullFilePath = np.array(fullFilePath).flatten()
  175. for i in range(len(fullFilePath)):
  176. fullFilePath[i] = fullFilePath[i].replace("\\", "/")
  177. if not maxFileNum is None:
  178. fullFilePath = fullFilePath[0:int(maxFileNum)]
  179. kwargs.update(
  180. {
  181. 'join': join,
  182. 'parallel': parallel,
  183. 'engine': engine,
  184. 'phony_dims': phony_dims,
  185. 'group': group
  186. }
  187. )
  188. if datesetOfGlobal is None:
  189. datesetOfGlobal = xr.open_mfdataset(
  190. fullFilePath,
  191. group="globals",
  192. concat_dim="fileNum",
  193. combine="nested",
  194. preprocess=_read_shot_number_from_hdf5,
  195. engine="h5netcdf",
  196. phony_dims="access",
  197. combine_attrs=_read_globals_attrs,
  198. parallel=True, )
  199. datesetOfGlobal.attrs['scanAxis'] = np.setdiff1d(datesetOfGlobal.attrs['scanAxis'], excludeAxis)
  200. _assgin_scan_axis = partial(_assign_scan_axis_partial_and_remove_everything, datesetOfGlobal=datesetOfGlobal, fullFilePath=fullFilePath)
  201. if preprocess is None:
  202. kwargs.update({'preprocess':_assgin_scan_axis})
  203. else:
  204. kwargs.update({'preprocess':preprocess})
  205. ds = xr.open_mfdataset(fullFilePath, **kwargs)
  206. newDimKey = np.append(['x', 'y', 'z'], [ chr(i) for i in range(97, 97+23)])
  207. oldDimKey = np.sort(
  208. [
  209. key
  210. for key in ds.dims
  211. if not key in datesetOfGlobal.scanAxis
  212. ]
  213. )
  214. renameDict = {
  215. oldDimKey[j]: newDimKey[j]
  216. for j in range(len(oldDimKey))
  217. }
  218. ds = ds.rename_dims(renameDict)
  219. ds.attrs = copy.deepcopy(datesetOfGlobal.attrs)
  220. return ds
  221. def read_hdf5_global(filePath, preprocess=None, join="outer", combine="nested", parallel=True, engine="h5netcdf", phony_dims="access", excludeAxis=[], maxFileNum=None, **kwargs):
  222. filePath = np.sort(np.atleast_1d(filePath))
  223. filePathAbs = []
  224. for i in range(len(filePath)):
  225. filePathAbs.append(os.path.abspath(filePath[i]).replace("\\", "/"))
  226. fullFilePath = []
  227. for i in range(len(filePathAbs)):
  228. fullFilePath.append(list(np.sort(glob.glob(filePathAbs[i]))))
  229. fullFilePath = np.array(fullFilePath).flatten()
  230. for i in range(len(fullFilePath)):
  231. fullFilePath[i] = fullFilePath[i].replace("\\", "/")
  232. if not maxFileNum is None:
  233. fullFilePath = fullFilePath[0:int(maxFileNum)]
  234. kwargs.update(
  235. {
  236. 'join': join,
  237. 'parallel': parallel,
  238. 'engine': engine,
  239. 'phony_dims': phony_dims,
  240. 'group': "globals",
  241. 'preprocess': _read_shot_number_from_hdf5,
  242. 'combine_attrs': _read_globals_attrs,
  243. 'combine':combine,
  244. 'concat_dim': "fileNum",
  245. }
  246. )
  247. datesetOfGlobal = xr.open_mfdataset(fullFilePath, **kwargs)
  248. datesetOfGlobal.attrs['scanAxis'] = np.setdiff1d(datesetOfGlobal.attrs['scanAxis'], excludeAxis)
  249. return datesetOfGlobal