rad len
This commit is contained in:
parent
c1234b8063
commit
ca20ccfe6e
BIN
methods/__pycache__/fit_linear_regression_model.cpython-310.pyc
Normal file
BIN
methods/__pycache__/fit_linear_regression_model.cpython-310.pyc
Normal file
Binary file not shown.
91
methods/fit_linear_regression_model.py
Normal file
91
methods/fit_linear_regression_model.py
Normal file
@ -0,0 +1,91 @@
|
||||
import awkward as ak
|
||||
from sklearn.preprocessing import PolynomialFeatures
|
||||
from sklearn.linear_model import LinearRegression
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.metrics import mean_squared_error
|
||||
import numpy as np
|
||||
|
||||
|
||||
def fit_linear_regression_model(
|
||||
array: ak.Array,
|
||||
target_feat: str,
|
||||
features: list[str],
|
||||
degree: int,
|
||||
keep: list[str] = None,
|
||||
keep_only_linear_in: str = "",
|
||||
remove: list[str] = None,
|
||||
include_bias: bool = False,
|
||||
fit_intercept: bool = False,
|
||||
test_size=0.2,
|
||||
random_state=42,
|
||||
) -> tuple[LinearRegression, list[str], np.array, np.array]:
|
||||
"""Wrapper around sklearn's LinearRegression with PolynomialFeatures.
|
||||
|
||||
Args:
|
||||
array (ak.Array): The data.
|
||||
target_feat (str): Target feature to be fitted.
|
||||
features (list[str]): Features the target depends on.
|
||||
degree (int): Highest order of the polynomial.
|
||||
keep (list[str], optional): Monomials to keep. Defaults to None.
|
||||
keep_only_linear_in (str, optional): Keep only terms that are linear in this feature. Defaults to "".
|
||||
remove (list[str], optional): Monomials to remove. Defaults to None.
|
||||
include_bias (bool, optional): Inlcude bias term in polynomial. Defaults to False.
|
||||
fit_intercept (bool, optional): Fit zeroth order. Defaults to False.
|
||||
test_size (float, optional): Fraction of data used for testing. Defaults to 0.2.
|
||||
random_state (int, optional): Defaults to 42.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: Simultaneous removing and keeping is not implemented.
|
||||
|
||||
Returns:
|
||||
tuple[LinearRegression, list[str]]: The linear regression object and the kept features.
|
||||
"""
|
||||
data = np.column_stack([ak.to_numpy(array[feat]) for feat in features])
|
||||
target = ak.to_numpy(array[target_feat])
|
||||
X_train, X_test, y_train, y_test = train_test_split(
|
||||
data,
|
||||
target,
|
||||
test_size=test_size,
|
||||
random_state=random_state,
|
||||
)
|
||||
poly = PolynomialFeatures(degree=degree, include_bias=include_bias)
|
||||
X_train_model = poly.fit_transform(X_train)
|
||||
X_test_model = poly.fit_transform(X_test)
|
||||
poly_features = poly.get_feature_names_out(input_features=features)
|
||||
if not remove:
|
||||
if keep:
|
||||
remove = [i for i, f in enumerate(poly_features) if f not in keep]
|
||||
elif keep_only_linear_in:
|
||||
# remove everything that's not linear in variable
|
||||
# the corrections should vanish
|
||||
remove = [
|
||||
i
|
||||
for i, f in enumerate(poly_features)
|
||||
if (keep_only_linear_in not in f) or (keep_only_linear_in + "^" in f)
|
||||
]
|
||||
else:
|
||||
remove = []
|
||||
elif remove and keep:
|
||||
raise NotImplementedError
|
||||
X_train_model = np.delete(X_train_model, remove, axis=1)
|
||||
X_test_model = np.delete(X_test_model, remove, axis=1)
|
||||
poly_features = np.delete(poly_features, remove)
|
||||
|
||||
lin_reg = LinearRegression(fit_intercept=fit_intercept)
|
||||
lin_reg.fit(X_train_model, y_train)
|
||||
y_pred_test = lin_reg.predict(X_test_model)
|
||||
print(f"Parameterisation for {target_feat}:")
|
||||
print("intercept=", lin_reg.intercept_)
|
||||
print(
|
||||
"coef=",
|
||||
dict(
|
||||
zip(
|
||||
poly_features,
|
||||
lin_reg.coef_,
|
||||
),
|
||||
),
|
||||
)
|
||||
print("r2 score=", lin_reg.score(X_test_model, y_test))
|
||||
print("RMSE =", mean_squared_error(y_test, y_pred_test, squared=False))
|
||||
print()
|
||||
return (lin_reg, poly_features, y_test, y_pred_test)
|
File diff suppressed because one or more lines are too long
346
trackinglosses_rad_length_beginVelo.ipynb
Normal file
346
trackinglosses_rad_length_beginVelo.ipynb
Normal file
File diff suppressed because one or more lines are too long
317
trackinglosses_rad_length_endVelo.ipynb
Normal file
317
trackinglosses_rad_length_endVelo.ipynb
Normal file
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user