diff --git a/euler/train_model_cluster_imdb.sh b/euler/train_model_cluster_imdb.sh new file mode 100644 index 0000000000000000000000000000000000000000..913c0caafc7eae02c392d28d0969ca5cc709aac7 --- /dev/null +++ b/euler/train_model_cluster_imdb.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +module load gcc/8.2.0 python_gpu/3.10.4 eth_proxy +source ../pp_env_tf_python310/bin/activate + +sbatch --mem-per-cpu=12g\ + --gpus=1\ + --gres=gpumem:12g\ + --time=30:00:00\ + --wrap "python ../moderation_classifier/train_BERT.py" + + diff --git a/moderation_classifier/train_BERT_imdb.py b/moderation_classifier/train_BERT_imdb.py new file mode 100644 index 0000000000000000000000000000000000000000..ba282b5a881f38326c3eea1200cb2190675387e9 --- /dev/null +++ b/moderation_classifier/train_BERT_imdb.py @@ -0,0 +1,169 @@ +from transformers import AutoTokenizer +from transformers import DataCollatorWithPadding +from transformers import TFAutoModelForSequenceClassification +from transformers.keras_callbacks import KerasMetricCallback + +from tensorflow.keras.callbacks import ModelCheckpoint +from tensorflow.keras.callbacks import TensorBoard + +from datasets import load_dataset + +import click +import datetime +import os +import pandas as pd +from pathlib import Path +import spacy +from typing import Union + +from src.preprocessing_text import TextLoader, TextProcessor +from src.prepare_bert_tf import df2dict, compute_metrics, prepare_training +from src.train_logs import save_logs + + +@click.argument("input_data", required=True) +@click.argument("text_preprocessing", required=False) +@click.argument("newspaper", required=False) +@click.argument("topic", required=False) +@click.argument("pretrained_model", required=True) +def main( + input_data: Union[str, os.PathLike], + text_preprocessing: bool, + newspaper: str, + topic: str, + hsprob: list, + pretrained_model: str, +): + """ + Prepares data and trains BERT model with TF + :param input_data: path to input data + :param text_preprocessing: Binary flag to set text preprocessing. + :param newspaper: Name of newspaper selected for training. + :param topic: Topic selected for training. + :param hsprob: List with min max values for hate speech probability + :param pretrained_model: Name of pretrained BERT model to use for finetuning. + """ + + def preprocess_function(examples): + """ + Prepares tokenizer for mapping + """ + return tokenizer(examples["text"], truncation=True) + + # Load data and extract only text from tagesanzeiger + print("Load and preprocess text") + imdb = load_dataset("imdb") + + d = {'text': imdb['train']['text'], 'label': imdb['train']['label']} + imdb_df = pd.DataFrame(data=d) + + if text_preprocessing: + tp = TextProcessor(lowercase=True) + text_proc = tp.fit_transform(imdb_df['text']) + imdb_df['text'] = text_proc.values + + # Prepare data for modeling + ds = df2dict(imdb_df) + + tokenizer = AutoTokenizer.from_pretrained(pretrained_model) + tokenized_text = ds.map(preprocess_function, batched=True) + data_collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="tf") + + # Training + print("Train model") + id2label = {0: "NEGATIVE", 1: "POSITIVE"} + label2id = {"NEGATIVE": 0, "POSITIVE": 1} + + optimizer, _ = prepare_training(tokenized_text) + model = TFAutoModelForSequenceClassification.from_pretrained( + pretrained_model, num_labels=2, id2label=id2label, label2id=label2id + ) + + tf_train_set = model.prepare_tf_dataset( + tokenized_text["train"], + shuffle=True, + batch_size=16, + collate_fn=data_collator, + ) + + tf_validation_set = model.prepare_tf_dataset( + tokenized_text["test"], + shuffle=False, + batch_size=16, + collate_fn=data_collator, + ) + + model.compile(optimizer=optimizer) + + # Define checkpoint + time_stemp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + path_checkpoint = Path('../').joinpath("tmp/checkpoint/" + time_stemp) + checkpoint_filepath = path_checkpoint + metric_callback = KerasMetricCallback( + metric_fn=compute_metrics, eval_dataset=tf_validation_set + ) + checkpoint_callback = ModelCheckpoint( + checkpoint_filepath, + monitor="val_loss", + save_best_only=True, + save_weights_only=False, + mode="min", + save_freq="epoch", + initial_value_threshold=None, + ) + log_dir = "logs/fit/" + time_stemp + tensorboard_callback = TensorBoard(log_dir=log_dir, histogram_freq=1) + + callbacks = [metric_callback, checkpoint_callback, tensorboard_callback] + + # Fit model + print("Train model") + model.fit( + x=tf_train_set, + validation_data=tf_validation_set, + epochs=5, + verbose=2, + callbacks=callbacks, + ) + + # Save model + print("Save model") + p_repo = Path('../') + path_model = (p_repo).joinpath("saved_models/" + time_stemp) + model.save_pretrained(path_model) + tokenizer.save_pretrained(path_model) + + # Save model logs + save_logs( + path_repo=p_repo, + path_model=path_model, + input_data=input_data, + text_preprocessing=True, + newspaper=newspaper, + lang=None, + topic=topic, + hsprob=hsprob, + remove_duplicates=None, + min_num_words=None, + model_name="BERT", + pretrained_model=pretrained_model, + ) + + print("Done") + + +if __name__ == "__main__": + input_data = 'imdb' + text_preprocessing = False + newspaper = None + topic = None + hsprob = None + pretrained_model = 'distilbert-base-uncased' + + main(input_data = input_data, + text_preprocessing = text_preprocessing, + newspaper = newspaper, + topic = topic, + hsprob = hsprob, + pretrained_model = pretrained_model,) +