diff --git a/CaloCluster/data/calo_cluster_net_v2_stage1.norm.json b/CaloCluster/data/calo_cluster_net_v2_stage1.norm.json new file mode 100644 index 0000000000..4b5fd6428c --- /dev/null +++ b/CaloCluster/data/calo_cluster_net_v2_stage1.norm.json @@ -0,0 +1,59 @@ +{ + "schema_version": 1, + "node_features": [ + "log_e", + "t", + "x", + "y", + "r", + "e_rel" + ], + "edge_features": [ + "dx", + "dy", + "d", + "dt", + "dlog_e", + "asym_e", + "logsum_e", + "dr" + ], + "node_mean": [ + 2.4069125652313232, + 834.6543579101562, + -23.62708282470703, + 70.98532104492188, + 455.12542724609375, + 0.38432541489601135 + ], + "node_std": [ + 0.7527871131896973, + 390.6362609863281, + 325.2221374511719, + 315.6965026855469, + 62.38364028930664, + 0.29344040155410767 + ], + "edge_mean": [ + 0.0, + 0.0, + 95.80323791503906, + 0.0, + 0.0, + 0.0, + 3.102428674697876, + 0.0 + ], + "edge_std": [ + 108.05096435546875, + 107.51701354980469, + 118.56077575683594, + 4.873984336853027, + 1.2146189212799072, + 0.5318130850791931, + 0.6094895005226135, + 47.51533889770508 + ], + "node_count": 348548, + "edge_count": 831668 +} diff --git a/CaloCluster/data/calo_cluster_net_v2_stage1.onnx b/CaloCluster/data/calo_cluster_net_v2_stage1.onnx new file mode 100644 index 0000000000..248ffa918f Binary files /dev/null and b/CaloCluster/data/calo_cluster_net_v2_stage1.onnx differ diff --git a/CaloCluster/data/simple_edge_net_v2.onnx b/CaloCluster/data/simple_edge_net_v2.onnx new file mode 100644 index 0000000000..8fabf3bd5c Binary files /dev/null and b/CaloCluster/data/simple_edge_net_v2.onnx differ diff --git a/CaloCluster/fcl/from_mcs-gnn-prod.fcl b/CaloCluster/fcl/from_mcs-gnn-prod.fcl new file mode 100644 index 0000000000..5844b58921 --- /dev/null +++ b/CaloCluster/fcl/from_mcs-gnn-prod.fcl @@ -0,0 +1,51 @@ +# Production-style FCL that runs both BFS clustering (existing) and +# GNN clustering (new) on MCS art-format input. Demonstrates the +# `CaloClusterGNN` prolog pattern from +# Offline/CaloCluster/fcl/prolog.fcl -- production-reco FCLs include +# both the BFS sequence and the GNN sequence to emit two +# CaloClusterCollections side by side. +# +# Usage (working_dir, with u092): +# mu2e -c Offline/CaloCluster/fcl/from_mcs-gnn-prod.fcl \ +# -s -T mcs.gnn.art -n 100 +# +# The output art file carries: +# * caloClusterMaker :: CaloClusterCollection ("") -- BFS, untouched +# * caloClusterMakerGNN :: CaloClusterCollection ("GNN") -- GNN +# +# Task 16h (production FHiCL wiring). + +#include "Offline/fcl/minimalMessageService.fcl" +#include "Offline/fcl/standardServices.fcl" +#include "Offline/CaloCluster/fcl/prolog.fcl" + +process_name : GnnProd + +source : { module_type : RootInput } + +services : @local::Services.Reco + +physics : { + producers : { + caloHitGraphMakerGNN : @local::CaloClusterGNN.caloHitGraphMakerGNN + caloClusterMakerGNN : @local::CaloClusterGNN.caloClusterMakerGNN + } + + GnnPath : [ @sequence::CaloClusterGNN.Reco ] + OutPath : [ out ] + + trigger_paths : [ GnnPath ] + end_paths : [ OutPath ] +} + +outputs : { + out : { + module_type : RootOutput + fileName : "mcs.gnn.art" + # CaloHitGraphCollection is a transient data product (no ROOT + # dictionary by design -- see Offline/RecoDataProducts/inc/CaloHitGraph.hh + # and offline_integration.md 2.2). Drop it from the output art file. + outputCommands : [ "keep *", + "drop *_caloHitGraphMakerGNN_*_*" ] + } +} diff --git a/CaloCluster/fcl/from_mcs-gnn-test.fcl b/CaloCluster/fcl/from_mcs-gnn-test.fcl new file mode 100644 index 0000000000..4e421f680b --- /dev/null +++ b/CaloCluster/fcl/from_mcs-gnn-test.fcl @@ -0,0 +1,72 @@ +# Smoke + parity test for the GNN clustering split design. +# +# Reads MCS art files (which carry CaloHitCollection produced by +# CaloHitMaker in the Reconstruct process), runs the two new +# EDProducers, and dumps the GNN cluster assignments per +# event-disk to a flat TTree for byte-comparison against the +# Python pipeline. +# +# Usage (build node, with u092): +# mu2e -c Offline/CaloCluster/fcl/from_mcs-gnn-test.fcl \ +# -T parity_dump.root -n 100 +# +# Companion Python script: +# calorimeter/GNN/scripts/compare_parity_dump.py +# +# Task 16g, Stage 3. + +#include "Offline/fcl/minimalMessageService.fcl" +#include "Offline/fcl/standardServices.fcl" + +process_name : GnnTest + +source : { module_type : RootInput } + +services : @local::Services.Reco + +physics : { + + producers : { + + caloHitGraphMakerGNN : { + module_type : "CaloHitGraphMaker" + caloHitCollection : "CaloHitMaker" + normSidecar : "Offline/CaloCluster/data/calo_cluster_net_v2_stage1.norm.json" + rMax : 210.0 + dtMax : 25.0 + kMin : 3 + kMax : 20 + } + + caloClusterMakerGNN : { + module_type : "CaloClusterMakerGNN" + caloHitGraphCollection : "caloHitGraphMakerGNN" + modelPath : "Offline/CaloCluster/data/calo_cluster_net_v2_stage1.onnx" + expectedModelVersion : "calo-cluster-net-v2-stage1" + expectedNodeFeatures : ["log_e","t","x","y","r","e_rel"] + expectedEdgeFeatures : ["dx","dy","d","dt","dlog_e","asym_e","logsum_e","dr"] + tauEdge : 0.20 + bfsExpandCut : 10.0 + minHits : 2 + minEnergyMeV : 10.0 + outputInstance : "GNN" + } + } + + analyzers : { + + parityDump : { + module_type : "CaloHitGraphParityDump" + caloHitCollection : "CaloHitMaker" + caloClusterCollection : "caloClusterMakerGNN:GNN" + } + } + + GnnPath : [ caloHitGraphMakerGNN, caloClusterMakerGNN ] + DumpPath: [ parityDump ] + + trigger_paths : [ GnnPath ] + end_paths : [ DumpPath ] +} + +services.TFileService.fileName : "parity_dump.root" diff --git a/CaloCluster/fcl/prolog.fcl b/CaloCluster/fcl/prolog.fcl index ce69466169..9fa47de88a 100644 --- a/CaloCluster/fcl/prolog.fcl +++ b/CaloCluster/fcl/prolog.fcl @@ -55,4 +55,68 @@ CaloCluster : { @table::CaloCluster Reco : [ CaloProtoClusterMaker, CaloClusterMaker, CaloClusterFastMaker ] } + +# --------------------------------------------------------------------- +# GNN clustering (split design -- see calorimeter/GNN/docs/offline_integration.md). +# +# Two producers run alongside the existing BFS chain (CaloProtoClusterMaker +# + CaloClusterMaker), reading the same CaloHitCollection. The graph +# producer emits a transient CaloHitGraphCollection that the cluster +# producer consumes. The cluster producer's output ships under instance +# name "GNN" so downstream consumers select via (module_label, "GNN") +# and existing BFS-reading analyses keep working unchanged. +# +# Production FCLs that want both BFS and GNN clustering should append: +# physics.producers.caloHitGraphMakerGNN : @local::CaloClusterGNN.caloHitGraphMakerGNN +# physics.producers.caloClusterMakerGNN : @local::CaloClusterGNN.caloClusterMakerGNN +# physics. : [ ..., CaloHitMaker, CaloProtoClusterMaker, CaloClusterMaker, +# caloHitGraphMakerGNN, caloClusterMakerGNN ] +# +# Or use the bundled sequence: +# physics. : [ ..., CaloClusterGNN.Reco ] +# +# Frozen recipe values (CCN+BFS10, calorimeter/GNN/docs/findings.md 7.4): +# tauEdge=0.20 bfsExpandCut=10.0 minHits=2 minEnergyMeV=10.0 +# +# To swap in SimpleEdgeNet for an A/B comparison job, declare a second +# instance with model_path/expected_model_version pointing at sen.onnx +# and tauEdge=0.26 (see offline_integration.md 2.2). + +CaloClusterGNN : { + caloHitGraphMakerGNN : + { + module_type : CaloHitGraphMaker + caloHitCollection : CaloHitMaker + normSidecar : "Offline/CaloCluster/data/calo_cluster_net_v2_stage1.norm.json" + rMax : 210.0 + dtMax : 25.0 + kMin : 3 + kMax : 20 + } + + caloClusterMakerGNN : + { + module_type : CaloClusterMakerGNN + caloHitGraphCollection : caloHitGraphMakerGNN + modelPath : "Offline/CaloCluster/data/calo_cluster_net_v2_stage1.onnx" + expectedModelVersion : "calo-cluster-net-v2-stage1" + expectedNodeFeatures : ["log_e","t","x","y","r","e_rel"] + expectedEdgeFeatures : ["dx","dy","d","dt","dlog_e","asym_e","logsum_e","dr"] + tauEdge : 0.20 + bfsExpandCut : 10.0 + minHits : 2 + minEnergyMeV : 10.0 + outputInstance : "GNN" + } +} + +CaloClusterGNN : { @table::CaloClusterGNN + producers : { + caloHitGraphMakerGNN : { @table::CaloClusterGNN.caloHitGraphMakerGNN } + caloClusterMakerGNN : { @table::CaloClusterGNN.caloClusterMakerGNN } + } + + Reco : [ caloHitGraphMakerGNN, caloClusterMakerGNN ] +} + END_PROLOG diff --git a/CaloCluster/inc/GnnClusterAssembler.hh b/CaloCluster/inc/GnnClusterAssembler.hh new file mode 100644 index 0000000000..b601dec15c --- /dev/null +++ b/CaloCluster/inc/GnnClusterAssembler.hh @@ -0,0 +1,60 @@ +#ifndef CaloCluster_GnnClusterAssembler_hh +#define CaloCluster_GnnClusterAssembler_hh +// +// C++ port of calorimeter/GNN/src/inference/cluster_reco.py for the +// CCN+BFS10 recipe (the winning configuration in +// docs/findings.md §7.4). +// +// Steps applied to the directed edge logits emitted by the ONNX model: +// 1. Sigmoid → per-edge probabilities. +// 2. Symmetrise: for each unordered pair {i, j}, take the mean of +// p_ij and p_ji. +// 3. Threshold at tauEdge. +// 4. BFS traversal seeded from highest-energy hits — hits with +// energy >= bfsExpandCut continue the BFS; lower-energy hits join +// but cannot recruit. Mirrors Offline's ClusterFinder ExpandCut. +// 5. Cleanup: drop clusters with fewer than minHits hits or total +// energy below minEnergyMeV. +// 6. Relabel to contiguous IDs. +// +// Returns labels[N] where labels[i] = cluster ID >= 0 or -1 (dropped). +// + +#include +#include + +namespace mu2e { + + class GnnClusterAssembler + { + public: + struct Config + { + double tauEdge = 0.20; // probability threshold (model-specific) + double bfsExpandCut = 10.0; // MeV — BFS-style ExpandCut + unsigned minHits = 2; // drop clusters smaller than this + double minEnergyMeV = 10.0; // drop clusters below this total energy + }; + + explicit GnnClusterAssembler(const Config& cfg) : cfg_(cfg) {} + + // nNodes : number of hits in the graph + // edgeIndex : flat (2 * E) int64s, src row first then dst row, + // matching the CaloHitGraph layout + // edgeLogits : pre-sigmoid logits emitted by the ONNX model (size E) + // hitEnergiesMeV: per-node raw energies in MeV (size N) + // + // Returns a vector of length N: labels[i] = cluster ID (>= 0) or + // -1 (unclustered after min_hits / min_energy_mev cleanup). + std::vector assemble(int nNodes, + const std::vector& edgeIndex, + const std::vector& edgeLogits, + const std::vector& hitEnergiesMeV) const; + + private: + Config cfg_; + }; + +} + +#endif diff --git a/CaloCluster/inc/GnnGraphBuilder.hh b/CaloCluster/inc/GnnGraphBuilder.hh new file mode 100644 index 0000000000..021c7d11a1 --- /dev/null +++ b/CaloCluster/inc/GnnGraphBuilder.hh @@ -0,0 +1,84 @@ +#ifndef CaloCluster_GnnGraphBuilder_hh +#define CaloCluster_GnnGraphBuilder_hh +// +// C++ port of calorimeter/GNN/src/data/graph_builder.py. +// +// Builds one CaloHitGraph per calorimeter disk per event: +// 1. Collect CaloHits per disk; look up (x, y) in the disk-local +// frame from the Calorimeter geometry service. +// 2. Brute-force pairwise distance loop with r_max cut for the +// radius graph (faithful to scipy.spatial.cKDTree.query_pairs). +// 3. Time filter |dt| < dt_max ns. +// 4. kNN fallback for nodes with degree < k_min after the radius+time pass. +// 5. Per-source-node degree cap at k_max (keep the k_max nearest dsts). +// 6. Compute 6 node features and 8 edge features. +// 7. Z-score normalise using the train-split stats from the JSON +// sidecar passed at construction (loaded via loadStatsFromJson). +// +// Feature column order is canonical and matches the model's +// metadata_props (see calorimeter/GNN/docs/onnx_deployment.md): +// +// nodes : log_e, t, x, y, r, e_rel +// edges : dx, dy, d, dt, dlog_e, asym_e, logsum_e, dr +// + +#include "Offline/CalorimeterGeom/inc/Calorimeter.hh" +#include "Offline/RecoDataProducts/inc/CaloHit.hh" +#include "Offline/RecoDataProducts/inc/CaloHitGraph.hh" + +#include "canvas/Persistency/Common/Ptr.h" + +#include +#include + +namespace mu2e { + + class GnnGraphBuilder + { + public: + // Per-feature normalisation stats (z-score: (x - mean) / std). + struct Stats + { + std::vector nodeMean; // size 6 + std::vector nodeStd; // size 6 + std::vector edgeMean; // size 8 + std::vector edgeStd; // size 8 + }; + + struct Config + { + double rMax = 210.0; // mm — radius graph cut + double dtMax = 25.0; // ns — time-coincidence cut + unsigned kMin = 3; // kNN fallback floor + unsigned kMax = 20; // per-source-node degree cap + }; + + GnnGraphBuilder(const Config& cfg, const Stats& stats) + : cfg_(cfg), stats_(stats) {} + + // Load Stats from the JSON sidecar produced by + // calorimeter/GNN/scripts/export_norm_stats.py. Throws + // cet::exception on missing keys, wrong sizes, or canonical + // node/edge feature-name mismatches. + static Stats loadStatsFromJson(const std::string& jsonPath); + + // Build one CaloHitGraph for one disk. + // diskID — destination disk for the emitted graph + // hits — pointers to the CaloHits on this disk + // ptrs — art::Ptr back to each hit, parallel to `hits` + // cal — geometry handle for crystal positions + // out — populated in place (cleared first) + void buildGraph(int diskID, + const std::vector& hits, + const std::vector>& ptrs, + const Calorimeter& cal, + CaloHitGraph& out) const; + + private: + Config cfg_; + Stats stats_; + }; + +} + +#endif diff --git a/CaloCluster/src/CaloClusterMakerGNN_module.cc b/CaloCluster/src/CaloClusterMakerGNN_module.cc new file mode 100644 index 0000000000..a68b0d5525 --- /dev/null +++ b/CaloCluster/src/CaloClusterMakerGNN_module.cc @@ -0,0 +1,344 @@ +// +// CaloClusterMakerGNN — second half of the GNN clustering split design +// (calorimeter/GNN/docs/offline_integration.md §1.2). +// +// Constructor: load the ONNX session for one model artifact and assert +// the model's metadata_props match what the FHiCL config expects. Bail +// loudly on mismatch — silent tensor-layout drift after a retraining +// must not be possible. +// +// produce(): consumes CaloHitGraphCollection (from CaloHitGraphMaker), +// runs ONNX inference per disk, then assembles CaloClusters via the +// CCN+BFS10 recipe. The assembly logic (16f) is not yet implemented; +// produce() currently emits an empty CaloClusterCollection so the +// module compiles and links. +// +// The module is model-agnostic — the same C++ class instances run +// SimpleEdgeNet or CaloClusterNet (or any future model with the same +// tensor I/O signature). FHiCL parameters distinguish them: +// - model_path +// - expected_model_version +// - expected_node_features / expected_edge_features +// - tau_edge / bfs_expand_cut / min_hits / min_energy_mev +// - output_instance +// + +#include "art/Framework/Core/EDProducer.h" +#include "art/Framework/Core/ModuleMacros.h" +#include "art/Framework/Principal/Event.h" +#include "art/Framework/Principal/Handle.h" +#include "cetlib_except/exception.h" +#include "fhiclcpp/types/Atom.h" +#include "fhiclcpp/types/Sequence.h" + +#include "Offline/CaloCluster/inc/ClusterUtils.hh" +#include "Offline/CaloCluster/inc/GnnClusterAssembler.hh" +#include "Offline/CalorimeterGeom/inc/Calorimeter.hh" +#include "Offline/ConfigTools/inc/ConfigFileLookupPolicy.hh" +#include "Offline/GeometryService/inc/GeomHandle.hh" +#include "Offline/RecoDataProducts/inc/CaloCluster.hh" +#include "Offline/RecoDataProducts/inc/CaloHit.hh" +#include "Offline/RecoDataProducts/inc/CaloHitGraph.hh" + +// Member-init order matters: env must outlive session_options, and +// session_options must outlive session. RAII resources lifetime +// captured in declaration order at the bottom of this class. +#include "onnxruntime/core/session/onnxruntime_cxx_api.h" + +#include +#include +#include +#include +#include +#include + + +namespace { + + // Helper: split a comma-separated string into trimmed tokens. + std::vector splitCsv(const std::string& s) + { + std::vector out; + std::stringstream ss(s); + std::string tok; + while (std::getline(ss, tok, ',')) { + // strip surrounding whitespace + const auto a = tok.find_first_not_of(" \t"); + const auto b = tok.find_last_not_of(" \t"); + if (a == std::string::npos) continue; + out.push_back(tok.substr(a, b - a + 1)); + } + return out; + } + + // Read an entry from the loaded model's metadata_props map by key. + // Throws cet::exception if the key is absent. + std::string readMetadataProp(const Ort::ModelMetadata& meta, + Ort::AllocatorWithDefaultOptions& alloc, + const char* key) + { + auto raw = meta.LookupCustomMetadataMapAllocated(key, alloc); + if (!raw) { + throw cet::exception("CaloClusterMakerGNN") + << "ONNX metadata_props missing required key '" << key << "'"; + } + return std::string(raw.get()); + } + +} + + +namespace mu2e { + + class CaloClusterMakerGNN : public art::EDProducer + { + public: + struct Config + { + using Name = fhicl::Name; + using Comment = fhicl::Comment; + + fhicl::Atom caloHitGraphCollection { + Name("caloHitGraphCollection"), + Comment("CaloHitGraphCollection input tag (from CaloHitGraphMaker)") }; + fhicl::Atom modelPath { + Name("modelPath"), + Comment("Relative path to the .onnx artifact; resolved by ConfigFileLookupPolicy") }; + fhicl::Atom expectedModelVersion { + Name("expectedModelVersion"), + Comment("Required value of metadata_props['model_version'] in the .onnx") }; + fhicl::Sequence expectedNodeFeatures { + Name("expectedNodeFeatures"), + Comment("Required canonical node feature names; asserted against metadata_props") }; + fhicl::Sequence expectedEdgeFeatures { + Name("expectedEdgeFeatures"), + Comment("Required canonical edge feature names; asserted against metadata_props") }; + fhicl::Atom tauEdge { + Name("tauEdge"), + Comment("Edge probability threshold (model-specific: 0.20 for CCN, 0.26 for SEN)") }; + fhicl::Atom bfsExpandCut { + Name("bfsExpandCut"), + Comment("BFS-style ExpandCut: hits below this energy join clusters but cannot recruit (MeV)"), + 10.0 }; + fhicl::Atom minHits { + Name("minHits"), + Comment("Drop clusters with fewer than this many hits"), + 2 }; + fhicl::Atom minEnergyMeV { + Name("minEnergyMeV"), + Comment("Drop clusters with less than this much total energy (MeV)"), + 10.0 }; + fhicl::Atom outputInstance { + Name("outputInstance"), + Comment("Instance name on the emitted CaloClusterCollection (\"GNN\" by default)"), + std::string("GNN") }; + }; + + explicit CaloClusterMakerGNN(const art::EDProducer::Table& config); + + void produce(art::Event& event) override; + + private: + art::ProductToken graphToken_; + + std::string outputInstance_; + std::vector expectedNodeFeatures_; + std::vector expectedEdgeFeatures_; + + // CCN+BFS10 (or whatever the swappable model needs) recipe. + GnnClusterAssembler assembler_; + + // ONNX Runtime resources. Member declaration order is intentional: + // env_ must outlive sessionOptions_ which must outlive session_. + Ort::Env env_; + Ort::SessionOptions sessionOptions_; + Ort::Session session_; + Ort::AllocatorWithDefaultOptions allocator_; + + // Cached input/output names (returned by the session, owned by + // RAII smart pointers from ONNX Runtime). + std::vector inputNameHolders_; + std::vector outputNameHolders_; + std::vector inputNames_; + std::vector outputNames_; + }; + + + CaloClusterMakerGNN::CaloClusterMakerGNN( + const art::EDProducer::Table& config) : + art::EDProducer{config}, + graphToken_{consumes(config().caloHitGraphCollection())}, + outputInstance_ (config().outputInstance()), + expectedNodeFeatures_(config().expectedNodeFeatures()), + expectedEdgeFeatures_(config().expectedEdgeFeatures()), + assembler_({config().tauEdge(), + config().bfsExpandCut(), + config().minHits(), + config().minEnergyMeV()}), + env_(ORT_LOGGING_LEVEL_WARNING, "CaloClusterMakerGNN"), + sessionOptions_(), + session_{nullptr} + { + ConfigFileLookupPolicy lookup; + const std::string modelPath = lookup(config().modelPath()); + + // Move-construct the session in place (Ort::Session has no + // copy-assign; the trick is to assign via std::move from a + // temporary built with the resolved path). + session_ = Ort::Session(env_, modelPath.c_str(), sessionOptions_); + + // Cache the input/output name strings — Run() takes const char* + // arrays and we want them to live as long as the session does. + { + const std::size_t nIn = session_.GetInputCount(); + const std::size_t nOut = session_.GetOutputCount(); + inputNameHolders_.reserve(nIn); + outputNameHolders_.reserve(nOut); + inputNames_.reserve(nIn); + outputNames_.reserve(nOut); + for (std::size_t i = 0; i < nIn; ++i) { + inputNameHolders_.emplace_back(session_.GetInputNameAllocated(i, allocator_)); + inputNames_.push_back(inputNameHolders_.back().get()); + } + for (std::size_t i = 0; i < nOut; ++i) { + outputNameHolders_.emplace_back(session_.GetOutputNameAllocated(i, allocator_)); + outputNames_.push_back(outputNameHolders_.back().get()); + } + } + + // Validate metadata_props against FHiCL expectations. + Ort::ModelMetadata meta = session_.GetModelMetadata(); + const std::string gotVersion = readMetadataProp(meta, allocator_, "model_version"); + const std::string gotNodes = readMetadataProp(meta, allocator_, "node_features"); + const std::string gotEdges = readMetadataProp(meta, allocator_, "edge_features"); + + if (gotVersion != config().expectedModelVersion()) { + throw cet::exception("CaloClusterMakerGNN") + << "model_version mismatch: ONNX has '" << gotVersion + << "', FHiCL expected '" << config().expectedModelVersion() + << "' (modelPath=" << modelPath << ")"; + } + const auto gotNodeNames = splitCsv(gotNodes); + if (gotNodeNames != expectedNodeFeatures_) { + throw cet::exception("CaloClusterMakerGNN") + << "node_features mismatch in " << modelPath + << ": ONNX has '" << gotNodes + << "', FHiCL expected " << expectedNodeFeatures_.size() << " entries"; + } + const auto gotEdgeNames = splitCsv(gotEdges); + if (gotEdgeNames != expectedEdgeFeatures_) { + throw cet::exception("CaloClusterMakerGNN") + << "edge_features mismatch in " << modelPath + << ": ONNX has '" << gotEdges + << "', FHiCL expected " << expectedEdgeFeatures_.size() << " entries"; + } + + produces(outputInstance_); + } + + + void CaloClusterMakerGNN::produce(art::Event& event) + { + auto graphHandle = event.getHandle(graphToken_); + auto out = std::make_unique(); + + if (!graphHandle.isValid() || graphHandle->empty()) { + event.put(std::move(out), outputInstance_); + return; + } + + const Calorimeter& cal = *(GeomHandle()); + const auto memInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, + OrtMemTypeDefault); + + for (const auto& graph : *graphHandle) { + if (graph.nNodes <= 0 || graph.nEdges <= 0) continue; + + const int N = graph.nNodes; + const int E = graph.nEdges; + + // Wrap the graph tensors as Ort::Value views — no copy. + const std::array xShape {N, 6}; + const std::array eiShape {2, E}; + const std::array eaShape {E, 8}; + std::array inputs{ + Ort::Value::CreateTensor(memInfo, + const_cast(graph.x.data()), graph.x.size(), + xShape.data(), xShape.size()), + Ort::Value::CreateTensor(memInfo, + const_cast(graph.edgeIndex.data()), graph.edgeIndex.size(), + eiShape.data(), eiShape.size()), + Ort::Value::CreateTensor(memInfo, + const_cast(graph.edgeAttr.data()), graph.edgeAttr.size(), + eaShape.data(), eaShape.size()) + }; + + auto outputs = session_.Run(Ort::RunOptions{nullptr}, + inputNames_.data(), inputs.data(), inputs.size(), + outputNames_.data(), outputNames_.size()); + + // Read edge_logits into a local buffer (don't mutate the + // runtime's output). + const Ort::Value& logitsVal = outputs.front(); + const float* logitsPtr = logitsVal.GetTensorData(); + const std::size_t logitsCount = + logitsVal.GetTensorTypeAndShapeInfo().GetElementCount(); + std::vector edgeLogits(logitsPtr, logitsPtr + logitsCount); + + // Per-node energies (raw MeV) for the seed selection / expand + // cut / min-energy cleanup. Pulled from the original CaloHits. + std::vector energiesMeV(N); + for (int i = 0; i < N; ++i) { + energiesMeV[i] = graph.caloHitPtrs[i]->energyDep(); + } + + // Run the CCN+BFS10 assembly on the directed-edge logits. + const std::vector labels = assembler_.assemble( + N, graph.edgeIndex, edgeLogits, energiesMeV); + + // Bucket nodes by cluster ID. + std::map> clusters; + for (int i = 0; i < N; ++i) { + if (labels[i] >= 0) clusters[labels[i]].push_back(i); + } + + for (const auto& [cid, nodeIdxs] : clusters) { + // Energy aggregates + seed identification. + float energy = 0.0f, energyErrSq = 0.0f; + int seedIdx = nodeIdxs.front(); + float seedEnergy = energiesMeV[seedIdx]; + CaloHitPtrVector caloHits; + caloHits.reserve(nodeIdxs.size()); + for (int idx : nodeIdxs) { + const auto& h = *graph.caloHitPtrs[idx]; + energy += h.energyDep(); + energyErrSq += h.energyDepErr() * h.energyDepErr(); + caloHits.push_back(graph.caloHitPtrs[idx]); + if (h.energyDep() > seedEnergy) { + seedEnergy = h.energyDep(); + seedIdx = idx; + } + } + const auto& seedHit = *graph.caloHitPtrs[seedIdx]; + const float time = seedHit.time(); + const float timeErr = seedHit.timeErr(); + + // 3D centroid via the existing ClusterUtils helper (linear + // energy weighting, matching the BFS module's convention). + ClusterUtils utils(cal, caloHits, ClusterUtils::Linear); + const CLHEP::Hep3Vector cog = utils.cog3Vector(); + + out->emplace_back(graph.diskID, time, timeErr, + energy, std::sqrt(energyErrSq), + cog, caloHits, + static_cast(nodeIdxs.size()), + /*isSplit=*/false); + } + } + + event.put(std::move(out), outputInstance_); + } + +} + +DEFINE_ART_MODULE(mu2e::CaloClusterMakerGNN) diff --git a/CaloCluster/src/CaloHitGraphMaker_module.cc b/CaloCluster/src/CaloHitGraphMaker_module.cc new file mode 100644 index 0000000000..c050e42dbf --- /dev/null +++ b/CaloCluster/src/CaloHitGraphMaker_module.cc @@ -0,0 +1,155 @@ +// +// CaloHitGraphMaker — first half of the GNN clustering split design +// (see calorimeter/GNN/docs/offline_integration.md §1.2). +// +// Per event, partitions the input CaloHitCollection by disk and runs +// the C++ port of the Python graph builder (GnnGraphBuilder) once per +// disk. Emits a CaloHitGraphCollection with one entry per disk that +// has at least one CaloHit. The emitted graphs already carry +// z-score-normalised feature tensors so the downstream cluster module +// has no normalisation responsibility. +// +// Frozen hyper-parameters and feature column order live in the +// canonical norm sidecar (calorimeter/GNN/scripts/export_norm_stats.py) +// and are validated against canonical names at sidecar load time. +// FHiCL parameters expose the graph-construction knobs so swapping +// future models with different feature lists is config-driven. +// + +#include "art/Framework/Core/EDProducer.h" +#include "art/Framework/Core/ModuleMacros.h" +#include "art/Framework/Principal/Event.h" +#include "art/Framework/Principal/Handle.h" +#include "fhiclcpp/types/Atom.h" + +#include "Offline/CaloCluster/inc/GnnGraphBuilder.hh" +#include "Offline/CalorimeterGeom/inc/Calorimeter.hh" +#include "Offline/CalorimeterGeom/inc/Crystal.hh" +#include "Offline/ConfigTools/inc/ConfigFileLookupPolicy.hh" +#include "Offline/GeometryService/inc/GeomHandle.hh" +#include "Offline/RecoDataProducts/inc/CaloHit.hh" +#include "Offline/RecoDataProducts/inc/CaloHitGraph.hh" + +#include +#include +#include + + +namespace mu2e { + + class CaloHitGraphMaker : public art::EDProducer + { + public: + struct Config + { + using Name = fhicl::Name; + using Comment = fhicl::Comment; + + fhicl::Atom caloHitCollection { + Name("caloHitCollection"), + Comment("CaloHit collection to read") }; + fhicl::Atom normSidecar { + Name("normSidecar"), + Comment("Relative path to the JSON norm sidecar; resolved by ConfigFileLookupPolicy") }; + fhicl::Atom rMax { + Name("rMax"), + Comment("Spatial radius cut for the radius graph (mm)"), + 210.0 }; + fhicl::Atom dtMax { + Name("dtMax"), + Comment("Maximum |dt| between connected hits (ns)"), + 25.0 }; + fhicl::Atom kMin { + Name("kMin"), + Comment("kNN-fallback floor for nodes with degree < kMin"), + 3 }; + fhicl::Atom kMax { + Name("kMax"), + Comment("Per-source-node degree cap"), + 20 }; + }; + + explicit CaloHitGraphMaker(const art::EDProducer::Table& config) : + art::EDProducer{config}, + caloHitToken_{consumes(config().caloHitCollection())}, + builder_(makeBuilder(config())) + { + produces(); + } + + void produce(art::Event& event) override; + + private: + static GnnGraphBuilder makeBuilder(const Config& cfg); + + art::ProductToken caloHitToken_; + GnnGraphBuilder builder_; + }; + + + GnnGraphBuilder + CaloHitGraphMaker::makeBuilder(const Config& cfg) + { + GnnGraphBuilder::Config gcfg; + gcfg.rMax = cfg.rMax(); + gcfg.dtMax = cfg.dtMax(); + gcfg.kMin = cfg.kMin(); + gcfg.kMax = cfg.kMax(); + ConfigFileLookupPolicy lookup; + const std::string statsPath = lookup(cfg.normSidecar()); + auto stats = GnnGraphBuilder::loadStatsFromJson(statsPath); + return GnnGraphBuilder(gcfg, stats); + } + + + void CaloHitGraphMaker::produce(art::Event& event) + { + auto hitsHandle = event.getHandle(caloHitToken_); + auto out = std::make_unique(); + + const auto& hits = *hitsHandle; + if (hits.empty()) { + event.put(std::move(out)); + return; + } + + const Calorimeter& cal = *(GeomHandle()); + + // Partition CaloHits by disk. Two passes so nDisks isn't hardcoded: + // first pass to discover the disk set, then bucket pointers + Ptrs. + std::vector diskIDs; + std::vector> hitsByDisk; + std::vector>> ptrsByDisk; + + for (std::size_t i = 0; i < hits.size(); ++i) { + const CaloHit& h = hits[i]; + const int disk = cal.crystal(h.crystalID()).diskID(); + + // Find/register the bucket for this disk. + std::size_t bucket = diskIDs.size(); + for (std::size_t k = 0; k < diskIDs.size(); ++k) { + if (diskIDs[k] == disk) { bucket = k; break; } + } + if (bucket == diskIDs.size()) { + diskIDs.push_back(disk); + hitsByDisk.emplace_back(); + ptrsByDisk.emplace_back(); + } + hitsByDisk[bucket].push_back(&h); + ptrsByDisk[bucket].emplace_back(hitsHandle, i); + } + + // Build one graph per disk that has hits. + out->resize(diskIDs.size()); + for (std::size_t k = 0; k < diskIDs.size(); ++k) { + builder_.buildGraph(diskIDs[k], + hitsByDisk[k], ptrsByDisk[k], + cal, (*out)[k]); + } + + event.put(std::move(out)); + } + +} + +DEFINE_ART_MODULE(mu2e::CaloHitGraphMaker) diff --git a/CaloCluster/src/CaloHitGraphParityDump_module.cc b/CaloCluster/src/CaloHitGraphParityDump_module.cc new file mode 100644 index 0000000000..59ae604d48 --- /dev/null +++ b/CaloCluster/src/CaloHitGraphParityDump_module.cc @@ -0,0 +1,159 @@ +// +// CaloHitGraphParityDump — dumps per-event CaloHit metadata and the +// GNN cluster assignments to a flat TTree, so a Python comparison +// script can replay the same events through the Python pipeline and +// assert byte-exact agreement on cluster labels. +// +// Output TTree (in the TFileService output): +// +// eventID : ULong64_t +// diskID : Int_t +// nHits : Int_t (CaloHits on this disk) +// crystalID[nHits] : std::vector (per hit) +// time_ns[nHits] : std::vector +// eDep_MeV[nHits] : std::vector +// gnnLabel[nHits] : std::vector (-1 if hit was dropped +// by min_hits / min_E cut; +// 0..K-1 otherwise) +// +// One entry per event-disk. The Python comparison script reads the +// TTree, rebuilds the per-disk graph from the dumped CaloHit info, +// runs the Python CaloClusterNet + cluster_reco, and asserts that +// label vectors match. +// +// Used as the Stage 3 parity check in Task 16g. +// + +#include "art/Framework/Core/EDAnalyzer.h" +#include "art/Framework/Core/ModuleMacros.h" +#include "art/Framework/Principal/Event.h" +#include "art/Framework/Principal/Handle.h" +#include "art_root_io/TFileService.h" +#include "fhiclcpp/types/Atom.h" + +#include "Offline/CalorimeterGeom/inc/Calorimeter.hh" +#include "Offline/GeometryService/inc/GeomHandle.hh" +#include "Offline/RecoDataProducts/inc/CaloCluster.hh" +#include "Offline/RecoDataProducts/inc/CaloHit.hh" + +#include "TTree.h" + +#include +#include +#include + + +namespace mu2e { + + class CaloHitGraphParityDump : public art::EDAnalyzer + { + public: + struct Config + { + using Name = fhicl::Name; + using Comment = fhicl::Comment; + fhicl::Atom caloHitCollection { + Name("caloHitCollection"), + Comment("CaloHit collection (input to the GNN graph maker)") }; + fhicl::Atom caloClusterCollection { + Name("caloClusterCollection"), + Comment("CaloCluster collection emitted by CaloClusterMakerGNN, e.g. caloClusterMakerGNN:GNN") }; + }; + + explicit CaloHitGraphParityDump(const art::EDAnalyzer::Table& config); + + void analyze(const art::Event& event) override; + void beginJob() override; + + private: + art::ProductToken hitToken_; + art::ProductToken clusterToken_; + + TTree* tree_ = nullptr; + std::uint64_t bEventID_ = 0; + int bDiskID_ = -1; + int bNHits_ = 0; + std::vector bCrystalID_; + std::vector bTime_; + std::vector bEDep_; + std::vector bGnnLabel_; + }; + + + CaloHitGraphParityDump::CaloHitGraphParityDump( + const art::EDAnalyzer::Table& config) : + art::EDAnalyzer{config}, + hitToken_ {consumes(config().caloHitCollection())}, + clusterToken_{consumes(config().caloClusterCollection())} + { + } + + + void CaloHitGraphParityDump::beginJob() + { + art::ServiceHandle tfs; + tree_ = tfs->make("parity", "GNN-cluster parity dump"); + tree_->Branch("eventID", &bEventID_); + tree_->Branch("diskID", &bDiskID_); + tree_->Branch("nHits", &bNHits_); + tree_->Branch("crystalID", &bCrystalID_); + tree_->Branch("time_ns", &bTime_); + tree_->Branch("eDep_MeV", &bEDep_); + tree_->Branch("gnnLabel", &bGnnLabel_); + } + + + void CaloHitGraphParityDump::analyze(const art::Event& event) + { + auto hitsHandle = event.getHandle(hitToken_); + auto clustersHandle= event.getHandle(clusterToken_); + if (!hitsHandle.isValid() || !clustersHandle.isValid()) return; + + const Calorimeter& cal = *(GeomHandle()); + const auto& hits = *hitsHandle; + const auto& clusters = *clustersHandle; + + // Build per-CaloHit-pointer → cluster-id-on-this-disk map. + // The cluster index is its position within the per-disk subset of + // the CaloClusterCollection (matching what the Python pipeline + // emits, since each disk-graph is reconstructed independently). + std::map> diskClusterByHit; // disk -> hit* -> labelIdx + std::map nextLabelByDisk; + for (const auto& cluster : clusters) { + const int disk = cluster.diskID(); + const int label = nextLabelByDisk[disk]++; + for (const auto& hp : cluster.caloHitsPtrVector()) { + diskClusterByHit[disk][hp.get()] = label; + } + } + + // Partition CaloHits by disk and emit one TTree entry per disk. + std::map> hitsByDisk; + for (const auto& h : hits) { + const int d = cal.crystal(h.crystalID()).diskID(); + hitsByDisk[d].push_back(&h); + } + + bEventID_ = static_cast(event.id().event()); + for (const auto& [disk, diskHits] : hitsByDisk) { + bDiskID_ = disk; + bNHits_ = static_cast(diskHits.size()); + bCrystalID_.clear(); bCrystalID_.reserve(diskHits.size()); + bTime_ .clear(); bTime_.reserve(diskHits.size()); + bEDep_ .clear(); bEDep_.reserve(diskHits.size()); + bGnnLabel_.clear(); bGnnLabel_.reserve(diskHits.size()); + const auto& hitMap = diskClusterByHit[disk]; + for (const CaloHit* h : diskHits) { + bCrystalID_.push_back(h->crystalID()); + bTime_ .push_back(h->time()); + bEDep_ .push_back(h->energyDep()); + auto it = hitMap.find(h); + bGnnLabel_.push_back(it != hitMap.end() ? it->second : -1); + } + tree_->Fill(); + } + } + +} + +DEFINE_ART_MODULE(mu2e::CaloHitGraphParityDump) diff --git a/CaloCluster/src/GnnClusterAssembler.cc b/CaloCluster/src/GnnClusterAssembler.cc new file mode 100644 index 0000000000..2df5aca4d6 --- /dev/null +++ b/CaloCluster/src/GnnClusterAssembler.cc @@ -0,0 +1,141 @@ +// +// Implementation of GnnClusterAssembler. +// See header for design notes / Python reference. +// + +#include "Offline/CaloCluster/inc/GnnClusterAssembler.hh" + +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace mu2e { + + std::vector + GnnClusterAssembler::assemble( + int nNodes, + const std::vector& edgeIndex, + const std::vector& edgeLogits, + const std::vector& hitEnergiesMeV) const + { + std::vector labels(nNodes, -1); + if (nNodes == 0) return labels; + + const int E = static_cast(edgeLogits.size()); + // edgeIndex is flat (2E): src row first (indices 0..E-1) then dst + // row (indices E..2E-1). Sanity check. + if (static_cast(edgeIndex.size()) != 2 * E) { + // Defensive: malformed input; nothing to cluster. + return labels; + } + + // -------- 1. Sigmoid + symmetrise. + // For each unordered pair {i, j}, accumulate the directed + // probabilities; the threshold check uses their mean. + std::map, std::pair> pairScores; + for (int e = 0; e < E; ++e) { + const int s = static_cast(edgeIndex[e]); + const int d = static_cast(edgeIndex[E + e]); + const std::pair key{std::min(s, d), std::max(s, d)}; + const double p = 1.0 / (1.0 + std::exp(-static_cast(edgeLogits[e]))); + auto it = pairScores.find(key); + if (it == pairScores.end()) { + pairScores.emplace(key, std::make_pair(p, 1)); + } else { + it->second.first += p; + it->second.second += 1; + } + } + + // -------- 2. Threshold + adjacency list. + std::vector> adjList(nNodes); + for (const auto& [key, sumCount] : pairScores) { + const double avg = sumCount.first / static_cast(sumCount.second); + if (avg < cfg_.tauEdge) continue; + adjList[key.first].push_back(key.second); + adjList[key.second].push_back(key.first); + } + + // -------- 3. BFS traversal with bfsExpandCut. + // Seed selection: process nodes in descending energy order, so + // each new cluster is rooted at its highest-energy hit (matching + // Offline ClusterFinder semantics and the Python reference). + std::vector seedOrder(nNodes); + std::iota(seedOrder.begin(), seedOrder.end(), 0); + std::sort(seedOrder.begin(), seedOrder.end(), + [&](int a, int b) { + return hitEnergiesMeV[a] > hitEnergiesMeV[b]; + }); + + int clusterId = 0; + for (int seed : seedOrder) { + if (labels[seed] >= 0) continue; + std::deque queue; + queue.push_back(seed); + labels[seed] = clusterId; + while (!queue.empty()) { + const int node = queue.front(); + queue.pop_front(); + // Hits below bfsExpandCut join the cluster but cannot recruit + // further neighbours. They become leaves in the traversal. + if (hitEnergiesMeV[node] < cfg_.bfsExpandCut) continue; + for (int neigh : adjList[node]) { + if (labels[neigh] < 0) { + labels[neigh] = clusterId; + queue.push_back(neigh); + } + } + } + ++clusterId; + } + + // -------- 4. Cleanup: minHits. + { + std::map hitCount; + for (int label : labels) { + if (label >= 0) hitCount[label]++; + } + for (int& label : labels) { + if (label >= 0 && hitCount[label] < static_cast(cfg_.minHits)) { + label = -1; + } + } + } + + // -------- 5. Cleanup: minEnergyMeV. + { + std::map totalE; + for (int i = 0; i < nNodes; ++i) { + if (labels[i] >= 0) totalE[labels[i]] += hitEnergiesMeV[i]; + } + for (int& label : labels) { + if (label >= 0 && totalE[label] < cfg_.minEnergyMeV) { + label = -1; + } + } + } + + // -------- 6. Relabel to contiguous IDs in first-appearance order. + { + std::map remap; + int nextId = 0; + for (int label : labels) { + if (label >= 0 && remap.find(label) == remap.end()) { + remap[label] = nextId++; + } + } + for (int& label : labels) { + if (label >= 0) label = remap[label]; + } + } + + return labels; + } + +} diff --git a/CaloCluster/src/GnnGraphBuilder.cc b/CaloCluster/src/GnnGraphBuilder.cc new file mode 100644 index 0000000000..cbaeb5e414 --- /dev/null +++ b/CaloCluster/src/GnnGraphBuilder.cc @@ -0,0 +1,332 @@ +// +// Implementation of GnnGraphBuilder. +// See header for design notes / Python reference. +// + +#include "Offline/CaloCluster/inc/GnnGraphBuilder.hh" + +#include "Offline/CalorimeterGeom/inc/Calorimeter.hh" +#include "Offline/CalorimeterGeom/inc/Crystal.hh" + +#include "cetlib_except/exception.h" +#include "nlohmann/json.hpp" + +#include +#include +#include +#include +#include +#include +#include + +namespace { + + // Canonical feature names — must agree with the norm sidecar emitted + // by calorimeter/GNN/scripts/export_norm_stats.py and with the + // model_version-stamped metadata_props of the .onnx artifacts. + const std::vector kNodeFeatures{ + "log_e", "t", "x", "y", "r", "e_rel" + }; + const std::vector kEdgeFeatures{ + "dx", "dy", "d", "dt", "dlog_e", "asym_e", "logsum_e", "dr" + }; + +} + +namespace mu2e { + + GnnGraphBuilder::Stats + GnnGraphBuilder::loadStatsFromJson(const std::string& jsonPath) + { + std::ifstream in(jsonPath); + if (!in.is_open()) { + throw cet::exception("GnnGraphBuilder") + << "cannot open norm sidecar: " << jsonPath; + } + nlohmann::json j; + try { + in >> j; + } catch (const std::exception& e) { + throw cet::exception("GnnGraphBuilder") + << "JSON parse error in " << jsonPath << ": " << e.what(); + } + + auto require = [&](const char* k) { + if (!j.contains(k)) { + throw cet::exception("GnnGraphBuilder") + << "norm sidecar missing key '" << k << "': " << jsonPath; + } + }; + require("schema_version"); + require("node_features"); + require("edge_features"); + require("node_mean"); + require("node_std"); + require("edge_mean"); + require("edge_std"); + + const int schema = j["schema_version"].get(); + if (schema != 1) { + throw cet::exception("GnnGraphBuilder") + << "unsupported norm sidecar schema_version " << schema + << " (expected 1): " << jsonPath; + } + + auto nodeNames = j["node_features"].get>(); + auto edgeNames = j["edge_features"].get>(); + if (nodeNames != kNodeFeatures) { + throw cet::exception("GnnGraphBuilder") + << "node_features mismatch in " << jsonPath; + } + if (edgeNames != kEdgeFeatures) { + throw cet::exception("GnnGraphBuilder") + << "edge_features mismatch in " << jsonPath; + } + + Stats s; + s.nodeMean = j["node_mean"].get>(); + s.nodeStd = j["node_std"].get>(); + s.edgeMean = j["edge_mean"].get>(); + s.edgeStd = j["edge_std"].get>(); + + if (s.nodeMean.size() != kNodeFeatures.size() + || s.nodeStd.size() != kNodeFeatures.size()) { + throw cet::exception("GnnGraphBuilder") + << "norm sidecar node stats size mismatch in " << jsonPath; + } + if (s.edgeMean.size() != kEdgeFeatures.size() + || s.edgeStd.size() != kEdgeFeatures.size()) { + throw cet::exception("GnnGraphBuilder") + << "norm sidecar edge stats size mismatch in " << jsonPath; + } + for (float v : s.nodeStd) { + if (v <= 0.0f) { + throw cet::exception("GnnGraphBuilder") + << "non-positive node_std entry in " << jsonPath; + } + } + for (float v : s.edgeStd) { + if (v <= 0.0f) { + throw cet::exception("GnnGraphBuilder") + << "non-positive edge_std entry in " << jsonPath; + } + } + return s; + } + + + void GnnGraphBuilder::buildGraph( + int diskID, + const std::vector& hits, + const std::vector>& ptrs, + const Calorimeter& cal, + CaloHitGraph& out) const + { + const std::size_t n = hits.size(); + out.diskID = diskID; + out.nNodes = static_cast(n); + out.caloHitPtrs = ptrs; + out.x.assign(6 * n, 0.0f); + out.edgeIndex.clear(); + out.edgeAttr.clear(); + out.nEdges = 0; + if (n == 0) return; + + // Per-hit cached arrays: position, time, energy, radial. + std::vector xs(n), ys(n), ts(n), es(n), rs(n); + double eMax = 0.0; + for (std::size_t i = 0; i < n; ++i) { + const auto& h = *hits[i]; + const auto& cry = cal.crystal(h.crystalID()); + const auto& pos = cry.localPosition(); + xs[i] = pos.x(); + ys[i] = pos.y(); + ts[i] = h.time(); + es[i] = h.energyDep(); + rs[i] = std::sqrt(xs[i] * xs[i] + ys[i] * ys[i]); + if (es[i] > eMax) eMax = es[i]; + } + + // Node features (6) + z-score normalisation. + for (std::size_t i = 0; i < n; ++i) { + const float feats[6] = { + static_cast(std::log1p(es[i])), + static_cast(ts[i]), + static_cast(xs[i]), + static_cast(ys[i]), + static_cast(rs[i]), + static_cast(eMax > 0.0 ? es[i] / eMax : 0.0) + }; + for (int k = 0; k < 6; ++k) { + out.x[6 * i + k] = (feats[k] - stats_.nodeMean[k]) / stats_.nodeStd[k]; + } + } + if (n == 1) return; + + // ----- 1. radius graph: brute-force pairwise i> pairs; + pairs.reserve(n * 4); + for (std::size_t i = 0; i < n; ++i) { + for (std::size_t j = i + 1; j < n; ++j) { + const double dx = xs[i] - xs[j]; + const double dy = ys[i] - ys[j]; + if (dx * dx + dy * dy > rMax2) continue; + if (std::abs(ts[i] - ts[j]) > cfg_.dtMax) continue; + pairs.emplace_back(static_cast(i), static_cast(j)); + } + } + + // Encode (s, d) pair as a single key for dedup. n is the count of + // nodes; s and d are < n, so s*n + d uniquely identifies a directed + // edge. + auto encode = [n](int s, int d) -> long long { + return static_cast(s) * static_cast(n) + + static_cast(d); + }; + + std::unordered_set seen; + std::vector srcL, dstL; + seen.reserve(4 * pairs.size() + 16); + srcL.reserve(2 * pairs.size()); + dstL.reserve(2 * pairs.size()); + + auto add_edge = [&](int s, int d) { + const long long k = encode(s, d); + if (seen.insert(k).second) { + srcL.push_back(s); + dstL.push_back(d); + } + }; + + // Match the Python order: forward edges first (i->j for all pairs), + // then backward edges (j->i). Keeps the pre-sort layout faithful + // to the Python reference. + for (const auto& p : pairs) add_edge(p.first, p.second); + for (const auto& p : pairs) add_edge(p.second, p.first); + + // ----- 2. degree before kNN fallback. + std::vector degree(n, 0); + for (int s : srcL) degree[s]++; + + // ----- 3. kNN fallback for under-connected nodes. + for (std::size_t i = 0; i < n; ++i) { + if (degree[i] >= static_cast(cfg_.kMin)) continue; + + // Sorted candidate list (excluding self) by spatial distance. + std::vector> cand; + cand.reserve(n - 1); + for (std::size_t j = 0; j < n; ++j) { + if (j == i) continue; + const double dx = xs[i] - xs[j]; + const double dy = ys[i] - ys[j]; + cand.emplace_back(std::sqrt(dx * dx + dy * dy), + static_cast(j)); + } + std::sort(cand.begin(), cand.end()); + + // Mirror Python: tree.query returns k_query nearest including + // self at j_pos=0. Self is silently skipped, so effectively + // (k_query - 1) actual neighbours are inspected. + const int kQuery = std::min(static_cast(cfg_.kMin) * 3, + static_cast(n)); + const std::size_t maxN = + std::min(std::max(0, kQuery - 1), cand.size()); + int added = 0; + for (std::size_t idx = 0; idx < maxN; ++idx) { + const int j = cand[idx].second; + if (std::abs(ts[i] - ts[static_cast(j)]) > cfg_.dtMax) + continue; + add_edge(static_cast(i), j); + add_edge(j, static_cast(i)); + ++added; + if (degree[i] + added >= static_cast(cfg_.kMin)) break; + } + } + + // ----- 4. Sort edges lexicographically by (src, dst) to match + // the Python deduplicate-by-encoded-value ordering. + { + std::vector> edges; + edges.reserve(srcL.size()); + for (std::size_t k = 0; k < srcL.size(); ++k) { + edges.emplace_back(srcL[k], dstL[k]); + } + std::sort(edges.begin(), edges.end()); + for (std::size_t k = 0; k < edges.size(); ++k) { + srcL[k] = edges[k].first; + dstL[k] = edges[k].second; + } + } + + // ----- 5. Per-source-node degree cap at k_max (keep nearest dsts). + if (cfg_.kMax > 0) { + std::vector> bySrc(n); + for (std::size_t k = 0; k < srcL.size(); ++k) { + bySrc[srcL[k]].push_back(static_cast(k)); + } + std::vector keep(srcL.size(), true); + for (std::size_t s = 0; s < n; ++s) { + if (bySrc[s].size() <= cfg_.kMax) continue; + auto& idx = bySrc[s]; + std::sort(idx.begin(), idx.end(), [&](int a, int b) { + const int da = dstL[a], db = dstL[b]; + const double dxa = xs[s] - xs[da]; + const double dya = ys[s] - ys[da]; + const double dxb = xs[s] - xs[db]; + const double dyb = ys[s] - ys[db]; + return (dxa * dxa + dya * dya) < (dxb * dxb + dyb * dyb); + }); + for (std::size_t k = cfg_.kMax; k < idx.size(); ++k) { + keep[idx[k]] = false; + } + } + std::vector sl, dl; + sl.reserve(srcL.size()); + dl.reserve(srcL.size()); + for (std::size_t k = 0; k < srcL.size(); ++k) { + if (keep[k]) { + sl.push_back(srcL[k]); + dl.push_back(dstL[k]); + } + } + srcL.swap(sl); + dstL.swap(dl); + } + + // ----- 6. Build flat edge_index (src row first, dst row second). + const int E = static_cast(srcL.size()); + out.nEdges = E; + out.edgeIndex.resize(2 * E); + for (int e = 0; e < E; ++e) { + out.edgeIndex[e] = srcL[e]; + out.edgeIndex[E + e] = dstL[e]; + } + + // ----- 7. Edge features (8) + z-score normalisation. + out.edgeAttr.resize(8 * E); + for (int e = 0; e < E; ++e) { + const int s = srcL[e]; + const int d = dstL[e]; + const double dx = xs[s] - xs[d]; + const double dy = ys[s] - ys[d]; + const double dist = std::sqrt(dx * dx + dy * dy); + const double dt = ts[s] - ts[d]; + const double log_e_s = std::log1p(es[s]); + const double log_e_d = std::log1p(es[d]); + const double dlog_e = log_e_s - log_e_d; + const double e_sum = es[s] + es[d]; + const double e_asym = (e_sum > 0.0) + ? (es[s] - es[d]) / e_sum : 0.0; + const double logsum_e = std::log1p(e_sum); + const double dr = rs[s] - rs[d]; + const double feats[8] = {dx, dy, dist, dt, + dlog_e, e_asym, logsum_e, dr}; + for (int k = 0; k < 8; ++k) { + out.edgeAttr[8 * e + k] = static_cast( + (feats[k] - stats_.edgeMean[k]) / stats_.edgeStd[k]); + } + } + } + +} diff --git a/CaloCluster/src/SConscript b/CaloCluster/src/SConscript index 1e19df7b31..aec9e2b32e 100644 --- a/CaloCluster/src/SConscript +++ b/CaloCluster/src/SConscript @@ -43,8 +43,14 @@ mainlib = helper.make_mainlib ( [ 'mu2e_GeometryService', rootlibs, ] ) +helper.make_bin( "testGnnClusterAssembler", + [ mainlib, + 'cetlib_except', + ] ) + helper.make_plugins( [ mainlib, 'mu2e_Mu2eUtilities', + 'mu2e_ConfigTools', 'mu2e_GlobalConstantsService', 'mu2e_GeometryService', 'mu2e_SeedService', @@ -70,6 +76,7 @@ helper.make_plugins( [ mainlib, 'TMVA', 'xerces-c', #needed for MVA 'boost_filesystem', + 'onnxruntime', # CaloClusterMakerGNN_module — central muse install via u092 rootlibs ], ) diff --git a/CaloCluster/src/testGnnClusterAssembler_main.cc b/CaloCluster/src/testGnnClusterAssembler_main.cc new file mode 100644 index 0000000000..8c6982bd6b --- /dev/null +++ b/CaloCluster/src/testGnnClusterAssembler_main.cc @@ -0,0 +1,143 @@ +// +// testGnnClusterAssembler_main — Stage-2 parity test for the C++ +// GnnClusterAssembler (Task 16g). +// +// Reads a JSON parity payload (produced by +// calorimeter/GNN/scripts/dump_parity_payloads.py), replays +// GnnClusterAssembler::assemble for each graph, and asserts the +// emitted cluster_labels are byte-identical to the Python reference +// labels stored in the same payload. Exits non-zero on any mismatch. +// +// Usage: +// testGnnClusterAssembler +// +// Default path (relative to MUSE_WORK_DIR): +// ../projects/calorimeter/GNN/tests/parity/calo_cluster_net_v2_stage1.parity.json +// + +#include "Offline/CaloCluster/inc/GnnClusterAssembler.hh" + +#include "nlohmann/json.hpp" + +#include +#include +#include +#include +#include +#include + + +namespace { + + constexpr const char* kDefaultPayload = + "/exp/mu2e/app/users/wzhou2/projects/calorimeter/GNN/tests/parity/calo_cluster_net_v2_stage1.parity.json"; + +} + + +int main(int argc, char* argv[]) +{ + const std::string path = (argc > 1) ? argv[1] : kDefaultPayload; + + std::ifstream in(path); + if (!in.is_open()) { + std::cerr << "[FAIL] cannot open " << path << "\n"; + return 2; + } + nlohmann::json j; + try { in >> j; } + catch (const std::exception& e) { + std::cerr << "[FAIL] JSON parse error in " << path + << ": " << e.what() << "\n"; + return 2; + } + + if (j.value("schema_version", 0) != 1) { + std::cerr << "[FAIL] unsupported schema_version " + << j.value("schema_version", 0) << "\n"; + return 2; + } + + mu2e::GnnClusterAssembler::Config cfg; + cfg.tauEdge = j.at("tau_edge").get(); + cfg.bfsExpandCut = j.at("bfs_expand_cut").get(); + cfg.minHits = j.at("min_hits").get(); + cfg.minEnergyMeV = j.at("min_energy_mev").get(); + mu2e::GnnClusterAssembler asm_(cfg); + + const auto& graphs = j.at("graphs"); + std::cout << "Loaded " << graphs.size() << " graphs from " << path << "\n"; + std::cout << "Recipe: tauEdge=" << cfg.tauEdge + << " bfsExpandCut=" << cfg.bfsExpandCut + << " minHits=" << cfg.minHits + << " minEnergyMeV=" << cfg.minEnergyMeV << "\n"; + std::cout << "model_version=" + << j.value("model_version", std::string("?")) << "\n"; + + std::size_t nGraphs = 0; + std::size_t nMismatchGraphs = 0; + std::size_t nMismatchNodes = 0; + std::size_t maxMismatchPerGraph = 0; + + for (const auto& g : graphs) { + const int N = g.at("n_nodes").get(); + const int E = g.at("n_edges").get(); + auto edgeIndex = g.at("edge_index").get>(); + auto edgeLogits = g.at("edge_logits").get>(); + auto energies = g.at("energies").get>(); + auto refLabels = g.at("cluster_labels").get>(); + + if (static_cast(edgeLogits.size()) != E + || static_cast(energies.size()) != N + || static_cast(edgeIndex.size()) != 2 * E + || static_cast(refLabels.size()) != N) { + std::cerr << "[FAIL] graph " << nGraphs + << ": shape mismatch in payload\n"; + return 2; + } + + auto outLabels = asm_.assemble(N, edgeIndex, edgeLogits, energies); + if (static_cast(outLabels.size()) != N) { + std::cerr << "[FAIL] graph " << nGraphs + << ": assemble returned size " << outLabels.size() + << " (expected " << N << ")\n"; + return 2; + } + + std::size_t diffs = 0; + for (int i = 0; i < N; ++i) { + if (outLabels[i] != refLabels[i]) ++diffs; + } + if (diffs > 0) { + ++nMismatchGraphs; + nMismatchNodes += diffs; + if (diffs > maxMismatchPerGraph) maxMismatchPerGraph = diffs; + if (nMismatchGraphs <= 5) { + std::cerr << "[diff] graph " << nGraphs + << " (N=" << N << ", E=" << E + << "): " << diffs << " mismatched node(s)\n"; + std::cerr << " py : "; + for (int v : refLabels) std::cerr << v << " "; + std::cerr << "\n cpp: "; + for (int v : outLabels) std::cerr << v << " "; + std::cerr << "\n"; + } + } + ++nGraphs; + } + + std::cout << "\n=== Summary ===\n"; + std::cout << "graphs: " << nGraphs << "\n"; + std::cout << "mismatch graphs: " << nMismatchGraphs << "\n"; + std::cout << "mismatch nodes: " << nMismatchNodes << "\n"; + std::cout << "max diffs/graph: " << maxMismatchPerGraph << "\n"; + + if (nMismatchGraphs > 0) { + std::cerr << "[FAIL] cluster-label parity broken on " + << nMismatchGraphs << " / " << nGraphs << " graph(s)\n"; + return 1; + } + std::cout << "[PASS] all " << nGraphs + << " graphs match Python cluster_labels byte-exactly\n"; + return 0; +} diff --git a/RecoDataProducts/inc/CaloHitGraph.hh b/RecoDataProducts/inc/CaloHitGraph.hh new file mode 100644 index 0000000000..c8897845a5 --- /dev/null +++ b/RecoDataProducts/inc/CaloHitGraph.hh @@ -0,0 +1,50 @@ +#ifndef RecoDataProducts_CaloHitGraph_hh +#define RecoDataProducts_CaloHitGraph_hh +// +// Per-disk graph carrying the inputs the GNN clustering ONNX model +// expects, plus per-node back-references to the source CaloHits. +// +// Transient data product: emitted by CaloHitGraphMaker, consumed by +// CaloClusterMakerGNN in the same job. Not registered in +// classes_def.xml (no ROOT serialisation). +// +// Tensor layout matches the ONNX model interface in +// calorimeter/GNN/docs/onnx_deployment.md (§2 / §7): +// +// x : nNodes * 6 floats, row-major (one row per hit) +// edgeIndex : 2 * nEdges int64s, row-major (src row first, dst row second) +// edgeAttr : nEdges * 8 floats, row-major (one row per directed edge) +// +// Feature column order is fixed: +// nodes : log_e, t, x, y, r, e_rel +// edges : dx, dy, d, dt, dlog_e, asym_e, logsum_e, dr +// The cluster module asserts these names against the loaded model's +// metadata_props (16j handshake). +// + +#include "Offline/RecoDataProducts/inc/CaloHit.hh" +#include "canvas/Persistency/Common/Ptr.h" + +#include +#include + +namespace mu2e { + + struct CaloHitGraph + { + int nNodes = 0; + int nEdges = 0; + int diskID = -1; + + std::vector x; // size 6 * nNodes + std::vector edgeIndex; // size 2 * nEdges + std::vector edgeAttr; // size 8 * nEdges + + std::vector> caloHitPtrs; // size nNodes + }; + + using CaloHitGraphCollection = std::vector; + +} + +#endif