您当前的位置:首页 > 电脑百科 > 程序开发 > 语言 > Python

PyTorch团队重写「分割一切」模型,比原始实现快八倍

时间:2023-11-23 12:28:01  来源:机器之心  作者:

编辑:陈萍

我们该如何优化 Meta 的「分割一切」模型,PyTorch 团队撰写的这篇博客由浅入深的帮你解答。

从年初到现在,生成式 AI 发展迅猛。但很多时候,我们又不得不面临一个难题:如何加快生成式 AI 的训练、推理等,尤其是在使用 PyTorch 的情况下。

本文 PyTorch 团队的研究者为我们提供了一个解决方案。文章重点介绍了如何使用纯原生 PyTorch 加速生成式 AI 模型,此外,文章还介绍了 PyTorch 新功能,以及如何组合这些功能的实际示例。

结果如何呢?PyTorch 团队表示,他们重写了 Meta 的「分割一切」 (SAM) 模型,从而使代码比原始实现快 8 倍,并且没有损失准确率,所有这些都是使用原生 PyTorch 进行优化的。

PyTorch团队重写「分割一切」模型,比原始实现快八倍

博客地址:https://pytorch.org/blog/accelerating-generative-ai/

看完本文,你将了解到:

  • Torch.compile:PyTorch 模型编译器, PyTorch 2.0 加入了一个新的函数,叫做 torch.compile (),能够通过一行代码对已有的模型进行加速;
  • GPU 量化:通过降低运算精度来加速模型;
  • SDPA(Scaled Dot Product Attention ):内存高效的注意力实现方式;
  • 半结构化 (2:4) 稀疏性:一种针对 GPU 优化的稀疏内存格式;
  • Nested Tensor:Nested Tensor 把 {tensor, mask} 打包在一起,将非均匀大小的数据批处理到单个张量中,例如不同大小的图像;
  • Triton 自定义操作:使用 Triton Python/ target=_blank class=infotextkey>Python DSL 编写 GPU 操作,并通过自定义操作符注册轻松将其集成到 PyTorch 的各种组件中。

PyTorch团队重写「分割一切」模型,比原始实现快八倍

PyTorch 原生特性所带来的吞吐量增加以及减少的内存开销。

SAM 由 Meta 提出,关于这项研究的更多内容请参考「CV 不存在了?Meta 发布「分割一切」AI 模型,CV 或迎来 GPT-3 时刻」。

PyTorch团队重写「分割一切」模型,比原始实现快八倍

接下来,文章介绍了 SAM 优化过程,包括性能分析、瓶颈识别,以及如何将这些新功能整合进 PyTorch 以解决 SAM 面临的这些问题。除此以外,本文还介绍了 PyTorch 的一些新特性:torch.compile、SDPA、Triton kernels、Nested Tensor 以及 semi-structured sparsity(半结构化稀疏)。

本文内容逐层深入,文章的最后会介绍快速版 SAM,感兴趣的小伙伴可以去 Github 上下载,此外,本文还通过 Perfetto UI 对这些数据进行了可视化,以此来阐释 PyTorch 每项特性的应用价值。

GitHub 地址:https://github.com/pytorch-labs/segment-anything-fast

对分割一切模型 SAM 的重写

该研究表示,本文利用的 SAM 基线数据类型为 float32 dtype、batch 大小为 1,使用 PyTorch Profiler 查看内核跟踪的结果如下:

PyTorch团队重写「分割一切」模型,比原始实现快八倍

本文发现 SAM 有两个地方可以优化:

第一个是对 aten::index 的长调用,这是由张量索引操作(例如 [])产生的底层调用导致的。然而实际上 GPU 花费在 aten::index 上的时间相对较低,原因在于 aten::index 在启动两个内核的过程中,两者之间发生了阻塞 cudaStreamSynchronize。这意味着 CPU 会等待 GPU 完成处理,直到启动第二个内核。因而为了优化 SAM,本文认为应该致力于消除导致空闲时间的阻塞 GPU 同步。

