您当前的位置:首页 > 电脑百科 > 人工智能

用 Java 训练深度学习模型,原来能这么简单

时间:2020-11-03 10:10:21  来源:  作者:

作者:DJL-Keerthan&Lanking

 

HelloGitHub 推出的《讲解开源项目》 系列。这一期是由亚马逊工程师:Keerthan Vasist,为我们讲解 DJL(完全由 JAVA 构建的深度学习平台)系列的第 4 篇。

一、前言

很长时间以来,Java 都是一个很受企业欢迎的编程语言。得益于丰富的生态以及完善维护的包和框架,Java 拥有着庞大的开发者社区。尽管深度学习应用的不断演进和落地,提供给 Java 开发者的框架和库却十分短缺。现今主要流行的深度学习模型都是用 Python 编译和训练的。对于 Java 开发者而言,如果要进军深度学习界,就需要重新学习并接受一门新的编程语言同时还要学习深度学习的复杂知识。这使得大部分 Java 开发者学习和转型深度学习开发变得困难重重。

 

为了减少 Java 开发者学习深度学习的成本,AWS 构建了 Deep Java Library (DJL),一个为 Java 开发者定制的开源深度学习框架。它为 Java 开发者对接主流深度学习框架提供了一个桥梁。

用 Java 训练深度学习模型,原来能这么简单

 

在这篇文章中,我们会尝试用 DJL 构建一个深度学习模型并用它训练 MNIST 手写数字识别任务。

二、什么是深度学习?

在我们正式开始之前,我们先来了解一下机器学习和深度学习的基本概念。

 

机器学习是一个通过利用统计学知识,将数据输入到计算机中进行训练并完成特定目标任务的过程。这种归纳学习的方法可以让计算机学习一些特征并进行一系列复杂的任务,比如识别照片中的物体。由于需要写复杂的逻辑以及测量标准,这些任务在传统计算科学领域中很难实现。

 

深度学习是机器学习的一个分支,主要侧重于对于人工神经网络的开发。人工神经网络是通过研究人脑如何学习和实现目标的过程中归纳而得出一套计算逻辑。它通过模拟部分人脑神经间信息传递的过程,从而实现各类复杂的任务。深度学习中的“深度”来源于我们会在人工神经网络中编织构建出许多层(layer)从而进一步对数据信息进行更深层的传导。深度学习技术应用范围十分广泛,现在被用来做目标检测、动作识别、机器翻译、语意分析等各类现实应用中。

三、训练 MNIST 手写数字识别

3.1 项目配置

你可以用如下的 gradle 配置来引入依赖项。在这个案例中,我们用 DJL 的 api 包 (核心 DJL 组件) 和 basicdataset 包 (DJL 数据集) 来构建神经网络和数据集。这个案例中我们使用了 MXNet 作为深度学习引擎,所以我们会引入 mxnet-engine 和 mxnet-native-auto 两个包。这个案例也可以运行在 PyTorch 引擎下,只需要替换成对应的软件包即可。

plugins {
    id 'java'
}
repositories {                           
    jcenter()
}
dependencies {
    implementation platform("ai.djl:bom:0.8.0")
    implementation "ai.djl:api"
    implementation "ai.djl:basicdataset"
    // MXNet
    runtimeOnly "ai.djl.mxnet:mxnet-engine"
    runtimeOnly "ai.djl.mxnet:mxnet-native-auto"
}

3.2 NDArray 和 NDManager

NDArray 是 DJL 存储数据结构和数学运算的基本结构。一个 NDArray 表达了一个定长的多维数组。NDArray 的使用方法类似于 Python 中的 numpy.ndarray。

 

NDManager 是 NDArray 的老板。它负责管理 NDArray 的产生和回收过程,这样可以帮助我们更好的对 Java 内存进行优化。每一个 NDArray 都会是由一个 NDManager 创造出来,同时它们会在 NDManager 关闭时一同关闭。

 

