-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add LoopContractionMitigation to the StrokeModeler.
This is behind a param flag that is default disabled at least until we have values we are confident in. This also involves changing the way that the Query() point is projected onto the segment and thus is behind a default-off param flag. PiperOrigin-RevId: 658058054
- Loading branch information
1 parent
639d22f
commit 000ebd0
Showing
19 changed files
with
4,867 additions
and
333 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
71 changes: 71 additions & 0 deletions
71
ink_stroke_modeler/internal/loop_contraction_mitigation_modeler.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
#include "ink_stroke_modeler/internal/loop_contraction_mitigation_modeler.h" | ||
|
||
#include <algorithm> | ||
|
||
#include "ink_stroke_modeler/internal/utils.h" | ||
#include "ink_stroke_modeler/params.h" | ||
#include "ink_stroke_modeler/types.h" | ||
|
||
namespace ink { | ||
namespace stroke_model { | ||
|
||
namespace { | ||
|
||
float InverseLerp(float a, float b, float value) { | ||
// If the interval between `a` and `b` is 0, there is no way to get to `t` | ||
// because in the other direction the value of `t` won't impact the result. | ||
if (b - a == 0.f) { | ||
return 0.f; | ||
} | ||
return (value - a) / (b - a); | ||
} | ||
|
||
} // namespace | ||
|
||
void LoopContractionMitigationModeler::Reset( | ||
const PositionModelerParams::LoopContractionMitigationParameters& params) { | ||
speeds_.clear(); | ||
|
||
save_active_ = false; | ||
|
||
params_ = params; | ||
} | ||
|
||
float LoopContractionMitigationModeler::GetInterpolationValue() { | ||
if (speeds_.empty() || !params_.is_enabled) return 1; | ||
|
||
float sum = 0; | ||
for (const auto& speed : speeds_) { | ||
sum += speed; | ||
} | ||
float average_speed = sum / speeds_.size(); | ||
|
||
float source_ratio = std::max( | ||
0.f, | ||
std::min(1.f, InverseLerp(params_.speed_lower_bound, | ||
params_.speed_upper_bound, average_speed))); | ||
return Interp(params_.interpolation_strength_at_speed_lower_bound, | ||
params_.interpolation_strength_at_speed_upper_bound, | ||
source_ratio); | ||
} | ||
|
||
float LoopContractionMitigationModeler::Update(Vec2 velocity) { | ||
// The moving average acts as a low-pass signal filter, removing | ||
// high-frequency fluctuations in the velocity. | ||
speeds_.push_back(velocity.Magnitude()); | ||
if (speeds_.size() > params_.n_speed_samples) speeds_.pop_front(); | ||
|
||
return GetInterpolationValue(); | ||
} | ||
|
||
void LoopContractionMitigationModeler::Save() { | ||
saved_speeds_ = speeds_; | ||
save_active_ = true; | ||
} | ||
|
||
void LoopContractionMitigationModeler::Restore() { | ||
if (save_active_) speeds_ = saved_speeds_; | ||
} | ||
|
||
} // namespace stroke_model | ||
} // namespace ink |
50 changes: 50 additions & 0 deletions
50
ink_stroke_modeler/internal/loop_contraction_mitigation_modeler.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
#ifndef INK_STROKE_MODELER_INTERNAL_LOOP_CONTRACTION_MITIGATION_MODELER_H_ | ||
#define INK_STROKE_MODELER_INTERNAL_LOOP_CONTRACTION_MITIGATION_MODELER_H_ | ||
|
||
#include <deque> | ||
|
||
#include "ink_stroke_modeler/params.h" | ||
#include "ink_stroke_modeler/types.h" | ||
|
||
namespace ink { | ||
namespace stroke_model { | ||
|
||
class LoopContractionMitigationModeler { | ||
public: | ||
void Reset( | ||
const PositionModelerParams::LoopContractionMitigationParameters ¶ms); | ||
|
||
// Updates the model with the position and time from the raw inputs, and | ||
// returns the interpolation value to be used when applying the loop | ||
// contraction mitigation | ||
float Update(Vec2 velocity); | ||
|
||
// Returns the interpolation value based on the current set of available | ||
// velocities and the LoopContractionMitigationParameters. | ||
float GetInterpolationValue(); | ||
|
||
// Saves the current state of the modeler. See comment on | ||
// StrokeModeler::Save() for more details. | ||
void Save(); | ||
|
||
// Restores the saved state of the modeler. See comment on | ||
// StrokeModeler::Restore() for more details. | ||
void Restore(); | ||
|
||
private: | ||
std::deque<float> speeds_; | ||
|
||
// Use a deque + bool instead of optional<deque> for performance. A | ||
// std::deque, which has a non-trivial destructor that would deallocate its | ||
// capacity. This setup avoids extra calls to the destructor that would be | ||
// triggered by each call to std::optional::reset(). | ||
std::deque<float> saved_speeds_; | ||
bool save_active_ = false; | ||
|
||
PositionModelerParams::LoopContractionMitigationParameters params_; | ||
}; | ||
|
||
} // namespace stroke_model | ||
} // namespace ink | ||
|
||
#endif // INK_STROKE_MODELER_INTERNAL_LOOP_CONTRACTION_MITIGATION_MODELER_H_ |
134 changes: 134 additions & 0 deletions
134
ink_stroke_modeler/internal/loop_contraction_mitigation_modeler_test.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
#include "ink_stroke_modeler/internal/loop_contraction_mitigation_modeler.h" | ||
|
||
#include "gmock/gmock.h" | ||
#include "gtest/gtest.h" | ||
#include "ink_stroke_modeler/params.h" | ||
|
||
namespace ink { | ||
namespace stroke_model { | ||
namespace { | ||
|
||
using ::testing::FloatNear; | ||
using LoopContractionMitigationParameters = | ||
PositionModelerParams::LoopContractionMitigationParameters; | ||
|
||
const LoopContractionMitigationParameters kDefaultParams{ | ||
.is_enabled = true, | ||
.speed_lower_bound = 0, | ||
.speed_upper_bound = 100, | ||
.interpolation_strength_at_speed_lower_bound = 1, | ||
.interpolation_strength_at_speed_upper_bound = 0, | ||
.n_speed_samples = 5}; | ||
|
||
TEST(LoopContractionMitigationModelerTest, | ||
GetInterpolationValueOnEmptyModeler) { | ||
LoopContractionMitigationModeler modeler; | ||
modeler.Reset(kDefaultParams); | ||
|
||
EXPECT_THAT(modeler.GetInterpolationValue(), FloatNear(1, 0.01)); | ||
} | ||
|
||
TEST(LoopContractionMitigationModelerTest, UpdateWithOneSample) { | ||
LoopContractionMitigationModeler modeler; | ||
modeler.Reset(kDefaultParams); | ||
|
||
EXPECT_THAT(modeler.Update({3, 4}), FloatNear(0.95, 0.01)); | ||
EXPECT_THAT(modeler.GetInterpolationValue(), FloatNear(0.95, 0.01)); | ||
} | ||
|
||
TEST(LoopContractionMitigationModelerTest, IsEnabledFalseResultsInOne) { | ||
LoopContractionMitigationParameters params = { | ||
.is_enabled = false, | ||
.speed_lower_bound = 0, | ||
.speed_upper_bound = 10, | ||
.interpolation_strength_at_speed_lower_bound = 1, | ||
.interpolation_strength_at_speed_upper_bound = 0, | ||
.n_speed_samples = 5}; | ||
|
||
LoopContractionMitigationModeler modeler; | ||
modeler.Reset(params); | ||
|
||
EXPECT_THAT(modeler.Update({3, 4}), FloatNear(1, 0.01)); | ||
EXPECT_THAT(modeler.GetInterpolationValue(), FloatNear(1, 0.01)); | ||
} | ||
|
||
TEST(LoopContractionMitigationModelerTest, ResetClearsModeler) { | ||
LoopContractionMitigationModeler modeler; | ||
modeler.Reset(kDefaultParams); | ||
|
||
EXPECT_THAT(modeler.Update({3, 4}), FloatNear(0.95, 0.01)); | ||
EXPECT_THAT(modeler.GetInterpolationValue(), FloatNear(0.95, 0.01)); | ||
|
||
modeler.Reset(kDefaultParams); | ||
EXPECT_THAT(modeler.GetInterpolationValue(), FloatNear(1, 0.01)); | ||
} | ||
|
||
TEST(LoopContractionMitigationModelerTest, | ||
MultipleUpdatesButLessThanSampleSize) { | ||
LoopContractionMitigationModeler modeler; | ||
modeler.Reset(kDefaultParams); | ||
|
||
// Average is 5. | ||
EXPECT_THAT(modeler.Update({3, 4}), FloatNear(0.95, 0.01)); | ||
EXPECT_THAT(modeler.GetInterpolationValue(), FloatNear(0.95, 0.01)); | ||
|
||
// Average is (5 + 3) / 2=4. | ||
EXPECT_THAT(modeler.Update({0, 3}), FloatNear(0.96, 0.01)); | ||
EXPECT_THAT(modeler.GetInterpolationValue(), FloatNear(0.96, 0.01)); | ||
|
||
// Average is (5 + 3 + 10) / 3 = 6. | ||
EXPECT_THAT(modeler.Update({-10, 0}), FloatNear(0.94, 0.01)); | ||
EXPECT_THAT(modeler.GetInterpolationValue(), FloatNear(0.94, 0.01)); | ||
} | ||
|
||
TEST(LoopContractionMitigationModelerTest, MultipleUpdatesOverSampleSize) { | ||
LoopContractionMitigationModeler modeler; | ||
modeler.Reset(kDefaultParams); | ||
|
||
// Average is 5. | ||
ASSERT_THAT(modeler.Update({0, 5}), FloatNear(0.95, 0.01)); | ||
// Average is (5 + 3) / 2=4. | ||
ASSERT_THAT(modeler.Update({0, 3}), FloatNear(0.96, 0.01)); | ||
// Average is (5 + 3 + 10) / 3 = 6. | ||
ASSERT_THAT(modeler.Update({-10, 0}), FloatNear(0.94, 0.01)); | ||
// Average is (5 + 3 + 10 + 2) / 4 = 5. | ||
ASSERT_THAT(modeler.Update({0, 2}), FloatNear(0.95, 0.01)); | ||
// Average is (5 + 3 + 10 + 2 + 2) / 5 = 4.4. | ||
ASSERT_THAT(modeler.Update({0, 2}), FloatNear(0.956, 0.01)); | ||
|
||
// The next one should clear the first value. | ||
// Average is (3 + 10 + 2 + 2 + 1) / 5 = 3.6) | ||
EXPECT_THAT(modeler.Update({0, 2}), FloatNear(0.964, 0.01)); | ||
} | ||
|
||
TEST(LoopContractionMitigationModelerTest, SaveAndRestore) { | ||
LoopContractionMitigationModeler modeler; | ||
modeler.Reset(kDefaultParams); | ||
|
||
// Average is 5. | ||
ASSERT_THAT(modeler.Update({0, 5}), FloatNear(0.95, 0.01)); | ||
// Average is (5 + 3) / 2=4. | ||
ASSERT_THAT(modeler.Update({0, 3}), FloatNear(0.96, 0.01)); | ||
// Average is (5 + 3 + 10) / 3 = 6. | ||
ASSERT_THAT(modeler.Update({-10, 0}), FloatNear(0.94, 0.01)); | ||
// Average is (5 + 3 + 10 + 2) / 4 = 5. | ||
ASSERT_THAT(modeler.Update({0, 2}), FloatNear(0.95, 0.01)); | ||
// Average is (5 + 3 + 10 + 2 + 2) / 5 = 4.4. | ||
ASSERT_THAT(modeler.Update({0, 2}), FloatNear(0.956, 0.01)); | ||
|
||
modeler.Save(); | ||
|
||
// This clears the first 2 values | ||
// Average is (3 + 10 + 2 + 2 + 1) / 5 = 3.6) | ||
ASSERT_THAT(modeler.Update({0, 2}), FloatNear(0.964, 0.01)); | ||
// Average is (10 + 2 + 2 + 1 + 5) / 5 = 4. | ||
ASSERT_THAT(modeler.Update({0, 5}), FloatNear(0.96, 0.01)); | ||
|
||
modeler.Restore(); | ||
// This should return the last value from before the save. | ||
EXPECT_THAT(modeler.GetInterpolationValue(), FloatNear(0.956, 0.01)); | ||
} | ||
|
||
} // namespace | ||
} // namespace stroke_model | ||
} // namespace ink |
Oops, something went wrong.