第二个是 SAM 在矩阵乘法中花费了大量的 GPU 时间(上图中的深绿色),这在 Transformers 中很常见。如果能够减少 SAM 模型在矩阵乘法上花费的 GPU 时间,我们就可以显着加快 SAM 的速度。

接下来本文用 SAM 的吞吐量 (img/s) 和内存开销 (GiB) 来建立基线。之后就是优化过程了。

PyTorch团队重写「分割一切」模型,比原始实现快八倍

Bfloat16 半精度(加上 GPU 同步和批处理)

为了解决上述问题,即让矩阵乘法花费的时间更少,本文转向 bfloat16。Bfloat16 是常用的半精度类型,通过降低每个参数和激活的精度,能够节省大量的计算时间和内存。

PyTorch团队重写「分割一切」模型,比原始实现快八倍

用 bfloat16 替换 padding 类型

此外,为了移除 GPU 同步,本文发现有两个位置可以优化。

PyTorch团队重写「分割一切」模型,比原始实现快八倍

PyTorch团队重写「分割一切」模型,比原始实现快八倍

具体来说(参考上图更容易理解,出现的变量名都在代码中),该研究发现在 SAM 的图像编码器中,有充当坐标缩放器(coordinate scalers)的变量 q_coords 和 k_coords,这些变量都是在 CPU 上分配和处理的。然而,一旦这些变量被用来在 rel_pos_resized 中建立索引,这些索引操作就会自动的将这些变量移动到 GPU 上,这种复制会导致 GPU 同步。为了解决上述问题,该研究注意到可以使用 torch.where 重写这部分内容来解决问题,如上所示。

内核跟踪

在应用了这些更改之后,本文注意到单个内核调用之间有着显著的时间间隔,尤其在小批量(这里为 1)时更为突出。为了更深入的了解这一现象,本文开始对批大小为 8 的 SAM 推理进行性能分析:

PyTorch团队重写「分割一切」模型,比原始实现快八倍

在查看每个内核所花费的时间时,本文观察到 SAM 的大部分 GPU 时间都花费在逐元素内核(elementwise kernels)和 softmax 操作上。

现在可以看到矩阵乘法的相对开销小了很多。

PyTorch团队重写「分割一切」模型,比原始实现快八倍

将 GPU 同步和 bfloat16 优化结合在一起,SAM 性能提高了 3 倍。

PyTorch团队重写「分割一切」模型,比原始实现快八倍

Torch.compile(+graph breaks 和 CUDA graphs)

本文发现在深入研究 SAM 的过程中有很多小的操作,他们认为使用编译器来融合操作有很大的好处,因而 PyTorch 对 torch.compile 做了以下优化:

  • 将 nn.LayerNorm 或 nn.GELU 等操作序列融合成一个单一的 GPU 内核;
  • 融合紧跟在矩阵乘法内核之后的操作,以减少 GPU 内核调用的数量。

通过这些优化,该研究减少了 GPU 全局内存往返次数(roundtrips),从而加快了推理速度。我们现在可以在 SAM 的图像编码器上尝试 torch.compile。为了最大限度地提高性能,本文使用了一些高级编译技术:

PyTorch团队重写「分割一切」模型,比原始实现快八倍

内核跟踪

PyTorch团队重写「分割一切」模型,比原始实现快八倍

结果显示,torch.compile 工作得很好。

PyTorch团队重写「分割一切」模型,比原始实现快八倍

可以观察到 softmax 占了很大一部分时间,然后是各种 GEMM 变体。以下测量的是批大小为 8 及以上的变化。

PyTorch团队重写「分割一切」模型,比原始实现快八倍

SDPA: scaled_dot_product_attention

