Analysis Software
Documentation for sPHENIX simulation software
|
#include <acts/blob/sPHENIX/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchEdgeClassifier.hpp>
Classes | |
struct | Config |
Public Member Functions | |
TorchEdgeClassifier (const Config &cfg, std::unique_ptr< const Logger > logger) | |
~TorchEdgeClassifier () | |
std::tuple< std::any, std::any, std::any > | operator() (std::any nodes, std::any edges, int deviceHint=-1) override |
Config | config () const |
Public Member Functions inherited from Acts::EdgeClassificationBase | |
virtual | ~EdgeClassificationBase ()=default |
Private Member Functions | |
const auto & | logger () const |
Private Attributes | |
std::unique_ptr< const Acts::Logger > | m_logger |
Config | m_cfg |
c10::DeviceType | m_deviceType |
std::unique_ptr < torch::jit::Module > | m_model |
Definition at line 26 of file TorchEdgeClassifier.hpp.
View newest version in sPHENIX GitHub at line 26 of file TorchEdgeClassifier.hpp
Acts::TorchEdgeClassifier::TorchEdgeClassifier | ( | const Config & | cfg, |
std::unique_ptr< const Logger > | logger | ||
) |
Definition at line 20 of file TorchEdgeClassifier.cpp.
View newest version in sPHENIX GitHub at line 20 of file TorchEdgeClassifier.cpp
References ACTS_DEBUG, ACTS_INFO, Acts::UnitConstants::e, m_cfg, m_deviceType, m_model, and Acts::TorchEdgeClassifier::Config::modelPath.
Acts::TorchEdgeClassifier::~TorchEdgeClassifier | ( | ) |
Definition at line 43 of file TorchEdgeClassifier.cpp.
View newest version in sPHENIX GitHub at line 43 of file TorchEdgeClassifier.cpp
|
inline |
Definition at line 42 of file TorchEdgeClassifier.hpp.
View newest version in sPHENIX GitHub at line 42 of file TorchEdgeClassifier.hpp
References m_cfg().
|
inlineprivate |
Definition at line 46 of file TorchEdgeClassifier.hpp.
View newest version in sPHENIX GitHub at line 46 of file TorchEdgeClassifier.hpp
References m_logger.
Referenced by operator()().
|
overridevirtual |
Perform edge classification
nodes | Node tensor with shape (n_nodes, n_node_features) |
edges | Edge-index tensor with shape (2, n_edges) |
deviceHint | Which GPU to pick. Not relevant for CPU-only builds |
Implements Acts::EdgeClassificationBase.
Definition at line 45 of file TorchEdgeClassifier.cpp.
View newest version in sPHENIX GitHub at line 45 of file TorchEdgeClassifier.cpp
References ACTS_DEBUG, Acts::ACTS_VERBOSE(), Acts::TorchEdgeClassifier::Config::cut, logger(), m_cfg, m_deviceType, m_model, mask, testing::internal::move(), Acts::TorchEdgeClassifier::Config::nChunks, Acts::TorchEdgeClassifier::Config::numFeatures, check_smearing_config::output, TauVsDIS_MachineLearning_Differentiation::results, and Acts::TorchEdgeClassifier::Config::undirected.
|
private |
Definition at line 48 of file TorchEdgeClassifier.hpp.
View newest version in sPHENIX GitHub at line 48 of file TorchEdgeClassifier.hpp
Referenced by operator()(), and TorchEdgeClassifier().
|
private |
Definition at line 49 of file TorchEdgeClassifier.hpp.
View newest version in sPHENIX GitHub at line 49 of file TorchEdgeClassifier.hpp
Referenced by operator()(), and TorchEdgeClassifier().
|
private |
Definition at line 45 of file TorchEdgeClassifier.hpp.
View newest version in sPHENIX GitHub at line 45 of file TorchEdgeClassifier.hpp
|
private |
Definition at line 50 of file TorchEdgeClassifier.hpp.
View newest version in sPHENIX GitHub at line 50 of file TorchEdgeClassifier.hpp
Referenced by operator()(), and TorchEdgeClassifier().