Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.tensor type annotation does not match mypy #545

Closed
willfrey opened this issue Oct 30, 2020 · 10 comments
Closed

torch.tensor type annotation does not match mypy #545

willfrey opened this issue Oct 30, 2020 · 10 comments
Labels
bug Something isn't working fixed in next version (main) A fix has been implemented and will appear in an upcoming version

Comments

@willfrey
Copy link

willfrey commented Oct 30, 2020

Environment data

  • Language Server version: 2020.10.3
  • OS and version: macOS Catalina Version 10.15.7
  • Python version (& distribution if applicable, e.g. Anaconda): 3.6.11

Expected behaviour

Consider the following:

from typing import TYPE_CHECKING

import torch

if TYPE_CHECKING:
    reveal_type(torch.tensor)

I expect Pylance to infer that torch.tensor symbol is referring only to the torch.tensor(...) callable, as mypy does, and not the union of the callable and the torch/tensor.py module.

Here's what I think it should be.

mypy

Revealed type is 'def (data: Any, dtype: Union[torch._C.dtype, None] =, device: Union[torch._C.device, builtins.str, None] =, requires_grad: builtins.bool =) -> torch.tensor.Tensor'

Pylance

Type of "torch.tensor" is "(data: Any, dtype: dtype | None = None, device: device | str | None = None, requires_grad: _bool = False) -> Tensor"

Actual behaviour

Here's what the inferred types are right now.

mypy

Revealed type is 'def (data: Any, dtype: Union[torch._C.dtype, None] =, device: Union[torch._C.device, builtins.str, None] =, requires_grad: builtins.bool =) -> torch.tensor.Tensor'

Pylance

Type of "torch.tensor" is "Module(".tensor") | (data: Any, dtype: dtype | None = None, device: device | str | None = None, requires_grad: _bool = False) -> Tensor"

Logs

If logs are helpful here, I can include them. I do not think they're super helpful here, though.

Code Snippet / Additional information

I know it's not ideal that there's a name conflict in the PyTorch codebase between the torch.tensor(...) callable and the torch.tensor module, but mypy appears to get the right idea, so there's a way to deduce this somehow.

I can get Pylance to narrow the type definition with an assert statement, as in the following.

from typing import TYPE_CHECKING

import torch

assert callable(torch.tensor)

if TYPE_CHECKING:
    reveal_type(torch.tensor)  # Pylance will match mypy

I also can throw a # type: ignore comment everywhere to silence Pylance but then the return type of torch.tensor(...) is marked as Unknown instead of being torch.Tensor, which is very useful to have inferred.

Thank you!

@jakebailey
Copy link
Member

What version of pytorch are you using? Is it the stable version or the release candidate? There are a couple of odd typing issues with both (#418, #484) that could be in play here.

I believe this would involve the code that does the assignments for module exports; I recall previous bugs in this area (but can't find them). Perhaps this is happening because tensor is imported, then overwritten with a type? The code flow should have picked the last one, I think.

@erictraut Do you have any thoughts?

@jakebailey jakebailey added needs investigation Could be an issue - needs investigation waiting for user response Requires more information from user labels Nov 2, 2020
@erictraut
Copy link
Contributor

I'd need to understand what version of pytorch is being used. I'm not seeing this with the version I'm using.

Based on the symptoms, it appears that the symbol tensor is being imported first as a module and then (probably later in the file) redefined as a function. Pyright already has special-case logic to eliminate the module assignment if it's an from .tensor import X statement, since many users don't realize that tensor is implicitly bound to the submodule in this case.

@willfrey
Copy link
Author

willfrey commented Nov 2, 2020

I'm on 1.7.0, which is the latest stable version available on PyPI.

@jakebailey
Copy link
Member

jakebailey commented Nov 2, 2020

Thanks; I hadn't seen that 1.7 was out yet.

Reading the __init__.py, this line brings the tensor function in via a * import inside a typing.TYPE_CHECKING block, then later it does from .tensor import Tensor.

See #545 (comment)

@jakebailey jakebailey removed the waiting for user response Requires more information from user label Nov 2, 2020
@erictraut
Copy link
Contributor

Yeah, that would explain it if the original definition of tensor is getting overwritten by the module later in the file. I can't think of any clean way to work around that without affecting compatibility and correctness with other libraries. The correct fix is to swap those declarations in __init__.py.

@jakebailey
Copy link
Member

My analysis was wrong, I think I copied the wrong line numbers or some other brainfart.

Line 412 does from .tensor import Tensor.

Line 513 does from torch._C._VariableFunctions import * inside an if TYPE_CHECKING: block.

I'm wondering if there's an issue with the if block specifically.

@erictraut
Copy link
Contributor

I dug into this and found that we were handling the following cases inconsistently:

import torch
from torch import tensor

if TYPE_CHECKING:
    reveal_type(torch.tensor) # function | module
    reveal_type(tensor) # function

In the first case, the type evaluator was considering all declarations in the target module and unioning their types. In the second case, the type evaluator was using only the last declaration found within the file.

I've updated the logic to consistently use only the last declaration, which produces the desired behavior. This change will be in the next release. In the meantime, you can use the from torch import tensor form as a workaround if that suits you.

@erictraut erictraut added bug Something isn't working fixed in next version (main) A fix has been implemented and will appear in an upcoming version and removed needs investigation Could be an issue - needs investigation labels Nov 3, 2020
@willfrey
Copy link
Author

willfrey commented Nov 3, 2020

Thank you for taking care of this!

@jakebailey
Copy link
Member

This issue has been fixed in version 2020.11.0, which we've just released. You can find the changelog here: https://github.com/microsoft/pylance-release/blob/master/CHANGELOG.md#2020110-4-november-2020

@sorenmc
Copy link

sorenmc commented Dec 11, 2021

For anyone experiencing this problem with torch tensors, a way for mypy to pass is to type annotate as

my_tensor: torch.Tensor = func_returning_tensor([1,2,3])

Notice that it is Tensor with a capital T

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working fixed in next version (main) A fix has been implemented and will appear in an upcoming version
Projects
None yet
Development

No branches or pull requests

4 participants