How to use a pre-trained BERT model for classification using Transfer learning

CoffeeBeans_BrewingInnovations
5 min readAug 17, 2022

Transfer Learning works on the idea that a pre-trained model trained in one setting can be utilized in a different setting for a specific task. I will be doing sentiment analysis here by using a pre-trained BERT model using Pre-Trained Model Approach and will be further training it on a custom dataset prepared.

Source: Educative.io

Key Motivation for Transfer Learning- If we look at tasks in our everyday life, we don’t necessarily learn everything from scratch. We essentially try to make our understanding of some new task based on our experiences in the past. It’s also fairly difficult to gather a huge amount of data to train a Deep Learning model from scratch. Transfer learning comes very handy in cases where you don’t have much data to train a new model from scratch.

This article is solely focused on the practical implementation of Transfer Learning by using a pre-trained model and we will further train it according to our requirements.

Let’s walk through the script:

We will start by installing pip, sentencepiece, datasets, and transformers.

Next, we will define a class where we would write the function for encoding data, in our case, it’s tweets.

For more details, you can read about the different attributes of tokenizer.encode_plus in the following link.

Also, you can go through this link as well.

The next steps involve defining hyperparameters as well as the model on which we are going to do training.

There is no strictness regarding the model used here, you can choose your own depending on the problem statement. But it’s always fair to check on its performance before you further train it on your custom data.

Next, we should define a relabeling function that actually defines the labels of your classification problem. Here I am taking three categories 0,1 and 2.

Next is converting the data into an appropriate format before we pass it to a trainer.

The next step is model training and this is the most crucial step of whole process where we will pass the data to pretrained model trainer.

Please see that I have kept epochs to 10 and the Batch size is 8, my model took approx 3hrs for 1 epoch to train since my dataset was large, you can play around here by keeping epochs less or keeping the dataset small, but if you want to train a good model, my suggestion is to use at least8 to 10 epochs for satisfactory results.

Also, make sure to evaluate your results.

The next concept I want to share is ONNX model format.

ONNX is an intermediary machine learning framework used to convert different machine learning frameworks. In Short, train your custom model in whatever framework you are comfortable with, like Tensorflow and Pytorch, and just finally convert it to common ONNX format. This provides a lot of flexibility to developers as well. You can also use ONNX models with optimized inference frameworks (Onnx runtime, Ncnn, Tvm etc) to maximize performance.

(I will write a separate blog later for ONNX in detail)

According to official documentation :

Source: https://github.com/onnx/onnx

Every developer who is working on either tensorflow or Pytorch or using any other DL framework must know ONNX as interoperability is very important to keep production pipelines intact and in case we need to change model or retrain model in different frameworks and use them in production then ONNX would come very handy as you would have a single framework for production instead of various frameworks.

Here I am sharing a code below to convert retrained model using ONNX.

Please note here that defining features is really important as this defines for what specific task we are converting our model to ONNX format. The last step involves converting the model into ONNX format and then checking the export or loading the model and then utilizing it in production or research environment.

Please suggest any changes if some error is there, the error would be rectified immediately.

--

--

CoffeeBeans_BrewingInnovations

CoffeeBeans empowers organizations to transform their business through the use of advanced technologies. Building data-driven solutions that drive innovation.