您当前的位置:首页 > 电脑百科 > 程序开发 > 编程百科

使用JAX实现完整的Vision Transformer

时间:2023-02-06 14:16:12  来源:今日头条  作者:deephub

本文将展示如何使用JAX/Flax实现Vision Transformer (ViT),以及如何使用JAX/Flax训练ViT。

Vision Transformer

在实现Vision Transformer时,首先要记住这张图。

以下是论文描述的ViT执行过程。

从输入图像中提取补丁图像,并将其转换为平面向量。

投影到 Transformer Encoder 来处理的维度

预先添加一个可学习的嵌入([class]标记),并添加一个位置嵌入。

由 Transformer Encoder 进行编码处理

使用[class]令牌作为输出,输入到MLP进行分类。

细节实现

下面,我们将使用JAX/Flax创建每个模块。

1、图像到展平的图像补丁

下面的代码从输入图像中提取图像补丁。这个过程通过卷积来实现,内核大小为patch_size * patch_size, stride为patch_size * patch_size,以避免重复。

class Patches(nn.Module):
patch_size: int
embed_dim: int

def setup(self):
self.conv = nn.Conv(
features=self.embed_dim,
kernel_size=(self.patch_size, self.patch_size),
strides=(self.patch_size, self.patch_size),
padding='VALID'
)

def __call__(self, images):
patches = self.conv(images)
b, h, w, c = patches.shape
patches = jnp.reshape(patches, (b, h*w, c))
return patches

2和3、对展平补丁块的线性投影/添加[CLS]标记/位置嵌入

Transformer Encoder 对所有层使用相同的尺寸大小hidden_dim。上面创建的补丁块向量被投影到hidden_dim维度向量上。与BERT一样,有一个CLS令牌被添加到序列的开头,还增加了一个可学习的位置嵌入来保存位置信息。

class PatchEncoder(nn.Module):
hidden_dim: int

@nn.compact
def __call__(self, x):
assert x.ndim == 3
n, seq_len, _ = x.shape
# Hidden dim
x = nn.Dense(self.hidden_dim)(x)
# Add cls token
cls = self.param('cls_token', nn.initializers.zeros, (1, 1, self.hidden_dim))
cls = jnp.tile(cls, (n, 1, 1))
x = jnp.concatenate([cls, x], axis=1)
# Add position embedding
pos_embed = self.param(
'position_embedding',
nn.initializers.normal(stddev=0.02), # From BERT
(1, seq_len + 1, self.hidden_dim)
)
return x + pos_embed

4、Transformer encoder

如上图所示,编码器由多头自注意(MSA)和MLP交替层组成。Norm层 (LN)在MSA和MLP块之前,残差连接在块之后。

class TransformerEncoder(nn.Module):
embed_dim: int
hidden_dim: int
n_heads: int
drop_p: float
mlp_dim: int

def setup(self):
self.mha = MultiHeadSelfAttention(self.hidden_dim, self.n_heads, self.drop_p)
self.mlp = MLP(self.mlp_dim, self.drop_p)
self.layer_norm = nn.LayerNorm(epsilon=1e-6)

def __call__(self, inputs, trAIn=True):
# Attention Block
x = self.layer_norm(inputs)
x = self.mha(x, train)
x = inputs + x
# MLP block
y = self.layer_norm(x)
y = self.mlp(y, train)

return x + y

MLP是一个两层网络。激活函数是GELU。本文将Dropout应用于Dense层之后。

class MLP(nn.Module):
mlp_dim: int
drop_p: float
out_dim: Optional[int] = None

@nn.compact
def __call__(self, inputs, train=True):
actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim
x = nn.Dense(features=self.mlp_dim)(inputs)
x = nn.gelu(x)
x = nn.Dropout(rate=self.drop_p, deterministic=not train)(x)
x = nn.Dense(features=actual_out_dim)(x)
x = nn.Dropout(rate=self.drop_p, deterministic=not train)(x)
return x

多头自注意(MSA)

qkv的形式应为[B, N, T, D],如Single Head中计算权重和注意力后,应输出回原维度[B, T, C=N*D]。

class MultiHeadSelfAttention(nn.Module):
hidden_dim: int
n_heads: int
drop_p: float

