You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

51 lines
1.8 KiB

from sklearn.linear_model import LinearRegression
def parse_regression_coef_to_array(
model: LinearRegression,
poly_features: list[str],
array_name: str,
rows: list[str] = [],
) -> list[str]:
"""Convenient function to parse the model coefficients into a cpp code string.
Args:
model (LinearRegression): A fitted linear regression model.
poly_features (list[str]): A list with the names of the polynomial features.
array_name (str): The name of the created cpp array.
rows (list[str], optional): In case of a matrix, list of rows. Defaults to [].
Returns:
list[str]: List of strings, first entry is the comment, second the cpp code.
"""
intercept = model.intercept_ != 0.0
indices = [i for i in range(len(poly_features)) if model.coef_[i] != 0.0]
feature_comment = (
("// param[0] + " if intercept else "// ")
+ " + ".join(
[
f"param[{idx}]*{poly_features[param_index]}"
for idx, param_index in enumerate(indices, start=intercept)
],
)
+ "\n"
)
n_col = sum(model.coef_ != 0.0) + model.fit_intercept
if not rows:
cpp_decl = f"static constexpr std::array<float, {n_col}> {array_name}"
cpp_decl += (
"{"
+ (str(model.intercept_) + "f," if intercept else "")
+ ",".join([str(coef) + "f" for coef in model.coef_ if coef != 0.0])
+ "};\n"
)
return [feature_comment, cpp_decl]
else:
n_row = len(rows)
cpp_decl = (
f"static constexpr std::array<std::array<float, {n_col}>, {n_row}> {array_name}"
+ "{{"
)
cpp_decl += ",".join(rows)
cpp_decl += "}};\n"
return [feature_comment, cpp_decl]