-
Notifications
You must be signed in to change notification settings - Fork 147
/
torch7.lua
226 lines (203 loc) · 7.69 KB
/
torch7.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
-- Copyright 2015 Paul Kulchenko, ZeroBrane LLC; All rights reserved
-- Path handling for Torch and QLua is based on Torch/QLua interpreters from ZBS-torch by Soumith Chintala
local win = ide.osname == 'Windows'
local sep = win and ';' or ':'
local debinit = [[
local mdb = require('mobdebug')
local line = mdb.line
mdb.line = function(v)
local r = line(v)
return type(v) == 'userdata' and loadstring("return "..r)() or r
end]]
local function fixBS(s) -- string callback to eliminate backspaces from Torch output
while s:find("\b") do
s = s
:gsub("[^\b\r\n]\b","") -- remove a backspace and a previous character
:gsub("^\b+","") -- remove all leading backspaces (if any)
:gsub("([\r\n])\b+","%1") -- remove a backspace and a previous character
end
return s
end
local lpaths = {
"install/share/lua/5.1/?.lua", "install/share/lua/5.1/?/init.lua",
"share/lua/5.1/?.lua", "share/lua/5.1/?/init.lua",
"./?.lua", "./?/init.lua",
}
local cpaths = {
"install/lib/lua/5.1/?.", "install/lib/lua/5.1/loadall.", "install/lib/?.",
"lib/lua/5.1/?.", "lib/lua/5.1/loadall.", "lib/?.",
"?.",
}
local function setEnv(torchroot, usepackage)
local tluapath = ''
for _, val in pairs(lpaths) do
tluapath = tluapath .. MergeFullPath(torchroot, val) .. ";"
end
local _, luapath = wx.wxGetEnv("LUA_PATH")
wx.wxSetEnv("LUA_PATH", tluapath..(luapath or ""))
local ext = win and 'dll' or 'so'
local tluacpath = ''
for _, val in pairs(cpaths) do
tluacpath = tluacpath .. MergeFullPath(torchroot, val..ext) .. ";"
end
local _, luacpath = wx.wxGetEnv("LUA_CPATH")
wx.wxSetEnv("LUA_CPATH", tluacpath..(luacpath or ""))
local _, path = wx.wxGetEnv("PATH")
wx.wxSetEnv("PATH", torchroot..(#path > 0 and sep..path or ""))
local env = {LUA_PATH = luapath, LUA_CPATH = luacpath, PATH = path}
if usepackage then -- also assign package variables if requested
env.path, package.path = package.path or false, tluapath .. (package.path or "")
env.cpath, package.cpath = package.cpath or false, tluacpath .. (package.cpath or "")
end
return env
end
local function unsetEnv(env)
for env, val in ipairs(env) do
if package[env] then
package[env] = val or nil
else
if val and #val > 0 then wx.wxSetEnv(env, val) else wx.wxUnsetEnv(env) end
end
end
end
local function findCmd(cmd, env)
local path = (os.getenv('PATH') or '')..sep
..(env or '')..sep
..(os.getenv('HOME') and os.getenv('HOME') .. '/bin' or '')
local paths = {}
local res
for p in path:gmatch("[^"..sep.."]+") do
res = res or GetFullPathIfExists(p, cmd)
table.insert(paths, p)
end
if not res then
ide:Print(("Can't find %s in any of the folders in PATH or TORCH_BIN: "):format(cmd)
..table.concat(paths, ", "))
return
end
return res
end
local qluaInterpreter = {
name = "QLua-LuaJIT",
description = "Qt hooks for luajit",
api = {"baselib", "qlua"},
frun = function(self,wfilename,rundebug)
local qlua = ide.config.path.qlua or findCmd('qlua', os.getenv('QLUA_BIN'))
if not qlua then return end
-- make minor modifications to the cpath to take care of OSX
-- make sure the root is using Torch exe location
local torchroot = MergeFullPath(GetPathWithSep(qlua), "../")
local env = setEnv(torchroot)
local filepath = wfilename:GetFullPath()
local script
if rundebug then
ide:GetDebugger():SetOptions({runstart = ide.config.debugger.runonstart == true, init = debinit})
script = rundebug
else
-- if running on Windows and can't open the file, this may mean that
-- the file path includes unicode characters that need special handling
local fh = io.open(filepath, "r")
if fh then fh:close() end
if win and pcall(require, "winapi")
and wfilename:FileExists() and not fh then
winapi.set_encoding(winapi.CP_UTF8)
filepath = winapi.short_path(filepath)
end
script = ('dofile [[%s]]'):format(filepath)
end
local code = ([[xpcall(function() io.stdout:setvbuf('no'); %s end,function(err) print(debug.traceback(err)) end)]]):format(script)
local cmd = '"'..qlua..'" -e "'..code..'"'
-- CommandLineRun(cmd,wdir,tooutput,nohide,stringcallback,uid,endcallback)
local pid = CommandLineRun(cmd,self:fworkdir(wfilename),true,false,fixBS)
unsetEnv(env)
return pid
end,
hasdebugger = true,
fattachdebug = function(self)
ide:GetDebugger():SetOptions({
runstart = ide.config.debugger.runonstart == true,
init = debinit
})
end,
scratchextloop = true,
}
local torchInterpreter = {
name = "Torch-7",
description = "Torch machine learning package",
api = {"baselib", "torch"},
frun = function(self,wfilename,rundebug)
-- check if the path is configured
local torch = ide.config.path.torch or findCmd(win and 'th.bat' or 'th', os.getenv('TORCH_BIN'))
if not torch then return end
local filepath = wfilename:GetFullPath()
if rundebug then
ide:GetDebugger():SetOptions({runstart = ide.config.debugger.runonstart == true, init = debinit})
-- update arg to point to the proper file
rundebug = ('if arg then arg[0] = [[%s]] end '):format(filepath)..rundebug
local tmpfile = wx.wxFileName()
tmpfile:AssignTempFileName(".")
filepath = tmpfile:GetFullPath()
local f = io.open(filepath, "w")
if not f then
ide:Print("Can't open temporary file '"..filepath.."' for writing.")
return
end
f:write("io.stdout:setvbuf('no'); " .. rundebug)
f:close()
else
-- if running on Windows and can't open the file, this may mean that
-- the file path includes unicode characters that need special handling
local fh = io.open(filepath, "r")
if fh then fh:close() end
if win and pcall(require, "winapi")
and wfilename:FileExists() and not fh then
winapi.set_encoding(winapi.CP_UTF8)
filepath = winapi.short_path(filepath)
end
end
-- doesn't need set environment with setEnv as it's already done in onInterpreterLoad
local params = ide.config.arg.any or ide.config.arg.torch7 or ''
local uselua = wx.wxDirExists(torch)
local cmd = ([["%s" "%s" %s]]):format(
uselua and ide:GetInterpreters().luadeb:GetExePath("") or torch, filepath, params)
-- CommandLineRun(cmd,wdir,tooutput,nohide,stringcallback,uid,endcallback)
return CommandLineRun(cmd,self:fworkdir(wfilename),true,false,fixBS,nil,
function() if rundebug then wx.wxRemoveFile(filepath) end end)
end,
hasdebugger = true,
fattachdebug = function(self)
ide:GetDebugger():SetOptions({
runstart = ide.config.debugger.runonstart == true,
init = debinit
})
end,
scratchextloop = true,
takeparameters = true,
}
return {
name = "Torch7",
description = "Implements integration with torch7 environment.",
author = "Paul Kulchenko",
version = 0.58,
dependencies = "1.40",
onRegister = function(self)
ide:AddInterpreter("torch", torchInterpreter)
ide:AddInterpreter("qlua", qluaInterpreter)
end,
onUnRegister = function(self)
ide:RemoveInterpreter("torch")
ide:RemoveInterpreter("qlua")
end,
onInterpreterLoad = function(self, interpreter)
if interpreter:GetFileName() ~= "torch" then return end
local torch = ide.config.path.torch or findCmd(win and 'th.bat' or 'th', os.getenv('TORCH_BIN'))
if not torch then return end
local uselua = wx.wxDirExists(torch)
local torchroot = uselua and torch or MergeFullPath(GetPathWithSep(torch), "../")
interpreter.env = setEnv(torchroot, true)
end,
onInterpreterClose = function(self, interpreter)
if interpreter:GetFileName() ~= "torch" then return end
if interpreter.env then unsetEnv(interpreter.env) end
end,
}