Self-Attention from Scratch in R
EDIT 2022-06-24: this code is now available (with helper functions) in the R
package attention, which is on CRAN. You can install it simply using:
install.packages('attention')
See also my blog post attention on CRAN. The development takes place on GitHub.
This post describes how to implement the attention mechanism - which forms the basis of transformers - in the R language.
The code is translated from the Python original by Stefania Cristina (University of Malta) in her post The Attention Mechanism from Scratch
We begin by generating encoder representations of four different words.
# encoder representations of four different words
word_1 = matrix(c(1,0,0), nrow=1)
word_2 = matrix(c(0,1,0), nrow=1)
word_3 = matrix(c(1,1,0), nrow=1)
word_4 = matrix(c(0,0,1), nrow=1)
Next, we stack the word embeddings into a single array (in this case a matrix).
# stacking the word embeddings into a single array
words = rbind(word_1,
word_2,
word_3,
word_4)
Next, we generate random integers on the domain [0,3]
.
# generating the weight matrices
set.seed(0)
W_Q = matrix(floor(runif(9, min=0, max=3)),nrow=3,ncol=3)
W_K = matrix(floor(runif(9, min=0, max=3)),nrow=3,ncol=3)
W_V = matrix(floor(runif(9, min=0, max=3)),nrow=3,ncol=3)
In order to keep the numbers the same as in the original Python
code, you can overwrite the randomly generated values with the values as they were generated by Python
.
# redefine matrices to match random numbers generated by Python in the original code
W_Q = matrix(c(2,0,2,
2,0,0,
2,1,2),
nrow=3,
ncol=3,
byrow = TRUE)
W_K = matrix(c(2,2,2,
0,2,1,
0,1,1),
nrow=3,
ncol=3,
byrow = TRUE)
W_V = matrix(c(1,1,0,
0,1,1,
0,0,0),
nrow=3,
ncol=3,
byrow = TRUE)
Next, we generate the Queries (Q
), Keys (K
), and Values (V
). The %*%
operator performs the matrix multiplication. You can view the R
help page using help('%*%')
.
# generating the queries, keys and values
Q = words %*% W_Q
K = words %*% W_K
V = words %*% W_V
Following this, we score the Queries (Q
) against the Key (K
) vectors.
# scoring the query vectors against all key vectors
scores = Q %*% t(K)
print(scores)
[,1] [,2] [,3] [,4]
[1,] 8 2 10 2
[2,] 4 0 4 0
[3,] 12 2 14 2
[4,] 10 4 14 3
We now need to find the maximum value for each row of the scores
matrix. We can do this by using the apply()
(see help('apply')
) with the max()
function on margin=1
(i.e. rows). Don’t worry too much about how this works, the key takeaway is that we find the maximum for each row (and using the wrapping in as.matrix()
we keep the maxima on their corresponding rows in the new maxs
matrix.
maxs = as.matrix(apply(scores, margin=1, max))
print(maxs)
[,1]
[1,] 10
[2,] 4
[3,] 14
[4,] 14
As you can see, the value for each row in maxs
is the maximum value of the corresponding row in scores
.
The weights matrix will be populated using a for loop
(see help('for')
). Since the loop does not edit the dimensions of the matrix, we generate a zero matrix (i.e. all values are set to 0
) beforehand, which we then populate using the for loop
.
# initialize weights matrix
weights = matrix(0, nrow=4, ncol=4)
We now populate the weights
matrix using the for loop
.
# computing the weights by a softmax operation
for (i in 1:dim(scores)[1]) {
weights[i,] = exp((scores[i,]-maxs[i,]) / ncol(K) ^ 0.5)/sum(exp((scores[i,]-maxs[i,]) / ncol(K) ^ 0.5))
}
Finally, we compute the attention as a weighted sum of the value vectors.
# computing the attention by a weighted sum of the value vectors
attention = weights %*% V
Now we can view the results using:
print(attention)
This gives:
[,1] [,2] [,3]
[1,] 0.9852202 1.741741 0.7565203
[2,] 0.9096526 1.409653 0.5000000
[3,] 0.9985123 1.758493 0.7599811
[4,] 0.9956039 1.904073 0.9084692
As you can see, these are the same values as those computed in Python
in the original post.
The complete code is also available as a Gist on GitHub.