Google Developer Experts

Experts on various Google products talking tech.

Building an Efficient DataLoader in JAX: From Scratch to Production

--

Introduction

JAX is a powerful framework for numerical computing, enabling high-performance machine learning and deep learning. However, one of its missing pieces is an efficient DataLoader similar to PyTorch’s torch.utils.data.DataLoader. To fill this gap, I built jax-dataloaders, a simple and efficient way to handle batched and shuffled datasets in JAX.

📦 PyPI Package: jax-dataloaders
📂 GitHub Repository: JAX-Dataloader

This blog will walk you through how I built jax-dataloaders from scratch—covering everything from data batching to jax.jit optimizations. Whether you’re a beginner or an expert, by the end of this post, you’ll understand how to efficiently load and preprocess data in JAX.

Why Do We Need a JAX DataLoader?

Unlike TensorFlow and PyTorch, JAX does not provide a built-in DataLoader. Most ML workflows require batching, shuffling, and efficient iteration, but achieving this in JAX is not as straightforward as in PyTorch.

Key challenges include:
✅ JAX prefers functional programming, meaning iterators and stateful objects don’t work as expected.
✅ JAX needs efficient data movement between CPU and GPU/TPU.
✅ Native Python loops are slow; JAX’s strength lies in JIT compilation and vectorization.

The goal of jax-dataloaders is to solve these issues while maintaining JAX-native performance.

Building the JAX DataLoader from Scratch

1️⃣ Basic Dataset Handling

In JAX, datasets are typically stored as NumPy or JAX arrays. Instead of an object-oriented Dataset class (like PyTorch), we can use a simple function-based approach:

import jax.numpy as jnp
def create_dataset(size=1000, feature_dim=10):
X = jnp.arange(size * feature_dim).reshape(size, feature_dim) / 100.0
y = jnp.arange(size) % 2 # Binary labels
return X, y
X, y = create_dataset()
print(X.shape, y.shape) # (1000, 10), (1000,)

2️⃣ Implementing Batch Loading

JAX prefers immutable arrays, so we avoid stateful iterators and use functional batching:

def batch_data(X, y, batch_size=32):
dataset_size = X.shape[0]
num_batches = dataset_size // batch_size
return [(X[i * batch_size : (i + 1) * batch_size],
y[i * batch_size : (i + 1) * batch_size]) for i in range(num_batches)]
batches = batch_data(X, y, batch_size=32)
print(f"Number of batches: {len(batches)}")

3️⃣ Efficient Shuffling Using JAX RNG

Instead of using numpy.random.shuffle, JAX provides pure functional random number generation:

import jax
def shuffle_dataset(X, y, seed=42):
key = jax.random.PRNGKey(seed)
indices = jax.random.permutation(key, X.shape[0])
return X[indices], y[indices]
X_shuffled, y_shuffled = shuffle_dataset(X, y)

4️⃣ Using JAX’s vmap for Efficient Batch Processing

JAX provides vectorized mapping (vmap), which speeds up computations:

import jax.numpy as jnp
import jax
def normalize(x):
return (x - jnp.mean(x, axis=0)) / jnp.std(x, axis=0)
batched_normalize = jax.vmap(normalize, in_axes=0)
X_normalized = batched_normalize(X)

5️⃣ jit for Accelerating Data Processing

JAX’s jit compiles functions for maximum speed:

@jax.jit
def process_batch(X_batch):
return X_batch * 2 # Example transformation
processed_batch = process_batch(X[:32])

6️⃣ Parallelizing Data Loading with pmap

For multi-GPU workloads, JAX offers pmap:

import jax
@jax.pmap
def process_parallel(X_batch):
return X_batch * 2
processed_batches = process_parallel(X[:64].reshape(2, 32, -1)) # Splitting across devices

Introducing jax-dataloaders

Instead of manually implementing these steps, jax-dataloaders provides an optimized solution:

Installation

pip install jax-dataloaders

Usage

from jax_dataloaders import DataLoader
X, y = create_dataset(size=1024, feature_dim=10)dataloader = DataLoader(X, y, batch_size=32, shuffle=True)for batch_X, batch_y in dataloader:
print(batch_X.shape, batch_y.shape) # (32, 10), (32,)

Features of jax-dataloaders

JAX-native: Uses jax.random and jit for efficient data loading.
Fast batching: Supports vectorized batch operations with vmap.
Shuffling & multi-GPU support: Compatible with TPUs and GPUs.
Lightweight & simple: No extra dependencies, just plug & play.

Optimizing DataLoader Performance

🔹 Pre-fetching with JAX

Since JAX runs asynchronously on accelerators, you can pre-fetch data:

from jax.experimental import host_callback
def prefetch(iterator, size=2):
import queue
import threading
q = queue.Queue(maxsize=size)

def producer():
for item in iterator:
q.put(item)
q.put(None)
thread = threading.Thread(target=producer, daemon=True)
thread.start()
while True:
item = q.get()
if item is None:
break
yield item
dataloader = prefetch(dataloader, size=4)

🔹 Streaming Large Datasets

If your dataset is too large for memory, you can load batches lazily:

def stream_batches(dataset_path, batch_size=32):
with open(dataset_path, 'r') as f:
for line in f:
yield process_batch(jnp.array([float(x) for x in line.split()]))

Final Thoughts

JAX is insanely fast, but you need to structure data-loading properly. With jax-dataloaders, you now have an efficient, easy-to-use DataLoader tailored for JAX workflows.

📂 GitHub Repository: JAX-Dataloader
📦 PyPI Package: jax-dataloaders

🚀 Try it out and contribute! Let me know if you have any feature requests.

--

--

Google Developer Experts
Google Developer Experts
Kartikey Rawat
Kartikey Rawat

No responses yet

  翻译: