You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

142 lines
5.4 KiB

10 months ago
  1. import re
  2. from pathlib import Path
  3. # flake8: noqaq
  4. def parse_tmva_matrix_to_array(
  5. input_class_files: list[str],
  6. simd_type: bool = False,
  7. outdir: str = "nn_electron_training",
  8. ) -> list[str]:
  9. """Function to transform the TMVA output MLP C++ class into a more modern form.
  10. Args:
  11. input_class_file (str): Path to the .C MLP class created by TMVA.
  12. simd_type (bool, optional): If true, type in array is set to simd::float_v. Defaults to False.
  13. Returns:
  14. (str) : Path to the resulting C++ file containing the matrices.
  15. Note:
  16. The created C++ code is written to a `hpp` file in `nn_electron_training/result`.
  17. """
  18. data_type = "float" if not simd_type else "simd::float_v"
  19. # TODO: as of writing this code, constexpr is not supported for the SIMD types
  20. constness = "constexpr" if not simd_type else "const"
  21. outfiles = []
  22. for input_class_file in input_class_files:
  23. print(f"Transforming {input_class_file} ...")
  24. input_class_file = Path(input_class_file)
  25. with open(input_class_file) as f:
  26. lines = f.readlines()
  27. # the name of the outputfile is the middle part of the input class file name
  28. outfile = (
  29. outdir
  30. + "/result/"
  31. + "_".join(str(input_class_file.stem).split("_")[1:-1])
  32. + ".hpp"
  33. )
  34. # this only supports category 2 for the transformation, which is the largest/smallest for fMax_1/fMin_1
  35. min_lines = [
  36. line.replace(";", "").split("=")[-1].strip()
  37. for line in lines
  38. if "fMin_1[2]" in line and "Scal" not in line and "double" not in line
  39. ]
  40. max_lines = [
  41. line.replace(";", "").split("=")[-1].strip()
  42. for line in lines
  43. if "fMax_1[2]" in line and "Scal" not in line and "double" not in line
  44. ]
  45. minima_cpp_decl = (
  46. f"{constness} auto fMin = std::array<{data_type}, {len(min_lines)}>"
  47. + "{{"
  48. + ", ".join(min_lines)
  49. + "}};\n"
  50. )
  51. maxima_cpp_decl = (
  52. f"{constness} auto fMax = std::array<{data_type}, {len(max_lines)}>"
  53. + "{{"
  54. + ", ".join(max_lines)
  55. + "}};\n"
  56. )
  57. print(f"Found minimum and maximum values for {len(min_lines)} variables.")
  58. with open(outfile, "w") as out:
  59. out.writelines([minima_cpp_decl, maxima_cpp_decl])
  60. # this list contains all lines that define a matrix element
  61. matrix_lines = [
  62. line.replace(";", "").strip()
  63. for line in lines
  64. if "fWeightMatrix" in line
  65. and "double" not in line
  66. and "*" not in line
  67. and "= fWeightMatrix" not in line
  68. ]
  69. # there are several matrices, figure out how many and loop accordingly
  70. n_matrices = int(re.findall(re.compile(r"(\d+)\["), matrix_lines[-1])[0])
  71. print(f"Found {n_matrices} matrices: ")
  72. for matrix in range(n_matrices):
  73. # get only entries for corresponding matrix
  74. matrix_entries = [m for m in matrix_lines if f"{matrix}to{(matrix+1)}" in m]
  75. # figure out the dimensions of the matrix, by checking largest index at the end
  76. dim_string = re.findall(re.compile(r"\[(\d+)\]"), matrix_entries[-1])
  77. # actual size is last index + 1
  78. n_rows = int(dim_string[0]) + 1
  79. n_cols = int(dim_string[1]) + 1
  80. # get the name of the matrix
  81. matrix_string = matrix_entries[matrix * n_rows].split("=")[0].split("[")[0]
  82. if n_rows > 1:
  83. cpp_decl = (
  84. f"{constness} auto {matrix_string} = std::array<std::array<{data_type}, {n_cols}>, {n_rows}>"
  85. + "{{"
  86. )
  87. else:
  88. cpp_decl = (
  89. f"{constness} auto {matrix_string} = std::array<{data_type}, {n_cols}>"
  90. + "{"
  91. )
  92. rows = [[] for _ in range(n_rows)]
  93. for i_col in range(n_cols):
  94. for i_row, row in enumerate(rows):
  95. # only keep number (right side of the equality sign)
  96. entry = (
  97. matrix_entries[i_col * n_rows + i_row].split("=")[-1].strip()
  98. )
  99. row.append(entry)
  100. array_strings = ["{" + ", ".join(row) + "}" for row in rows]
  101. cpp_decl += ", ".join(array_strings) + ("}};\n" if n_rows > 1 else "};\n")
  102. print(
  103. f" {matrix+1}. {matrix_string} with {n_cols} columns and {n_rows} rows",
  104. )
  105. with open(outfile, "a") as out:
  106. out.write(cpp_decl)
  107. outfiles.append(outfile)
  108. return outfiles
  109. if __name__ == "__main__":
  110. outfiles = parse_tmva_matrix_to_array(
  111. [
  112. "nn_electron_training/result/GhostNNDataSet/weights/TMVAClassification_default_forward_ghost_mlp.class.C",
  113. "nn_electron_training/result/GhostNNDataSet/weights/TMVAClassification_veloUT_forward_ghost_mlp.class.C",
  114. ],
  115. )
  116. try:
  117. import subprocess
  118. # run clang-format for nicer looking result
  119. for outfile in outfiles:
  120. subprocess.run(
  121. [
  122. "clang-format",
  123. "-i",
  124. f"{outfile}",
  125. ],
  126. check=True,
  127. )
  128. except:
  129. pass