Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
from transformers import AutoTokenizer
from transformers import DataCollatorWithPadding
from transformers import TFAutoModelForSequenceClassification
from transformers.keras_callbacks import KerasMetricCallback
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.callbacks import TensorBoard
from datasets import load_dataset
import click
import datetime
import os
import pandas as pd
from pathlib import Path
import spacy
from typing import Union
from src.preprocessing_text import TextLoader, TextProcessor
from src.prepare_bert_tf import df2dict, compute_metrics, prepare_training
from src.train_logs import save_logs
@click.argument("input_data", required=True)
@click.argument("text_preprocessing", required=False)
@click.argument("newspaper", required=False)
@click.argument("topic", required=False)
@click.argument("pretrained_model", required=True)
def main(
input_data: Union[str, os.PathLike],
text_preprocessing: bool,
newspaper: str,
topic: str,
hsprob: list,
pretrained_model: str,
):
"""
Prepares data and trains BERT model with TF
:param input_data: path to input data
:param text_preprocessing: Binary flag to set text preprocessing.
:param newspaper: Name of newspaper selected for training.
:param topic: Topic selected for training.
:param hsprob: List with min max values for hate speech probability
:param pretrained_model: Name of pretrained BERT model to use for finetuning.
"""
def preprocess_function(examples):
"""
Prepares tokenizer for mapping
"""
return tokenizer(examples["text"], truncation=True)
# Load data and extract only text from tagesanzeiger
print("Load and preprocess text")
imdb = load_dataset("imdb")
d = {'text': imdb['train']['text'], 'label': imdb['train']['label']}
imdb_df = pd.DataFrame(data=d)
if text_preprocessing:
tp = TextProcessor(lowercase=True)
text_proc = tp.fit_transform(imdb_df['text'])
imdb_df['text'] = text_proc.values
# Prepare data for modeling
ds = df2dict(imdb_df)
tokenizer = AutoTokenizer.from_pretrained(pretrained_model)
tokenized_text = ds.map(preprocess_function, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="tf")
# Training
print("Train model")
id2label = {0: "NEGATIVE", 1: "POSITIVE"}
label2id = {"NEGATIVE": 0, "POSITIVE": 1}
optimizer, _ = prepare_training(tokenized_text)
model = TFAutoModelForSequenceClassification.from_pretrained(
pretrained_model, num_labels=2, id2label=id2label, label2id=label2id
)
tf_train_set = model.prepare_tf_dataset(
tokenized_text["train"],
shuffle=True,
batch_size=16,
collate_fn=data_collator,
)
tf_validation_set = model.prepare_tf_dataset(
tokenized_text["test"],
shuffle=False,
batch_size=16,
collate_fn=data_collator,
)
model.compile(optimizer=optimizer)
# Define checkpoint
time_stemp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
path_checkpoint = Path('../').joinpath("tmp/checkpoint/" + time_stemp)
checkpoint_filepath = path_checkpoint
metric_callback = KerasMetricCallback(
metric_fn=compute_metrics, eval_dataset=tf_validation_set
)
checkpoint_callback = ModelCheckpoint(
checkpoint_filepath,
monitor="val_loss",
save_best_only=True,
save_weights_only=False,
mode="min",
save_freq="epoch",
initial_value_threshold=None,
)
log_dir = "logs/fit/" + time_stemp
tensorboard_callback = TensorBoard(log_dir=log_dir, histogram_freq=1)
callbacks = [metric_callback, checkpoint_callback, tensorboard_callback]
# Fit model
print("Train model")
model.fit(
x=tf_train_set,
validation_data=tf_validation_set,
epochs=5,
verbose=2,
callbacks=callbacks,
)
# Save model
print("Save model")
p_repo = Path('../')
path_model = (p_repo).joinpath("saved_models/" + time_stemp)
model.save_pretrained(path_model)
tokenizer.save_pretrained(path_model)
# Save model logs
save_logs(
path_repo=p_repo,
path_model=path_model,
input_data=input_data,
text_preprocessing=True,
newspaper=newspaper,
lang=None,
topic=topic,
hsprob=hsprob,
remove_duplicates=None,
min_num_words=None,
model_name="BERT",
pretrained_model=pretrained_model,
)
print("Done")
if __name__ == "__main__":
input_data = 'imdb'
text_preprocessing = False
newspaper = None
topic = None
hsprob = None
pretrained_model = 'distilbert-base-uncased'
main(input_data = input_data,
text_preprocessing = text_preprocessing,
newspaper = newspaper,
topic = topic,
hsprob = hsprob,
pretrained_model = pretrained_model,)