# Test cases related to the functools.singledispatch decorator
# Most of these tests are marked as xfails because mypyc doesn't support singledispatch yet
# (These tests will be re-enabled when mypyc supports singledispatch)

[case testSpecializedImplementationUsed]
from functools import singledispatch

@singledispatch
def fun(arg) -> bool:
    return False

@fun.register
def fun_specialized(arg: str) -> bool:
    return True

def test_specialize() -> None:
    assert fun('a')
    assert not fun(3)

[case testSubclassesOfExpectedTypeUseSpecialized]
from functools import singledispatch
class A: pass
class B(A): pass

@singledispatch
def fun(arg) -> bool:
    return False

@fun.register
def fun_specialized(arg: A) -> bool:
    return True

def test_specialize() -> None:
    assert fun(B())
    assert fun(A())

[case testSuperclassImplementationNotUsedWhenSubclassHasImplementation]
from functools import singledispatch
class A: pass
class B(A): pass

@singledispatch
def fun(arg) -> bool:
    # shouldn't be using this
    assert False

@fun.register
def fun_specialized(arg: A) -> bool:
    return False

@fun.register
def fun_specialized2(arg: B) -> bool:
    return True

def test_specialize() -> None:
    assert fun(B())
    assert not fun(A())

[case testMultipleUnderscoreFunctionsIsntError]
from functools import singledispatch

@singledispatch
def fun(arg) -> str:
    return 'default'

@fun.register
def _(arg: str) -> str:
    return 'str'

@fun.register
def _(arg: int) -> str:
    return 'int'

# extra function to make sure all 3 underscore functions aren't treated as one OverloadedFuncDef
def a(b): pass

@fun.register
def _(arg: list) -> str:
    return 'list'

def test_singledispatch() -> None:
    assert fun(0) == 'int'
    assert fun('a') == 'str'
    assert fun([1, 2]) == 'list'
    assert fun({'a': 'b'}) == 'default'

[case testCanRegisterCompiledClasses]
from functools import singledispatch
class A: pass

@singledispatch
def fun(arg) -> bool:
    return False
@fun.register
def fun_specialized(arg: A) -> bool:
    return True

def test_singledispatch() -> None:
    assert fun(A())
    assert not fun(1)

[case testTypeUsedAsArgumentToRegister]
from functools import singledispatch

@singledispatch
def fun(arg) -> bool:
    return False

@fun.register(int)
def fun_specialized(arg) -> bool:
    return True

def test_singledispatch() -> None:
    assert fun(1)
    assert not fun('a')

[case testUseRegisterAsAFunction]
from functools import singledispatch

@singledispatch
def fun(arg) -> bool:
    return False

def fun_specialized_impl(arg) -> bool:
    return True

fun.register(int, fun_specialized_impl)

def test_singledispatch() -> None:
    assert fun(0)
    assert not fun('a')

[case testRegisterDoesntChangeFunction]
from functools import singledispatch

@singledispatch
def fun(arg) -> bool:
    return False

@fun.register(int)
def fun_specialized(arg) -> bool:
    return True

def test_singledispatch() -> None:
    assert fun_specialized('a')

# TODO: turn this into a mypy error
[case testNoneIsntATypeWhenUsedAsArgumentToRegister]
from functools import singledispatch

@singledispatch
def fun(arg) -> bool:
    return False

try:
    @fun.register
    def fun_specialized(arg: None) -> bool:
        return True
except TypeError:
    pass

[case testRegisteringTheSameFunctionSeveralTimes]
from functools import singledispatch

@singledispatch
def fun(arg) -> bool:
    return False

@fun.register(int)
@fun.register(str)
def fun_specialized(arg) -> bool:
    return True

def test_singledispatch() -> None:
    assert fun(0)
    assert fun('a')
    assert not fun([1, 2])

[case testTypeIsAnABC]
from functools import singledispatch
from collections.abc import Mapping

@singledispatch
def fun(arg) -> bool:
    return False

@fun.register
def fun_specialized(arg: Mapping) -> bool:
    return True

