Verified Commit 6b1b2652 authored by Vladimír Štill's avatar Vladimír Štill
Browse files

CFG: Add a has_simple_rules function

parent 5cee58cd
Loading
Loading
Loading
Loading
+18 −5
Original line number Original line Diff line number Diff line
@@ -326,14 +326,19 @@ class CFG:
                       new_rules, new_init)
                       new_rules, new_init)
        return CFG(self.nonterminals, self.terminals, new_rules, self.init)
        return CFG(self.nonterminals, self.terminals, new_rules, self.init)


    def remove_simple_rules(self) -> CFG:
    def get_simple_to(self) -> Dict[Nonterminal, Set[Nonterminal]]:
        simple_to: Dict[Nonterminal, Set[Nonterminal]] =\
        """
            {n: {n} for n in self.nonterminals}
        Get a mapping from nonterminal to a set of nonterminals it can be
        *nontrivially* rewritten to (i.e., it does not include the nonterminal
        itself unless there is a rule of form X → X for it")
        """
        simple_to : Dict[Nonterminal, Set[Nonterminal]] \
            = {n: set() for n in self.nonterminals}


        for tracker in ChangeTracker():
        for tracker in ChangeTracker():
            for src_orig in self.nonterminals:
            for src_orig in self.nonterminals:
                # make a copy to avoid error for changing during iteraton
                # make a copy to avoid error for changing during iteraton
                for src in list(simple_to[src_orig]):
                for src in list({src_orig} | simple_to[src_orig]):
                    for prod in self.rules.get(src, []):
                    for prod in self.rules.get(src, []):
                        if not isinstance(prod, Eps) and len(prod) == 1 \
                        if not isinstance(prod, Eps) and len(prod) == 1 \
                                and isinstance(prod[0], Nonterminal) \
                                and isinstance(prod[0], Nonterminal) \
@@ -341,9 +346,17 @@ class CFG:
                            tracker.changed()
                            tracker.changed()
                            simple_to[src_orig].add(prod[0])
                            simple_to[src_orig].add(prod[0])


        return simple_to

    def has_simple_rules(self) -> bool:
        return any_of(lambda v: len(v) != 0, self.get_simple_to().values())

    def remove_simple_rules(self) -> CFG:
        simple_to = self.get_simple_to()

        new_rules: CFG.Rules = dict()
        new_rules: CFG.Rules = dict()
        for src in self.nonterminals:
        for src in self.nonterminals:
            for esrc in simple_to[src]:
            for esrc in {src} | simple_to[src]:
                for prod in self.rules.get(esrc, []):
                for prod in self.rules.get(esrc, []):
                    if len(prod) != 1 or typing.cast(CFG.Symbols, prod)[0] \
                    if len(prod) != 1 or typing.cast(CFG.Symbols, prod)[0] \
                            in self.terminals:
                            in self.terminals:
+17 −0
Original line number Original line Diff line number Diff line
@@ -57,8 +57,10 @@ def test_remove_simple() -> None:
                 C: {(a, a, C), (a, a)}},
                 C: {(a, a, C), (a, a)}},
                S)
                S)
    print(g)
    print(g)
    assert g.has_simple_rules()
    gs = g.remove_simple_rules()
    gs = g.remove_simple_rules()
    print(gs)
    print(gs)
    assert not gs.has_simple_rules()
    assert (B, C) in gs.rules[gs.init]
    assert (B, C) in gs.rules[gs.init]
    assert (a, B) in gs.rules[gs.init]
    assert (a, B) in gs.rules[gs.init]
    assert (a,) in gs.rules[gs.init]
    assert (a,) in gs.rules[gs.init]
@@ -66,6 +68,21 @@ def test_remove_simple() -> None:
    assert (a, a) in gs.rules[gs.init]
    assert (a, a) in gs.rules[gs.init]




def test_remove_simple_2() -> None:
    g = cfl.CFG({S, A, B}, {a},
                {S: {(a, A)},
                 A: {(B,), (a, a)},
                 B: {(a,)}},
                S)
    print(g)
    assert g.has_simple_rules()
    gs = g.remove_simple_rules()
    print(gs)
    assert not gs.has_simple_rules()
    assert g.generates("aa")
    assert g.generates("aaa")


def test_is_epsilon_normal():
def test_is_epsilon_normal():
    assert g0.is_epsilon_normal_form()
    assert g0.is_epsilon_normal_form()
    assert g1.is_epsilon_normal_form()
    assert g1.is_epsilon_normal_form()