20 lines
840 B
Python
20 lines
840 B
Python
|
import numpy as np
|
||
|
|
||
|
from hep_ml.reweight import BinsReweighter
|
||
|
|
||
|
|
||
|
def select_feature(feature: np.ndarray, limits: tuple[float, float]) -> tuple[np.ndarray, list]:
|
||
|
selection_indices = []
|
||
|
for index, value in enumerate(feature):
|
||
|
if value > limits[0] and value < limits[1]:
|
||
|
selection_indices.append(index)
|
||
|
|
||
|
return feature[selection_indices], selection_indices
|
||
|
|
||
|
|
||
|
def reweight_feature(original_feature: list, target_feature: list, n_bins: int, n_neighs: int = 2):
|
||
|
original_weights = np.ones(len(original_feature))
|
||
|
bin_reweighter = BinsReweighter(n_bins = n_bins, n_neighs = n_neighs)
|
||
|
bin_reweighter.fit(original = original_feature, target = target_feature, original_weight = original_weights)
|
||
|
|
||
|
return bin_reweighter.predict_weights(original = original_feature, original_weight = original_weights)
|