Skip to content

Commit

Permalink
workaround for apache#55
Browse files Browse the repository at this point in the history
  • Loading branch information
pluskid committed Jan 16, 2016
1 parent 8edb94b commit 4c52eb8
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 0 deletions.
7 changes: 7 additions & 0 deletions src/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -940,6 +940,13 @@ function _import_ndarray_functions(;gen_docs=false)
_use_vars = Expr(:ref, :MX_handle, [symbol("in$i") for i=1:n_used_vars]...)
_scalars = Expr(:ref, :MX_float, [symbol("sca$i") for i=1:n_scalars]...)
_mut_vars = Expr(:ref, :MX_handle, [symbol("out$i") for i=1:n_mutate_vars]...)

# XXX: hacky way of solving the problem that the arguments of `dot` should be swapped
# See https://github.com/dmlc/MXNet.jl/issues/55
if func_name == :dot
_use_vars.args[2:end] = flipdim(_use_vars.args[2:end], 1)
end

stmt_call = Expr(:call, :_invoke_mxfunction, func_handle, _use_vars, _scalars, _mut_vars)
if n_mutate_vars == 1
stmt_ret = :(return out1)
Expand Down
12 changes: 12 additions & 0 deletions test/unittest/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,17 @@ function test_nd_as_jl()
@test reldiff(copy(z)[:,2:end], copy(x)[:,2:end]) < 1e-6
end

function test_dot()
dims1 = (2, 3)
dims2 = (3, 8)
info("NDArray::dot")

x = mx.zeros(dims1)
y = mx.zeros(dims2)
z = mx.dot(x, y)
@test size(z) == (2, 8)
end


################################################################################
# Run tests
Expand All @@ -276,5 +287,6 @@ test_saveload()
test_clip()
test_sqrt()
test_nd_as_jl()
test_dot()

end

0 comments on commit 4c52eb8

Please sign in to comment.