diff --git a/moderation_classifier/main.py b/moderation_classifier/main.py index cab59e9c8b0797883f77e176ae2acd1b7e0166ba..df05af20a638dda4fa1ea3195968605966965f5e 100644 --- a/moderation_classifier/main.py +++ b/moderation_classifier/main.py @@ -5,6 +5,7 @@ from src.preprocessing_df import DataProcessor import moderation_classifier.split_data as split_data import moderation_classifier.train_MNB as train_MNB import moderation_classifier.train_BERT as train_BERT +import moderation_classifier.eval_MNB as eval_MNB import moderation_classifier.eval_BERT as eval_BERT import moderation_classifier.train_BERT_torch as train_BERT_torch @@ -18,9 +19,9 @@ import os @click.option("-tp", "--text_preprocessing", is_flag=True) @click.option("-tm", "--train_mnb", is_flag=True) @click.option("-tb", "--train_bert", is_flag=True) +@click.option("-em", "--eval_mnb", is_flag=True) @click.option("-eb", "--eval_bert", is_flag=True) @click.option("-tbto", "--train_bert_torch", is_flag=True) -@click.option("-mt", "--model_timestemp") @click.argument("input_data") def main( split: bool, @@ -28,9 +29,9 @@ def main( text_preprocessing: bool, train_mnb: bool, train_bert: bool, + eval_mnb: bool, eval_bert: bool, train_bert_torch: bool, - model_timestemp: Union[str, os.PathLike], input_data: Union[str, os.PathLike], ): """ @@ -40,6 +41,7 @@ def main( :param text_preprocessing: Binary flag to set text preprocessing. :param train_mnb: Binary flag to specify whether MNB should be trained. :param train_bert: Binary flag to specify whether BERT should be trained. + :param eval_mnb: Binary flag to specify whether MNB should be evaluated. :param eval_bert: Binary flag to specify whether BERT should be evaluated. :param model_timestemp: Path to trained model. :param input_data: Path to input dataframe. @@ -60,8 +62,11 @@ def main( if train_bert: train_BERT.main(input_data, text_preprocessing) + if eval_mnb: + eval_MNB.main(input_data) + if eval_bert: - eval_BERT.main(input_data, model_timestemp) + eval_BERT.main(input_data) if train_bert_torch: train_BERT_torch.main(input_data)