Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow SDR classifier to handle multiple category #1339

Merged
merged 4 commits into from
Jun 7, 2017
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 55 additions & 37 deletions src/nupic/algorithms/SDRClassifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,17 +84,28 @@ namespace nupic
}

void SDRClassifier::compute(
UInt recordNum, const vector<UInt>& patternNZ, UInt bucketIdx,
Real64 actValue, bool category, bool learn, bool infer,
UInt recordNum, const vector<UInt>& patternNZ, const vector<UInt>& bucketIdxList,
const vector<Real64>& actValueList, bool category, bool learn, bool infer,
ClassifierResult* result)
{
// update pattern history
patternNZHistory_.emplace_back(patternNZ.begin(), patternNZ.end());
recordNumHistory_.push_back(recordNum);
if (patternNZHistory_.size() > maxSteps_)
// ensures that recordNum increases monotonically
UInt lastRecordNum = -1;
if (recordNumHistory_.size() > 0)
{
patternNZHistory_.pop_front();
recordNumHistory_.pop_front();
lastRecordNum = recordNumHistory_[recordNumHistory_.size()-1];
if(recordNum < lastRecordNum)
NTA_THROW << "the record number has to increase monotonically";
}
// update pattern history if this is a new record
if(recordNumHistory_.size() ==0 || recordNum > lastRecordNum)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add space after if and ==

{
patternNZHistory_.emplace_back(patternNZ.begin(), patternNZ.end());
recordNumHistory_.push_back(recordNum);
if (patternNZHistory_.size() > maxSteps_)
{
patternNZHistory_.pop_front();
recordNumHistory_.pop_front();
}
}

// if input pattern has greater index than previously seen, update
Expand All @@ -116,38 +127,43 @@ namespace nupic
// if in inference mode, compute likelihood and update return value
if (infer)
{
infer_(patternNZ, bucketIdx, actValue, result);
infer_(patternNZ, actValueList, result);
}

// update weights if in learning mode
if (learn)
{
// if bucket is greater, update maxBucketIdx_ and augment weight
// matrix with zero-padding
if (bucketIdx > maxBucketIdx_)
for(size_t categoryI=0; categoryI < bucketIdxList.size(); categoryI++)
{
maxBucketIdx_ = bucketIdx;
for (const auto& step : steps_)
UInt bucketIdx = bucketIdxList[categoryI];
Real64 actValue = actValueList[categoryI];
// if bucket is greater, update maxBucketIdx_ and augment weight
// matrix with zero-padding
if (bucketIdx > maxBucketIdx_)
{
Matrix& weights = weightMatrix_.at(step);
weights.resize(maxInputIdx_ + 1, maxBucketIdx_ + 1);
maxBucketIdx_ = bucketIdx;
for (const auto& step : steps_)
{
Matrix& weights = weightMatrix_.at(step);
weights.resize(maxInputIdx_ + 1, maxBucketIdx_ + 1);
}
}
}

// update rolling averages of bucket values
while (actualValues_.size() <= maxBucketIdx_)
{
actualValues_.push_back(0.0);
actualValuesSet_.push_back(false);
}
if (!actualValuesSet_[bucketIdx] || category)
{
actualValues_[bucketIdx] = actValue;
actualValuesSet_[bucketIdx] = true;
} else {
actualValues_[bucketIdx] =
((1.0 - actValueAlpha_) * actualValues_[bucketIdx]) +
(actValueAlpha_ * actValue);
// update rolling averages of bucket values
while (actualValues_.size() <= maxBucketIdx_)
{
actualValues_.push_back(0.0);
actualValuesSet_.push_back(false);
}
if (!actualValuesSet_[bucketIdx] || category)
{
actualValues_[bucketIdx] = actValue;
actualValuesSet_[bucketIdx] = true;
} else {
actualValues_[bucketIdx] =
((1.0 - actValueAlpha_) * actualValues_[bucketIdx]) +
(actValueAlpha_ * actValue);
}
}

// compute errors and update weights
Expand All @@ -162,7 +178,7 @@ namespace nupic
// update weights
if (binary_search(steps_.begin(), steps_.end(), nSteps))
{
vector<Real64> error = calculateError_(bucketIdx,
vector<Real64> error = calculateError_(bucketIdxList,
learnPatternNZ, nSteps);
Matrix& weights = weightMatrix_.at(nSteps);
for (auto& bit : learnPatternNZ)
Expand All @@ -184,8 +200,8 @@ namespace nupic
return s.str().size();
}

void SDRClassifier::infer_(const vector<UInt>& patternNZ, UInt bucketIdx,
Real64 actValue, ClassifierResult* result)
void SDRClassifier::infer_(const vector<UInt>& patternNZ,
const vector<Real64>& actValue, ClassifierResult* result)
{
// add the actual values to the return value. For buckets that haven't
// been seen yet, the actual value doesn't matter since it will have
Expand All @@ -204,7 +220,7 @@ namespace nupic
{
(*actValueVector)[i] = 0;
} else {
(*actValueVector)[i] = actValue;
(*actValueVector)[i] = actValue[0];
}
}
}
Expand All @@ -227,7 +243,7 @@ namespace nupic
}
}

