
In the above code, Nick explained that there are at least two different ways models are being combined — some loss terms that are being linearly combined (e.g. the AF loss + the ESM pseudo log likelihood loss), and some terms where models are being composed serially (e.g., in the first Boltz-1 term we first fold the sequence with Boltz-1 and then compute the sequence likelihood after inverse folding with another model, ProteinMPNN).
To make this work, they embraced the JAX ecosystem, even translating models from PyTorch themselves — a prime example being their JAX translation of the Boltz-2 structure prediction model.
This approach gives what April called an “expressive language for protein design,” where models can be composed, added, and transformed to define a final objective. April said that the most incredible part is that this entire, complex graph of models “can be wrapped in a single jax.jit call that gives great performance” — something they found very difficult to do in other frameworks.
Instead of a typical training run that optimizes a model’s weights, their workflow inverts the process to optimize the input itself, using a collection of fixed, pre-trained neural networks as a complex, multi-objective loss function. The approach is mechanically analogous to Google’s DeepDream. Just as DeepDream takes a fixed, pre-trained image classifier and uses gradient ascent to iteratively modify an input image’s pixels to maximize a chosen layer’s activation, Escalante’s method starts with a random protein sequence. This sequence is fed through a committee of “expert” models — each one a pre-trained scorer for a different desirable property, like binding affinity or stability. The outputs from all the models are combined into a single, differentiable objective functional. They then calculate a gradient of this final score with respect to the input sequence via backpropagation. An optimizer then uses this gradient to update the sequence, nudging it in a direction that better satisfies the collective requirements of all the models. This cycle repeats, evolving the random initial input into a novel, optimized protein sequence that the entire ensemble of models “believes” is ideal.
Nick said that the choice of JAX was critical for this process. Its ability to compile and automatically differentiate complex code makes it ideal for optimizing the sophisticated loss functions at the heart of their work with Escalante’s library of tools for their protein design work, Mosaic. Furthermore, the framework’s native integration with TPU hardware via the XLA compiler allowed them to easily scale these workloads.
Escalante is sampling many potential protein designs for solving a problem (by optimizing the loss function). Each sampling job might generate 1K – 50K potential designs, which are then ranked and filtered. By the end of the process, they test only about 10 designs in the wet lab. This has led them to adopt a unique infrastructure pattern. Using Google Kubernetes Engine) (GKE), they instantly spin up 2,000 to 4,000 spot TPUs, run their optimization jobs for about half an hour, and then shut them all down.
Nick also shared the compelling economics driving this choice. Given current spot pricing, adopting Cloud TPU v6e (Trillium) over an H100 GPU translated to a gain of 3.65x in performance per dollar for their large-scale jobs. He stressed that this cost-effectiveness is critical for their long-term goal of designing protein binders against the entire human proteome, a task that requires immense computational scale.
To build their system, they rely on key libraries within the JAX ecosystem like Equinox and Optax. Nick prefers Equinox because it feels like “vanilla JAX,” calling its concept of representing a model as a simple PyTree “beautiful and easy to reason about.” Optax, meanwhile, gives them the flexibility to easily swap in different optimization algorithms for their design loops.
They emphasized that this entire stack — JAX’s functional core, its powerful ecosystem libraries, and the scalable TPU hardware — is what makes their research possible.
Explore the JAX scientific computing ecosystem
Stories like this are part of a vibrant new ecosystem in computational biology, one spawned in major part by the Nobel Prize-winning work on AlphaFold from Google DeepMind’s Demis Hassabis and John Jumper. Escalante’s work is a powerful example of how the foundational ideas of predicting protein structure are now being used to actively design new ones.
We are excited to see community contributions like Escalante’s Mosaic library, which contains the tools for their protein design work and is now available on GitHub. It’s a fantastic addition to the landscape of JAX-native scientific tools.
Stories like this highlight a growing trend: JAX is much more than a framework for deep learning. Its powerful system of program transformations, like grad and jit, makes it a foundational library for the paradigm of differentiable programming, empowering a new generation of scientific discovery. The JAX team at Google is committed to supporting and growing this vibrant ecosystem, and that starts with hearing directly from you.
-
Share your story: Are you using JAX to tackle a challenging problem?
-
Help guide our roadmap: Are there new features or capabilities that would unlock your next breakthrough?
Your feature requests are essential for guiding the evolution of JAX. Please reach out to the team to share your work or discuss what you need from JAX via GitHub.
Our sincere thanks to April and Nick for sharing their insightful journey with us. We’re excited to see how they and other researchers continue to leverage JAX to solve the world’s most complex scientific problems.
Source Credit: https://cloud.google.com/blog/topics/customers/escalante-uses-jax-on-tpus-for-ai-driven-protein-design/