Skip to content
Snippets Groups Projects
eval_BERT.py 3.75 KiB
Newer Older
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
oschmanf's avatar
oschmanf committed

import os
import pandas as pd
oschmanf's avatar
oschmanf committed

from sklearn.metrics import precision_recall_fscore_support, accuracy_score

from src.preprocessing_text import TextLoader, TextProcessor
from src.train_logs import load_logs
from src.BERT_utils import predict_batches
from src.eval_utils import gen_scores_dict
oschmanf's avatar
oschmanf committed


@click.argument("train_logs")
def main(train_logs: Union[str, os.PathLike]):
    """
    Prepares data and evaluates trained BERT model with TF
    :param train_logs: path to csv-file containing train logs
    (
        path_repo,
        path_model,
        input_data,
        text_preprocessing,
        newspaper,
        lang,
        topic,
        hsprob,
    # Load data and extract only text from tagesanzeiger
    print("Load and preprocess text")
    tl = TextLoader(input_data)
        hsprob=hsprob,
        remove_duplicates=remove_duplicates,
        min_num_words=min_num_words,
    )

    if text_preprocessing:
        tp = TextProcessor()
        text_proc = tp.fit_transform(df_de.text)
        df_de.text = text_proc
    comon_topics = tl.get_comments_per_topic(df_de)
    tokenizer = AutoTokenizer.from_pretrained(pretrained_model)
    model = TFAutoModelForSequenceClassification.from_pretrained(
        pretrained_model_name_or_path=path_model
    )
    y_pred_all, y_prob_all = predict_batches(df_de.text.values, model, tokenizer)

    import pdb; pdb.set_trace()
    precision, recall, f1, _ = precision_recall_fscore_support(
        df_de.label, y_pred_all, average="weighted"
    accuracy = accuracy_score(df_de.label, y_pred_all)
    results_all = gen_scores_dict(precision, recall, f1, accuracy)

    # eval per topic
    topics = [t[0] for t in comon_topics]
    results_t = dict()

    for t in topics:
        y_test_t = df_de[df_de.topic == t].label
        y_pred_t = y_pred_all[df_de.topic == t]

        precision, recall, f1, _ = precision_recall_fscore_support(
        accuracy = accuracy_score(y_test_t, y_pred_t)

        results_t[t] = gen_scores_dict(precision, recall, f1, accuracy)
    reject_rate_all = np.round(df_de.label.mean(), 4) * 100
    reject_rate_topic = [
        np.round(df_de[df_de.topic == k].label.mean(), 4) * 100 for k in topics

    # Compute number comments
    num_comm_all = df_de.shape[0]
    num_comm_topic = [df_de[df_de.topic == k].shape[0] for k in topics]
    # Save results labels
    df_res_all = pd.DataFrame().from_dict(results_all, orient="index", columns=["all"])
    df_res_all.loc["rejection rate"] = reject_rate_all
    df_res_all.loc["number comments"] = num_comm_all
    df_res_topic = pd.DataFrame.from_dict(results_t)
    df_res_topic.loc["rejection rate"] = reject_rate_topic
    df_res_topic.loc["number comments"] = num_comm_topic

    df_res = df_res_all.join(df_res_topic)
    df_res.loc["data"] = [input_data] * df_res.shape[1]

    df_res.to_csv(
        path_repo + "/results/results_eval_BERT/" + Path(path_model).stem + ".csv"
    )
    # Save results probs
    df_prob_all = df_de.copy()
    df_prob_all['bert_probability'] = y_prob_all
    df_prob_all.to_csv(
        path_repo + "/results/results_eval_BERT/" + Path(path_model).stem + "_bert_probability.csv"
    )


if __name__ == "__main__":
    main()