Analysis Software
Documentation for sPHENIX simulation software
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
exatrkx.py
Go to the documentation of this file. Or view the newest version in sPHENIX GitHub for file exatrkx.py
1 #!/usr/bin/env python3
2 from pathlib import Path
3 from typing import Optional, Union
4 
5 import acts.examples
6 import acts
7 from acts import UnitConstants as u
8 
9 
10 if "__main__" == __name__:
11  import os
12  import sys
13  from digitization import runDigitization
14  from acts.examples.reconstruction import addExaTrkX, ExaTrkXBackend
15 
16  backend = ExaTrkXBackend.Torch
17 
18  if "onnx" in sys.argv:
19  backend = ExaTrkXBackend.Onnx
20  if "torch" in sys.argv:
21  backend = ExaTrkXBackend.Torch
22 
23  srcdir = Path(__file__).resolve().parent.parent.parent.parent
24 
25  detector, trackingGeometry, decorators = acts.examples.GenericDetector.create()
26 
27  field = acts.ConstantBField(acts.Vector3(0, 0, 2 * u.T))
28 
29  inputParticlePath = Path("particles.root")
30  if not inputParticlePath.exists():
31  inputParticlePath = None
32 
33  srcdir = Path(__file__).resolve().parent.parent.parent.parent
34 
35  geometrySelection = (
36  srcdir
37  / "Examples/Algorithms/TrackFinding/share/geoSelection-genericDetector.json"
38  )
39  assert geometrySelection.exists()
40 
41  digiConfigFile = (
42  srcdir
43  / "Examples/Algorithms/Digitization/share/default-smearing-config-generic.json"
44  )
45  assert digiConfigFile.exists()
46 
47  if backend == ExaTrkXBackend.Torch:
48  modelDir = Path.cwd() / "torchscript_models"
49  assert (modelDir / "embed.pt").exists()
50  assert (modelDir / "filter.pt").exists()
51  assert (modelDir / "gnn.pt").exists()
52  else:
53  modelDir = Path.cwd() / "onnx_models"
54  assert (modelDir / "embedding.onnx").exists()
55  assert (modelDir / "filtering.onnx").exists()
56  assert (modelDir / "gnn.onnx").exists()
57 
58  s = acts.examples.Sequencer(events=2, numThreads=1)
59  s.config.logLevel = acts.logging.INFO
60 
61  rnd = acts.examples.RandomNumbers()
62  outputDir = Path(os.getcwd())
63 
64  s = runDigitization(
65  trackingGeometry,
66  field,
67  outputDir,
68  digiConfigFile=digiConfigFile,
69  particlesInput=inputParticlePath,
70  outputRoot=True,
71  outputCsv=True,
72  s=s,
73  )
74 
75  addExaTrkX(
76  s,
77  trackingGeometry,
78  geometrySelection,
79  modelDir,
80  outputDir,
81  backend=backend,
82  )
83 
84  s.run()