Skip to content

Commit

Permalink
Merge pull request #537 from htm-community/tm_conn
Browse files Browse the repository at this point in the history
TM Connections integration
  • Loading branch information
breznak authored Jul 18, 2019
2 parents 8e11cdb + 24983cf commit 6475fba
Show file tree
Hide file tree
Showing 10 changed files with 336 additions and 278 deletions.
28 changes: 19 additions & 9 deletions bindings/py/cpp_src/bindings/algorithms/py_Connections.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,14 @@ R"(Compatibility Warning: This classes API is unstable and may change without wa
[](const Connections &self) { return self.getConnectedThreshold(); });

py_Connections.def("createSegment", &Connections::createSegment,
py::arg("cell"));
py::arg("cell"),
py::arg("maxSegmentsPerCell") = 0
);

py_Connections.def("destroySegment", &Connections::destroySegment);

py_Connections.def("iteration", &Connections::iteration);

py_Connections.def("createSynapse", &Connections::createSynapse,
py::arg("segment"),
py::arg("presynaticCell"),
Expand Down Expand Up @@ -93,24 +97,26 @@ R"(Compatibility Warning: This classes API is unstable and may change without wa
py_Connections.def("reset", &Connections::reset);

py_Connections.def("computeActivity",
[](Connections &self, SDR &activePresynapticCells) {
[](Connections &self, SDR &activePresynapticCells, bool learn=true) {
// Allocate buffer to return & make a python destructor object for it.
auto activeConnectedSynapses =
new std::vector<SynapseIdx>( self.segmentFlatListLength(), 0u );
auto destructor = py::capsule( activeConnectedSynapses,
[](void *dataPtr) {
delete reinterpret_cast<std::vector<SynapseIdx>*>(dataPtr); });
// Call the C++ method.
self.computeActivity(*activeConnectedSynapses, activePresynapticCells.getSparse());
// Wrap vector in numpy array.

// Call the C++ method.
self.computeActivity(*activeConnectedSynapses, activePresynapticCells.getSparse(), learn);

// Wrap vector in numpy array.
return py::array(activeConnectedSynapses->size(),
activeConnectedSynapses->data(),
destructor);
},
R"(Returns numActiveConnectedSynapsesForSegment)");

py_Connections.def("computeActivityFull",
[](Connections &self, SDR &activePresynapticCells) {
[](Connections &self, SDR &activePresynapticCells, bool learn=true) {
// Allocate buffer to return & make a python destructor object for it.
auto activeConnectedSynapses =
new std::vector<SynapseIdx>( self.segmentFlatListLength(), 0u );
Expand All @@ -123,9 +129,13 @@ R"(Returns numActiveConnectedSynapsesForSegment)");
auto potentialDestructor = py::capsule( activePotentialSynapses,
[](void *dataPtr) {
delete reinterpret_cast<std::vector<SynapseIdx>*>(dataPtr); });
// Call the C++ method.
self.computeActivity(*activeConnectedSynapses, *activePotentialSynapses,
activePresynapticCells.getSparse());

// Call the C++ method.
self.computeActivity(*activeConnectedSynapses,
*activePotentialSynapses,
activePresynapticCells.getSparse(),
learn);

// Wrap vector in numpy array.
return py::make_tuple(
py::array(activeConnectedSynapses->size(),
Expand Down
27 changes: 24 additions & 3 deletions src/examples/hotgym/HelloSPTP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "htm/types/Sdr.hpp"
#include "htm/utils/Random.hpp"
#include "htm/utils/MovingAverage.hpp"
#include "htm/utils/SdrMetrics.hpp"

namespace examples {

Expand Down Expand Up @@ -79,6 +80,13 @@ Real64 BenchmarkHotgym::run(UInt EPOCHS, bool useSPlocal, bool useSPglobal, bool
SDR outTM(spGlobal.getColumnDimensions());
Real an = 0.0f, anLikely = 0.0f; //for anomaly:
MovingAverage avgAnom10(1000); //chose the window large enough so there's (some) periodicity in the patter, so TM can learn something

//metrics
Metrics statsInput(input, 1000);
Metrics statsSPlocal(outSPlocal, 1000);
Metrics statsSPglobal(outSPglobal, 1000);
Metrics statsTM(outTM, 1000);

/*
* For example: fn = sin(x) -> periodic >= 2Pi ~ 6.3 && x+=0.01 -> 630 steps to 1st period -> window >= 630
*/
Expand Down Expand Up @@ -147,13 +155,26 @@ Real64 BenchmarkHotgym::run(UInt EPOCHS, bool useSPlocal, bool useSPglobal, bool
if (e == EPOCHS - 1) {
tAll.stop();

//print connections stats
cout << "\nInput :\n" << statsInput
<< "\nSP(local) " << spLocal.connections
<< "\nSP(local) " << statsSPlocal
<< "\nSP(global) " << spGlobal.connections
<< "\nSP(global) " << statsSPglobal
<< "\nTM " << tm.connections
<< "\nTM " << statsTM
<< "\n";

// output values
cout << "Epoch = " << e << endl;
cout << "Anomaly = " << an << endl;
cout << "Anomaly (avg) = " << avgAnom10.getCurrentAvg() << endl;
cout << "Anomaly (Likelihood) = " << anLikely << endl;
cout << "SP (g)= " << outSP << endl;
cout << "SP (l)= " << outSPlocal <<endl;
cout << "TM= " << outTM << endl;

//timers
cout << "==============TIMERS============" << endl;
cout << "Init:\t" << tInit.getElapsed() << endl;
cout << "Random:\t" << tRng.getElapsed() << endl;
Expand Down Expand Up @@ -184,12 +205,12 @@ Real64 BenchmarkHotgym::run(UInt EPOCHS, bool useSPlocal, bool useSPglobal, bool

SDR goldTM({COLS});
const SDR_sparse_t deterministicTM{
51, 62, 72, 77, 102, 155, 287, 306, 337, 340, 370, 493, 542, 952, 1089, 1110, 1115, 1193, 1463, 1488, 1507, 1518, 1547, 1626, 1668, 1694, 1781, 1803, 1805, 1827, 1841, 1858,1859, 1860, 1861, 1862, 1878, 1881, 1915, 1918, 1923, 1929, 1933, 1939, 1941, 1953, 1955, 1956, 1958, 1961, 1965, 1968, 1975, 1976, 1980, 1981, 1985, 1986, 1987, 1991, 1992, 1994, 1997, 2002, 2006, 2008, 2012, 2013, 2040, 2042
62, 77, 85, 322, 340, 432, 952, 1120, 1488, 1502, 1512, 1518, 1547, 1627, 1633, 1668, 1727, 1729, 1797, 1803, 1805, 1812, 1858, 1859, 1896, 1918, 1923, 1925, 1929, 1931, 1939, 1941, 1942, 1944, 1950, 1953, 1955, 1956, 1965, 1966, 1967, 1968, 1974, 1980, 1987, 1996, 2006, 2008, 2011, 2027, 2030, 2042, 2046
};
goldTM.setSparse(deterministicTM);

const float goldAn = 0.745098f;
const float goldAnAvg = 0.408286f;
const float goldAn = 0.627451f;
const float goldAnAvg = 0.407265f;

if(EPOCHS == 5000) { //these hand-written values are only valid for EPOCHS = 5000 (default), but not for debug and custom runs.
NTA_CHECK(input == goldEnc) << "Deterministic output of Encoder failed!\n" << input << "should be:\n" << goldEnc;
Expand Down
Loading

0 comments on commit 6475fba

Please sign in to comment.