def setup(self):
self.q.NET = nn.Dense(self.hidden_dim)
self.k_net = nn.Dense(self.hidden_dim)
self.v_net = nn.Dense(self.hidden_dim)

self.proj_net = nn.Dense(self.hidden_dim)

self.att_drop = nn.Dropout(self.drop_p)
self.proj_drop = nn.Dropout(self.drop_p)

def __call__(self, x, train=True):
B, T, C = x.shape # batch_size, seq_length, hidden_dim
N, D = self.n_heads, C // self.n_heads # num_heads, head_dim
q = self.q_net(x).reshape(B, T, N, D).transpose(0, 2, 1, 3) # (B, N, T, D)
k = self.k_net(x).reshape(B, T, N, D).transpose(0, 2, 1, 3)
v = self.v_net(x).reshape(B, T, N, D).transpose(0, 2, 1, 3)

# weights (B, N, T, T)
weights = jnp.matmul(q, jnp.swapaxes(k, -2, -1)) / math.sqrt(D)
normalized_weights = nn.softmax(weights, axis=-1)

# attention (B, N, T, D)
attention = jnp.matmul(normalized_weights, v)
attention = self.att_drop(attention, deterministic=not train)

# gather heads
attention = attention.transpose(0, 2, 1, 3).reshape(B, T, N*D)

# project
out = self.proj_drop(self.proj_net(attention), deterministic=not train)

return out

5、使用CLS嵌入进行分类

最后MLP头(分类头)。

class ViT(nn.Module):
patch_size: int
embed_dim: int
hidden_dim: int
n_heads: int
drop_p: float
num_layers: int
mlp_dim: int
num_classes: int

def setup(self):
self.patch_extracter = Patches(self.patch_size, self.embed_dim)
self.patch_encoder = PatchEncoder(self.hidden_dim)
self.dropout = nn.Dropout(self.drop_p)
self.transformer_encoder = TransformerEncoder(self.embed_dim, self.hidden_dim, self.n_heads, self.drop_p, self.mlp_dim)
self.cls_head = nn.Dense(features=self.num_classes)

def __call__(self, x, train=True):
x = self.patch_extracter(x)
x = self.patch_encoder(x)
x = self.dropout(x, deterministic=not train)
for i in range(self.num_layers):
x = self.transformer_encoder(x, train)
# MLP head
x = x[:, 0] # [CLS] token
x = self.cls_head(x)
return x

使用JAX/Flax训练

现在已经创建了模型,下面就是使用JAX/Flax来训练。

数据集

这里我们直接使用 torchvision的CIFAR10.

首先是一些工具函数

def image_to_numpy(img):
img = np.array(img, dtype=np.float32)
img = (img / 255. - DATA_MEANS) / DATA_STD
return img

def numpy_collate(batch):
if isinstance(batch[0], np.ndarray):
return np.stack(batch)
elif isinstance(batch[0], (tuple, list)):
transposed = zip(*batch)
return [numpy_collate(samples) for samples in transposed]
else:
return np.array(batch)

然后是训练和测试的dataloader

test_transform = image_to_numpy
train_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomResizedCrop((IMAGE_SIZE, IMAGE_SIZE), scale=CROP_SCALES, ratio=CROP_RATIO),
image_to_numpy
])

# Validation set should not use the augmentation.
train_dataset = CIFAR10('data', train=True, transform=train_transform, download=True)
val_dataset = CIFAR10('data', train=True, transform=test_transform, download=True)
train_set, _ = torch.utils.data.random_split(train_dataset, [45000, 5000], generator=torch.Generator().manual_seed(SEED))
_, val_set = torch.utils.data.random_split(val_dataset, [45000, 5000], generator=torch.Generator().manual_seed(SEED))
test_set = CIFAR10('data', train=False, transform=test_transform, download=True)

train_loader = torch.utils.data.DataLoader(
train_set, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=2, persistent_workers=True, collate_fn=numpy_collate,
)
val_loader = torch.utils.data.DataLoader(
val_set, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_workers=2, persistent_workers=True, collate_fn=numpy_collate,
)
test_loader = torch.utils.data.DataLoader(
test_set, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_workers=2, persistent_workers=True, collate_fn=numpy_collate,
)

初始化模型