接下来,本文又对 SDPA(scaled_dot_product_attention)进行了实验,研究的重点是注意力机制。一般来讲,原生注意力机制在时间和内存上随序列长度呈二次方扩展。PyTorch 的 SDPA 操作基于 Flash Attention、FlashAttentionV2 和 xFormer 的内存高效注意力原理构建,可以显着加快 GPU 注意力。与 torch.compile 相结合,这个操作允许在 MultiheadAttention 的变体中表达和融合一个共同的模式。经过一小部分更改后,现在模型可以使用 scaled_dot_product_attention。

PyTorch团队重写「分割一切」模型,比原始实现快八倍

内核跟踪

现在可以看到内存高效的注意力内核占用了 GPU 上大量的计算时间:

PyTorch团队重写「分割一切」模型,比原始实现快八倍

使用 PyTorch 的原生 scaled_dot_product_attention,可以显著增加批处理大小。下图为批大小为 32 及以上的变化。

PyTorch团队重写「分割一切」模型,比原始实现快八倍

之后,该研究又实验了 Triton,NestedTensor 、批处理 Predict_torch, int8 量化,半结构化 (2:4) 稀疏性等操作。

例如本文使用自定义 positional Triton 内核,观察到批大小为 32 的测量结果。

PyTorch团队重写「分割一切」模型,比原始实现快八倍

使用 Nested Tensor,批大小为 32 及以上的变化。

PyTorch团队重写「分割一切」模型,比原始实现快八倍

添加量化后,批大小为 32 及以上变化的测量结果。

PyTorch团队重写「分割一切」模型,比原始实现快八倍

文章的最后是半结构化稀疏性。该研究表示,矩阵乘法仍然是需要面对的一个瓶颈。解决的办法是使用稀疏化来近似矩阵乘法。通过稀疏矩阵(即将值归零)可以使用更少的位来存储权重和激活张量。该研究将张量中哪些权重设置为零的过程称为剪枝。剪枝掉较小的权重可以潜在地减小模型大小,而不会显着损失准确率。

剪枝的方法多种多样,从完全非结构化到高度结构化。虽然非结构化剪枝理论上对精度的影响最小,但 GPU 在进行大型密集矩阵乘法方面尽管非常高效,然而在稀疏情况下可能还会遭受显着的性能下降。PyTorch 最近支持的一种剪枝方法旨在寻求平衡,称为半结构化(或 2:4)稀疏性。这种稀疏存储将原始张量减少了 50%,同时产生密集张量输出。参见下图的说明。

PyTorch团队重写「分割一切」模型,比原始实现快八倍

为了使用这种稀疏存储格式和相关的快速内核,接下来要做的是剪枝权重。本文在 2:4 的稀疏度下选择最小的两个权重进行剪枝,将权重从默认的 PyTorch(“strided”)布局更改为这种新的半结构化稀疏布局很容易。要实现 Apply_sparse (model),只需要 32 行 Python 代码:

PyTorch团队重写「分割一切」模型,比原始实现快八倍

在 2:4 的稀疏度下,本文观察到 vit_b 和批大小为 32 时的 SAM 峰值性能:

PyTorch团队重写「分割一切」模型,比原始实现快八倍

最后,一句话总结这篇文章:本文介绍了迄今为止在 PyTorch 上最快的 Segment Anything 实现方式,借助官方发布的一系列新功能,本文在纯 PyTorch 中重写了原始 SAM,并且没有损失准确率。

感兴趣的读者可以查看原博客了解更多内容。

参考链接:https://pytorch.org/blog/accelerating-generative-ai



