tracking-parametrisation-tuner/parameterisations/utils/parse_tmva_matrix_to_array_electron.py
2023-12-19 13:00:59 +01:00

143 lines
5.4 KiB
Python

import re
from pathlib import Path
# flake8: noqaq
def parse_tmva_matrix_to_array(
input_class_files: list[str],
simd_type: bool = False,
outdir: str = "nn_electron_training",
) -> list[str]:
"""Function to transform the TMVA output MLP C++ class into a more modern form.
Args:
input_class_file (str): Path to the .C MLP class created by TMVA.
simd_type (bool, optional): If true, type in array is set to simd::float_v. Defaults to False.
Returns:
(str) : Path to the resulting C++ file containing the matrices.
Note:
The created C++ code is written to a `hpp` file in `nn_electron_training/result`.
"""
data_type = "float" if not simd_type else "simd::float_v"
# TODO: as of writing this code, constexpr is not supported for the SIMD types
constness = "constexpr" if not simd_type else "const"
outfiles = []
for input_class_file in input_class_files:
print(f"Transforming {input_class_file} ...")
input_class_file = Path(input_class_file)
with open(input_class_file) as f:
lines = f.readlines()
# the name of the outputfile is the middle part of the input class file name
outfile = (
outdir
+ "/result/"
+ "_".join(str(input_class_file.stem).split("_")[1:-1])
+ ".hpp"
)
# this only supports category 2 for the transformation, which is the largest/smallest for fMax_1/fMin_1
min_lines = [
line.replace(";", "").split("=")[-1].strip()
for line in lines
if "fMin_1[2]" in line and "Scal" not in line and "double" not in line
]
max_lines = [
line.replace(";", "").split("=")[-1].strip()
for line in lines
if "fMax_1[2]" in line and "Scal" not in line and "double" not in line
]
minima_cpp_decl = (
f"{constness} auto fMin = std::array<{data_type}, {len(min_lines)}>"
+ "{{"
+ ", ".join(min_lines)
+ "}};\n"
)
maxima_cpp_decl = (
f"{constness} auto fMax = std::array<{data_type}, {len(max_lines)}>"
+ "{{"
+ ", ".join(max_lines)
+ "}};\n"
)
print(f"Found minimum and maximum values for {len(min_lines)} variables.")
with open(outfile, "w") as out:
out.writelines([minima_cpp_decl, maxima_cpp_decl])
# this list contains all lines that define a matrix element
matrix_lines = [
line.replace(";", "").strip()
for line in lines
if "fWeightMatrix" in line
and "double" not in line
and "*" not in line
and "= fWeightMatrix" not in line
]
# there are several matrices, figure out how many and loop accordingly
n_matrices = int(re.findall(re.compile(r"(\d+)\["), matrix_lines[-1])[0])
print(f"Found {n_matrices} matrices: ")
for matrix in range(n_matrices):
# get only entries for corresponding matrix
matrix_entries = [m for m in matrix_lines if f"{matrix}to{(matrix+1)}" in m]
# figure out the dimensions of the matrix, by checking largest index at the end
dim_string = re.findall(re.compile(r"\[(\d+)\]"), matrix_entries[-1])
# actual size is last index + 1
n_rows = int(dim_string[0]) + 1
n_cols = int(dim_string[1]) + 1
# get the name of the matrix
matrix_string = matrix_entries[matrix * n_rows].split("=")[0].split("[")[0]
if n_rows > 1:
cpp_decl = (
f"{constness} auto {matrix_string} = std::array<std::array<{data_type}, {n_cols}>, {n_rows}>"
+ "{{"
)
else:
cpp_decl = (
f"{constness} auto {matrix_string} = std::array<{data_type}, {n_cols}>"
+ "{"
)
rows = [[] for _ in range(n_rows)]
for i_col in range(n_cols):
for i_row, row in enumerate(rows):
# only keep number (right side of the equality sign)
entry = (
matrix_entries[i_col * n_rows + i_row].split("=")[-1].strip()
)
row.append(entry)
array_strings = ["{" + ", ".join(row) + "}" for row in rows]
cpp_decl += ", ".join(array_strings) + ("}};\n" if n_rows > 1 else "};\n")
print(
f" {matrix+1}. {matrix_string} with {n_cols} columns and {n_rows} rows",
)
with open(outfile, "a") as out:
out.write(cpp_decl)
outfiles.append(outfile)
return outfiles
if __name__ == "__main__":
outfiles = parse_tmva_matrix_to_array(
[
"nn_electron_training/result/GhostNNDataSet/weights/TMVAClassification_default_forward_ghost_mlp.class.C",
"nn_electron_training/result/GhostNNDataSet/weights/TMVAClassification_veloUT_forward_ghost_mlp.class.C",
],
)
try:
import subprocess
# run clang-format for nicer looking result
for outfile in outfiles:
subprocess.run(
[
"clang-format",
"-i",
f"{outfile}",
],
check=True,
)
except:
pass