Skip to content

Commit

Permalink
Add LoopContractionMitigation to the StrokeModeler.
Browse files Browse the repository at this point in the history
This is behind a param flag that is default disabled at least until we have values we are confident in.
This also involves changing the way that the Query() point is projected onto the segment and thus is behind a default-off param flag.

PiperOrigin-RevId: 658058054
  • Loading branch information
Ink Open Source authored and copybara-github committed Aug 22, 2024
1 parent 639d22f commit 000ebd0
Show file tree
Hide file tree
Showing 19 changed files with 4,867 additions and 333 deletions.
32 changes: 30 additions & 2 deletions ink_stroke_modeler/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,10 @@ cc_library(
":params",
":types",
"//ink_stroke_modeler/internal:internal_types",
"//ink_stroke_modeler/internal:loop_contraction_mitigation_modeler",
"//ink_stroke_modeler/internal:position_modeler",
"//ink_stroke_modeler/internal:stylus_state_modeler",
"//ink_stroke_modeler/internal:utils",
"//ink_stroke_modeler/internal:wobble_smoother",
"//ink_stroke_modeler/internal/prediction:input_predictor",
"//ink_stroke_modeler/internal/prediction:kalman_predictor",
Expand Down Expand Up @@ -87,10 +89,22 @@ cc_test(
deps = [
":params",
":stroke_modeler",
":type_matchers",
":types",
"@com_google_absl//absl/status",
"@com_google_googletest//:gtest_main",
],
)

cc_test(
name = "stroke_modeler_with_new_projection_test",
srcs = ["stroke_modeler_with_new_projection_test.cc"],
deps = [
":params",
":stroke_modeler",
":type_matchers",
":types",
"//ink_stroke_modeler/internal:type_matchers",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest_main",
],
)
Expand All @@ -109,6 +123,20 @@ cc_library(
],
)

cc_library(
name = "type_matchers",
testonly = 1,
srcs = ["type_matchers.cc"],
hdrs = ["type_matchers.h"],
deps = [
":types",
"//:gtest_for_library_testonly",
"//ink_stroke_modeler/internal:type_matchers",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest_for_library",
],
)

cc_test(
name = "types_test",
srcs = ["types_test.cc"],
Expand Down
37 changes: 36 additions & 1 deletion ink_stroke_modeler/internal/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,23 @@ cc_test(
deps = [
":internal_types",
":stylus_state_modeler",
":type_matchers",
"//ink_stroke_modeler:numbers",
"//ink_stroke_modeler:params",
"//ink_stroke_modeler:type_matchers",
"//ink_stroke_modeler:types",
"@com_google_googletest//:gtest_main",
],
)

cc_test(
name = "stylus_state_modeler_with_new_projection_test",
srcs = ["stylus_state_modeler_with_new_projection_test.cc"],
deps = [
":stylus_state_modeler",
"//ink_stroke_modeler:numbers",
"//ink_stroke_modeler:params",
"//ink_stroke_modeler:type_matchers",
"//ink_stroke_modeler:types",
"@com_google_googletest//:gtest_main",
],
)
Expand Down Expand Up @@ -170,3 +184,24 @@ cc_test(
"@com_google_googletest//:gtest_main",
],
)

cc_library(
name = "loop_contraction_mitigation_modeler",
srcs = ["loop_contraction_mitigation_modeler.cc"],
hdrs = ["loop_contraction_mitigation_modeler.h"],
deps = [
":utils",
"//ink_stroke_modeler:params",
"//ink_stroke_modeler:types",
],
)

cc_test(
name = "loop_contraction_mitigation_modeler_test",
srcs = ["loop_contraction_mitigation_modeler_test.cc"],
deps = [
":loop_contraction_mitigation_modeler",
"//ink_stroke_modeler:params",
"@com_google_googletest//:gtest_main",
],
)
71 changes: 71 additions & 0 deletions ink_stroke_modeler/internal/loop_contraction_mitigation_modeler.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#include "ink_stroke_modeler/internal/loop_contraction_mitigation_modeler.h"

#include <algorithm>

#include "ink_stroke_modeler/internal/utils.h"
#include "ink_stroke_modeler/params.h"
#include "ink_stroke_modeler/types.h"

