From 4da559dc6103d91da990d440bce02b7283ffd0f6 Mon Sep 17 00:00:00 2001
From: Michael <mikemx7f@gmail.com>
Date: Wed, 25 Oct 2017 14:20:23 -0700
Subject: [PATCH] Fixed bug when there are no unknownwise residuals of a given
 ispace.

---
 API/src/o.t    | 51 +++++++++++++++++++++++++++++++++++++++++++++++++-
 API/src/util.t | 24 ++++++++++++++++++++++++
 2 files changed, 74 insertions(+), 1 deletion(-)

diff --git a/API/src/o.t b/API/src/o.t
index 334150d96..7d29c9447 100644
--- a/API/src/o.t
+++ b/API/src/o.t
@@ -1973,8 +1973,58 @@ local function classifyexpression(exp) -- what index space, or graph is this thi
     return classification,template
 end
 
+local function extract_unused_unknowns(kinds,kind_to_templates)
+    local used_nongraph_unknown_images = {}
+    local all_unknowns = {}
+    for _,k in ipairs(kinds) do
+        for _,template in ipairs(kind_to_templates[k]) do
+            for _,u in ipairs(template.unknowns) do
+                if k.kind == "CenteredFunction" then
+                    used_nongraph_unknown_images[u.image] = true
+                end
+                all_unknowns[u.image] = true
+            end
+        end
+    end
+    local unused_unknowns = terralib.newlist()
+    for _,k in ipairs(table.keys(all_unknowns)) do
+        if not used_nongraph_unknown_images[k] then
+            unused_unknowns:insert(k)
+        end
+    end
+    return unused_unknowns
+end
+
+local function insert_dummy_energies_for_unused_unknowns(kinds,kind_to_templates,unused_unknowns)
+    for _,u in ipairs(unused_unknowns) do
+        local ispace = u.type.ispace
+        local unknownaccesses = terralib.newlist()
+        for i = 1,u.type.channelcount do
+            local imacc = ImageAccess(u,ad.scalar,ispace:ZeroOffset(),i-1)
+            unknownaccesses:insert(imacc)
+        end
+        local kind = A.CenteredFunction(ispace)
+        if not kind_to_templates[kind] then
+            kinds:insert(kind)
+            kind_to_templates[kind] = terralib.newlist()
+        end
+        local exp = ad.toexp(0)
+        local template = A.ResidualTemplate(exp,unknownaccesses)
+        kind_to_templates[kind]:insert(template)
+    end
+end
+
 local function toenergyspecs(Rs)    
     local kinds,kind_to_templates = MapAndGroupBy(Rs,classifyexpression)
+    local unused_unknowns = extract_unused_unknowns(kinds,kind_to_templates)
+    if #unused_unknowns > 0 then
+        local message = "No unknownwise residuals for unknown(s) "..tostring(unused_unknowns[1])
+        for i=2,#unused_unknowns do
+            message = message..", "..tostring(unused_unknowns[i])
+        end
+        print(message..". Creating zero-valued stand-ins.")
+    end
+    insert_dummy_energies_for_unused_unknowns(kinds,kind_to_templates,unused_unknowns)
     return kinds:map(function(k) return A.EnergySpec(k,kind_to_templates[k]) end)
 end
 
@@ -2413,7 +2463,6 @@ local function extractresidualterms(...)
 end
 function ProblemSpecAD:Cost(...)
     local terms = extractresidualterms(...)
-    
     local functionspecs = List()
     local energyspecs = toenergyspecs(terms)
     for _,energyspec in ipairs(energyspecs) do
diff --git a/API/src/util.t b/API/src/util.t
index e5a575dd7..5af45567d 100644
--- a/API/src/util.t
+++ b/API/src/util.t
@@ -14,6 +14,30 @@ util.C = terralib.includecstring [[
 ]]
 local C = util.C
 
+--[[ rPrint(struct, [limit], [indent])   Recursively print arbitrary data. 
+    Set limit (default 100) to stanch infinite loops.
+    Indents tables as [KEY] VALUE, nested tables as [KEY] [KEY]...[KEY] VALUE
+    Set indent ("") to prefix each line:    Mytable [KEY] [KEY]...[KEY] VALUE
+--]]
+function util.rPrint(s, l, i) -- recursive Print (structure, limit, indent)
+    l = (l) or 100; i = i or "";    -- default item limit, indent string
+    local ts = type(s);
+    if (l<1) then print (i,ts," *snip* "); return end;
+    if (ts ~= "table") then print (i,ts,s); return end
+    print (i,ts);           -- print "table"
+    for k,v in pairs(s) do  -- print "[KEY] VALUE"
+        util.rPrint(v, l-1, i.."\t["..tostring(k).."]");
+    end
+end 
+
+function table.keys(tab)
+    local result = terralib.newlist()
+    for k,_ in pairs(tab) do
+        result:insert(k)
+    end
+    return result
+end
+
 local cuda_compute_version = 30
 local libdevice = terralib.cudahome..string.format("/nvvm/libdevice/libdevice.compute_%d.10.bc",cuda_compute_version)