You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
mindspore/model_zoo/official/lite/style_transfer/README.md

14 KiB

MindSpore Lite 端侧风格迁移demoAndroid

本示例程序演示了如何在端侧利用MindSpore Lite API以及MindSpore Lite风格迁移模型完成端侧推理根据demo内置的标准图片更换目标图片的艺术风格并在App图像预览界面中显示出来。

运行依赖

  • Android Studio >= 3.2 (推荐4.0以上版本)

构建与运行

  1. 在Android Studio中加载本示例源码。

    start_home

    启动Android Studio后点击File->Settings->System Settings->Android SDK,勾选相应的SDK Tools。如下图所示,勾选后,点击OKAndroid Studio即可自动安装SDK。

    start_sdk

    Android SDK Tools为默认安装项取消Hide Obsolete Packages选框之后可看到。

    使用过程中若出现问题可参考第4项解决。

  2. 连接Android设备运行该应用程序。

    通过USB连接Android手机。待成功识别到设备后点击Run 'app'即可在您的手机上运行本示例项目。

    编译过程中Android Studio会自动下载MindSpore Lite、模型文件等相关依赖项编译过程需做耐心等待。

    Android Studio连接设备调试操作可参考https://developer.android.com/studio/run/device?hl=zh-cn

    手机需开启“USB调试模式”Android Studio 才能识别到手机。 华为手机一般在设置->系统和更新->开发人员选项->USB调试中开始“USB调试模型”。

    run_app

  3. 在Android设备上点击“继续安装”安装完即可查看到设备摄像头捕获的内容和推理结果。

    install

    如下图所示,识别出的概率最高的物体是植物。

    result

  4. Demo部署问题解决方案。

    4.1 NDK、CMake、JDK等工具问题

    如果Android Studio内安装的工具出现无法识别等问题可重新从相应官网下载和安装并配置路径。

    • NDK >= 21.3 NDK

    • CMake >= 3.10.2 CMake

    • Android SDK >= 26 SDK

    • JDK >= 1.8 JDK

      project_structure

    4.2 NDK版本不匹配问题

    打开Android SDK,点击Show Package Details根据报错信息选择安装合适的NDK版本。 NDK_version

    4.3 Android Studio版本问题

    工具栏-help-Checkout for Updates中更新Android Studio版本。

    4.4 Gradle下依赖项安装过慢问题

    如图所示, 打开Demo根目录下build.gradle文件,加入华为镜像源地址:maven {url 'https://developer.huawei.com/repo/'}修改classpath为4.0.0,点击sync进行同步。下载完成后将classpath版本复原再次进行同步。 maven

示例程序详细说明

风格Android示例程序通过Android Camera 2 API实现摄像头获取图像帧以及相应的图像处理等功能Runtime中完成模型推理的过程。

示例程序结构


├── app
│   ├── build.gradle # 其他Android配置文件
│   ├── download.gradle # APP构建时由gradle自动从HuaWei Server下载依赖的库文件及模型文件
│   ├── proguard-rules.pro
│   └── src
│       ├── main
│       │   ├── AndroidManifest.xml # Android配置文件
│       │   ├── java # java层应用代码
│       │   │   └── com
│       │   │       └── mindspore
│       │   │           └── posenetdemo # 图像处理及推理流程实现
│       │   │               ├── CameraDataDealListener.java
│       │   │               ├── ImageUtils.java
│       │   │               ├── MainActivity.java
│       │   │               ├── PoseNetFragment.java
│       │   │               ├── Posenet.java #
│       │   │               └── TestActivity.java
│       │   └── res # 存放Android相关的资源文件
│       └── test
└── ...

下载及部署模型文件

从MindSpore Model Hub中下载模型文件本示例程序中使用的目标检测模型文件为style_predict_quant.msstyle_transfer_quant.ms,同样通过download.gradle脚本在APP构建时自动下载并放置在app/src/main/assets工程目录下。

若下载失败请手动下载模型文件style_predict_quant.ms 下载链接以及style_transfer_quant.ms 下载链接

编写端侧推理代码

在风格迁移demo中使用Java API实现端测推理。相比于C++ APIJava API可以直接在Java Class中调用无需实现JNI层的相关代码具有更好的便捷性。

