Predictive Hacks

How to Fine-Tune an NLP Classification Model with Transformers and HuggingFace


This tutorial is an ultimate guide on how to train your custom NLP classification model with transformers, starting with a pre-trained model and then fine-tuning it using transfer learning. We will work with the HuggingFace library, called “transformers”.

Classification Model

For exhibition purposes, we will build a classification model trying to predict if an email is a “ham” or “spam”. In another tutorial, we built an Email Spam Detector using Scikit-Learn and TF-IDF. Feel free to have a look at the tutorial in order to get the data and compare the two different approaches.

For this tutorial, we will work with the Amazon SageMaker Studio Lab which is free. Alternatively, you can work with Colab or locally. The tutorial is reproducible so that you can code along.

Install the Required Libraries

For this tutorial, you can download the following libraries:

!pip install transformers
!pip install datasets
!pip install numpy
!pip install pandas

Load the Data

Assume that you have the train and test datasets stored as CSV files. Let’s see how we can load them as datasets. Notice that HuggingFace requires the data to be as Dataset Dictionary

import datasets
from datasets import load_dataset, load_from_disk

dataset = load_dataset('csv', data_files={'train': 'train_spam.csv', 'test': 'test_spam.csv'})



    train: Dataset({
        features: ['text', 'label'],
        num_rows: 3900
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 1672

Fine-Tune the Model

Keep in mind that the “target” variable should be called “label” and should be numeric. In this dataset, we are dealing with a binary problem, 0 (Ham) or 1 (Spam). So we will start with the “distilbert-base-cased” and then we will fine-tune it. First, we will load the tokenizer.

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased")

def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)

tokenized_datasets =, batched=True)


Loading cached processed dataset at ham_spam_dataset/train/cache-3a436b86c79a53fe.arrow
Loading cached processed dataset at ham_spam_dataset/test/cache-9524e6b19881902e.arrow

Then we will load the model for the Sequence Classification.

from transformers import AutoModelForSequenceClassification
checkpoint = "distilbert-base-cased"
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)

Note that we set “num_labels=2”. If you are dealing with more classes, you have to adjust the number accordingly.

Since we want to report the accuracy of the model, we can add the following function.

import numpy as np
from datasets import load_metric

metric = load_metric("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

Train the Model

Now, we are ready to train the model. We will train only one epoch, but feel free to add more. I would suggest 3 to 5. In the Trainer, you have a great option of arguments. We leave the default values, but I encourage you to have a look at the documentation since many times it is important to experiment with arguments like batch size, learning rate and so on.

from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(output_dir="test_trainer", evaluation_strategy="epoch", num_train_epochs=1)

trainer = Trainer(

How to Fine-Tune an NLP Classification Model with Transformers and HuggingFace 1

As we can see, the model ran for one epoch and the accuracy was 98.3%!

Save the Model

I suggest saving the model and the tokenizer under the same path in order to load both of them at the same time. Bear in mind that In our case, we have not fine-tuned the tokenizer.


# alternatively save the trainer
# trainer.save_model("CustomModels/CustomHamSpam")


Here is our model:

How to Fine-Tune an NLP Classification Model with Transformers and HuggingFace 2

Load the Model

We can load the model and the tokenizer as follows.

# load the model
from transformers import AutoModelForSequenceClassification

load_model = AutoModelForSequenceClassification.from_pretrained("CustomModels/CustomHamSpam")

load_tokenizer = AutoTokenizer.from_pretrained("CustomModels/CustomHamSpam")

Make Predictions

We can make predictions using the TextClassificationPipeline. Let’s see if this email is a HAM or SPAM:

XXXMobileMovieClub: To use your credit, click the WAP link in the next txt message or click here>> http://wap.

model = load_model
tokenizer = load_tokenizer
pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer, return_all_scores=True)
# outputs a list of dicts 
pipe("XXXMobileMovieClub: To use your credit, click the WAP link in the next txt message or click here>> http://wap.")


[[{'label': 'LABEL_0', 'score': 0.009705818258225918},
  {'label': 'LABEL_1', 'score': 0.9902942180633545}]]

We can work alternatively with the pipelines as follows:

from transformers import pipeline
my_pipeline  = pipeline("text-classification", model=load_model, tokenizer=load_tokenizer)
data = ["I love you", "XXXMobileMovieClub: To use your credit, click the WAP link in the next txt message or click here>> http://wap."]



[{'label': 'LABEL_0', 'score': 0.9980890154838562},
 {'label': 'LABEL_1', 'score': 0.9902942180633545}]

As we can see, the email “I love you” is labeled as 0 (i.e. HAM) and the second one that we saw earlier was labeled as 1 (i.e. SPAM).

Share This Post

Share on facebook
Share on linkedin
Share on twitter
Share on email

Leave a Comment

Subscribe To Our Newsletter

Get updates and learn from the best

More To Explore

Photo by NordWood Themes on Unsplash

How To Manage Multiple Screen Sessions

Linux’s Screen lets you run terminal applications to a Server in the background even if you disconnect from the ssh connection.

python exception

Exceptions in Python

In this tutorial, we will provide you with an example of exception handling in Python. For simplicity, we will work