diff --git a/R/trace.R b/R/trace.R index ea195de5e2..f0489ab757 100644 --- a/R/trace.R +++ b/R/trace.R @@ -229,12 +229,17 @@ create_script_module <- function(mod) { module$register_module(name, create_script_module(child)) }) - constants <- names(mod)[!names(mod) %in% module_ignored_names] - walk(constants, function(name) { - if (rlang::is_closure(mod[[name]])) return() - # TODO catch invalid types and raise a warning listing their names. - module$add_constant(name, mod[[name]]) - }) + + # Let's not keep the constants in the module right now as it might cause more + # problems than benefits. In pytorch they are only added if their name is in + # `__constants__` and we are using `torch.jit.script`, not `torch.jit.trace`. + + # constants <- names(mod)[!names(mod) %in% module_ignored_names] + # walk(constants, function(name) { + # if (rlang::is_closure(mod[[name]])) return() + # # TODO catch invalid types and raise a warning listing their names. + # module$add_constant(name, mod[[name]]) + # }) module } diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index 6c56cb9c1c..86acf5f204 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -6,6 +6,11 @@ using namespace Rcpp; +#ifdef RCPP_USE_GLOBAL_ROSTREAM +Rcpp::Rostream& Rcpp::Rcout = Rcpp::Rcpp_cout_get(); +Rcpp::Rostream& Rcpp::Rcerr = Rcpp::Rcpp_cerr_get(); +#endif + // cpp_set_lantern_allocator void cpp_set_lantern_allocator(uint64_t threshold_call_gc); RcppExport SEXP _torch_cpp_set_lantern_allocator(SEXP threshold_call_gcSEXP) { diff --git a/tests/testthat/test-nn-rnn.R b/tests/testthat/test-nn-rnn.R index 21ae4df991..de54edd659 100644 --- a/tests/testthat/test-nn-rnn.R +++ b/tests/testthat/test-nn-rnn.R @@ -300,3 +300,14 @@ test_that("lstm and gru works with packed sequences", { expect_tensor_shape(unpack[[1]], c(4, 3, 4)) }) + +test_that("gru can be traced", { + x <- nn_gru(10, 10) + tr <- jit_trace(x, torch_randn(10, 10, 10)) + + v <- torch_randn(10, 10, 10) + expect_equal_to_tensor( + x(v)[[1]], + tr(v)[[1]] + ) +}) diff --git a/tests/testthat/test-trace.R b/tests/testthat/test-trace.R index 85ff5cf4a4..7b806db458 100644 --- a/tests/testthat/test-trace.R +++ b/tests/testthat/test-trace.R @@ -242,8 +242,6 @@ test_that("trace a nn module", { regexp = NA ) - expect_equal(m$constant, 1) - expect_equal(m$hello, list(torch_tensor(1), torch_tensor(2), "hello")) expect_length(m$parameters, 5) expect_length(m$buffers, 4) expect_length(m$modules, 3) @@ -308,8 +306,6 @@ test_that("we can save traced modules", { m <- jit_load("tracedmodule.pt") - expect_equal(m$constant, 1) - expect_equal(m$hello, list(torch_tensor(1), torch_tensor(2), "hello")) expect_length(m$parameters, 5) expect_length(m$buffers, 4) expect_length(m$modules, 3)