vector<Real64> SDRClassifier::calculateError_(UInt bucketIdx,
vector<Real64> SDRClassifier::calculateError_(const vector<UInt>& bucketIdxList,
const vector<UInt> patternNZ, UInt step)
{
// compute predicted likelihoods
Expand All @@ -244,7 +260,9 @@ namespace nupic

// compute target likelihoods
vector<Real64> targetDistribution (maxBucketIdx_ + 1, 0.0);
targetDistribution[bucketIdx] = 1.0;
Real64 numCategories = (Real64)bucketIdxList.size();
for(size_t i=0; i<bucketIdxList.size(); i++)
targetDistribution[bucketIdxList[i]] = 1.0 / numCategories;

axby(-1.0, likelihoods, 1.0, targetDistribution);
return likelihoods;
Expand Down
12 changes: 6 additions & 6 deletions src/nupic/algorithms/SDRClassifier.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ namespace nupic
* used when predicting each bucket.
*/
virtual void compute(
UInt recordNum, const vector<UInt>& patternNZ, UInt bucketIdx,
Real64 actValue, bool category, bool learn, bool infer,
UInt recordNum, const vector<UInt>& patternNZ, const vector<UInt>& bucketIdxList,
const vector<Real64>& actValueList, bool category, bool learn, bool infer,
ClassifierResult* result);

/**
Expand Down Expand Up @@ -161,12 +161,12 @@ namespace nupic

private:
// Helper function for inference mode
void infer_(const vector<UInt>& patternNZ, UInt bucketIdx,
Real64 actValue, ClassifierResult* result);
void infer_(const vector<UInt>& patternNZ,
const vector<Real64>& actValue, ClassifierResult* result);

// Helper function to compute the error signal in learning mode
vector<Real64> calculateError_(UInt bucketIdx, const vector<UInt>,
UInt step);
vector<Real64> calculateError_(const vector<UInt>& bucketIdxList,
const vector<UInt> patternNZ, UInt step);

// The list of prediction steps to learn and infer.
vector<UInt> steps_;
Expand Down
23 changes: 15 additions & 8 deletions src/nupic/bindings/algorithms.i
Original file line number Diff line number Diff line change
Expand Up @@ -1446,25 +1446,31 @@ void forceRetentionOfImageSensorLiteLibrary(void) {
noneSentinel = 3.14159

if type(classification["actValue"]) in (int, float):
actValue = classification["actValue"]
actValueList = [classification["actValue"]]
bucketIdxList = [classification["bucketIdx"]]
category = False
elif classification["actValue"] is None:
# Use the sentinel value so we know if it gets used in actualValues
# returned.
actValue = noneSentinel
actValueList = [noneSentinel]
# Turn learning off this step.
learn = False
category = False
# This does not get used when learning is disabled anyway.
classification["bucketIdx"] = 0
bucketIdxList = [0]
isNone = True
elif type(classification["actValue"]) is list:
actValueList = classification["actValue"]
bucketIdxList = classification["bucketIdx"]
category = False
else:
actValue = int(classification["bucketIdx"])
actValueList = [int(classification["bucketIdx"])]
bucketIdxList = [classification["bucketIdx"]]
category = True

result = self.convertedCompute(
recordNum, patternNZ, int(classification["bucketIdx"]),
actValue, category, learn, infer)
recordNum, patternNZ, bucketIdxList,
actValueList, category, learn, infer)

if isNone:
for i, v in enumerate(result["actualValues"]):
Expand Down Expand Up @@ -1549,11 +1555,12 @@ void forceRetentionOfImageSensorLiteLibrary(void) {
}

PyObject* convertedCompute(UInt recordNum, const vector<UInt>& patternNZ,
UInt bucketIdx, Real64 actValue, bool category,
const vector<UInt>& bucketIdxList,
const vector<Real64>& actValueList, bool category,
bool learn, bool infer)
{
ClassifierResult result;
self->compute(recordNum, patternNZ, bucketIdx, actValue, category,
self->compute(recordNum, patternNZ, bucketIdxList, actValueList, category,
learn, infer, &result);
PyObject* d = PyDict_New();
for (map<Int, vector<Real64>*>::const_iterator it = result.begin();
Expand Down
Loading