Skip to content
Snippets Groups Projects
train_BERT_imdb.py 4.91 KiB
Newer Older
from transformers import AutoTokenizer
from transformers import DataCollatorWithPadding
from transformers import TFAutoModelForSequenceClassification
from transformers.keras_callbacks import KerasMetricCallback

from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.callbacks import TensorBoard

from datasets import load_dataset

import click
import datetime
import os
import pandas as pd
from pathlib import Path
import spacy
from typing import Union

from src.preprocessing_text import TextLoader, TextProcessor
from src.prepare_bert_tf import df2dict, compute_metrics, prepare_training
from src.train_logs import save_logs


@click.argument("input_data", required=True)
@click.argument("text_preprocessing", required=False)
@click.argument("newspaper", required=False)
@click.argument("topic", required=False)
@click.argument("pretrained_model", required=True)
def main(
    input_data: Union[str, os.PathLike],
    text_preprocessing: bool,
    newspaper: str,
    topic: str,
    hsprob: list,
    pretrained_model: str,
):
    """
    Prepares data and trains BERT model with TF
    :param input_data: path to input data
    :param text_preprocessing: Binary flag to set text preprocessing.
    :param newspaper: Name of newspaper selected for training.
    :param topic: Topic selected for training.
    :param hsprob: List with min max values for hate speech probability
    :param pretrained_model: Name of pretrained BERT model to use for finetuning.
    """

    def preprocess_function(examples):
        """
        Prepares tokenizer for mapping
        """
        return tokenizer(examples["text"], truncation=True)

    # Load data and extract only text from tagesanzeiger
    print("Load and preprocess text")
    imdb = load_dataset("imdb")

    d = {'text': imdb['train']['text'], 'label': imdb['train']['label']}
    imdb_df = pd.DataFrame(data=d)

    if text_preprocessing:
        tp = TextProcessor(lowercase=True)
        text_proc = tp.fit_transform(imdb_df['text'])
        imdb_df['text'] = text_proc.values

    # Prepare data for modeling
    ds = df2dict(imdb_df)

    tokenizer = AutoTokenizer.from_pretrained(pretrained_model)
    tokenized_text = ds.map(preprocess_function, batched=True)
    data_collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="tf")

    # Training
    print("Train model")
    id2label = {0: "NEGATIVE", 1: "POSITIVE"}
    label2id = {"NEGATIVE": 0, "POSITIVE": 1}

    optimizer, _ = prepare_training(tokenized_text)
    model = TFAutoModelForSequenceClassification.from_pretrained(
        pretrained_model, num_labels=2, id2label=id2label, label2id=label2id
    )

    tf_train_set = model.prepare_tf_dataset(
        tokenized_text["train"],
        shuffle=True,
        batch_size=16,
        collate_fn=data_collator,
    )

    tf_validation_set = model.prepare_tf_dataset(
        tokenized_text["test"],
        shuffle=False,
        batch_size=16,
        collate_fn=data_collator,
    )

    model.compile(optimizer=optimizer)

    # Define checkpoint
    time_stemp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    path_checkpoint = Path('../').joinpath("tmp/checkpoint/" + time_stemp)
    checkpoint_filepath = path_checkpoint
    metric_callback = KerasMetricCallback(
        metric_fn=compute_metrics, eval_dataset=tf_validation_set
    )
    checkpoint_callback = ModelCheckpoint(
        checkpoint_filepath,
        monitor="val_loss",
        save_best_only=True,
        save_weights_only=False,
        mode="min",
        save_freq="epoch",
        initial_value_threshold=None,
    )
    log_dir = "logs/fit/" + time_stemp
    tensorboard_callback = TensorBoard(log_dir=log_dir, histogram_freq=1)

    callbacks = [metric_callback, checkpoint_callback, tensorboard_callback]

    # Fit model
    print("Train model")
    model.fit(
        x=tf_train_set,
        validation_data=tf_validation_set,
        epochs=5,
        verbose=2,
        callbacks=callbacks,
    )

    # Save model
    print("Save model")
    p_repo = Path('../')
    path_model = (p_repo).joinpath("saved_models/" + time_stemp)
    model.save_pretrained(path_model)
    tokenizer.save_pretrained(path_model)

    # Save model logs
    save_logs(
        path_repo=p_repo,
        path_model=path_model,
        input_data=input_data,
        text_preprocessing=True,
        newspaper=newspaper,
        lang=None,
        topic=topic,
        hsprob=hsprob,
        remove_duplicates=None,
        min_num_words=None,
        model_name="BERT",
        pretrained_model=pretrained_model,
    )

    print("Done")


if __name__ == "__main__":
    input_data = 'imdb'
    text_preprocessing = False
    newspaper = None
    topic = None
    hsprob = None
    pretrained_model = 'distilbert-base-uncased'

    main(input_data = input_data,
         text_preprocessing = text_preprocessing,
         newspaper = newspaper,
         topic = topic,
         hsprob = hsprob,
         pretrained_model = pretrained_model,)