7 import torch.nn.functional
as F
9 from torch.utils.tensorboard
import SummaryWriter
11 from sklearn.preprocessing
import LabelEncoder, StandardScaler, OrdinalEncoder
13 from ambiguity_solver_network
import prepareDataSet, DuplicateClassifier, Normalise
15 avg_mean = [0, 0, 0, 0, 0, 0, 0, 0]
16 avg_sdv = [0, 0, 0, 0, 0, 0, 0, 0]
21 """Read the dataset from the different files, remove the pure duplicate tracks and combine the datasets"""
23 @param[in] CKS_files: DataFrame contain the data from each track files (1 file per events usually)
24 @return: combined DataFrame containing all the track, ordered by events and then by truth particle ID in each events
29 datafile = pd.read_csv(f)
32 datafile[
"good/duplicate/fake"] ==
"fake",
"good/duplicate/fake"
36 data = pd.concat([data, datafile])
41 """Prepare the data"""
43 @param[in] data: input DataFrame to be prepared
44 @return: array of the network input and the corresponding truth
47 target_column =
"good/duplicate/fake"
49 y = LabelEncoder().
fit(data[target_column]).
transform(data[target_column])
56 "truthMatchProbability",
63 scale = StandardScaler()
64 scale.fit(input.select_dtypes(
"number"))
67 avg_mean = avg_mean + scale.mean_
69 avg_sdv = avg_sdv + scale.var_
73 x_cat = OrdinalEncoder().fit_transform(input.select_dtypes(
"object"))
74 x = np.concatenate((x_cat, input), axis=1)
78 def batchSplit(data: pd.DataFrame, batch_size: int) -> list[pd.DataFrame]:
79 """Split the data into batch each containing @batch_size truth particles (the number of corresponding tracks may vary)"""
81 @param[in] data: input DataFrame to be cut into batch
82 @param[in] batch_size: Number of truth particles per batch
83 @return: list of DataFrame, each element correspond to a batch
90 for index, row, truth
in zip(data[0], data[1], data[2]):
94 if n_particle == batch_size:
95 b = data[0][id_prev:id], data[1][id_prev:id], data[2][id_prev:id]
104 score_good: torch.Tensor,
105 score_duplicate: list[torch.Tensor],
106 batch_loss: torch.Tensor,
107 margin: float = 0.05,
109 """Compute one loss for each duplicate track associated with the particle"""
111 @param[in] score_good: score return by the model for the good track associated with this particle
112 @param[in] score_duplicate: list of the scores of all duplicate track associated with this particle
113 @param[in] margin: Margin used in the computation of the MarginRankingLoss
114 @return: return the updated loss
117 batch_loss = batch_loss
119 for s
in score_duplicate:
120 batch_loss += F.relu(s - score_good + margin) / len(score_duplicate)
124 def scoringBatch(batch: list[pd.DataFrame], Optimiser=0) -> tuple[int, int, float]:
125 """Run the MLP on a batch and compute the corresponding efficiency and loss. If an optimiser is specified train the MLP."""
127 @param[in] batch: list of DataFrame, each element correspond to a batch
128 @param[in] Optimiser: Optimiser for the MLP, if one is specify the network will be train on batch.
129 @return: array containing the number of particles, the number of particle where the good track was found and the loss
152 Optimiser.zero_grad()
153 input = torch.tensor(b_data[1], dtype=torch.float32)
156 for index, pred, truth
in zip(b_data[0], prediction, b_data[2]):
163 score_good, score_duplicate, batch_loss, margin=0.05
176 score_duplicate.append(pred)
178 if pred == max_score:
186 batch_loss =
computeLoss(score_good, score_duplicate, batch_loss, margin=0.05)
189 batch_loss = batch_loss / len(b_data[0])
193 batch_loss.backward()
195 loss = loss / len(batch)
196 return nb_part, nb_good_match, loss
200 duplicateClassifier: DuplicateClassifier,
201 data: tuple[np.ndarray, np.ndarray, np.ndarray],
204 validation: float = 0.3,
205 ) -> DuplicateClassifier:
206 """Training of the MLP"""
208 @param[in] duplicateClassifier: model to be trained.
209 @param[in] data: tuple containing three list. Each element of those list correspond to a given track and represent : the truth particle ID, the track parameters and the truth.
210 @param[in] epochs: number of epoch the model will be trained for.
211 @param[in] batch: size of the batch used in the training
212 @param[in] validation: Fraction of the batch used in training
213 @return: trained model
217 writer = SummaryWriter()
218 opt = torch.optim.Adam(duplicateClassifier.parameters())
221 val_batch = int(len(batch) * (1 - validation))
223 for epoch
in range(epochs):
224 print(
"Epoch : ", epoch,
" / ", epochs)
230 nb_part, nb_good_match, loss =
scoringBatch(batch[:val_batch], Optimiser=opt)
231 print(
"Loss/train : ", loss,
" Eff/train : ", nb_good_match / nb_part)
232 writer.add_scalar(
"Loss/train", loss, epoch)
233 writer.add_scalar(
"Eff/train", nb_good_match / nb_part, epoch)
237 nb_part, nb_good_match, loss =
scoringBatch(batch[val_batch:])
238 writer.add_scalar(
"Loss/val", loss, epoch)
239 writer.add_scalar(
"Eff/val", nb_good_match / nb_part, epoch)
240 print(
"Loss/val : ", loss,
" Eff/val : ", nb_good_match / nb_part)
243 return duplicateClassifier
249 CKF_files = sorted(glob.glob(
"odd_output" +
"/event0000000[0-7][0-9]-tracks_ckf.csv"))
255 avg_mean = [x / events
for x
in avg_mean]
256 avg_sdv = [x / events
for x
in avg_sdv]
259 input_dim = np.shape(x_train)[1]
260 layers_dim = [10, 15, 10]
262 duplicateClassifier = nn.Sequential(
267 input = data.index, x_train, y_train
268 train(duplicateClassifier, input, epochs=20, batch=128, validation=0.3)
269 duplicateClassifier.eval()
270 input_test = torch.tensor(x_train, dtype=torch.float32)
271 torch.save(duplicateClassifier,
"duplicateClassifier.pt")
275 "duplicateClassifier.onnx",
278 dynamic_axes={
"x": {0:
"batch_size"},
"y": {0:
"batch_size"}},
283 CKF_files_test = sorted(
284 glob.glob(
"odd_output" +
"/event0000000[8-9][0-9]-tracks_ckf.csv")
294 x_test = torch.tensor(x_test, dtype=torch.float32)
299 for sample_test, sample_predict, sample_true
in zip(
300 test.index[0:100], output_predict[0:100], y_test[0:100]
302 print(sample_test, sample_predict, sample_true)
312 for index, pred, truth
in zip(test.index, output_predict, y_test):
320 if pred == max_score:
329 print(
"nb particles : ", nb_part)
330 print(
"nb good match : ", nb_good_match)
331 print(
"Efficiency: ", 100 * nb_good_match / nb_part,
" %")