def test_singledispatch() -> None:
    assert not fun(1)
    assert fun({'a': 'b'})

[case testSingleDispatchMethod-xfail]
from functools import singledispatchmethod
class A:
    @singledispatchmethod
    def fun(self, arg) -> str:
        return 'default'

    @fun.register
    def fun_int(self, arg: int) -> str:
        return 'int'

    @fun.register
    def fun_str(self, arg: str) -> str:
        return 'str'

def test_singledispatchmethod() -> None:
    x = A()
    assert x.fun(5) == 'int'
    assert x.fun('a') == 'str'
    assert x.fun([1, 2]) == 'default'

[case testSingleDispatchMethodWithOtherDecorator-xfail]
from functools import singledispatchmethod
class A:
    @singledispatchmethod
    @staticmethod
    def fun(arg) -> str:
        return 'default'

    @fun.register
    @staticmethod
    def fun_int(arg: int) -> str:
        return 'int'

    @fun.register
    @staticmethod
    def fun_str(arg: str) -> str:
        return 'str'

def test_singledispatchmethod() -> None:
    x = A()
    assert x.fun(5) == 'int'
    assert x.fun('a') == 'str'
    assert x.fun([1, 2]) == 'default'

[case testSingledispatchTreeSumAndEqual]
from functools import singledispatch

class Tree:
    pass
class Leaf(Tree):
    pass
class Node(Tree):
    def __init__(self, value: int, left: Tree, right: Tree) -> None:
        self.value = value
        self.left = left
        self.right = right

@singledispatch
def calc_sum(x: Tree) -> int:
    raise TypeError('invalid type for x')

@calc_sum.register
def _(x: Leaf) -> int:
    return 0

@calc_sum.register
def _(x: Node) -> int:
    return x.value + calc_sum(x.left) + calc_sum(x.right)

@singledispatch
def equal(to_compare: Tree, known: Tree) -> bool:
    raise TypeError('invalid type for x')

@equal.register
def _(to_compare: Leaf, known: Tree) -> bool:
    return isinstance(known, Leaf)

@equal.register
def _(to_compare: Node, known: Tree) -> bool:
    if isinstance(known, Node):
        if to_compare.value != known.value:
            return False
        else:
            return equal(to_compare.left, known.left) and equal(to_compare.right, known.right)
    return False

def build(n: int) -> Tree:
    if n == 0:
        return Leaf()
    return Node(n, build(n - 1), build(n - 1))

def test_sum_and_equal():
    tree = build(5)
    tree2 = build(5)
    tree2.right.right.right.value = 10
    assert calc_sum(tree) == 57
    assert calc_sum(tree2) == 65
    assert equal(tree, tree)
    assert not equal(tree, tree2)
    tree3 = build(4)
    assert not equal(tree, tree3)

[case testSimulateMypySingledispatch]
from functools import singledispatch
from mypy_extensions import trait
from typing import Iterator, Union, TypeVar, Any, List, Type
# based on use of singledispatch in stubtest.py
class Error:
    def __init__(self, msg: str) -> None:
        self.msg = msg

@trait
class Node: pass

class MypyFile(Node): pass
class TypeInfo(Node): pass


@trait
class SymbolNode(Node): pass
@trait
class Expression(Node): pass
class TypeVarLikeExpr(SymbolNode, Expression): pass
class TypeVarExpr(TypeVarLikeExpr): pass
class TypeAlias(SymbolNode): pass

class Missing: pass
MISSING = Missing()

T = TypeVar("T")

MaybeMissing = Union[T, Missing]

@singledispatch
def verify(stub: Node, a: MaybeMissing[Any], b: List[str]) -> Iterator[Error]:
    yield Error('unknown node type')

@verify.register(MypyFile)
def verify_mypyfile(stub: MypyFile, a: MaybeMissing[int], b: List[str]) -> Iterator[Error]:
    if isinstance(a, Missing):
        yield Error("shouldn't be missing")
        return
    if not isinstance(a, int):
        # this check should be unnecessary because of the type signature and the previous check,
        # but stubtest.py has this check
        yield Error("should be an int")
        return
    yield from verify(TypeInfo(), str, ['abc', 'def'])