NDManager 和 NDArray 都是由 Java 的 AutoClosable 构建,这样可以确保在运行结束时及时进行回收。想了解更多关于它们的用法和实践,请参阅我们前一期文章:DJL 之 Java 玩转多维数组,就像 NumPy 一样

 

Model

在 DJL 中,训练和推理都是从 Model class 开始构建的。我们在这里主要讲训练过程中的构建方法。下面我们为 Model 创建一个新的目标。因为 Model 也是继承了 AutoClosable 结构体,我们会用一个 try block 实现:

try (Model model = Model.newInstance()) {
    ...
    // 主体训练代码
    ...
}

准备数据

MNIST(Modified National Institute of Standards and Technology)数据库包含大量手写数字的图,通常被用来训练图像处理系统。DJL 已经将 MNIST 的数据集收录到了 basicdataset 数据集里,每个 MNIST 的图的大小是 28 x 28。如果你有自己的数据集,你也可以通过 DJL 数据集导入教程来导入数据集到你的训练任务中。

int batchSize = 32; // 批大小
Mnist trainingDataset = Mnist.builder()
        .optUsage(Usage.TRAIN) // 训练集
        .setSampling(batchSize, true)
        .build();
Mnist validationDataset = Mnist.builder()
        .optUsage(Usage.TEST) // 验证集
        .setSampling(batchSize, true)
        .build();

这段代码分别制作出了训练和验证集。同时我们也随机排列了数据集从而更好的训练。除了这些配置以外,你也可以添加对于图片的进一步处理,比如设置图片大小,对图片进行归一化等处理。

制作 model(建立 Block)

当你的数据集准备就绪后,我们就可以构建神经网络了。在 DJL 中,神经网络是由 Block(代码块)构成的。一个 Block 是一个具备多种神经网络特性的结构。它们可以代表 一个操作, 神经网络的一部分,甚至是一个完整的神经网络。然后 Block 可以顺序执行或者并行。同时 Block 本身也可以带参数和子 Block。这种嵌套结构可以帮助我们构造一个复杂但又不失维护性的神经网络。在训练过程中,每个 Block 中附带的参数会被实时更新,同时也包括它们的各个子 Block。这种递归更新的过程可以确保整个神经网络得到充分训练。

 

当我们构建这些 Block 的过程中,最简单的方式就是将它们一个一个的嵌套起来。直接使用准备好 DJL 的 Block 种类,我们就可以快速制作出各类神经网络。

 

根据几种基本的神经网络工作模式,我们提供了几种 Block 的变体。SequentialBlock 是为了应对顺序执行每一个子 Block 构造而成的。它会将前一个子 Block 的输出作为下一个 Block 的输入 继续执行到底。与之对应的,是 ParallelBlock 它用于将一个输入并行输入到每一个子 Block 中,同时将输出结果根据特定的合并方程合并起来。最后我们说一下 LambdaBlock,它是帮助用户进行快速操作的一个 Block,其中并不具备任何参数,所以也没有任何部分在训练过程中更新。

用 Java 训练深度学习模型,原来能这么简单

 

我们来尝试创建一个基本的 多层感知机(MLP)神经网络吧。多层感知机是一个简单的前向型神经网络,它只包含了几个全连接层 (LinearBlock)。那么构建这个网络,我们可以直接使用 SequentialBlock。

int input = 28 * 28; // 输入层大小
int output = 10; // 输出层大小
int[] hidden = new int[] {128, 64}; // 隐藏层大小
SequentialBlock sequentialBlock = new SequentialBlock();
sequentialBlock.add(Blocks.batchFlattenBlock(input));
for (int hiddenSize : hidden) {
    // 全连接层
    sequentialBlock.add(Linear.builder().setUnits(hiddenSize).build());
    // 激活函数
    sequentialBlock.add(activation);
}
sequentialBlock.add(Linear.builder().setUnits(output).build());

当然 DJL 也提供了直接就可以拿来用的 MLP Block :

