Skip to content

Commit

Permalink
Fixed bug when there are no unknownwise residuals of a given ispace.
Browse files Browse the repository at this point in the history
  • Loading branch information
Mx7f committed Oct 25, 2017
1 parent 14587e8 commit 4da559d
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 1 deletion.
51 changes: 50 additions & 1 deletion API/src/o.t
Original file line number Diff line number Diff line change
Expand Up @@ -1973,8 +1973,58 @@ local function classifyexpression(exp) -- what index space, or graph is this thi
return classification,template
end

local function extract_unused_unknowns(kinds,kind_to_templates)
local used_nongraph_unknown_images = {}
local all_unknowns = {}
for _,k in ipairs(kinds) do
for _,template in ipairs(kind_to_templates[k]) do
for _,u in ipairs(template.unknowns) do
if k.kind == "CenteredFunction" then
used_nongraph_unknown_images[u.image] = true
end
all_unknowns[u.image] = true
end
end
end
local unused_unknowns = terralib.newlist()
for _,k in ipairs(table.keys(all_unknowns)) do
if not used_nongraph_unknown_images[k] then
unused_unknowns:insert(k)
end
end
return unused_unknowns
end

local function insert_dummy_energies_for_unused_unknowns(kinds,kind_to_templates,unused_unknowns)
for _,u in ipairs(unused_unknowns) do
local ispace = u.type.ispace
local unknownaccesses = terralib.newlist()
for i = 1,u.type.channelcount do
local imacc = ImageAccess(u,ad.scalar,ispace:ZeroOffset(),i-1)
unknownaccesses:insert(imacc)
end
local kind = A.CenteredFunction(ispace)
if not kind_to_templates[kind] then
kinds:insert(kind)
kind_to_templates[kind] = terralib.newlist()
end
local exp = ad.toexp(0)
local template = A.ResidualTemplate(exp,unknownaccesses)
kind_to_templates[kind]:insert(template)
end
end

local function toenergyspecs(Rs)
local kinds,kind_to_templates = MapAndGroupBy(Rs,classifyexpression)
local unused_unknowns = extract_unused_unknowns(kinds,kind_to_templates)
if #unused_unknowns > 0 then
local message = "No unknownwise residuals for unknown(s) "..tostring(unused_unknowns[1])
for i=2,#unused_unknowns do
message = message..", "..tostring(unused_unknowns[i])
end
print(message..". Creating zero-valued stand-ins.")
end
insert_dummy_energies_for_unused_unknowns(kinds,kind_to_templates,unused_unknowns)
return kinds:map(function(k) return A.EnergySpec(k,kind_to_templates[k]) end)
end

Expand Down Expand Up @@ -2413,7 +2463,6 @@ local function extractresidualterms(...)
end
function ProblemSpecAD:Cost(...)
local terms = extractresidualterms(...)

local functionspecs = List()
local energyspecs = toenergyspecs(terms)
for _,energyspec in ipairs(energyspecs) do
Expand Down
24 changes: 24 additions & 0 deletions API/src/util.t
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,30 @@ util.C = terralib.includecstring [[
]]
local C = util.C

--[[ rPrint(struct, [limit], [indent]) Recursively print arbitrary data.
Set limit (default 100) to stanch infinite loops.
Indents tables as [KEY] VALUE, nested tables as [KEY] [KEY]...[KEY] VALUE
Set indent ("") to prefix each line: Mytable [KEY] [KEY]...[KEY] VALUE
--]]
function util.rPrint(s, l, i) -- recursive Print (structure, limit, indent)
l = (l) or 100; i = i or ""; -- default item limit, indent string
local ts = type(s);
if (l<1) then print (i,ts," *snip* "); return end;
if (ts ~= "table") then print (i,ts,s); return end
print (i,ts); -- print "table"
for k,v in pairs(s) do -- print "[KEY] VALUE"
util.rPrint(v, l-1, i.."\t["..tostring(k).."]");
end
end

function table.keys(tab)
local result = terralib.newlist()
for k,_ in pairs(tab) do
result:insert(k)
end
return result
end

local cuda_compute_version = 30
local libdevice = terralib.cudahome..string.format("/nvvm/libdevice/libdevice.compute_%d.10.bc",cuda_compute_version)

Expand Down

0 comments on commit 4da559d

Please sign in to comment.