Commit c912c849 authored by Filip Kučerák's avatar Filip Kučerák
Browse files

typing repairs

parent 16bc0e74
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
import sys
sys.path.insert(0, '../..')

from kdtesting import *
from kdtesting import Command, CommandSeq, CommandGroup, ContextRefresher, \
                    ArbitraryContextProvider, VariableA, ConstantA, NumberA

context_provider = ArbitraryContextProvider({
    "number": NumberA(1, 10)
+15 −16
Original line number Diff line number Diff line
@@ -35,6 +35,7 @@ def header(string: str, lines: bool = False) -> str:


Generatable = Any
UnknownType = Any
ValueGenerator = Callable[[], Generatable]


@@ -568,7 +569,8 @@ class Tester:
    def run_test(self, test: Test[Generatable]) -> bool:
        assert False, "run_test not implemented"

    def log(self, priority: int, *args, **kwargs) -> None:
    def log(self, priority: int, *args: UnknownType, **kwargs: UnknownType) \
            -> None:
        if priority <= self.verbosity:
            print(*args, **kwargs)

@@ -671,7 +673,7 @@ class Tester:


class ContextProvider:
    def __init__(self):
    def __init__(self) -> None:
        pass

    def get(self, symbol: str) -> Generatable:
@@ -687,7 +689,7 @@ class ContextProvider:

class ArbitraryContextProvider(ContextProvider):

    def __init__(self, arbitraries: Dict[str, Arbitrary]):
    def __init__(self, arbitraries: Dict[str, Arbitrary[Generatable]]):
        super().__init__()
        self.arbitraries = arbitraries
        self.values: Dict[str, Generatable] = {}
@@ -700,12 +702,14 @@ class ArbitraryContextProvider(ContextProvider):
    def get(self, symbol: str) -> Generatable:
        return self.values[symbol]

    def shrink(self, symbol: str, value: Generatable) -> Generator[Generatable, None, None]:
    def shrink(self, symbol: str, value: Generatable) -> \
            Generator[Generatable, None, None]:
        yield self.arbitraries[symbol].shrink(value)

    def to_str(self, symbol: str, value: Generatable) -> str:
        return self.arbitraries[symbol].to_str(value)


class ContextRefresher(ArbitraryWrapper[T]):

    def __init__(self, arbitrary: Arbitrary[T],
@@ -808,7 +812,8 @@ class CommandSeq(TupleA):
        self.str_del = '\n'
        self.str_end = ''

    def get_symbol(self, symbol: str, values: Tuple[Generatable]) -> List[Generatable]:
    def get_symbol(self, symbol: str, values: Tuple[Generatable]) \
            -> List[Generatable]:
        res = []
        for i, val in enumerate(values):
            res += self.commands[i].get_symbol(symbol, val)
@@ -849,11 +854,6 @@ class CommandSequence(ListA[Tuple[int, Generatable]]):
        self.commands = commands
        self.start_str: CommandAppend = lambda s, x: ""
        self.end_str: CommandAppend = lambda s, x: ""
        self.context_provider = context_provider

    def get(self) -> List[Tuple[int, Generatable]]:
        self.context_provider.refresh_values()
        return super().get()

    def to_str(self, values: List[Tuple[int, Generatable]]) -> str:
        body = ""
@@ -884,7 +884,7 @@ def str_id(x: str) -> str:
    return x


ProcessOutput = Any
ProcessOutput = UnknownType
ProcessOutputTest = Callable[[ProcessOutput, ProcessOutput], bool]


@@ -917,11 +917,11 @@ def stdout_equality() -> ProcessOutputTest:
        ProcessOutputTest: test that compares the raw outputs
    """
    def res_tester(test_out: ProcessOutput, exp_out: ProcessOutput) -> bool:
        return test_out.stdout == exp_out.stdout
        return bool(test_out.stdout == exp_out.stdout)
    return res_tester


OutputTest = Callable[[Any], bool]
OutputTest = Callable[[ProcessOutput], bool]


class OutputTester(Tester):
@@ -1042,9 +1042,8 @@ class RefImplTester(Tester):
_VALGRIND_PROBLEM = 11


def valgrind_no_leaks(output: Any) -> bool:

    return output.returncode != _VALGRIND_PROBLEM
def valgrind_no_leaks(output: ProcessOutput) -> bool:
    return bool(output.returncode != _VALGRIND_PROBLEM)


def create_valgrind_tester(program: str,