Block block = new Mlp(
        Mnist.IMAGE_HEIGHT * Mnist.IMAGE_WIDTH,
        Mnist.NUM_CLASSES,
        new int[] {128, 64});

训练

当我们准备好数据集和神经网络之后,就可以开始训练模型了。在深度学习中,一般会由下面几步来完成一个训练过程:

用 Java 训练深度学习模型,原来能这么简单

 

  • 初始化:我们会对每一个 Block 的参数进行初始化,初始化每个参数的函数都是由 设定的 Initializer 决定的。
  • 前向传播:这一步将输入数据在神经网络中逐层传递,然后产生输出数据。
  • 计算损失:我们会根据特定的损失函数 Loss 来计算输出和标记结果的偏差。
  • 反向传播:在这一步中,你可以利用损失反向求导算出每一个参数的梯度。
  • 更新权重:我们会根据选择的优化器(Optimizer)更新每一个在 Block 上参数的值。

DJL 利用了 Trainer 结构体精简了整个过程。开发者只需要创建 Trainer 并指定对应的 Initializer、Loss 和 Optimizer 即可。这些参数都是由 TrainingConfig 设定的。下面我们来看一下具体的参数设置:

  • TrainingListener:这个是对训练过程设定的监听器。它可以实时反馈每个阶段的训练结果。这些结果可以用于记录训练过程或者帮助 debug 神经网络训练过程中的问题。用户也可以定制自己的 TrainingListener 来对训练过程进行监听。
DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
    .addEvaluator(new Accuracy())
    .addTrainingListeners(TrainingListener.Defaults.logging());
try (Trainer trainer = model.newTrainer(config)){
    // 训练代码
}

当训练器产生后,我们可以定义输入的 Shape。之后就可以调用 fit 函数来进行训练。fit 函数会对输入数据,训练多个 epoch 是并最终将结果存储在本地目录下。

/*
 * MNIST 包含 28x28 灰度图片并导入成 28 * 28 NDArray。
 * 第一个维度是批大小, 在这里我们设置批大小为 1 用于初始化。
 */
Shape inputShape = new Shape(1, Mnist.IMAGE_HEIGHT * Mnist.IMAGE_WIDTH);
int numEpoch = 5;
String outputDir = "/build/model";

// 用输入初始化 trainer
trainer.initialize(inputShape);

TrainingUtils.fit(trainer, numEpoch, trainingSet, validateSet, outputDir, "mlp");

这就是训练过程的全部流程了!用 DJL 训练是不是还是很轻松的?之后看一下输出每一步的训练结果。如果你用了我们默认的监听器,那么输出是类似于下图:

[INFO ] - Downloading libmxnet.dylib ...
[INFO ] - Training on: cpu().
[INFO ] - Load MXNet Engine Version 1.7.0 in 0.131 ms.
Training:    100% |████████████████████████████████████████| Accuracy: 0.93, SoftmaxCrossEntropyLoss: 0.24, speed: 1235.20 items/sec
Validating:  100% |████████████████████████████████████████|
[INFO ] - Epoch 1 finished.
[INFO ] - Train: Accuracy: 0.93, SoftmaxCrossEntropyLoss: 0.24
[INFO ] - Validate: Accuracy: 0.95, SoftmaxCrossEntropyLoss: 0.14
Training:    100% |████████████████████████████████████████| Accuracy: 0.97, SoftmaxCrossEntropyLoss: 0.10, speed: 2851.06 items/sec
Validating:  100% |████████████████████████████████████████|
[INFO ] - Epoch 2 finished.NG [1m 41s]
[INFO ] - Train: Accuracy: 0.97, SoftmaxCrossEntropyLoss: 0.10
[INFO ] - Validate: Accuracy: 0.97, SoftmaxCrossEntropyLoss: 0.09
[INFO ] - train P50: 12.756 ms, P90: 21.044 ms
[INFO ] - forward P50: 0.375 ms, P90: 0.607 ms
[INFO ] - training-metrics P50: 0.021 ms, P90: 0.034 ms
[INFO ] - backward P50: 0.608 ms, P90: 0.973 ms
[INFO ] - step P50: 0.543 ms, P90: 0.869 ms
[INFO ] - epoch P50: 35.989 s, P90: 35.989 s