风格迁移demo推理代码流程如下完整代码请参见src/main/java/com/mindspore/styletransferdemo/StyleTransferModelExecutor.java

  1. 加载MindSpore Lite模型文件构建上下文、会话以及用于推理的计算图。

    • 加载模型从文件系统中读取MindSpore Lite模型并进行模型解析。

      // Load the .ms model.
      style_predict_model = new Model();
      if (!style_predict_model.loadModel(mContext, "style_predict_quant.ms")) {
          Log.e("MS_LITE", "Load style_predict_model failed");
      }
      
      style_transform_model = new Model();
      if (!style_transform_model.loadModel(mContext, "style_transfer_quant.ms")) {
          Log.e("MS_LITE", "Load style_transform_model failed");
      }
      
    • 创建配置上下文:创建配置上下文MSConfig,保存会话所需的一些基本配置参数,用于指导图编译和图执行。

      msConfig = new MSConfig();
      if (!msConfig.init(DeviceType.DT_CPU, NUM_THREADS, CpuBindMode.MID_CPU)) {
          Log.e("MS_LITE", "Init context failed");
      }
      
    • 创建会话:创建LiteSession,并调用init方法将上一步得到MSConfig配置到会话中。

      // Create the MindSpore lite session.
      Predict_session = new LiteSession();
      if (!Predict_session.init(msConfig)) {
          Log.e("MS_LITE", "Create Predict_session failed");
          msConfig.free();
      }
      
      Transform_session = new LiteSession();
      if (!Transform_session.init(msConfig)) {
          Log.e("MS_LITE", "Create Predict_session failed");
          msConfig.free();
      }
      msConfig.free();
      
    • 加载模型文件并构建用于推理的计算图

      // Complile graph.
      if (!Predict_session.compileGraph(style_predict_model)) {
          Log.e("MS_LITE", "Compile style_predict graph failed");
          style_predict_model.freeBuffer();
      }
      if (!Transform_session.compileGraph(style_transform_model)) {
          Log.e("MS_LITE", "Compile style_transform graph failed");
          style_transform_model.freeBuffer();
      }
      
      // Note: when use model.freeBuffer(), the model can not be complile graph again.
      style_predict_model.freeBuffer();
      style_transform_model.freeBuffer();
      
  2. 输入数据: Java目前支持byte[]或者ByteBuffer两种类型的数据设置输入Tensor的数据。

    • 在输入数据之前将float数组转换为byte数组。

      
       public static byte[] floatArrayToByteArray(float[] floats) {
           ByteBuffer buffer = ByteBuffer.allocate(4 * floats.length);
           buffer.order(ByteOrder.nativeOrder());
           FloatBuffer floatBuffer = buffer.asFloatBuffer();
           floatBuffer.put(floats);
           return buffer.array();
       }
      
    • 通过ByteBuffer输入数据。contentImage为用户提供的图片,styleBitmap为预置风格图片。

       public ModelExecutionResult execute(Bitmap contentImage, Bitmap styleBitmap) {
           Log.i(TAG, "running models");
           fullExecutionTime = SystemClock.uptimeMillis();
           preProcessTime = SystemClock.uptimeMillis();
           ByteBuffer contentArray =
                   ImageUtils.bitmapToByteBuffer(contentImage, CONTENT_IMAGE_SIZE, CONTENT_IMAGE_SIZE, 0, 255);
           ByteBuffer input = ImageUtils.bitmapToByteBuffer(styleBitmap, STYLE_IMAGE_SIZE, STYLE_IMAGE_SIZE, 0, 255);
      
  3. 对输入Tensor按照模型进行推理获取输出Tensor并进行后处理。

    • 使用runGraph对预置图片进行模型推理,并获取结果Predict_results

      List<MSTensor> Predict_inputs = Predict_session.getInputs();
      if (Predict_inputs.size() != 1) {
          return null;
      }
      MSTensor Predict_inTensor = Predict_inputs.get(0);
      Predict_inTensor.setData(input);
      
      preProcessTime = SystemClock.uptimeMillis() - preProcessTime;
      stylePredictTime = SystemClock.uptimeMillis();
      
      
      if (!Predict_session.runGraph()) {
          Log.e("MS_LITE", "Run Predict_graph failed");
          return null;
      }
      stylePredictTime = SystemClock.uptimeMillis() - stylePredictTime;
      Log.d(TAG, "Style Predict Time to run: " + stylePredictTime);
      
      // Get output tensor values.
      List<String> tensorNames = Predict_session.getOutputTensorNames();
      Map<String, MSTensor> outputs = Predict_session.getOutputMapByTensor();
      Set<Map.Entry<String, MSTensor>> entrys = outputs.entrySet();
      
      float[] Predict_results = null;
      for (String tensorName : tensorNames) {
          MSTensor output = outputs.get(tensorName);
          if (output == null) {
              Log.e("MS_LITE", "Can not find Predict_session output " + tensorName);
              return null;
          }
          int type = output.getDataType();
          Predict_results = output.getFloatData();
      }
      
    • 利用上一步获取的结果,再次对用户图片进行模型推理,得到风格转换的数据transform_results

          List<MSTensor> Transform_inputs = Transform_session.getInputs();
          // transform model have 2 input tensor,  tensor0: 1*1*1*100,   tensor11*384*384*3
          MSTensor Transform_inputs_inTensor0 = Transform_inputs.get(0);
          Transform_inputs_inTensor0.setData(floatArrayToByteArray(Predict_results));
      
          MSTensor Transform_inputs_inTensor1 = Transform_inputs.get(1);
          Transform_inputs_inTensor1.setData(contentArray);
      
      
          styleTransferTime = SystemClock.uptimeMillis();
      
          if (!Transform_session.runGraph()) {
              Log.e("MS_LITE", "Run Transform_graph failed");
              return null;
          }
      
          styleTransferTime = SystemClock.uptimeMillis() - styleTransferTime;
          Log.d(TAG, "Style apply Time to run: " + styleTransferTime);
      
          postProcessTime = SystemClock.uptimeMillis();
      
          // Get output tensor values.
          List<String> Transform_tensorNames = Transform_session.getOutputTensorNames();
          Map<String, MSTensor> Transform_outputs = Transform_session.getOutputMapByTensor();
      
          float[] transform_results = null;
          for (String tensorName : Transform_tensorNames) {
              MSTensor output1 = Transform_outputs.get(tensorName);
              if (output1 == null) {
                  Log.e("MS_LITE", "Can not find Transform_session output " + tensorName);
                  return null;
              }
              transform_results = output1.getFloatData();
          }
      
    • 对输出节点的数据进行处理,得到推理后的最终结果。

      float[][][][] outputImage = new float[1][][][];  // 1 384 384 3
      for (int x = 0; x < 1; x++) {
          float[][][] arrayThree = new float[CONTENT_IMAGE_SIZE][][];
          for (int y = 0; y < CONTENT_IMAGE_SIZE; y++) {
              float[][] arrayTwo = new float[CONTENT_IMAGE_SIZE][];
              for (int z = 0; z < CONTENT_IMAGE_SIZE; z++) {
                  float[] arrayOne = new float[3];
                  for (int i = 0; i < 3; i++) {
                      int n = i + z * 3 + y * CONTENT_IMAGE_SIZE * 3 + x * CONTENT_IMAGE_SIZE * CONTENT_IMAGE_SIZE * 3;
                      arrayOne[i] = transform_results[n];
                  }
                  arrayTwo[z] = arrayOne;
              }
              arrayThree[y] = arrayTwo;
          }
          outputImage[x] = arrayThree;
      }
      
      
      Bitmap styledImage =
              ImageUtils.convertArrayToBitmap(outputImage, CONTENT_IMAGE_SIZE, CONTENT_IMAGE_SIZE);
      postProcessTime = SystemClock.uptimeMillis() - postProcessTime;
      
      fullExecutionTime = SystemClock.uptimeMillis() - fullExecutionTime;
      Log.d(TAG, "Time to run everything: $" + fullExecutionTime);
      
      return new ModelExecutionResult(styledImage,
              preProcessTime,
              stylePredictTime,
              styleTransferTime,
              postProcessTime,
              fullExecutionTime,
              formatExecutionLog());