Browse Source

not finished

joschka_dev
Jianshun Gao 1 year ago
parent
commit
0925cf63b6
  1. 14
      Analyser/FitAnalyser.py
  2. 246
      ToolFunction/HomeMadeXarrayFunction.py
  3. 8
      ToolFunction/ToolFunction.py
  4. 1328
      test.ipynb

14
Analyser/FitAnalyser.py

@ -293,13 +293,13 @@ class FitAnalyser():
x = dataArray['x'].to_numpy() x = dataArray['x'].to_numpy()
else: else:
if isinstance(x, str): if isinstance(x, str):
x = dataArray[x].to_numpy()
if input_core_dims is None: if input_core_dims is None:
kwargs.update( kwargs.update(
{ {
"input_core_dims": [[x]], "input_core_dims": [[x]],
} }
) )
x = dataArray[x].to_numpy()
if self.fitDim == 1: if self.fitDim == 1:
@ -327,8 +327,8 @@ class FitAnalyser():
) )
else: else:
if isinstance(y, str): if isinstance(y, str):
y = dataArray[y].to_numpy()
kwargs["input_core_dims"][0] = np.append(kwargs["input_core_dims"][0], y) kwargs["input_core_dims"][0] = np.append(kwargs["input_core_dims"][0], y)
y = dataArray[y].to_numpy()
elif input_core_dims is None: elif input_core_dims is None:
kwargs.update( kwargs.update(
{ {
@ -386,13 +386,13 @@ class FitAnalyser():
x = dataArray['x'].to_numpy() x = dataArray['x'].to_numpy()
else: else:
if isinstance(x, str): if isinstance(x, str):
x = dataArray[x].to_numpy()
if input_core_dims is None: if input_core_dims is None:
kwargs.update( kwargs.update(
{ {
"input_core_dims": [[x], []], "input_core_dims": [[x], []],
} }
) )
x = dataArray[x].to_numpy()
if isinstance(paramsArray, type(self.fitModel.make_params())): if isinstance(paramsArray, type(self.fitModel.make_params())):
@ -414,8 +414,8 @@ class FitAnalyser():
) )
else: else:
if isinstance(y, str): if isinstance(y, str):
y = dataArray[y].to_numpy()
kwargs["input_core_dims"][0] = np.append(kwargs["input_core_dims"][0], y) kwargs["input_core_dims"][0] = np.append(kwargs["input_core_dims"][0], y)
y = dataArray[y].to_numpy()
elif input_core_dims is None: elif input_core_dims is None:
kwargs.update( kwargs.update(
{ {
@ -492,8 +492,7 @@ class FitAnalyser():
if output_core_dims is None: if output_core_dims is None:
kwargs.update( kwargs.update(
{ {
"output_core_dims": [[prefix+'x']],
"output_dtypes": float,
"output_core_dims": prefix+'x',
} }
) )
output_core_dims = [prefix+'x'] output_core_dims = [prefix+'x']
@ -527,6 +526,7 @@ class FitAnalyser():
output_core_dims[0]: np.size(x), output_core_dims[0]: np.size(x),
output_core_dims[1]: np.size(y), output_core_dims[1]: np.size(y),
}, },
# 'output_dtypes': float,
# 'output_dtypes': { # 'output_dtypes': {
# output_core_dims[0]: float, # output_core_dims[0]: float,
# output_core_dims[1]: float, # output_core_dims[1]: float,
@ -535,6 +535,8 @@ class FitAnalyser():
} }
) )
# del kwargs['output_dtypes']
_x, _y = np.meshgrid(x, y) _x, _y = np.meshgrid(x, y)
_x = _x.flatten() _x = _x.flatten()
_y = _y.flatten() _y = _y.flatten()

246
ToolFunction/HomeMadeXarrayFunction.py

