# Self-Attention from Scratch in R

EDIT 2022-06-24: this code is now available (with helper functions) in the `R` package attention, which is on CRAN. You can install it simply using:

``````install.packages('attention')
``````

See also my blog post attention on CRAN. The development takes place on GitHub.

This post describes how to implement the attention mechanism - which forms the basis of transformers - in the R language.

The code is translated from the Python original by Stefania Cristina (University of Malta) in her post The Attention Mechanism from Scratch

We begin by generating encoder representations of four different words.

``````# encoder representations of four different words
word_1 = matrix(c(1,0,0), nrow=1)
word_2 = matrix(c(0,1,0), nrow=1)
word_3 = matrix(c(1,1,0), nrow=1)
word_4 = matrix(c(0,0,1), nrow=1)
``````

Next, we stack the word embeddings into a single array (in this case a matrix).

``````# stacking the word embeddings into a single array
words = rbind(word_1,
word_2,
word_3,
word_4)
``````

Next, we generate random integers on the domain `[0,3]`.

``````# generating the weight matrices
set.seed(0)
W_Q = matrix(floor(runif(9, min=0, max=3)),nrow=3,ncol=3)
W_K = matrix(floor(runif(9, min=0, max=3)),nrow=3,ncol=3)
W_V = matrix(floor(runif(9, min=0, max=3)),nrow=3,ncol=3)
``````

In order to keep the numbers the same as in the original `Python` code, you can overwrite the randomly generated values with the values as they were generated by `Python`.

``````# redefine matrices to match random numbers generated by Python in the original code
W_Q = matrix(c(2,0,2,
2,0,0,
2,1,2),
nrow=3,
ncol=3,
byrow = TRUE)
W_K = matrix(c(2,2,2,
0,2,1,
0,1,1),
nrow=3,
ncol=3,
byrow = TRUE)
W_V = matrix(c(1,1,0,
0,1,1,
0,0,0),
nrow=3,
ncol=3,
byrow = TRUE)
``````

Next, we generate the Queries (`Q`), Keys (`K`), and Values (`V`). The `%*%` operator performs the matrix multiplication. You can view the `R` help page using `help('%*%')`.

``````# generating the queries, keys and values
Q = words %*% W_Q
K = words %*% W_K
V = words %*% W_V
``````

Following this, we score the Queries (`Q`) against the Key (`K`) vectors.

``````# scoring the query vectors against all key vectors
scores = Q %*% t(K)
print(scores)
``````
``````     [,1] [,2] [,3] [,4]
[1,]    8    2   10    2
[2,]    4    0    4    0
[3,]   12    2   14    2
[4,]   10    4   14    3
``````

We now need to find the maximum value for each row of the `scores` matrix. We can do this by using the `apply()` (see `help('apply')`) with the `max()` function on `margin=1` (i.e. rows). Donâ€™t worry too much about how this works, the key takeaway is that we find the maximum for each row (and using the wrapping in `as.matrix()` we keep the maxima on their corresponding rows in the new `maxs` matrix.

``````maxs = as.matrix(apply(scores, margin=1, max))
print(maxs)
``````
``````     [,1]
[1,]   10
[2,]    4
[3,]   14
[4,]   14
``````

As you can see, the value for each row in `maxs` is the maximum value of the corresponding row in `scores`.

The weights matrix will be populated using a `for loop` (see `help('for')`). Since the loop does not edit the dimensions of the matrix, we generate a zero matrix (i.e. all values are set to `0`) beforehand, which we then populate using the `for loop`.

``````# initialize weights matrix
weights = matrix(0, nrow=4, ncol=4)
``````

We now populate the `weights` matrix using the `for loop`.

``````# computing the weights by a softmax operation
for (i in 1:dim(scores)[1]) {
weights[i,] = exp((scores[i,]-maxs[i,]) / ncol(K) ^ 0.5)/sum(exp((scores[i,]-maxs[i,]) / ncol(K) ^ 0.5))
}
``````

Finally, we compute the attention as a weighted sum of the value vectors.

``````# computing the attention by a weighted sum of the value vectors
attention = weights %*% V
``````

Now we can view the results using:

``````print(attention)
``````

This gives:

``````          [,1]     [,2]      [,3]
[1,] 0.9852202 1.741741 0.7565203
[2,] 0.9096526 1.409653 0.5000000
[3,] 0.9985123 1.758493 0.7599811
[4,] 0.9956039 1.904073 0.9084692
``````

As you can see, these are the same values as those computed in `Python` in the original post.

The complete code is also available as a Gist on GitHub.

Updated: