Skip to content

Commit

Permalink
backport @defstruct from SimpleStructs.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
pluskid committed Dec 3, 2015
1 parent 67c8bb2 commit aa1b600
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 17 deletions.
82 changes: 68 additions & 14 deletions src/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,35 +161,79 @@ is available.
The macro will define a constructor that could accept
the keyword arguments.
"""
macro defstruct(name, super_name, fields)
@assert fields.head == :tuple
fields = fields.args
macro defstruct(name, fields)
_defstruct_impl(false, name, fields)
end

"""A convenient macro to define immutable structs. The same as
`@defstruct` except that the defined type is immutable.
"""
macro defimmutable(name, fields)
_defstruct_impl(true, name, fields)
end

"""Internal use only, this value is used to indicate a required value
is not specified.
"""
immutable __Undefined
end

function _defstruct_impl(is_immutable, name, fields)
if isa(fields, Expr) && fields.head == :tuple
fields = fields.args
else
fields = [fields]
end
@assert length(fields) > 0
name = esc(name)
super_name = esc(super_name)

if isa(name, Symbol)
name = esc(name)
super_name = :Any
else
@assert(isa(name, Expr) && name.head == :comparison && length(name.args) == 3 && name.args[2] == :(<:),
"name must be of form 'Name <: SuperType'")
@assert(isa(name.args[1], Symbol) && isa(name.args[3], Symbol))
super_name = esc(name.args[3])
name = esc(name.args[1])
end

field_defs = Array(Expr, length(fields)) # :(field2 :: Int)
field_names = Array(Expr, length(fields)) # :field2
field_defaults = Array(Expr, length(fields)) # :(field2 = 0)
field_types = Array(Expr, length(fields)) # Int
field_asserts = Array(Expr, length(fields)) # :(field2 >= 0)
required_field = Symbol[]

for i = 1:length(fields)
field = fields[i]
if field.head == :tuple
field_asserts[i] = esc(field.args[2])
field = field.args[1]
end
field_defs[i] = esc(field.args[1])
field_names[i] = esc(field.args[1].args[1])
field_types[i] = esc(field.args[1].args[2])
field_defaults[i] = Expr(:kw, field.args[1].args[1], esc(field.args[2]))
if field.head == :(=)
fname = field.args[1].args[1]
field_defs[i] = esc(field.args[1])
field_names[i] = esc(fname)
field_types[i] = esc(field.args[1].args[2])
field_defaults[i] = Expr(:kw, fname, esc(field.args[2]))
else
# no default value provided, required field
fname = field.args[1]
field_defs[i] = esc(field)
field_names[i] = esc(fname)
field_types[i] = esc(field.args[2])
field_defaults[i] = Expr(:kw, fname, __Undefined())
push!(required_field, fname)
end
end

# body of layer type, defining fields
type_body = Expr(:block, field_defs...)

# constructor
requires = map(required_field) do fname
:(@assert(!isa($fname, __Undefined), "value for " * string($fname) * " is required"))
end
converts = map(zip(field_names, field_types)) do param
f_name, f_type = param
:($f_name = convert($f_type, $f_name))
Expand All @@ -198,15 +242,25 @@ macro defstruct(name, super_name, fields)
:(@assert($(field_asserts[i])))
end
construct = Expr(:call, name, field_names...)
ctor_body = Expr(:block, converts..., asserts..., construct)
ctor_body = Expr(:block, requires..., converts..., asserts..., construct)
ctor_def = Expr(:call, name, Expr(:parameters, field_defaults...))
ctor = Expr(:(=), ctor_def, ctor_body)

quote
type $(name) <: $super_name
$type_body
if is_immutable
quote
immutable $(name) <: $(super_name)
$type_body
end

$ctor
end
else
quote
type $(name) <: $(super_name)
$type_body
end

$ctor
$ctor
end
end
end
2 changes: 1 addition & 1 deletion src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ function _create_kvstore(kv_type :: Base.Symbol, num_device :: Int, arg_params :
return (kv, update_on_kvstore)
end

@defstruct TrainingOptions Any (
@defstruct TrainingOptions (
initializer :: AbstractInitializer = UniformInitializer(0.01),
n_epoch :: Int = 10,
eval_data :: Union{Void, AbstractDataProvider} = nothing,
Expand Down
2 changes: 1 addition & 1 deletion src/optimizers/adam.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
@defstruct ADAMOptions AbstractOptimizerOptions (
@defstruct ADAMOptions <: AbstractOptimizerOptions (
(lr :: Real = 0.001, lr > 0),
(grad_clip :: Real = 0, grad_clip >= 0),
(weight_decay :: Real = 0.00001, weight_decay >= 0),
Expand Down
2 changes: 1 addition & 1 deletion src/optimizers/sgd.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
@defstruct SGDOptions AbstractOptimizerOptions (
@defstruct SGDOptions <: AbstractOptimizerOptions (
(lr :: Real = 0.01, lr > 0),
(momentum :: Real = 0.0, momentum >= 0),
(grad_clip :: Real = 0, grad_clip >= 0),
Expand Down

0 comments on commit aa1b600

Please sign in to comment.