Commit e1ba220c authored by Vít Starý Novotný's avatar Vít Starý Novotný
Browse files

Count approximate absolute frequencies of entities

parent c6b5fff2
Loading
Loading
Loading
Loading
Loading
+34 −17
Original line number Diff line number Diff line
@@ -177,11 +177,13 @@ class TableRow:
    CONTEXT_SIZE = 10
    HEADER = ('Left context', 'Entity', 'Right context',
              'Paragraph start distance', 'Sentence start distance',
              'Sentence end distance', 'Paragraph end distance', 'Language')
              'Sentence end distance', 'Paragraph end distance',
              'Language', 'Absolute frequency om the collection (approximate)')

    def __init__(self, entities: Iterable[Entity], left_context: Iterable[Word], right_context: Iterable[Word],
                 sentence_start_distance: int, sentence_end_distance: int,
                 paragraph_start_distance: int, paragraph_end_distance: int):
                 paragraph_start_distance: int, paragraph_end_distance: int,
                 get_frequency: Callable[['TableRow'], int]):
        self.entities = tuple(entities)
        assert len(self.entities) > 0
        assert all(isinstance(entity, Entity) for entity in self.entities)
@@ -191,6 +193,11 @@ class TableRow:
        self.sentence_end_distance = sentence_end_distance
        self.paragraph_start_distance = paragraph_start_distance
        self.paragraph_end_distance = paragraph_end_distance
        self.get_frequency = get_frequency

    @property
    def frequency(self) -> int:
        return self.get_frequency(self)

    @property
    def formatted(self) -> List[Union[str, TextBlock]]:
@@ -217,7 +224,8 @@ class TableRow:
    def format(self, format_words: Callable[[Iterable[Word]], Any]):
        row = [format_words(self.left_context), format_words(self.entities), format_words(self.right_context),
               str(self.paragraph_start_distance), str(self.sentence_start_distance),
               str(self.sentence_end_distance), str(self.paragraph_end_distance), self.entities[0].language]
               str(self.sentence_end_distance), str(self.paragraph_end_distance), self.entities[0].language,
               str(self.frequency)]
        return row

    def __str__(self):
@@ -236,9 +244,16 @@ class TableRow:
            return self.entities < other.entities


def produce_table_rows(paragraphs: Iterable[List[Word]]) -> Iterable[TableRow]:
def produce_table_rows(paragraphs: Iterable[List[Word]]) -> List[TableRow]:
    num_skipped_duplicates = 0
    seen_entities = set()  # do not produce duplicate rows for entities
    seen_fingerprints = Counter()  # do not produce duplicate rows for entities

    def get_frequency(table_row: TableRow):
        entity_fingerprint = Word.fingerprint(table_row.entities)
        frequency = seen_fingerprints[entity_fingerprint]
        return frequency

    table_rows = []
    for paragraph in paragraphs:
        sentences = extract_sentences(' '.join(str(word) for word in paragraph), include_prefix=True, include_suffix=True)
        sentence_boundaries = [0]
@@ -258,13 +273,14 @@ def produce_table_rows(paragraphs: Iterable[List[Word]]) -> Iterable[TableRow]:
                sentence_end_distance = sentence_boundaries[1] - index
                paragraph_end_distance = len(paragraph) - index
                assert sentence_end_distance <= paragraph_end_distance
                entity_text = Word.fingerprint(entities)
                if entity_text not in seen_entities:
                    seen_entities.add(entity_text)
                    yield TableRow(entities, left_context, right_context, sentence_start_distance, sentence_end_distance,
                                   paragraph_start_distance, paragraph_end_distance)
                entity_fingerprint = Word.fingerprint(entities)
                if entity_fingerprint not in seen_fingerprints:
                    table_row = TableRow(entities, left_context, right_context, sentence_start_distance, sentence_end_distance,
                                         paragraph_start_distance, paragraph_end_distance, get_frequency)
                    table_rows.append(table_row)
                else:
                    num_skipped_duplicates += 1
                seen_fingerprints[entity_fingerprint] += 1

            while index >= sentence_boundaries[1]:
                sentence_boundaries.pop(0)
@@ -282,14 +298,16 @@ def produce_table_rows(paragraphs: Iterable[List[Word]]) -> Iterable[TableRow]:
            right_context = []
            sentence_end_distance = 0
            paragraph_end_distance = 0
            entity_text = Word.fingerprint(entities)
            if entity_text not in seen_entities:
                seen_entities.add(entity_text)
                yield TableRow(entities, left_context, right_context, sentence_start_distance, sentence_end_distance,
                               paragraph_start_distance, paragraph_end_distance)
            entity_fingerprint = Word.fingerprint(entities)
            if entity_fingerprint not in seen_fingerprints:
                table_row = TableRow(entities, left_context, right_context, sentence_start_distance, sentence_end_distance,
                                     paragraph_start_distance, paragraph_end_distance, get_frequency)
                table_rows.append(table_row)
            else:
                num_skipped_duplicates += 1
    print(f'Produced rows for {len(seen_entities)} unique entities, skipping {num_skipped_duplicates} duplicates')
            seen_fingerprints[entity_fingerprint] += 1
    print(f'Produced rows for {len(seen_fingerprints)} unique entities, skipping {num_skipped_duplicates} duplicates')
    return table_rows


def write_csv_file(output_csv_file: Path, table_rows: Iterable[TableRow]) -> None:
@@ -335,7 +353,6 @@ def take_sample(table_rows: Iterable[TableRow], k: int, random_seed: int = 42) -
def main(input_vert_file: Path, num_input_lines: int, output_basename: Path, num_output_xlsx_rows: int) -> None:
    paragraphs = read_vert_file(input_vert_file, num_input_lines)
    table_rows = produce_table_rows(paragraphs)
    table_rows = list(table_rows)
    write_csv_file(output_basename.with_suffix('.csv'), table_rows)
    table_rows = take_sample(table_rows, num_output_xlsx_rows)
    write_xlsx_file(output_basename.with_suffix('.xlsx'), table_rows)