fix bugs (can not use xarray.sutck with dask)

This commit is contained in:
Jianshun Gao 2023-05-15 17:05:16 +02:00
parent dfcb2b6a69
commit 9854f5a660
2 changed files with 634 additions and 638 deletions

View File

@ -298,6 +298,7 @@ class FitAnalyser():
return self.fitModel.guess(data=data, x=x, **kwargs)
def _guess_2D(self, data, x, y, **kwargs):
data = data.flatten(order='F')
return self.fitModel.guess(data=data, x=x, y=y, **kwargs)
def guess(self, dataArray, x=None, y=None, guess_kwargs={}, input_core_dims=None, dask='parallelized', vectorize=True, keep_attrs=True, daskKwargs=None, **kwargs):
@ -374,9 +375,9 @@ class FitAnalyser():
_x = _x.flatten()
_y = _y.flatten()
dataArray = dataArray.stack(_z=(kwargs["input_core_dims"][0][0], kwargs["input_core_dims"][0][1]))
# dataArray = dataArray.stack(_z=(kwargs["input_core_dims"][0][0], kwargs["input_core_dims"][0][1]))
kwargs["input_core_dims"][0] = ['_z']
# kwargs["input_core_dims"][0] = ['_z']
guess_kwargs.update(
{
@ -391,10 +392,10 @@ class FitAnalyser():
)
def _fit_1D(self, data, params, x):
# try:
return self.fitModel.fit(data=data, x=x, params=params)
def _fit_2D(self, data, params, x, y):
data = data.flatten(order='F')
return self.fitModel.fit(data=data, x=x, y=y, params=params)
def fit(self, dataArray, paramsArray, x=None, y=None, input_core_dims=None, dask='parallelized', vectorize=True, keep_attrs=True, daskKwargs=None, **kwargs):
@ -464,9 +465,9 @@ class FitAnalyser():
_x = _x.flatten()
_y = _y.flatten()
dataArray = dataArray.stack(_z=(kwargs["input_core_dims"][0][0], kwargs["input_core_dims"][0][1]))
# dataArray = dataArray.stack(_z=(kwargs["input_core_dims"][0][0], kwargs["input_core_dims"][0][1]))
kwargs["input_core_dims"][0] = ['_z']
# kwargs["input_core_dims"][0] = ['_z']
return xr.apply_ufunc(self._fit_2D, dataArray, kwargs={'params':paramsArray,'x':_x, 'y':_y},
output_dtypes=[type(lmfit.model.ModelResult(self.fitModel, self.fitModel.make_params()))],
@ -520,9 +521,9 @@ class FitAnalyser():
_x = _x.flatten()
_y = _y.flatten()
dataArray = dataArray.stack(_z=(kwargs["input_core_dims"][0][0], kwargs["input_core_dims"][0][1]))
# dataArray = dataArray.stack(_z=(kwargs["input_core_dims"][0][0], kwargs["input_core_dims"][0][1]))
kwargs["input_core_dims"][0] = ['_z']
# kwargs["input_core_dims"][0] = ['_z']
return xr.apply_ufunc(self._fit_2D, dataArray, paramsArray, kwargs={'x':_x, 'y':_y},
output_dtypes=[type(lmfit.model.ModelResult(self.fitModel, self.fitModel.make_params()))],
@ -533,7 +534,7 @@ class FitAnalyser():
def _eval_2D(self, fitResult, x, y, shape):
res = self.fitModel.eval(x=x, y=y, **fitResult.best_values)
return res.reshape(shape)
return res.reshape(shape, order='F')
def eval(self, fitResultArray, x=None, y=None, output_core_dims=None, prefix="", dask='parallelized', vectorize=True, daskKwargs=None, **kwargs):

File diff suppressed because one or more lines are too long