namespace ink {
namespace stroke_model {

namespace {

float InverseLerp(float a, float b, float value) {
// If the interval between `a` and `b` is 0, there is no way to get to `t`
// because in the other direction the value of `t` won't impact the result.
if (b - a == 0.f) {
return 0.f;
}
return (value - a) / (b - a);
}

} // namespace

void LoopContractionMitigationModeler::Reset(
const PositionModelerParams::LoopContractionMitigationParameters& params) {
speeds_.clear();

save_active_ = false;

params_ = params;
}

float LoopContractionMitigationModeler::GetInterpolationValue() {
if (speeds_.empty() || !params_.is_enabled) return 1;

float sum = 0;
for (const auto& speed : speeds_) {
sum += speed;
}
float average_speed = sum / speeds_.size();

float source_ratio = std::max(
0.f,
std::min(1.f, InverseLerp(params_.speed_lower_bound,
params_.speed_upper_bound, average_speed)));
return Interp(params_.interpolation_strength_at_speed_lower_bound,
params_.interpolation_strength_at_speed_upper_bound,
source_ratio);
}

float LoopContractionMitigationModeler::Update(Vec2 velocity) {
// The moving average acts as a low-pass signal filter, removing
// high-frequency fluctuations in the velocity.
speeds_.push_back(velocity.Magnitude());
if (speeds_.size() > params_.n_speed_samples) speeds_.pop_front();

return GetInterpolationValue();
}

void LoopContractionMitigationModeler::Save() {
saved_speeds_ = speeds_;
save_active_ = true;
}

void LoopContractionMitigationModeler::Restore() {
if (save_active_) speeds_ = saved_speeds_;
}

} // namespace stroke_model
} // namespace ink
50 changes: 50 additions & 0 deletions ink_stroke_modeler/internal/loop_contraction_mitigation_modeler.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#ifndef INK_STROKE_MODELER_INTERNAL_LOOP_CONTRACTION_MITIGATION_MODELER_H_
#define INK_STROKE_MODELER_INTERNAL_LOOP_CONTRACTION_MITIGATION_MODELER_H_

#include <deque>

#include "ink_stroke_modeler/params.h"
#include "ink_stroke_modeler/types.h"

namespace ink {
namespace stroke_model {

class LoopContractionMitigationModeler {
public:
void Reset(
const PositionModelerParams::LoopContractionMitigationParameters &params);

// Updates the model with the position and time from the raw inputs, and
// returns the interpolation value to be used when applying the loop
// contraction mitigation
float Update(Vec2 velocity);

// Returns the interpolation value based on the current set of available
// velocities and the LoopContractionMitigationParameters.
float GetInterpolationValue();

// Saves the current state of the modeler. See comment on
// StrokeModeler::Save() for more details.
void Save();

// Restores the saved state of the modeler. See comment on
// StrokeModeler::Restore() for more details.
void Restore();

private:
std::deque<float> speeds_;

// Use a deque + bool instead of optional<deque> for performance. A
// std::deque, which has a non-trivial destructor that would deallocate its
// capacity. This setup avoids extra calls to the destructor that would be
// triggered by each call to std::optional::reset().
std::deque<float> saved_speeds_;
bool save_active_ = false;

PositionModelerParams::LoopContractionMitigationParameters params_;
};

} // namespace stroke_model
} // namespace ink

#endif // INK_STROKE_MODELER_INTERNAL_LOOP_CONTRACTION_MITIGATION_MODELER_H_
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
#include "ink_stroke_modeler/internal/loop_contraction_mitigation_modeler.h"

#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "ink_stroke_modeler/params.h"

