秒秒钟揪出张量形状错误,这个工具能防止 ML 模型训练白忙一场

观点
2021
12/28
00:37
亚设网
分享

模型吭哧吭哧训练了半天,结果发现张量形状定义错了,这一定没少让你抓狂吧。那么针对这种情况,是否存在较好的解决方法呢?

这不最近,韩国首尔大学的研究者就开发出了一款“利器”—— PyTea。

据研究人员介绍,它在训练模型前,能几秒内帮助你静态分析潜在的张量形状错误。

那么 PyTea 是如何做到的,到底靠不靠谱,让我们一探究竟吧。

PyTea 的出场方式

为什么张量形状错误这么重要?

神经网络涉及到一系列的矩阵计算,前面矩阵的列数必需匹配后面矩阵的行数,如果维度不匹配,那后面的运算就都无法运行了。

秒秒钟揪出张量形状错误,这个工具能防止 ML 模型训练白忙一场

上图代码就是一个典型的张量形状错误,[B x 120] * [80 x 10] 无法进行矩阵运算。

秒秒钟揪出张量形状错误,这个工具能防止 ML 模型训练白忙一场

无论是 PyTorch,TensorFlow 还是 Keras 在进行神经网络的训练时,大多都遵循图上的流程。

首先定义一系列神经网络层(也就是矩阵),然后合成神经网络模块……

那么为什么需要 PyTea 呢?

以往我们都是在模型读取大量数据,开始训练,代码运行到错误张量处,才可以发现张量形状定义错误。

由于模型可能十分复杂,训练数据非常庞大,所以发现错误的时间成本会很高,有时候代码放在后台训练,出了问题都不知道……

PyTea 就可以有效帮我们避免这个问题,因为它能在运行模型代码之前,就帮我们分析出形状错误。

秒秒钟揪出张量形状错误,这个工具能防止 ML 模型训练白忙一场

网友们已经在热烈讨论了。

PyTea 是如何运作的,它能否有效地检查出错误呢?

秒秒钟揪出张量形状错误,这个工具能防止 ML 模型训练白忙一场

受各种约束条件的影响,代码可能的运行路径有很多,不同的数据会走向不同的路径。

所以 PyTea 需要静态扫描所有可能的运行路径,跟踪张量变化,推断出每个张量形状精确而保守的范围。

上图就是 PyTea 的整体架构,一共分为翻译语言,收集约束条件,求解器判断和给出反馈四步。

秒秒钟揪出张量形状错误,这个工具能防止 ML 模型训练白忙一场

首先 PyTea 将原始的 Python 代码翻译成一种内核语言。PyTea 内部表示法(PyTea IR)。

秒秒钟揪出张量形状错误,这个工具能防止 ML 模型训练白忙一场

接着 PyTea 追踪 PyTea IR 每个可能的执行路径,并收集有关张量形状的约束条件。

判断约束条件是否被满足,分为线上分析和离线分析两步:

线上分析 node.js(TypeScript / JavaScript):查找张量形状数值上的不匹配和误用 API 函数的情况。如果 PyTea 发现问题,就会停止在当前位置,然后给用户报错。

秒秒钟揪出张量形状错误,这个工具能防止 ML 模型训练白忙一场

离线分析 Z3 / Python:如果线上分析没有问题,PyTea 将收集到的约束条件传给 SMT(Satisfiability Modulo Theories)求解器 Z3,求解器负责查看每条路径的约束条件是否都能被满足,如果不能,返回给用户第一条出错路径的约束条件。

秒秒钟揪出张量形状错误,这个工具能防止 ML 模型训练白忙一场

如果求解器过久没有反应,PyTea 会返回不知道是否存在问题。

然而追踪所有可能的路径是指数级别的任务,对于复杂的神经网络来说,一定会发生路径爆炸这个问题。

秒秒钟揪出张量形状错误,这个工具能防止 ML 模型训练白忙一场

比如说在这个例子中,网络的最终结构是由 24 个相同模块块构成的(第 17 行),那么可能的路径就有 16M 之多。

所以路径爆炸是一定要处理的,PyTea 是怎么做的?

PyTea 选择保守的地对路径剪枝和超时判断来处理这种路径爆炸。

什么样的路径可以被剪枝?

PyTea 给出的答案是,如果该前馈函数不改变全局值,并且它的输出值不受分支条件影响,对于每条路径都是相等的,我们就可以忽略许多完全一致的路径,来节约计算资源。

如果路径剪枝还是不行,那么就只能按超时处理了。

原理就介绍这么多了,感觉还是值得一试的,现在代码已经在 GitHub 上面开源了,快去看看吧!

使用方法

依赖库:

秒秒钟揪出张量形状错误,这个工具能防止 ML 模型训练白忙一场

安装方法:

秒秒钟揪出张量形状错误,这个工具能防止 ML 模型训练白忙一场

运行命令:

秒秒钟揪出张量形状错误,这个工具能防止 ML 模型训练白忙一场

参考链接

[1]https://github.com/ropas/pytea

[2]https://arxiv.org/abs/2112.09037

THE END
免责声明:本文系转载,版权归原作者所有;旨在传递信息,不代表 亚设网的观点和立场。

20.jpg

关于我们

微信扫一扫,加关注

Top