Skip to content
This repository has been archived by the owner on Oct 30, 2019. It is now read-only.

Add inception-resnet-v2 model #64

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions checkpoints.lua
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ function checkpoint.latest(opt)
return latest, optimState
end

function checkpoint.save(epoch, model, optimState, bestModel)
function checkpoint.save(epoch, model, optimState, bestModel, opt)
-- Don't save the DataParallelTable for easier loading on other machines
if torch.type(model) == 'nn.DataParallelTable' then
model = model:get(1)
Expand All @@ -34,16 +34,16 @@ function checkpoint.save(epoch, model, optimState, bestModel)
local modelFile = 'model_' .. epoch .. '.t7'
local optimFile = 'optimState_' .. epoch .. '.t7'

torch.save(modelFile, model)
torch.save(optimFile, optimState)
torch.save('latest.t7', {
torch.save(paths.concat(opt.resume, modelFile), model)
torch.save(paths.concat(opt.resume, optimFile), optimState)
torch.save(paths.concat(opt.resume, 'latest.t7'), {
epoch = epoch,
modelFile = modelFile,
optimFile = optimFile,
})

if bestModel then
torch.save('model_best.t7', model)
torch.save(paths.concat(opt.resume, 'model_best.t7'), model)
end
end

Expand Down
6 changes: 5 additions & 1 deletion dataloader.lua
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@ function DataLoader:__init(dataset, opt, split)
end
torch.setnumthreads(1)
_G.dataset = dataset
_G.preprocess = dataset:preprocess()
if opt.netType == 'inception-resnet-v2' or opt.netType == 'inception-resnet-v2-aux' then
_G.preprocess = dataset:preprocess(328, 299)
else
_G.preprocess = dataset:preprocess(256, 224)
end
return dataset:size()
end

Expand Down
189 changes: 189 additions & 0 deletions datasets/create-imagenet-lmdb.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
--[[
A script to conver images to lmdb dataset

References
1. https://github.com/facebook/fb.resnet.torch/blob/master/datasets/init.lua
2. https://github.com/eladhoffer/ImageNet-Training/blob/master/CreateLMDBs.lua
]]--

local ffi = require 'ffi'
local image = require 'image'

-- Define local functions
local function isvalid(opt, cachePath)
local imageInfo = torch.load(cachePath)
if imageInfo.basedir and imageInfo.basedir ~= opt.data then
return false
end
return true
end

local function split(dataset)
local tokens = {}
for word in string.gmatch(dataset, '([^-]+)') do
table.insert(tokens, word)
end
assert(tokens[2] == 'lmdb', string.format('opt.dataset should be <datset>-lmdb form; opt.dataset = %s', dataset))
return tokens[1]
end

local function _loadImage(path)
local ok, input = pcall(function()
return image.load(path, 3, 'float')
end)

-- Sometimes image.load fails because the file extension does not match the
-- image format. In that case, use image.decompress on a ByteTensor.
if not ok then
local f = io.open(path, 'r')
assert(f, 'Error reading: ' .. tostring(path))
local data = f:read('*a')
f:close()

local b = torch.ByteTensor(string.len(data))
ffi.copy(b:data(), data, b:size(1))

input = image.decompress(b, 3, 'float')
end

return input
end

local function getItem(basedir, imageClass, imagePath, i)
local path = ffi.string(imagePath[i]:data())
local image = _loadImage(paths.concat(basedir, path))
local class = imageClass[i]

return {
input = image,
target = class,
}
end



-- Init opt
local opt = {}
opt.gen = 'gen'
opt.dataset = 'imagenet-lmdb'
opt.data = '/media/data1/image/ilsvrc15/ILSVRC2015/Data/CLS-LOC'

opt.data_lmdb = '/media/data1/image'
opt.shuffle = true
--print(opt)

-- Load imageInfo
local cachePath = paths.concat(opt.gen, split(opt.dataset) .. '.t7')
if not paths.filep(cachePath) or not isvalid(opt, cachePath) then
paths.mkdir('gen')

local script = paths.dofile(split(opt.dataset) .. '-gen.lua')
script.exec(opt, cachePath)
end
local imageInfo = torch.load(cachePath)
--print(imageInfo)

-- Create LMDB
local lmdb = require 'lmdb'

local train_env = lmdb.env{
Path = paths.concat(opt.data_lmdb, 'train_lmdb'),
Name = 'train_lmdb'
}

local val_env= lmdb.env{
Path = paths.concat(opt.data_lmdb, 'val_lmdb'),
Name = 'val_lmdb'
}

local path = ffi.string(imageInfo.train.imagePath[1]:data())
--local image = self:_loadImage(paths.concat(self.dir, path))
--local class = self.imageInfo.imageClass[i]
print(path)

local n_images = (#imageInfo.train.imagePath)[1]
print(string.format("n_image: %d", n_images))

local idxs
if opt.shuffle then
idxs = torch.randperm(n_images)
else
idxs = torch.range(1, n_images)
end
print(string.format("opt.shuffle: %s, idxs[1]: %d", opt.shuffle, idxs[1]))

local basedir = paths.concat(imageInfo.basedir, 'train')
--local item = getItem(basedir, imageInfo.train.imageClass, imageInfo.train.imagePath, idxs[1])
--print(item.target)
----print(item.input)
--print(#item.input)
--print(item.input[1][1][1])

train_env:open()
local txn = train_env:txn()
local cursor = txn:cursor()
for i = 1, 1000 do --n_images do
local item = getItem(basedir, imageInfo.train.imageClass, imageInfo.train.imagePath, idxs[i])

cursor:put(string.format("%07d", i), item, lmdb.C.MDB_NODUPDATA)
if i % 100 == 0 then
txn:commit()
print(train_env:stat())
collectgarbage()
txn = train_env:txn()
cursor = txn:cursor()
end
xlua.progress(i, n_images)
end
txn:commit()
train_env:close()

--[[
local sys = require 'sys'
local n_test = 5000
sys.tic()
-------Read-------
train_env:open()
print(train_env:stat()) -- Current status
local reader = train_env:txn(true) --Read-only transaction
--local y = torch.Tensor(10,3,256,256)
local y = {}

local idxs = torch.randperm(n_test)
for i=1,n_test do
local item = reader:get(string.format("%07d", idxs[i]))
if i % 1000 == 0 then
print(string.format('%d: %d', i, idxs[i]))
print(#item.input)
end
--print(item)
--print(#item.input)
--print(item.input[1][1][1])
end
reader:abort()
train_env:close()
print(sys.toc())
collectgarbage()

sys.tic()
-------Read-------
train_env:open()
print(train_env:stat()) -- Current status
local reader = train_env:txn(true) --Read-only transaction
--local y = torch.Tensor(10,3,256,256)
local y = {}

for i=1,n_test do
local item = reader:get(string.format("%07d", i))
if i % 1000 == 0 then
print(string.format('%d: %d', i, i))
print(#item.input)
end
--print(item)
--print(#item.input)
--print(item.input[1][1][1])
end
reader:abort()
train_env:close()
print(sys.toc())
]]--
87 changes: 87 additions & 0 deletions datasets/imagenet-lmdb.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
--
-- Copyright (c) 2016, Facebook, Inc.
-- All rights reserved.
--
-- This source code is licensed under the BSD-style license found in the
-- LICENSE file in the root directory of this source tree. An additional grant
-- of patent rights can be found in the PATENTS file in the same directory.
--
-- ImageNet dataset loader
--

local image = require 'image'
local paths = require 'paths'
local t = require 'datasets/transforms'
local lmdb = require 'lmdb'

local M = {}
local ImagenetLMDBDataset = torch.class('resnet.ImagenetLMDBDataset', M)

function ImagenetLMDBDataset:__init(imageInfo, opt, split)
self.imageInfo = imageInfo[split]
self.opt = opt
self.split = split
--self.dir = paths.concat(opt.data, split)
--assert(paths.dirp(self.dir), 'directory does not exist: ' .. self.dir)

self.env = lmdb.env{
Path = paths.concat(opt.data, string.format('%s_lmdb', split)),
Name = string.format('%s_lmdb', split)
}
assert(env:open(), 'directory does not exist: ' .. string.format('%s_lmdb', split))
self.stat = env:stat() -- Current status

self.reader = env:txn(true) --Read-only transaction
self.idxs = torch.randperm(n_images)
assert(self.imageInfo.imageClass:size(1) == #self.idxs, string.format('Something wrong with lmdb. The lmdb db should have %d number of items, but it has %d', self.imageInfo.imageClass:size(1), #self.idxs))
end

function ImagenetLMDBDataset:get(i)
local item = reader:get(string.format("%07d", self.idxs[i]))
return item
end

function ImagenetLMDBDataset:size()
return self.imageInfo.imageClass:size(1)
end

-- Computed from random subset of ImageNet training images
local meanstd = {
mean = { 0.485, 0.456, 0.406 },
std = { 0.229, 0.224, 0.225 },
}
local pca = {
eigval = torch.Tensor{ 0.2175, 0.0188, 0.0045 },
eigvec = torch.Tensor{
{ -0.5675, 0.7192, 0.4009 },
{ -0.5808, -0.0045, -0.8140 },
{ -0.5836, -0.6948, 0.4203 },
},
}

function ImagenetLMDBDataset:preprocess()
if self.split == 'train' then
return t.Compose{
t.RandomSizedCrop(224),
t.ColorJitter({
brightness = 0.4,
contrast = 0.4,
saturation = 0.4,
}),
t.Lighting(0.1, pca.eigval, pca.eigvec),
t.ColorNormalize(meanstd),
t.HorizontalFlip(0.5),
}
elseif self.split == 'val' then
local Crop = self.opt.tenCrop and t.TenCrop or t.CenterCrop
return t.Compose{
t.Scale(256),
t.ColorNormalize(meanstd),
Crop(224),
}
else
error('invalid split: ' .. self.split)
end
end

return M.ImagenetLMDBDataset
11 changes: 7 additions & 4 deletions datasets/imagenet.lua
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,13 @@ local pca = {
},
}

function ImagenetDataset:preprocess()
function ImagenetDataset:preprocess(minSize, cropSize)
-- minSize : 256, cropSize : 224 for resnet
-- minSize : 328, cropSize : 299 for inception-resnet-v2

if self.split == 'train' then
return t.Compose{
t.RandomSizedCrop(224),
t.RandomSizedCrop(cropSize),
t.ColorJitter({
brightness = 0.4,
contrast = 0.4,
Expand All @@ -93,9 +96,9 @@ function ImagenetDataset:preprocess()
elseif self.split == 'val' then
local Crop = self.opt.tenCrop and t.TenCrop or t.CenterCrop
return t.Compose{
t.Scale(256),
t.Scale(minSize),
t.ColorNormalize(meanstd),
Crop(224),
Crop(cropSize),
}
else
error('invalid split: ' .. self.split)
Expand Down
Loading