前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >安卓软件开发:如何实现机器学习部署到安卓端

安卓软件开发:如何实现机器学习部署到安卓端

原创
作者头像
Nimyears
修改2024-09-21 11:55:07
4570
修改2024-09-21 11:55:07
举报
文章被收录于专栏:JetpackCompose M3

2024年已经过半了,我作为聋人独立开发者,我经常反思自己在这半年中的成长,自己这半年到底进步了多少?在这篇文章里,分享机器学习模型部署Android端的开发案例。无论你有没有开发经验,相信这篇文章对你会非常有所帮助。

一、背景

现在智能手机和移动设备越来越普及,很多应用都依赖机器学习模型提升用户体验,比如图像识别、文本识别、面部识别、语音处理、图像分类等。因为移动设备的硬件资源有限,直接使用大模型往往会卡顿,无法顺畅运行。所以,如何在移动端高效地部署和优化模型,成了开发的关键。

我个人特别喜欢使用 TensorFlow 框架做开发,简称“TF”,研究如何使用机器学习模型部署工作,TensorFlow 的功能强大,简化开发流程,真的非常成功。TensorFlow 官网上有非常全面的文档,可以参考:TensorFlow 官网

思考一:为啥选择 TensorFlow的原因?

TensorFlow 是一个适合移动端的平台,无论你是刚入门还是专家级别,都可以使用它轻松构建部署机器学习模型。

思考二:如何轻松构建和部署模型?

TensorFlow 提供了不同层次的工具,比如Keras API,能大大简化模型的构建和训练流程,初学者都可以很快上手。


二、讲解核心代码

首先看一下如何使用 TensorFlow 进行基础的机器学习开发。

2.1 安装和导入 TensorFlow

代码语言:javascript
复制
import tensorflow as tf

2.2 加载预处理数据集

代码语言:javascript
复制
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

2.3 构建模型

代码语言:javascript
复制
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10)
])

2.4 训练模型

代码语言:javascript
复制
model.fit(x_train, y_train, epochs=5)

三、在 Android 项目中集成 TensorFlow Lite

3.1 在 Android Studio 中导入 TensorFlow Lite 模型

  • 在Project Explorer 中选择 File > New > Other > TensorFlow Lite Model
  • 选择已训练好的 .tflite 模型文件。
  • 导入完成后,Android Studio 会显示模型的概要信息,提供示例代码。

然后可以看到提供了两种编程语言代码的模板,根据个人喜爱用哪种编程语言。

3.2 在build.gradle依赖指定tensorflow版本:

代码语言:javascript
复制
dependencies {
    implementation 'org.tensorflow:tensorflow-lite:2.12.0'
    implementation 'org.tensorflow:tensorflow-lite-support:0.4.0'
    implementation 'org.tensorflow:tensorflow-lite-metadata:0.1.0'
}

3.3 为了适配更多的 Android 设备,还需要配置 ABI:

代码语言:javascript
复制
android {
    defaultConfig {
        ndk {
            abiFilters 'armeabi-v7a', 'arm64-v8a'
        }
    }
}

不对劲,这篇文章看起来好像很复杂看不懂(,那么开始做项目作为演示,用kotlin实现,以手写数字识别App实现。


四、项目开发

在这个项目中,我展示如何使用 TensorFlow Lite 实现一个简单的手写数字识别App。

4.1 使用 TensorFlow 训练模型,最后导出 .tflite 模型

以下模型训练的代码,最后生成nim_model.tflite 文件部署:

代码语言:python
代码运行次数:0
复制
import tensorflow as tf

# 加载 MNIST 数据集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 数据预处理
x_train, x_test = x_train / 255.0, x_test / 255.0

# 定义模型
model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),  # 一维数组
    tf.keras.layers.Dense(128, activation='relu'),  # 全连接层
    tf.keras.layers.Dropout(0.2),  # 随机丢弃20%的神经元
    tf.keras.layers.Dense(10, activation='softmax')  # 输出层,有10个类别
])

# 编译模型
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# 训练模型
model.fit(x_train, y_train, epochs=5)

# 保存为 TensorFlow Lite 模型
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

# 保存模型文件
with open('nim_model.tflite', 'wb') as f:
    f.write(tflite_model)

保存模型文件代码码会输出一个 nim_model.tflite 文件,参考第三章的操作步骤实现,略讲。

生成结果是:

PS:我不做推荐用哪个平台产品训练模型!

生成到云硬盘上的文件自行下载。

