Skip to content
Snippets Groups Projects
Commit c845bf17 authored by Franziska Oschmann's avatar Franziska Oschmann
Browse files

Add scripts to run training for IMDB

parent 0f7bac13
No related branches found
No related tags found
No related merge requests found
#!/bin/bash
module load gcc/8.2.0 python_gpu/3.10.4 eth_proxy
source ../pp_env_tf_python310/bin/activate
sbatch --mem-per-cpu=12g\
--gpus=1\
--gres=gpumem:12g\
--time=30:00:00\
--wrap "python ../moderation_classifier/train_BERT.py"
from transformers import AutoTokenizer
from transformers import DataCollatorWithPadding
from transformers import TFAutoModelForSequenceClassification
from transformers.keras_callbacks import KerasMetricCallback
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.callbacks import TensorBoard
from datasets import load_dataset
import click
import datetime
import os
import pandas as pd
from pathlib import Path
import spacy
from typing import Union
from src.preprocessing_text import TextLoader, TextProcessor
from src.prepare_bert_tf import df2dict, compute_metrics, prepare_training
from src.train_logs import save_logs
@click.argument("input_data", required=True)
@click.argument("text_preprocessing", required=False)
@click.argument("newspaper", required=False)
@click.argument("topic", required=False)
@click.argument("pretrained_model", required=True)
def main(
input_data: Union[str, os.PathLike],
text_preprocessing: bool,
newspaper: str,
topic: str,
hsprob: list,
pretrained_model: str,
):
"""
Prepares data and trains BERT model with TF
:param input_data: path to input data
:param text_preprocessing: Binary flag to set text preprocessing.
:param newspaper: Name of newspaper selected for training.
:param topic: Topic selected for training.
:param hsprob: List with min max values for hate speech probability
:param pretrained_model: Name of pretrained BERT model to use for finetuning.
"""
def preprocess_function(examples):
"""
Prepares tokenizer for mapping
"""
return tokenizer(examples["text"], truncation=True)
# Load data and extract only text from tagesanzeiger
print("Load and preprocess text")
imdb = load_dataset("imdb")
d = {'text': imdb['train']['text'], 'label': imdb['train']['label']}
imdb_df = pd.DataFrame(data=d)
if text_preprocessing:
tp = TextProcessor(lowercase=True)
text_proc = tp.fit_transform(imdb_df['text'])
imdb_df['text'] = text_proc.values
# Prepare data for modeling
ds = df2dict(imdb_df)
tokenizer = AutoTokenizer.from_pretrained(pretrained_model)
tokenized_text = ds.map(preprocess_function, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="tf")
# Training
print("Train model")
id2label = {0: "NEGATIVE", 1: "POSITIVE"}
label2id = {"NEGATIVE": 0, "POSITIVE": 1}
optimizer, _ = prepare_training(tokenized_text)
model = TFAutoModelForSequenceClassification.from_pretrained(
pretrained_model, num_labels=2, id2label=id2label, label2id=label2id
)
tf_train_set = model.prepare_tf_dataset(
tokenized_text["train"],
shuffle=True,
batch_size=16,
collate_fn=data_collator,
)
tf_validation_set = model.prepare_tf_dataset(
tokenized_text["test"],
shuffle=False,
batch_size=16,
collate_fn=data_collator,
)
model.compile(optimizer=optimizer)
# Define checkpoint
time_stemp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
path_checkpoint = Path('../').joinpath("tmp/checkpoint/" + time_stemp)
checkpoint_filepath = path_checkpoint
metric_callback = KerasMetricCallback(
metric_fn=compute_metrics, eval_dataset=tf_validation_set
)
checkpoint_callback = ModelCheckpoint(
checkpoint_filepath,
monitor="val_loss",
save_best_only=True,
save_weights_only=False,
mode="min",
save_freq="epoch",
initial_value_threshold=None,
)
log_dir = "logs/fit/" + time_stemp
tensorboard_callback = TensorBoard(log_dir=log_dir, histogram_freq=1)
callbacks = [metric_callback, checkpoint_callback, tensorboard_callback]
# Fit model
print("Train model")
model.fit(
x=tf_train_set,
validation_data=tf_validation_set,
epochs=5,
verbose=2,
callbacks=callbacks,
)
# Save model
print("Save model")
p_repo = Path('../')
path_model = (p_repo).joinpath("saved_models/" + time_stemp)
model.save_pretrained(path_model)
tokenizer.save_pretrained(path_model)
# Save model logs
save_logs(
path_repo=p_repo,
path_model=path_model,
input_data=input_data,
text_preprocessing=True,
newspaper=newspaper,
lang=None,
topic=topic,
hsprob=hsprob,
remove_duplicates=None,
min_num_words=None,
model_name="BERT",
pretrained_model=pretrained_model,
)
print("Done")
if __name__ == "__main__":
input_data = 'imdb'
text_preprocessing = False
newspaper = None
topic = None
hsprob = None
pretrained_model = 'distilbert-base-uncased'
main(input_data = input_data,
text_preprocessing = text_preprocessing,
newspaper = newspaper,
topic = topic,
hsprob = hsprob,
pretrained_model = pretrained_model,)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment