You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

51 lines
1.8 KiB

10 months ago
  1. from sklearn.linear_model import LinearRegression
  2. def parse_regression_coef_to_array(
  3. model: LinearRegression,
  4. poly_features: list[str],
  5. array_name: str,
  6. rows: list[str] = [],
  7. ) -> list[str]:
  8. """Convenient function to parse the model coefficients into a cpp code string.
  9. Args:
  10. model (LinearRegression): A fitted linear regression model.
  11. poly_features (list[str]): A list with the names of the polynomial features.
  12. array_name (str): The name of the created cpp array.
  13. rows (list[str], optional): In case of a matrix, list of rows. Defaults to [].
  14. Returns:
  15. list[str]: List of strings, first entry is the comment, second the cpp code.
  16. """
  17. intercept = model.intercept_ != 0.0
  18. indices = [i for i in range(len(poly_features)) if model.coef_[i] != 0.0]
  19. feature_comment = (
  20. ("// param[0] + " if intercept else "// ")
  21. + " + ".join(
  22. [
  23. f"param[{idx}]*{poly_features[param_index]}"
  24. for idx, param_index in enumerate(indices, start=intercept)
  25. ],
  26. )
  27. + "\n"
  28. )
  29. n_col = sum(model.coef_ != 0.0) + model.fit_intercept
  30. if not rows:
  31. cpp_decl = f"static constexpr std::array<float, {n_col}> {array_name}"
  32. cpp_decl += (
  33. "{"
  34. + (str(model.intercept_) + "f," if intercept else "")
  35. + ",".join([str(coef) + "f" for coef in model.coef_ if coef != 0.0])
  36. + "};\n"
  37. )
  38. return [feature_comment, cpp_decl]
  39. else:
  40. n_row = len(rows)
  41. cpp_decl = (
  42. f"static constexpr std::array<std::array<float, {n_col}>, {n_row}> {array_name}"
  43. + "{{"
  44. )
  45. cpp_decl += ",".join(rows)
  46. cpp_decl += "}};\n"
  47. return [feature_comment, cpp_decl]