Analysis Software
Documentation for sPHENIX simulation software
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
train.C
Go to the documentation of this file. Or view the newest version in sPHENIX GitHub for file train.C
1 #include <TFile.h>
2 #include <TTree.h>
3 #include <TString.h>
4 #include <TChain.h>
5 
6 
7 #if not defined(__CINT__) || defined(__MAKECINT__)
8 // needs to be included when makecint runs (ACLIC)
9 #include "TMVA/Factory.h"
10 #include "TMVA/Tools.h"
11 #endif
12 
13 
14 
15 TChain* handleFile(string name, string extension, string treename, int filecount){
16  TChain *all = new TChain(treename.c_str());
17  string temp;
18  for (int i = 0; i < filecount; ++i)
19  {
20 
21  ostringstream s;
22  s<<i;
23  temp = name+string(s.str())+extension;
24  all->Add(temp.c_str());
25  }
26  return all;
27 }
28 
29 
30 void makeFactory(TTree* signalTree, TTree* backTree,std::string outfile,std::string factoryname, TTree* bgTree2=NULL)
31 {
32  using namespace TMVA;
33  TString jobname(factoryname.c_str());
34  TFile *targetFile = new TFile(outfile.c_str(),"RECREATE");
35  Factory *factory = new Factory(jobname,targetFile);
36  factory->AddSignalTree(signalTree,1.0);
37  factory->AddBackgroundTree(backTree,1.0);
38  if(bgTree2){
39  factory->AddBackgroundTree(bgTree2,1.0);
40  }
41  factory->AddSpectator("track_layer",'I');
42  factory->AddSpectator("track_pT",'F');
43  factory->AddSpectator("track_dca",'F');
44  factory->AddSpectator("cluster_prob",'F');
45  factory->AddSpectator("abs(track_deta)",'F');
46  factory->AddSpectator("abs(cluster_deta)",'F');
47  factory->AddSpectator("abs(cluster_dphi)",'F');
48  factory->AddSpectator("abs(track_dlayer)",'I');
49  factory->AddSpectator("approach_dist",'F');
50  factory->AddVariable("vtx_radius",'F');
51 // factory->AddVariable("vtx_chi2",'F');
52  //factory->AddVariable("vtxTrackRZ_dist",'F');
53  //factory->AddVariable("abs(vtxTrackRPhi_dist-vtxTrackRZ_dist)",'F');
54  factory->AddVariable("photon_m",'F');
55  factory->AddVariable("photon_pT",'F');
56 
57  string track_layer_cut = "track_layer>-1.";
58  string track_pT_cut = "track_pT>2.0";
59  string track_dca_cut = "50>track_dca>0";
60  string em_prob_cut = "cluster_prob>-0.1";
61  string track_deta_cut = ".0082>=track_deta";
62  string track_dlayer_cut = "2>=abs(track_dlayer)";
63  string approach_dist_cut = "69.34>approach_dist>0";
64  string vtx_radius_cut = "vtx_radius>0";
65  //do I need photon cuts?
66  string tCutInitializer = track_pT_cut+"&&"+em_prob_cut+"&&"+track_layer_cut;//+"&&"+track_deta_cut+"&&"+track_dlayer_cut+"&&"+approach_dist_cut+"&&"+vtx_radius_cut;
67  TCut preTraingCuts(tCutInitializer.c_str());
68 
69  factory->PrepareTrainingAndTestTree(preTraingCuts,"nTrain_Signal=0:nTrain_Background=0:nTest_Signal=0:nTest_Background=0");
70  //for track training
71  //factory->BookMethod(Types::kCuts,"Cuts");
72  //for pair training
73  //factory->BookMethod(Types::kCuts,"Cuts","CutRangeMin[0]=0:CutRangeMax[0]=1:CutRangeMin[1]=-100:CutRangeMax[1]=100:CutRangeMin[2]=0:CutRangeMax[2]=100");
74  //for vtx training
75  factory->BookMethod(Types::kCuts,"Cuts");
76  factory->TrainAllMethods();
77  factory->TestAllMethods();
78  factory->EvaluateAllMethods();
79  targetFile->Write();
80  targetFile->Close();
81 }
82 
83 
84 int train(){
85  using namespace std;
86  string treePath = "/sphenix/user/vassalli/RecoConversionTests/truthconversionembededonlineanalysis";
87  string treeExtension = ".root";
88  string outname = "cutTrainA.root";
89  unsigned int nFiles=100;
90 
91  /*TChain *backVtxTree = new TChain("vtxBackTree");
92  TChain *signalTree = new TChain("cutTreeSignal");
93  backVtxTree->Add(treePath.c_str());
94  signalTree->Add(treePath.c_str());*/
95  TChain *signalTree = handleFile(treePath,treeExtension,"cutTreeSignal",nFiles);
96  TChain *backtrackTree = handleFile(treePath,treeExtension,"trackBackTree",nFiles);
97  TChain *backpairTree = handleFile(treePath,treeExtension,"pairBackTree",nFiles);
98  TChain *backVtxTree = handleFile(treePath,treeExtension,"vtxBackTree",nFiles);
99  //makeFactory(signalTree,backtrackTree,outname,"trackback");
100  //makeFactory(signalTree,backpairTree,outname,"pairback");
101  makeFactory(signalTree,backVtxTree,outname,"vtxback");
102 /* outname="cutTrainE.root";
103  makeFactory(signalTree,backETree,outname,"eback");*/
104 }