There is a lot of hype recently around the GPT-3 language model developed by OpenAI. I will dig deeper into the architecture of this language models using PyTorch
OpenAI released this public API, that allow people to access their GPT language model. Paul Graham claims it might be the next Altair, that justify me spending the afternoon looking into it.
Maybe. It may be the Altair.
— Paul Graham (@paulg) July 19, 2020
I will be following the transformer tutorial from PyTorch (since I use it at work), that implements the famous paper attention is what you need, that AFAIK has been in 2017 the seminal work that started the new trend in NN architecture (BERT, ELM) that eventually led to GPT models from OpenAI. To go over the paper there is a nice 1h Paper Reading Group by Rachel Tatman and also a famous blog post Illustrated Transformers by Jay Alammar. Another interesting blog post The Annotated Transformer from Stanford NLP that is the text of the paper plus additional PyTorch code from ground up.
The tutorial is about training a nn.TransformerEncoder to assign a probability for the likelihood of a given word (or a sequence of words) to follow a sequence of words. Given the Encoder and Decoder, building the model is pretty simple, as explained in define-the-model section.
We take our text input, and we add two parts to it.
1mask = self._generate_square_subsequent_mask(len(src))
mask = self._generate_square_subsequent_mask(len(src))The most straightforward is the mask, that ensures that predictions for position i depends only on the known outputs at position less than i by masking those.
1src = self.pos_encoder(src)
src = self.pos_encoder(src): the second is the positional encoding, we basically can't just pass plain text to the NN, but we need to encode it with relative and absolute positional data that might look like "this is the third word in the sentence, the one before cat". The encoding is actually more complex, and done using sine and cosine functions of different frequencies.
So when we fix the above, we just run our data through Encoder and Decoder and we have our Transformer! Nice :) But of course I want to understand what is inside this nn.TransformerEncoder so before running the training I will have a look inside that.
To follow what happens inside the encoder we can look at the nn.TransformerEncoder source code that has pretty decent docs. At the same time is pretty helpful to keep a look both at Self-Attention at a High Level from jalammar that has a nice English words explaination of what's going on, and also at the Attention code from the original paper. We are also at part two of the Kaggle Reading Group.
The jist of the transformer is in the TransformerEncoderLayer.forward method. There we take the input sequence, and we pass it to the self attention, and then we pass it to the feed forward network. The self attention layer is implemented in torch.nn.MultiheadAttention, where we pass key value and query in the forward pass, and we get out the embeddings.
1self.self_attn(src, src, src)
Reading the code the most surprising line was this one self.self_attn(src, src, src), where those three arguments are respectively query, key and value. That is because query component in the self-attention equation is the matrix multiplication of the input sentence (src) by the query weights, and the same for key and values. The rest of the modules just uses the self-attention equation to output the embeddings.
The more I dive into the code (i.e. implementations details from the paper) the more everything looks more and more like alchemy: for some not-well understood reasons is better not to have only one attention layer but 8 parallels heads that are identical in architecture but obtains different weights through training. Also the method of reconciling this 8 heads after training is peculiar: we concatenate the 8 embeddings and we learn an additional weighting tensor that reduces it to a standard-size embedding that can be sent to the FNN.
After some theory I want to have an intuitive understanding of what happens inside transformer. A good approach is to examine sentence embeddings using some encoder like BERT (the name actually means Encoded Representation from Transformer). A good repo is bert-as-a-service that offer some helpful code snippets. The one I will try to use is this inner layers dimensionality reduction script: it takes one of the internal transformer layers of BERT of size [N*H] and uses PCA to project it to a 2D plan to visualise how different sentences have different embeddings.
The results look good. I embedded sentences from the music and computer pages of Wikipedia and I plotted their BERT representation on a 2D plane. The topic distribution looks coherent. You can find my implementation notebook here.