@verify.register(TypeInfo)
def verify_typeinfo(stub: TypeInfo, a: MaybeMissing[Type[Any]], b: List[str]) -> Iterator[Error]:
    yield Error('in TypeInfo')
    yield Error('hello')

@verify.register(TypeVarExpr)
def verify_typevarexpr(stub: TypeVarExpr, a: MaybeMissing[Any], b: List[str]) -> Iterator[Error]:
    if False:
        yield None

def verify_list(stub, a, b) -> List[str]:
    """Helper function that converts iterator of errors to list of messages"""
    return list(err.msg for err in verify(stub, a, b))

def test_verify() -> None:
    assert verify_list(TypeAlias(), 'a', ['a', 'b']) == ['unknown node type']
    assert verify_list(MypyFile(), MISSING, ['a', 'b']) == ["shouldn't be missing"]
    assert verify_list(MypyFile(), 5, ['a', 'b']) == ['in TypeInfo', 'hello']
    assert verify_list(TypeInfo(), str, ['a', 'b']) == ['in TypeInfo', 'hello']
    assert verify_list(TypeVarExpr(), 'a', ['x', 'y']) == []


[case testArgsInRegisteredImplNamedDifferentlyFromMainFunction]
from functools import singledispatch

@singledispatch
def f(a) -> bool:
    return False

@f.register
def g(b: int) -> bool:
    return True

def test_singledispatch():
    assert f(5)
    assert not f('a')

[case testKeywordArguments]
from functools import singledispatch

@singledispatch
def f(arg, *, kwarg: int = 0) -> int:
    return kwarg + 10

@f.register
def g(arg: int, *, kwarg: int = 5) -> int:
    return kwarg - 10

def test_keywords():
    assert f('a') == 10
    assert f('a', kwarg=3) == 13
    assert f('a', kwarg=7) == 17

    assert f(1) == -5
    assert f(1, kwarg=4) == -6
    assert f(1, kwarg=6) == -4

[case testGeneratorAndMultipleTypesOfIterable]
from functools import singledispatch
from typing import *

@singledispatch
def f(arg: Any) -> Iterable[int]:
    yield 1

@f.register
def g(arg: str) -> Iterable[int]:
    return [0]

def test_iterables():
    assert f(1) != [1]
    assert list(f(1)) == [1]
    assert f('a') == [0]

[case testRegisterUsedAtSameTimeAsOtherDecorators]
from functools import singledispatch
from typing import TypeVar

class A: pass
class B: pass

T = TypeVar('T')

def decorator(f: T) -> T:
    return f

@singledispatch
def f(arg) -> int:
    return 0

@f.register
@decorator
def h(arg: str) -> int:
    return 2

def test_singledispatch():
    assert f(1) == 0
    assert f('a') == 2

[case testDecoratorModifiesFunction]
from functools import singledispatch
from typing import Callable, Any

class A: pass

def decorator(f: Callable[[Any], int]) -> Callable[[Any], int]:
    def wrapper(x) -> int:
        return f(x) * 7
    return wrapper

@singledispatch
def f(arg) -> int:
    return 10

@f.register
@decorator
def h(arg: str) -> int:
    return 5


def test_singledispatch():
    assert f('a') == 35
    assert f(A()) == 10

[case testMoreSpecificTypeBeforeLessSpecificType]
from functools import singledispatch
class A: pass
class B(A): pass

@singledispatch
def f(arg) -> str:
    return 'default'

@f.register
def g(arg: B) -> str:
    return 'b'

@f.register
def h(arg: A) -> str:
    return 'a'

def test_singledispatch():
    assert f(B()) == 'b'
    assert f(A()) == 'a'
    assert f(5) == 'default'

[case testMultipleRelatedClassesBeingRegistered]
from functools import singledispatch

