This commit is contained in:
Jianshun Gao 2023-05-10 19:03:03 +02:00
parent d88b5d2fec
commit dfcb2b6a69

View File

@ -300,7 +300,7 @@ class FitAnalyser():
def _guess_2D(self, data, x, y, **kwargs): def _guess_2D(self, data, x, y, **kwargs):
return self.fitModel.guess(data=data, x=x, y=y, **kwargs) return self.fitModel.guess(data=data, x=x, y=y, **kwargs)
def guess(self, dataArray, x=None, y=None, guess_kwargs={}, input_core_dims=None, dask='parallelized', vectorize=True, keep_attrs=True, **kwargs): def guess(self, dataArray, x=None, y=None, guess_kwargs={}, input_core_dims=None, dask='parallelized', vectorize=True, keep_attrs=True, daskKwargs=None, **kwargs):
kwargs.update( kwargs.update(
{ {
@ -308,9 +308,13 @@ class FitAnalyser():
"vectorize": vectorize, "vectorize": vectorize,
"input_core_dims": input_core_dims, "input_core_dims": input_core_dims,
'keep_attrs': keep_attrs, 'keep_attrs': keep_attrs,
} }
) )
if not daskKwargs is None:
kwargs.update({"dask_gufunc_kwargs": daskKwargs})
if input_core_dims is None: if input_core_dims is None:
kwargs.update( kwargs.update(
{ {
@ -393,7 +397,7 @@ class FitAnalyser():
def _fit_2D(self, data, params, x, y): def _fit_2D(self, data, params, x, y):
return self.fitModel.fit(data=data, x=x, y=y, params=params) return self.fitModel.fit(data=data, x=x, y=y, params=params)
def fit(self, dataArray, paramsArray, x=None, y=None, input_core_dims=None, dask='parallelized', vectorize=True, keep_attrs=True, **kwargs): def fit(self, dataArray, paramsArray, x=None, y=None, input_core_dims=None, dask='parallelized', vectorize=True, keep_attrs=True, daskKwargs=None, **kwargs):
kwargs.update( kwargs.update(
{ {
@ -404,6 +408,9 @@ class FitAnalyser():
} }
) )
if not daskKwargs is None:
kwargs.update({"dask_gufunc_kwargs": daskKwargs})
if isinstance(paramsArray, type(self.fitModel.make_params())): if isinstance(paramsArray, type(self.fitModel.make_params())):
if input_core_dims is None: if input_core_dims is None:
@ -528,7 +535,7 @@ class FitAnalyser():
res = self.fitModel.eval(x=x, y=y, **fitResult.best_values) res = self.fitModel.eval(x=x, y=y, **fitResult.best_values)
return res.reshape(shape) return res.reshape(shape)
def eval(self, fitResultArray, x=None, y=None, output_core_dims=None, prefix="", dask='parallelized', vectorize=True, **kwargs): def eval(self, fitResultArray, x=None, y=None, output_core_dims=None, prefix="", dask='parallelized', vectorize=True, daskKwargs=None, **kwargs):
kwargs.update( kwargs.update(
{ {
@ -538,6 +545,9 @@ class FitAnalyser():
} }
) )
if daskKwargs is None:
daskKwargs = {}
if self.fitDim == 1: if self.fitDim == 1:
if output_core_dims is None: if output_core_dims is None:
@ -548,14 +558,18 @@ class FitAnalyser():
) )
output_core_dims = [prefix+'x'] output_core_dims = [prefix+'x']
kwargs.update( daskKwargs.update(
{ {
"dask_gufunc_kwargs": {
'output_sizes': { 'output_sizes': {
output_core_dims[0]: np.size(x), output_core_dims[0]: np.size(x),
}, },
'meta': np.ndarray((0,0), dtype=float) 'meta': np.ndarray((0,0), dtype=float)
}, }
)
kwargs.update(
{
"dask_gufunc_kwargs": daskKwargs,
} }
) )
@ -571,15 +585,19 @@ class FitAnalyser():
) )
output_core_dims = [prefix+'x', prefix+'y'] output_core_dims = [prefix+'x', prefix+'y']
kwargs.update( daskKwargs.update(
{ {
"dask_gufunc_kwargs": {
'output_sizes': { 'output_sizes': {
output_core_dims[0]: np.size(x), output_core_dims[0]: np.size(x),
output_core_dims[1]: np.size(y), output_core_dims[1]: np.size(y),
}, },
'meta': np.ndarray((0,0), dtype=float) 'meta': np.ndarray((0,0), dtype=float)
}, },
)
kwargs.update(
{
"dask_gufunc_kwargs": daskKwargs,
} }
) )