diff --git a/API/src/o.t b/API/src/o.t index 334150d96..7d29c9447 100644 --- a/API/src/o.t +++ b/API/src/o.t @@ -1973,8 +1973,58 @@ local function classifyexpression(exp) -- what index space, or graph is this thi return classification,template end +local function extract_unused_unknowns(kinds,kind_to_templates) + local used_nongraph_unknown_images = {} + local all_unknowns = {} + for _,k in ipairs(kinds) do + for _,template in ipairs(kind_to_templates[k]) do + for _,u in ipairs(template.unknowns) do + if k.kind == "CenteredFunction" then + used_nongraph_unknown_images[u.image] = true + end + all_unknowns[u.image] = true + end + end + end + local unused_unknowns = terralib.newlist() + for _,k in ipairs(table.keys(all_unknowns)) do + if not used_nongraph_unknown_images[k] then + unused_unknowns:insert(k) + end + end + return unused_unknowns +end + +local function insert_dummy_energies_for_unused_unknowns(kinds,kind_to_templates,unused_unknowns) + for _,u in ipairs(unused_unknowns) do + local ispace = u.type.ispace + local unknownaccesses = terralib.newlist() + for i = 1,u.type.channelcount do + local imacc = ImageAccess(u,ad.scalar,ispace:ZeroOffset(),i-1) + unknownaccesses:insert(imacc) + end + local kind = A.CenteredFunction(ispace) + if not kind_to_templates[kind] then + kinds:insert(kind) + kind_to_templates[kind] = terralib.newlist() + end + local exp = ad.toexp(0) + local template = A.ResidualTemplate(exp,unknownaccesses) + kind_to_templates[kind]:insert(template) + end +end + local function toenergyspecs(Rs) local kinds,kind_to_templates = MapAndGroupBy(Rs,classifyexpression) + local unused_unknowns = extract_unused_unknowns(kinds,kind_to_templates) + if #unused_unknowns > 0 then + local message = "No unknownwise residuals for unknown(s) "..tostring(unused_unknowns[1]) + for i=2,#unused_unknowns do + message = message..", "..tostring(unused_unknowns[i]) + end + print(message..". Creating zero-valued stand-ins.") + end + insert_dummy_energies_for_unused_unknowns(kinds,kind_to_templates,unused_unknowns) return kinds:map(function(k) return A.EnergySpec(k,kind_to_templates[k]) end) end @@ -2413,7 +2463,6 @@ local function extractresidualterms(...) end function ProblemSpecAD:Cost(...) local terms = extractresidualterms(...) - local functionspecs = List() local energyspecs = toenergyspecs(terms) for _,energyspec in ipairs(energyspecs) do diff --git a/API/src/util.t b/API/src/util.t index e5a575dd7..5af45567d 100644 --- a/API/src/util.t +++ b/API/src/util.t @@ -14,6 +14,30 @@ util.C = terralib.includecstring [[ ]] local C = util.C +--[[ rPrint(struct, [limit], [indent]) Recursively print arbitrary data. + Set limit (default 100) to stanch infinite loops. + Indents tables as [KEY] VALUE, nested tables as [KEY] [KEY]...[KEY] VALUE + Set indent ("") to prefix each line: Mytable [KEY] [KEY]...[KEY] VALUE +--]] +function util.rPrint(s, l, i) -- recursive Print (structure, limit, indent) + l = (l) or 100; i = i or ""; -- default item limit, indent string + local ts = type(s); + if (l<1) then print (i,ts," *snip* "); return end; + if (ts ~= "table") then print (i,ts,s); return end + print (i,ts); -- print "table" + for k,v in pairs(s) do -- print "[KEY] VALUE" + util.rPrint(v, l-1, i.."\t["..tostring(k).."]"); + end +end + +function table.keys(tab) + local result = terralib.newlist() + for k,_ in pairs(tab) do + result:insert(k) + end + return result +end + local cuda_compute_version = 30 local libdevice = terralib.cudahome..string.format("/nvvm/libdevice/libdevice.compute_%d.10.bc",cuda_compute_version)