Short CourseIntermediate 49 Minutes

Build and Train an LLM with JAX

Instructor: Chris Achard

Google
  • Intermediate
  • 49 Minutes
  • 7 Video Lessons
  • 4 Code Examples
  • Instructor: Chris Achard
    • Google
    Google

What you'll learn

  • Build a GPT-2 style language model with 20 million parameters from scratch using JAX, the open-source library behind Google’s Gemini, Veo, and Nano Banana models.

  • Learn JAX’s core primitives (automatic differentiation, JIT compilation, and vectorized mapping) and how to combine them to define, train, and checkpoint a neural network efficiently.

  • Load a pretrained MiniGPT model and run inference through a chat interface, completing the full workflow from data preprocessing and training to generating text with the trained LLM.

About this course

Introducing Build and Train an LLM with JAX, a short course built in partnership with Google and taught by Chris Achard, Developer Relations Engineer on Google’s TPU Software team.

JAX is the open-source numerical computing library that Google uses to build and train its most advanced models, including Gemini. It looks similar to NumPy, but adds automatic differentiation, just-in-time compilation, and the ability to scale training across thousands of CPUs, GPUs, and TPUs. In this course, you’ll learn JAX by building and training a language model from scratch.

You’ll implement a complete MiniGPT-style LLM with 20 million parameters—defining the architecture, loading and preprocessing training data, running the training loop, saving checkpoints, and finally chatting with your trained model through a graphical interface. Along the way, you’ll work with key tools from the JAX ecosystem: Flax/NNX for neural network layers, Grain for data loading, Optax for optimization, and Orbax for checkpointing.

In detail, you’ll:

  • Explore JAX’s core concepts—automatic differentiation, JIT compilation, and vectorized execution—and see how it compares to NumPy, PyTorch, and TensorFlow in the broader ML landscape.
  • Build the architecture of a MiniGPT-style language model using JAX and Flax/NNX, implementing token embeddings and transformer blocks into a complete, trainable model.
  • Load and preprocess a dataset of mini stories for training, covering tokenization, batching, and structuring data for JAX’s functional execution model.
  • Implement the full training loop—computing losses, applying gradients with Optax, and using JAX transformations to keep training efficient—then save your model with Orbax checkpointing.
  • Load a pretrained MiniGPT model and run inference through a chat interface to generate stories, completing the full build-train-deploy workflow.

The steps you’ll follow to build and train MiniGPT are the same foundational steps Google uses to develop its more powerful models like Gemini. This course gives you hands-on experience with the tools and techniques at the core of modern LLM development.

Who should join?

Developers and ML practitioners who want to understand how large language models are built and trained at a foundational level. Familiarity with Python and basic machine learning concepts is recommended.

Course Outline

7 Lessons・4 Code Examples
  • Introduction

    Video3 mins

  • Overview of JAX

    Video6 mins

  • Building the Architecture

    Video with code examples10 mins

  • Data Loading

    Video with code examples6 mins

  • Training and Saving

    Video with code examples8 mins

  • Final MiniGPT

    Video with code examples3 mins

  • Conclusion

    Video1 min

  • Quiz

    Reading10 mins

Instructor

Chris Achard

Chris Achard

Developer Relations Engineer at Google

Additional learning features, such as quizzes and projects, are included with DeepLearning.AI Pro. Explore it today

Want to learn more about Generative AI?

Keep learning with updates on curated AI news, courses, and events, as well as Andrew’s thoughts from DeepLearning.AI!