Skip to content

Commit

Permalink
multithread
Browse files Browse the repository at this point in the history
  • Loading branch information
Meakk committed Oct 14, 2023
1 parent d90f4d3 commit 3ceadb2
Showing 1 changed file with 38 additions and 12 deletions.
50 changes: 38 additions & 12 deletions library/src/image.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <vtkJPEGWriter.h>
#include <vtkPNGWriter.h>
#include <vtkPointData.h>
#include <vtkSMPTools.h>
#include <vtkSmartPointer.h>
#include <vtkTIFFWriter.h>
#include <vtksys/SystemTools.hxx>
Expand Down Expand Up @@ -272,22 +273,47 @@ double image::psnr(const image& reference) const
throw psnr_exception("One image has a channel type different then BYTE");
}

unsigned char* contentRef = static_cast<unsigned char*>(reference.getContent());
unsigned char* contentThis = static_cast<unsigned char*>(this->getContent());
struct MSEWorker
{
vtkSMPThreadLocal<double> LocalMSE;
double ReducedMSE = 0.0;
unsigned char* Buffer1;
unsigned char* Buffer2;

MSEWorker(unsigned char* buffer1, unsigned char* buffer2)
: Buffer1(buffer1)
, Buffer2(buffer2)
{
}

void Initialize() { this->LocalMSE.Local() = 0.0; }

void operator()(vtkIdType begin, vtkIdType end)
{
for (vtkIdType idx = begin; idx < end; ++idx)
{
double diff =
static_cast<double>(this->Buffer1[idx]) - static_cast<double>(this->Buffer2[idx]);
this->LocalMSE.Local() += diff * diff;
}
}

Check warning on line 299 in library/src/image.cxx

View check run for this annotation

Codecov / codecov/patch

library/src/image.cxx#L299

Added line #L299 was not covered by tests

void Reduce()
{
for (double localMSE : this->LocalMSE)
{
this->ReducedMSE += localMSE;
}
}
};

double mse = 0.0;
unsigned int totalSize = this->getHeight() * this->getWidth() * this->getChannelCount();

// todo: multi-thread using SMP
for (unsigned int i = 0; i < totalSize; i++)
{
unsigned char valRef = (*contentRef++);
unsigned char valThis = (*contentThis++);
double diff = static_cast<double>(valRef) - static_cast<double>(valThis);
mse += diff * diff;
}
MSEWorker worker(static_cast<unsigned char*>(reference.getContent()),
static_cast<unsigned char*>(this->getContent()));
vtkSMPTools::For(0, totalSize, worker);

mse /= static_cast<double>(totalSize);
double mse = worker.ReducedMSE / static_cast<double>(totalSize);

if (mse < 1e-9)
{
Expand Down

0 comments on commit 3ceadb2

Please sign in to comment.