Skip to content
Snippets Groups Projects
main.py 2.94 KiB
Newer Older
# imports
import click

from src.preprocessing_df import DataProcessor
import moderation_classifier.split_data as split_data
import moderation_classifier.train_MNB as train_MNB
import moderation_classifier.train_BERT as train_BERT
import moderation_classifier.eval_MNB as eval_MNB
import moderation_classifier.eval_BERT as eval_BERT
import moderation_classifier.train_BERT_torch as train_BERT_torch

from typing import Union
import os

Franziska Oschmann's avatar
Franziska Oschmann committed

@click.command()
Franziska Oschmann's avatar
Franziska Oschmann committed
@click.option("-s", "--split", is_flag=True)
@click.option("-p", "--prepare_data", is_flag=True)
@click.option("-tp", "--text_preprocessing", is_flag=True)
@click.option("-n", "--newspaper", default=None)
@click.option("-t", "--topic", default=None)
@click.option("-h", "--hsprob", default=None)
@click.option("-pm", "--pretrained_model", default=None)
Franziska Oschmann's avatar
Franziska Oschmann committed
@click.option("-tm", "--train_mnb", is_flag=True)
@click.option("-tb", "--train_bert", is_flag=True)
@click.option("-em", "--eval_mnb", is_flag=True)
@click.option("-eb", "--eval_bert", is_flag=True)
Franziska Oschmann's avatar
Franziska Oschmann committed
@click.option("-tbto", "--train_bert_torch", is_flag=True)
@click.argument("input_data")
def main(
    split: bool,
    prepare_data: bool,
    text_preprocessing: bool,
    hsprob: list, 
Franziska Oschmann's avatar
Franziska Oschmann committed
    train_mnb: bool,
    train_bert: bool,
    eval_mnb: bool,
Franziska Oschmann's avatar
Franziska Oschmann committed
    train_bert_torch: bool,
    input_data: Union[str, os.PathLike],
):
    """
    Run moderation classifier.
    :param split_data: Binary flag to specify if data should be split.
    :param prepare_data: Binary flag to specify if data should be prepared.
    :param text_preprocessing: Binary flag to set text preprocessing.
    :param newspaper: Name of newspaper selected for training.
    :param topic: Topic selected for training.
    :param hsprob: List with min max values for hate speech probability
    :param pretrained_model: Name of pretrained BERT model to use for finetuning.
    :param train_mnb: Binary flag to specify whether MNB should be trained.
    :param train_bert: Binary flag to specify whether BERT should be trained.
    :param eval_mnb: Binary flag to specify whether MNB should be evaluated.
    :param eval_bert: Binary flag to specify whether BERT should be evaluated.
    :param input_data: Path to input dataframe.
    """

    if split:
        split_data.main(input_data)
Franziska Oschmann's avatar
Franziska Oschmann committed

    if prepare_data:
        dp = DataProcessor(input_data)
        dp.add_language()
        print(input_data)
Franziska Oschmann's avatar
Franziska Oschmann committed
        print("Prepare data")
    if train_mnb:
        train_MNB.main(input_data, newspaper, topic)

    if train_bert:
        if hsprob is None:
            pass
        else:
            hsprob = eval(hsprob)
            input_data, text_preprocessing, newspaper, topic, hsprob, pretrained_model
    if eval_mnb:
        eval_MNB.main(input_data)

        eval_BERT.main(input_data)
    if train_bert_torch:
        train_BERT_torch.main(input_data)

if __name__ == "__main__":