Analysis Software
Documentation for sPHENIX simulation software
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
vtxPredictionTraining.C
Go to the documentation of this file. Or view the newest version in sPHENIX GitHub for file vtxPredictionTraining.C
1 #include <TFile.h>
2 #include <TTree.h>
3 #include <TString.h>
4 #include <TChain.h>
5 
6 
7 #if not defined(__CINT__) || defined(__MAKECINT__)
8 // needs to be included when makecint runs (ACLIC)
9 #include "TMVA/Factory.h"
10 #include "TMVA/Tools.h"
11 #endif
12 
13 
14 
15 TChain* handleFile(string name, string extension, string treename, int filecount){
16  TChain *all = new TChain(treename.c_str());
17  string temp;
18  for (int i = 0; i < filecount; ++i)
19  {
20 
21  ostringstream s;
22  s<<i;
23  temp = name+string(s.str())+extension;
24  all->Add(temp.c_str());
25  }
26  return all;
27 }
28 
29 
30 void makeFactory(TTree* signalTree,std::string outfile,std::string factoryname)
31 {
32  using namespace TMVA;
33  TString jobname(factoryname.c_str());
34  TFile *targetFile = new TFile(outfile.c_str(),"RECREATE");
35  Factory *factory = new Factory(jobname,targetFile,"AnalysisType=Regression");
36  factory->AddRegressionTree(signalTree,1.0);
37  factory->AddVariable("track1_pt",'F');
38  factory->AddVariable("track2_pt",'F');
39  factory->AddVariable("track1_phi",'F');
40  factory->AddVariable("track1_phi-track2_phi","d#phi","rad");
41  factory->AddVariable("track1_eta",'F');
42  factory->AddVariable("track1_eta-track2_eta","d#eta","rad");
43  factory->AddVariable("vtx_radius","radius","[cm]");
44  factory->AddTarget("tvtx_radius","radius","[cm]");
45 
46  string track_pT_cut = "";
47 
48  //string vtx_radius_cut = "vtx_radius>0"; //can I cut based on label?
49  string tCutInitializer = track_pT_cut;
50  TCut preTraingCuts(tCutInitializer.c_str());
51  factory->PrepareTrainingAndTestTree(preTraingCuts,"nTrain_Regression=0:nTest_Regression=0");
52  factory->BookMethod(Types::kMLP,"MLP_ANN","HiddenLayers=2000");
53  factory->BookMethod(Types::kMLP,"MLP_ANN2","HiddenLayers=500,6");
54 
55 
56  factory->TrainAllMethods();
57  factory->TestAllMethods();
58  factory->EvaluateAllMethods();
59  targetFile->Write();
60  targetFile->Close();
61 }
62 
63 
65  using namespace std;
66  string treePath = "/sphenix/user/vassalli/gammasample/conversiononlineanalysis";
67  string treeExtension = ".root";
68  string outname = "vtxTrain.root";
69  unsigned int nFiles=200;
70 
71  TChain *signalTree = handleFile(treePath,treeExtension,"vtxingTree",nFiles);
72  makeFactory(signalTree,outname,"vtxFactory");
73 /* outname="cutTrainE.root";
74  makeFactory(signalTree,backETree,outname,"eback");*/
75 }