The T5 language model has proved hugely popular since it first appeared in Hugging Face Transformers. There have also been constant demands to make T5 runnable at float16 precision.
Until now, T5 has only worked with hardware that supports bfloat16, the format that the model was originally trained with. This has limited its use to select CPUs, TPUs beyond v2, and GPUs beyond A100.
The best alternative – using float32 – typically leads to exceeding hardware memory limits or simply taking too long to execute, compared to running in float16.
With the release of FLAN-T5, we were keen to offer these models running on our IPUs – which means using float16.
In this blog, we are delighted to present our FLAN-T5 for IPU solution. While this has been developed specifically for the T5 model, the methods are reusable and can help you in similar scenarios.
Porting T5 to float16 on IPU
Identifying dynamic parts of the computational graph
Before running the model we need to carry out a quick visual inspection of the model code to look for parts that won’t compile into a static graph. We found dynamic branching of the graph in the T5Block
. Coincidentally, the branches that are created clamp the data if it has already overflowed in float16:
We chose to remove the dynamic condition, torch.isinf(hidden_states).any()
, from this branch* because:
- We cannot statically compile this dynamic branching condition
- While clamping the hidden states only treats the symptom of float16 issues, it is still needed for training and so cannot be removed entirely. See the “FeedForward’s down projection” section for details on how we treated the cause for inference.
*this change has also been made in the latest version of Transformers
Enabling Poplar’s floating-point exception detection
Our Poplar backend has floating-point exception detection built-in, which makes tracking down the source of numerical issues far more straightforward. The process consists of the following steps:
- Enable floating-point exceptions in your application. In PopTorch, you can use
opts.Precision.enableFloatingPointExceptions(True)
(For more information see the PopTorch User Guide)
- Run your application with graph profiling enabled:
POPLAR_ENGINE_OPTIONS"='{"autoReport.all":"true", "autoReport.outputExecutionProfile": "false", "autoReport.directory":"./report"}'
For more details see the section on capturing IPU reports in the Graph Analyser User Guide*.
- If a floating-point exception is triggered a
poptorch_error.log
file will be generated. Open this file and scroll down to (or search for) Backtrace
. Find the ID nearest the top of the backtrace, denoted by (Id: 1234), and search for it in the graph profile’s program tree. From here you should be able to examine the debug information of the offending operation and figure out where in the model it came from.
*note that we use "autoReport.outputExecutionProfile": "false"
to avoid the overhead of profiling the execution. We can do this because we are only interested in the program tree.
Using this method, we solved the rest of the floating-point exceptions.
Resolving the floating-point exceptions
Attention Masking
The first two exceptions were found in the attention masking. In two places the attention mask was “inverted” and used additively. The mask value was set to -torch.finfo(torch.float16).min
(i.e.-65504)
and the pass value was set to 0. This was done so that when the masked attention values are passed to softmax
they have minimum relevance in the resulting output. However, if what you were masking was negative and had an absolute value greater than the resolution of float16 at -65504, then you would end up with a negative infinity:
We solved these two exceptions by simply scaling the mask down by 25%, meaning that you could have attention values as low as -16376 without the mask causing an overflow.
GeLU approximated by tanh
The third exception was found in the explicit definition of the tanh GeLU approximation used by the FLAN-T5 model (the original T5 model used ReLU activations). The formula
cubes the input, which will cause an overflow if the absolute value of the input is larger than approximately 39. We fixed this by reverting to ReLU when the input was larger than 39, which is a safe approximation to make since ReLU==GeLU when the absolute value of the input is >5.
Pre-norm residual connections
The fourth exception was found in the residual additions in the encoder’s FF layers. We were seeing that, when the output of the FF network was added to its input, the operation was overflowing. We solved this by:
- Casting the embeddings input to the first encoder block to float32
- For the SelfAttention and FeedForward layers in every encoder block:
- Cast to float16 after LayerNorm* so that the bulk of the compute still happens in float16
- Cast to float32 after the dropout before adding to the float32 residual connection
- Cast the output of final_layer_norm* after all the encoder blocks back to float16 ready for the decoder, which is all float16
*this actually happened automatically because of the way that LayerNorm was implemented for T5
The following diagrams are colour coded as follows to represent the precision of the data:
The T5 encoder consists of a chain of blocks, each block contains a SelfAttention layer and a FeedForward layer:
Each of these layers has the same fundamental structure, with the only difference being the Attention/Hidden layer:
After the casting changes mentioned in step 2 above, these layers look like:
This prevents overflow in the pre-norm residuals that get passed all the way through the encoder.
FeedForward’s down projection
The final floating-point exception was found in the down projection in the Hidden part of the encoder’s FeedForward layer. In the code this is the wo
layer, which we shall refer to as DownProject for clarity. Currently, the FeedForward layer and its Hidden component look like this:
We were able to resolve the overflow in DownProject by scaling down its input and then scaling up its output once it was safely in float32 again.
The scaling factor was chosen by examining the standard deviation of the activations coming out of DownProject and identifying a suitable power of 2 that would tame these activations. We want to use a power of two because then only the exponents of the float16 activations need to be changed, avoiding lossy modification of the mantissa.
We found that the standard deviation was ~4400 and so we chose 8 as our scaling factor to reduce the standard deviation to ~550. After implementing this scaling, the FeedForward layer and its Hidden component look like this:
The solution to this problem in the latest version of Transformers keeps this layer in float32 at all times.
Validation
Since we’ve changed a few things in the model, you’re probably wondering if the model still performs as it is supposed to. We wondered this too, and so validated it on a subset* of the MMLU benchmark on CPU in float32 and on IPU in float16. The CPU and IPU achieved overall averages of 49.3% and 49.4% respectively, proving that we have not degraded the performance of the original model.
*Our current implementation of FLAN-T5-XL has a maximum input length of 896 tokens, so we used the subset of MMLU where the examples did not exceed this length.
Conclusion
With this, we now have FLAN-T5-XL implementation that can be used for inference on IPU in float16. Please head over to Paperspace to try it out for yourself!