Commit 0886acf5 authored by Vít Novotný's avatar Vít Novotný
Browse files

Transfer cross-entropy loss weight tensor to GPU

parent 2194e48f
Pipeline #147716 passed with stage
in 9 minutes and 41 seconds
......@@ -164,7 +164,9 @@ class BIOTokenWeightedClassification(TokenClassification):
for expected_label_id, (label_id, weight) in enumerate(sorted(weights_dict.items())):
assert expected_label_id == label_id
weights.append(weight)
loss_fct = torch.nn.CrossEntropyLoss(weight=torch.tensor(weights))
weight_tensor = torch.tensor(weights)
weight_tensor = weight_tensor.to(labels.device)
loss_fct = torch.nn.CrossEntropyLoss(weight=weight_tensor)
# Only keep active parts of the loss
if attention_mask is not None:
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment