Just as transformers-based models have revolutionized NLP, they are now making waves in computer vision. The Vision Transformer (ViT), introduced by Google Brain in June 2021, tokenizes images by splitting them into a grid of patches, embedding each patch, and treating them as a sequence for transformer training. This allows pre-training and fine-tuning similar to NLP tasks.
In this tutorial, we'll use Hugging Face's datasets and transformers libraries to fine-tune a pre-trained ViT on an image classification dataset. First, install the required packages:
pip install datasets transformers
Load a Dataset
We'll use the beans dataset, which contains images of healthy and unhealthy bean leaves. Load it and inspect its structure:
from datasets import load_dataset
ds = load_dataset('beans')
Each example has three features: image (a PIL Image), image_file_path (the path to the image file), and labels (an integer class label). Let's look at the 400th training example:
ex = ds['train'][400]
ex
The output shows a PIL image, its file path, and label ID 1. To see the class name, use ClassLabel.int2str():
labels = ds['train'].features['labels']
labels.int2str(ex['labels'])
This returns 'bean_rust'. The dataset has three classes: angular leaf spot, bean rust, and healthy.
Load the ViT Image Processor
Prepare images for the model using the correct preprocessing. Load the ViTImageProcessor for the pretrained model google/vit-base-patch16-224-in21k:
from transformers import ViTImageProcessor
model_name_or_path = 'google/vit-base-patch16-224-in21k'
processor = ViTImageProcessor.from_pretrained(model_name_or_path)
Printing the processor reveals its configuration: it resizes images to 224x224, normalizes with mean and std of 0.5, and uses bilinear resampling.
To process an image, call the processor:
processor(image, return_tensors='pt')
This returns a dict with pixel_values, a tensor of shape (1, 3, 224, 224).
Process the Dataset
Now apply the processor to the entire dataset. Define a function that transforms images and labels, and use dataset.map() to apply it efficiently.