Implementing the Self-Attention Mechanism from Scratch in PyTorch!
Attention Mechanism - PyTorch - Transformers
You may have seen different explanations of how the self-attention mechanism works, but it is quite likely that it is still a little abstract for you. Personally, everything becomes much clearer when I see things being coded. So, I am going to show you how to implement the self-attention layer from scratch in PyTorch.
I made a previous video to explain the logic behind it, so check it out if you haven't already. Here, I am just going to focus on the coding aspect! Let's get started!
So, let’s create an Attention
class with two arguments: d_in
, the size of the input vectors, and d_out
, the size of the output vectors
import torch
import torch.nn as nn
import torch.nn.functional as F
class Attention(nn.Module):
def __init__(self, d_in, d_out):
super().__init__()
self.d_in = d_in
self.d_out = d_out
We then add, as attributes of the class, three matrices that will be used to project from the input tensor to the Keys, Queries, and Values:
class Attention(nn.Module):
def __init__(self, d_in, d_out):
super().__init__()
self.d_in = d_in
self.d_out = d_out
# will be used to project from the input
# tensor to the Keys, Queries, and Values
self.Q = nn.Linear(d_in, d_out)
self.K = nn.Linear(d_in, d_out)
self.V = nn.Linear(d_in, d_out)
Now, we can focus on the forward function. First, as promised, we project the input tensor into Keys, Queries, and Values:
class Attention(nn.Module):
...
def forward(self, x):
# we project the input tensor into Keys, Queries, and Values
queries = self.Q(x)
keys = self.K(x)
values = self.V(x)
...
Then, we create an interaction matrix between the Keys and the Queries by performing a batched matrix multiplication to understand how the different tokens interact with each other. We also normalize this matrix by the size of the output vector such that the resulting matrix is independent of the size of the queries and keys:
class Attention(nn.Module):
...
def forward(self, x):
# we project the input tensor into Keys, Queries, and Values
queries = self.Q(x)
keys = self.K(x)
values = self.V(x)
# interaction matrix between Keys and Queries
scores = torch.bmm(queries, keys.transpose(1, 2))
scores = scores / (self.d_out ** 0.5)
...
We can now create the self-attention by mapping from scores to probability-like values with the softmax transformation:
class Attention(nn.Module):
...
def forward(self, x):
# we project the input tensor into Keys, Queries, and Values
queries = self.Q(x)
keys = self.K(x)
values = self.V(x)
# interaction matrix between Keys and Queries
scores = torch.bmm(queries, keys.transpose(1, 2))
scores = scores / (self.d_out ** 0.5)
# we create the self-attention
attention = F.softmax(scores, dim=2)
The self-attentions are used as weights to compute a weighted average of the Values. The resulting tensor represents the internal hidden states within the Transformer model:
class Attention(nn.Module):
...
def forward(self, x):
# we project the input tensor into Keys, Queries, and Values
queries = self.Q(x)
keys = self.K(x)
values = self.V(x)
# interaction matrix between Keys and Queries
scores = torch.bmm(queries, keys.transpose(1, 2))
scores = scores / (self.d_out ** 0.5)
# we create the self-attention
attention = F.softmax(scores, dim=2)
# The self-attentions are used as weights
# to compute a weighted average of the Values.
hidden_states = torch.bmm(attention, values)
return hidden_states
Easy peasy!
SPONSOR US
Get your product in front of more than 63,000 tech professionals.
Our newsletter puts your products and services directly in front of an audience that matters - tens of thousands of engineering leaders and senior engineers - who have influence over significant tech decisions and big purchases.
To ensure your ad reaches this influential audience, reserve your space now by emailing damienb@theaiedge.io.