Skip to content

Commit

Permalink
numenta/nupic.core-legacy#1380: Fix SP tests with correct dtype values
Browse files Browse the repository at this point in the history
  • Loading branch information
lscheinkman committed Jan 16, 2018
1 parent 17db904 commit e470860
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 5 deletions.
6 changes: 3 additions & 3 deletions tests/unit/nupic/algorithms/sp_overlap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,10 @@ def frequency(self,
maxval=maxVal, periodic=False, forced=True) # forced: it's strongly recommended to use w>=21, in the example we force skip the check for readibility
for y in xrange(numColors):
temp = enc.encode(rnd.random()*maxVal)
colors.append(numpy.array(temp, dtype=realDType))
colors.append(numpy.array(temp, dtype=numpy.uint32))
else:
for y in xrange(numColors):
sdr = numpy.zeros(n, dtype=realDType)
sdr = numpy.zeros(n, dtype=numpy.uint32)
# Randomly setting w out of n bits to 1
sdr[rnd.sample(xrange(n), w)] = 1
colors.append(sdr)
Expand All @@ -144,7 +144,7 @@ def frequency(self,
for i in xrange(numColors):
# TODO: See https://github.com/numenta/nupic/issues/2072
spInput = colors[i]
onCells = numpy.zeros(columnDimensions)
onCells = numpy.zeros(columnDimensions, dtype=numpy.uint32)
spImpl.compute(spInput, True, onCells)
spOutput.append(onCells.tolist())
activeCoincIndices = set(onCells.nonzero()[0])
Expand Down
18 changes: 18 additions & 0 deletions tests/unit/nupic/algorithms/spatial_pooler_cpp_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,24 @@ def testUpdateDutyCycles(self):
self.assertEqual(list(resultOverlapArr2), list(trueOverlapArr2))


def testComputeParametersValidation(self):
sp = SpatialPooler(inputDimensions=[5], columnDimensions=[5])
inputGood = np.ones(5, dtype=uintDType)
outGood = np.zeros(5, dtype=uintDType)
inputBad = np.ones(5, dtype=realDType)
outBad = np.zeros(5, dtype=realDType)

# Validate good parameters
sp.compute(inputGood, False, outGood)

# Validate bad input
with self.assertRaises(RuntimeError):
sp.compute(inputBad, False, outGood)

# Validate bad output
with self.assertRaises(RuntimeError):
sp.compute(inputGood, False, outBad)


if __name__ == "__main__":
unittest.main()
4 changes: 2 additions & 2 deletions tests/unit/nupic/algorithms/spatial_pooler_py_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def setUp(self):

def testCompute(self):
# Check that there are no errors in call to compute
inputVector = numpy.ones(5)
activeArray = numpy.zeros(5)
inputVector = numpy.ones(5, dtype=uintType)
activeArray = numpy.zeros(5, dtype=uintType)
self.sp.compute(inputVector, True, activeArray)


Expand Down

0 comments on commit e470860

Please sign in to comment.