
Fine Tuning Gemma 3 using MaxText and XPK
Training large models efficiently requires powerful, scalable infrastructure. Google’s AI Hypercomputer, with its JAX framework and Cloud TPUs, provides just that, but mastering the workflow can be a challenge. This guide simplifies the process by demonstrating a complete, end-to-end example. We’ll fine-tune the Gemma 3 12B vision-language model on the ChartQA dataset using MaxText, the high-performance JAX-based LLM training library. From setting up a Google Kubernetes Engine (GKE) cluster with the XPK orchestrator to converting model checkpoints and serving with JetStream, this tutorial covers every step needed to leverage Google’s affordable and powerful compute for your AI workloads.
Let’s get started by setting up our workspace. All the things you should need to get you through this tutorial are located in a GitHub repository to make this process as smooth as possible. Before you begin, ensure you have the following tools installed: the gcloud CLI, Docker, uv, kubectl, and Git.
First, clone the repository and create your personal configuration file from the provided template. It’s crucial that you open the .env
file and edit it with your specific project details (like project ID and GCS bucket name), as this file will configure all subsequent steps.
# setup uv
curl -LsSf https://astral.sh/uv/install.sh | sh#setup gclou d
sudo apt-get update
sudo apt-get install apt-transport-https ca-certificates gnupg curl
curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo gpg --dearmor -o /usr/share/keyrings/cloud.google.gpg
echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | sudo tee -a /etc/apt/sources.list.d/google-cloud-sdk.list
sudo apt-get update && sudo apt-get install google-cloud-cli
gcloud init
# clone the repo
git clone git@github.com:bdicegoogle/maxtext_chartqa.git
cp .env.example .env
# Edit .env with your configuration
# I HIGHLY recommend going through this script manually, and running it
# one line / chunk at a time!
./run_experiment.sh
Kubernetes is great for managing complex workloads across diverse compute, but for someone who is more focused on building the best possible models, it’s a lot to manage all on your own. Google Kubernetes Engine (GKE) makes Kubernetes easier to handle then ever — and the XPK library makes the whole process even easier , by streamlining cluster creation and job submission and through its integration with MaxText. The end result is that you can handle the complexity through tools that let you focus on the model building and training workflow.
To create a cluster to work on:
uv run xpk cluster create \
--cluster ${CLUSTER_NAME} \
--tpu-type=${TPU_TYPE} \
--num-slices=${NUM_SLICES} \
--zone ${ZONE} --on-demand \
--custom-cluster-arguments='--enable-shielded-nodes --shielded-secure-boot --shielded-integrity-monitoring' \
--custom-nodepool-arguments='--shielded-secure-boot --shielded-integrity-monitoring'
MaxText has its own format during the model training / serving workflows ( as well as scanned vs unscanned model files ) , so the next step of our journey is to convert the model we are interested in into that format. I’ve used google/gemma3–12b-it for the workflow below. You can submit an XPK job to run this — I found it to be surprisingly RAM intensive, so for bigger models you may want to pick a specific machine type that has large amounts of RAM.
uv run xpk workload create \
--workload make-model-gemma12b \
--docker-image ${BASE_IMAGE} \
--cluster ${CLUSTER_NAME} \
--tpu-type ${TPU_TYPE} \
--zone ${ZONE} \
--num-slices=${NUM_SLICES} \
--command "python3 -m MaxText.utils.ckpt_conversion.to_maxtext MaxText/configs/base.yml \
model_name=${MODEL_NAME} \
hf_access_token=${HF_TOKEN} \
base_output_directory=${MODEL_BUCKET}/${MODEL_VARIATION}/unscanned/${idx} \
use_multimodal=${USE_MULTIMODAL} \
scan_layers=${SCAN_LAYERS}"
Now its time for the the fun part. MaxText makes setting up the training step straightforward :
- We can use Huggingface datasets — MaxText will convert it into a Grain dataset. This makes it easier to integrate MaxText into existing pipelines, or to use standard datasets.
- We’re using ChartQA because a config file already has been written — but when you look deeper into the configuration process, it’s pretty straightforward to change the dataset by specifying the input fields / output fields in a yaml file.
To submit the training job, use XPK to run MaxText :
uv run xpk workload create \
--workload finetune-sft-gemma12b-chartqa \
--docker-image ${BASE_IMAGE} \
--cluster ${CLUSTER_NAME} \
--tpu-type ${TPU_TYPE} \
--zone ${ZONE} \
--num-slices=${NUM_SLICES} \
--env GOOGLE_CLOUD_PROJECT=$GOOGLE_CLOUD_PROJECT \
--command "python -m MaxText.sft_trainer MaxText/configs/sft-vision-chartqa.yml \
run_name=$idx \
model_name=$MODEL_NAME tokenizer_path="google/gemma-3-12b-it" \
per_device_batch_size=8 \
max_prefill_predict_length=1024 max_target_length=2048 \
steps=$SFT_STEPS \
scan_layers=$SCAN_LAYERS async_checkpointing=False \
attention=dot_product \
dataset_type=hf hf_path=HuggingFaceM4/ChartQA hf_access_token=$HF_TOKEN \
base_output_directory=$BASE_OUTPUT_DIRECTORY \
load_parameters_path=$UNSCANNED_CKPT_PATH \
dtype=bfloat16 weight_dtype=bfloat16 sharding_tolerance=0.05 "
MaxText gives you several fairly robust options for monitoring your model training progress. Most usefully, model training stats are collected and pushed to the base output bucket for tensorboard to display. You can also stream the logs from the pod via kubectl, as below.
#Get the logs from the container.
kubectl logs $(kubectl get pods --no-headers -o custom-columns=":metadata.name" | grep finetune-sft-gemma12b-chartqa) -f#run tensorboard locally
uv run tensorboard --logdir=$BASE_OUTPUT_DIRECTORY/$idx/tensorboard
You can check the TPU performance via the TPU stats collected by the pod : for this workflow, you can see the consistently high usage of the TPU pod. Some statistics of the training run— we averaged 243 TFLOPS/s/device with ~3200 Tokens/s/device. The below screenshot is from the Google Cloud Console page for this workload — notice how the duty cycle utilization is extremely high! We’re putting the TPU’s to work.
JetStream is a Google AI Hypercomputer library to enable “throughput and memory optimized engine for LLM inference on XLA devices”. That’s pretty dry — what it means in practice is that its got a great way to deploy the model you just trained, efficiently, on the TPU hardware that offers the performance/value to make it useful out in the real world.
#build a container for JetStream
cd inference
docker build -t gcr.io/moonshot-ml/moonshot-pluto:latest .
docker push gcr.io/moonshot-ml/moonshot-pluto:latest # export environment variable for hosting
export TOKENIZER_PATH="assets/tokenizer.gemma3"
export LOAD_PARAMETERS_PATH="gs://moonshot-ml-output-bucket/gemma3-12b/12b/unscanned/objectdetection/2025-08-20/checkpoints/249/items"
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=1
export ICI_TENSOR_PARALLELISM=-1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=16
#run hosting.
uv run xpk workload create \
--workload serve-gemma3-12b-chartqa \
--docker-image gcr.io/moonshot-ml/moonshot-pluto:latest \
--cluster ${CLUSTER_NAME} \
--tpu-type ${TPU_TYPE} \
--zone ${ZONE} \
--num-slices=${NUM_SLICES} \
--env GOOGLE_CLOUD_PROJECT=$GOOGLE_CLOUD_PROJECT \
--command "python3 -m MaxText.maxengine_server \
MaxText/configs/base.yml \
tokenizer_path=${TOKENIZER_PATH} \
load_parameters_path=${LOAD_PARAMETERS_PATH} \
max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
max_target_length=${MAX_TARGET_LENGTH} \
model_name=${MODEL_NAME} \
ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
scan_layers=${SCAN_LAYERS} \
weight_dtype=${WEIGHT_DTYPE} \
per_device_batch_size=${PER_DEVICE_BATCH_SIZE}"
Once that’s serving our model, we can use JetStream’s utility functions to interact with the model :
#Port Forward in one terminal, to add
kubectl port-forward $(kubectl get pods --no-headers -o custom-columns=":metadata.name" | grep serve-gemma3-12b-chartqa) 9000:9000 #run the jetstream enabled container
docker run --network host -p 9000:9000 -it gcr.io/moonshot-ml/moonshot-pluto:latest
#inside of the container ... a warmup
python /workspace/JetStream/jetstream/tools/requester.py --tokenizer assets/tokenizer.gemma3 --text " The name of the tallest mountain in the world is .. "
# result of a 512 query load test ( I had to go manually change this in the file )
python /workspace/JetStream/jetstream/tools/load_tester.py --text " The name of the tallest mountain in the world is .. "
Time taken: 22.366379976272583
QPS: 22.891500571087327
When running the load test script with 512 queries, I got 22.366 seconds total w/ 22.89 QPS. That’s not too bad for serving a 12b model in BF16, and it made excellent use of the available hardware.
Jetstream is extremely powerful, but sometimes you just want to save your model back to the Huggingface compatible format. MaxText has a utility for this as well :
# 5. Convert the SFT checkpoint back to HuggingFace format.
export FINAL_CKPT_STEP=$((SFT_STEPS - 1))
export FINETUNED_CKPT_PATH="gs://moonshot-ml-output-bucket/gemma3work/12b/unscanned/objectdetection/2025-08-20/checkpoints/499/items/"
export LOCAL_PATH=gs://moonshot-ml-output-bucket/gemma3-12b-chartqauv run xpk workload create \
--workload export-gemma3-12b-chartqa \
--docker-image ${BASE_IMAGE} \
--cluster ${CLUSTER_NAME} \
--tpu-type ${TPU_TYPE} \
--zone ${ZONE} \
--num-slices=${NUM_SLICES} \
--env GOOGLE_CLOUD_PROJECT=$GOOGLE_CLOUD_PROJECT \
--command "python3 -m MaxText.utils.ckpt_conversion.to_huggingface MaxText/configs/base.yml \
model_name=${MODEL_NAME} \
hf_access_token=${HF_TOKEN} \
load_parameters_path=${FINETUNED_CKPT_PATH} \
base_output_directory=${LOCAL_PATH} \
use_multimodal=${USE_MULTIMODAL} \
scan_layers=$SCAN_LAYERS"
We’re winding this down — it’s time to clean up the cluster. Thankfully, xpk makes this pretty straightforward :
yes | uv run xpk cluster delete --cluster ${CLUSTER_NAME} --zone ${ZONE}
By completing this walkthrough, you’ve done more than just fine-tune a model; you’ve executed a complete, end-to-end blueprint for training on Google’s AI infrastructure. We’ve demystified the process, showing how tools like XPK, MaxText, and JetStream work together to create a cohesive and manageable workflow.
The key takeaway is that you don’t have to sacrifice usability for performance. As we saw from the TPU utilization metrics and inference benchmarks, this toolchain is built to extract maximum value from the underlying hardware. This guide serves as your foundation. Now, try swapping in a new dataset, adjusting the cluster, or scaling to an even larger model. You have the template to build at scale.
Source Credit: https://medium.com/google-cloud/using-googles-ai-hypercomputer-b149ad3fe3e7?source=rss—-e52cf94d98af—4