GitHub 7.5k star amount, a collection of PyTorch implementations of various visual transformers is sorted out

#News ·2025-01-02

Transformer Crossover CV missions are nothing new for the last year or two.

Since Google proposed Vision Transformer (ViT) in October 2020, various vision Transformers have begun to play a role in image synthesis, point cloud processing, visual-language modeling and other fields.

After that, the implementation of Vision Transformer in PyTorch has become a research hotspot. There are a lot of great projects on GitHub, and today I'm going to introduce you to one of them.

The project is called "vit-pytorch"; and it is a Vision Transformer implementation that demonstrates a simple way to achieve SOTA results for visual classification in PyTorch using only a single transformer encoder.

The project currently has 7.5k of stars and was created by Phil Wang, who has 147 repositories on GitHub.

图片

The address of the project: https://github.com/lucidrains/vit-pytorch

The project authors also provide a GIF showing:

图片

Project introduction

First look at the installation, use, parameters, distillation and other steps of Vision Transformer-PyTorch.

The first step is to install:

$ pip install vit-pytorch 

     
  • 1.

The second step is to use:

import torch from vit_pytorch import ViT v = ViT( image_size = 256, patch_size = 32, num_classes = 1000, dim = 1024, depth = 6, heads = 16, mlp_dim = 2048, dropout = 0.1, emb_dropout = 0.1) img = torch.randn(1, 3, 256, 256) preds = v(img) # (1, 1000) 

     
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.

The third step is the required parameters, including the following:

  • image_size: image size
  • patch_size: indicates the number of patches
  • num_classes: indicates the number of categories
  • dim: linear transformation nn.Linear(...) , dim) after output the final dimension of the tensor
  • depth: Number of Transformer blocks
  • heads: The number of heads in the multi-head attention layer
  • mlp_dim: The dimension of the MLP (feedforward) layer
  • channels: The number of image channels
  • dropout: Dropout rate
  • emb_dropout: Embed dropout rate
  • ……

Finally, distillation, using a process developed by Facebook AI and the Sorbonne paper "Training data-efficient image transformers & distillation through attention."

图片

Address: https://arxiv.org/pdf/2012.12877.pdf

The code distilled from ResNet50 (or any teacher network) to vision transformer is as follows:

import torchfrom torchvision.models import resnet50from vit_pytorch.distill import DistillableViT, DistillWrapperteacher = resnet50(pretrained = True) v = DistillableViT( image_size = 256, patch_size = 32, num_classes = 1000, dim = 1024, depth = 6, heads = 8, mlp_dim = 2048, dropout = 0.1, emb_dropout = 0.1) distiller = DistillWrapper(student = v, teacher = teacher, temperature = 3, # temperature of distillationalpha = 0.5, # trade between main loss and distillation losshard = False # whether to use soft or hard distillation ) img = torch.randn(2, 3, 256, 256)labels = torch.randint(0, 1000, (2,)) loss = distiller(img, labels)loss.backward() # after lots of training above ... pred = v(img) # (2, 1000) 

     
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.
  • 19.
  • 20.

In addition to Vision Transformer, the project also provides PyTorch implementations of other ViT variant models such as Deep ViT, CaiT, tokes-to-token ViT, PiT, etc.

图片

Readers interested in the PyTorch implementation of the ViT model should refer to the original project.

TAGS:

  • 13004184443

  • Room 607, 6th Floor, Building 9, Hongjing Xinhuiyuan, Qingpu District, Shanghai

  • gcfai@dongfangyuzhe.com

  • wechat

  • WeChat official account

Quantum (Shanghai) Artificial Intelligence Technology Co., Ltd. ICP:沪ICP备2025113240号-1

friend link