Build and Train an LLM with JAX
Instructor: Chris Achard
- Intermediate
- 49 Minutes
- 7 Video Lessons
- 4 Code Examples
- Instructor: Chris Achard
What you'll learn
Build a GPT-2 style language model with 20 million parameters from scratch using JAX, the open-source library behind Google’s Gemini, Veo, and Nano Banana models.
Learn JAX’s core primitives (automatic differentiation, JIT compilation, and vectorized mapping) and how to combine them to define, train, and checkpoint a neural network efficiently.
Load a pretrained MiniGPT model and run inference through a chat interface, completing the full workflow from data preprocessing and training to generating text with the trained LLM.
About this course
Introducing Build and Train an LLM with JAX, a short course built in partnership with Google and taught by Chris Achard, Developer Relations Engineer on Google’s TPU Software team.
JAX is the open-source numerical computing library that Google uses to build and train its most advanced models, including Gemini. It looks similar to NumPy, but adds automatic differentiation, just-in-time compilation, and the ability to scale training across thousands of CPUs, GPUs, and TPUs. In this course, you’ll learn JAX by building and training a language model from scratch.
You’ll implement a complete MiniGPT-style LLM with 20 million parameters—defining the architecture, loading and preprocessing training data, running the training loop, saving checkpoints, and finally chatting with your trained model through a graphical interface. Along the way, you’ll work with key tools from the JAX ecosystem: Flax/NNX for neural network layers, Grain for data loading, Optax for optimization, and Orbax for checkpointing.
In detail, you’ll:
- Explore JAX’s core concepts—automatic differentiation, JIT compilation, and vectorized execution—and see how it compares to NumPy, PyTorch, and TensorFlow in the broader ML landscape.
- Build the architecture of a MiniGPT-style language model using JAX and Flax/NNX, implementing token embeddings and transformer blocks into a complete, trainable model.
- Load and preprocess a dataset of mini stories for training, covering tokenization, batching, and structuring data for JAX’s functional execution model.
- Implement the full training loop—computing losses, applying gradients with Optax, and using JAX transformations to keep training efficient—then save your model with Orbax checkpointing.
- Load a pretrained MiniGPT model and run inference through a chat interface to generate stories, completing the full build-train-deploy workflow.
The steps you’ll follow to build and train MiniGPT are the same foundational steps Google uses to develop its more powerful models like Gemini. This course gives you hands-on experience with the tools and techniques at the core of modern LLM development.
Who should join?
Developers and ML practitioners who want to understand how large language models are built and trained at a foundational level. Familiarity with Python and basic machine learning concepts is recommended.
Course Outline
7 Lessons・4 Code ExamplesIntroduction
Video・3 mins
Overview of JAX
Video・6 mins
Building the Architecture
Video with code examples・10 mins
Data Loading
Video with code examples・6 mins
Training and Saving
Video with code examples・8 mins
Final MiniGPT
Video with code examples・3 mins
Conclusion
Video・1 min
Quiz
Reading・10 mins
Instructor
Build and Train an LLM with JAX
- Intermediate
- 49 Minutes
- 7 Video Lessons
- 4 Code Examples
- Instructor: Chris Achard
Additional learning features, such as quizzes and projects, are included with DeepLearning.AI Pro. Explore it today
Want to learn more about Generative AI?
Keep learning with updates on curated AI news, courses, and events, as well as Andrew’s thoughts from DeepLearning.AI!