@ -0,0 +1,246 @@
from __future__ import annotations
from xarray.plot.dataarray_plot import _infer_line_data, _infer_xy_labels, _assert_valid_xy
from xarray.plot.facetgrid import _easy_facetgrid
from xarray.plot.utils import (
_LINEWIDTH_RANGE,
_MARKERSIZE_RANGE,
_ensure_plottable,
_resolve_intervals_1dplot,
_update_axes,
get_axis,
label_from_attrs,
)
from matplotlib.axes import Axes
from mpl_toolkits.mplot3d.art3d import Line3D
import numpy as np
from numpy.typing import ArrayLike
from xarray.core.dataarray import DataArray
from xarray.core.types import (
AspectOptions,
ScaleOptions,
)
from xarray.plot.facetgrid import FacetGrid
def _infer_errorbar_data(
darray: DataArray,
xerrdarray: DataArray | None,
yerrdarray: DataArray | None,
x: Hashable | None,
y: Hashable | None,
hue: Hashable | None
) -> tuple[DataArray, DataArray, DataArray | None, str]:
ndims = len(darray.dims)
if x is not None and y is not None:
raise ValueError("Cannot specify both x and y kwargs for line plots.")
if x is not None:
_assert_valid_xy(darray, x, "x")
if y is not None:
_assert_valid_xy(darray, y, "y")
if ndims == 1:
huename = None
hueplt = None
huelabel = ""
xerrplt = None
yerrplt = None
if x is not None:
xplt = darray[x]
if xerrdarray is not None:
xerrplt = xerrdarray[x]
yplt = darray
if yerrdarray is not None:
yerrplt = yerrdarray
elif y is not None:
xplt = darray
if xerrdarray is not None:
xerrplt = xerrdarray
yplt = darray[y]
if yerrdarray is not None:
yerrplt = yerrdarray[y]
else: # Both x & y are None
dim = darray.dims[0]
xplt = darray[dim]
yplt = darray
if xerrdarray is not None:
xerrplt = xerrdarray[dim]
if yerrdarray is not None:
yerrplt = yerrdarray
else:
if x is None and y is None and hue is None:
raise ValueError("For 2D inputs, please specify either hue, x or y.")
if y is None:
if hue is not None:
_assert_valid_xy(darray, hue, "hue")
xname, huename = _infer_xy_labels(darray=darray, x=x, y=hue)
xplt = darray[xname]
if xerrdarray is not None:
xerrplt = xerrdarray[xname]
if xplt.ndim > 1:
if huename in darray.dims:
otherindex = 1 if darray.dims.index(huename) == 0 else 0
otherdim = darray.dims[otherindex]
yplt = darray.transpose(otherdim, huename, transpose_coords=False)
if yerrdarray is not None:
yerrplt = yerrdarray.transpose(otherdim, huename, transpose_coords=False)
xplt = xplt.transpose(otherdim, huename, transpose_coords=False)
if xerrdarray is not None:
xerrplt = xerrplt.transpose(otherdim, huename, transpose_coords=False)
else:
raise ValueError(
"For 2D inputs, hue must be a dimension"
" i.e. one of " + repr(darray.dims)
)
else:
(xdim,) = darray[xname].dims
(huedim,) = darray[huename].dims
yplt = darray.transpose(xdim, huedim)
if yerrdarray is not None:
yerrplt = yerrdarray.transpose(xdim, huedim)
else:
yname, huename = _infer_xy_labels(darray=darray, x=y, y=hue)
yplt = darray[yname]
if yerrdarray is not None:
yerrplt = yerrdarray[yname]
if yplt.ndim > 1:
if huename in darray.dims:
otherindex = 1 if darray.dims.index(huename) == 0 else 0
otherdim = darray.dims[otherindex]
xplt = darray.transpose(otherdim, huename, transpose_coords=False)
if xerrdarray is not None:
xerrplt = xerrdarray.transpose(otherdim, huename, transpose_coords=False)
yplt = yplt.transpose(otherdim, huename, transpose_coords=False)
if yerrdarray is not None:
yerrplt = yerrplt.transpose(otherdim, huename, transpose_coords=False)
else:
raise ValueError(
"For 2D inputs, hue must be a dimension"
" i.e. one of " + repr(darray.dims)
)
else:
(ydim,) = darray[yname].dims
(huedim,) = darray[huename].dims
xplt = darray.transpose(ydim, huedim)
if xerrdarray is not None:
xerrplt = xerrdarray.transpose(ydim, huedim)
huelabel = label_from_attrs(darray[huename])
hueplt = darray[huename]
return xplt, yplt, xerrplt, yerrplt, hueplt, huelabel
def errorbar(
darray: DataArray,
*args: Any,
xerr: Hashable | DataArray | None = None,
yerr: Hashable | DataArray | None = None,
row: Hashable | None = None,
col: Hashable | None = None,
figsize: Iterable[float] | None = None,
aspect: AspectOptions = None,
size: float | None = None,
ax: Axes | None = None,
hue: Hashable | None = None,
x: Hashable | None = None,
y: Hashable | None = None,
xincrease: bool | None = None,
yincrease: bool | None = None,
xscale: ScaleOptions = None,
yscale: ScaleOptions = None,
xticks: ArrayLike | None = None,
yticks: ArrayLike | None = None,
xlim: ArrayLike | None = None,
ylim: ArrayLike | None = None,
add_legend: bool = True,
_labels: bool = True,
**kwargs: Any,
) -> list[Line3D] | FacetGrid[DataArray]:
# Handle facetgrids first
if row or col:
allargs = locals().copy()
allargs.update(allargs.pop("kwargs"))
allargs.pop("darray")
return _easy_facetgrid(darray, line, kind="line", **allargs)
ndims = len(darray.dims)
if ndims == 0 or darray.size == 0:
# TypeError to be consistent with pandas
raise TypeError("No numeric data to plot.")
if ndims > 2:
raise ValueError(
"Line plots are for 1- or 2-dimensional DataArrays. "
"Passed DataArray has {ndims} "
"dimensions".format(ndims=ndims)
)
# The allargs dict passed to _easy_facetgrid above contains args
if args == ():
args = kwargs.pop("args", ())
else:
assert "args" not in kwargs
ax = get_axis(figsize, size, aspect, ax)
if isinstance(xerr, DataArray) or isinstance(yerr, DataArray):
xplt, yplt, xerr, yerr, hueplt, hue_label = _infer_errorbar_data(darray, xerr, yerr, x, y, hue)
else:
xplt, yplt, hueplt, hue_label = _infer_line_data(darray, x, y, hue)
# Remove pd.Intervals if contained in xplt.values and/or yplt.values.
xplt_val, yplt_val, x_suffix, y_suffix, kwargs = _resolve_intervals_1dplot(
xplt.to_numpy(), yplt.to_numpy(), kwargs
)
xlabel = label_from_attrs(xplt, extra=x_suffix)
ylabel = label_from_attrs(yplt, extra=y_suffix)
_ensure_plottable(xplt_val, yplt_val)
primitive = ax.errorbar(xplt_val, yplt_val, *args, xerr=xerr, yerr=yerr, **kwargs)
if _labels:
if xlabel is not None:
ax.set_xlabel(xlabel)
if ylabel is not None:
ax.set_ylabel(ylabel)
ax.set_title(darray._title_for_slice())
if darray.ndim == 2 and add_legend:
assert hueplt is not None
ax.legend(handles=primitive, labels=list(hueplt.to_numpy()), title=hue_label)
# Rotate dates on xlabels
# Do this without calling autofmt_xdate so that x-axes ticks
# on other subplots (if any) are not deleted.
# https://stackoverflow.com/questions/17430105/autofmt-xdate-deletes-x-axis-labels-of-all-subplots
if np.issubdtype(xplt.dtype, np.datetime64):
for xlabels in ax.get_xticklabels():
xlabels.set_rotation(30)
xlabels.set_horizontalalignment("right")
_update_axes(ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim)
return primitive
from xarray.plot.accessor import DataArrayPlotAccessor
# from xarray.plot.accessor import DatasetPlotAccessor
def dataarray_plot_errorbar(DataArrayPlotAccessor, *args, **kwargs) -> list[Line3D] | FacetGrid[DataArray]:
return errorbar(DataArrayPlotAccessor._da, *args, **kwargs)

8
ToolFunction/ToolFunction.py

@ -14,7 +14,7 @@ def remove_bad_shots(dataArray, **kwargs):
def auto_rechunk(dataSet): def auto_rechunk(dataSet):
kwargs = { kwargs = {
key: "auto" key: "auto"
for key in dataSet.dims.keys()
for key in dataSet.dims
} }
return dataSet.chunk(**kwargs) return dataSet.chunk(**kwargs)
@ -30,3 +30,9 @@ def get_h5_file_path(folderpath, maxFileNum=None, filename='*.h5',):
def get_date(): def get_date():
today = date.today() today = date.today()
return today.strftime("%y/%m/%d") return today.strftime("%y/%m/%d")
def resolve_fit_result(fitResult):
return

1328
test.ipynb
File diff suppressed because one or more lines are too long
View File

Loading…
Cancel
Save