Jianshun Gao
1 year ago
4 changed files with 618 additions and 980 deletions
-
14Analyser/FitAnalyser.py
-
246ToolFunction/HomeMadeXarrayFunction.py
-
8ToolFunction/ToolFunction.py
-
1328test.ipynb
@ -0,0 +1,246 @@ |
|||||
|
from __future__ import annotations |
||||
|
|
||||
|
from xarray.plot.dataarray_plot import _infer_line_data, _infer_xy_labels, _assert_valid_xy |
||||
|
from xarray.plot.facetgrid import _easy_facetgrid |
||||
|
from xarray.plot.utils import ( |
||||
|
_LINEWIDTH_RANGE, |
||||
|
_MARKERSIZE_RANGE, |
||||
|
_ensure_plottable, |
||||
|
_resolve_intervals_1dplot, |
||||
|
_update_axes, |
||||
|
get_axis, |
||||
|
label_from_attrs, |
||||
|
) |
||||
|
|
||||
|
from matplotlib.axes import Axes |
||||
|
from mpl_toolkits.mplot3d.art3d import Line3D |
||||
|
import numpy as np |
||||
|
from numpy.typing import ArrayLike |
||||
|
|
||||
|
from xarray.core.dataarray import DataArray |
||||
|
from xarray.core.types import ( |
||||
|
AspectOptions, |
||||
|
ScaleOptions, |
||||
|
) |
||||
|
from xarray.plot.facetgrid import FacetGrid |
||||
|
|
||||
|
|
||||
|
def _infer_errorbar_data( |
||||
|
darray: DataArray, |
||||
|
xerrdarray: DataArray | None, |
||||
|
yerrdarray: DataArray | None, |
||||
|
x: Hashable | None, |
||||
|
y: Hashable | None, |
||||
|
hue: Hashable | None |
||||
|
) -> tuple[DataArray, DataArray, DataArray | None, str]: |
||||
|
ndims = len(darray.dims) |
||||
|
|
||||
|
if x is not None and y is not None: |
||||
|
raise ValueError("Cannot specify both x and y kwargs for line plots.") |
||||
|
|
||||
|
if x is not None: |
||||
|
_assert_valid_xy(darray, x, "x") |
||||
|
|
||||
|
if y is not None: |
||||
|
_assert_valid_xy(darray, y, "y") |
||||
|
|
||||
|
if ndims == 1: |
||||
|
huename = None |
||||
|
hueplt = None |
||||
|
huelabel = "" |
||||
|
xerrplt = None |
||||
|
yerrplt = None |
||||
|
|
||||
|
if x is not None: |
||||
|
xplt = darray[x] |
||||
|
if xerrdarray is not None: |
||||
|
xerrplt = xerrdarray[x] |
||||
|
yplt = darray |
||||
|
if yerrdarray is not None: |
||||
|
yerrplt = yerrdarray |
||||
|
|
||||
|
elif y is not None: |
||||
|
xplt = darray |
||||
|
if xerrdarray is not None: |
||||
|
xerrplt = xerrdarray |
||||
|
yplt = darray[y] |
||||
|
if yerrdarray is not None: |
||||
|
yerrplt = yerrdarray[y] |
||||
|
|
||||
|
else: # Both x & y are None |
||||
|
dim = darray.dims[0] |
||||
|
xplt = darray[dim] |
||||
|
yplt = darray |
||||
|
if xerrdarray is not None: |
||||
|
xerrplt = xerrdarray[dim] |
||||
|
if yerrdarray is not None: |
||||
|
yerrplt = yerrdarray |
||||
|
|
||||
|
else: |
||||
|
if x is None and y is None and hue is None: |
||||
|
raise ValueError("For 2D inputs, please specify either hue, x or y.") |
||||
|
|
||||
|
if y is None: |
||||
|
if hue is not None: |
||||
|
_assert_valid_xy(darray, hue, "hue") |
||||
|
xname, huename = _infer_xy_labels(darray=darray, x=x, y=hue) |
||||
|
xplt = darray[xname] |
||||
|
if xerrdarray is not None: |
||||
|
xerrplt = xerrdarray[xname] |
||||
|
if xplt.ndim > 1: |
||||
|
if huename in darray.dims: |
||||
|
otherindex = 1 if darray.dims.index(huename) == 0 else 0 |
||||
|
otherdim = darray.dims[otherindex] |
||||
|
yplt = darray.transpose(otherdim, huename, transpose_coords=False) |
||||
|
if yerrdarray is not None: |
||||
|
yerrplt = yerrdarray.transpose(otherdim, huename, transpose_coords=False) |
||||
|
xplt = xplt.transpose(otherdim, huename, transpose_coords=False) |
||||
|
if xerrdarray is not None: |
||||
|
xerrplt = xerrplt.transpose(otherdim, huename, transpose_coords=False) |
||||
|
else: |
||||
|
raise ValueError( |
||||
|
"For 2D inputs, hue must be a dimension" |
||||
|
" i.e. one of " + repr(darray.dims) |
||||
|
) |
||||
|
|
||||
|
else: |
||||
|
(xdim,) = darray[xname].dims |
||||
|
(huedim,) = darray[huename].dims |
||||
|
yplt = darray.transpose(xdim, huedim) |
||||
|
if yerrdarray is not None: |
||||
|
yerrplt = yerrdarray.transpose(xdim, huedim) |
||||
|
|
||||
|
else: |
||||
|
yname, huename = _infer_xy_labels(darray=darray, x=y, y=hue) |
||||
|
yplt = darray[yname] |
||||
|
if yerrdarray is not None: |
||||
|
yerrplt = yerrdarray[yname] |
||||
|
if yplt.ndim > 1: |
||||
|
if huename in darray.dims: |
||||
|
otherindex = 1 if darray.dims.index(huename) == 0 else 0 |
||||
|
otherdim = darray.dims[otherindex] |
||||
|
xplt = darray.transpose(otherdim, huename, transpose_coords=False) |
||||
|
if xerrdarray is not None: |
||||
|
xerrplt = xerrdarray.transpose(otherdim, huename, transpose_coords=False) |
||||
|
yplt = yplt.transpose(otherdim, huename, transpose_coords=False) |
||||
|
if yerrdarray is not None: |
||||
|
yerrplt = yerrplt.transpose(otherdim, huename, transpose_coords=False) |
||||
|
else: |
||||
|
raise ValueError( |
||||
|
"For 2D inputs, hue must be a dimension" |
||||
|
" i.e. one of " + repr(darray.dims) |
||||
|
) |
||||
|
|
||||
|
else: |
||||
|
(ydim,) = darray[yname].dims |
||||
|
(huedim,) = darray[huename].dims |
||||
|
xplt = darray.transpose(ydim, huedim) |
||||
|
if xerrdarray is not None: |
||||
|
xerrplt = xerrdarray.transpose(ydim, huedim) |
||||
|
|
||||
|
huelabel = label_from_attrs(darray[huename]) |
||||
|
hueplt = darray[huename] |
||||
|
|
||||
|
return xplt, yplt, xerrplt, yerrplt, hueplt, huelabel |
||||
|
|
||||
|
|
||||
|
def errorbar( |
||||
|
darray: DataArray, |
||||
|
*args: Any, |
||||
|
xerr: Hashable | DataArray | None = None, |
||||
|
yerr: Hashable | DataArray | None = None, |
||||
|
row: Hashable | None = None, |
||||
|
col: Hashable | None = None, |
||||
|
figsize: Iterable[float] | None = None, |
||||
|
aspect: AspectOptions = None, |
||||
|
size: float | None = None, |
||||
|
ax: Axes | None = None, |
||||
|
hue: Hashable | None = None, |
||||
|
x: Hashable | None = None, |
||||
|
y: Hashable | None = None, |
||||
|
xincrease: bool | None = None, |
||||
|
yincrease: bool | None = None, |
||||
|
xscale: ScaleOptions = None, |
||||
|
yscale: ScaleOptions = None, |
||||
|
xticks: ArrayLike | None = None, |
||||
|
yticks: ArrayLike | None = None, |
||||
|
xlim: ArrayLike | None = None, |
||||
|
ylim: ArrayLike | None = None, |
||||
|
add_legend: bool = True, |
||||
|
_labels: bool = True, |
||||
|
**kwargs: Any, |
||||
|
) -> list[Line3D] | FacetGrid[DataArray]: |
||||
|
# Handle facetgrids first |
||||
|
if row or col: |
||||
|
allargs = locals().copy() |
||||
|
allargs.update(allargs.pop("kwargs")) |
||||
|
allargs.pop("darray") |
||||
|
return _easy_facetgrid(darray, line, kind="line", **allargs) |
||||
|
|
||||
|
ndims = len(darray.dims) |
||||
|
if ndims == 0 or darray.size == 0: |
||||
|
# TypeError to be consistent with pandas |
||||
|
raise TypeError("No numeric data to plot.") |
||||
|
if ndims > 2: |
||||
|
raise ValueError( |
||||
|
"Line plots are for 1- or 2-dimensional DataArrays. " |
||||
|
"Passed DataArray has {ndims} " |
||||
|
"dimensions".format(ndims=ndims) |
||||
|
) |
||||
|
|
||||
|
# The allargs dict passed to _easy_facetgrid above contains args |
||||
|
if args == (): |
||||
|
args = kwargs.pop("args", ()) |
||||
|
else: |
||||
|
assert "args" not in kwargs |
||||
|
|
||||
|
ax = get_axis(figsize, size, aspect, ax) |
||||
|
|
||||
|
if isinstance(xerr, DataArray) or isinstance(yerr, DataArray): |
||||
|
xplt, yplt, xerr, yerr, hueplt, hue_label = _infer_errorbar_data(darray, xerr, yerr, x, y, hue) |
||||
|
else: |
||||
|
xplt, yplt, hueplt, hue_label = _infer_line_data(darray, x, y, hue) |
||||
|
|
||||
|
# Remove pd.Intervals if contained in xplt.values and/or yplt.values. |
||||
|
xplt_val, yplt_val, x_suffix, y_suffix, kwargs = _resolve_intervals_1dplot( |
||||
|
xplt.to_numpy(), yplt.to_numpy(), kwargs |
||||
|
) |
||||
|
xlabel = label_from_attrs(xplt, extra=x_suffix) |
||||
|
ylabel = label_from_attrs(yplt, extra=y_suffix) |
||||
|
|
||||
|
_ensure_plottable(xplt_val, yplt_val) |
||||
|
|
||||
|
primitive = ax.errorbar(xplt_val, yplt_val, *args, xerr=xerr, yerr=yerr, **kwargs) |
||||
|
|
||||
|
if _labels: |
||||
|
if xlabel is not None: |
||||
|
ax.set_xlabel(xlabel) |
||||
|
|
||||
|
if ylabel is not None: |
||||
|
ax.set_ylabel(ylabel) |
||||
|
|
||||
|
ax.set_title(darray._title_for_slice()) |
||||
|
|
||||
|
if darray.ndim == 2 and add_legend: |
||||
|
assert hueplt is not None |
||||
|
ax.legend(handles=primitive, labels=list(hueplt.to_numpy()), title=hue_label) |
||||
|
|
||||
|
# Rotate dates on xlabels |
||||
|
# Do this without calling autofmt_xdate so that x-axes ticks |
||||
|
# on other subplots (if any) are not deleted. |
||||
|
# https://stackoverflow.com/questions/17430105/autofmt-xdate-deletes-x-axis-labels-of-all-subplots |
||||
|
if np.issubdtype(xplt.dtype, np.datetime64): |
||||
|
for xlabels in ax.get_xticklabels(): |
||||
|
xlabels.set_rotation(30) |
||||
|
xlabels.set_horizontalalignment("right") |
||||
|
|
||||
|
_update_axes(ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim) |
||||
|
|
||||
|
return primitive |
||||
|
|
||||
|
|
||||
|
from xarray.plot.accessor import DataArrayPlotAccessor |
||||
|
# from xarray.plot.accessor import DatasetPlotAccessor |
||||
|
|
||||
|
def dataarray_plot_errorbar(DataArrayPlotAccessor, *args, **kwargs) -> list[Line3D] | FacetGrid[DataArray]: |
||||
|
return errorbar(DataArrayPlotAccessor._da, *args, **kwargs) |
1328
test.ipynb
File diff suppressed because one or more lines are too long
View File
File diff suppressed because one or more lines are too long
View File
Write
Preview
Loading…
Cancel
Save
Reference in new issue