Android|将Pytorch模型部署到Android端

最后效果:
Android|将Pytorch模型部署到Android端
文章图片

准备:

  • pytorch1.4(pytorch环境至少要在1.3以上,当前最新版本1.4)
  • 已经训练好的pytorch模型
  • Jetpack组件:CameraX(这个用来调用相机的)
如有需要,可以先看看我这两篇博文:
如果pytorch环境不满足,进行pytorch环境升级:win10+pytorch1.4+cuda10.1安装:从显卡驱动开始
Jetpack组件:CameraX,使用前一定要先了解:Jetpack CameraX实践,预览(preview)及分析(analysis)
模型转化
# pytorch环境中 model_pth = os.path.join(MODEL_PATH, 'resnet18.pth') # resnet18模型的参数文件 mobile_pt = os.path.join(MODEL_PATH, 'resnet18.pt')# 将resnet18模型保存为Android可以调用的文件model = make_model('resnet18') # 搭建网络 model.load_state_dict(torch.load(model_pth)) # 加载参数 model.eval() # 模型设为评估模式# 1张3通道224*224的图片 input_tensor = torch.rand(1, 3, 224, 224) # 设定输入数据格式mobile = torch.jit.trace(model, input_tensor) # 模型转化 mobile.save(mobile_pt) # 保存文件

注:这样就完成了模型的转化,得到resnet18.pt文件
Android 设置CameraX:实现预览 添加依赖:
// CameraX core library using the camera2 implementation def camerax_version = "1.0.0-beta01" implementation "androidx.camera:camera-camera2:${camerax_version}" implementation "androidx.camera:camera-view:1.0.0-alpha08" implementation "androidx.camera:camera-extensions:1.0.0-alpha08" implementation "androidx.camera:camera-lifecycle:${camerax_version}"//pytorch implementation 'org.pytorch:pytorch_android:1.4.0' implementation 'org.pytorch:pytorch_android_torchvision:1.4.0'

申请相机权限及跳转:
在AndroidManifest.xml中添加权限:,然后跳转CameraX页面之前进行动态权限申请(也可以将动态权限申请放在CameraXFragment中,调用相机时再申请):
package com.example.gca.leftFragmentimport android.Manifest ...... import kotlinx.android.synthetic.main.left_fragment.*private const val REQUEST_CODE_PERMISSIONS = 10 // 权限标识符 private val REQUIRED_PERMISSIONS = arrayOf(Manifest.permission.CAMERA) // 相机权限class LeftFragment : Fragment() {override fun onCreateView( inflater: LayoutInflater, container: ViewGroup?, savedInstanceState: Bundle? ): View? { return inflater.inflate(R.layout.left_fragment, container, false) }override fun onActivityCreated(savedInstanceState: Bundle?) { super.onActivityCreated(savedInstanceState)// 去到CameraXFragment页面 buttonCameraX.setOnClickListener { // 检查相机权限 if (allPermissionsGranted()) { // 这里是用导航组件(Navigation)进行跳转的 Navigation.findNavController(it).navigate(R.id.action_leftFragment_to_cameraXFragment) } else { requestPermissions(REQUIRED_PERMISSIONS, REQUEST_CODE_PERMISSIONS) } } }// 请求权限结果回调 override fun onRequestPermissionsResult( requestCode: Int, permissions: Array>, grantResults: IntArray ) { if (requestCode == REQUEST_CODE_PERMISSIONS) { if (allPermissionsGranted()) { // 权限通过,进行跳转 Navigation.findNavController(requireView()).navigate(R.id.action_leftFragment_to_cameraXFragment) } else { Toast.makeText( requireContext(), "Permissions not granted by the user.", Toast.LENGTH_SHORT ).show() } } }// 检查权限 private fun allPermissionsGranted() = REQUIRED_PERMISSIONS.all { ContextCompat.checkSelfPermission( requireContext(), it ) == PackageManager.PERMISSION_GRANTED } }

新建一个fragment和布局文件(用来放置相机的),布局如下(fragment_camera_x.xml):

