python/mypy

Mypy is unable to detect if a builtin class supports a generic protocol. Unexpected behavior with custom classes too.

AgarwalPragy opened this issue · 1 comments

Bug Report

Mypy is unable to detect if a class supports a generic protocol.

To Reproduce

Create a Wrapper[T] that supports arithmetic on its values, that is

if
    A + B -> C
then
    Wrapper[A] + Wrapper[B] -> Wrapper[C]
    Wrapper[A] + B -> Wrapper[C]
    A + Wrapper[B] -> Wrapper[C]

https://mypy-play.net/?mypy=latest&python=3.12&flags=strict&gist=f78c1a5896a21c3c76a9f4380a02ae69

from __future__ import annotations

from typing import TypeVar, Protocol, Generic
from typing_extensions import reveal_type


_T_contra = TypeVar('_T_contra', contravariant=True)
_T_co = TypeVar('_T_co', covariant=True)


class SupportsAdd(Protocol[_T_contra, _T_co]):
    def __add__(self, other: _T_contra, /) -> _T_co: ...

class SupportsRAdd(Protocol[_T_contra, _T_co]):
    def __radd__(self, other: _T_contra, /) -> _T_co: ...


_T = TypeVar('_T', covariant=True)
_T_left = TypeVar('_T_left', covariant=True)
_T_right = TypeVar('_T_right', covariant=True)
_T_out = TypeVar('_T_out', covariant=True)


class Wrapper(Generic[_T]):
    val: _T

    def __init__(self, val: _T):
        self.val = val

    def __add__(self: Wrapper[SupportsAdd[_T_right, _T_out]], right: _T_right | Wrapper[_T_right]) -> Wrapper[_T_out]:
        if isinstance(right, Wrapper):
            return Wrapper(self.val + right.val)
        return Wrapper(self.val + right)
    
    def __radd__(self: Wrapper[SupportsRAdd[_T_left, _T_out]], left: _T_left | Wrapper[_T_left]) -> Wrapper[_T_out]:
        if isinstance(left, Wrapper):
            return Wrapper(left.val + self.val)
        return Wrapper(left + self.val)

    def __repr__(self) -> str:
        return f'Wrapper({self.val})'




class Snake:
    # Snake + Legs = Lizard
    def __add__(self, other: Legs) -> Lizard:
        if not isinstance(other, Legs):
            return NotImplemented
        return Lizard()
    
    # Legs + Snake = Dragon
    def __radd__(self, other: Legs) -> Dragon:
        if not isinstance(other, Legs):
            return NotImplemented
        return Dragon()

class Legs: ...
class Lizard: ...
class Dragon: ...



# What works as expected
reveal_type(Wrapper(10) + Wrapper(1j))  # Wrapper[complex]
reveal_type(10 + Wrapper(1j))  # Wrapper[complex]
reveal_type(Wrapper(1j) + Wrapper(10))  # Wrapper[complex]
reveal_type(Wrapper(1j) + 10)  # Wrapper[complex]
reveal_type(Wrapper(10.0) + Wrapper(2))   # Wrapper[float]
reveal_type(Wrapper(10.0) + 2)   # Wrapper[float]
reveal_type(Snake() + Legs())  # Lizard
reveal_type(Legs() + Snake())  # Dragon
reveal_type(Wrapper(Snake()) + Wrapper(Legs()))  # Wrapper[Lizard]
reveal_type(Wrapper(Snake()) + Legs())  # Wrapper[Lizard]
reveal_type(Wrapper(Legs()) + Wrapper(Snake()))  # Wrapper[Dragon]
reveal_type(Legs() + Wrapper(Snake()))  # Wrapper[Dragon]
Wrapper(Snake()) + Wrapper(Snake())  # correctly gives error: error: Argument 1 to "Wrapper" has incompatible type "Snake"; expected "Legs"  [arg-type]
Wrapper(10.0) + Wrapper('haha')  # correctly gives error: Argument 1 to "Wrapper" has incompatible type "str"; expected "float"  [arg-type]


# What doesn't work as expected    
reveal_type(Wrapper(10) + 1j)  # unexpected error: Unsupported operand types for + ("Wrapper[int]" and "complex")  [operator]
reveal_type(1j + Wrapper(10))  # unexpected error: Unsupported operand types for + ("complex" and "Wrapper[int]")  [operator]
reveal_type(Snake() + Wrapper(Legs()))  # unexpected error: Unsupported operand types for + ("Snake" and "Wrapper[Legs]")  [operator]
reveal_type(Wrapper(Legs()) + Snake())  # unpexpected error: Unsupported operand types for + ("Wrapper[Legs]" and "Snake")  [operator]

Expected Behavior

Wrapper[int] + complex => Wrapper[complex]
complex + Wrapper[int] => Wrapper[complex]
Snake + Wrapper[Legs] => Wrapper[Lizard]
Wrapper[Legs] + Snake => Wrapper[Dragon]

Actual Behavior

error: Unsupported operand types for + ("Wrapper[int]" and "complex")  [operator]
error: Unsupported operand types for + ("complex" and "Wrapper[int]")  [operator]
error: Unsupported operand types for + ("Snake" and "Wrapper[Legs]")  [operator]
error: Unsupported operand types for + ("Wrapper[Legs]" and "Snake")  [operator]

Is there any workaround by any chance?

I'm trying to create a library for symbolic computation, and I need to handle things like Expression[Snake] + Expression[Leg] => Expression[Lizard] if Snake + Leg => Lizard