from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
import tensorflow as tf

import click
import numpy as np
import os
import pandas as pd
from pathlib import Path
from typing import List, Union

from sklearn.metrics import precision_recall_fscore_support, accuracy_score

from src.preprocessing_text import TextLoader, TextProcessor
from src.train_logs import load_logs
from src.BERT_utils import predict_batches
from src.eval_utils import gen_scores_dict


@click.argument("train_logs")
def main(train_logs: Union[str, os.PathLike]):
    """
    Prepares data and evaluates trained BERT model with TF
    :param train_logs: path to csv-file containing train logs
    """

    # Load logs
    (
        path_repo,
        path_model,
        input_data,
        text_preprocessing,
        newspaper,
        lang,
        topic,
        hsprob,
        remove_duplicates,
        min_num_words,
        pretrained_model,
    ) = load_logs(train_logs)


    # Load data and extract only text from tagesanzeiger
    print("Load and preprocess text")
    tl = TextLoader(input_data)
    df_de = tl.load_text_csv(
        newspaper=newspaper,
        lang=lang,
        topic=topic,
        hsprob=hsprob,
        load_subset=False,
        remove_duplicates=remove_duplicates,
        min_num_words=min_num_words,
    )

    if text_preprocessing:
        tp = TextProcessor()
        text_proc = tp.fit_transform(df_de.text)
        df_de.text = text_proc
    comon_topics = tl.get_comments_per_topic(df_de)

    # Load tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(pretrained_model)
    model = TFAutoModelForSequenceClassification.from_pretrained(
        pretrained_model_name_or_path=path_model
    )

    # Split text into batches
    y_pred_all, y_prob_all = predict_batches(df_de.text.values, model, tokenizer)

    import pdb; pdb.set_trace()

    # eval all
    precision, recall, f1, _ = precision_recall_fscore_support(
        df_de.label, y_pred_all, average="weighted"
    )
    accuracy = accuracy_score(df_de.label, y_pred_all)

    results_all = gen_scores_dict(precision, recall, f1, accuracy)

    # eval per topic
    topics = [t[0] for t in comon_topics]
    results_t = dict()

    for t in topics:
        y_test_t = df_de[df_de.topic == t].label
        y_pred_t = y_pred_all[df_de.topic == t]

        precision, recall, f1, _ = precision_recall_fscore_support(
            y_test_t, y_pred_t, average="weighted"
        )
        accuracy = accuracy_score(y_test_t, y_pred_t)

        results_t[t] = gen_scores_dict(precision, recall, f1, accuracy)

    # Compute rejection rate
    reject_rate_all = np.round(df_de.label.mean(), 4) * 100
    reject_rate_topic = [
        np.round(df_de[df_de.topic == k].label.mean(), 4) * 100 for k in topics
    ]

    # Compute number comments
    num_comm_all = df_de.shape[0]
    num_comm_topic = [df_de[df_de.topic == k].shape[0] for k in topics]

    # Save results labels
    df_res_all = pd.DataFrame().from_dict(results_all, orient="index", columns=["all"])
    df_res_all.loc["rejection rate"] = reject_rate_all
    df_res_all.loc["number comments"] = num_comm_all

    df_res_topic = pd.DataFrame.from_dict(results_t)
    df_res_topic.loc["rejection rate"] = reject_rate_topic
    df_res_topic.loc["number comments"] = num_comm_topic

    df_res = df_res_all.join(df_res_topic)
    df_res.loc["data"] = [input_data] * df_res.shape[1]

    df_res.to_csv(
        path_repo + "/results/results_eval_BERT/" + Path(path_model).stem + ".csv"
    )

    # Save results probs
    df_prob_all = df_de.copy()
    df_prob_all['bert_probability'] = y_prob_all
    df_prob_all.to_csv(
        path_repo + "/results/results_eval_BERT/" + Path(path_model).stem + "_bert_probability.csv"
    )


if __name__ == "__main__":
    main()