Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion source/framework/core/inc/TRestVolumeHits.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ class TRestVolumeHits : public TRestHits {
return TMath::Sqrt(fSigmaX[n] * fSigmaX[n] + fSigmaY[n] * fSigmaY[n]);
}

static void kMeansClustering(TRestVolumeHits* hits, TRestVolumeHits& vHits, int maxIt = 100);
static void kMeansClustering(TRestVolumeHits* hits, TRestVolumeHits& vHits, int maxIt = 100,
bool fixBoundaries = false);

// Constructor & Destructor
TRestVolumeHits();
Expand Down
19 changes: 14 additions & 5 deletions source/framework/core/src/TRestVolumeHits.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -162,13 +162,14 @@ void TRestVolumeHits::PrintHits() const {
}
}

void TRestVolumeHits::kMeansClustering(TRestVolumeHits* hits, TRestVolumeHits& vHits, int maxIt) {
void TRestVolumeHits::kMeansClustering(TRestVolumeHits* hits, TRestVolumeHits& vHits, int maxIt,
bool fixBoundaries) {
const int nodes = vHits.GetNumberOfHits();
vector<TRestVolumeHits> volHits(nodes);
// std::cout<<"Nhits "<<hits->GetNumberOfHits()<<" Nodes "<<nodes<<std::endl;
TVector3 nullVector = TVector3(0, 0, 0);
std::vector<TVector3> centroid(nodes);
std::vector<TVector3> centroidOld(nodes, nullVector);
std::vector<TVector3> centroidOld(nodes, nullVector); // used for iterations

for (int h = 0; h < nodes; h++) centroid[h] = vHits.GetPosition(h);

Expand All @@ -178,6 +179,7 @@ void TRestVolumeHits::kMeansClustering(TRestVolumeHits* hits, TRestVolumeHits& v
double minDist = 1E9;
int clIndex = -1;
for (int n = 0; n < nodes; n++) {
if (fixBoundaries && (n == 0 || n == nodes - 1)) continue; // Skip fixed nodes
TVector3 hitPos = hits->GetPosition(i);
double dist = (centroid[n] - hitPos).Mag();
if (dist < minDist) {
Expand All @@ -188,8 +190,11 @@ void TRestVolumeHits::kMeansClustering(TRestVolumeHits* hits, TRestVolumeHits& v
// cout<<minDist<<" "<<clIndex<<endl;
volHits[clIndex].AddHit(*hits, i);
}

// Update centroids and check for convergence
bool converge = true;
for (int n = 0; n < nodes; n++) {
if (fixBoundaries && (n == 0 || n == nodes - 1)) continue; // Skip fixed nodes
centroid[n] = volHits[n].GetMeanPosition();
converge &= (centroid[n] == centroidOld[n]);
centroidOld[n] = centroid[n];
Expand All @@ -202,8 +207,12 @@ void TRestVolumeHits::kMeansClustering(TRestVolumeHits* hits, TRestVolumeHits& v
vHits.RemoveHits();
const TVector3 sigma(0., 0., 0.);
for (int n = 0; n < nodes; n++) {
if (volHits[n].GetNumberOfHits() > 0)
vHits.AddHit(volHits[n].GetMeanPosition(), volHits[n].GetTotalEnergy(), 0, volHits[n].GetType(0),
sigma);
if (fixBoundaries && (n == 0 || n == nodes - 1)) {
vHits.AddHit(centroid[n], 0, 0, vHits.GetType(n), sigma);
} else {
if (volHits[n].GetNumberOfHits() > 0)
vHits.AddHit(volHits[n].GetMeanPosition(), volHits[n].GetTotalEnergy(), 0,
volHits[n].GetType(0), sigma);
}
}
}