Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Reduce memory consumption of AMVF #2832

Merged
merged 8 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions Core/include/Acts/Vertexing/AdaptiveMultiVertexFitter.ipp
Original file line number Diff line number Diff line change
Expand Up @@ -380,11 +380,9 @@ void Acts::AdaptiveMultiVertexFitter<
// vertex. The second template argument corresponds to the number of
// fitted vertex dimensions (i.e., 3 if we only fit spatial coordinates
// and 4 if we also fit time).
if (m_cfg.useTime) {
KalmanVertexTrackUpdater::update<input_track_t, 4>(trkAtVtx, *vtx);
} else {
KalmanVertexTrackUpdater::update<input_track_t, 3>(trkAtVtx, *vtx);
}
KalmanVertexTrackUpdater::update(trkAtVtx, vtx->fullPosition(),
vtx->fullCovariance(),
m_cfg.useTime ? 4 : 3);
}
}
}
Expand Down
9 changes: 4 additions & 5 deletions Core/include/Acts/Vertexing/KalmanVertexTrackUpdater.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@

#pragma once

#include "Acts/EventData/TrackParameters.hpp"
#include "Acts/Utilities/Logger.hpp"
#include "Acts/Utilities/Result.hpp"
#include "Acts/Vertexing/KalmanVertexUpdater.hpp"
#include "Acts/Vertexing/LinearizedTrack.hpp"
#include "Acts/Vertexing/TrackAtVertex.hpp"
#include "Acts/Vertexing/Vertex.hpp"