初始化ViT模型

def initialize_model(
seed=42,
patch_size=16, embed_dim=192, hidden_dim=192,
n_heads=3, drop_p=0.1, num_layers=12, mlp_dim=768, num_classes=10
):
main_rng = jax.random.PRNGKey(seed)
x = jnp.ones(shape=(5, 32, 32, 3))
# ViT
model = ViT(
patch_size=patch_size,
embed_dim=embed_dim,
hidden_dim=hidden_dim,
n_heads=n_heads,
drop_p=drop_p,
num_layers=num_layers,
mlp_dim=mlp_dim,
num_classes=num_classes
)
main_rng, init_rng, drop_rng = random.split(main_rng, 3)
params = model.init({'params': init_rng, 'dropout': drop_rng}, x, train=True)['params']
return model, params, main_rng

vit_model, vit_params, vit_rng = initialize_model()

创建TrainState

在Flax中常见的模式是创建管理训练的状态的类,包括轮次、优化器状态和模型参数等等。还可以通过在Apply_fn中指定apply_fn来减少学习循环中的函数参数列表,apply_fn对应于模型的前向传播。

def create_train_state(
model, params, learning_rate
):
optimizer = optax.adam(learning_rate)
return train_state.TrainState.create(
apply_fn=model.apply,
tx=optimizer,
params=params
)

state = create_train_state(vit_model, vit_params, 3e-4)

循环训练

def train_model(train_loader, val_loader, state, rng, num_epochs=100):
best_eval = 0.0
for epoch_idx in tqdm(range(1, num_epochs + 1)):
state, rng = train_epoch(train_loader, epoch_idx, state, rng)
if epoch_idx % 1 == 0:
eval_acc = eval_model(val_loader, state, rng)
logger.add_scalar('val/acc', eval_acc, global_step=epoch_idx)
if eval_acc >= best_eval:
best_eval = eval_acc
save_model(state, step=epoch_idx)
logger.flush()
# Evaluate after training
test_acc = eval_model(test_loader, state, rng)
print(f'test_acc: {test_acc}')

def train_epoch(train_loader, epoch_idx, state, rng):
metrics = defaultdict(list)
for batch in tqdm(train_loader, desc='Training', leave=False):
state, rng, loss, acc = train_step(state, rng, batch)
metrics['loss'].append(loss)
metrics['acc'].append(acc)
for key in metrics.keys():
arg_val = np.stack(jax.device_get(metrics[key])).mean()
logger.add_scalar('train/' + key, arg_val, global_step=epoch_idx)
print(f'[epoch {epoch_idx}] {key}: {arg_val}')
return state, rng

验证

def eval_model(data_loader, state, rng):
# Test model on all images of a data loader and return avg loss
correct_class, count = 0, 0
for batch in data_loader:
rng, acc = eval_step(state, rng, batch)
correct_class += acc * batch[0].shape[0]
count += batch[0].shape[0]
eval_acc = (correct_class / count).item()
return eval_acc

训练步骤

在train_step中定义损失函数,计算模型参数的梯度,并根据梯度更新参数;在value_and_gradients方法中,计算状态的梯度。在apply_gradients中,更新TrainState。交叉熵损失是通过apply_fn(与model.apply相同)计算logits来计算的,apply_fn是在创建TrainState时指定的。

@jax.jit
def train_step(state, rng, batch):
loss_fn = lambda params: calculate_loss(params, state, rng, batch, train=True)
# Get loss, gradients for loss, and other outputs of loss function
(loss, (acc, rng)), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
# Update parameters and batch statistics
state = state.apply_gradients(grads=grads)
return state, rng, loss, acc

计算损失

def calculate_loss(params, state, rng, batch, train):
imgs, labels = batch
rng, drop_rng = random.split(rng)
logits = state.apply_fn({'params': params}, imgs, train=train, rngs={'dropout': drop_rng})
loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=labels).mean()
acc = (logits.argmax(axis=-1) == labels).mean()
return loss, (acc, rng)

结果

训练结果如下所示。在Colab pro的标准GPU上,训练时间约为1.5小时。

test_acc: 0.7704000473022461

如果你对JAX感兴趣,请看这里是本文的完整代码:

