from __future__ import annotations from xarray.plot.dataarray_plot import _infer_line_data, _infer_xy_labels, _assert_valid_xy from xarray.plot.facetgrid import _easy_facetgrid from xarray.plot.utils import ( _LINEWIDTH_RANGE, _MARKERSIZE_RANGE, _ensure_plottable, _resolve_intervals_1dplot, _update_axes, get_axis, label_from_attrs, ) from matplotlib.axes import Axes from mpl_toolkits.mplot3d.art3d import Line3D import numpy as np from numpy.typing import ArrayLike from xarray.core.dataarray import DataArray from xarray.core.types import ( AspectOptions, ScaleOptions, ) from xarray.plot.facetgrid import FacetGrid def _infer_errorbar_data( darray: DataArray, xerrdarray: DataArray | None, yerrdarray: DataArray | None, x: Hashable | None, y: Hashable | None, hue: Hashable | None ) -> tuple[DataArray, DataArray, DataArray | None, str]: ndims = len(darray.dims) if x is not None and y is not None: raise ValueError("Cannot specify both x and y kwargs for line plots.") if x is not None: _assert_valid_xy(darray, x, "x") if y is not None: _assert_valid_xy(darray, y, "y") if ndims == 1: huename = None hueplt = None huelabel = "" xerrplt = None yerrplt = None if x is not None: xplt = darray[x] if xerrdarray is not None: xerrplt = xerrdarray[x] yplt = darray if yerrdarray is not None: yerrplt = yerrdarray elif y is not None: xplt = darray if xerrdarray is not None: xerrplt = xerrdarray yplt = darray[y] if yerrdarray is not None: yerrplt = yerrdarray[y] else: # Both x & y are None dim = darray.dims[0] xplt = darray[dim] yplt = darray if xerrdarray is not None: xerrplt = xerrdarray[dim] if yerrdarray is not None: yerrplt = yerrdarray else: if x is None and y is None and hue is None: raise ValueError("For 2D inputs, please specify either hue, x or y.") if y is None: if hue is not None: _assert_valid_xy(darray, hue, "hue") xname, huename = _infer_xy_labels(darray=darray, x=x, y=hue) xplt = darray[xname] if xerrdarray is not None: xerrplt = xerrdarray[xname] if xplt.ndim > 1: if huename in darray.dims: otherindex = 1 if darray.dims.index(huename) == 0 else 0 otherdim = darray.dims[otherindex] yplt = darray.transpose(otherdim, huename, transpose_coords=False) if yerrdarray is not None: yerrplt = yerrdarray.transpose(otherdim, huename, transpose_coords=False) xplt = xplt.transpose(otherdim, huename, transpose_coords=False) if xerrdarray is not None: xerrplt = xerrplt.transpose(otherdim, huename, transpose_coords=False) else: raise ValueError( "For 2D inputs, hue must be a dimension" " i.e. one of " + repr(darray.dims) ) else: (xdim,) = darray[xname].dims (huedim,) = darray[huename].dims yplt = darray.transpose(xdim, huedim) if yerrdarray is not None: yerrplt = yerrdarray.transpose(xdim, huedim) else: yname, huename = _infer_xy_labels(darray=darray, x=y, y=hue) yplt = darray[yname] if yerrdarray is not None: yerrplt = yerrdarray[yname] if yplt.ndim > 1: if huename in darray.dims: otherindex = 1 if darray.dims.index(huename) == 0 else 0 otherdim = darray.dims[otherindex] xplt = darray.transpose(otherdim, huename, transpose_coords=False) if xerrdarray is not None: xerrplt = xerrdarray.transpose(otherdim, huename, transpose_coords=False) yplt = yplt.transpose(otherdim, huename, transpose_coords=False) if yerrdarray is not None: yerrplt = yerrplt.transpose(otherdim, huename, transpose_coords=False) else: raise ValueError( "For 2D inputs, hue must be a dimension" " i.e. one of " + repr(darray.dims) ) else: (ydim,) = darray[yname].dims (huedim,) = darray[huename].dims xplt = darray.transpose(ydim, huedim) if xerrdarray is not None: xerrplt = xerrdarray.transpose(ydim, huedim) huelabel = label_from_attrs(darray[huename]) hueplt = darray[huename] return xplt, yplt, xerrplt, yerrplt, hueplt, huelabel def errorbar( darray: DataArray, *args: Any, xerr: Hashable | DataArray | None = None, yerr: Hashable | DataArray | None = None, row: Hashable | None = None, col: Hashable | None = None, figsize: Iterable[float] | None = None, aspect: AspectOptions = None, size: float | None = None, ax: Axes | None = None, hue: Hashable | None = None, x: Hashable | None = None, y: Hashable | None = None, xincrease: bool | None = None, yincrease: bool | None = None, xscale: ScaleOptions = None, yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, xlim: ArrayLike | None = None, ylim: ArrayLike | None = None, add_legend: bool = True, _labels: bool = True, **kwargs: Any, ) -> list[Line3D] | FacetGrid[DataArray]: # Handle facetgrids first if row or col: allargs = locals().copy() allargs.update(allargs.pop("kwargs")) allargs.pop("darray") return _easy_facetgrid(darray, line, kind="line", **allargs) ndims = len(darray.dims) if ndims == 0 or darray.size == 0: # TypeError to be consistent with pandas raise TypeError("No numeric data to plot.") if ndims > 2: raise ValueError( "Line plots are for 1- or 2-dimensional DataArrays. " "Passed DataArray has {ndims} " "dimensions".format(ndims=ndims) ) # The allargs dict passed to _easy_facetgrid above contains args if args == (): args = kwargs.pop("args", ()) else: assert "args" not in kwargs ax = get_axis(figsize, size, aspect, ax) if isinstance(xerr, DataArray) or isinstance(yerr, DataArray): xplt, yplt, xerr, yerr, hueplt, hue_label = _infer_errorbar_data(darray, xerr, yerr, x, y, hue) else: xplt, yplt, hueplt, hue_label = _infer_line_data(darray, x, y, hue) # Remove pd.Intervals if contained in xplt.values and/or yplt.values. xplt_val, yplt_val, x_suffix, y_suffix, kwargs = _resolve_intervals_1dplot( xplt.to_numpy(), yplt.to_numpy(), kwargs ) xlabel = label_from_attrs(xplt, extra=x_suffix) ylabel = label_from_attrs(yplt, extra=y_suffix) _ensure_plottable(xplt_val, yplt_val) fmt = None if 'fmt' in kwargs and isinstance(kwargs['fmt'], list): fmt = kwargs["fmt"] if len(np.shape(xplt_val)) == len(np.shape(yplt_val)): primitive = ax.errorbar(xplt_val, yplt_val, *args, xerr=xerr, yerr=yerr, **kwargs) else: primitive = np.empty(np.shape(yplt_val)[1], dtype=object) if not yerr is None: if not fmt is None: for i in range(np.shape(yplt_val)[1]): kwargs.update({'fmt': fmt[i]}) primitive[i] = ax.errorbar(xplt_val, yplt_val[:, i], *args, xerr=xerr, yerr=yerr[:, i], **kwargs) else: for i in range(np.shape(yplt_val)[1]): primitive[i] = ax.errorbar(xplt_val, yplt_val[:, i], *args, xerr=xerr, yerr=yerr[:, i], **kwargs) else: if not fmt is None: for i in range(np.shape(yplt_val)[1]): kwargs.update({'fmt': fmt[i]}) primitive[i] = ax.errorbar(xplt_val, yplt_val[:, i], *args, xerr=xerr, yerr=yerr, **kwargs) else: for i in range(np.shape(yplt_val)[1]): primitive[i] = ax.errorbar(xplt_val, yplt_val[:, i], *args, xerr=xerr, yerr=yerr, **kwargs) primitive = tuple(primitive) if _labels: if xlabel is not None: ax.set_xlabel(xlabel) if ylabel is not None: ax.set_ylabel(ylabel) ax.set_title(darray._title_for_slice()) if darray.ndim == 2 and add_legend: assert hueplt is not None ax.legend(handles=primitive, labels=list(hueplt.to_numpy()), title=hue_label) # Rotate dates on xlabels # Do this without calling autofmt_xdate so that x-axes ticks # on other subplots (if any) are not deleted. # https://stackoverflow.com/questions/17430105/autofmt-xdate-deletes-x-axis-labels-of-all-subplots if np.issubdtype(xplt.dtype, np.datetime64): for xlabels in ax.get_xticklabels(): xlabels.set_rotation(30) xlabels.set_horizontalalignment("right") _update_axes(ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim) return primitive from xarray.plot.accessor import DataArrayPlotAccessor # from xarray.plot.accessor import DatasetPlotAccessor def dataarray_plot_errorbar(DataArrayPlotAccessor, *args, **kwargs) -> list[Line3D] | FacetGrid[DataArray]: return errorbar(DataArrayPlotAccessor._da, *args, **kwargs)