Commit 2194e48f authored by Vít Novotný's avatar Vít Novotný
Browse files

Fix typos in indentations and missing `.items()`

parent 551ef0b0
Pipeline #147708 passed with stage
in 10 minutes and 4 seconds
......@@ -196,5 +196,9 @@ def get_label_weights(labels: Iterable[Label], dataset: Iterable[BioNerTags]) ->
for label in bio_ner_tags.split():
assert label in label_counts
label_counts[label] += 1
label_weights = {label: count**-1 if count > 0 else 0 for label, count in label_counts.items()}
label_weights = {
label: count**-1 if count > 0 else 0
for label, count
in label_counts.items()
}
return label_weights
......@@ -155,7 +155,11 @@ class BIOTokenWeightedClassification(TokenClassification):
labels: torch.LongTensor,
inputs: Optional[Union[BatchEncoding, Dict[str, torch.Tensor]]] = None,
attention_mask: Optional[torch.LongTensor] = None) -> torch.FloatTensor:
weights_dict = {self.labels_map[label]: weight for label, weight in self.label_weights}
weights_dict = {
self.labels_map[label]: weight
for label, weight
in self.label_weights.items()
}
weights = []
for expected_label_id, (label_id, weight) in enumerate(sorted(weights_dict.items())):
assert expected_label_id == label_id
......@@ -167,7 +171,7 @@ class BIOTokenWeightedClassification(TokenClassification):
active_loss = attention_mask.view(-1) == 1
active_logits = logit_outputs.view(-1, len(self.labels_map))
active_labels = torch.where(
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
)
loss = loss_fct(active_logits, active_labels)
else:
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment