-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathxp_bare_thg.py
More file actions
105 lines (90 loc) · 3.22 KB
/
xp_bare_thg.py
File metadata and controls
105 lines (90 loc) · 3.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from typing import Optional
import os
import numpy as np
from sacred import Experiment
from sacred.commands import print_config
from sacred.run import Run
from sacred.observers import FileStorageObserver, TelegramObserver
from sacred.utils import apply_backspaces_and_linefeeds
from conivel.datas.dekker import DekkerDataset
from conivel.datas.ontonotes import OntonotesDataset
from conivel.datas.the_hunger_games.the_hunger_games import TheHungerGamesDataset
from conivel.predict import predict
from conivel.score import score_ner
from conivel.train import train_ner_model
from conivel.utils import RunLogScope, pretrained_bert_for_token_classification
script_dir = os.path.abspath(os.path.dirname(__file__))
ex = Experiment()
ex.captured_out_filter = apply_backspaces_and_linefeeds # type: ignore
ex.observers.append(FileStorageObserver("runs"))
if os.path.isfile(f"{script_dir}/telegram_observer_config.json"):
ex.observers.append(
TelegramObserver.from_config(f"{script_dir}/telegram_observer_config.json")
)
@ex.config
def config():
# -- common parameters
batch_size: int
# wether models should be saved or not
save_models: bool = True
# number of experiment repeats
runs_nb: int = 5
# -- NER training parameters
# number of epochs for NER training
ner_epochs_nb: int = 2
# learning rate for NER training
ner_lr: float = 2e-5
@ex.automain
def main(
_run: Run,
batch_size: int,
save_models: bool,
runs_nb: int,
ner_epochs_nb: int,
ner_lr: float,
):
print_config(_run)
dekker = DekkerDataset()
the_hunger_games = TheHungerGamesDataset(cut_into_chapters=False)
precision_matrix = np.zeros((runs_nb,))
recall_matrix = np.zeros((runs_nb,))
f1_matrix = np.zeros((runs_nb,))
metrics_matrices = [
("precision", precision_matrix),
("recall", recall_matrix),
("f1", f1_matrix),
]
for run_i in range(runs_nb):
with RunLogScope(_run, f"run{run_i}"):
model = pretrained_bert_for_token_classification(
"bert-base-cased", dekker.tag_to_id
)
model = train_ner_model(
model,
dekker,
dekker,
_run=_run,
epochs_nb=ner_epochs_nb,
batch_size=batch_size,
learning_rate=ner_lr,
quiet=True,
)
if save_models:
sacred_archive_huggingface_model(_run, model, "model") # type: ignore
preds = predict(model, the_hunger_games, batch_size=batch_size).tags
precision, recall, f1 = score_ner(
the_hunger_games.sents(), preds, ignored_classes={"LOC", "ORG"}
)
_run.log_scalar(f"test_precision", precision)
precision_matrix[run_i] = precision
_run.log_scalar("test_recall", recall)
recall_matrix[run_i] = recall
_run.log_scalar("test_f1", f1)
f1_matrix[run_i] = f1
# global mean metrics
for name, matrix in metrics_matrices:
for op_name, op in [("mean", np.mean), ("stdev", np.std)]:
_run.log_scalar(
f"{op_name}_test_{name}",
op(matrix),
)