Building an Efficient DataLoader in JAX: From Scratch to Production
Introduction
JAX is a powerful framework for numerical computing, enabling high-performance machine learning and deep learning. However, one of its missing pieces is an efficient DataLoader similar to PyTorch’s torch.utils.data.DataLoader
. To fill this gap, I built jax-dataloaders
, a simple and efficient way to handle batched and shuffled datasets in JAX.
📦 PyPI Package: jax-dataloaders
📂 GitHub Repository: JAX-Dataloader
This blog will walk you through how I built jax-dataloaders
from scratch—covering everything from data batching to jax.jit
optimizations. Whether you’re a beginner or an expert, by the end of this post, you’ll understand how to efficiently load and preprocess data in JAX.
Why Do We Need a JAX DataLoader?
Unlike TensorFlow and PyTorch, JAX does not provide a built-in DataLoader. Most ML workflows require batching, shuffling, and efficient iteration, but achieving this in JAX is not as straightforward as in PyTorch.
Key challenges include:
✅ JAX prefers functional programming, meaning iterators and stateful objects don’t work as expected.
✅ JAX needs efficient data movement between CPU and GPU/TPU.
✅ Native Python loops are slow; JAX’s strength lies in JIT compilation and vectorization.
The goal of jax-dataloaders
is to solve these issues while maintaining JAX-native performance.
Building the JAX DataLoader from Scratch
1️⃣ Basic Dataset Handling
In JAX, datasets are typically stored as NumPy or JAX arrays. Instead of an object-oriented Dataset
class (like PyTorch), we can use a simple function-based approach:
import jax.numpy as jnp
def create_dataset(size=1000, feature_dim=10):
X = jnp.arange(size * feature_dim).reshape(size, feature_dim) / 100.0
y = jnp.arange(size) % 2 # Binary labels
return X, yX, y = create_dataset()
print(X.shape, y.shape) # (1000, 10), (1000,)
2️⃣ Implementing Batch Loading
JAX prefers immutable arrays, so we avoid stateful iterators and use functional batching:
def batch_data(X, y, batch_size=32):
dataset_size = X.shape[0]
num_batches = dataset_size // batch_size
return [(X[i * batch_size : (i + 1) * batch_size],
y[i * batch_size : (i + 1) * batch_size]) for i in range(num_batches)]
batches = batch_data(X, y, batch_size=32)
print(f"Number of batches: {len(batches)}")
3️⃣ Efficient Shuffling Using JAX RNG
Instead of using numpy.random.shuffle
, JAX provides pure functional random number generation:
import jax
def shuffle_dataset(X, y, seed=42):
key = jax.random.PRNGKey(seed)
indices = jax.random.permutation(key, X.shape[0])
return X[indices], y[indices]X_shuffled, y_shuffled = shuffle_dataset(X, y)
4️⃣ Using JAX’s vmap
for Efficient Batch Processing
JAX provides vectorized mapping (vmap
), which speeds up computations:
import jax.numpy as jnp
import jax
def normalize(x):
return (x - jnp.mean(x, axis=0)) / jnp.std(x, axis=0)batched_normalize = jax.vmap(normalize, in_axes=0)
X_normalized = batched_normalize(X)
5️⃣ jit
for Accelerating Data Processing
JAX’s jit
compiles functions for maximum speed:
@jax.jit
def process_batch(X_batch):
return X_batch * 2 # Example transformation
processed_batch = process_batch(X[:32])
6️⃣ Parallelizing Data Loading with pmap
For multi-GPU workloads, JAX offers pmap
:
import jax
@jax.pmap
def process_parallel(X_batch):
return X_batch * 2processed_batches = process_parallel(X[:64].reshape(2, 32, -1)) # Splitting across devices
Introducing jax-dataloaders
Instead of manually implementing these steps, jax-dataloaders
provides an optimized solution:
Installation
pip install jax-dataloaders
Usage
from jax_dataloaders import DataLoader
X, y = create_dataset(size=1024, feature_dim=10)dataloader = DataLoader(X, y, batch_size=32, shuffle=True)for batch_X, batch_y in dataloader:
print(batch_X.shape, batch_y.shape) # (32, 10), (32,)
Features of jax-dataloaders
✅ JAX-native: Uses jax.random
and jit
for efficient data loading.
✅ Fast batching: Supports vectorized batch operations with vmap
.
✅ Shuffling & multi-GPU support: Compatible with TPUs and GPUs.
✅ Lightweight & simple: No extra dependencies, just plug & play.
Optimizing DataLoader Performance
🔹 Pre-fetching with JAX
Since JAX runs asynchronously on accelerators, you can pre-fetch data:
from jax.experimental import host_callback
def prefetch(iterator, size=2):
import queue
import threading q = queue.Queue(maxsize=size)
def producer():
for item in iterator:
q.put(item)
q.put(None) thread = threading.Thread(target=producer, daemon=True)
thread.start() while True:
item = q.get()
if item is None:
break
yield itemdataloader = prefetch(dataloader, size=4)
🔹 Streaming Large Datasets
If your dataset is too large for memory, you can load batches lazily:
def stream_batches(dataset_path, batch_size=32):
with open(dataset_path, 'r') as f:
for line in f:
yield process_batch(jnp.array([float(x) for x in line.split()]))
Final Thoughts
JAX is insanely fast, but you need to structure data-loading properly. With jax-dataloaders
, you now have an efficient, easy-to-use DataLoader tailored for JAX workflows.
📂 GitHub Repository: JAX-Dataloader
📦 PyPI Package: jax-dataloaders
🚀 Try it out and contribute! Let me know if you have any feature requests.