large language model - Trouble understanding the formula for estimating dense self-attention FLOPS per Token given as 6LH(2QT) -

admin2025-04-21  1

In the appendix B of the PaLM paper (.02311) it describes a metric called "Model Flops Utilization (MFU)" and the formula for estimating it. It's computation makes use of the stated fact that "The matmuls in dense self-attention add 6LH(2QT) FLOPs per token ... where L, H, Q, and T are the number of layers, the number of heads, the head dimension, and the sequence length respectively.". I need help understanding how the authors arrived at the 6LH(2QT) value.

My thought process is as folows: For every layer and head (LH) we will have to compute self-attention. The self attention would require query-key multiplications. Each of these query-key multiplications would require Q (i.e. head_embedding_dim) multiplications and Q-1 ~ Q additions which would result in approximately 2Q operations. For a given token, this query-key multiplication would happen sequence length (T) many times as a given query vector would multiply the key vectors for every other token. So far in the forward pass of self-attention we would have LH(2QT) operations for each token. But then we should take the backward pass into account too. During backpropagation for every layer and head (LH) we would need to compute gradients for the embeddings (Q), query weights (w_q), key weights (w_k), and value weights (w_v). This would give 4*(LH(2QT)). So by combining the forward pass computation (LH(2QT)) with backward pass computation (4*(LH(2QT))) I got (5LH(2QT)) operations per token which is less than the formula shared by the paper above.

I would really appreciate a detailed and easy to understand explanation. Thank you.

In the appendix B of the PaLM paper (https://arxiv./pdf/2204.02311) it describes a metric called "Model Flops Utilization (MFU)" and the formula for estimating it. It's computation makes use of the stated fact that "The matmuls in dense self-attention add 6LH(2QT) FLOPs per token ... where L, H, Q, and T are the number of layers, the number of heads, the head dimension, and the sequence length respectively.". I need help understanding how the authors arrived at the 6LH(2QT) value.

My thought process is as folows: For every layer and head (LH) we will have to compute self-attention. The self attention would require query-key multiplications. Each of these query-key multiplications would require Q (i.e. head_embedding_dim) multiplications and Q-1 ~ Q additions which would result in approximately 2Q operations. For a given token, this query-key multiplication would happen sequence length (T) many times as a given query vector would multiply the key vectors for every other token. So far in the forward pass of self-attention we would have LH(2QT) operations for each token. But then we should take the backward pass into account too. During backpropagation for every layer and head (LH) we would need to compute gradients for the embeddings (Q), query weights (w_q), key weights (w_k), and value weights (w_v). This would give 4*(LH(2QT)). So by combining the forward pass computation (LH(2QT)) with backward pass computation (4*(LH(2QT))) I got (5LH(2QT)) operations per token which is less than the formula shared by the paper above.

I would really appreciate a detailed and easy to understand explanation. Thank you.

Share Improve this question asked Feb 12 at 21:22 cangozpicangozpi 1492 silver badges8 bronze badges
Add a comment  | 

1 Answer 1

Reset to default 0

This is the way that I understand it.

Let me first illustrate the forward pass.

Given that self-attention is Softmax(QKT)V (ignoring scaling factor in flop calculation, and sorry for the use of the same notation for different things!).

Since we only care about the flop of a single token, our query Q has size (1xQ). K and V has size (TxQ), which the query will use to interact with the neighboring tokens.

If we focus on just 1 head of 1 layer, we can ignore L and H for now. QKT is a multiplication between (1xQ) and (QxT), which has ~2QT operations. This operation yields a single vector of size (1xT)

But there is still the operation of computing the product between Softmax(QKT) and V. The product is between a vector (1xT) and a matrix (TxQ), which has again ~2QT operations.

Combining both steps, we get 2(2QT). Then finally we scale by the number of heads (H) and the number of layers (L), giving us 2LH(2QT) for the forward pass. If we take the backward to be twice the flop of the forward pass, we get:

2LH(2QT) (1 + 2) = 6LH(2QT) = 12LHQT flops per token.

转载请注明原文地址:http://conceptsofalgorithm.com/Algorithm/1745201423a290076.html

最新回复(0)