Expand All @@ -35,9 +37,8 @@ namespace KalmanVertexTrackUpdater {
///
/// @param track Track to update
/// @param vtx Vertex `track` belongs to
template <typename input_track_t, unsigned int nDimVertex>
void update(TrackAtVertex<input_track_t>& track,
const Vertex<input_track_t>& vtx);
void update(TrackAtVertexRef track, const Vector4& vtxPosFull,
const SquareMatrix4& vtxCovFull, unsigned int nDimVertex);

namespace detail {

Expand All @@ -61,5 +62,3 @@ inline BoundMatrix calculateTrackCovariance(

} // Namespace KalmanVertexTrackUpdater
} // Namespace Acts

#include "Acts/Vertexing/KalmanVertexTrackUpdater.ipp"
88 changes: 35 additions & 53 deletions Core/include/Acts/Vertexing/KalmanVertexUpdater.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#pragma once

#include "Acts/Definitions/Algebra.hpp"
#include "Acts/Utilities/Result.hpp"
#include "Acts/Vertexing/TrackAtVertex.hpp"
#include "Acts/Vertexing/Vertex.hpp"
Expand Down Expand Up @@ -68,6 +69,25 @@ template <typename input_track_t, unsigned int nDimVertex>
void updateVertexWithTrack(Vertex<input_track_t>& vtx,
TrackAtVertex<input_track_t>& trk);

namespace detail {
void updateVertexWithTrack(Vector4& vtxPos, SquareMatrix4& vtxCov,
std::pair<double, double>& fitQuality,
TrackAtVertexRef trk, int sign,
unsigned int nDimVertex);

// These two functions only exist so we can compile calculateUpdate in a
// compilation unit
void calculateUpdate3(const Vector4& vtxPos, const SquareMatrix4& vtxCov,
const Acts::LinearizedTrack& linTrack,
const double trackWeight, const int sign,
Cache<3>& cache);

void calculateUpdate4(const Vector4& vtxPos, const SquareMatrix4& vtxCov,
const Acts::LinearizedTrack& linTrack,
const double trackWeight, const int sign,
Cache<4>& cache);
} // namespace detail

/// @brief Calculates updated vertex position and covariance as well as the
/// updated track momentum when adding/removing linTrack. Saves the result in
/// cache.
Expand All @@ -83,61 +103,23 @@ void updateVertexWithTrack(Vertex<input_track_t>& vtx,
/// @note Tracks are removed during the smoothing procedure to compute
/// the chi2 of the track wrt the updated vertex position
/// @param[out] cache A cache to store the results of this function
template <typename input_track_t, unsigned int nDimVertex>
void calculateUpdate(const Acts::Vertex<input_track_t>& vtx,
template <unsigned int nDimVertex>
void calculateUpdate(const Vector4& vtxPos, const SquareMatrix4& vtxCov,
const Acts::LinearizedTrack& linTrack,
const double trackWeight, const int sign,
Cache<nDimVertex>& cache);

namespace detail {

/// @brief Calculates the update of the vertex position chi2 after
/// adding/removing the track
///
/// @tparam input_track_t Track object type
/// @tparam nDimVertex number of dimensions of the vertex. Can be 3 (if we only
/// fit its spatial coordinates) or 4 (if we also fit time).
///
/// @param oldVtx Vertex before the track was added/removed
/// @param cache Cache containing updated vertex position
///
/// @return Chi2
template <typename input_track_t, unsigned int nDimVertex>
double vertexPositionChi2Update(const Vertex<input_track_t>& oldVtx,
const Cache<nDimVertex>& cache);

/// @brief Calculates chi2 of refitted track parameters
/// w.r.t. updated vertex
///
/// @tparam input_track_t Track object type
/// @tparam nDimVertex number of dimensions of the vertex. Can be 3 (if we only
/// fit its spatial coordinates) or 4 (if we also fit time).
///
/// @param linTrack Linearized version of track
/// @param cache Cache containing some quantities needed in
/// this function
///
/// @return Chi2
template <typename input_track_t, unsigned int nDimVertex>
double trackParametersChi2(const LinearizedTrack& linTrack,
const Cache<nDimVertex>& cache);

/// @brief Adds or removes (depending on `sign`) tracks from vertex
/// and updates the vertex
///
/// @tparam input_track_t Track object type
/// @tparam nDimVertex number of dimensions of the vertex. Can be 3 (if we only
/// fit its spatial coordinates) or 4 (if we also fit time).
///
/// @param vtx Vertex to be updated
/// @param trk Track to be added to/removed from vtx
/// @param sign +1 (add track) or -1 (remove track)
/// @note Tracks are removed during the smoothing procedure to compute
/// the chi2 of the track wrt the updated vertex position
template <typename input_track_t, unsigned int nDimVertex>
void update(Vertex<input_track_t>& vtx, TrackAtVertex<input_track_t>& trk,
int sign);
} // Namespace detail
Cache<nDimVertex>& cache) {
static_assert(nDimVertex == 3 || nDimVertex == 4,
"The vertex dimension must either be 3 (when fitting the "
"spatial coordinates) or 4 (when fitting the spatial "
"coordinates + time).");
if constexpr (nDimVertex == 3) {
detail::calculateUpdate3(vtxPos, vtxCov, linTrack, trackWeight, sign,
cache);
} else if constexpr (nDimVertex == 4) {
detail::calculateUpdate4(vtxPos, vtxCov, linTrack, trackWeight, sign,
cache);
}
}

} // Namespace KalmanVertexUpdater
} // Namespace Acts
Expand Down
191 changes: 4 additions & 187 deletions Core/include/Acts/Vertexing/KalmanVertexUpdater.ipp
Original file line number Diff line number Diff line change
Expand Up @@ -13,191 +13,8 @@
template <typename input_track_t, unsigned int nDimVertex>
void Acts::KalmanVertexUpdater::updateVertexWithTrack(
Vertex<input_track_t>& vtx, TrackAtVertex<input_track_t>& trk) {
detail::update<input_track_t, nDimVertex>(vtx, trk, 1);
}

template <typename input_track_t, unsigned int nDimVertex>
void Acts::KalmanVertexUpdater::detail::update(
Vertex<input_track_t>& vtx, TrackAtVertex<input_track_t>& trk, int sign) {
if constexpr (nDimVertex != 3 && nDimVertex != 4) {
throw std::invalid_argument(
"The vertex dimension must either be 3 (when fitting the spatial "
"coordinates) or 4 (when fitting the spatial coordinates + time).");
}

double trackWeight = trk.trackWeight;

// Set up cache where entire content is set to 0
Cache<nDimVertex> cache;

// Calculate update and save result in cache
calculateUpdate(vtx, trk.linearizedState, trackWeight, sign, cache);

// Get fit quality parameters wrt to old vertex
std::pair fitQuality = vtx.fitQuality();
double chi2 = fitQuality.first;
double ndf = fitQuality.second;

// Chi2 of the track parameters
double trkChi2 =
detail::trackParametersChi2<input_track_t>(trk.linearizedState, cache);

// Update of the chi2 of the vertex position
double vtxPosChi2Update =
detail::vertexPositionChi2Update<input_track_t>(vtx, cache);

// Calculate new chi2
chi2 += sign * (vtxPosChi2Update + trackWeight * trkChi2);

// Calculate ndf
ndf += sign * trackWeight * 2.;

// Updating the vertex
if constexpr (nDimVertex == 3) {
vtx.setPosition(cache.newVertexPos);
vtx.setCovariance(cache.newVertexCov);
} else if constexpr (nDimVertex == 4) {
vtx.setFullPosition(cache.newVertexPos);
vtx.setFullCovariance(cache.newVertexCov);
}
vtx.setFitQuality(chi2, ndf);

if (sign == 1) {
// Update track
trk.chi2Track = trkChi2;
trk.ndf = 2 * trackWeight;
}
// Remove trk from current vertex by setting its weight to 0
else if (sign == -1) {
trk.trackWeight = 0.;
} else {
throw std::invalid_argument(
"Sign for adding/removing track must be +1 (add) or -1 (remove).");
}
}

template <typename input_track_t, unsigned int nDimVertex>
void Acts::KalmanVertexUpdater::calculateUpdate(
const Acts::Vertex<input_track_t>& vtx,
const Acts::LinearizedTrack& linTrack, const double trackWeight,
const int sign, Cache<nDimVertex>& cache) {
constexpr unsigned int nBoundParams = nDimVertex + 2;
using ParameterVector = ActsVector<nBoundParams>;
using ParameterMatrix = ActsSquareMatrix<nBoundParams>;
// Retrieve variables from the track linearization. The comments indicate the
// corresponding symbol used in Ref. (1).
// A_k
const ActsMatrix<nBoundParams, nDimVertex> posJac =
linTrack.positionJacobian.block<nBoundParams, nDimVertex>(0, 0);
// B_k
const ActsMatrix<nBoundParams, 3> momJac =
linTrack.momentumJacobian.block<nBoundParams, 3>(0, 0);
// p_k
const ParameterVector trkParams =
linTrack.parametersAtPCA.head<nBoundParams>();
// c_k
const ParameterVector constTerm = linTrack.constantTerm.head<nBoundParams>();
// TODO we could use `linTrack.weightAtPCA` but only if we would always fit
// time.
// G_k
// Note that, when removing a track, G_k -> - G_k, see Ref. (1).
// Further note that, as we use the weighted formalism, the track weight
// matrix (i.e., the inverse track covariance matrix) should be multiplied
// with the track weight from the AMVF formalism. Here, we choose to
// consider these two multiplicative factors directly in the updates of
// newVertexWeight and newVertexPos.
const ParameterMatrix trkParamWeight =
linTrack.covarianceAtPCA.block<nBoundParams, nBoundParams>(0, 0)
.inverse();

// Retrieve current position of the vertex and its current weight matrix
const ActsVector<nDimVertex> oldVtxPos =
vtx.fullPosition().template head<nDimVertex>();
// C_{k-1}^-1
cache.oldVertexWeight =
(vtx.fullCovariance().template block<nDimVertex, nDimVertex>(0, 0))
.inverse();

// W_k
cache.wMat = (momJac.transpose() * (trkParamWeight * momJac)).inverse();

// G_k^B = G_k - G_k*B_k*W_k*B_k^(T)*G_k
ParameterMatrix gBMat = trkParamWeight - trkParamWeight * momJac *
cache.wMat * momJac.transpose() *
trkParamWeight;

// C_k^-1
cache.newVertexWeight = cache.oldVertexWeight + sign * trackWeight *
posJac.transpose() *
gBMat * posJac;
// C_k
cache.newVertexCov = cache.newVertexWeight.inverse();

// \tilde{x_k}
cache.newVertexPos =
cache.newVertexCov * (cache.oldVertexWeight * oldVtxPos +
sign * trackWeight * posJac.transpose() * gBMat *
(trkParams - constTerm));
}

