You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

140 lines
5.4 KiB

10 months ago
  1. import re
  2. from pathlib import Path
  3. def parse_tmva_matrix_to_array(
  4. input_class_files: list[str],
  5. simd_type: bool = False,
  6. outdir: str = "neural_net_training",
  7. ) -> list[str]:
  8. """Function to transform the TMVA output MLP C++ class into a more modern form.
  9. Args:
  10. input_class_file (str): Path to the .C MLP class created by TMVA.
  11. simd_type (bool, optional): If true, type in array is set to simd::float_v. Defaults to False.
  12. Returns:
  13. (str) : Path to the resulting C++ file containing the matrices.
  14. Note:
  15. The created C++ code is written to a `hpp` file in `neural_net_training/result`.
  16. """
  17. data_type = "float" if not simd_type else "simd::float_v"
  18. # TODO: as of writing this code, constexpr is not supported for the SIMD types
  19. constness = "constexpr" if not simd_type else "const"
  20. outfiles = []
  21. for input_class_file in input_class_files:
  22. print(f"Transforming {input_class_file} ...")
  23. input_class_file = Path(input_class_file)
  24. with open(input_class_file) as f:
  25. lines = f.readlines()
  26. # the name of the outputfile is the middle part of the input class file name
  27. outfile = (
  28. outdir
  29. + "/result/"
  30. + "_".join(str(input_class_file.stem).split("_")[1:-1])
  31. + ".hpp"
  32. )
  33. # this only supports category 2 for the transformation, which is the largest/smallest for fMax_1/fMin_1
  34. min_lines = [
  35. line.replace(";", "").split("=")[-1].strip()
  36. for line in lines
  37. if "fMin_1[2]" in line and "Scal" not in line and "double" not in line
  38. ]
  39. max_lines = [
  40. line.replace(";", "").split("=")[-1].strip()
  41. for line in lines
  42. if "fMax_1[2]" in line and "Scal" not in line and "double" not in line
  43. ]
  44. minima_cpp_decl = (
  45. f"{constness} auto fMin = std::array<{data_type}, {len(min_lines)}>"
  46. + "{{"
  47. + ", ".join(min_lines)
  48. + "}};\n"
  49. )
  50. maxima_cpp_decl = (
  51. f"{constness} auto fMax = std::array<{data_type}, {len(max_lines)}>"
  52. + "{{"
  53. + ", ".join(max_lines)
  54. + "}};\n"
  55. )
  56. print(f"Found minimum and maximum values for {len(min_lines)} variables.")
  57. with open(outfile, "w") as out:
  58. out.writelines([minima_cpp_decl, maxima_cpp_decl])
  59. # this list contains all lines that define a matrix element
  60. matrix_lines = [
  61. line.replace(";", "").strip()
  62. for line in lines
  63. if "fWeightMatrix" in line
  64. and "double" not in line
  65. and "*" not in line
  66. and "= fWeightMatrix" not in line
  67. ]
  68. # there are several matrices, figure out how many and loop accordingly
  69. n_matrices = int(re.findall(re.compile(r"(\d+)\["), matrix_lines[-1])[0])
  70. print(f"Found {n_matrices} matrices: ")
  71. for matrix in range(n_matrices):
  72. # get only entries for corresponding matrix
  73. matrix_entries = [m for m in matrix_lines if f"{matrix}to{(matrix+1)}" in m]
  74. # figure out the dimensions of the matrix, by checking largest index at the end
  75. dim_string = re.findall(re.compile(r"\[(\d+)\]"), matrix_entries[-1])
  76. # actual size is last index + 1
  77. n_rows = int(dim_string[0]) + 1
  78. n_cols = int(dim_string[1]) + 1
  79. # get the name of the matrix
  80. matrix_string = matrix_entries[matrix * n_rows].split("=")[0].split("[")[0]
  81. if n_rows > 1:
  82. cpp_decl = (
  83. f"{constness} auto {matrix_string} = std::array<std::array<{data_type}, {n_cols}>, {n_rows}>"
  84. + "{{"
  85. )
  86. else:
  87. cpp_decl = (
  88. f"{constness} auto {matrix_string} = std::array<{data_type}, {n_cols}>"
  89. + "{"
  90. )
  91. rows = [[] for _ in range(n_rows)]
  92. for i_col in range(n_cols):
  93. for i_row, row in enumerate(rows):
  94. # only keep number (right side of the equality sign)
  95. entry = (
  96. matrix_entries[i_col * n_rows + i_row].split("=")[-1].strip()
  97. )
  98. row.append(entry)
  99. array_strings = ["{" + ", ".join(row) + "}" for row in rows]
  100. cpp_decl += ", ".join(array_strings) + ("}};\n" if n_rows > 1 else "};\n")
  101. print(
  102. f" {matrix+1}. {matrix_string} with {n_cols} columns and {n_rows} rows",
  103. )
  104. with open(outfile, "a") as out:
  105. out.write(cpp_decl)
  106. outfiles.append(outfile)
  107. return outfiles
  108. if __name__ == "__main__":
  109. outfiles = parse_tmva_matrix_to_array(
  110. [
  111. "neural_net_training/result/GhostNNDataSet/weights/TMVAClassification_default_forward_ghost_mlp.class.C",
  112. "neural_net_training/result/GhostNNDataSet/weights/TMVAClassification_veloUT_forward_ghost_mlp.class.C",
  113. ],
  114. )
  115. try:
  116. import subprocess
  117. # run clang-format for nicer looking result
  118. for outfile in outfiles:
  119. subprocess.run(
  120. [
  121. "clang-format",
  122. "-i",
  123. f"{outfile}",
  124. ],
  125. check=True,
  126. )
  127. except:
  128. pass