namespace ink {
namespace stroke_model {
namespace {

using ::testing::FloatNear;
using LoopContractionMitigationParameters =
PositionModelerParams::LoopContractionMitigationParameters;

const LoopContractionMitigationParameters kDefaultParams{
.is_enabled = true,
.speed_lower_bound = 0,
.speed_upper_bound = 100,
.interpolation_strength_at_speed_lower_bound = 1,
.interpolation_strength_at_speed_upper_bound = 0,
.n_speed_samples = 5};

TEST(LoopContractionMitigationModelerTest,
GetInterpolationValueOnEmptyModeler) {
LoopContractionMitigationModeler modeler;
modeler.Reset(kDefaultParams);

EXPECT_THAT(modeler.GetInterpolationValue(), FloatNear(1, 0.01));
}

TEST(LoopContractionMitigationModelerTest, UpdateWithOneSample) {
LoopContractionMitigationModeler modeler;
modeler.Reset(kDefaultParams);

EXPECT_THAT(modeler.Update({3, 4}), FloatNear(0.95, 0.01));
EXPECT_THAT(modeler.GetInterpolationValue(), FloatNear(0.95, 0.01));
}

TEST(LoopContractionMitigationModelerTest, IsEnabledFalseResultsInOne) {
LoopContractionMitigationParameters params = {
.is_enabled = false,
.speed_lower_bound = 0,
.speed_upper_bound = 10,
.interpolation_strength_at_speed_lower_bound = 1,
.interpolation_strength_at_speed_upper_bound = 0,
.n_speed_samples = 5};

LoopContractionMitigationModeler modeler;
modeler.Reset(params);

EXPECT_THAT(modeler.Update({3, 4}), FloatNear(1, 0.01));
EXPECT_THAT(modeler.GetInterpolationValue(), FloatNear(1, 0.01));
}

TEST(LoopContractionMitigationModelerTest, ResetClearsModeler) {
LoopContractionMitigationModeler modeler;
modeler.Reset(kDefaultParams);

EXPECT_THAT(modeler.Update({3, 4}), FloatNear(0.95, 0.01));
EXPECT_THAT(modeler.GetInterpolationValue(), FloatNear(0.95, 0.01));

modeler.Reset(kDefaultParams);
EXPECT_THAT(modeler.GetInterpolationValue(), FloatNear(1, 0.01));
}

TEST(LoopContractionMitigationModelerTest,
MultipleUpdatesButLessThanSampleSize) {
LoopContractionMitigationModeler modeler;
modeler.Reset(kDefaultParams);

// Average is 5.
EXPECT_THAT(modeler.Update({3, 4}), FloatNear(0.95, 0.01));
EXPECT_THAT(modeler.GetInterpolationValue(), FloatNear(0.95, 0.01));

// Average is (5 + 3) / 2=4.
EXPECT_THAT(modeler.Update({0, 3}), FloatNear(0.96, 0.01));
EXPECT_THAT(modeler.GetInterpolationValue(), FloatNear(0.96, 0.01));

// Average is (5 + 3 + 10) / 3 = 6.
EXPECT_THAT(modeler.Update({-10, 0}), FloatNear(0.94, 0.01));
EXPECT_THAT(modeler.GetInterpolationValue(), FloatNear(0.94, 0.01));
}

TEST(LoopContractionMitigationModelerTest, MultipleUpdatesOverSampleSize) {
LoopContractionMitigationModeler modeler;
modeler.Reset(kDefaultParams);

// Average is 5.
ASSERT_THAT(modeler.Update({0, 5}), FloatNear(0.95, 0.01));
// Average is (5 + 3) / 2=4.
ASSERT_THAT(modeler.Update({0, 3}), FloatNear(0.96, 0.01));
// Average is (5 + 3 + 10) / 3 = 6.
ASSERT_THAT(modeler.Update({-10, 0}), FloatNear(0.94, 0.01));
// Average is (5 + 3 + 10 + 2) / 4 = 5.
ASSERT_THAT(modeler.Update({0, 2}), FloatNear(0.95, 0.01));
// Average is (5 + 3 + 10 + 2 + 2) / 5 = 4.4.
ASSERT_THAT(modeler.Update({0, 2}), FloatNear(0.956, 0.01));

// The next one should clear the first value.
// Average is (3 + 10 + 2 + 2 + 1) / 5 = 3.6)
EXPECT_THAT(modeler.Update({0, 2}), FloatNear(0.964, 0.01));
}

TEST(LoopContractionMitigationModelerTest, SaveAndRestore) {
LoopContractionMitigationModeler modeler;
modeler.Reset(kDefaultParams);

// Average is 5.
ASSERT_THAT(modeler.Update({0, 5}), FloatNear(0.95, 0.01));
// Average is (5 + 3) / 2=4.
ASSERT_THAT(modeler.Update({0, 3}), FloatNear(0.96, 0.01));
// Average is (5 + 3 + 10) / 3 = 6.
ASSERT_THAT(modeler.Update({-10, 0}), FloatNear(0.94, 0.01));
// Average is (5 + 3 + 10 + 2) / 4 = 5.
ASSERT_THAT(modeler.Update({0, 2}), FloatNear(0.95, 0.01));
// Average is (5 + 3 + 10 + 2 + 2) / 5 = 4.4.
ASSERT_THAT(modeler.Update({0, 2}), FloatNear(0.956, 0.01));

modeler.Save();

// This clears the first 2 values
// Average is (3 + 10 + 2 + 2 + 1) / 5 = 3.6)
ASSERT_THAT(modeler.Update({0, 2}), FloatNear(0.964, 0.01));
// Average is (10 + 2 + 2 + 1 + 5) / 5 = 4.
ASSERT_THAT(modeler.Update({0, 5}), FloatNear(0.96, 0.01));

modeler.Restore();
// This should return the last value from before the save.
EXPECT_THAT(modeler.GetInterpolationValue(), FloatNear(0.956, 0.01));
}

} // namespace
} // namespace stroke_model
} // namespace ink
Loading

0 comments on commit 000ebd0

Please sign in to comment.