Verified Commit bd1b9bc0 authored by Vladimír Štill's avatar Vladimír Štill
Browse files

CFL: Implement Context-Free Grammars (including CYK)

Based on the old CFG evaluator from KSI
parent dcfaaaeb
from __future__ import annotations
from typing import Set, Dict, List, Union, Optional, Tuple, Iterable, \
Callable, TypeVar
import typing
from typing_extensions import Final
from copy import deepcopy
T = TypeVar("T")
class Terminal:
def __init__(self, name: str):
self.name = name
def __eq__(self, obj):
if isinstance(obj, Terminal):
return obj.name == self.name
return False
def __hash__(self):
return hash(self.name)
def __repr__(self) -> str:
return f"Terminal({self.name})"
class Nonterminal:
def __init__(self, name: str):
self.name = name
def __eq__(self, obj):
if isinstance(obj, Nonterminal):
return obj.name == self.name
return False
def __hash__(self):
return hash(self.name)
def __lt__(self, other : Nonterminal) -> bool:
return self.name < other.name
def __repr__(self) -> str:
return f"Nonterminal({self.name})"
def all_of(pred: Callable[[T], bool], it: Iterable[T]) -> bool:
return all(map(pred, it))
def any_of(pred: Callable[[T], bool], it: Iterable[T]) -> bool:
return any(map(pred, it))
class GeneratesResult:
def __init__(self, value: bool, cnf_cfg: CFG,
cyk_table: Optional[List[List[Set[Nonterminal]]]] = None):
self.value: Final = value
self.cnf_cfg: Final = cnf_cfg
self.cyk_table: Final = cyk_table
def __bool__(self) -> bool:
return self.value
class CFG:
Symbol = Union[Terminal, Nonterminal]
Production = Tuple[Symbol, ...]
Rules = Dict[Nonterminal, Set[Production]]
def __init__(self, nonterminals: Set[Nonterminal],
terminals: Set[Terminal],
rules: CFG.Rules,
init: Nonterminal):
self.nonterminals: Final = deepcopy(nonterminals)
self.terminals: Final = deepcopy(terminals)
self.rules: Final = deepcopy(rules)
self.init: Final = deepcopy(init)
# normalize rules: avoid any rules leading to empty set of productions
to_drop: Set[Nonterminal] = set()
for src, prods in self.rules.items():
if len(prods) == 0:
to_drop.add(src)
for src in to_drop:
del self.rules[src]
self._check()
def productions(self) -> Iterable[Tuple[Nonterminal, CFG.Production]]:
for src, prods in self.rules.items():
for prod in prods:
yield (src, prod)
@staticmethod
def empty(terminals: Set[Terminal]) -> CFG:
S = Nonterminal("S")
return CFG({S}, deepcopy(terminals), dict(), S)
def _check(self) -> None:
assert self.init in self.nonterminals,\
"Initial nonterminal {self.init.name} must be in nonterminals"
for nterm, prod in self.productions():
assert nterm in self.nonterminals,\
f"A rule for {nterm.name} exists, "\
f"but it is not in nonterminals"
for x in prod:
if isinstance(x, Nonterminal):
assert x in self.nonterminals,\
f"A rule containing nonterminal {x.name} found, "\
f"but it is not in nonterminals"
else:
assert isinstance(x, Terminal),\
f"Neither terminal not nonterminal symbol found: "\
f"{x.name}"
assert x in self.terminals,\
f"A rule containing terminal {x.name} found, "\
f"but it is not in terminals"
def reduced(self) -> CFG:
return self.normalized().remove_unreachable()
def normalized(self) -> CFG:
normalized_nts: Set[Nonterminal] = set()
added = True
while added:
added = False
for src, prods in self.rules.items():
for prod in prods:
if src not in normalized_nts and \
all_of(lambda x:
x in normalized_nts or x in self.terminals,
prod):
normalized_nts.add(src)
added = True
break
return self.restrict_symbols(normalized_nts | self.terminals)
def restrict_symbols(self, symbols: Set[Union[Terminal, Nonterminal]])\
-> CFG:
if self.init not in symbols:
return CFG.empty(self.terminals & symbols)
nonterminals = self.nonterminals & symbols
terminals = self.terminals & symbols
rules: CFG.Rules = dict()
for src, prods in self.rules.items():
if src not in symbols:
continue
new_prods: Set[CFG.Production] = set()
for prod in prods:
if all_of(lambda x: x in symbols, prod):
new_prods.add(prod)
if new_prods:
rules[src] = new_prods
return CFG(nonterminals, terminals, rules, self.init)
def remove_unreachable(self) -> CFG:
if len(self.rules) == 0:
return self
old_reachable: Set[CFG.Symbol] = set()
reachable: Set[CFG.Symbol] = {self.init}
while len(old_reachable) < len(reachable):
old_reachable = deepcopy(reachable)
for src in old_reachable:
if isinstance(src, Nonterminal) and src in self.rules:
for prod in self.rules[src]:
for symbol in prod:
reachable.add(symbol)
return self.restrict_symbols(reachable)
def is_epsilon_normal_form(self) -> bool:
has_eps = False
has_non_start_eps = False
recursive_start = False
for src, prods in self.rules.items():
if () in prods:
has_eps = True
if src != self.init:
has_non_start_eps = True
if any_of(lambda prod: self.init in prod, prods):
recursive_start = True
return not has_eps or (not has_non_start_eps and not recursive_start)
def epsilon_normal_form(self) -> CFG:
if self.is_epsilon_normal_form():
return self
erasable = set()
added = True
while added:
added = False
for src, prod in self.productions():
if src in erasable:
continue
if prod == [] or all_of(lambda x: x in erasable, prod):
erasable.add(src)
added = True
continue
new_rules: CFG.Rules = dict()
def drop(prod: CFG.Production, n=0) -> Iterable[CFG.Production]:
if len(prod) == 0:
yield ()
return
head = prod[0]
for tail in drop(prod[1:], n + 1):
if head in erasable:
yield tail
yield (head,) + tail
for src, prod in self.productions():
for new_prod in drop(prod):
if new_prod:
if src not in new_rules:
new_rules[src] = set()
new_rules[src].add(new_prod)
if self.init in erasable:
new_init = Nonterminal("S")
while new_init in self.nonterminals:
new_init.name += "'"
new_rules[new_init] = {(), (self.init,)}
return CFG(self.nonterminals | {new_init}, self.terminals,
new_rules, new_init)
return CFG(self.nonterminals, self.terminals, new_rules, self.init)
def remove_simple_rules(self) -> CFG:
simple_to: Dict[Nonterminal, Set[Nonterminal]] =\
{n: {n} for n in self.nonterminals}
added = True
while added:
added = False
for src, prod in self.productions():
if len(prod) == 1 and prod[0] in self.nonterminals \
and prod[0] not in simple_to[src]:
added = True
simple_to[src].add(typing.cast(Nonterminal, prod[0]))
new_rules: CFG.Rules = dict()
for src in self.nonterminals:
for esrc in simple_to[src]:
for prod in self.rules[esrc]:
if len(prod) != 1 or prod[0] in self.terminals:
if src not in new_rules:
new_rules[src] = set()
new_rules[src].add(prod)
return CFG(self.nonterminals, self.terminals, new_rules, self.init)
def proper(self) -> CFG:
return self.epsilon_normal_form().remove_simple_rules().reduced()
def cnf(self) -> CFG:
prop = self.proper()
new_nontminals = deepcopy(prop.nonterminals)
new_rules: CFG.Rules = dict()
def get_cnf_name(prod_part: CFG.Production) -> Nonterminal:
nterm = Nonterminal('<' +
''.join(map(lambda x: x.name.replace('>', '_')
.replace('<', '_'), prod_part))
+ '>')
while nterm in new_nontminals:
nterm.name += "'"
new_nontminals.add(nterm)
return nterm
new_nterms: Dict[CFG.Production, Nonterminal] = dict()
def mk_cnf_prod(src: Nonterminal, prod: CFG.Production) -> None:
if src not in new_rules:
new_rules[src] = set()
if len(prod) == 0 or len(prod) == 1:
assert len(prod) == 0 or prod[0] in prop.terminals
new_rules[src].add(prod)
else:
half = len(prod) // 2
tgt = []
for x in [prod[:half], prod[half:]]:
if len(x) == 1 and x[0] in prop.nonterminals:
tgt.append(x[0])
elif x in new_nterms:
tgt.append(new_nterms[x])
else:
tgt_nt = get_cnf_name(x)
mk_cnf_prod(tgt_nt, x)
tgt.append(tgt_nt)
new_nterms[x] = tgt_nt
new_rules[src].add(tuple(tgt))
for src, prod in prop.productions():
mk_cnf_prod(src, prod)
return CFG(new_nontminals, prop.terminals, new_rules, prop.init)
def is_cnf(self) -> bool:
for src, prod in self.productions():
# X -> ???+
if len(prod) > 2:
return False
# X -> ?? with some terminal symbol
if len(prod) == 2 and (prod[0] not in self.nonterminals
or prod[1] not in self.nonterminals):
return False
# X -> Y
if len(prod) == 1 and prod[0] not in self.terminals:
return False
# X -> *S* and S -> \e
if self.init in prod and self.init in self.rules \
and () in self.rules[self.init]:
return False
return True
def generates(self, word: Union[str, Iterable[Terminal]])\
-> GeneratesResult:
cnf = self if self.is_cnf() else self.cnf()
if isinstance(word, str):
word = [Terminal(x) for x in word]
else:
word = list(word)
if cnf.init not in cnf.rules:
return GeneratesResult(False, cnf)
n = len(word)
if n == 0:
return GeneratesResult(() in cnf.rules[cnf.init], cnf)
table: List[List[Set[Nonterminal]]] = \
[[set() for _ in range(n - i)] for i in range(n)]
for i in range(n):
T_i1 = table[0][i]
for src, prod in cnf.productions():
if len(prod) == 1 and word[i] == prod[0]:
T_i1.add(src)
for j in range(2, n + 1):
for i in range(n - j + 1):
T_ij = table[j - 1][i]
for k in range(1, j):
for src, prod in cnf.productions():
if len(prod) == 2 and \
prod[0] in table[k - 1][i] and \
prod[1] in table[j - k - 1][i + k]:
T_ij.add(src)
return GeneratesResult(cnf.init in table[n - 1][0], cnf, table)
def to_string(self) -> str:
nonterms = sorted(self.nonterminals)
nonterms.remove(self.init)
nonterms.insert(0, self.init)
out = []
for r in nonterms:
to = sorted(map(lambda prds: "".join(map(lambda x: x.name, prds))
if prds else "ε", self.rules[r]))
out.append(f"{r.name} -> {' | '.join(to)}")
return "\n".join(out)
def __str__(self) -> str:
return self.to_string()
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment