N-Shot Learning:用最少的数据训练最多的模型

作者&投稿:枝李 (若有异议请与网页底部的电邮联系)
~

作 者 | Heet Sankesara

翻 译 | 天字一号(郑州大学)、邺调(江苏 科技 大学)

审 校 | 唐里、Pita

如果将AI比作电力的话,那么数据就是创造电力的煤。

不幸的是,正如我们看到可用煤是消耗品一样,许多 AI 应用程序可供访问的数据很少或根本就没有数据。

新技术已经弥补了物质资源的不足;同样需要新的技术来允许在数据很少时,保证程序的正常运行。这是正在成为一个非常受欢迎的领域,核心问题:N-shot Learning

1. N-Shot Learning

你可能会问,什么是shot?好问题,shot只用一个样本来训练,在N-shot学习中,我们有N个训练的样本。术语“小样本学习”中的“小”通常在0-5之间,也就是说,训练一个没有样本的模型被称为 zero-shot ,一个样本就是 one-shot 学习,以此类推。

1-1 为什么需要N-Shot?

我们在 ImageNet 中的分类错误率已经小于 4% 了,为什么我们需要这个?

首先,ImageNet 的数据集包含了许多用于机器学习的示例,但在医学影像、药物发现和许多其他 AI 可能至关重要的领域中并不总是如此。典型的深度学习架构依赖于大量数据训练才能获得足够可靠的结果。例如,ImageNet 需要对数百张热狗图像进行训练,然后才能判断一幅新图像准确判断是否为热狗。一些数据集,就像7月4日庆祝活动后的冰箱缺乏热狗一样,是非常缺乏图像的。

机器学习有许多案例数据是都非常稀缺,这就是N-Shot技术的用武之地。我们需要训练一个包含数百万甚至数十亿个参数(全部随机初始化)的深度学习模型,但可用于训练的图像不超过 5 个图像。简单地说,我们的模型必须使用非常有限的热狗图像进行训练。

要处理像这个这样复杂的问题,我们首先需要清楚N-Shot的定义。

对我来说,最有趣的子领域是Zero-shot learning,该领域的目标是不需要一张训练图像,就能够对未知类别进行分类。

没有任何数据可以利用的话怎么进行训练和学习呢?

想一下这种情况,你能对一个没有见过的物体进行分类吗?

