From 52ee3dbc5304eb65908931fc3ebb74f2839485e0 Mon Sep 17 00:00:00 2001 From: Franziska Oschmann <franziskaoschmann@staff-net-vpn-dhcp-1778.intern.ethz.ch> Date: Mon, 10 Jul 2023 10:21:52 +0200 Subject: [PATCH] Adjust main.py for loading of logs --- moderation_classifier/main.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/moderation_classifier/main.py b/moderation_classifier/main.py index cab59e9..df05af2 100644 --- a/moderation_classifier/main.py +++ b/moderation_classifier/main.py @@ -5,6 +5,7 @@ from src.preprocessing_df import DataProcessor import moderation_classifier.split_data as split_data import moderation_classifier.train_MNB as train_MNB import moderation_classifier.train_BERT as train_BERT +import moderation_classifier.eval_MNB as eval_MNB import moderation_classifier.eval_BERT as eval_BERT import moderation_classifier.train_BERT_torch as train_BERT_torch @@ -18,9 +19,9 @@ import os @click.option("-tp", "--text_preprocessing", is_flag=True) @click.option("-tm", "--train_mnb", is_flag=True) @click.option("-tb", "--train_bert", is_flag=True) +@click.option("-em", "--eval_mnb", is_flag=True) @click.option("-eb", "--eval_bert", is_flag=True) @click.option("-tbto", "--train_bert_torch", is_flag=True) -@click.option("-mt", "--model_timestemp") @click.argument("input_data") def main( split: bool, @@ -28,9 +29,9 @@ def main( text_preprocessing: bool, train_mnb: bool, train_bert: bool, + eval_mnb: bool, eval_bert: bool, train_bert_torch: bool, - model_timestemp: Union[str, os.PathLike], input_data: Union[str, os.PathLike], ): """ @@ -40,6 +41,7 @@ def main( :param text_preprocessing: Binary flag to set text preprocessing. :param train_mnb: Binary flag to specify whether MNB should be trained. :param train_bert: Binary flag to specify whether BERT should be trained. + :param eval_mnb: Binary flag to specify whether MNB should be evaluated. :param eval_bert: Binary flag to specify whether BERT should be evaluated. :param model_timestemp: Path to trained model. :param input_data: Path to input dataframe. @@ -60,8 +62,11 @@ def main( if train_bert: train_BERT.main(input_data, text_preprocessing) + if eval_mnb: + eval_MNB.main(input_data) + if eval_bert: - eval_BERT.main(input_data, model_timestemp) + eval_BERT.main(input_data) if train_bert_torch: train_BERT_torch.main(input_data) -- GitLab