Supercharging Computational Fluid Dynamics with TPUs and JAX: An Alternative Workflow for Physicists and Engineers
If you work in computational physics — whether it’s weather modeling, astrophysics, or Computational Fluid Dynamics (CFD) — you might be ignoring one of the most powerful tools currently at your disposal. While the tech world obsesses over Generative AI, the underlying infrastructure powering that revolution holds a distinct, often overlooked opportunity for physical sciences. By repurposing modern AI libraries like JAX, researchers can inherit billions of dollars of investment in hardware optimization and automatic differentiation.
Currently, your workflow probably looks something like this:
- You have an idea for a new physical model or boundary condition.
- You prototype it quickly in Python using NumPy or MATLAB to validate the math on a coarse grid.
- It works, but it’s too slow for a “real” science run.
- The bottleneck: You hand your prototype off to a research software engineer to rewrite it in C++, Fortran, or CUDA so it can run on an HPC cluster.
- Wait.
- Inspect results
- Adjust model and iterate from #1
There is a massive gap between the flexibility of Python and the performance required for high-resolution simulations. We’ve accepted that crossing this gap requires changing languages. We’ve also accepted that manual iteration is required to tune engineering designs.
But what if it didn’t? What if the hardware designed to accelerate AI could also accelerate fundamental physics, allowing you to keep your Python syntax while running at supercomputer speeds?
Double-click on the iterative cycle in Step #7, where you manually tune the cross-section to find the shape with the least turbulence. That manual loop is effectively a backpropagation algorithm; The same mathematical engine AI researchers use to train neural networks. The difference? If you define a differentiable loss function, the framework can perform that tuning for you, automatically homing in on the ideal solution without manual guessing.
Enter JAX and Google’s Tensor Processing Units (TPUs).
Why TPUs for Physics?
Most researchers hear “TPU” and think of Large Language Models or generative art. While TPUs are specialized for AI, it’s important to understand why.
At their core, TPUs are massive matrix multiplication engines paired with incredible amounts of High Bandwidth Memory (HBM).
It turns out that many scientific methods, particularly structured grid solvers like Finite Difference or the Lattice Boltzmann Method (LBM), look mathematically identical to deep learning operations:
- They involve massive, dense tensor operations.
- They are bandwidth-bound (moving data from RAM to the processor is usually the bottleneck).
A TPU doesn’t know if it’s multiplying weights for a neural network or calculating particle distribution functions for fluid flow. It just sees massive matrices and crunches them efficiently.
By using JAX — a library that looks like NumPy but compiles to XLA (accelerated linear algebra) — we can target this hardware directly from Python. We can also reuse the built-in differentiation capabilities.
The video below goes over automatic differentiation in JAX.
https://medium.com/media/a139c62141241475b7c53d7b56f6d1ca/href
While geared toward AI, these same capabilities are valuable for computational modeling of physical systems.
The Challenge: A Real-World CFD Example
To drive this discussion, we’ll use a realistic 2D CFD simulation using the Lattice Boltzmann Method to model turbulent flow past a cylinder (the classic Von Kármán vortex street).
Get the example notebook: https://github.com/bernieongewe/Differentiable-Fluid-Dynamics-on-TPUs
The example demonstrates two things:
- Scaling potential. While the JAX example deliberately uses a coarse grid with few iterations so you can run it without access to a TPU. However, if you increase the granulality and iterations you quickly find that the execution slows to a crawl if it doesn’t crash altogether. For instance, a 4096×4096 grid is nearly 17 million cells, with 9 velocity vectors per cell, updated at every time step. Just storing the state required several gigabytes of RAM and the CPU significant amounts of its time waiting to fetch this data from DDR RAM.
- Differentiable physics with JAX: We use these capabilities to iteratively shape the cross section as we home in on a profile that causes the least turbulence
Feel free to experiment with grid size and iterations wth CPUs and TPUs. You should see results similar to that in the example video below;
https://medium.com/media/34da01f6527d14c521749054b1b0b72c/href
Take note of how widely the performance varies with the framework (JAX vs NumPy) and processor (TPU vs CPU).
2. The Real Advantage: Differentiable Physics
While the speed is great, JAX offers something even more valuable to engineers: automatic differentiation.
In traditional CFD, you run a simulation to see what happens. But engineers often want to know how to optimize it. For example: “What shape of cylinder minimizes turbulence?”
Answering this usually requires deriving complex “adjoint equations” and writing specialized solvers.
Because JAX is differentiable, we can define a “loss function” (e.g., total turbulence) and ask JAX to calculate the gradient of the entire physics simulation with respect to the obstacle’s shape. It runs the physics backward in time to find the answer.
With just a few lines of code, we generate this “Sensitivity Map”:

- The Blue Area: The math is telling us, “If you add solid material here (creating a familiar nose cone), turbulence goes down.”
- The Red Area: “Adding material here makes it worse.”
This is physical insight, generated automatically by the framework.
Conclusion: Own Your Workflow
The combination of JAX and TPUs offers a new path for scientific computing. It allows domain experts — physicists, meteorologists, engineers — to write high-level Python code that automatically leverages supercomputer-class hardware.
You no longer need to wait for the “C++ middleman.” You can go from a mathematical idea to a massive, differentiable, high-performance simulation in a single afternoon.
Supercharging Computational Fluid Dynamics with TPUs and JAX: An Alternative Workflow for… was originally published in Google Cloud – Community on Medium, where people are continuing the conversation by highlighting and responding to this story.
Source Credit: https://medium.com/google-cloud/beyond-ai-supercharging-computational-fluid-dynamics-with-tpus-and-jax-dabef1dff928?source=rss—-e52cf94d98af—4
