Ragged Paged Attention: A High-Performance and Flexible LLM Inference Kernel for TPU
Summary: arXiv:2604.15464v1 Announce Type: cross
The deployment of Large Language Models (LLMs) is increasingly transitioning to cost-efficient accelerators like Google’s Tensor Processing Units (TPUs). This shift emphasizes the need for both performance and total cost of ownership (TCO). Despite this trend, existing LLM inference kernels and serving systems are primarily designed around GPU architectures. Consequently, there is a lack of established methodologies for efficiently mapping LLM workloads onto TPU architectures, particularly when addressing the dynamic and ragged execution patterns that are prevalent in modern serving scenarios.
Introduction to Ragged Paged Attention
In response to these challenges, we introduce Ragged Paged Attention (RPA), a high-performance and flexible attention kernel designed for use with TPUs. RPA is implemented using advanced frameworks such as Pallas and Mosaic, providing a robust solution for the efficient execution of LLM workloads on TPU hardware.
Key Techniques of RPA
The RPA kernel incorporates three key techniques to enhance its performance and flexibility:
- Fine-Grained Tiling: This technique enables efficient dynamic slicing over ragged memory, allowing for better memory management and utilization.
- Custom Software Pipeline: RPA features a unique software pipeline that fuses key-value (KV) cache updates with attention computations, minimizing latency and maximizing throughput.
- Distribution-Aware Compilation Strategy: By generating specialized kernels tailored for different workloads—such as decode, prefill, and mixed workloads—RPA optimizes performance across various use cases.
Performance Evaluation
In an evaluation using the Llama 3 8B model on TPU7x, RPA demonstrated impressive performance metrics. It achieved up to 86% memory bandwidth utilization (MBU) during decode operations and 73% model FLOPs utilization (MFU) during prefill tasks. These results indicate that RPA not only meets but exceeds performance expectations for LLM inference on TPU architectures.
Integration and Impact
RPA has been integrated as the primary TPU backend in both vLLM and SGLang, providing a production-grade foundation for efficient TPU inference. This integration highlights RPA’s significance in the evolving landscape of LLM deployment and offers practical insights into the design of kernels for TPUs.
Conclusion
The introduction of Ragged Paged Attention marks a significant advancement in the realm of LLM inference on TPUs. By addressing the unique challenges posed by ragged execution patterns and providing high-performance solutions, RPA stands to enhance the efficiency and effectiveness of deploying large language models in various applications. As the demand for LLMs continues to grow, solutions like RPA will play a crucial role in optimizing performance while managing costs.
