This technical overview explores the specific methods and tools within the JAX and MaxText ecosystems designed to refine training efficiency and reach peak performance on Ironwood hardware.
Key optimization strategies for Ironwood
1. Leverage native FP8 with MaxText
Ironwood is the first TPU generation with native 8-bit floating point (FP8) support in its Matrix Multiply Units (MXUs). By utilizingFP8 precision for weights, activations, and gradients, users can theoretically double throughput compared to Brain Floating Point 16 (BF16). When FP8 recipes are configured correctly, increased efficiency is achievable without compromising model quality.
To implement these FP8 training recipes, users can start with the Qwix library. This functionality is enabled by specifying the relevant flags within the MaxText configuration. ,
See our blog post, Inside the optimization of FP8 training on Ironwood, in the Google Developer forums for more details.
2. Accelerate with Tokamax kernels
Tokamax is a library of high-performance JAX kernels optimized for TPUs. These kernels are designed to mitigate specific bottlenecks through the following mechanisms:
-
Splash Attention: This mechanism addresses the I/O limitations inherent in standard attention processes. By maintaining computations within on-chip SRAM, it is particularly effective for processing long context lengths where memory bandwidth typically becomes a constraint.
-
Megablox Grouped Matrix Multiplication (GMM): This manages the “ragged” tensors often found in Mixture of Experts (MoE) models. By utilizing GMM, the system avoids inefficient padding and ensures higher utilization of the MXU.
-
Kernel tuning: The Tokamax library includes Utilities for hyperparameter optimization. These tools allow for the adjustment of tile sizes and other configurations to align with the specific memory hierarchy of the Ironwood TPU.
3. Offload collectives to SparseCore
The fourth-generation SparseCores in Ironwood are processors specifically designed to manage irregular memory access patterns. By using specific XLA flags, users can offload collective communication operations—such as All-Gather and Reduce-Scatter—directly to the SparseCore.
This offloading mechanism allows the TensorCores to remain dedicated to primary model computations while communication tasks execute in parallel. This functional overlap is a critical strategy for hiding communication latency and ensuring consistent data throughput to the MXUs.
4. Fine-tune the memory pipeline on VMEM
VMEM, a critical part of the TPU memory architecture, is a fast on-chip SRAM that is designed to optimize kernel performance. You can improve the overall speed of execution by tuning the allocation of VMEM between current operation and future weight prefetch. For example, increasing the VMEM reserved for the current scope allows increasing the tile sizes used by the kernel, which can increase kernel performance by removing potential memory stalls.
Refer to TPU Pipelining for more on TPU memory architecture.
5. Choose optimal sharding strategies
Lastly, MaxText supports various parallelism techniques which are available on all TPUs. The best choice depends on model size, architecture (Dense vs. MoE), and sequence length. Selecting a proper sharding strategy can improve the performance of the model:
-
Fully Sharded Data Parallelism (FSDP): This is the preferred strategy for training large models that exceed the memory capacity of a single chip. FSDP shards model weights, gradients, and optimizer states across multiple chips. Increasing the per-device batch size and introducing more compute can hide the latency of the All-Gather operations and improve efficiency.
-
Tensor Parallelism (TP): Shards individual tensors. Given Ironwood’s high arithmetic intensity, TP is most effective for very large model dimensions. Leveraging TP with a dimension of 2 can take advantage of the fast die-to-die interconnect on Ironwood’s dual-chiplet design.
-
Expert Parallelism (EP): Helpful for MoE models to distribute experts across devices.
-
Context Parallelism (CP): Necessary for very long sequences, sharding activations along the sequence dimension.
-
Hybrid approaches: Combining strategies is often required to balance compute, memory, and communication on large-scale runs.
See the Optimizing Frontier Model Training on TPU v7x Ironwood post in the Developer forums for more detail on techniques 2-5 above.
The Ironwood advantage: System-level performance
These optimization techniques, coupled with Ironwood’s architectural strengths like the high-speed 3D Torus Inter-Chip Interconnect (ICI) and massive HBM capacity, create a highly performant platform for training frontier models. The tight co-design across hardware, compilers (XLA), and frameworks (JAX, MaxText) ensures you can extract maximum performance from your AI Infrastructure.
Ready to accelerate your AI journey? Explore the resources below to dive deeper into each optimization method.
Further reading
A special thanks to Hina Jajoo and Amanda Liang for their contributions to this blog post.
Source Credit: https://cloud.google.com/blog/products/compute/training-large-models-on-ironwood-tpus/