在Android项目加载导入tf文件即可。

4.2 编写模型推理逻辑

MainActivity 中,编写代码加载模型进行推理。

代码语言:java
复制
package com.nim.nimhanddigits
import android.app.Activity
import android.content.Intent
import android.content.pm.PackageManager
import android.graphics.Bitmap
import android.graphics.BitmapFactory
import android.net.Uri
import android.os.Build
import android.os.Bundle
import android.widget.Button
import android.widget.ImageView
import android.widget.TextView
import android.widget.Toast
import androidx.appcompat.app.AppCompatActivity
import androidx.core.app.ActivityCompat
import androidx.core.content.ContextCompat
import com.spd.nimhanddigits.ml.NimModel
import org.tensorflow.lite.DataType
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
import java.io.InputStream
import java.nio.ByteBuffer
import java.nio.ByteOrder

class MainActivity : AppCompatActivity() {

    private val FILE_SELECT_CODE = 100
    private val PERMISSION_REQUEST_CODE = 200

    private lateinit var imageView: ImageView
    private lateinit var resultTextView: TextView
    private lateinit var selectButton: Button
     private lateinit var selectImageButton: Button  

    override fun onCreate(savedInstanceState: Bundle?) {
        super.onCreate(savedInstanceState)
        setContentView(R.layout.activity_main)

        imageView = findViewById(R.id.imageView)
        resultTextView = findViewById(R.id.resultTextView)
        selectButton = findViewById(R.id.selectButton)

        selectButton.setOnClickListener {
            if (checkPermissions()) {
                openFileChooser()
            } else {
                requestPermissions()
            }
        }
        
        selectImageButton.setOnClickListener {
            val bitmap = (imageView.drawable as? BitmapDrawable)?.bitmap
            if (bitmap != null) {
                showLoading()
                val result = runModel(bitmap)
                hideLoading()
                resultTextView.text = "预测结果: $result"
            } else {
                Toast.makeText(this, "请先选择一张图片", Toast.LENGTH_SHORT).show()
            }
        }
    }
    }

    private fun openFileChooser() {
        val intent = Intent(Intent.ACTION_GET_CONTENT)
        intent.type = "image/*"
        startActivityForResult(intent, FILE_SELECT_CODE)
    }

    private fun checkPermissions(): Boolean {
        return if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU) {
            // Android 13+ 权限处理
            ContextCompat.checkSelfPermission(
                this,
                android.Manifest.permission.READ_MEDIA_IMAGES
            ) == PackageManager.PERMISSION_GRANTED
        } else {
            // Android 12 及以下版本
            ContextCompat.checkSelfPermission(
                this,
                android.Manifest.permission.READ_EXTERNAL_STORAGE
            ) == PackageManager.PERMISSION_GRANTED
        }
    }
    
    private fun requestPermissions() {
        if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU) {
            // Android 13+
            ActivityCompat.requestPermissions(
                this,
                arrayOf(android.Manifest.permission.READ_MEDIA_IMAGES),
                PERMISSION_REQUEST_CODE
            )
        } else {
            // Android 12以下
            ActivityCompat.requestPermissions(
                this,
                arrayOf(android.Manifest.permission.READ_EXTERNAL_STORAGE),
                PERMISSION_REQUEST_CODE
            )
        }
    }

    
    override fun onRequestPermissionsResult(
        requestCode: Int,
        permissions: Array<out String>,
        grantResults: IntArray
    ) {
        super.onRequestPermissionsResult(requestCode, permissions, grantResults)
        if (requestCode == PERMISSION_REQUEST_CODE) {
            if (grantResults.isNotEmpty() && grantResults[0] == PackageManager.PERMISSION_GRANTED) {
                openFileChooser()
            } else {
                Toast.makeText(this, "权限被拒绝", Toast.LENGTH_SHORT).show()
            }
        }
    }

   
    override fun onActivityResult(requestCode: Int, resultCode: Int, data: Intent?) {
        super.onActivityResult(requestCode, resultCode, data)
        if (requestCode == FILE_SELECT_CODE && resultCode == Activity.RESULT_OK) {
            val imageUri: Uri? = data?.data
            imageUri?.let {
                val imageBitmap = getBitmapFromUri(it)
                imageView.setImageBitmap(imageBitmap)
                imageBitmap?.let { bitmap ->
                    showLoading()
                    val result = runModel(bitmap)
                    hideLoading()
                    resultTextView.text = "预测结果: $result"
                }
            }
        }
    }

    // 从Url获取Bitmap
    private fun getBitmapFromUri(uri: Uri): Bitmap? {
        return try {
            val inputStream: InputStream? = contentResolver.openInputStream(uri)
            BitmapFactory.decodeStream(inputStream)
        } catch (e: Exception) {
            e.printStackTrace()
            null
        }
    }

    // Bitmap转换为模型输入格式进行推理
    private fun runModel(bitmap: Bitmap): Int {
        return try {
            val model = NimModel.newInstance(this)

            // 输入大小 28x28
            val scaledBitmap = Bitmap.createScaledBitmap(bitmap, 28, 28, true)

            //Bitmap转为ByteBuffer
            val byteBuffer = ByteBuffer.allocateDirect(28 * 28 * 4).order(ByteOrder.nativeOrder())
            for (y in 0 until 28) {
                for (x in 0 until 28) {
                    val pixelValue = scaledBitmap.getPixel(x, y)
                    val r = (pixelValue shr 16 and 0xFF).toFloat()
                    val g = (pixelValue shr 8 and 0xFF).toFloat()
                    val b = (pixelValue and 0xFF).toFloat()
                    val gray = (r + g + b) / 3 / 255.0f
                    byteBuffer.putFloat(gray)
                }
            }

            // 创建输入对象
            val inputFeature0 = TensorBuffer.createFixedSize(intArrayOf(1, 28, 28, 1), DataType.FLOAT32)
            inputFeature0.loadBuffer(byteBuffer)

            val outputs = model.process(inputFeature0)
            val outputFeature0 = outputs.outputFeature0AsTensorBuffer.floatArray

            // 找到最大的预测值的索引
            val predictedDigit = outputFeature0.indices.maxByOrNull { outputFeature0[it] } ?: -1

            // 关闭模型
            model.close()

            predictedDigit
        } catch (e: Exception) {
            e.printStackTrace()
            -1 
        }
    }

    private fun showLoading() {
        resultTextView.text = "预测中..."
    }

    private fun hideLoading() {
        resultTextView.text = ""
    }
}

