随着人工智能技术的不断发展,手写数字识别已经成为深度学习领域的一个经典案例。不管是老牌的机器学习模型还是现代的神经网络架构,手写数字识别总是大家学习和实战的起点之一。而对于我们日常使用的Java开发者来说,借助DeepLearning4J这个强大的Java深度学习框架,可以很方便地在Java项目中实现手写数字识别的功能。
在本文中,我们会以一种轻松的方式,带你一步步实现一个完整的手写数字识别系统,使用Spring Boot作为后端框架,结合Thymeleaf和Bootstrap来构建用户友好的界面。
在很多人提到深度学习时,首先想到的可能是TensorFlow或PyTorch。那么,为什么我们这次要使用DeepLearning4J呢?其实,DeepLearning4J最大的优势就在于它是专门为Java和JVM语言设计的深度学习框架。对于熟悉Java生态的开发者来说,DeepLearning4J让我们可以利用现有的Java工具链和库,轻松构建和部署深度学习模型。
除此之外,DeepLearning4J支持大规模的分布式训练,甚至可以与Hadoop和Spark进行集成,方便企业级应用开发。
在本文中,我们将实现一个可以识别手写数字的Web应用。用户可以通过网页上传一张手写数字的图片,系统会自动识别并返回预测结果。
为了实现这个功能,我们会按照以下步骤进行:
我们的项目结构将包括后端服务用于处理上传的图片和执行预测逻辑,以及前端页面用于用户上传图片和查看预测结果。
├── src/
│ ├── main/
│ │ ├── java/
│ │ │ ├── com/neo/
│ │ │ │ ├── controller/
│ │ │ │ │ └── OCRController.java
│ │ │ │ ├── service/
│ │ │ │ │ ├── ImageProcessingService.java
│ │ │ │ │ └── OCRPredictionService.java
│ │ │ │ └── model/
│ │ │ │ └── OCRModelService.java
│ │ ├── resources/
│ │ │ ├── static/
│ │ │ ├── templates/
│ │ │ │ └── upload.html
│ ├── pom.xml
在这段代码中,我们使用DeepLearning4J框架训练了一个手写数字识别模型,基于经典的MNIST数据集。下面我们将逐步解释每一部分代码的功能和作用,以便更好地理解整个模型训练的过程。
package com.neo.service;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions;
public class OCRModelService {
public static void trainModel() throws Exception {
int batchSize = 128;
int rngSeed = 123;
int numEpochs = 1;
DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, rngSeed);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(rngSeed)
.updater(new Nesterovs(0.006, 0.9))
.l2(1e-4)
.list()
.layer(0, new DenseLayer.Builder()
.nIn(28 * 28)
.nOut(1000)
.activation(Activation.RELU)
.build())
.layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.activation(Activation.SOFTMAX)
.nIn(1000)
.nOut(10)
.build())
.build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
model.setListeners(new ScoreIterationListener(100));
model.fit(mnistTrain, numEpochs);
// 保存模型
ModelSerializer.writeModel(model, "ocr-model.zip", true);
}
public static void main(String[] args) {
try {
// 调用训练方法
OCRModelService.trainModel();
System.out.println("模型训练完成并保存为 ocr-model.zip");
} catch (Exception e) {
e.printStackTrace();
}
}
}
int batchSize = 128;
int rngSeed = 123;
DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, rngSeed);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(rngSeed)
.updater(new Nesterovs(0.006, 0.9))
.l2(1e-4)
.list()
.layer(0, new DenseLayer.Builder()
.nIn(28 * 28)
.nOut(1000)
.activation(Activation.RELU)
.build())
.layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.activation(Activation.SOFTMAX)
.nIn(1000)
.nOut(10)
.build())
.build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
model.setListeners(new ScoreIterationListener(100));
model.fit(mnistTrain, numEpochs);
ModelSerializer.writeModel(model, "ocr-model.zip", true);
通过这段代码,我们实现了一个简单的深度学习模型,用于手写数字识别。我们详细讲解了每个部分的功能,包括数据加载、模型配置、训练和保存。
如果说,上面的概念太难懂,那我们来把深度学习模型训练的过程用一个做菜的例子来解释,让它变得更加简单易懂。
假设我们要做一道菜——“手写数字识别”。在这个比喻里,我们的目标就是“做出一道完美的菜”,也就是训练出一个能够准确识别手写数字的模型。而我们做菜的过程就像是模型的训练过程。
在做菜之前,首先要准备好食材。对于我们的“手写数字识别”任务,食材就是MNIST数据集,这个数据集包含了大量的手写数字图片,类似于我们准备的“原料”。
int batchSize = 128; // 这里就像是一次做菜时需要的食材数量
DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, 123); // 这是我们的食材(训练数据)
MnistDataSetIterator
来把数据集分成一批一批(比如每次做菜时要准备128个食材),确保我们每次只用一定量的食材,避免浪费资源。接下来,就像做菜要挑选合适的菜谱一样,我们需要为模型挑选合适的结构和步骤。比如,你要做一道麻辣火锅,你不能随便照着糖醋排骨的菜谱做,它的“配料”和“步骤”都不一样。
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(123) // 随机种子,就像菜谱的版本,确保每次做出来的菜口感一致
.updater(new Nesterovs(0.006, 0.9)) // 就像选择了一个合适的调味料:Nesterovs优化器
.l2(1e-4) // 加一点调味料:L2正则化,防止做出来的菜太腻
.list() // 开始选择菜谱中的每个步骤
.layer(0, new DenseLayer.Builder()
.nIn(28 * 28) // 食材的数量,就是手写数字图片的大小
.nOut(1000) // 菜谱第一步:做成1000种口味的菜
.activation(Activation.RELU) // 用的调味料是ReLU
.build())
.layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.activation(Activation.SOFTMAX) // 做的最后一道工序,用SOFTMAX调料,让它有个概率结果
.nIn(1000)
.nOut(10) // 最后把1000种口味缩减到10种:0到9的数字
.build())
.build();
在这里,我们通过NeuralNetConfiguration.Builder
来搭建了一个神经网络模型,它就像是我们选择的菜谱。我们定义了神经网络的“结构”:
当你选择好了菜谱,就要开始实际动手做菜了。我们通过训练模型,类似于在厨房里开始切菜、炒菜,经过一段时间的“烹饪”,最终做出一道合格的菜。
MultiLayerNetwork model = new MultiLayerNetwork(conf); // 按照菜谱开始做菜
model.init(); // 开始做菜的过程
model.setListeners(new ScoreIterationListener(100)); // 观察一下菜的火候,每100次迭代检查一次
model.fit(mnistTrain, 1); // 做菜:一次完整的“烹饪”,这里是1次迭代
model.fit
方法就像是我们把食材按步骤处理、混合,最后做成一道菜。ScoreIterationListener
监听器查看损失函数的变化)。如果菜做得不对,就赶紧调整火候和调料。一旦菜做得差不多了,我们就要把做好的菜保存下来,下一次可以重新享用,这就像是我们训练好的模型需要保存,以便以后使用。
ModelSerializer.writeModel(model, "ocr-model.zip", true); // 把做好的菜保存起来,方便下次享用
通过ModelSerializer.writeModel
,我们把做好的模型保存为ocr-model.zip
,它就像是我们做好的菜装进了保鲜盒,可以随时取出来再次使用。
model.fit(mnistTrain, numEpochs);
ModelSerializer.writeModel(model, new File("ocr-model.zip"), true);
我们使用model.fit
方法进行训练,训练完成后将模型保存到ocr-model.zip
文件中。这样,我们的模型就可以在Spring Boot项目中使用了。
在Web应用中,用户上传的图片可能是各种格式、尺寸和颜色的。我们需要对图片进行预处理,转换为模型能够接受的输入格式。
package com.neo.service;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.factory.Nd4j;
import org.springframework.stereotype.Service;
import java.awt.image.BufferedImage;
@Service
public class ImageProcessingService {
public INDArray preprocessImage(BufferedImage image) {
int width = 28;
int height = 28;
INDArray array = Nd4j.zeros(1, width * height);
// 遍历图像,将每个像素值转换为灰度值
for (int i = 0; i < width; i++) {
for (int j = 0; j < height; j++) {
int rgb = image.getRGB(i, j);
int r = (rgb >> 16) & 0xFF;
int g = (rgb >> 8) & 0xFF;
int b = rgb & 0xFF;
// 灰度转换公式
double gray = (0.299 * r + 0.587 * g + 0.114 * b) / 255.0;
array.putScalar(new int[]{0, j * width + i}, gray);
}
}
// 使用随机数据模拟样本,计算标准化参数
INDArray sampleData = Nd4j.rand(new int[]{100, width * height});
DataSet sampleDataSet = new DataSet(sampleData, null);
NormalizerStandardize normalizer = new NormalizerStandardize();
normalizer.fit(sampleDataSet); // 计算标准化参数
// 归一化图像数据
DataSet dataSet = new DataSet(array, null);
normalizer.transform(dataSet);
return dataSet.getFeatures(); // 返回标准化后的特征
}
}
这段代码的功能是处理图像,将图像转换为适合深度学习模型输入的格式,并对图像进行标准化处理。下面我们将详细解释每一部分的代码和背后的原理。
int width = 28;
int height = 28;
INDArray array = Nd4j.zeros(1, width * height);
这段代码的目的是为图像数据创建一个INDArray,也就是一个二维数组。在深度学习中,数据通常以数组的形式输入模型。
for (int i = 0; i < width; i++) {
for (int j = 0; j < height; j++) {
int rgb = image.getRGB(i, j);
int r = (rgb >> 16) & 0xFF;
int g = (rgb >> 8) & 0xFF;
int b = rgb & 0xFF;
// 灰度转换公式
double gray = (0.299 * r + 0.587 * g + 0.114 * b) / 255.0;
array.putScalar(new int[]{0, j * width + i}, gray);
}
}
这部分代码的目的是将图像的每个像素值转换为灰度值,因为在处理手写数字识别任务时,颜色信息(如红、绿、蓝)对识别数字的帮助有限,使用灰度图像就足够了。
INDArray sampleData = Nd4j.rand(new int[]{100, width * height});
DataSet sampleDataSet = new DataSet(sampleData, null);
NormalizerStandardize normalizer = new NormalizerStandardize();
normalizer.fit(sampleDataSet); // 计算标准化参数
这一部分的目的是通过一组随机数据来计算标准化参数。标准化是深度学习中常用的一种预处理方法,目的是让数据的分布更符合模型的要求,通常是将数据的均值调整为0,标准差调整为1。
DataSet dataSet = new DataSet(array, null);
normalizer.transform(dataSet);
现在我们已经有了一个训练好的标准化器,可以用它来对图像进行标准化处理。
return dataSet.getFeatures(); // 返回标准化后的特征
这行代码返回处理后的图像数据,dataSet.getFeatures()
获取了经过标准化处理的特征数据。这个特征就是接下来输入到深度学习模型中的数据。
INDArray
**:存储图像的像素值,并为接下来的计算做准备。NormalizerStandardize
来标准化图像数据。图像处理的核心目的是将图像数据转换为适合深度学习模型处理的格式,同时通过标准化减少数据的偏差,确保模型能够更快收敛,并提高预测的准确性。
接下来,我们需要一个控制器来处理用户上传的图片,并调用模型进行预测。
package com.neo.controller;
import com.neo.service.ImageProcessingService;
import com.neo.service.OCRPredictionService;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Controller;
import org.springframework.ui.Model;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.multipart.MultipartFile;
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
@Controller
public class OCRController {
@Autowired
private ImageProcessingService imageProcessingService;
@Autowired
private OCRPredictionService ocrPredictionService;
// 文件保存路径,您可以根据实际需求修改此路径
private static final String UPLOAD_DIR = "src/main/resources/static/uploads/";
// 显示上传页面
@RequestMapping("/ocr")
public String showUploadPage() {
return "upload"; // 返回上传页面的视图
}
// 处理上传的图片并进行 OCR 预测
@PostMapping("/ocr/predict")
public String predict(@RequestParam("file") MultipartFile file, Model model) {
try {
String fileName = file.getOriginalFilename();
// 创建保存文件的路径
Path uploadPath = Paths.get(UPLOAD_DIR);
if (!Files.exists(uploadPath)) {
Files.createDirectories(uploadPath); // 创建目录
}
// 保存文件到本地
Path filePath = uploadPath.resolve(fileName);
file.transferTo(filePath);
// 读取保存的文件
// 读取图片
BufferedImage image = ImageIO.read(filePath.toFile());
// 处理图片并预测
INDArray processedImage = imageProcessingService.preprocessImage(image);
// 使用模型进行预测
int predictedDigit = ocrPredictionService.predict(processedImage);
System.out.println("识别的数字是: " + predictedDigit);
// 将图片和预测结果传递给前端
model.addAttribute("imagePath", fileName); // 只传递相对路径
model.addAttribute("prediction", predictedDigit);
return "upload"; // 返回上传页面并显示预测结果
} catch (IOException e) {
e.printStackTrace();
model.addAttribute("error", "图片处理失败,请重新上传");
return "upload"; // 如果出现错误,返回上传页面
}
}
}
为了让用户能够方便地上传图片,我们需要一个友好的用户界面。使用Thymeleaf和Bootstrap,我们可以快速构建一个简洁的上传页面。
通过这个简单的表单,用户可以上传图片,然后在页面上查看预测结果。
<!DOCTYPE html>
<html lang="en" xmlns:th="http://www.w3.org/1999/xhtml">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>OCR 图像上传和预测</title>
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0-alpha1/dist/css/bootstrap.min.css" rel="stylesheet">
</head>
<body>
<div class="container mt-5">
<h2 class="text-center">手写数字识别 OCR</h2>
<!-- 上传图片表单 -->
<form action="/ocr/predict" method="post" enctype="multipart/form-data" class="mt-4">
<div class="mb-3">
<label for="file" class="form-label">选择图片文件</label>
<input type="file" class="form-control" id="file" name="file" required>
</div>
<button type="submit" class="btn btn-primary">上传并预测</button>
</form>
<!-- 错误信息 -->
<div th:if="${error}" class="alert alert-danger mt-3" role="alert">
<p th:text="${error}"></p>
</div>
<!-- 显示上传的图片 -->
<div th:if="${imagePath}">
<h3 class="mt-4">上传的图片:</h3>
<img th:src="@{/uploads/{image}(image=${imagePath})}" alt="Uploaded Image" class="img-fluid">
</div>
<div th:if="${prediction != null}">
<h3 class="mt-4">预测结果:</h3>
<div class="alert alert-success" role="alert">
识别结果:<strong th:text="${prediction}"></strong>
</div>
</div>
</div>
<script src="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0-alpha1/dist/js/bootstrap.bundle.min.js"></script>
</body>
</html>
我们完成了一个从模型训练到Web应用集成的完整流程,展示了如何使用Spring Boot和DeepLearning4J构建一个手写数字识别系统。这个项目不仅展示了深度学习在实际应用中的潜力,也展示了Java开发环境中集成深度学习技术的可能性。
通过这个示例,相信你对深度学习有了更加直观的理解,也希望你能够在此基础上,进一步探索深度学习和Java的结合应用。无论是扩展这个项目的功能,还是尝试不同的数据集和模型结构,都是很好的学习和实践方式。
我也是一名刚刚入门深度学习的小学生,欢迎友好指正和交流~
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。