import re from pathlib import Path # flake8: noqaq def parse_tmva_matrix_to_array( input_class_files: list[str], simd_type: bool = False, outdir: str = "nn_electron_training", ) -> list[str]: """Function to transform the TMVA output MLP C++ class into a more modern form. Args: input_class_file (str): Path to the .C MLP class created by TMVA. simd_type (bool, optional): If true, type in array is set to simd::float_v. Defaults to False. Returns: (str) : Path to the resulting C++ file containing the matrices. Note: The created C++ code is written to a `hpp` file in `nn_electron_training/result`. """ data_type = "float" if not simd_type else "simd::float_v" # TODO: as of writing this code, constexpr is not supported for the SIMD types constness = "constexpr" if not simd_type else "const" outfiles = [] for input_class_file in input_class_files: print(f"Transforming {input_class_file} ...") input_class_file = Path(input_class_file) with open(input_class_file) as f: lines = f.readlines() # the name of the outputfile is the middle part of the input class file name outfile = ( outdir + "/result/" + "_".join(str(input_class_file.stem).split("_")[1:-1]) + ".hpp" ) # this only supports category 2 for the transformation, which is the largest/smallest for fMax_1/fMin_1 min_lines = [ line.replace(";", "").split("=")[-1].strip() for line in lines if "fMin_1[2]" in line and "Scal" not in line and "double" not in line ] max_lines = [ line.replace(";", "").split("=")[-1].strip() for line in lines if "fMax_1[2]" in line and "Scal" not in line and "double" not in line ] minima_cpp_decl = ( f"{constness} auto fMin = std::array<{data_type}, {len(min_lines)}>" + "{{" + ", ".join(min_lines) + "}};\n" ) maxima_cpp_decl = ( f"{constness} auto fMax = std::array<{data_type}, {len(max_lines)}>" + "{{" + ", ".join(max_lines) + "}};\n" ) print(f"Found minimum and maximum values for {len(min_lines)} variables.") with open(outfile, "w") as out: out.writelines([minima_cpp_decl, maxima_cpp_decl]) # this list contains all lines that define a matrix element matrix_lines = [ line.replace(";", "").strip() for line in lines if "fWeightMatrix" in line and "double" not in line and "*" not in line and "= fWeightMatrix" not in line ] # there are several matrices, figure out how many and loop accordingly n_matrices = int(re.findall(re.compile(r"(\d+)\["), matrix_lines[-1])[0]) print(f"Found {n_matrices} matrices: ") for matrix in range(n_matrices): # get only entries for corresponding matrix matrix_entries = [m for m in matrix_lines if f"{matrix}to{(matrix+1)}" in m] # figure out the dimensions of the matrix, by checking largest index at the end dim_string = re.findall(re.compile(r"\[(\d+)\]"), matrix_entries[-1]) # actual size is last index + 1 n_rows = int(dim_string[0]) + 1 n_cols = int(dim_string[1]) + 1 # get the name of the matrix matrix_string = matrix_entries[matrix * n_rows].split("=")[0].split("[")[0] if n_rows > 1: cpp_decl = ( f"{constness} auto {matrix_string} = std::array, {n_rows}>" + "{{" ) else: cpp_decl = ( f"{constness} auto {matrix_string} = std::array<{data_type}, {n_cols}>" + "{" ) rows = [[] for _ in range(n_rows)] for i_col in range(n_cols): for i_row, row in enumerate(rows): # only keep number (right side of the equality sign) entry = ( matrix_entries[i_col * n_rows + i_row].split("=")[-1].strip() ) row.append(entry) array_strings = ["{" + ", ".join(row) + "}" for row in rows] cpp_decl += ", ".join(array_strings) + ("}};\n" if n_rows > 1 else "};\n") print( f" {matrix+1}. {matrix_string} with {n_cols} columns and {n_rows} rows", ) with open(outfile, "a") as out: out.write(cpp_decl) outfiles.append(outfile) return outfiles if __name__ == "__main__": outfiles = parse_tmva_matrix_to_array( [ "nn_electron_training/result/GhostNNDataSet/weights/TMVAClassification_default_forward_ghost_mlp.class.C", "nn_electron_training/result/GhostNNDataSet/weights/TMVAClassification_veloUT_forward_ghost_mlp.class.C", ], ) try: import subprocess # run clang-format for nicer looking result for outfile in outfiles: subprocess.run( [ "clang-format", "-i", f"{outfile}", ], check=True, ) except: pass