Course Hive
Search

Welcome

Sign in or create your account

Continue with Google
or
Build and Train an LLM with JAX
Play lesson

DeepLearning.AI Courses - Build and Train an LLM with JAX

5.0 (2)
18 learners

What you'll learn

This course includes

  • 5.5 hours of video
  • Certificate of completion
  • Access on mobile and TV

Summary

Full Transcript

Learn more: https://bit.ly/4rce49q Introducing Build and Train an LLM with JAX, a short course built in partnership with Google and taught by Chris Achard, Developer Relations Engineer on Google's TPU Software team. JAX is the open-source numerical computing library that Google uses to build and train its most advanced models, including Gemini. It looks similar to NumPy, but adds automatic differentiation, just-in-time compilation, and the ability to scale training across thousands of CPUs, GPUs, and TPUs. In this course, you'll learn JAX by building and training a language model from scratch. You'll implement a complete MiniGPT-style LLM with 20 million parameters—defining the architecture, loading and preprocessing training data, running the training loop, saving checkpoints, and finally chatting with your trained model through a graphical interface. Along the way, you'll work with key tools from the JAX ecosystem: Flax/NNX for neural network layers, Grain for data loading, Optax for optimization, and Orbax for checkpointing. In detail, you'll: - Explore JAX's core concepts such as automatic differentiation, JIT compilation, and vectorized execution and see how it compares to NumPy, PyTorch, and TensorFlow in the broader ML landscape. - Build the architecture of a MiniGPT-style language model using JAX and Flax/NNX, implementing token embeddings and transformer blocks into a complete, trainable model. - Load and preprocess a dataset of mini stories for training, covering tokenization, batching, and structuring data for JAX's functional execution model. - Implement the full training loop: compute losses, apply gradients with Optax, and use JAX transformations to keep training efficient, then save your model with Orbax checkpointing. - Load a pretrained MiniGPT model and run inference through a chat interface to generate stories, completing the full build-train-deploy workflow. The steps you'll follow to build and train MiniGPT are the same foundational steps Google uses to develop its more powerful models like Gemini. This course gives you hands-on experience with the tools and techniques at the core of modern LLM development. Enroll now: https://bit.ly/4rce49q

Course Hive

Continue this lesson in the app

Install CourseHive on Android or iOS to keep learning while you move.

Related Courses

FAQs

Course Hive
Download CourseHive
Keep learning anywhere