You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

187 lines
5.1 KiB

1 year ago
  1. import xarray as xr
  2. import numpy as np
  3. from collections import OrderedDict
  4. from functools import partial
  5. import copy
  6. import glob
  7. import os
  8. def _read_globals_attrs(variable_attrs, context=None):
  9. """Combine attributes from different variables according to combine_attrs"""
  10. if not variable_attrs:
  11. # no attributes to merge
  12. return None
  13. from xarray.core.utils import equivalent
  14. result = {}
  15. dropped_attrs = OrderedDict()
  16. for attrs in variable_attrs:
  17. result.update(
  18. {
  19. key: value
  20. for key, value in attrs.items()
  21. if key not in result and key not in dropped_attrs.keys()
  22. }
  23. )
  24. result = {
  25. key: value
  26. for key, value in result.items()
  27. if key not in attrs or equivalent(attrs[key], value)
  28. }
  29. dropped_attrs.update(
  30. {
  31. key: []
  32. for key in attrs if key not in result
  33. }
  34. )
  35. for attrs in variable_attrs:
  36. dropped_attrs.update(
  37. {
  38. key: np.append(dropped_attrs[key], attrs[key])
  39. for key in dropped_attrs.keys()
  40. }
  41. )
  42. scan_attrs = OrderedDict()
  43. scan_length = []
  44. for attrs_key in dropped_attrs.keys():
  45. flag = True
  46. for key in scan_attrs.keys():
  47. if equivalent(scan_attrs[key], dropped_attrs[attrs_key]):
  48. flag = False
  49. result.update({attrs_key: key})
  50. break
  51. if flag:
  52. scan_attrs.update({
  53. attrs_key: dropped_attrs[attrs_key]
  54. })
  55. scan_length = np.append(scan_length, len(dropped_attrs[attrs_key]))
  56. result.update(
  57. {
  58. key: value
  59. for key, value in scan_attrs.items()
  60. }
  61. )
  62. result.update(
  63. {
  64. "scanAxis": list(scan_attrs.keys()),
  65. "scanAxisLength": scan_length,
  66. }
  67. )
  68. # if result['scanAxis'] == []:
  69. # result['scanAxis'] = ['runs',]
  70. return result
  71. def _read_shot_number_from_hdf5(x):
  72. filePath = x.encoding["source"]
  73. shotNum = filePath.split("_")[-1].split("_")[-1].split(".")[0]
  74. return x.assign(shotNum=shotNum)
  75. def _assign_scan_axis_partial(x, datesetOfGlobal, fullFilePath):
  76. scanAxis = datesetOfGlobal.scanAxis
  77. filePath = x.encoding["source"].replace("\\", "/")
  78. shotNum = np.where(fullFilePath==filePath)
  79. shotNum = np.squeeze(shotNum)
  80. # shotNum = filePath.split("_")[-1].split("_")[-1].split(".")[0]
  81. x = x.assign(shotNum=shotNum)
  82. x = x.expand_dims(list(scanAxis))
  83. return x.assign_coords(
  84. {
  85. key: np.atleast_1d(np.atleast_1d(datesetOfGlobal.attrs[key])[int(shotNum)])
  86. for key in scanAxis
  87. }
  88. )
  89. def _update_globals_attrs(variable_attrs, context=None):
  90. pass
  91. def update_hdf5_file():
  92. pass
  93. def read_hdf5_file(filePath, group=None, datesetOfGlobal=None, preprocess=None, join="outer", parallel=True, engine="h5netcdf", phony_dims="access", excludeAxis=[], maxFileNum=None, **kwargs):
  94. filePath = np.sort(np.atleast_1d(filePath))
  95. filePathAbs = []
  96. for i in range(len(filePath)):
  97. filePathAbs.append(os.path.abspath(filePath[i]).replace("\\", "/"))
  98. fullFilePath = []
  99. for i in range(len(filePathAbs)):
  100. fullFilePath.append(list(np.sort(glob.glob(filePathAbs[i]))))
  101. fullFilePath = np.array(fullFilePath).flatten()
  102. for i in range(len(fullFilePath)):
  103. fullFilePath[i] = fullFilePath[i].replace("\\", "/")
  104. if not maxFileNum is None:
  105. fullFilePath = fullFilePath[0:int(maxFileNum)]
  106. kwargs.update(
  107. {
  108. 'join': join,
  109. 'parallel': parallel,
  110. 'engine': engine,
  111. 'phony_dims': phony_dims,
  112. 'group': group
  113. }
  114. )
  115. if datesetOfGlobal is None:
  116. datesetOfGlobal = xr.open_mfdataset(
  117. fullFilePath,
  118. group="globals",
  119. concat_dim="fileNum",
  120. combine="nested",
  121. preprocess=_read_shot_number_from_hdf5,
  122. engine="h5netcdf",
  123. phony_dims="access",
  124. combine_attrs=_read_globals_attrs,
  125. parallel=True, )
  126. datesetOfGlobal.attrs['scanAxis'] = np.setdiff1d(datesetOfGlobal.attrs['scanAxis'], excludeAxis)
  127. _assgin_scan_axis = partial(_assign_scan_axis_partial, datesetOfGlobal=datesetOfGlobal, fullFilePath=fullFilePath)
  128. if preprocess is None:
  129. kwargs.update({'preprocess':_assgin_scan_axis})
  130. else:
  131. kwargs.update({'preprocess':preprocess})
  132. ds = xr.open_mfdataset(fullFilePath, **kwargs)
  133. newDimKey = np.append(['x', 'y', 'z'], [ chr(i) for i in range(97, 97+23)])
  134. oldDimKey = np.sort(
  135. [
  136. key
  137. for key in ds.dims
  138. if not key in datesetOfGlobal.scanAxis
  139. ]
  140. )
  141. renameDict = {
  142. oldDimKey[j]: newDimKey[j]
  143. for j in range(len(oldDimKey))
  144. }
  145. ds = ds.rename_dims(renameDict)
  146. ds.attrs = copy.deepcopy(datesetOfGlobal.attrs)
  147. return ds