from datasets import load_dataset
from evaluate import evaluator
from transformers import pipeline, AutoTokenizer, TFAutoModelForSequenceClassification
import tensorflow as tf

import click
import evaluate
import numpy as np
import os
import pandas as pd
from pathlib import Path
import timeit
from tqdm import tqdm
from typing import Union

from sklearn.metrics import precision_recall_fscore_support, accuracy_score

from src.preprocessing_text import TextLoader, TextProcessor
from src.prepare_bert_tf import df2dict


@click.argument("train_logs")
def main(train_logs: Union[str, os.PathLike]):
    """
    Prepares data and evaluates trained BERT model with TF
    :param train_logs: path to csv-file containing train logs
    """

    # Load logs
    df = pd.read_csv(train_logs, index_col="Unnamed: 0")
    path_repo = df.loc["path_repo"].values[0]
    path_model = df.loc["path_model"].values[0]
    input_data = df.loc["input_data"].values[0].replace("train", "test")
    text_preprocessing = df.loc["text_preprocessing"].values[0]

    # Load data and extract only text from tagesanzeiger
    print("Load and preprocess text")
    tl = TextLoader(input_data)
    df_de = tl.load_text_csv(
        newspaper="tagesanzeiger",
        lang='de',
        load_subset=False,
        remove_duplicates=True,
        min_num_words=3,
    )

    if text_preprocessing:
        tp = TextProcessor()
        text_proc = tp.fit_transform(df_de.text)
        df_de.text = text_proc
    comon_topics = tl.get_comments_per_topic(df_de)
    
    # Load tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained("bert-base-german-cased")
    model = TFAutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path=path_model)
    
    # Split text into batches
    text_list = list(df_de.text.values)
    n = 100
    results=[text_list[idx:idx+n] for idx in range(0, len(text_list), n)]
    import pdb; pdb.set_trace()
    # eval all
    y_pred_all = []
    for batch in tqdm(results):
        inputs = tokenizer(batch, return_tensors="tf", padding=True, truncation=True)
        logits = model(**inputs).logits
        y_pred_batch = tf.argmax(logits,axis=1)
        y_pred_all.append(y_pred_batch)

    y_pred_all = np.concatenate(y_pred_all)

    precision, recall, f1, _ = precision_recall_fscore_support(
        df_de.label, y_pred_all, average='weighted'
    )
    accuracy = accuracy_score(df_de.label, y_pred_all)

    results_all = dict()
    results_all["precision"] = precision
    results_all["recall"] = recall
    results_all["f1"] = f1
    results_all["accuracy"] = accuracy
    
    # eval per topic
    topics = [t[0] for t in comon_topics]
    results_t = dict()

    for t in topics:
        y_test_t = df_de[df_de.topic == t].label
        y_pred_t = y_pred_all[df_de.topic == t]

        precision, recall, f1, _ = precision_recall_fscore_support(
            y_test_t, y_pred_t, average='weighted'
        )
        accuracy = accuracy_score(y_test_t, y_pred_t)
        results_t[t] = dict()
        results_t[t]["accuracy"] = accuracy
        results_t[t]["f1"] = f1
        results_t[t]["precision"] = precision
        results_t[t]["recall"] = recall

    # Compute rejection rate
    reject_rate_all = np.round(df_de.label.mean(), 4) * 100
    reject_rate_topic = [
        np.round(df_de[df_de.topic == k].label.mean(), 4) * 100
        for k in topics
    ]

    # Compute number comments
    num_comm_all = df_de.shape[0]
    num_comm_topic = [df_de[df_de.topic == k].shape[0] for k in topics]

    # Save results
    df_res_all = pd.DataFrame().from_dict(
        results_all, orient="index", columns=["all"]
    )
    df_res_all.loc["rejection rate"] = reject_rate_all
    df_res_all.loc["number comments"] = num_comm_all

    df_res_topic = pd.DataFrame.from_dict(results_t)
    df_res_topic.loc["rejection rate"] = reject_rate_topic
    df_res_topic.loc["number comments"] = num_comm_topic

    df_res = df_res_all.join(df_res_topic)
    df_res.loc["data"] = [input_data] * df_res.shape[1]

    df_res.to_csv(path_repo + "/results/results_eval_BERT/" + Path(path_model).stem + ".csv")


if __name__ == "__main__":
    main()