4.3 绘制UI

代码语言:xml
复制
<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
    android:layout_width="match_parent"
    android:layout_height="match_parent"
    android:gravity="center"
    android:orientation="vertical"
    android:padding="16dp">

    <TextView
        android:layout_width="match_parent"
        android:layout_height="50dp"
        android:gravity="center"
        android:text="Nim"
        android:layout_margin="40px"
        android:textColor="#03A9F4"
        android:textSize="40sp" />

    <ImageView
        android:id="@+id/imageView"
        android:layout_width="250dp"
        android:layout_height="250dp"
        android:layout_gravity="center"
        android:layout_marginBottom="16dp"
        android:background="@color/design_default_color_secondary_variant" />

    <Button
        android:id="@+id/selectButton"
        android:layout_width="wrap_content"
        android:layout_height="wrap_content"
        android:layout_gravity="center"
        android:backgroundTint="#03DAC5"
        android:text="选择图片"
        android:textColor="#FFFFFF" />

    <Button
        android:id="@+id/selectImageButton"
        android:layout_width="wrap_content"
        android:layout_height="wrap_content"
        android:layout_gravity="center"
        android:layout_marginBottom="16dp"
        android:backgroundTint="#6200EE"
        android:text="预测"
        android:textColor="#FFFFFF" />


    <TextView
        android:id="@+id/resultTextView"
        android:layout_width="wrap_content"
        android:layout_height="wrap_content"
        android:layout_margin="40dp"
        android:text="预测结果"
        android:textSize="24sp" />

</LinearLayout>

4.4 效果图

4.5 视频演示

视频内容

五、技术难点

在开发手写数字识别应用的过程中,主要的技术难点总结以下几点:

5.1. 模型压缩与优化

手写数字识别应用虽然模型相对简单,但由于移动端设备的资源有限,如何在保证准确度的前提下压缩模型并优化性能是一个技术挑战。TensorFlow Lite 提供了量化技术,模型的权重和激活函数从浮点数表示转换为整数,从而减少模型大小加快推理速度。

挑战点

• 在模型压缩的过程中,如何在保持模型精度的同时降低模型大小。

• 实现轻量级模型时,如何减少运算资源的消耗而不影响用户体验。

5.2 实时推理的延迟控制

手写数字识别属于实时性要求较高的任务。为了提升用户体验,需要降低推理延迟。通过 TensorFlow Lite 的优化和多线程处理,可以有效降低推理时的延迟。

