Question Classification using Self-Attention Transformer — Part 2

In this blog, we will go through one of how we can use Self-Attention Transformer models to classify a piece of text (question in our case) into two different categories each category containing some number of classes.

We will look at a Self-Attention Transformer with multiple heads for classifying classes in different categories. Without any further adieu let’s code out the Encoder and Decoder modules for a Multi-Head Self-Transformer to classify multiple labels via multiple fully connected linear layer. The Encoder and Decoder parts will be similar to the once shown in . You can use that same code and wouldn’t need to write the Encoder and Decoder modules and their supporting modules. Still below I will put those modules in code below.

The code in the later part of the blog is available


The Encoder part is quite simpler compared to the Decoder part. The Encoder contains N EncoderLayers and each EncoderLayer contain M Self-Attention Heads.

  • Encoder Module
  • Positional Encoding
  • Encoder Layer
  • Multi-Head Self Attention and Scaled Dot Product
  • Position Wise FeedForward layer


The Decoder contains N DecoderLayers and 2 x M Self Attention heads. The first M Self-Attention heads are for calculation decoder attention i.e. attention between the target sequence and the second set of M Self-Attention head is to calculated attention with encoder output i.e. the output of target sequence attention and encoder output.

Similar to the Encoder the Decoder also has Embedding Layer and Positional Encoding Layer.

Let’s code out every part of the Decoder

The Positional Encoding, Multi-Head Attention using Scaled Dot product, and Position Wise FeedForward Layer remain the same for the Decoder also. Instead of the EncoderLayer inside the Encoder, the Decoder has a DecoderLayer which has the 2 x M Self Attention Heads employing Scaled Dot Product to calculate the attention, Positional Encoding, and Position Wise FeedForward Layer.

  • Decoder
  • Decoder Layer

Unlike the Encoder the Decoder has three different outputs,

  1. The Decoder Output
  2. List of Decoder Attention i.e. attention calculated among the values in Target Sequence at each DecoderLayer
  3. List of Decoder Encoder Attention i.e. attention calculated between the output of the Encoder being Encoder Output and the Query being the output from Target Sequence attention calculation

The output we receive from the Decoder (Decoder Output) we will flatten it and pass that through a FullyConnected Layer with N output neurons for each class.

Classification Transformer

Our end goal is to provide two different class names to the given question. Here, we can pass the features extracted from the Encoder-Decoder layers of the Self Attention Transformer to two fully connected linear layers; one predicts the main class and another one predicts the subclass.

Let’s train the model

  • Dataset class for model training

The data we pre-processed in the 1st part is in a list of dictionaries with question-tokens, question-class, and question-subclass as keys in each dictionary inside the list representing the tokenized question, class of the question, and the subclass of that question. In the Dataset object, we will be padding the question-tokens to the max length of tokens in a question and in our case, it’s 100. We will return the padded question-tokens under the key source_seq and class and subclass of that padded sequence with labels class and subclass.

  • Training Steps

1. Imports, Seeding and Logging

2. Utilities

The functions under the utility section include a function to select a device on which to load a model and data, function to count of trainable model parameters, a function to count the performance of the model after training on a batch, and a function to load some pickle files.

3. Data and Tokenizer Loading

4. Model parameters and model initialization

We will initialize the model parameters like the size of the vocabulary, padding id, class label id ([CLS]), number of classes in every category, the maximum length of the sequence formed from the data, batch size to train with, and number of workers to use in training.

5. Dataloader and Optimizer

6. Train and Save at best accuracy

In training, we will add the losses of class prediction and subclass prediction and then backpropagate the combined loss to the model layers. We will save a model checkpoint at the start and based on the max mean accuracy of class and subclass we will save the best model in the epochs after the initial epochs. The Loss and Accuracy of prediction for both class and subclass will be logged in a file as well as it will be available on the terminal window also.

  • Logs

Inference using the trained model

As the Subclass names to index list have around 47 values I have saved those in a pickle file everything will be available in the repo.

  • Load Classname details
  • Load Subclass and details and index to class and index to subclass
  • Prediction Function

With more data regarding the questions and its class and subclass or a pre-trained transformer model, one can get better performance quickly by fine-tuning for a few epochs. But for certain use cases, it’s hard to find any pre-trained model and sometimes one might need to train a transformer on the task using a bulk of data if there isn’t that time to go through the pre-training process and then fine-tuning.

In the next and the last part (), we will see one more approach to deal with classifying text into multiple categories with a different number of classes in each category. The building blocks i.e. the Encoder and Decoder in that part will also remain the same as here just the Classification Transformer model will change to accommodate two decoders.

The code for all the parts is available in .

If this article helped you in any which way possible and you liked it, please appreciate it by sharing it in among your community. If there are any mistakes feel free to point those out by commenting down below.

To know more about me please click and if you find something interesting just shoot me a mail and we could have a chat over a cup of ☕️.

Machine Learning @ Quinnox | Love to write about Deep Learning for NLP and Computer Vision, Model Deployment, and ReactJS.

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store