-
Notifications
You must be signed in to change notification settings - Fork 0
/
Nncachebuilder.lua
197 lines (171 loc) · 6.53 KB
/
Nncachebuilder.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
-- Nncachebuilder.lua
-- cache for nearsest 256 neighbors
-- API overview
if false then
nncb = Nncachebuilder(allXs, nShards)
for n = 1, nShards do
-- serialize cache object (a table) to file <prefix>nncache-shard-n.txt
nncb:createShard(n, 'filePathPrefix')
end
-- create serialized cache object int file <prefix>nncache-merged.txt
Nncachebulder.mergeShards(nShards, 'filePathPrefix')
-- ILLUSTRATIVE USE
-- read the serialized merged cache from file system
cache = Nncache.loadUsingPrefix('filePathPrefix')
-- now cache[27] is a 1D tensor of the sorted indices closest to obs # 27
-- in the original xs
-- use the cache to smooth values
selected = setSelected() -- to selected observations in Xs and Ys
-- use the original allXs to create smoothed estimates
knnSmoother = KnnSmoother(allXs, allYs, selected, cache)
estimate = knnSmoother:estimate(queryIndex, k)
end
--------------------------------------------------------------------------------
-- CONSTRUCTION
--------------------------------------------------------------------------------
torch.class('Nncachebuilder')
function Nncachebuilder:__init(allXs, nShards)
local v, isVerbose = makeVerbose(false, 'Nncachebuilder:__init')
verify(v, isVerbose,
{{allXs, 'allXs', 'isTensor2D'},
{nShards, 'nShards', 'isIntegerPositive'}})
-- an index must fit into an integer
assert(allXs:size(1) <= 2147483647, -- about 2 billion
'more than 2^31 - 1 rows in the tensor')
self._allXs = allXs
self._nShards = nShards
self._cache = Nncache()
end -- __init
--------------------------------------------------------------------------------
-- PUBLIC CLASS METHODS
--------------------------------------------------------------------------------
function Nncachebuilder.format()
-- format used to serialize the cache
return 'ascii' -- 'binary' is faster
end -- _format
function Nncachebuilder.maxNeighbors()
-- number of neighbor indices stored; size of cache[index]
return 256
end
function Nncachebuilder.mergedFileSuffix()
-- end part of file name
return 'nncache-merged.txt'
end -- mergedFileSuffix
--------------------------------------------------------------------------------
-- PRIVATE CLASS METHODS
--------------------------------------------------------------------------------
function Nncachebuilder._shardFilePath(filePathPrefix, shardNumber)
return filePathPrefix .. string.format('shard-%d.txt', shardNumber)
end -- _shardFilePath
--------------------------------------------------------------------------------
-- PUBLIC INSTANCE METHODS
--------------------------------------------------------------------------------
function Nncachebuilder:createShard(shardNumber, filePathPrefix, chatty)
-- create an Nncache holding all the nearest neighbors in the shard
-- write this Nncache to disk
-- return file path where written
local v, isVerbose = makeVerbose(false, 'createShard')
verify(v, isVerbose,
{{shardNumber, 'shardNumber', 'isIntegerPositive'},
{filePathPrefix, 'filePathPrefix', 'isString'}})
-- set default for chatty [true]
v('chatty', chatty)
if chatty == nil then
chatty = true
end
v('chatty', chatty)
v('self', self)
assert(shardNumber <= self._nShards)
local tc = TimerCpu()
local cache = Nncache()
local count = 0
local shard = 0
local roughCount = self._allXs:size(1) / self._nShards
for obsIndex = 1, self._allXs:size(1) do
shard = shard + 1
if shard > self._nShards then
shard = 1
end
if shard == shardNumber then
-- observation in shard, so create its neighbors indices
local query = self._allXs[obsIndex]:clone()
collectgarbage()
local _, allIndices = Nnw.nearest(self._allXs,
query)
-- NOTE: creating a view of the storage seems like a good idea
-- but fails when the tensor is serialized out
local n = math.min(Nncachebuilder.maxNeighbors(), self._allXs:size(1))
local firstIndices = torch.Tensor(n)
for i = 1, n do
firstIndices[i] = allIndices[i]
end
cache:setLine(obsIndex, firstIndices)
count = count + 1
if false then
v('count', count)
v('obsIndex', obsIndex)
v('firstIndices', firstIndices)
end
if count % 10000 == 1 and chatty then
local rate = tc:cumSeconds() / count
print(string.format(
'Nncachebuilder:createShard: create %d indices' ..
' at %f CPU sec each',
count, rate))
local remaining = roughCount - count
print(string.format('need %f CPU hours to finish remaining %d',
rate * remaining / 60 / 60, remaining))
--halt()
end
end
end
v('count', count)
-- halt()
-- write by serializing
local filePath = Nncachebuilder._shardFilePath(filePathPrefix, shardNumber)
v('filePath', filePath)
cache:save(filePath)
return filePath
end -- createShard
function Nncachebuilder.mergeShards(nShards, filePathPrefix, chatty)
-- RETURN
-- number of records in merged file
-- file path where merged cache data were written
local v, isVerbose = makeVerbose(false, 'mergeShards')
verify(v, isVerbose,
{{nShards, 'nShards', 'isIntegerPositive'},
{filePathPrefix, 'filePathPrefix', 'isString'}})
-- set default for chatty [true]
if chatty == nil then
chatty = true
end
local cache = Nncache()
local countAll = 0
for n = 1, nShards do
local path = Nncachebuilder._shardFilePath(filePathPrefix, n)
if chatty then
print('reading shard cache file ', path)
end
local shard = Nncache.load(path)
affirm.isTable(shard, 'Nncache')
-- insert all shard elements into the cache
local countShard = 0
local function insert(key, value)
cache:setLine(key, value)
countShard = countShard + 1
end
shard:apply(insert)
if chatty then
print('number records inserted from shard', countShard)
end
end
if chatty then
print('number of records inserted from all shards', countAll)
end
local mergedFilePath = filePathPrefix .. Nncachebuilder.mergedFileSuffix()
if chatty then
print('writing merged cache file', mergedFilePath)
end
torch.save(mergedFilePath, cache, Nncachebuilder.format())
return countAll, mergedFilePath
end -- mergeShards