挑战点

• 如何通过多线程或者硬件加速器来减少延迟,同时保证推理结果的准确性。

• 控制实时推理的时间,通常需要将延迟控制在150毫秒以下,确保用户感觉到应用响应迅速。

5.3 模型的跨平台兼容性

保证应用在特定设备上运行良好,还要确保在不同硬件架构的设备上(如 armeabi-v7a 和 arm64-v8a)都能正常工作,涉及到 TensorFlow Lite 模型在不同设备间的兼容性。

挑战点

• 在 Android 项目中,需要针对不同的硬件平台配置 ABI,支持各种 Android 设备。

• 同时,使用 ONNX 格式可以帮助模型在不同框架和平台间迁移,但在转换过程中,可能遇到精度下降或者其他兼容性问题。

5.4 UI 交互与用户体验

在手写数字识别App中,用户选择图片、显示推理结果、交互流畅性等细节都需要精心设计,才能让用户获得良好的体验。

挑战点

• 保证应用 UI 流程简洁流畅,用户能够快速完成操作,得到识别结果。

• 优化加载和推理过程中 UI 的反馈。

六、学习技术笔记

6.1 简化模型部署的体验

TensorFlow Lite 很好地简化了模型的部署过程,让开发者无需过多关注底层优化细节,就能在移动端上部署机器学习模型。我特别喜欢它的 API 设计,它让复杂的模型推理工作变得直观易懂。通过一些工具和指南,轻松就能将 Keras 模型转换为 .tflite 文件并集成到 Android 项目中。

6.2 模型量化带来的性能提升

在使用量化技术时,我感受到模型的大小大幅减少,同时推理速度也有了明显提升。在原始模型大小过大的情况下,通过量化能将模型大小减少近 75%,对于移动设备来说,这种优化是非常实用的。

6.2 如何通过量化技术优化模型

模型权重和激活函数的浮点数表示形式转换为整数表示的过程。

6.3 跨平台兼容性和挑战

ONNX 格式为模型的跨平台迁移提供了强有力的支持。尽管这个技术非常方便,但在模型从一个框架转换到另一个框架时,遇到了一些兼容性问题,例如某些层的行为与期望不符,需要进行额外调整。不过总体来说,ONNX 确实大大简化了跨平台的部署工作。

6.4 技术细节的把控

在将机器学习模型应用于移动设备时,深刻感受到硬件性能和资源的局限性,特别是在推理时间、内存使用和功耗之间做平衡时,需要不断优化和调试代码.

七、总结

通过这个项目的开发,我学习了如何优化机器学习模型在移动设备上高效运行,还学会了如何利用多种优化技术,比如量化和硬件加速,提升性能。

总体来说,使用 TensorFlow Lite 和相关技术时,虽然面临一些技术难点和挑战,但让我更加深入了解了移动端机器学习应用开发的核心技巧。

有任何问题欢迎提问,感谢大家阅读 :)

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 一、背景
  • 二、讲解核心代码
    • 2.1 安装和导入 TensorFlow
      • 2.2 加载预处理数据集
        • 2.3 构建模型
          • 2.4 训练模型
          • 三、在 Android 项目中集成 TensorFlow Lite
            • 3.1 在 Android Studio 中导入 TensorFlow Lite 模型
              • 3.2 在build.gradle依赖指定tensorflow版本:
                • 3.3 为了适配更多的 Android 设备,还需要配置 ABI:
                • 四、项目开发
                  • 4.1 使用 TensorFlow 训练模型,最后导出 .tflite 模型
                    • 4.2 编写模型推理逻辑
                      • 4.3 绘制UI
                        • 4.4 效果图
                          • 4.5 视频演示
                          • 五、技术难点
                          • 六、学习技术笔记
                            • 6.1 简化模型部署的体验
                              • 6.2 模型量化带来的性能提升
                                • 6.2 如何通过量化技术优化模型
                                  • 6.3 跨平台兼容性和挑战
                                    • 6.4 技术细节的把控
                                    • 七、总结
                                    相关产品与服务
                                    云硬盘
                                    云硬盘(Cloud Block Storage,CBS)为您提供用于 CVM 的持久性数据块级存储服务。云硬盘中的数据自动地在可用区内以多副本冗余方式存储,避免数据的单点故障风险,提供高达99.9999999%的数据可靠性。同时提供多种类型及规格,满足稳定低延迟的存储性能要求。
                                    领券
                                    问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档