fragment设置(CameraXFragment.kt):
package com.example.gca.leftFragment.cameraXFragmentimport android.os.Bundle ...... import java.util.concurrent.Executorsclass CameraXFragment : Fragment(), CameraXConfig.Provider { override fun getCameraXConfig(): CameraXConfig { return Camera2Config.defaultConfig() }private lateinit var cameraProviderFuture: ListenableFuture // 相机的控制者 private lateinit var imagePreview: Preview // 图像预览 private lateinit var cameraPreviewView: PreviewView // 显示相机的控件override fun onCreateView( inflater: LayoutInflater, container: ViewGroup?, savedInstanceState: Bundle? ): View? { // Inflate the layout for this fragment return inflater.inflate(R.layout.fragment_camera_x, container, false) }override fun onActivityCreated(savedInstanceState: Bundle?) { super.onActivityCreated(savedInstanceState)cameraProviderFuture = ProcessCameraProvider.getInstance(requireContext()) // 相机控制权 cameraPreviewView = previewView // 显示相机控件// 加载相机 cameraPreviewView.post { startCamera() } }private fun startCamera() { // 预览 imagePreview = Preview.Builder().apply { setTargetAspectRatio(AspectRatio.RATIO_16_9) setTargetRotation(previewView.display.rotation) }.build() imagePreview.setSurfaceProvider(previewView.previewSurfaceProvider)// 绑定 val cameraSelector = CameraSelector.Builder().requireLensFacing(CameraSelector.LENS_FACING_BACK).build() cameraProviderFuture.addListener(Runnable { val cameraProvider = cameraProviderFuture.get() cameraProvider.bindToLifecycle(this, cameraSelector, imagePreview) }, ContextCompat.getMainExecutor(requireContext())) } }

注:到这一步,可以运行项目,已经可以调用相机进行预览了,如果不行,参考Jetpack CameraX实践,预览(preview)及分析(analysis),应该是哪里漏掉了什么
pytorch模型部署 添加资源,将我们转化的模型resnet18.pt复制到assets文件夹下(如果你没有assets文件夹,参考:https://blog.csdn.net/y_dd6011)
添加两个常量:
const val MODEL_NAME = "resnet18.pt" // 神经网络 val IMAGE_CLASSIFICATION = arrayOf(// 这个就是你的神经网络能够识别的图片种类数目 "tench, Tinca tinca", ...... "goldfish, Carassius auratus", )

新建一个kotlin类(Unit.kt):(用来获取神经网络的绝对地址)
package com.example.gca.unitimport android.content.Context import android.util.Log import java.io.File import java.io.FileOutputStream import java.io.IOExceptionobject Unit { fun assetFilePath(context: Context, assetName: String): String? { val file = File(context.filesDir, assetName) try { context.assets.open(assetName).use { `is` -> FileOutputStream(file).use { os -> val buffer = ByteArray(4 * 1024) while (true) { val length = `is`.read(buffer) if (length <= 0) break os.write(buffer, 0, length) } os.flush() os.close() } return file.absolutePath } } catch (e: IOException) { Log.e("pytorch", "Error process asset $assetName to file path") } return null } }

再新建一个kotlin类(ImageClassificationResult.kt):(图像分析之后的回调结果,这里封装成类)
package com.example.gca.unitimport com.example.gca.IMAGE_CLASSIFICATIONclass ImageClassificationResult(private val index: Int, private val value: Float) {fun getImageClassification() = IMAGE_CLASSIFICATION[index] fun getGarbageIndex() = index fun getGarbageValue() = value }

最后一步,给相机添加图像分析器(CameraXFragment.kt完整代码如下):
package com.example.gca.leftFragment.cameraXFragmentimport android.os.Bundle import android.util.Log import android.util.Size import android.view.LayoutInflater import android.view.View import android.view.ViewGroup import androidx.camera.camera2.Camera2Config import androidx.camera.core.* import androidx.camera.lifecycle.ProcessCameraProvider import androidx.camera.view.PreviewView import androidx.core.content.ContextCompat import androidx.fragment.app.Fragment import com.example.gca.MODEL_NAME import com.example.gca.R import com.example.gca.unit.ImageClassificationResult import com.example.gca.unit.Unit.assetFilePath import com.google.common.util.concurrent.ListenableFuture import kotlinx.android.synthetic.main.fragment_camera_x.* import kotlinx.coroutines.MainScope import kotlinx.coroutines.launch import org.pytorch.IValue import org.pytorch.Module import org.pytorch.Tensor import org.pytorch.torchvision.TensorImageUtils import java.nio.ByteBuffer import java.util.concurrent.Executorstypealias ResultListener = (result: ImageClassificationResult) -> Unit // 图像分析器的返回结果类型,typealias 是取别名class CameraXFragment : Fragment(), CameraXConfig.Provider { override fun getCameraXConfig(): CameraXConfig { return Camera2Config.defaultConfig() }private lateinit var cameraProviderFuture: ListenableFuture // 相机的控制者 private lateinit var imagePreview: Preview // 图像预览 private lateinit var imageAnalysis: ImageAnalysis // 图像分析 private val executor = Executors.newSingleThreadExecutor() // 后台线程 private lateinit var cameraPreviewView: PreviewView // 显示相机的控件 private lateinit var module: Module // 模型override fun onCreateView( inflater: LayoutInflater, container: ViewGroup?, savedInstanceState: Bundle? ): View? { // Inflate the layout for this fragment return inflater.inflate(R.layout.fragment_camera_x, container, false) }override fun onActivityCreated(savedInstanceState: Bundle?) { super.onActivityCreated(savedInstanceState)cameraProviderFuture = ProcessCameraProvider.getInstance(requireContext()) // 相机控制权 cameraPreviewView = previewView // 显示相机控件// 加载图片识别模型 try { val modulePath = assetFilePath(requireContext(), MODEL_NAME) module = Module.load(modulePath) } catch (e: Exception) { Log.e(CameraXFragment::class.java.simpleName, e.toString()) }// 加载相机 cameraPreviewView.post { startCamera() } }private fun startCamera() { // 预览 imagePreview = Preview.Builder().apply { setTargetAspectRatio(AspectRatio.RATIO_16_9) setTargetRotation(previewView.display.rotation) }.build() imagePreview.setSurfaceProvider(previewView.previewSurfaceProvider)// 分析 imageAnalysis = ImageAnalysis.Builder().apply { setImageQueueDepth(ImageAnalysis.STRATEGY_KEEP_ONLY_LATEST) setTargetResolution(Size(224, 224)) }.build() imageAnalysis.setAnalyzer(executor, ImageClassificationAnalyzer(module) { MainScope().launch { textView2.text = it.getImageClassification() textView3.text = it.getGarbageIndex().toString() textView4.text = it.getGarbageValue().toString() } Log.v(CameraXFragment::class.java.simpleName, it.toString()) })// 绑定 val cameraSelector = CameraSelector.Builder().requireLensFacing(CameraSelector.LENS_FACING_BACK).build() cameraProviderFuture.addListener(Runnable { val cameraProvider = cameraProviderFuture.get() cameraProvider.bindToLifecycle(this, cameraSelector, imagePreview, imageAnalysis) }, ContextCompat.getMainExecutor(requireContext())) }// 图像分类器 private class ImageClassificationAnalyzer(module: Module, listener: ResultListener?=null) : ImageAnalysis.Analyzer {private val mModule = module private val listeners = ArrayList().apply { listener?.let { add(it) } }private fun ByteBuffer.toByteArray(): ByteArray { rewind()// Rewind the buffer to zero val data = https://www.it610.com/article/ByteArray(remaining()) get(data)// Copy the buffer into a byte array return data // Return the byte array }override fun analyze(imageProxy: ImageProxy) { if (listeners.isEmpty()) { imageProxy.close() return }val buffer = imageProxy.planes[0].buffer val data = buffer.toByteArray()// 图像识别 val inputTensorBuffer = Tensor.allocateFloatBuffer(3*224*224) // 输入数据格式设置 val inputTensor = Tensor.fromBlob(inputTensorBuffer, longArrayOf(1, 3, 224, 224)) // 转化成tensorTensorImageUtils.imageYUV420CenterCropToFloatBuffer( // 加载图片 imageProxy.image,0, 224, 224, TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB, inputTensorBuffer, 0)val outputTensor = mModule.forward(IValue.from(inputTensor)).toTensor() // 使用模型进行图像识别 val scores = outputTensor.dataAsFloatArray var topScore = 0.0f var topIndex = 0 for (index in scores.indices) { // 获取识别结果可能性最大的 if (topScore < scores[index]) { topScore = scores[index] topIndex = index } }// Call all listeners with new value listeners.forEach { it(ImageClassificationResult(topIndex, topScore)) }imageProxy.close() } } }

【Android|将Pytorch模型部署到Android端】注:到此整个pytorch模型部署就完成了,整个流程和核心代码都在这了,其余的可以自己扩展

    推荐阅读