Tags:PyTorch   点击:()  评论:()
声明:本站部分内容及图片来自互联网,转载是出于传递更多信息之目的,内容观点仅代表作者本人,不构成投资建议。投资者据此操作,风险自担。如有任何标注错误或版权侵犯请与我们联系,我们将及时更正、删除。
▌相关推荐
突破Pytorch核心点,优化器 !!
今儿咱们聊聊pytorch中的优化器。优化器在深度学习中的选择直接影响模型的训练效果和速度。不同的优化器适用于不同的问题,其性能的差异可能导致模型更快、更稳定地收敛,或者...【详细内容】
2024-01-05  Search: PyTorch  点击:(90)  评论:(0)  加入收藏
突破Pytorch核心点,CNN !!!
创建卷积神经网络(CNN),很多初学者不太熟悉,今儿咱们来大概说说,给一个完整的案例进行说明。CNN 用于图像分类、目标检测、图像生成等任务。它的关键思想是通过卷积层和池化层来...【详细内容】
2024-01-03  Search: PyTorch  点击:(86)  评论:(0)  加入收藏
PyTorch团队重写「分割一切」模型,比原始实现快八倍
编辑:陈萍我们该如何优化 Meta 的「分割一切」模型,PyTorch 团队撰写的这篇博客由浅入深的帮你解答。从年初到现在,生成式 AI 发展迅猛。但很多时候,我们又不得不面临一个难题:如...【详细内容】
2023-11-23  Search: PyTorch  点击:(250)  评论:(0)  加入收藏
基于Pytorch的从零开始的目标检测
引言目标检测是计算机视觉中一个非常流行的任务,在这个任务中,给定一个图像,你预测图像中物体的包围盒(通常是矩形的) ,并且识别物体的类型。在这个图像中可能有多个对象,而且现...【详细内容】
2023-11-10  Search: PyTorch  点击:(201)  评论:(0)  加入收藏
深度学习中实现PyTorch和NumPy之间的数据转换知多少?
在深度学习中,PyTorch和NumPy是两个常用的工具,用于处理和转换数据。PyTorch是一个基于Python的科学计算库,用于构建神经网络和深度学习模型。NumPy是一个用于科学计算的Python...【详细内容】
2023-10-13  Search: PyTorch  点击:(67)  评论:(0)  加入收藏
Star量近8万,大火AutoGPT星标超PyTorch,网友:看清它的局限性
机器之心编辑部英伟达 AI 科学家 Jim Fan 表示,「AutoGPT 只是一项有趣的实验,虽然火爆但并不意味着可以投入生产。」他的观点得到了很多人的附和和现身说法。仿佛一夜之间,AI...【详细内容】
2023-04-18  Search: PyTorch  点击:(171)  评论:(0)  加入收藏
PyTorch将塑造生成式人工智能系统(GPT-4及以上)的未来
PyTorch不仅用于研究,还用于生产目的,每天有数十亿个请求得到服务和训练。...【详细内容】
2023-04-13  Search: PyTorch  点击:(171)  评论:(0)  加入收藏
微信基于 PyTorch 的大规模推荐系统训练实践
本文将介绍微信基于 PyTorch 进行的大规模推荐系统训练。推荐系统和其它一些深度学习领域不同,仍在使用 Tensorflow 作为训练框架,被广大开发者诟病。虽然也有使用 PyTorch 进...【详细内容】
2023-04-04  Search: PyTorch  点击:(236)  评论:(0)  加入收藏
PyTorch张量的四种乘法运算
在PyTorch中有四种类型的乘法运算(位置乘法、点积、矩阵与向量乘法、矩阵乘法),非常容易搞混,我们一起来看看这四种乘法运算的区别。位置乘法先构建两个张量a,b他们都是4行5列。a...【详细内容】
2023-03-21  Search: PyTorch  点击:(249)  评论:(0)  加入收藏
PyTorch 并行训练 DistributedDataParallel 完整代码示例
使用大型数据集训练大型深度神经网络 (DNN) 的问题是深度学习领域的主要挑战。 随着 DNN 和数据集规模的增加,训练这些模型的计算和内存需求也会增加。 这使得在计算资源有限...【详细内容】
2023-02-19  Search: PyTorch  点击:(275)  评论:(0)  加入收藏
▌简易百科推荐
Python 可视化:Plotly 库使用基础
当使用 Plotly 进行数据可视化时,我们可以通过以下示例展示多种绘图方法,每个示例都会有详细的注释和说明。1.创建折线图import plotly.graph_objects as go# 示例1: 创建简单...【详细内容】
2024-04-01  Python技术    Tags:Python   点击:(8)  评论:(0)  加入收藏
Python 办公神器:教你使用 Python 批量制作 PPT
介绍本文将介绍如何使用openpyxl和pptx库来批量制作PPT奖状。本文假设你已经安装了python和这两个库。本文的场景是:一名基层人员,要给一次比赛活动获奖的500名选手制作奖状,并...【详细内容】
2024-03-26  Python技术  微信公众号  Tags:Python   点击:(15)  评论:(0)  加入收藏
Python实现工厂模式、抽象工厂,单例模式
工厂模式是一种常见的设计模式,它可以帮助我们创建对象的过程更加灵活和可扩展。在Python中,我们可以使用函数和类来实现工厂模式。一、Python中实现工厂模式工厂模式是一种常...【详细内容】
2024-03-07  Python都知道  微信公众号  Tags:Python   点击:(31)  评论:(0)  加入收藏
不可不学的Python技巧:字典推导式使用全攻略
Python的字典推导式是一种优雅而强大的工具,用于创建字典(dict)。这种方法不仅代码更加简洁,而且执行效率高。无论你是Python新手还是有经验的开发者,掌握字典推导式都将是你技能...【详细内容】
2024-02-22  子午Python  微信公众号  Tags:Python技巧   点击:(32)  评论:(0)  加入收藏
如何进行Python代码的代码重构和优化?
Python是一种高级编程语言,它具有简洁、易于理解和易于维护的特点。然而,代码重构和优化对于保持代码质量和性能至关重要。什么是代码重构?代码重构是指在不改变代码外部行为的...【详细内容】
2024-02-22  编程技术汇    Tags:Python代码   点击:(32)  评论:(0)  加入收藏
Python开发者必备的八个PyCharm插件
在编写代码的过程中,括号几乎无处不在,以至于有时我们会拼命辨别哪个闭合括号与哪个开头的括号相匹配。这款插件能帮助解决这个众所周知的问题。前言在PyCharm中浏览插件列表...【详细内容】
2024-01-26  Python学研大本营  微信公众号  Tags:PyCharm插件   点击:(84)  评论:(0)  加入收藏
Python的Graphlib库,再也不用手敲图结构了
Python中的graphlib库是一个功能强大且易于使用的工具。graphlib提供了许多功能,可以帮助您创建、操作和分析图形对象。本文将介绍graphlib库的主要用法,并提供一些示例代码和...【详细内容】
2024-01-26  科学随想录  微信公众号  Tags:Graphlib库   点击:(86)  评论:(0)  加入收藏
Python分布式爬虫打造搜索引擎
简单分布式爬虫结构主从模式是指由一台主机作为控制节点负责所有运行网络爬虫的主机进行管理,爬虫只需要从控制节点那里接收任务,并把新生成任务提交给控制节点就可以了,在这个...【详细内容】
2024-01-25  大雷家吃饭    Tags:Python   点击:(58)  评论:(0)  加入收藏
使用Python进行数据分析,需要哪些步骤?
Python是一门动态的、面向对象的脚本语言,同时也是一门简约,通俗易懂的编程语言。Python入门简单,代码可读性强,一段好的Python代码,阅读起来像是在读一篇外语文章。Python这种特...【详细内容】
2024-01-15  程序员不二    Tags:Python   点击:(161)  评论:(0)  加入收藏
Python语言的特点及应用场景, 同其它语言对比优势
Python语言作为一种高级编程语言,具有许多独特的特点和优势,这使得它在众多编程语言中脱颖而出。在本文中,我们将探讨Python语言的特点、应用场景以及与其他语言的对比优势。一...【详细内容】
2024-01-09    今日头条  Tags:Python语言   点击:(251)  评论:(0)  加入收藏
站内最新
站内热门
站内头条