Hugging Face's Diffusers library now supports Flax, enabling lightning-fast inference on Google TPUs. This integration allows users to run Stable Diffusion on TPU hardware available in Colab, Kaggle, or Google Cloud Platform, leveraging parallel processing across 8 devices to generate multiple images simultaneously.
Setup
To begin, ensure you are using a TPU backend. In Colab, select Runtime > Change runtime type and choose TPU under Hardware accelerator. Then install the required version of diffusers:
pip install diffusers==0.5.1
Import dependencies:
import numpy as np
import jax
import jax.numpy as jnp
from pathlib import Path
from jax import pmap
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image
from huggingface_hub import notebook_login
from diffusers import FlaxStableDiffusionPipeline
Model Loading
Before downloading the model, you must accept the CreativeML OpenRAIL-M license on the Hugging Face Hub. The license prohibits generating illegal or harmful content but allows free use of outputs and commercial redistribution with restrictions. Use notebook_login() to authenticate.
Load the pipeline using bfloat16 precision for efficiency:
dtype = jnp.bfloat16
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
revision="bf16",
dtype=dtype,
)
Inference
TPUs typically have 8 devices, so we replicate the prompt across all devices to generate 8 images in parallel. First, tokenize the prompt:
prompt = "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of field, close up, split lighting, cinematic"
prompt = [prompt] * jax.device_count()
prompt_ids = pipeline.prepare_inputs(prompt)
Replicate parameters and shard inputs:
p_params = replicate(params)
prompt_ids = shard(prompt_ids) # shape: (8, 1, 77)
Prepare a random number generator for reproducibility:
rng = jax.random.PRNGKey(0)
rng = jax.random.split(rng, jax.device_count())
Generate images using pmap for parallel execution:
@pmap
def generate(prompt_ids, p_params, rng):
return pipeline(prompt_ids=prompt_ids, params=p_params, prng_seed=rng, num_inference_steps=50, guidance_scale=7.5).images
images = generate(prompt_ids, p_params, rng)
Each device outputs an image, yielding 8 unique results in one inference call. The entire process runs efficiently on TPU hardware, making Stable Diffusion more accessible for batch generation.