You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

245 lines
8.5 KiB

1 year ago
  1. from __future__ import annotations
  2. from xarray.plot.dataarray_plot import _infer_line_data, _infer_xy_labels, _assert_valid_xy
  3. from xarray.plot.facetgrid import _easy_facetgrid
  4. from xarray.plot.utils import (
  5. _LINEWIDTH_RANGE,
  6. _MARKERSIZE_RANGE,
  7. _ensure_plottable,
  8. _resolve_intervals_1dplot,
  9. _update_axes,
  10. get_axis,
  11. label_from_attrs,
  12. )
  13. from matplotlib.axes import Axes
  14. from mpl_toolkits.mplot3d.art3d import Line3D
  15. import numpy as np
  16. from numpy.typing import ArrayLike
  17. from xarray.core.dataarray import DataArray
  18. from xarray.core.types import (
  19. AspectOptions,
  20. ScaleOptions,
  21. )
  22. from xarray.plot.facetgrid import FacetGrid
  23. def _infer_errorbar_data(
  24. darray: DataArray,
  25. xerrdarray: DataArray | None,
  26. yerrdarray: DataArray | None,
  27. x: Hashable | None,
  28. y: Hashable | None,
  29. hue: Hashable | None
  30. ) -> tuple[DataArray, DataArray, DataArray | None, str]:
  31. ndims = len(darray.dims)
  32. if x is not None and y is not None:
  33. raise ValueError("Cannot specify both x and y kwargs for line plots.")
  34. if x is not None:
  35. _assert_valid_xy(darray, x, "x")
  36. if y is not None:
  37. _assert_valid_xy(darray, y, "y")
  38. if ndims == 1:
  39. huename = None
  40. hueplt = None
  41. huelabel = ""
  42. xerrplt = None
  43. yerrplt = None
  44. if x is not None:
  45. xplt = darray[x]
  46. if xerrdarray is not None:
  47. xerrplt = xerrdarray[x]
  48. yplt = darray
  49. if yerrdarray is not None:
  50. yerrplt = yerrdarray
  51. elif y is not None:
  52. xplt = darray
  53. if xerrdarray is not None:
  54. xerrplt = xerrdarray
  55. yplt = darray[y]
  56. if yerrdarray is not None:
  57. yerrplt = yerrdarray[y]
  58. else: # Both x & y are None
  59. dim = darray.dims[0]
  60. xplt = darray[dim]
  61. yplt = darray
  62. if xerrdarray is not None:
  63. xerrplt = xerrdarray[dim]
  64. if yerrdarray is not None:
  65. yerrplt = yerrdarray
  66. else:
  67. if x is None and y is None and hue is None:
  68. raise ValueError("For 2D inputs, please specify either hue, x or y.")
  69. if y is None:
  70. if hue is not None:
  71. _assert_valid_xy(darray, hue, "hue")
  72. xname, huename = _infer_xy_labels(darray=darray, x=x, y=hue)
  73. xplt = darray[xname]
  74. if xerrdarray is not None:
  75. xerrplt = xerrdarray[xname]
  76. if xplt.ndim > 1:
  77. if huename in darray.dims:
  78. otherindex = 1 if darray.dims.index(huename) == 0 else 0
  79. otherdim = darray.dims[otherindex]
  80. yplt = darray.transpose(otherdim, huename, transpose_coords=False)
  81. if yerrdarray is not None:
  82. yerrplt = yerrdarray.transpose(otherdim, huename, transpose_coords=False)
  83. xplt = xplt.transpose(otherdim, huename, transpose_coords=False)
  84. if xerrdarray is not None:
  85. xerrplt = xerrplt.transpose(otherdim, huename, transpose_coords=False)
  86. else:
  87. raise ValueError(
  88. "For 2D inputs, hue must be a dimension"
  89. " i.e. one of " + repr(darray.dims)
  90. )
  91. else:
  92. (xdim,) = darray[xname].dims
  93. (huedim,) = darray[huename].dims
  94. yplt = darray.transpose(xdim, huedim)
  95. if yerrdarray is not None:
  96. yerrplt = yerrdarray.transpose(xdim, huedim)
  97. else:
  98. yname, huename = _infer_xy_labels(darray=darray, x=y, y=hue)
  99. yplt = darray[yname]
  100. if yerrdarray is not None:
  101. yerrplt = yerrdarray[yname]
  102. if yplt.ndim > 1:
  103. if huename in darray.dims:
  104. otherindex = 1 if darray.dims.index(huename) == 0 else 0
  105. otherdim = darray.dims[otherindex]
  106. xplt = darray.transpose(otherdim, huename, transpose_coords=False)
  107. if xerrdarray is not None:
  108. xerrplt = xerrdarray.transpose(otherdim, huename, transpose_coords=False)
  109. yplt = yplt.transpose(otherdim, huename, transpose_coords=False)
  110. if yerrdarray is not None:
  111. yerrplt = yerrplt.transpose(otherdim, huename, transpose_coords=False)
  112. else:
  113. raise ValueError(
  114. "For 2D inputs, hue must be a dimension"
  115. " i.e. one of " + repr(darray.dims)
  116. )
  117. else:
  118. (ydim,) = darray[yname].dims
  119. (huedim,) = darray[huename].dims
  120. xplt = darray.transpose(ydim, huedim)
  121. if xerrdarray is not None:
  122. xerrplt = xerrdarray.transpose(ydim, huedim)
  123. huelabel = label_from_attrs(darray[huename])
  124. hueplt = darray[huename]
  125. return xplt, yplt, xerrplt, yerrplt, hueplt, huelabel
  126. def errorbar(
  127. darray: DataArray,
  128. *args: Any,
  129. xerr: Hashable | DataArray | None = None,
  130. yerr: Hashable | DataArray | None = None,
  131. row: Hashable | None = None,
  132. col: Hashable | None = None,
  133. figsize: Iterable[float] | None = None,
  134. aspect: AspectOptions = None,
  135. size: float | None = None,
  136. ax: Axes | None = None,
  137. hue: Hashable | None = None,
  138. x: Hashable | None = None,
  139. y: Hashable | None = None,
  140. xincrease: bool | None = None,
  141. yincrease: bool | None = None,
  142. xscale: ScaleOptions = None,
  143. yscale: ScaleOptions = None,
  144. xticks: ArrayLike | None = None,
  145. yticks: ArrayLike | None = None,
  146. xlim: ArrayLike | None = None,
  147. ylim: ArrayLike | None = None,
  148. add_legend: bool = True,
  149. _labels: bool = True,
  150. **kwargs: Any,
  151. ) -> list[Line3D] | FacetGrid[DataArray]:
  152. # Handle facetgrids first
  153. if row or col:
  154. allargs = locals().copy()
  155. allargs.update(allargs.pop("kwargs"))
  156. allargs.pop("darray")
  157. return _easy_facetgrid(darray, line, kind="line", **allargs)
  158. ndims = len(darray.dims)
  159. if ndims == 0 or darray.size == 0:
  160. # TypeError to be consistent with pandas
  161. raise TypeError("No numeric data to plot.")
  162. if ndims > 2:
  163. raise ValueError(
  164. "Line plots are for 1- or 2-dimensional DataArrays. "
  165. "Passed DataArray has {ndims} "
  166. "dimensions".format(ndims=ndims)
  167. )
  168. # The allargs dict passed to _easy_facetgrid above contains args
  169. if args == ():
  170. args = kwargs.pop("args", ())
  171. else:
  172. assert "args" not in kwargs
  173. ax = get_axis(figsize, size, aspect, ax)
  174. if isinstance(xerr, DataArray) or isinstance(yerr, DataArray):
  175. xplt, yplt, xerr, yerr, hueplt, hue_label = _infer_errorbar_data(darray, xerr, yerr, x, y, hue)
  176. else:
  177. xplt, yplt, hueplt, hue_label = _infer_line_data(darray, x, y, hue)
  178. # Remove pd.Intervals if contained in xplt.values and/or yplt.values.
  179. xplt_val, yplt_val, x_suffix, y_suffix, kwargs = _resolve_intervals_1dplot(
  180. xplt.to_numpy(), yplt.to_numpy(), kwargs
  181. )
  182. xlabel = label_from_attrs(xplt, extra=x_suffix)
  183. ylabel = label_from_attrs(yplt, extra=y_suffix)
  184. _ensure_plottable(xplt_val, yplt_val)
  185. primitive = ax.errorbar(xplt_val, yplt_val, *args, xerr=xerr, yerr=yerr, **kwargs)
  186. if _labels:
  187. if xlabel is not None:
  188. ax.set_xlabel(xlabel)
  189. if ylabel is not None:
  190. ax.set_ylabel(ylabel)
  191. ax.set_title(darray._title_for_slice())
  192. if darray.ndim == 2 and add_legend:
  193. assert hueplt is not None
  194. ax.legend(handles=primitive, labels=list(hueplt.to_numpy()), title=hue_label)
  195. # Rotate dates on xlabels
  196. # Do this without calling autofmt_xdate so that x-axes ticks
  197. # on other subplots (if any) are not deleted.
  198. # https://stackoverflow.com/questions/17430105/autofmt-xdate-deletes-x-axis-labels-of-all-subplots
  199. if np.issubdtype(xplt.dtype, np.datetime64):
  200. for xlabels in ax.get_xticklabels():
  201. xlabels.set_rotation(30)
  202. xlabels.set_horizontalalignment("right")
  203. _update_axes(ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim)
  204. return primitive
  205. from xarray.plot.accessor import DataArrayPlotAccessor
  206. # from xarray.plot.accessor import DatasetPlotAccessor
  207. def dataarray_plot_errorbar(DataArrayPlotAccessor, *args, **kwargs) -> list[Line3D] | FacetGrid[DataArray]:
  208. return errorbar(DataArrayPlotAccessor._da, *args, **kwargs)