Commit 798d9a0f authored by Vít Novotný's avatar Vít Novotný
Browse files

Add weighting to `CrossEntropyLoss()`

parent fc384e4a
Pipeline #147704 passed with stage
in 10 minutes and 32 seconds
%% Cell type:markdown id:c4d8e902-f647-4359-9923-26908044e5a9 tags:
# Find all entities
In this notebook, we will produce a dataset for training, validating, and testing a model with masked language modeling (MLM) and named entity recognition (NER) objectives. We will do this by finding occurences of all entities from regests in all OCR documents.
%% Cell type:markdown id:0d70b697-762b-4bc0-b593-9888cb68e320 tags:
## Preliminaries
We will begin with a bit of boilerplate, logging information and setting up the computational environment.
%% Cell type:code id:23867144-6a79-4b35-9ce9-bf412791f7da tags:
``` python
import socket
```
%% Cell type:code id:af368029-288e-448a-9a6e-86864df83a8a tags:
``` python
hostname = socket.gethostname()
print(hostname)
```
%% Output
apollo.fi.muni.cz
%% Cell type:code id:6e10cff3-e0b8-4790-9922-8a5d88fe543a tags:
``` python
! python -V
```
%% Output
Python 3.8.5
Python 3.8.10
%% Cell type:markdown id:3612629c-27c3-4490-bc1e-472a4e9b8f88 tags:
Install the current version of the package and its dependencies.
%% Cell type:code id:5a226150-9853-439e-9a5e-04312a006720 tags:
``` python
%%capture
! pip install .
```
%% Cell type:markdown id:fa13770b-06e1-498c-a677-200e1d097711 tags:
Set up logging to display informational messages.
%% Cell type:code id:e2c5d1c7-5e51-4840-808a-448ac7f457d3 tags:
``` python
import logging
import sys
```
%% Cell type:code id:f96236e4-110b-4261-9d71-19058dc4685a tags:
``` python
logging.basicConfig(stream=sys.stdout, level=logging.INFO, format='%(message)s')
```
%% Cell type:markdown id:f583b627-9936-4273-9b4f-3a41dedb917a tags:
## Pre-filter, pre-type, and canonicalize the entities
First, we will produce a spreadsheet with all entities. We will pass the spreadsheet to layman annotators.
%% Cell type:markdown id:f1115b5e-d80c-43b0-97a2-59d117eb8d36 tags:
We will load all entities.
%% Cell type:code id:d987f79c-7352-492a-8a22-93a5fe97653c tags:
``` python
from ahisto_named_entity_search.entity import load_entities, Entity, Place, Person
```
%% Cell type:code id:b5b7fce4-74a0-474b-a47a-a970e57e3a83 tags:
``` python
all_entities = load_entities()
```
%% Output
Loading entities: 100%|█████████| 4182/4182 [00:06<00:00, 622.83it/s]
Loading entities: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4182/4182 [00:06<00:00, 680.07it/s]
Loaded 20508 entities: 4350 places (21.21%) and 16158 persons (78.79%).
%% Cell type:code id:dc6506c3-f110-468f-874a-a9edc6bc5b2a tags:
``` python
place_entities = [entity for entity in all_entities if isinstance(entity, Place)]
```
%% Cell type:code id:64a61a23-2ea0-4b3f-a7d5-d109714675c9 tags:
``` python
person_entities = [entity for entity in all_entities if isinstance(entity, Person)]
```
%% Cell type:code id:b565bace-d1d4-4e0d-aeec-c16431eed9d2 tags:
``` python
from ahisto_named_entity_search.entity import (Patchset, EntityType,
default_entity_map_persons as entity_map_persons,
default_entity_map_places as entity_map_places)
```
%% Cell type:code id:7f28168f-1426-4003-a2c3-4f0e7e579147 tags:
``` python
Patchset.create_patchset_template('patchset-template-persons.xlsx', person_entities, all_entities)
Patchset.create_patchset_template('patchset-template-places.xlsx', place_entities, all_entities)
```
%% Cell type:markdown id:c5520d80-fa57-441a-89d3-b3ac4823ab7c tags:
Next, we will produce the spreadsheet and pass the spreadsheet to annotators.
%% Cell type:code id:0a56034d-5b7a-4bcc-b04a-3150c01ec0ff tags:
``` python
person_patchset = Patchset.from_file('patchset-persons.xlsx', person_entities)
place_patchset = Patchset.from_file('patchset-places.xlsx', place_entities)
```
%% Cell type:code id:7202ede1-86ef-4a67-b2b1-27320b1412cc tags:
``` python
all_canonical_entities = {
**person_patchset.apply(person_entities, entity_map_persons),
**place_patchset.apply(place_entities, entity_map_places),
}
```
%% Output
Skipped 0, filtered out 2767, and altered 3662 out of 16158 entities
Skipped 0, filtered out 796, and altered 367 out of 4350 entities
%% Cell type:code id:961bd6de-7a7b-4375-80ad-afc6a85341a6 tags:
``` python
unique_canonical_entities = set(all_canonical_entities.values())
```
%% Cell type:code id:7c8bd43c-ece4-46bc-95c5-ae85600b2fba tags:
``` python
print(f'We selected {len(unique_canonical_entities)} canonical entities.')
```
%% Output
We selected 15304 canonical entities.
%% Cell type:markdown id:313a1998-2bf1-4782-93d5-cff56cf43ec2 tags:
The annotators will filter out non-entities, and will provide types (person or place) and canonical forms of the remaining entities by fixing the most jarring typos.
%% Cell type:markdown id:d42147a9-b3c2-4e7b-acd2-fb2c97473a24 tags:
## Find pre-filtered canonical entities
Next, we will find the pre-filtered and canonicalized entities in OCR document texts using Manatee and fuzzy regexes.
%% Cell type:markdown id:d1f54110-7b49-42f9-8634-13901bb8dc4c tags:
We will load all documents.
%% Cell type:code id:586ffe7e-0213-46f9-8be1-d0ed9aae1106 tags:
``` python
from ahisto_named_entity_search.document import Document, load_documents
```
%% Cell type:code id:5e48481f-be2d-4ee7-ac75-b8859255f464 tags:
``` python
documents = load_documents()
```
%% Output
Loading documents: 100%|██| 268669/268669 [00:08<00:00, 33427.83it/s]
Loading documents: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 268669/268669 [00:05<00:00, 48691.74it/s]
%% Cell type:markdown id:9b1168e8-a385-4579-9cb8-910d3d360e73 tags:
We will order entities by difficulty and split them into chunks.
%% Cell type:code id:5335fc8a-a297-45b8-becb-e93840bbb967 tags:
``` python
def difficulty_sort_key(entity: Entity):
difficulty = len(str(entity))
return (difficulty, entity)
```
%% Cell type:code id:d64ff033-3ac0-4a31-b173-545c633abb3e tags:
``` python
sorted_unique_canonical_entities = sorted(unique_canonical_entities, key=difficulty_sort_key)
```
%% Cell type:code id:14a1bff0-85c2-411f-8a14-19e7489f7336 tags:
``` python
from more_itertools import chunked
```
%% Cell type:code id:57428c6f-d6de-4fd6-97a5-9ba959ba6eda tags:
``` python
chunk_size = 100
```
%% Cell type:code id:ca10b249-eb32-4014-bf8b-380c8cfefa5d tags:
``` python
unique_canonical_entity_chunks = list(enumerate(chunked(sorted_unique_canonical_entities, chunk_size)))
```
%% Cell type:markdown id:6427ce08-326b-4d9b-86fa-2d320024de06 tags:
We will find the pre-filtered and canonicalized entities in OCR document texts using Manatee and fuzzy regexes.
%% Cell type:code id:da00c0e3-ff5a-4346-ab1f-14240e1adc69 tags:
``` python
from ahisto_named_entity_search.index import RemoteManateeIndex, FuzzyRegexIndex
```
%% Output
/home/xnovot32/ahisto-named-entity-search/venv-apollo/lib/python3.8/site-packages/huggingface_hub/snapshot_download.py:6: FutureWarning: snapshot_download.py has been made private and will no longer be available from version 0.11. Please use `from huggingface_hub import snapshot_download` to import the only public function in this module. Other members of the file may be changed without a deprecation notice.
warnings.warn(
%% Cell type:code id:1434374b-0be6-4a2b-8522-474dd9397f67 tags:
``` python
from json import JSONDecodeError
```
%% Cell type:code id:f4dbc257-7ed1-4bce-8993-5fb36f1fbc7a tags:
``` python
from ahisto_named_entity_search.search import Search
from ahisto_named_entity_search.search import SearchResultList
```
%% Cell type:code id:899d4cb9-f94e-47b2-b6b1-ac2e4dd62021 tags:
``` python
from tqdm import tqdm
```
%% Cell type:markdown id:47b59b2e-8bf6-40b3-9266-6d5c1cd670e6 tags:
### Manatee
%% Cell type:code id:fa8e050b-8b6b-4f58-87f3-30e2ce3fe565 tags:
``` python
manatee_result_lists = []
for chunk_number, entities in tqdm(unique_canonical_entity_chunks, desc='Finding entities using Manatee'):
filename = f'manatee_{chunk_number + 1:03}'
try:
manatee_result_list = SearchResultList.load(filename, entities)
except (IOError, JSONDecodeError):
manatee_result_list = Search(RemoteManateeIndex(documents.values())).search(entities)
manatee_result_list.save(filename)
manatee_result_lists.append(manatee_result_list)
```
%% Output
Finding entities using Manatee: 100%|█| 154/154 [04:38<00:00, 1.81s/
Finding entities using Manatee: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 154/154 [04:16<00:00, 1.67s/it]
%% Cell type:code id:9cecab30-2303-4135-924a-3d8c5d3250f0 tags:
``` python
combined_manatee_result_list = sum(manatee_result_lists)
```
%% Cell type:code id:3bbd9c78-8692-4094-8e71-6af55366fdc4 tags:
``` python
print(combined_manatee_result_list)
```
%% Output
Retrieved 340990 results for 15304 entities (22.28 on average, 0 at minimum) in a day using Combined.
%% Cell type:code id:029b567c-5134-49c0-8fdb-76bbece50143 tags:
``` python
assert len(combined_manatee_result_list) == len(unique_canonical_entities)
```
%% Cell type:markdown id:22f8ebc3-9846-4503-8dd9-be0b2d63e72d tags:
### Fuzzy regexes
%% Cell type:code id:2163c2b5-49d9-45b7-9c14-26d9e5b835b7 tags:
``` python
fuzzy_regex_result_lists = []
for chunk_number, entities in tqdm(unique_canonical_entity_chunks, desc='Finding entities using fuzzy regexes'):
filename = f'fuzzy-regex_{chunk_number + 1:03}'
try:
fuzzy_regex_result_list = SearchResultList.load(filename, entities)
except (IOError, JSONDecodeError):
fuzzy_regex_result_list = Search(FuzzyRegexIndex(documents.values())).search(entities)
fuzzy_regex_result_list.save(filename)
fuzzy_regex_result_lists.append(fuzzy_regex_result_list)
```
%% Output
Finding entities using fuzzy regexes: 100%|█| 154/154 [5:18:22<00:00,
Finding entities using fuzzy regexes: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 154/154 [4:23:43<00:00, 102.75s/it]
%% Cell type:code id:dd737e7a-7997-443e-996a-dec417021df0 tags:
``` python
combined_fuzzy_regex_result_list = sum(fuzzy_regex_result_lists)
```
%% Cell type:code id:1d433466-975c-4265-92bf-904f30e120aa tags:
``` python
print(combined_fuzzy_regex_result_list)
```
%% Output
Retrieved 24951727 results for 15304 entities (1630.41 on average, 0 at minimum) in 6 months using Combined.
%% Cell type:code id:1bdf3d6e-6168-4b7a-b74c-9fd4c31b9200 tags:
``` python
assert len(combined_fuzzy_regex_result_list) == len(unique_canonical_entities)
```
%% Cell type:markdown id:d15503e4-a6db-4416-a699-159924d56fe6 tags:
## Filter and type entities
The search takes a while. In parallel to the search, we will produce a confirmatory spreadsheet with pre-filtered canonical entities. We will pass the spreadsheet to expert annotators.
%% Cell type:code id:9e3c32ca-7d4b-4c24-bdf8-a434abfea59d tags:
``` python
def concretize_entity_types(patchset: Patchset, default_entity_type: EntityType) -> None:
for patch in patchset.patches.values():
if patch.entity_type is None:
patch.entity_type = default_entity_type
```
%% Cell type:code id:eb53954d-6284-4b62-bbe4-8747a79ebe85 tags:
``` python
concretize_entity_types(person_patchset, EntityType.PERSON)
concretize_entity_types(place_patchset, EntityType.PLACE)
```
%% Cell type:code id:98ab40f6-18da-44be-910d-d37ea662c95a tags:
``` python
patches = {**person_patchset.patches, **place_patchset.patches}
patchset = Patchset.from_patches(patches)
```
%% Cell type:code id:84f9e976-39c7-4247-afc6-560050b292f5 tags:
``` python
from ahisto_named_entity_search.entity import PatchsetConfirmatory
```
%% Cell type:code id:6eef4810-1927-4284-a466-cb008a74f820 tags:
``` python
PatchsetConfirmatory.create_confirmatory_template(
patchset, 'patchset-confirmatory-template.xlsx', all_entities, all_entities)
```
%% Cell type:markdown id:5f424820-1c6c-45d9-864d-39ab58b5250d tags:
The annotators will filter out remaining non-entities, and confirm the type (person or place) of the remaining entities.
%% Cell type:code id:0b76f9fa-4360-4299-a6ac-50ad593c9b1a tags:
``` python
patchset_confirmatory = PatchsetConfirmatory('patchset-confirmatory.xlsx', patchset, all_entities)
```
%% Cell type:code id:537e6bf8-afbd-4bc5-85ea-5fe0c36ecb00 tags:
``` python
confirmed_person_patchset = patchset_confirmatory.apply(person_patchset, person_entities)
confirmed_place_patchset = patchset_confirmatory.apply(place_patchset, place_entities)
```
%% Output
Skipped 2765, filtered out 205, and altered 1666 out of 16158 patches
Skipped 796, filtered out 16, and altered 46 out of 4350 patches
%% Cell type:code id:978825ff-9c92-46e1-b6a2-345dc6f13c05 tags:
``` python
all_confirmed_canonical_entities = {
**confirmed_person_patchset.apply(person_entities, entity_map_persons),
**confirmed_place_patchset.apply(place_entities, entity_map_places),
}
```
%% Output
Skipped 2970, filtered out 2, and altered 4982 out of 16158 entities
Skipped 812, filtered out 0, and altered 363 out of 4350 entities
%% Cell type:code id:6a0dc154-a54c-4038-807b-aa54d400cc2d tags:
``` python
unique_confirmed_canonical_entities = sorted(set(all_confirmed_canonical_entities.values()))
```
%% Cell type:code id:be0fe9ff-f6f9-41de-bcab-1f2ed46b0519 tags:
``` python
print(f'We selected {len(unique_confirmed_canonical_entities)} confirmed canonical entities.')
```
%% Output
We selected 15100 confirmed canonical entities.
%% Cell type:markdown id:edd8b312-1ded-4192-9f06-ccdf7b63e202 tags:
## Create dataset
Finally, we will create the datasets for for training, validating, and testing a model with masked language modeling (MLM) and named entity recognition (NER) objectives.
%% Cell type:markdown id:8ddf0a28-f817-4211-9de9-573b7dbd8a03 tags:
### Masked Language Modeling
To improve the suitability of a language model for our domain, we will produce a dataset for training and validating a language model.
%% Cell type:markdown id:9521943b-e2aa-484f-b099-1c9b75c80e6a tags:
First, we will shuffle and split entities into 90% for training and 10% for validation.
%% Cell type:code id:d72b453e-84af-4137-b0ca-23c6d4c9bee4 tags:
``` python
from random import Random
```
%% Cell type:code id:21192d4a-1e9f-4142-931c-78987db2c58c tags:
``` python
random_seed = 42
```