Skip to content

Commit

Permalink
Merge pull request #620 from mlverse/bugfix/trace-constants
Browse files Browse the repository at this point in the history
Constants in traced modules
  • Loading branch information
dfalbel committed Jul 27, 2021
2 parents 38e5547 + 51f78fa commit e3b107a
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 10 deletions.
17 changes: 11 additions & 6 deletions R/trace.R
Original file line number Diff line number Diff line change
Expand Up @@ -229,12 +229,17 @@ create_script_module <- function(mod) {
module$register_module(name, create_script_module(child))
})

constants <- names(mod)[!names(mod) %in% module_ignored_names]
walk(constants, function(name) {
if (rlang::is_closure(mod[[name]])) return()
# TODO catch invalid types and raise a warning listing their names.
module$add_constant(name, mod[[name]])
})

# Let's not keep the constants in the module right now as it might cause more
# problems than benefits. In pytorch they are only added if their name is in
# `__constants__` and we are using `torch.jit.script`, not `torch.jit.trace`.

# constants <- names(mod)[!names(mod) %in% module_ignored_names]
# walk(constants, function(name) {
# if (rlang::is_closure(mod[[name]])) return()
# # TODO catch invalid types and raise a warning listing their names.
# module$add_constant(name, mod[[name]])
# })

module
}
Expand Down
5 changes: 5 additions & 0 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@

using namespace Rcpp;

#ifdef RCPP_USE_GLOBAL_ROSTREAM
Rcpp::Rostream<true>& Rcpp::Rcout = Rcpp::Rcpp_cout_get();
Rcpp::Rostream<false>& Rcpp::Rcerr = Rcpp::Rcpp_cerr_get();
#endif

// cpp_set_lantern_allocator
void cpp_set_lantern_allocator(uint64_t threshold_call_gc);
RcppExport SEXP _torch_cpp_set_lantern_allocator(SEXP threshold_call_gcSEXP) {
Expand Down
11 changes: 11 additions & 0 deletions tests/testthat/test-nn-rnn.R
Original file line number Diff line number Diff line change
Expand Up @@ -300,3 +300,14 @@ test_that("lstm and gru works with packed sequences", {
expect_tensor_shape(unpack[[1]], c(4, 3, 4))

})

test_that("gru can be traced", {
x <- nn_gru(10, 10)
tr <- jit_trace(x, torch_randn(10, 10, 10))

v <- torch_randn(10, 10, 10)
expect_equal_to_tensor(
x(v)[[1]],
tr(v)[[1]]
)
})
4 changes: 0 additions & 4 deletions tests/testthat/test-trace.R
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,6 @@ test_that("trace a nn module", {
regexp = NA
)

expect_equal(m$constant, 1)
expect_equal(m$hello, list(torch_tensor(1), torch_tensor(2), "hello"))
expect_length(m$parameters, 5)
expect_length(m$buffers, 4)
expect_length(m$modules, 3)
Expand Down Expand Up @@ -308,8 +306,6 @@ test_that("we can save traced modules", {

m <- jit_load("tracedmodule.pt")

expect_equal(m$constant, 1)
expect_equal(m$hello, list(torch_tensor(1), torch_tensor(2), "hello"))
expect_length(m$parameters, 5)
expect_length(m$buffers, 4)
expect_length(m$modules, 3)
Expand Down

0 comments on commit e3b107a

Please sign in to comment.