Newer
Older
Franziska Oschmann
committed
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
Franziska Oschmann
committed
import tensorflow as tf
import numpy as np
import os
import pandas as pd
from pathlib import Path
Franziska Oschmann
committed
from typing import List, Union
Franziska Oschmann
committed
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
from src.preprocessing_text import TextLoader, TextProcessor
Franziska Oschmann
committed
from src.train_logs import load_logs
from src.BERT_utils import predict_batches
from src.eval_utils import gen_scores_dict
@click.argument("train_logs")
def main(train_logs: Union[str, os.PathLike]):
"""
Prepares data and evaluates trained BERT model with TF
:param train_logs: path to csv-file containing train logs
# Load logs
Franziska Oschmann
committed
(
path_repo,
path_model,
input_data,
text_preprocessing,
newspaper,
lang,
topic,
Franziska Oschmann
committed
remove_duplicates,
min_num_words,
pretrained_model,
Franziska Oschmann
committed
) = load_logs(train_logs)
# Load data and extract only text from tagesanzeiger
print("Load and preprocess text")
tl = TextLoader(input_data)
df_de = tl.load_text_csv(
Franziska Oschmann
committed
newspaper=newspaper,
lang=lang,
topic=topic,
load_subset=False,
Franziska Oschmann
committed
remove_duplicates=remove_duplicates,
min_num_words=min_num_words,
)
if text_preprocessing:
tp = TextProcessor()
text_proc = tp.fit_transform(df_de.text)
df_de.text = text_proc
Franziska Oschmann
committed
comon_topics = tl.get_comments_per_topic(df_de)
Franziska Oschmann
committed
Franziska Oschmann
committed
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(pretrained_model)
Franziska Oschmann
committed
model = TFAutoModelForSequenceClassification.from_pretrained(
pretrained_model_name_or_path=path_model
)
Franziska Oschmann
committed
Franziska Oschmann
committed
# Split text into batches
y_pred_all, y_prob_all = predict_batches(df_de.text.values, model, tokenizer)
import pdb; pdb.set_trace()
Franziska Oschmann
committed
Franziska Oschmann
committed
# eval all
Franziska Oschmann
committed
precision, recall, f1, _ = precision_recall_fscore_support(
Franziska Oschmann
committed
df_de.label, y_pred_all, average="weighted"
Franziska Oschmann
committed
accuracy = accuracy_score(df_de.label, y_pred_all)
Franziska Oschmann
committed
results_all = gen_scores_dict(precision, recall, f1, accuracy)
Franziska Oschmann
committed
# eval per topic
topics = [t[0] for t in comon_topics]
results_t = dict()
for t in topics:
y_test_t = df_de[df_de.topic == t].label
y_pred_t = y_pred_all[df_de.topic == t]
precision, recall, f1, _ = precision_recall_fscore_support(
Franziska Oschmann
committed
y_test_t, y_pred_t, average="weighted"
Franziska Oschmann
committed
accuracy = accuracy_score(y_test_t, y_pred_t)
Franziska Oschmann
committed
results_t[t] = gen_scores_dict(precision, recall, f1, accuracy)
# Compute rejection rate
reject_rate_all = np.round(df_de.label.mean(), 4) * 100
reject_rate_topic = [
Franziska Oschmann
committed
np.round(df_de[df_de.topic == k].label.mean(), 4) * 100 for k in topics
# Compute number comments
num_comm_all = df_de.shape[0]
Franziska Oschmann
committed
num_comm_topic = [df_de[df_de.topic == k].shape[0] for k in topics]
Franziska Oschmann
committed
df_res_all = pd.DataFrame().from_dict(results_all, orient="index", columns=["all"])
df_res_all.loc["rejection rate"] = reject_rate_all
df_res_all.loc["number comments"] = num_comm_all
Franziska Oschmann
committed
df_res_topic = pd.DataFrame.from_dict(results_t)
df_res_topic.loc["rejection rate"] = reject_rate_topic
df_res_topic.loc["number comments"] = num_comm_topic
df_res = df_res_all.join(df_res_topic)
df_res.loc["data"] = [input_data] * df_res.shape[1]
Franziska Oschmann
committed
df_res.to_csv(
path_repo + "/results/results_eval_BERT/" + Path(path_model).stem + ".csv"
)
# Save results probs
df_prob_all = df_de.copy()
df_prob_all['bert_probability'] = y_prob_all
df_prob_all.to_csv(
path_repo + "/results/results_eval_BERT/" + Path(path_model).stem + "_bert_probability.csv"
)
if __name__ == "__main__":
main()