当训练结果完成后,我们可以用刚才的模型进行推理来识别手写数字。

四、最后

在这个文章中,我们介绍了深度学习的基本概念,同时还有如何优雅的利用 DJL 构建深度学习模型并进行训练。



Tags:深度学习   点击:()  评论:()
声明:本站部分内容及图片来自互联网,转载是出于传递更多信息之目的,内容观点仅代表作者本人,如有任何标注错误或版权侵犯请与我们联系(Email:2595517585@qq.com),我们将及时更正、删除,谢谢。
▌相关推荐
现如今,谈起人工智能我们都会想到的是“深度学习”(deep learning),无论是战胜李世石的AlphaGo,还是能够随意写出人类水平文章的GPT-3,它们的背后都依托的是这套算法。 它具有很好...【详细内容】
2021-09-14  Tags: 深度学习  点击:(104)  评论:(0)  加入收藏
0 前言大家好,欢迎来到“自由技艺”的知识小馆。今天我们来探讨下深度学习中的 Attention 机制,中文名为“注意力”。本文内容结构组织如下:1 为什么需要引入 Attention 机制?2...【详细内容】
2021-06-09  Tags: 深度学习  点击:(149)  评论:(0)  加入收藏
微信正用着的深度学习框架,现在你也可以上手试一试了。 就在最近,腾讯把这个名叫deepx_core的深度学习基础库正式对外开源。 相比于PyTorch、TensorFlow等流行深度学习框架,这位选手不仅具有通用性,还针对高维稀疏数据...【详细内容】
2021-04-06  Tags: 深度学习  点击:(205)  评论:(0)  加入收藏
为深度学习项目建立一个良好的环境不是一件容易的任务。因为需要处理的事情太多了:库必须匹配特定的版本,整个环境需要可以复制到其他机器上,所有东西都需要能够机器中的所有...【详细内容】
2021-03-05  Tags: 深度学习  点击:(96)  评论:(0)  加入收藏
深度学习是机器学习的一个子领域,它采用了一个特定的模型:一族通过某种方式连接起来的简单函数。由于这类模型的结构是受到人类大脑结构的启发而创造出来的...【详细内容】
2021-02-26  Tags: 深度学习  点击:(269)  评论:(0)  加入收藏
基于人工智能和深度学习方法的现代计算机视觉技术在过去10年里取得了显著进展。如今,它被用于图像分类、人脸识别、图像中物体的识别、视频分析和分类以及机器人和自动驾驶车辆的图像处理等应用上。...【详细内容】
2021-01-07  Tags: 深度学习  点击:(87)  评论:(0)  加入收藏
什么是深度学习深度学习有如下一些众所周知且被广泛接受的定义。(1)深度学习是机器学习的子集。(2)深度学习使用级联的多层(非线性)处理单元,称为人工神经网络(ANN),以及受大脑结构和...【详细内容】
2020-12-07  Tags: 深度学习  点击:(184)  评论:(0)  加入收藏
本文将介绍在 Windows 计算机上配置深度学习环境的全过程,其中涉及安装所需的工具和驱动软件。出人意料的是,即便只是配置深度学习环境,任务也不轻松。你很有可能在这个过程中犯错。我个人已经很多次从头开始配置深度学...【详细内容】
2020-12-02  Tags: 深度学习  点击:(97)  评论:(0)  加入收藏
一代深度学习框架研究于璠华为技术有限公司摘要:从人工智能的历史出发,简述深度学习发展历程以及目前的挑战,通过介绍新一代深度学习框架的特点,分析总体框架,阐述自动并行、自动...【详细内容】
2020-11-10  Tags: 深度学习  点击:(94)  评论:(0)  加入收藏
HelloGitHub 推出的《讲解开源项目》 系列。这一期是由亚马逊工程师:Keerthan Vasist,为我们讲解 DJL(完全由 Java 构建的深度学习平台)系列的第 4 篇。...【详细内容】
2020-11-03  Tags: 深度学习  点击:(86)  评论:(0)  加入收藏
▌简易百科推荐
作为数据科学家或机器学习从业者,将可解释性集成到机器学习模型中可以帮助决策者和其他利益相关者有更多的可见性并可以让他们理解模型输出决策的解释。在本文中,我将介绍两个...【详细内容】
2021-12-17  deephub    Tags:AI   点击:(15)  评论:(0)  加入收藏
基于算法的业务或者说AI的应用在这几年发展得很快。但是,在实际应用的场景中,我们经常会遇到一些非常奇怪的偏差现象。例如,Facebook将黑人标记为灵长类动物、城市图像识别系统...【详细内容】
2021-11-08  数据学习DataLearner    Tags:机器学习   点击:(32)  评论:(0)  加入收藏
11月2日召开的世界顶尖科学家数字未来论坛上,2013年诺贝尔化学奖得主迈克尔·莱维特、2014年诺贝尔生理学或医学奖得主爱德华·莫索尔、2007年图灵奖得主约瑟夫·斯发斯基、1986年图灵奖得主约翰·霍普克罗夫特、2002...【详细内容】
2021-11-03  张淑贤  证券时报  Tags:人工智能   点击:(39)  评论:(0)  加入收藏
鉴于物联网设备广泛部署、5G快速无线技术闪亮登场,把计算、存储和分析放在靠近数据生成的地方来处理,让边缘计算有了用武之地。 边缘计算正在改变全球数百万个设备处理和传输...【详细内容】
2021-10-26    计算机世界  Tags:边缘计算   点击:(45)  评论:(0)  加入收藏
这是几位机器学习权威专家汇总的725个机器学习术语表,非常全面了,值得收藏! 英文术语 中文翻译 0-1 Loss Function 0-1损失函数 Accept-Reject Samplin...【详细内容】
2021-10-21  Python部落    Tags:机器学习   点击:(43)  评论:(0)  加入收藏
要开始为开源项目做贡献,有一些先决条件:1. 学习一门编程语言:由于在开源贡献中你需要编写代码才能参与开发,你需要学习任意一门编程语言。根据项目的需要,在后期学习另一种语言...【详细内容】
2021-10-20  TSINGSEE青犀视频    Tags:机器学习   点击:(37)  评论:(0)  加入收藏
SimpleAI.人工智能、机器学习、深度学习还是遥不可及?来这里看看吧~ 从基本的概念、原理、公式,到用生动形象的例子去理解,到动手做实验去感知,到著名案例的学习,到用所学来实现...【详细内容】
2021-10-19  憨昊昊    Tags:神经网络   点击:(47)  评论:(0)  加入收藏
语言是人类思维的基础,当计算机具备了处理自然语言的能力,才具有真正智能的想象。自然语言处理(Natural Language Processing, NLP)作为人工智能(Artificial Intelligence, AI)的核心技术之一,是用计算机来处理、理解以及运...【详细内容】
2021-10-11    36氪  Tags:NLP   点击:(48)  评论:(0)  加入收藏
边缘计算是什么?近年来,物联网设备数量呈线性增长趋势。根据艾瑞测算, 2020年,中国物联网设备的数量达74亿,预计2025年突破150亿个。同时,设备本身也变得越来越智能化,AI与互联网在...【详细内容】
2021-09-22  汉智兴科技    Tags:   点击:(54)  评论:(0)  加入收藏
说起人工智能,大家总把它和科幻电影中的机器人联系起来,而实际上这些科幻场景与现如今的人工智能没什么太大关系。人工智能确实跟人类大脑很相似,但它们的显著差异在于人工智能...【详细内容】
2021-09-17  异步社区    Tags:人工智能   点击:(57)  评论:(0)  加入收藏
最新更新
栏目热门
栏目头条