Verified Commit de4bc741 authored by Vladimír Štill's avatar Vladimír Štill
Browse files

CFL: Implement cached version of CYK for batch processing

parent 76284e76
Loading
Loading
Loading
Loading
+49 −4
Original line number Diff line number Diff line
@@ -536,7 +536,7 @@ class CFG:
                right_ce = may_pop(right_words - left_words)

        def try_word(maybe_ce: Optional[CFG.Word], rng: CFGRandom,
                     other: CFG, length: int) -> Optional[CFG.Word]:
                     other: CachedCYK, length: int) -> Optional[CFG.Word]:
            if maybe_ce is not None:
                return maybe_ce

@@ -555,7 +555,7 @@ class CFG:
        if max_cmp_len is None:
            max_cmp_len = min(max(pow(2, len(left.nonterminals) + 1),
                                  pow(2, len(right.nonterminals) + 1)),
                              100)
                              25)
            print(f"max_cmp_len = {max_cmp_len}")

        if full_cmp_len > 0:
@@ -586,12 +586,14 @@ class CFG:
                    return mkres()

        left_rnd = CFGRandom(left)
        left_cyk = CachedCYK(left)
        right_rnd = CFGRandom(right)
        right_cyk = CachedCYK(right)

        for length in range(full_cmp_len + 1, max_cmp_len + 1):
            for _ in range(random_samples):
                left_ce = try_word(left_ce, left_rnd, right, length)
                right_ce = try_word(right_ce, right_rnd, left, length)
                left_ce = try_word(left_ce, left_rnd, right_cyk, length)
                right_ce = try_word(right_ce, right_rnd, left_cyk, length)
                if left_ce is not None and right_ce is not None:
                    return mkres()
            print(f"Tested for length {length}")
@@ -706,3 +708,46 @@ class CFGRandom:

            sentence = random.choices(candidates, weights=weights)[0]
        return typing.cast(CFG.Word, sentence)


class CachedCYK:
    def __init__(self, cfg: CFG):
        self.cfg = cfg.cnf()
        self.cache: Dict[CFG.Word, Set[Nonterminal]] = dict()
        for src, dst in self.cfg.productions():
            if len(dst) <= 1:
                dst = typing.cast(CFG.Word, dst)
                if dst not in self.cache:
                    self.cache[dst] = {src}
                else:
                    self.cache[dst].add(src)

    def generates(self, word: Union[str, Iterable[Terminal]]) -> bool:
        if isinstance(word, str):
            word = tuple(Terminal(x) for x in word)
        else:
            word = tuple(word)
        return self.cfg.init in self._generates(word)

    def _generates(self, word: CFG.Word) -> Set[Nonterminal]:
        if word in self.cache:
            return self.cache[word]

        out: Set[Nonterminal] = set()
        for i in range(1, len(word)):
            alpha = word[:i]
            beta = word[i:]
            alpha_nterms = self._generates(alpha)
            if len(alpha_nterms) == 0:
                continue
            beta_nterms = self._generates(beta)
            if len(beta_nterms) == 0:
                continue
            for src, dst in self.cfg.productions():
                if len(dst) != 2:
                    continue
                if dst[0] in alpha_nterms and dst[1] in beta_nterms:
                    out.add(src)

        self.cache[word] = out
        return out