夜空中的仙后座(图源:https:// www .star-registration .com /constellation/cassiopeia)

是的,如果你对这个物体的外表、属性和功能有充足的信息的话,你是可以实现的。想一想,当你还是一个孩子的时候,是怎么理解这个世界的。在了解了火星的颜色和晚上的位置后,你可以在夜空中找到火星。或者你可以通过了解仙后座在天空中"基本上是一个畸形的'W'"这个信息中识别仙后座。

根据今年NLP的趋势,Zero-shot learning 将变得更加有效(https://blog.floydhub .com /ten-trends-in-deep-learning-nlp/#9-zero-shot-learning-will-become-more-effective)。

计算机利用图像的元数据执行相同的任务。元数据只不过是与图像关联的功能。以下是该领域的几篇论文,这些论文取得了优异的成绩。

在one-shot learning中,我们每个类别只有一个示例。现在的任务是使用一个影像进行训练,最终完成将测试影像划分为各个类。为了实现这一目标,目前已经出现了很多不同的架构,例如Siamese Neural Networks(https:// www .cs.cmu.edu/~rsalakhu/papers/oneshot1.pdf),它带来了重大进步,并达到了卓越的结果。然后紧接着是matching networks(https://ar xi v.org/pdf/1606.04080.pdf),这也帮助我们在这一领域实现了巨大的飞跃。

小样本学习只是one-shot learning 的灵活应用。在小样本学习中,我们有多个训练示例(通常为两到五个图像,尽管上述one-shot learning中的大多数模型也可用于小样本学习)。

在2019年计算机视觉和模式识别会议上,介绍了 Meta-Transfer Learning for Few-Shot Learning(https://ar xi v.org/pdf/ 181 2.02391v3.pdf)。这一模式为今后的研究开创了先例;它给出了最先进的结果,并为更复杂的元迁移学习方法铺平了道路。

这些元学习和强化学习算法中有许多都是与典型的深度学习算法相结合,并产生了显著的结果。原型网络是最流行的深度学习算法之一,并经常用于小样本学习 。

在本文中,我们将使用原型网络完成小样本学习,并了解其工作原理。

2. 原型网络背后的思想

上图为原型网络函数的示意图。编码器将图像进行编码映射到嵌入空间(黑圈)中的矢量中,支持图像用于定义原型(星形)。利用原型和编码查询图像之间的距离进行分类。图源:https:// www .semanticscholar.org/paper/Gaussian-Prototypical-Networks-for-Few-Shot-on-Fort/feaecb5f7a8d29636650db7c0b480f55d098a6a7/figure/1

与典型的深度学习体系结构不同,原型网络不直接对图像进行分类,而是通过在度量空间(https://en.wikipedia.org/wiki/Metric_space)中寻找图像之间的映射关系。

对于任何需要复习数学的人来说,度量空间都涉及"距离"的概念。它没有一个可区分的"起源"点。相反,在度量空间中,我们只计算一个点与另一个点的距离。因此,这里缺少了矢量空间中加法和标量乘法(因为与矢量不同,点仅表示坐标,添加两个坐标或缩放坐标毫无意义!)请查看此链接,详细了解矢量空间和度量空间之间的差异:https://math.stackexchange .com /questions/1 149 40/what-is-the-difference-between-metric-spaces-and-vector-spaces。

现在,我们已经学习了这一背景,我们可以开始了解原型网络是怎样不直接对图像进行分类,而是通过在度量空间中寻找图像之间的映射关系。如上图所示,同一类的图像经过编码器的映射之后,彼此之间的距离非常接近,而不同类的图像之间具有较长的距离。这意味着,每当给出新示例时,网络只需检查与新示例的图像最近的集合,并将该示例图像分到其相应的类。原型网络中将图像映射到度量空间的基础模型可以被称为"Image2Vector"模型,这是一种基于卷积神经网络 (CNN) 的体系结构。

现在,对于那些对 CNN 不了解的人,您可以在此处阅读更多内容:

简单地说,他们的目标是训练分类器。然后,该分类器可以对在训练期间不可用的新类进行概括,并且只需要每个新类的少量示例。因此,训练集包含一组类的图像,而我们的测试集包含另一组类的图像,这与前一组完全不相关。在该模型中,示例被随机分为支持集和查询集。

很少有镜头原型ck被计算为每个类的嵌入式支持示例的平均值。编码器映射新图像(x)并将其分类到最接近的类,如上图中的c2(图源:https://ar xi v.org/pdf/ 1703 .05 175 .pdf)。

在少镜头学习的情况下,训练迭代被称为一个片段。一个小插曲不过是我们训练网络一次,计算损失并反向传播错误的一个步骤。在每一集中,我们从训练集中随机选择NC类。对于每一类,我们随机抽取ns图像。这些图像属于支持集,学习模型称为ns-shot模型。另一个随机采样的nq图像属于查询集。这里nc、ns和nq只是模型中的超参数,其中nc是每次迭代的类数,ns是每个类的支持示例数,nq是每个类的查询示例数。

之后,我们通过“image2vector”模型从支持集图像中检索d维点。该模型利用图像在度量空间中的对应点对图像进行编码。对于每个类,我们现在有多个点,但是我们需要将它们表示为每个类的一个点。因此,我们计算每个类的几何中心,即点的平均值。之后,我们还需要对查询图像进行分类。

为此,我们首先需要将查询集中的每个图像编码为一个点。然后,计算每个质心到每个查询点的距离。最后,预测每个查询图像位于最靠近它的类中。一般来说,模型就是这样工作的。

但现在的问题是,这个“image2vector”模型的架构是什么?

论文汇总 Image2Vector 向量的结构

对于所有实际应用中,一般都会使用 4-5 CNN 模块。如上图所示,每个模块由一个 CNN 层组成,然后是批处理规范化,然后是 ReLu 激活函数,最后通向最大池层。在所有模块之后,剩余的输出将被展平并返回。这是本文中使用的网络结构(https://ar xi v.org/pdf/ 1703 .05 175 v2.pdf),您可以使用任何任何你喜欢的体系结构。有必要知道,虽然我们称之为"Image2Vector"模型,但它实际上将图像转换为度量空间中的 64 维的点。要更好地了解差异,请查看 math stack exchange(https://math.stackexchange .com /questions/ 64 5672/what-is-the-difference-between-a-point-and-a-vector)。

负log概率的原理,图源:https://ljvmiranda921.github.io/notebook/2017/08/13/softmax-and-the-negative-log-likelihood/#nll

现在,已经知道了模型是如何工作的,您可能更想知道我们将如何计算损失函数。我们需要一个足够强大的损失函数,以便我们的模型能够快速高效地学习。原型网络使用log-softmax损失,这只不过是对 softmax 损失取了对数。当模型无法预测正确的类时,log-softmax 的效果会严重惩罚模型,而这正是我们需要的。要了解有关损失函数的更多情况,请访问此处。这里是关于 softmax 和 log-softmax 的很好的讨论。

Omniglot数据集中的部分示例(图源:https://github .com /brendenlake/omniglot)

该网络在 Omniglot 数据集(https://github .com /brendenlake/omniglot)上进行了训练。Omniglot 数据集是专门为开发更类似于人类学习的算法而设计。它包含 50个不同的字母表,共计1623 个不同的手写字符。为了增加类的数量,所有图像分别旋转 90、 180 和 270 度,每次旋转后的图像都当做一个新类。因此,类的总数达到 了 64 92(1,623 + 4)类别。我们将 4200 个类别的图像作为训练数据,其余部分则用于测试。对于每个集合,我们根据 64 个随机选择的类中的每个示例对模型进行了训练。我们训练了模型 1 小时,获得了约 88% 的准确率。官方文件声称,经过几个小时的训练和调整一些参数,准确率达到99.7%。

是时候亲自动手实践了!

您可以通过访问以下链接轻松运行代码:

代码地址: https://github .com /Hsankesara/Prototypical-Networks

运行地址: https://floydhub .com /run?template=https://github .com /Hsankesara/Prototypical-Networks

让我们深入学习一下代码!(向左←滑动可查看完整代码)

以上的代码是 Image2Vector CNN结构的一个实现。它的输入图像的维度为28*28*3,返回特征向量的长度为 64 。

上面的代码片段是原型网中单个结构的实现。如果你有任何疑问,只需在评论中询问或在这里创建一个问题,非常欢迎您的参与和评论。

网络概述。图源:https://youtu.be/wcKL05DomBU

代码的结构与解释算法的格式相同。我们为原型网络函数提供以下输入:输入图像数据、输入标签、每次迭代的类数(即 Nc )、每个类的支持示例数(即 Ns )和每个类的查询示例数(即 Nq )。函数返回 Queryx ,它是从每个查询点到每个平均点的距离矩阵, Queryy 是包含与 Queryx 对应的标签的向量。 Queryy 存储 Queryx 的图像实际所属的类。在上面的图像中,我们可以看到,使用3个类,即 Nc =3,并且对于每个类,总共有5个示例用于训练,即 Ns =5。上面的s表示包含这15个( Ns * Nc )图像的支持集, X 表示查询集。注意,支持集和查询集都通过 f ,它只不过是我们的“image2vector”函数。它在度量空间中映射所有图像。让我们一步一步地把整个过程分解。

首先,我们从输入数据中随机选择 Nc 类。对于每个类,我们使用random_sample_cls函数从图像中随机选择一个支持集和一个查询集。在上图中,s是支持集,x是查询集。现在我们选择了类( C1 、C2 和 C3 ),我们通过“image2vector”模型传递所有支持集示例,并使用get_centroid函数计算每个类的质心。在附近的图像中也可以观察到这一点。每个质心代表一个类,将用于对查询进行分类。

网络中的质心计算。图源:https://youtu.be/wcKL05DomBU

在计算每个类的质心之后,我们现在必须预测其中一个类的查询图像。为此,我们需要与每个查询对应的实际标签,这些标签是使用get_query_y函数获得的。 Queryy 是分类数据,该函数将该分类文本数据转换为一个热向量,该热向量在列点对应的图像实际所属的行标签中仅为“1”,在列中为“0”。

之后,我们需要对应于每个 Queryx 图像的点来对其进行分类。我们使用“image2vector”模型得到这些点,现在我们需要对它们进行分类。为此,我们计算 Queryx 中每个点到每个类中心的距离。这给出了一个矩阵,其中索引 ij 表示与第 i 个查询图像对应的点到第 j 类中心的距离。我们使用get_query_x函数构造矩阵并将矩阵保存在 Queryx 变量中。在附近的图像中也可以看到同样的情况。对于查询集中的每个示例,将计算它与 C1、C2 和 C3 之间的距离。在这种情况下, X 最接近 C2 ,因此我们可以说 X 被预测属于 C2 类。

以编程方式,我们可以使用一个简单的ARMmin函数来做同样的事情,即找出图像被预测的类。然后使用预测类和实际类计算损失并反向传播错误。

如果你想使用经过训练的模型,或者只需要重新训练自己,这里是我的实现。您可以使用它作为API,并使用几行代码来训练模型。你可以在这里找到这个网络。

3. 资源列表

这里有些资源可以帮你更全面的了解本文内容:

4. 局限性

尽管原型网络的结果不错,但它们仍然有局限性。首先是缺乏泛化,它在Omniglot数据集上表现很好,因为其中的所有图像都是一个字符的图像,因此共享一些相似的特征。然而,如果我们试图用这个模型来分类不同品种的猫,它不会给我们准确的结果。猫和字符图像几乎没有共同的特征,可以用来将图像映射到相应度量空间的共同特征的数量可以忽略不计。

原型网络的另一个限制是只使用均值来确定中心,而忽略了支持集中的方差,这在图像有噪声的情况下阻碍了模型的分类能力。利用高斯原网络(https://ar xi v.org/abs/ 1708 .02 73 5)类中的方差,利用高斯公式对嵌入点进行建模,克服了这一局限性。

5. 结论

小概率学习是近年来研究的热点之一。有许多使用原型网络的新方法,比如这种元学习方法,效果很好。研究人员也在 探索 强化学习,这也有很大的潜力。这个模型最好的地方在于它简单易懂,并且能给出令人难以置信的结果。

via https://blog.floydhub .com /n-shot-learning/

本文由雷锋字幕组成员翻译,雷锋字幕组是由AI爱好者组成的字幕翻译团队;团队成员有大数据专家、算法工程师、图像处理工程师、产品经理、产品运营、IT咨询人、在校师生;志愿者们来自IBM、AVL、Adobe、阿里、百度等知名企业,北大、清华、港大、中科院、南卡罗莱纳大学、早稻田大学等海内外高校研究所。了解字幕组 请加 微信 ~




睢阳区14733017835: 如何用深度学习查找相似问题 -
艾狄奥迪: 如果用现有的深度学习去实现这一点,那就需要大量的事故数据,但这方面的数据供给非常有限,而采集数据又难度很大.首先,没有人能够准确预测何时何地会发生何种事故,因此无法系统地提前部署以采集真实事故数据;其次,从法律上来...

睢阳区14733017835: 自己学习深度学习时,有哪些途径寻找数据集 -
艾狄奥迪: 一般大型数据集都伴随着比赛,比如图像分类数据集伴随着ImageNet比赛,图像检测/图像分割等数据则是伴随着Pascal VOC和COCO等比赛,文本识别与检测伴随着ICDAR比赛,还有很多这样的.其实你想要什么样类型的数据集,百度一下基本会有.

睢阳区14733017835: 如何使用计算器开N次方根 -
艾狄奥迪: 使用计算器开N次方根的计算方法非常简单,按照下面的操作就可以完成了,具体操作步骤如下所述:1、首先打开科学计算器,如图所示.2、然后输入需要开次方的数字,如图所示. 3、接下来按一下计算器上的“^”这个按钮,如图所示. 4、然后输入需要开几次方的数字,这里以2次方为例输入“2”,如图所示. 5、最后点击一下等于符号,这样就计算出2的2次方了,如图所示.

睢阳区14733017835: 为什么bpl贝叶斯框架可以做到one - shot learning -
艾狄奥迪: 贝叶斯定理用数学的方法来解释生活中大家都知道的常识 形式最简单的定理往往是最好的定理,比如说中心极限定理,这样的定理往往会成为某一个领域的理论基础.机器学习的各种算法中使用的方法,最常见的就是贝叶斯定理. 贝叶斯定理...

睢阳区14733017835: 为什么在求样本方差是自由度为n - 1?求大神帮助 -
艾狄奥迪: 由于 则在求离差平均和时, 只有 n-1 个数据可以自由取值, 所以自由度为 n-1 . 样本方差的分母用 n-1 ,其原因可以从多方面来解释. 从实际应用的角度看,当我们用样本方差 估计总体方 差σ2 时, 是σ 2 的无偏估计量.... 一组数据中可以自由取值的数据的个数2? 当样本数据的个数为 n 时,若样本均值x 确定后,?? 只有n-1个数据可以自由取值,其中必有一个数据则不能自由取值3? 例如,样本有3个数值,即x1=2,x2=4,x3=9,希望采纳

睢阳区14733017835: c++编程题 编写函数实现求n!,主程序要求输入n的值,用函数调用的方式求n!的值 -
艾狄奥迪: 限于整型数据的数据范围,所以实际上这个程序只能求n<=12的情况.如果想求更大的,必须为这个阶乘自定义一个数据类型. #include<iostream> using namespace std; int fact(int n); int main(){int n; cout<<"求n!,请输入n:\n";cin>>n; ...

睢阳区14733017835: n个五边形要用多少根火柴棒 -
艾狄奥迪: n=1时用5根 n=2时最少用9根,可以共用一条边 n>2时用(n-2)X3+9根,每多加一个五边形,可以有两条共用边,所以只用加3根

睢阳区14733017835: 编写程序用递归法实现将一个整数n转化成字符串 -
艾狄奥迪: #include int myfun( int n, char*p ) { int i=0 if(n>10) { i = myfun(n/10, p ); } *(p+i) = n%10; i++; return i; } main() { int n; char a[20]; printf( "input n:\n"); scanf("%d",&n); int i=my( n,a); a[i] = '\0'; printf("n=%s\n",a); getch(); }

睢阳区14733017835: 编写c++程序,递归函数与非递归函数.编写一个函数,求从n个不同的数中取r个数的所有选择的个数.其个数值为:C rn=n!/( r!*(n - r)!). -
艾狄奥迪: #includelong con(int a){ int i; long add=1; for(i=1;i<=a;i++) add*=i; return add; } long fun(int n,int r){ return con(n)/(con(r)*con(n-r)); } void main(){ int n,r; cout<<"请输入数字的个数n与抽取个数r:"<>n>>r;//输入时用空格隔开 cout<<"个不同的数中取r个数的所有选择的个数.其个数值为:"; cout<<

睢阳区14733017835: 1. 有n个学生,每个学生有3门课程的成绩,n不小于3.要求: (1) 编写函数计算每个学生的平均成绩; ... -
艾狄奥迪: #include<stdio.h> #include<alloc.h> #include<string.h> struct student_inf{char name[11];float math;float english;float computer;float average; }stu; void scmsinput(struct student_inf *psf,int n) {int i;for(i=0;i<n;i++){printf("\n输入第%d个学生...

本站内容来自于网友发表,不代表本站立场,仅表示其个人看法,不对其真实性、正确性、有效性作任何的担保
相关事宜请发邮件给我们
© 星空见康网