Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hide PyTorch trace compilation warnings #185

Merged
merged 3 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 4 additions & 13 deletions torchquad/integration/grid_integrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
_linspace_with_grads,
expand_func_values_and_squeeze_integral,
_setup_integration_domain,
_torch_trace_without_warnings,
)


Expand Down Expand Up @@ -208,8 +209,6 @@ def compiled_integrate(fn, integration_domain):
elif backend == "torch":
# Torch requires explicit tracing with example inputs.
def do_compile(example_integrand):
import torch

# Define traceable first and third steps
def step1(integration_domain):
grid_points, hs, n_per_dim = self.calculate_grid(
Expand All @@ -218,7 +217,7 @@ def step1(integration_domain):
return (
grid_points,
hs,
torch.Tensor([n_per_dim]),
anp.array([n_per_dim], like="torch"),
) # n_per_dim is constant

dim = int(integration_domain.shape[0])
Expand All @@ -229,7 +228,7 @@ def step3(function_values, hs, integration_domain):
)

# Trace the first step
step1 = torch.jit.trace(step1, (integration_domain,))
step1 = _torch_trace_without_warnings(step1, (integration_domain,))

# Get example input for the third step
grid_points, hs, n_per_dim = step1(integration_domain)
Expand All @@ -241,15 +240,7 @@ def step3(function_values, hs, integration_domain):
)

# Trace the third step
# Avoid the warnings about a .grad attribute access of a
# non-leaf Tensor
if hs.requires_grad:
hs = hs.detach()
hs.requires_grad = True
if function_values.requires_grad:
function_values = function_values.detach()
function_values.requires_grad = True
step3 = torch.jit.trace(
step3 = _torch_trace_without_warnings(
step3, (function_values, hs, integration_domain)
)

Expand Down
21 changes: 11 additions & 10 deletions torchquad/integration/monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
from loguru import logger

from .base_integrator import BaseIntegrator
from .utils import _setup_integration_domain, expand_func_values_and_squeeze_integral
from .utils import (
_setup_integration_domain,
expand_func_values_and_squeeze_integral,
_torch_trace_without_warnings,
)
from .rng import RNG


Expand Down Expand Up @@ -195,8 +199,6 @@ def compiled_integrate(fn, integration_domain):
elif backend == "torch":
# Torch requires explicit tracing with example inputs.
def do_compile(example_integrand):
import torch

# Define traceable first and third steps
def step1(integration_domain):
return self.calculate_sample_points(
Expand All @@ -206,7 +208,9 @@ def step1(integration_domain):
step3 = self.calculate_result

# Trace the first step (which is non-deterministic)
step1 = torch.jit.trace(step1, (integration_domain,), check_trace=False)
step1 = _torch_trace_without_warnings(
step1, (integration_domain,), check_trace=False
)

# Get example input for the third step
sample_points = step1(integration_domain)
Expand All @@ -215,12 +219,9 @@ def step1(integration_domain):
)

# Trace the third step
if function_values.requires_grad:
# Avoid the warning about a .grad attribute access of a
# non-leaf Tensor
function_values = function_values.detach()
function_values.requires_grad = True
step3 = torch.jit.trace(step3, (function_values, integration_domain))
step3 = _torch_trace_without_warnings(
step3, (function_values, integration_domain)
)

# Define a compiled integrate function
def compiled_integrate(fn, integration_domain):
Expand Down
15 changes: 15 additions & 0 deletions torchquad/integration/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,3 +261,18 @@ def wrap(*args, **kwargs):
return f(*args, **kwargs)

return wrap


def _torch_trace_without_warnings(*args, **kwargs):
"""Execute `torch.jit.trace` on the passed arguments and hide tracer warnings

PyTorch can show warnings about traces being potentially incorrect because
the Python3 control flow is not completely recorded.
This function can be used to hide the warnings in situations where they are
false positives.
"""
import torch

with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
return torch.jit.trace(*args, **kwargs)
Loading