class A: pass
class B(A): pass
class C(B): pass

@singledispatch
def f(arg) -> str: return 'default'

@f.register
def _(arg: A) -> str: return 'a'

@f.register
def _(arg: C) -> str: return 'c'

@f.register
def _(arg: B) -> str: return 'b'

def test_singledispatch():
    assert f(A()) == 'a'
    assert f(B()) == 'b'
    assert f(C()) == 'c'
    assert f(1) == 'default'

[case testRegisteredImplementationsInDifferentFiles]
from other_a import f, A, B, C
@f.register
def a(arg: A) -> int:
    return 2

@f.register
def _(arg: C) -> int:
    return 3

def test_singledispatch():
    assert f(B()) == 1
    assert f(A()) == 2
    assert f(C()) == 3
    assert f(1) == 0

[file other_a.py]
from functools import singledispatch

class A: pass
class B(A): pass
class C(B): pass

@singledispatch
def f(arg) -> int:
    return 0

@f.register
def g(arg: B) -> int:
    return 1

[case testOrderCanOnlyBeDeterminedFromMRONotIsinstanceChecks]
from mypy_extensions import trait
from functools import singledispatch

@trait
class A: pass
@trait
class B: pass
class AB(A, B): pass
class BA(B, A): pass

@singledispatch
def f(arg) -> str:
    return "default"
    pass

@f.register
def fa(arg: A) -> str:
    return "a"

@f.register
def fb(arg: B) -> str:
    return "b"

def test_singledispatch():
    assert f(AB()) == "a"
    assert f(BA()) == "b"

[case testCallingFunctionBeforeAllImplementationsRegistered]
from functools import singledispatch

class A: pass
class B(A): pass

@singledispatch
def f(arg) -> str:
    return 'default'

assert f(A()) == 'default'
assert f(B()) == 'default'
assert f(1) == 'default'

@f.register
def g(arg: A) -> str:
    return 'a'

assert f(A()) == 'a'
assert f(B()) == 'a'
assert f(1) == 'default'

@f.register
def _(arg: B) -> str:
    return 'b'

assert f(A()) == 'a'
assert f(B()) == 'b'
assert f(1) == 'default'


[case testDynamicallyRegisteringFunctionFromInterpretedCode]
from functools import singledispatch

class A: pass
class B(A): pass
class C(B): pass
class D(C): pass

@singledispatch
def f(arg) -> str:
    return "default"

@f.register
def _(arg: B) -> str:
    return 'b'

[file register_impl.py]
from native import f, A, B, C

@f.register(A)
def a(arg) -> str:
    return 'a'

@f.register
def c(arg: C) -> str:
    return 'c'

[file driver.py]
from native import f, A, B, C
from register_impl import a, c
# We need a custom driver here because register_impl has to be run before we test this (so that the
# additional implementations are registered)
assert f(C()) == 'c'
assert f(A()) == 'a'
assert f(B()) == 'b'
assert a(C()) == 'a'
assert c(A()) == 'c'

[case testMalformedDynamicRegisterCall]
from functools import singledispatch

@singledispatch
def f(arg) -> None:
    pass
[file register.py]
from native import f
from testutil import assertRaises

with assertRaises(TypeError, 'Invalid first argument to `register()`'):
    @f.register
    def _():
        pass

[file driver.py]
import register

[case testCacheClearedWhenNewFunctionRegistered]
from functools import singledispatch

@singledispatch
def f(arg) -> str:
    return 'default'

[file register.py]
from native import f
class A: pass
class B: pass
class C: pass

# annotated function
assert f(A()) == 'default'
@f.register
def _(arg: A) -> str:
    return 'a'
assert f(A()) == 'a'

# type passed as argument
assert f(B()) == 'default'
@f.register(B)
def _(arg: B) -> str:
    return 'b'
assert f(B()) == 'b'

# 2 argument form
assert f(C()) == 'default'
def c(arg) -> str:
    return 'c'
f.register(C, c)
assert f(C()) == 'c'


[file driver.py]
import register
