analyseScript/ToolFunction/HomeMadeXarrayFunction.py
2023-05-08 16:57:58 +02:00

274 lines
9.7 KiB
Python

from __future__ import annotations
from xarray.plot.dataarray_plot import _infer_line_data, _infer_xy_labels, _assert_valid_xy
from xarray.plot.facetgrid import _easy_facetgrid
from xarray.plot.utils import (
_LINEWIDTH_RANGE,
_MARKERSIZE_RANGE,
_ensure_plottable,
_resolve_intervals_1dplot,
_update_axes,
get_axis,
label_from_attrs,
)
from matplotlib.axes import Axes
from mpl_toolkits.mplot3d.art3d import Line3D
import numpy as np
from numpy.typing import ArrayLike
from xarray.core.dataarray import DataArray
from xarray.core.types import (
AspectOptions,
ScaleOptions,
)
from xarray.plot.facetgrid import FacetGrid
def _infer_errorbar_data(
darray: DataArray,
xerrdarray: DataArray | None,
yerrdarray: DataArray | None,
x: Hashable | None,
y: Hashable | None,
hue: Hashable | None
) -> tuple[DataArray, DataArray, DataArray | None, str]:
ndims = len(darray.dims)
if x is not None and y is not None:
raise ValueError("Cannot specify both x and y kwargs for line plots.")
if x is not None:
_assert_valid_xy(darray, x, "x")
if y is not None:
_assert_valid_xy(darray, y, "y")
if ndims == 1:
huename = None
hueplt = None
huelabel = ""
xerrplt = None
yerrplt = None
if x is not None:
xplt = darray[x]
if xerrdarray is not None:
xerrplt = xerrdarray[x]
yplt = darray
if yerrdarray is not None:
yerrplt = yerrdarray
elif y is not None:
xplt = darray
if xerrdarray is not None:
xerrplt = xerrdarray
yplt = darray[y]
if yerrdarray is not None:
yerrplt = yerrdarray[y]
else: # Both x & y are None
dim = darray.dims[0]
xplt = darray[dim]
yplt = darray
if xerrdarray is not None:
xerrplt = xerrdarray[dim]
if yerrdarray is not None:
yerrplt = yerrdarray
else:
if x is None and y is None and hue is None:
raise ValueError("For 2D inputs, please specify either hue, x or y.")
if y is None:
if hue is not None:
_assert_valid_xy(darray, hue, "hue")
xname, huename = _infer_xy_labels(darray=darray, x=x, y=hue)
xplt = darray[xname]
if xerrdarray is not None:
xerrplt = xerrdarray[xname]
if xplt.ndim > 1:
if huename in darray.dims:
otherindex = 1 if darray.dims.index(huename) == 0 else 0
otherdim = darray.dims[otherindex]
yplt = darray.transpose(otherdim, huename, transpose_coords=False)
if yerrdarray is not None:
yerrplt = yerrdarray.transpose(otherdim, huename, transpose_coords=False)
xplt = xplt.transpose(otherdim, huename, transpose_coords=False)
if xerrdarray is not None:
xerrplt = xerrplt.transpose(otherdim, huename, transpose_coords=False)
else:
raise ValueError(
"For 2D inputs, hue must be a dimension"
" i.e. one of " + repr(darray.dims)
)
else:
(xdim,) = darray[xname].dims
(huedim,) = darray[huename].dims
yplt = darray.transpose(xdim, huedim)
if yerrdarray is not None:
yerrplt = yerrdarray.transpose(xdim, huedim)
else:
yname, huename = _infer_xy_labels(darray=darray, x=y, y=hue)
yplt = darray[yname]
if yerrdarray is not None:
yerrplt = yerrdarray[yname]
if yplt.ndim > 1:
if huename in darray.dims:
otherindex = 1 if darray.dims.index(huename) == 0 else 0
otherdim = darray.dims[otherindex]
xplt = darray.transpose(otherdim, huename, transpose_coords=False)
if xerrdarray is not None:
xerrplt = xerrdarray.transpose(otherdim, huename, transpose_coords=False)
yplt = yplt.transpose(otherdim, huename, transpose_coords=False)
if yerrdarray is not None:
yerrplt = yerrplt.transpose(otherdim, huename, transpose_coords=False)
else:
raise ValueError(
"For 2D inputs, hue must be a dimension"
" i.e. one of " + repr(darray.dims)
)
else:
(ydim,) = darray[yname].dims
(huedim,) = darray[huename].dims
xplt = darray.transpose(ydim, huedim)
if xerrdarray is not None:
xerrplt = xerrdarray.transpose(ydim, huedim)
huelabel = label_from_attrs(darray[huename])
hueplt = darray[huename]
return xplt, yplt, xerrplt, yerrplt, hueplt, huelabel
def errorbar(
darray: DataArray,
*args: Any,
xerr: Hashable | DataArray | None = None,
yerr: Hashable | DataArray | None = None,
row: Hashable | None = None,
col: Hashable | None = None,
figsize: Iterable[float] | None = None,
aspect: AspectOptions = None,
size: float | None = None,
ax: Axes | None = None,
hue: Hashable | None = None,
x: Hashable | None = None,
y: Hashable | None = None,
xincrease: bool | None = None,
yincrease: bool | None = None,
xscale: ScaleOptions = None,
yscale: ScaleOptions = None,
xticks: ArrayLike | None = None,
yticks: ArrayLike | None = None,
xlim: ArrayLike | None = None,
ylim: ArrayLike | None = None,
add_legend: bool = True,
_labels: bool = True,
**kwargs: Any,
) -> list[Line3D] | FacetGrid[DataArray]:
# Handle facetgrids first
if row or col:
allargs = locals().copy()
allargs.update(allargs.pop("kwargs"))
allargs.pop("darray")
return _easy_facetgrid(darray, line, kind="line", **allargs)
ndims = len(darray.dims)
if ndims == 0 or darray.size == 0:
# TypeError to be consistent with pandas
raise TypeError("No numeric data to plot.")
if ndims > 2:
raise ValueError(
"Line plots are for 1- or 2-dimensional DataArrays. "
"Passed DataArray has {ndims} "
"dimensions".format(ndims=ndims)
)
# The allargs dict passed to _easy_facetgrid above contains args
if args == ():
args = kwargs.pop("args", ())
else:
assert "args" not in kwargs
ax = get_axis(figsize, size, aspect, ax)
if isinstance(xerr, DataArray) or isinstance(yerr, DataArray):
xplt, yplt, xerr, yerr, hueplt, hue_label = _infer_errorbar_data(darray, xerr, yerr, x, y, hue)
else:
xplt, yplt, hueplt, hue_label = _infer_line_data(darray, x, y, hue)
# Remove pd.Intervals if contained in xplt.values and/or yplt.values.
xplt_val, yplt_val, x_suffix, y_suffix, kwargs = _resolve_intervals_1dplot(
xplt.to_numpy(), yplt.to_numpy(), kwargs
)
xlabel = label_from_attrs(xplt, extra=x_suffix)
ylabel = label_from_attrs(yplt, extra=y_suffix)
_ensure_plottable(xplt_val, yplt_val)
fmt = None
if 'fmt' in kwargs and isinstance(kwargs['fmt'], list):
fmt = kwargs["fmt"]
if len(np.shape(xplt_val)) == len(np.shape(yplt_val)):
primitive = ax.errorbar(xplt_val, yplt_val, *args, xerr=xerr, yerr=yerr, **kwargs)
else:
primitive = np.empty(np.shape(yplt_val)[1], dtype=object)
if not yerr is None:
if not fmt is None:
for i in range(np.shape(yplt_val)[1]):
kwargs.update({'fmt': fmt[i]})
primitive[i] = ax.errorbar(xplt_val, yplt_val[:, i], *args, xerr=xerr, yerr=yerr[:, i], **kwargs)
else:
for i in range(np.shape(yplt_val)[1]):
primitive[i] = ax.errorbar(xplt_val, yplt_val[:, i], *args, xerr=xerr, yerr=yerr[:, i], **kwargs)
else:
if not fmt is None:
for i in range(np.shape(yplt_val)[1]):
kwargs.update({'fmt': fmt[i]})
primitive[i] = ax.errorbar(xplt_val, yplt_val[:, i], *args, xerr=xerr, yerr=yerr, **kwargs)
else:
for i in range(np.shape(yplt_val)[1]):
primitive[i] = ax.errorbar(xplt_val, yplt_val[:, i], *args, xerr=xerr, yerr=yerr, **kwargs)
primitive = tuple(primitive)
if _labels:
if xlabel is not None:
ax.set_xlabel(xlabel)
if ylabel is not None:
ax.set_ylabel(ylabel)
ax.set_title(darray._title_for_slice())
if darray.ndim == 2 and add_legend:
assert hueplt is not None
ax.legend(handles=primitive, labels=list(hueplt.to_numpy()), title=hue_label)
# Rotate dates on xlabels
# Do this without calling autofmt_xdate so that x-axes ticks
# on other subplots (if any) are not deleted.
# https://stackoverflow.com/questions/17430105/autofmt-xdate-deletes-x-axis-labels-of-all-subplots
if np.issubdtype(xplt.dtype, np.datetime64):
for xlabels in ax.get_xticklabels():
xlabels.set_rotation(30)
xlabels.set_horizontalalignment("right")
_update_axes(ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim)
return primitive
from xarray.plot.accessor import DataArrayPlotAccessor
# from xarray.plot.accessor import DatasetPlotAccessor
def dataarray_plot_errorbar(DataArrayPlotAccessor, *args, **kwargs) -> list[Line3D] | FacetGrid[DataArray]:
return errorbar(DataArrayPlotAccessor._da, *args, **kwargs)