diff --git a/source/framework/core/inc/TRestVolumeHits.h b/source/framework/core/inc/TRestVolumeHits.h index 39281f125..94f0f442b 100644 --- a/source/framework/core/inc/TRestVolumeHits.h +++ b/source/framework/core/inc/TRestVolumeHits.h @@ -69,7 +69,8 @@ class TRestVolumeHits : public TRestHits { return TMath::Sqrt(fSigmaX[n] * fSigmaX[n] + fSigmaY[n] * fSigmaY[n]); } - static void kMeansClustering(TRestVolumeHits* hits, TRestVolumeHits& vHits, int maxIt = 100); + static void kMeansClustering(TRestVolumeHits* hits, TRestVolumeHits& vHits, int maxIt = 100, + bool fixBoundaries = false); // Constructor & Destructor TRestVolumeHits(); diff --git a/source/framework/core/src/TRestVolumeHits.cxx b/source/framework/core/src/TRestVolumeHits.cxx index 331d6b5d4..7f0c1c93d 100644 --- a/source/framework/core/src/TRestVolumeHits.cxx +++ b/source/framework/core/src/TRestVolumeHits.cxx @@ -162,13 +162,14 @@ void TRestVolumeHits::PrintHits() const { } } -void TRestVolumeHits::kMeansClustering(TRestVolumeHits* hits, TRestVolumeHits& vHits, int maxIt) { +void TRestVolumeHits::kMeansClustering(TRestVolumeHits* hits, TRestVolumeHits& vHits, int maxIt, + bool fixBoundaries) { const int nodes = vHits.GetNumberOfHits(); vector volHits(nodes); // std::cout<<"Nhits "<GetNumberOfHits()<<" Nodes "< centroid(nodes); - std::vector centroidOld(nodes, nullVector); + std::vector centroidOld(nodes, nullVector); // used for iterations for (int h = 0; h < nodes; h++) centroid[h] = vHits.GetPosition(h); @@ -178,6 +179,7 @@ void TRestVolumeHits::kMeansClustering(TRestVolumeHits* hits, TRestVolumeHits& v double minDist = 1E9; int clIndex = -1; for (int n = 0; n < nodes; n++) { + if (fixBoundaries && (n == 0 || n == nodes - 1)) continue; // Skip fixed nodes TVector3 hitPos = hits->GetPosition(i); double dist = (centroid[n] - hitPos).Mag(); if (dist < minDist) { @@ -188,8 +190,11 @@ void TRestVolumeHits::kMeansClustering(TRestVolumeHits* hits, TRestVolumeHits& v // cout< 0) - vHits.AddHit(volHits[n].GetMeanPosition(), volHits[n].GetTotalEnergy(), 0, volHits[n].GetType(0), - sigma); + if (fixBoundaries && (n == 0 || n == nodes - 1)) { + vHits.AddHit(centroid[n], 0, 0, vHits.GetType(n), sigma); + } else { + if (volHits[n].GetNumberOfHits() > 0) + vHits.AddHit(volHits[n].GetMeanPosition(), volHits[n].GetTotalEnergy(), 0, + volHits[n].GetType(0), sigma); + } } }