274 lines
9.7 KiB
Python
274 lines
9.7 KiB
Python
from __future__ import annotations
|
|
|
|
from xarray.plot.dataarray_plot import _infer_line_data, _infer_xy_labels, _assert_valid_xy
|
|
from xarray.plot.facetgrid import _easy_facetgrid
|
|
from xarray.plot.utils import (
|
|
_LINEWIDTH_RANGE,
|
|
_MARKERSIZE_RANGE,
|
|
_ensure_plottable,
|
|
_resolve_intervals_1dplot,
|
|
_update_axes,
|
|
get_axis,
|
|
label_from_attrs,
|
|
)
|
|
|
|
from matplotlib.axes import Axes
|
|
from mpl_toolkits.mplot3d.art3d import Line3D
|
|
import numpy as np
|
|
from numpy.typing import ArrayLike
|
|
|
|
from xarray.core.dataarray import DataArray
|
|
from xarray.core.types import (
|
|
AspectOptions,
|
|
ScaleOptions,
|
|
)
|
|
from xarray.plot.facetgrid import FacetGrid
|
|
|
|
|
|
def _infer_errorbar_data(
|
|
darray: DataArray,
|
|
xerrdarray: DataArray | None,
|
|
yerrdarray: DataArray | None,
|
|
x: Hashable | None,
|
|
y: Hashable | None,
|
|
hue: Hashable | None
|
|
) -> tuple[DataArray, DataArray, DataArray | None, str]:
|
|
ndims = len(darray.dims)
|
|
|
|
if x is not None and y is not None:
|
|
raise ValueError("Cannot specify both x and y kwargs for line plots.")
|
|
|
|
if x is not None:
|
|
_assert_valid_xy(darray, x, "x")
|
|
|
|
if y is not None:
|
|
_assert_valid_xy(darray, y, "y")
|
|
|
|
if ndims == 1:
|
|
huename = None
|
|
hueplt = None
|
|
huelabel = ""
|
|
xerrplt = None
|
|
yerrplt = None
|
|
|
|
if x is not None:
|
|
xplt = darray[x]
|
|
if xerrdarray is not None:
|
|
xerrplt = xerrdarray[x]
|
|
yplt = darray
|
|
if yerrdarray is not None:
|
|
yerrplt = yerrdarray
|
|
|
|
elif y is not None:
|
|
xplt = darray
|
|
if xerrdarray is not None:
|
|
xerrplt = xerrdarray
|
|
yplt = darray[y]
|
|
if yerrdarray is not None:
|
|
yerrplt = yerrdarray[y]
|
|
|
|
else: # Both x & y are None
|
|
dim = darray.dims[0]
|
|
xplt = darray[dim]
|
|
yplt = darray
|
|
if xerrdarray is not None:
|
|
xerrplt = xerrdarray[dim]
|
|
if yerrdarray is not None:
|
|
yerrplt = yerrdarray
|
|
|
|
else:
|
|
if x is None and y is None and hue is None:
|
|
raise ValueError("For 2D inputs, please specify either hue, x or y.")
|
|
|
|
if y is None:
|
|
if hue is not None:
|
|
_assert_valid_xy(darray, hue, "hue")
|
|
xname, huename = _infer_xy_labels(darray=darray, x=x, y=hue)
|
|
xplt = darray[xname]
|
|
if xerrdarray is not None:
|
|
xerrplt = xerrdarray[xname]
|
|
if xplt.ndim > 1:
|
|
if huename in darray.dims:
|
|
otherindex = 1 if darray.dims.index(huename) == 0 else 0
|
|
otherdim = darray.dims[otherindex]
|
|
yplt = darray.transpose(otherdim, huename, transpose_coords=False)
|
|
if yerrdarray is not None:
|
|
yerrplt = yerrdarray.transpose(otherdim, huename, transpose_coords=False)
|
|
xplt = xplt.transpose(otherdim, huename, transpose_coords=False)
|
|
if xerrdarray is not None:
|
|
xerrplt = xerrplt.transpose(otherdim, huename, transpose_coords=False)
|
|
else:
|
|
raise ValueError(
|
|
"For 2D inputs, hue must be a dimension"
|
|
" i.e. one of " + repr(darray.dims)
|
|
)
|
|
|
|
else:
|
|
(xdim,) = darray[xname].dims
|
|
(huedim,) = darray[huename].dims
|
|
yplt = darray.transpose(xdim, huedim)
|
|
if yerrdarray is not None:
|
|
yerrplt = yerrdarray.transpose(xdim, huedim)
|
|
|
|
else:
|
|
yname, huename = _infer_xy_labels(darray=darray, x=y, y=hue)
|
|
yplt = darray[yname]
|
|
if yerrdarray is not None:
|
|
yerrplt = yerrdarray[yname]
|
|
if yplt.ndim > 1:
|
|
if huename in darray.dims:
|
|
otherindex = 1 if darray.dims.index(huename) == 0 else 0
|
|
otherdim = darray.dims[otherindex]
|
|
xplt = darray.transpose(otherdim, huename, transpose_coords=False)
|
|
if xerrdarray is not None:
|
|
xerrplt = xerrdarray.transpose(otherdim, huename, transpose_coords=False)
|
|
yplt = yplt.transpose(otherdim, huename, transpose_coords=False)
|
|
if yerrdarray is not None:
|
|
yerrplt = yerrplt.transpose(otherdim, huename, transpose_coords=False)
|
|
else:
|
|
raise ValueError(
|
|
"For 2D inputs, hue must be a dimension"
|
|
" i.e. one of " + repr(darray.dims)
|
|
)
|
|
|
|
else:
|
|
(ydim,) = darray[yname].dims
|
|
(huedim,) = darray[huename].dims
|
|
xplt = darray.transpose(ydim, huedim)
|
|
if xerrdarray is not None:
|
|
xerrplt = xerrdarray.transpose(ydim, huedim)
|
|
|
|
huelabel = label_from_attrs(darray[huename])
|
|
hueplt = darray[huename]
|
|
|
|
return xplt, yplt, xerrplt, yerrplt, hueplt, huelabel
|
|
|
|
|
|
def errorbar(
|
|
darray: DataArray,
|
|
*args: Any,
|
|
xerr: Hashable | DataArray | None = None,
|
|
yerr: Hashable | DataArray | None = None,
|
|
row: Hashable | None = None,
|
|
col: Hashable | None = None,
|
|
figsize: Iterable[float] | None = None,
|
|
aspect: AspectOptions = None,
|
|
size: float | None = None,
|
|
ax: Axes | None = None,
|
|
hue: Hashable | None = None,
|
|
x: Hashable | None = None,
|
|
y: Hashable | None = None,
|
|
xincrease: bool | None = None,
|
|
yincrease: bool | None = None,
|
|
xscale: ScaleOptions = None,
|
|
yscale: ScaleOptions = None,
|
|
xticks: ArrayLike | None = None,
|
|
yticks: ArrayLike | None = None,
|
|
xlim: ArrayLike | None = None,
|
|
ylim: ArrayLike | None = None,
|
|
add_legend: bool = True,
|
|
_labels: bool = True,
|
|
**kwargs: Any,
|
|
) -> list[Line3D] | FacetGrid[DataArray]:
|
|
# Handle facetgrids first
|
|
if row or col:
|
|
allargs = locals().copy()
|
|
allargs.update(allargs.pop("kwargs"))
|
|
allargs.pop("darray")
|
|
return _easy_facetgrid(darray, line, kind="line", **allargs)
|
|
|
|
ndims = len(darray.dims)
|
|
if ndims == 0 or darray.size == 0:
|
|
# TypeError to be consistent with pandas
|
|
raise TypeError("No numeric data to plot.")
|
|
if ndims > 2:
|
|
raise ValueError(
|
|
"Line plots are for 1- or 2-dimensional DataArrays. "
|
|
"Passed DataArray has {ndims} "
|
|
"dimensions".format(ndims=ndims)
|
|
)
|
|
|
|
# The allargs dict passed to _easy_facetgrid above contains args
|
|
if args == ():
|
|
args = kwargs.pop("args", ())
|
|
else:
|
|
assert "args" not in kwargs
|
|
|
|
ax = get_axis(figsize, size, aspect, ax)
|
|
|
|
if isinstance(xerr, DataArray) or isinstance(yerr, DataArray):
|
|
xplt, yplt, xerr, yerr, hueplt, hue_label = _infer_errorbar_data(darray, xerr, yerr, x, y, hue)
|
|
else:
|
|
xplt, yplt, hueplt, hue_label = _infer_line_data(darray, x, y, hue)
|
|
|
|
# Remove pd.Intervals if contained in xplt.values and/or yplt.values.
|
|
xplt_val, yplt_val, x_suffix, y_suffix, kwargs = _resolve_intervals_1dplot(
|
|
xplt.to_numpy(), yplt.to_numpy(), kwargs
|
|
)
|
|
xlabel = label_from_attrs(xplt, extra=x_suffix)
|
|
ylabel = label_from_attrs(yplt, extra=y_suffix)
|
|
|
|
_ensure_plottable(xplt_val, yplt_val)
|
|
|
|
fmt = None
|
|
if 'fmt' in kwargs and isinstance(kwargs['fmt'], list):
|
|
fmt = kwargs["fmt"]
|
|
|
|
if len(np.shape(xplt_val)) == len(np.shape(yplt_val)):
|
|
|
|
primitive = ax.errorbar(xplt_val, yplt_val, *args, xerr=xerr, yerr=yerr, **kwargs)
|
|
|
|
else:
|
|
primitive = np.empty(np.shape(yplt_val)[1], dtype=object)
|
|
|
|
if not yerr is None:
|
|
if not fmt is None:
|
|
for i in range(np.shape(yplt_val)[1]):
|
|
kwargs.update({'fmt': fmt[i]})
|
|
primitive[i] = ax.errorbar(xplt_val, yplt_val[:, i], *args, xerr=xerr, yerr=yerr[:, i], **kwargs)
|
|
else:
|
|
for i in range(np.shape(yplt_val)[1]):
|
|
primitive[i] = ax.errorbar(xplt_val, yplt_val[:, i], *args, xerr=xerr, yerr=yerr[:, i], **kwargs)
|
|
else:
|
|
if not fmt is None:
|
|
for i in range(np.shape(yplt_val)[1]):
|
|
kwargs.update({'fmt': fmt[i]})
|
|
primitive[i] = ax.errorbar(xplt_val, yplt_val[:, i], *args, xerr=xerr, yerr=yerr, **kwargs)
|
|
else:
|
|
for i in range(np.shape(yplt_val)[1]):
|
|
primitive[i] = ax.errorbar(xplt_val, yplt_val[:, i], *args, xerr=xerr, yerr=yerr, **kwargs)
|
|
|
|
primitive = tuple(primitive)
|
|
|
|
if _labels:
|
|
if xlabel is not None:
|
|
ax.set_xlabel(xlabel)
|
|
|
|
if ylabel is not None:
|
|
ax.set_ylabel(ylabel)
|
|
|
|
ax.set_title(darray._title_for_slice())
|
|
|
|
if darray.ndim == 2 and add_legend:
|
|
assert hueplt is not None
|
|
ax.legend(handles=primitive, labels=list(hueplt.to_numpy()), title=hue_label)
|
|
|
|
# Rotate dates on xlabels
|
|
# Do this without calling autofmt_xdate so that x-axes ticks
|
|
# on other subplots (if any) are not deleted.
|
|
# https://stackoverflow.com/questions/17430105/autofmt-xdate-deletes-x-axis-labels-of-all-subplots
|
|
if np.issubdtype(xplt.dtype, np.datetime64):
|
|
for xlabels in ax.get_xticklabels():
|
|
xlabels.set_rotation(30)
|
|
xlabels.set_horizontalalignment("right")
|
|
|
|
_update_axes(ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim)
|
|
|
|
return primitive
|
|
|
|
|
|
from xarray.plot.accessor import DataArrayPlotAccessor
|
|
# from xarray.plot.accessor import DatasetPlotAccessor
|
|
|
|
def dataarray_plot_errorbar(DataArrayPlotAccessor, *args, **kwargs) -> list[Line3D] | FacetGrid[DataArray]:
|
|
return errorbar(DataArrayPlotAccessor._da, *args, **kwargs) |