https://avoid.overfit.cn/post/926b7965ba56464ba151cbbfb6a98a93

作者:satojkovic



Tags:JAX   点击:()  评论:()
声明:本站部分内容及图片来自互联网,转载是出于传递更多信息之目的,内容观点仅代表作者本人,不构成投资建议。投资者据此操作,风险自担。如有任何标注错误或版权侵犯请与我们联系,我们将及时更正、删除。
▌相关推荐
Ajax是什么?JavaScript中如何使用Ajax技术进行网络请求?
在web初期阶段,前端想要获取后端服务信息需要刷新整个页面,这种方式既耗时又让用户体验十分糟糕,那么怎么解决这个问题呢?想要搭建起前端和后端的快速通道,这个时候就需要使用Aja...【详细内容】
2023-06-28  Search: JAX  点击:(233)  评论:(0)  加入收藏
Java抓取前端Ajax的秘诀:andXML和XML
在当今互联网时代,前端技术的发展已经越来越快,越来越多的网站采用了Ajax技术来实现前端渲染。这种技术可以使得页面更加流畅,用户体验更好,但是它也给后端爬虫带来了很大的挑战...【详细内容】
2023-05-09  Search: JAX  点击:(380)  评论:(0)  加入收藏
使用JAX实现完整的Vision Transformer
本文将展示如何使用JAX/Flax实现Vision Transformer (ViT),以及如何使用JAX/Flax训练ViT。Vision Transformer在实现Vision Transformer时,首先要记住这张图。以下是论文描述...【详细内容】
2023-02-06  Search: JAX  点击:(213)  评论:(0)  加入收藏
在 node 中使用 jquery ajax
对于前端同学来说,ajax 请求应该不会陌生。jquery 真的ajax请求做了封装,可以通过下面的方式发送一个请求并获取相应结果:$.ajax({ url: "https://echo.apipost.cn/get.php"...【详细内容】
2022-09-01  Search: JAX  点击:(408)  评论:(0)  加入收藏
Spring Boot与JAX-RS框架Jersey的完美搭配
编辑推荐:本文来自于oschina,本文主要介绍了一个在 spring boot 项目添加Jeresy框架的详细过程,希望对大家能有所帮助。Jeresy是一个轻量级的JAX-RS框架添加Jeresy 2.x的依赖c...【详细内容】
2022-06-05  Search: JAX  点击:(356)  评论:(0)  加入收藏
Ajax 之战:XMLHttpRequest 与 Fetch API
Ajax 是大多数 web 应用程序背后的核心技术,它允许页面向 web 服务发出异步请求,因此数据可以不经过页面往返服务器无刷新显示数据。术语 Ajax 不是一种技术,相反,它指的是从客...【详细内容】
2022-05-06  Search: JAX  点击:(299)  评论:(0)  加入收藏
在 JS 中如何使用 Ajax 来进行请求
本人已经过原 Danny Markov 授权翻译在本教程中,我们将学习如何使用 JS 进行AJAX调用。1.AJAX术语AJAX 表示 异步的 JavaScript 和 XML。AJAX 在 JS 中用于发出异步网络请求...【详细内容】
2021-01-12  Search: JAX  点击:(453)  评论:(0)  加入收藏
ajax请求controller ajax跨域报错处理
报错:Access to XMLHttpRequest at 'http://localhost:8080/SpringBootServer/testfile' from origin 'null' has been blocked by CORS policy: No 'Ac...【详细内容】
2020-09-03  Search: JAX  点击:(219)  评论:(0)  加入收藏
ajax跨域完全讲解
跨域产生的原因: 浏览器限制。如果浏览器发现请求是跨域的时候,就会做校验,如果校验不通过就会报跨域的错误 跨域。发出去的请求只要域名、端口、协议中的任意一个与当前域不同...【详细内容】
2020-07-02  Search: JAX  点击:(382)  评论:(0)  加入收藏
MathJax的基本使用
MathJax是一个开放源代码的JavaScript显示引擎,适用于所有现代浏览器中的LaTeX、MathML和AsciMath表示法。MathJax官网为 https://www.mathjax.org 其开源文档地址为 https:/...【详细内容】
2020-05-04  Search: JAX  点击:(381)  评论:(0)  加入收藏
▌简易百科推荐
Netflix 是如何管理 2.38 亿会员的
作者 | Surabhi Diwan译者 | 明知山策划 | TinaNetflix 高级软件工程师 Surabhi Diwan 在 2023 年旧金山 QCon 大会上发表了题为管理 Netflix 的 2.38 亿会员 的演讲。她在...【详细内容】
2024-04-08    InfoQ  Tags:Netflix   点击:(3)  评论:(0)  加入收藏
即将过时的 5 种软件开发技能!
作者 | Eran Yahav编译 | 言征出品 | 51CTO技术栈(微信号:blog51cto) 时至今日,AI编码工具已经进化到足够强大了吗?这未必好回答,但从2023 年 Stack Overflow 上的调查数据来看,44%...【详细内容】
2024-04-03    51CTO  Tags:软件开发   点击:(8)  评论:(0)  加入收藏
跳转链接代码怎么写?
在网页开发中,跳转链接是一项常见的功能。然而,对于非技术人员来说,编写跳转链接代码可能会显得有些困难。不用担心!我们可以借助外链平台来简化操作,即使没有编程经验,也能轻松实...【详细内容】
2024-03-27  蓝色天纪    Tags:跳转链接   点击:(15)  评论:(0)  加入收藏
中台亡了,问题到底出在哪里?
曾几何时,中台一度被当做“变革灵药”,嫁接在“前台作战单元”和“后台资源部门”之间,实现企业各业务线的“打通”和全域业务能力集成,提高开发和服务效率。但在中台如火如荼之...【详细内容】
2024-03-27  dbaplus社群    Tags:中台   点击:(11)  评论:(0)  加入收藏
员工写了个比删库更可怕的Bug!
想必大家都听说过删库跑路吧,我之前一直把它当一个段子来看。可万万没想到,就在昨天,我们公司的某位员工,竟然写了一个比删库更可怕的 Bug!给大家分享一下(不是公开处刑),希望朋友们...【详细内容】
2024-03-26  dbaplus社群    Tags:Bug   点击:(8)  评论:(0)  加入收藏
我们一起聊聊什么是正向代理和反向代理
从字面意思上看,代理就是代替处理的意思,一个对象有能力代替另一个对象处理某一件事。代理,这个词在我们的日常生活中也不陌生,比如在购物、旅游等场景中,我们经常会委托别人代替...【详细内容】
2024-03-26  萤火架构  微信公众号  Tags:正向代理   点击:(14)  评论:(0)  加入收藏
看一遍就理解:IO模型详解
前言大家好,我是程序员田螺。今天我们一起来学习IO模型。在本文开始前呢,先问问大家几个问题哈~什么是IO呢?什么是阻塞非阻塞IO?什么是同步异步IO?什么是IO多路复用?select/epoll...【详细内容】
2024-03-26  捡田螺的小男孩  微信公众号  Tags:IO模型   点击:(10)  评论:(0)  加入收藏
为什么都说 HashMap 是线程不安全的?
做Java开发的人,应该都用过 HashMap 这种集合。今天就和大家来聊聊,为什么 HashMap 是线程不安全的。1.HashMap 数据结构简单来说,HashMap 基于哈希表实现。它使用键的哈希码来...【详细内容】
2024-03-22  Java技术指北  微信公众号  Tags:HashMap   点击:(12)  评论:(0)  加入收藏
如何从头开始编写LoRA代码,这有一份教程
选自 lightning.ai作者:Sebastian Raschka机器之心编译编辑:陈萍作者表示:在各种有效的 LLM 微调方法中,LoRA 仍然是他的首选。LoRA(Low-Rank Adaptation)作为一种用于微调 LLM(大...【详细内容】
2024-03-21  机器之心Pro    Tags:LoRA   点击:(13)  评论:(0)  加入收藏
这样搭建日志中心,传统的ELK就扔了吧!
最近客户有个新需求,就是想查看网站的访问情况。由于网站没有做google的统计和百度的统计,所以访问情况,只能通过日志查看,通过脚本的形式给客户导出也不太实际,给客户写个简单的...【详细内容】
2024-03-20  dbaplus社群    Tags:日志   点击:(6)  评论:(0)  加入收藏
站内最新
站内热门
站内头条