import click
from collections import Counter
from joblib import load
import numpy as np
import pandas as pd
from pathlib import Path
from sklearn.metrics import f1_score, precision_recall_fscore_support

from typing import Union
import os

from src.preprocessing_text import TextLoader


def load_model(path):
    """
    Loads trained model
    """
    pipe = load(path)

    return pipe


@click.argument("train_logs")
def main(train_logs: Union[str, os.PathLike]):
    """
    Prepares data and evaluates trained MNB model
    :param train_logs: path to csv-file containing train logs
    """

    # Load logs
    df = pd.read_csv(train_logs, index_col="Unnamed: 0")
    path_model = df.loc["path_model"].values[0]
    input_data = df.loc["input_data"].values[0].replace("train", "test")

    # Load model
    pipe = load_model(path_model)

    # Load test data
    tl = TextLoader(input_data)
    df_test = tl.load_text_csv(
        newspaper="tagesanzeiger",
        lang='de',
        load_subset=False,
        remove_duplicates=False,
        min_num_words=3,
    )
  
    X_test = df_test.text
    y_test = df_test.label

    # Make prediction
    y_pred = pipe.predict(X_test)

    y_pred_t = pipe.predict(X_test)
    precision, recall, f1, _ = precision_recall_fscore_support(
        y_test, y_pred, average='weighted'
    )
    accuracy = pipe.score(X_test, y_test)

    results_all = dict()
    results_all["precision"] = precision
    results_all["recall"] = recall
    results_all["f1"] = f1
    results_all["accuracy"] = accuracy

    #import pdb; pdb.set_trace()

    # Get results per topic
    count_topics = Counter(df_test["topic"]).most_common(10)
    topics = [t[0] for t in count_topics]
    results_t = dict()

    for t in topics:
        X_test_t = df_test[df_test.topic == t].text
        y_test_t = df_test[df_test.topic == t].label

        y_pred_t = pipe.predict(X_test_t)
        precision, recall, f1, _ = precision_recall_fscore_support(
            y_test_t, y_pred_t, average='weighted'
        )
        #f1 = f1_score(y_test_t, y_pred_t)
        accuracy = pipe.score(X_test_t, y_test_t)
        results_t[t] = dict()
        results_t[t]["accuracy"] = accuracy
        results_t[t]["f1"] = f1
        results_t[t]["precision"] = precision
        results_t[t]["recall"] = recall

    # Compute rejection rate
    reject_rate_all = np.round(df_test.label.mean(), 4) * 100
    reject_rate_topic = [
        np.round(df_test[df_test.topic == k].label.mean(), 4) * 100 for k in topics
    ]

    # Compute number comments
    num_comm_all = df_test.shape[0]
    num_comm_topic = [df_test[df_test.topic == k].shape[0] for k in topics]

    # Save results
    df_res_all = pd.DataFrame().from_dict(results_all, orient="index", columns=["all"])
    df_res_all.loc["rejection rate"] = reject_rate_all
    df_res_all.loc["number comments"] = num_comm_all

    df_res_topic = pd.DataFrame.from_dict(results_t)
    df_res_topic.loc["rejection rate"] = reject_rate_topic
    df_res_topic.loc["number comments"] = num_comm_topic

    df_res = df_res_all.join(df_res_topic)
    df_res.loc["data"] = [input_data] * df_res.shape[1]

    df_res.to_csv("results/results_eval_MNB/" + Path(path_model).stem + ".csv")


if __name__ == "__main__":
    main()