Skip to content

Commit

Permalink
tidied up FIR tests
Browse files Browse the repository at this point in the history
  • Loading branch information
petiaccja committed Sep 5, 2021
1 parent 56c721d commit 0012e33
Showing 1 changed file with 74 additions and 95 deletions.
169 changes: 74 additions & 95 deletions test/Filtering/Test_FirFilter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,42 @@ template <class SignalT>
bool IsSymmetric(const SignalT& signal) {
auto beg = signal.begin();
auto end = signal.rbegin();
while (beg != end.base()) {
if (*beg != *end) {
while (beg <= end.base()) {
if (*beg != Approx(*end).margin(1e-7f)) {
return false;
}
++beg;
++end;
}
return true;
}

template <class SignalT>
bool IsAntiSymmetric(const SignalT& signal) {
auto beg = signal.begin();
auto end = signal.rbegin();
while (beg <= end.base()) {
if (*beg != Approx(-*end).margin(1e-7f)) {
return false;
}
++beg;
++end;
}
return true;
}

template <class SignalT>
auto MeasureResponse(size_t sampleRate, float frequency, const SignalT& filter) {
const float period = 1.0f / frequency;
const float length = 25.f * period;
auto testSignal = GenTestSignal(sampleRate, frequency, length);
testSignal *= BlackmanWindow<float, TIME_DOMAIN>(testSignal.Size());
const auto filteredSignal = Convolution(testSignal, filter, convolution::full);
const auto rmsTest = std::sqrt(SumSquare(testSignal));
const auto rmsFiltered = std::sqrt(SumSquare(filteredSignal));
return rmsFiltered / rmsTest;
}


//------------------------------------------------------------------------------
// Tests
Expand All @@ -46,29 +74,20 @@ TEST_CASE("Windowed Lowpass", "[FirFilter]") {
static constexpr float cutoff = 3800.f;
const auto normalizedCutoff = NormalizedFrequency(cutoff, sampleRate);


const auto impulse1 = FirFilter<float, TIME_DOMAIN>(numTaps, Lowpass(normalizedCutoff), Windowed(windows::hamming));
const auto impulse2 = FirFilter<float, TIME_DOMAIN>(numTaps, Lowpass(normalizedCutoff), Windowed(windows::hamming.operator()<float, TIME_DOMAIN>(numTaps)));
REQUIRE(IsSymmetric(impulse1));
REQUIRE(Sum(impulse1) == Approx(1));
REQUIRE(impulse1.Size() == numTaps);
REQUIRE(impulse2.Size() == numTaps);
REQUIRE(Max(Abs(impulse1 - impulse2)) < 1e-4f);

// Generate two signals just above and just below the cutoff and see their attenuation.
const auto passSignal = GenTestSignal(sampleRate, cutoff * 0.85f);
const auto rejectSignal = GenTestSignal(sampleRate, cutoff * 1.15f);

const auto filteredPassSignal = Convolution(passSignal, impulse1, convolution::full);
const auto filteredRejectSignal = Convolution(rejectSignal, impulse1, convolution::full);

const float energyPass = SumSquare(passSignal);
const float energyReject = SumSquare(rejectSignal);
const float energyFilteredPass = SumSquare(filteredPassSignal);
const float energyFilteredReject = SumSquare(filteredRejectSignal);

REQUIRE(energyFilteredPass / energyPass > 0.95f);
REQUIRE(energyFilteredPass / energyPass < 1.05f);
REQUIRE(energyFilteredReject / energyReject < 0.05f);
const float passResponse = MeasureResponse(sampleRate, cutoff * 0.85f, impulse1);
const float stopResponse = MeasureResponse(sampleRate, cutoff * 1.15f, impulse1);

REQUIRE(passResponse > 0.95f);
REQUIRE(passResponse < 1.05f);
REQUIRE(stopResponse < 0.05f);
}


Expand All @@ -92,16 +111,14 @@ TEST_CASE("Windowed arbitrary filter", "[FirFilter]") {

const auto impulse1 = FirFilter<float, TIME_DOMAIN>(numTaps, Arbitrary(response), Windowed(windows::hamming));
const auto impulse2 = FirFilter<float, TIME_DOMAIN>(numTaps, Arbitrary(response), Windowed(windows::hamming.operator()<float, TIME_DOMAIN>(numTaps)));
REQUIRE(IsSymmetric(impulse1));
REQUIRE(impulse1.Size() == numTaps);
REQUIRE(impulse2.Size() == numTaps);
REQUIRE(Max(Abs(impulse1 - impulse2)) < 1e-4f);

for (size_t i = 0; i < amplitudes.size(); ++i) {
const auto signal = GenTestSignal(sampleRate, frequencies[i] * sampleRate / 2.0f);
const auto filtered = Convolution(signal, impulse1, convolution::full);
const float energy = std::sqrt(SumSquare(signal));
const float filteredEnergy = std::sqrt(SumSquare(filtered));
REQUIRE(filteredEnergy / energy == Approx(amplitudes[i]).margin(0.05f));
const auto response = MeasureResponse(sampleRate, frequencies[i] * sampleRate / 2.0f, impulse1);
REQUIRE(response == Approx(amplitudes[i]).margin(0.05f));
}
}

Expand All @@ -112,24 +129,16 @@ TEST_CASE("Highpass", "[FirFilter]") {
const auto normalizedCutoff = NormalizedFrequency(cutoff, sampleRate);

const auto impulse = FirFilter<float, TIME_DOMAIN>(numTaps, Highpass(normalizedCutoff), Windowed(windows::hamming));
REQUIRE(IsSymmetric(impulse));
REQUIRE(Sum(impulse) < 1e-4f);
REQUIRE(impulse.Size() == numTaps);

// Generate two signals just above and just below the cutoff and see their attenuation.
const auto passSignal = GenTestSignal(sampleRate, cutoff * 1.15f);
const auto rejectSignal = GenTestSignal(sampleRate, cutoff * 0.85f);

const auto filteredPassSignal = Convolution(passSignal, impulse, convolution::full);
const auto filteredRejectSignal = Convolution(rejectSignal, impulse, convolution::full);

const float energyPass = SumSquare(passSignal);
const float energyReject = SumSquare(rejectSignal);
const float energyFilteredPass = SumSquare(filteredPassSignal);
const float energyFilteredReject = SumSquare(filteredRejectSignal);
const float passResponse = MeasureResponse(sampleRate, cutoff * 1.15f, impulse);
const float stopResponse = MeasureResponse(sampleRate, cutoff * 0.85f, impulse);

REQUIRE(energyFilteredPass / energyPass > 0.95f);
REQUIRE(energyFilteredPass / energyPass < 1.05f);
REQUIRE(energyFilteredReject / energyReject < 0.05f);
REQUIRE(passResponse > 0.95f);
REQUIRE(passResponse < 1.05f);
REQUIRE(stopResponse < 0.05f);
}


Expand All @@ -141,38 +150,21 @@ TEST_CASE("Bandpass", "[FirFilter]") {
const auto normalizedHigh = NormalizedFrequency(bandHigh, sampleRate);

const auto impulse = FirFilter<float, TIME_DOMAIN>(numTaps, Bandpass(normalizedLow, normalizedHigh), Windowed(windows::hamming));
REQUIRE(IsSymmetric(impulse));
REQUIRE(Sum(impulse) < 1e-3f);
REQUIRE(impulse.Size() == numTaps);

auto extended = impulse;
extended.Resize(44100, 0.0f);
const auto spectrum = Abs(FourierTransform(extended));

const auto passSignal1 = GenTestSignal(sampleRate, bandLow * 1.1f);
const auto passSignal2 = GenTestSignal(sampleRate, bandHigh * 0.9f);
const auto rejectSignal1 = GenTestSignal(sampleRate, bandLow * 0.9f);
const auto rejectSignal2 = GenTestSignal(sampleRate, bandHigh * 1.1f);

const auto filteredPassSignal1 = Convolution(passSignal1, impulse, convolution::full);
const auto filteredPassSignal2 = Convolution(passSignal2, impulse, convolution::full);
const auto filteredRejectSignal1 = Convolution(rejectSignal1, impulse, convolution::full);
const auto filteredRejectSignal2 = Convolution(rejectSignal2, impulse, convolution::full);

const float energyPass1 = SumSquare(passSignal1);
const float energyPass2 = SumSquare(passSignal2);
const float energyReject1 = SumSquare(rejectSignal1);
const float energyReject2 = SumSquare(rejectSignal2);
const float energyFilteredPass1 = SumSquare(filteredPassSignal1);
const float energyFilteredPass2 = SumSquare(filteredPassSignal2);
const float energyFilteredReject1 = SumSquare(filteredRejectSignal1);
const float energyFilteredReject2 = SumSquare(filteredRejectSignal2);

REQUIRE(energyFilteredPass1 / energyPass1 > 0.95f);
REQUIRE(energyFilteredPass1 / energyPass1 < 1.05f);
REQUIRE(energyFilteredReject1 / energyReject1 < 0.05f);
REQUIRE(energyFilteredPass2 / energyPass2 > 0.95f);
REQUIRE(energyFilteredPass2 / energyPass2 < 1.05f);
REQUIRE(energyFilteredReject2 / energyReject2 < 0.05f);
const float lowStopResponse = MeasureResponse(sampleRate, bandLow * 0.9f, impulse);
const float lowPassResponse = MeasureResponse(sampleRate, bandLow * 1.1f, impulse);
const float highPassResponse = MeasureResponse(sampleRate, bandHigh * 0.9f, impulse);
const float highStopResponse = MeasureResponse(sampleRate, bandHigh * 1.1f, impulse);

REQUIRE(highPassResponse > 0.95f);
REQUIRE(highPassResponse < 1.05f);
REQUIRE(highStopResponse < 0.05f);
REQUIRE(lowPassResponse > 0.95f);
REQUIRE(lowPassResponse < 1.05f);
REQUIRE(lowStopResponse < 0.05f);
}


Expand All @@ -184,45 +176,29 @@ TEST_CASE("Bandstop", "[FirFilter]") {
const auto normalizedHigh = NormalizedFrequency(bandHigh, sampleRate);

const auto impulse = FirFilter<float, TIME_DOMAIN>(numTaps, Bandstop(normalizedLow, normalizedHigh), Windowed(windows::hamming));
REQUIRE(IsSymmetric(impulse));
REQUIRE(Sum(impulse) == Approx(1).epsilon(0.005f));
REQUIRE(impulse.Size() == numTaps);

auto extended = impulse;
extended.Resize(44100, 0.0f);
const auto spectrum = Abs(FourierTransform(extended));

const auto rejectSignal1 = GenTestSignal(sampleRate, bandLow * 1.1f);
const auto rejectSignal2 = GenTestSignal(sampleRate, bandHigh * 0.9f);
const auto passSignal1 = GenTestSignal(sampleRate, bandLow * 0.9f);
const auto passSignal2 = GenTestSignal(sampleRate, bandHigh * 1.1f);

const auto filteredPassSignal1 = Convolution(passSignal1, impulse, convolution::full);
const auto filteredPassSignal2 = Convolution(passSignal2, impulse, convolution::full);
const auto filteredRejectSignal1 = Convolution(rejectSignal1, impulse, convolution::full);
const auto filteredRejectSignal2 = Convolution(rejectSignal2, impulse, convolution::full);

const float energyPass1 = SumSquare(passSignal1);
const float energyPass2 = SumSquare(passSignal2);
const float energyReject1 = SumSquare(rejectSignal1);
const float energyReject2 = SumSquare(rejectSignal2);
const float energyFilteredPass1 = SumSquare(filteredPassSignal1);
const float energyFilteredPass2 = SumSquare(filteredPassSignal2);
const float energyFilteredReject1 = SumSquare(filteredRejectSignal1);
const float energyFilteredReject2 = SumSquare(filteredRejectSignal2);

REQUIRE(energyFilteredPass1 / energyPass1 > 0.95f);
REQUIRE(energyFilteredPass1 / energyPass1 < 1.05f);
REQUIRE(energyFilteredReject1 / energyReject1 < 0.05f);
REQUIRE(energyFilteredPass2 / energyPass2 > 0.95f);
REQUIRE(energyFilteredPass2 / energyPass2 < 1.05f);
REQUIRE(energyFilteredReject2 / energyReject2 < 0.05f);
const float lowPassResponse = MeasureResponse(sampleRate, bandLow * 0.9f, impulse);
const float lowStopResponse = MeasureResponse(sampleRate, bandLow * 1.1f, impulse);
const float highStopResponse = MeasureResponse(sampleRate, bandHigh * 0.9f, impulse);
const float highPassResponse = MeasureResponse(sampleRate, bandHigh * 1.1f, impulse);

REQUIRE(highPassResponse > 0.95f);
REQUIRE(highPassResponse < 1.05f);
REQUIRE(highStopResponse < 0.05f);
REQUIRE(lowPassResponse > 0.95f);
REQUIRE(lowPassResponse < 1.05f);
REQUIRE(lowStopResponse < 0.05f);
}



TEST_CASE("Hilbert odd form", "[Hilbert]") {
const auto filter = FirFilter<float, TIME_DOMAIN>(247, Hilbert(), Windowed(windows::hamming));
REQUIRE(filter.Size() == 247);
REQUIRE(IsAntiSymmetric(filter));
const auto nonZeroSamples = Decimate(filter, 2);
const auto zeroSamples = Decimate(AsView(filter).SubSignal(1), 2);
REQUIRE(Max(zeroSamples) == 0.0f);
Expand All @@ -236,6 +212,7 @@ TEST_CASE("Hilbert odd form", "[Hilbert]") {
TEST_CASE("Hilbert even form", "[Hilbert]") {
const auto filter = FirFilter<float, TIME_DOMAIN>(246, Hilbert(), Windowed(windows::hamming));
REQUIRE(filter.Size() == 246);
REQUIRE(IsAntiSymmetric(filter));
REQUIRE(Min(Abs(filter)) > 0.0f);
const auto firstHalf = AsView(filter).SubSignal(0, filter.Size() / 2);
const auto secondHalf = AsView(filter).SubSignal(filter.Size() / 2);
Expand All @@ -246,6 +223,7 @@ TEST_CASE("Hilbert even form", "[Hilbert]") {
TEST_CASE("Hilbert odd small form", "[Hilbert]") {
const auto filter = FirFilter<float, TIME_DOMAIN>(19, Hilbert(), Windowed(windows::hamming));
REQUIRE(filter.Size() == 19);
REQUIRE(IsAntiSymmetric(filter));
const auto nonZeroSamples = Decimate(filter, 2);
const auto zeroSamples = Decimate(AsView(filter).SubSignal(1), 2);
REQUIRE(Max(zeroSamples) == 0.0f);
Expand All @@ -259,6 +237,7 @@ TEST_CASE("Hilbert odd small form", "[Hilbert]") {
TEST_CASE("Hilbert even small form", "[Hilbert]") {
const auto filter = FirFilter<float, TIME_DOMAIN>(10, Hilbert(), Windowed(windows::hamming));
REQUIRE(filter.Size() == 10);
REQUIRE(IsAntiSymmetric(filter));
REQUIRE(Min(Abs(filter)) > 0.0f);
const auto firstHalf = AsView(filter).SubSignal(0, filter.Size() / 2);
const auto secondHalf = AsView(filter).SubSignal(filter.Size() / 2);
Expand Down

0 comments on commit 0012e33

Please sign in to comment.