template <typename input_track_t, unsigned int nDimVertex>
double Acts::KalmanVertexUpdater::detail::vertexPositionChi2Update(
const Vertex<input_track_t>& oldVtx, const Cache<nDimVertex>& cache) {
ActsVector<nDimVertex> posDiff =
cache.newVertexPos - oldVtx.fullPosition().template head<nDimVertex>();

// Calculate and return corresponding chi2
return posDiff.transpose() * (cache.oldVertexWeight * posDiff);
}

template <typename input_track_t, unsigned int nDimVertex>
double Acts::KalmanVertexUpdater::detail::trackParametersChi2(
const LinearizedTrack& linTrack, const Cache<nDimVertex>& cache) {
constexpr unsigned int nBoundParams = nDimVertex + 2;
using ParameterVector = ActsVector<nBoundParams>;
using ParameterMatrix = ActsSquareMatrix<nBoundParams>;
// A_k
const ActsMatrix<nBoundParams, nDimVertex> posJac =
linTrack.positionJacobian.block<nBoundParams, nDimVertex>(0, 0);
// B_k
const ActsMatrix<nBoundParams, 3> momJac =
linTrack.momentumJacobian.block<nBoundParams, 3>(0, 0);
// p_k
const ParameterVector trkParams =
linTrack.parametersAtPCA.head<nBoundParams>();
// c_k
const ParameterVector constTerm = linTrack.constantTerm.head<nBoundParams>();
// TODO we could use `linTrack.weightAtPCA` but only if we would always fit
// time.
// G_k
const ParameterMatrix trkParamWeight =
linTrack.covarianceAtPCA.block<nBoundParams, nBoundParams>(0, 0)
.inverse();

// A_k * \tilde{x_k}
const ParameterVector posJacVtxPos = posJac * cache.newVertexPos;

// \tilde{q_k}
Vector3 newTrkMom = cache.wMat * momJac.transpose() * trkParamWeight *
(trkParams - constTerm - posJacVtxPos);

// Correct phi and theta for possible periodicity changes
// Commented out because of broken ATHENA tests.
// TODO: uncomment
/*
const auto correctedPhiTheta =
Acts::detail::normalizePhiTheta(newTrkMom(0), newTrkMom(1));
newTrkMom(0) = correctedPhiTheta.first; // phi
newTrkMom(1) = correctedPhiTheta.second; // theta
*/

// \tilde{p_k}
ParameterVector linearizedTrackParameters =
constTerm + posJacVtxPos + momJac * newTrkMom;

// r_k
ParameterVector paramDiff = trkParams - linearizedTrackParameters;

// Return chi2
return paramDiff.transpose() * (trkParamWeight * paramDiff);
std::pair<double, double> fitQuality = vtx.fitQuality();
detail::updateVertexWithTrack(vtx.fullPosition(), vtx.fullCovariance(),
fitQuality, trk, 1, nDimVertex);
vtx.setFitQuality(fitQuality);
}
20 changes: 20 additions & 0 deletions Core/include/Acts/Vertexing/TrackAtVertex.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,24 @@ struct TrackAtVertex {
bool isLinearized = false;
};

struct TrackAtVertexRef {
paulgessinger marked this conversation as resolved.
Show resolved Hide resolved
BoundTrackParameters& fittedParams;
double& chi2Track;
double& ndf;
double& vertexCompatibility;
double& trackWeight;
LinearizedTrack& linearizedState;
bool isLinearized;

template <typename input_track_t>
TrackAtVertexRef(TrackAtVertex<input_track_t>& track)
: fittedParams(track.fittedParams),
chi2Track(track.chi2Track),
ndf(track.ndf),
vertexCompatibility(track.vertexCompatibility),
trackWeight(track.trackWeight),
linearizedState(track.linearizedState),
isLinearized(track.isLinearized) {}
};

} // namespace Acts
2 changes: 2 additions & 0 deletions Core/include/Acts/Vertexing/Vertex.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,14 @@ class Vertex {

/// @return Returns 4-position
const Vector4& fullPosition() const;
Vector4& fullPosition();

/// @return Returns position covariance
SquareMatrix3 covariance() const;

/// @return Returns 4x4 covariance
const SquareMatrix4& fullCovariance() const;
SquareMatrix4& fullCovariance();

/// @return Returns vector of tracks associated with the vertex
const std::vector<TrackAtVertex<input_track_t>>& tracks() const;
Expand Down
Loading
Loading