143 lines
5.4 KiB
Python
143 lines
5.4 KiB
Python
import re
|
|
from pathlib import Path
|
|
|
|
# flake8: noqaq
|
|
|
|
|
|
def parse_tmva_matrix_to_array(
|
|
input_class_files: list[str],
|
|
simd_type: bool = False,
|
|
outdir: str = "nn_electron_training",
|
|
) -> list[str]:
|
|
"""Function to transform the TMVA output MLP C++ class into a more modern form.
|
|
|
|
Args:
|
|
input_class_file (str): Path to the .C MLP class created by TMVA.
|
|
simd_type (bool, optional): If true, type in array is set to simd::float_v. Defaults to False.
|
|
|
|
Returns:
|
|
(str) : Path to the resulting C++ file containing the matrices.
|
|
|
|
Note:
|
|
The created C++ code is written to a `hpp` file in `nn_electron_training/result`.
|
|
|
|
"""
|
|
data_type = "float" if not simd_type else "simd::float_v"
|
|
# TODO: as of writing this code, constexpr is not supported for the SIMD types
|
|
constness = "constexpr" if not simd_type else "const"
|
|
outfiles = []
|
|
for input_class_file in input_class_files:
|
|
print(f"Transforming {input_class_file} ...")
|
|
input_class_file = Path(input_class_file)
|
|
with open(input_class_file) as f:
|
|
lines = f.readlines()
|
|
# the name of the outputfile is the middle part of the input class file name
|
|
outfile = (
|
|
outdir
|
|
+ "/result/"
|
|
+ "_".join(str(input_class_file.stem).split("_")[1:-1])
|
|
+ ".hpp"
|
|
)
|
|
# this only supports category 2 for the transformation, which is the largest/smallest for fMax_1/fMin_1
|
|
min_lines = [
|
|
line.replace(";", "").split("=")[-1].strip()
|
|
for line in lines
|
|
if "fMin_1[2]" in line and "Scal" not in line and "double" not in line
|
|
]
|
|
max_lines = [
|
|
line.replace(";", "").split("=")[-1].strip()
|
|
for line in lines
|
|
if "fMax_1[2]" in line and "Scal" not in line and "double" not in line
|
|
]
|
|
minima_cpp_decl = (
|
|
f"{constness} auto fMin = std::array<{data_type}, {len(min_lines)}>"
|
|
+ "{{"
|
|
+ ", ".join(min_lines)
|
|
+ "}};\n"
|
|
)
|
|
maxima_cpp_decl = (
|
|
f"{constness} auto fMax = std::array<{data_type}, {len(max_lines)}>"
|
|
+ "{{"
|
|
+ ", ".join(max_lines)
|
|
+ "}};\n"
|
|
)
|
|
print(f"Found minimum and maximum values for {len(min_lines)} variables.")
|
|
with open(outfile, "w") as out:
|
|
out.writelines([minima_cpp_decl, maxima_cpp_decl])
|
|
|
|
# this list contains all lines that define a matrix element
|
|
matrix_lines = [
|
|
line.replace(";", "").strip()
|
|
for line in lines
|
|
if "fWeightMatrix" in line
|
|
and "double" not in line
|
|
and "*" not in line
|
|
and "= fWeightMatrix" not in line
|
|
]
|
|
|
|
# there are several matrices, figure out how many and loop accordingly
|
|
n_matrices = int(re.findall(re.compile(r"(\d+)\["), matrix_lines[-1])[0])
|
|
print(f"Found {n_matrices} matrices: ")
|
|
for matrix in range(n_matrices):
|
|
# get only entries for corresponding matrix
|
|
matrix_entries = [m for m in matrix_lines if f"{matrix}to{(matrix+1)}" in m]
|
|
# figure out the dimensions of the matrix, by checking largest index at the end
|
|
dim_string = re.findall(re.compile(r"\[(\d+)\]"), matrix_entries[-1])
|
|
# actual size is last index + 1
|
|
n_rows = int(dim_string[0]) + 1
|
|
n_cols = int(dim_string[1]) + 1
|
|
# get the name of the matrix
|
|
matrix_string = matrix_entries[matrix * n_rows].split("=")[0].split("[")[0]
|
|
if n_rows > 1:
|
|
cpp_decl = (
|
|
f"{constness} auto {matrix_string} = std::array<std::array<{data_type}, {n_cols}>, {n_rows}>"
|
|
+ "{{"
|
|
)
|
|
else:
|
|
cpp_decl = (
|
|
f"{constness} auto {matrix_string} = std::array<{data_type}, {n_cols}>"
|
|
+ "{"
|
|
)
|
|
rows = [[] for _ in range(n_rows)]
|
|
for i_col in range(n_cols):
|
|
for i_row, row in enumerate(rows):
|
|
# only keep number (right side of the equality sign)
|
|
entry = (
|
|
matrix_entries[i_col * n_rows + i_row].split("=")[-1].strip()
|
|
)
|
|
row.append(entry)
|
|
array_strings = ["{" + ", ".join(row) + "}" for row in rows]
|
|
cpp_decl += ", ".join(array_strings) + ("}};\n" if n_rows > 1 else "};\n")
|
|
print(
|
|
f" {matrix+1}. {matrix_string} with {n_cols} columns and {n_rows} rows",
|
|
)
|
|
with open(outfile, "a") as out:
|
|
out.write(cpp_decl)
|
|
outfiles.append(outfile)
|
|
return outfiles
|
|
|
|
|
|
if __name__ == "__main__":
|
|
outfiles = parse_tmva_matrix_to_array(
|
|
[
|
|
"nn_electron_training/result/GhostNNDataSet/weights/TMVAClassification_default_forward_ghost_mlp.class.C",
|
|
"nn_electron_training/result/GhostNNDataSet/weights/TMVAClassification_veloUT_forward_ghost_mlp.class.C",
|
|
],
|
|
)
|
|
|
|
try:
|
|
import subprocess
|
|
|
|
# run clang-format for nicer looking result
|
|
for outfile in outfiles:
|
|
subprocess.run(
|
|
[
|
|
"clang-format",
|
|
"-i",
|
|
f"{outfile}",
|
|
],
|
|
check=True,
|
|
)
|
|
except:
|
|
pass
|