From 2c002b1f6ee5fa04b09f204bc96962fe75a46bb0 Mon Sep 17 00:00:00 2001 From: yeyunpeng Date: Wed, 26 Aug 2020 18:08:53 +0800 Subject: [PATCH] Get different types of data in tensor --- .../java/com/mindspore/lite/MSTensor.java | 48 +++++-- .../lite/java/native/runtime/ms_tensor.cpp | 119 +++++++++++++++++- 2 files changed, 149 insertions(+), 18 deletions(-) diff --git a/mindspore/lite/java/java/app/src/main/java/com/mindspore/lite/MSTensor.java b/mindspore/lite/java/java/app/src/main/java/com/mindspore/lite/MSTensor.java index 3e4274e31e..0dcf8d2049 100644 --- a/mindspore/lite/java/java/app/src/main/java/com/mindspore/lite/MSTensor.java +++ b/mindspore/lite/java/java/app/src/main/java/com/mindspore/lite/MSTensor.java @@ -1,12 +1,12 @@ /** * Copyright 2020 Huawei Technologies Co., Ltd - * + *

* Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * + *

* http://www.apache.org/licenses/LICENSE-2.0 - * + *

* Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -18,6 +18,8 @@ package com.mindspore.lite; import android.util.Log; +import java.nio.ByteBuffer; + public class MSTensor { private long tensorPtr; @@ -29,7 +31,7 @@ public class MSTensor { this.tensorPtr = tensorPtr; } - public boolean init (int dataType, int[] shape) { + public boolean init(int dataType, int[] shape) { this.tensorPtr = createMSTensor(dataType, shape, shape.length); return this.tensorPtr != 0; } @@ -50,18 +52,30 @@ public class MSTensor { this.setDataType(this.tensorPtr, dataType); } - public byte[] getData() { - return this.getData(this.tensorPtr); + public byte[] getBtyeData() { + return this.getByteData(this.tensorPtr); } public float[] getFloatData() { - return decodeBytes(this.getData(this.tensorPtr)); + return this.getFloatData(this.tensorPtr); + } + + public int[] getIntData() { + return this.getIntData(this.tensorPtr); + } + + public long[] getLongData() { + return this.getLongData(this.tensorPtr); } public void setData(byte[] data) { this.setData(this.tensorPtr, data, data.length); } + public void setData(ByteBuffer data) { + this.setByteBufferData(this.tensorPtr, data); + } + public long size() { return this.size(this.tensorPtr); } @@ -82,13 +96,13 @@ public class MSTensor { } int size = bytes.length / 4; float[] ret = new float[size]; - for (int i = 0; i < size; i=i+4) { + for (int i = 0; i < size; i = i + 4) { int accNum = 0; accNum = accNum | (bytes[i] & 0xff) << 0; - accNum = accNum | (bytes[i+1] & 0xff) << 8; - accNum = accNum | (bytes[i+2] & 0xff) << 16; - accNum = accNum | (bytes[i+3] & 0xff) << 24; - ret[i/4] = Float.intBitsToFloat(accNum); + accNum = accNum | (bytes[i + 1] & 0xff) << 8; + accNum = accNum | (bytes[i + 2] & 0xff) << 16; + accNum = accNum | (bytes[i + 3] & 0xff) << 24; + ret[i / 4] = Float.intBitsToFloat(accNum); } return ret; } @@ -103,10 +117,18 @@ public class MSTensor { private native boolean setDataType(long tensorPtr, int dataType); - private native byte[] getData(long tensorPtr); + private native byte[] getByteData(long tensorPtr); + + private native long[] getLongData(long tensorPtr); + + private native int[] getIntData(long tensorPtr); + + private native float[] getFloatData(long tensorPtr); private native boolean setData(long tensorPtr, byte[] data, long dataLen); + private native boolean setByteBufferData(long tensorPtr, ByteBuffer buffer); + private native long size(long tensorPtr); private native int elementsNum(long tensorPtr); diff --git a/mindspore/lite/java/native/runtime/ms_tensor.cpp b/mindspore/lite/java/native/runtime/ms_tensor.cpp index 8f3a607f85..d71aca9f41 100644 --- a/mindspore/lite/java/native/runtime/ms_tensor.cpp +++ b/mindspore/lite/java/native/runtime/ms_tensor.cpp @@ -99,8 +99,8 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_MSTensor_setDataTy return ret == data_type; } -extern "C" JNIEXPORT jbyteArray JNICALL Java_com_mindspore_lite_MSTensor_getData(JNIEnv *env, jobject thiz, - jlong tensor_ptr) { +extern "C" JNIEXPORT jbyteArray JNICALL Java_com_mindspore_lite_MSTensor_getByteData(JNIEnv *env, jobject thiz, + jlong tensor_ptr) { auto *pointer = reinterpret_cast(tensor_ptr); if (pointer == nullptr) { MS_LOGE("Tensor pointer from java is nullptr"); @@ -113,16 +113,95 @@ extern "C" JNIEXPORT jbyteArray JNICALL Java_com_mindspore_lite_MSTensor_getData return env->NewByteArray(0); } - auto *float_local_data = reinterpret_cast(ms_tensor_ptr->MutableData()); - for (size_t i = 0; i < ms_tensor_ptr->ElementsNum() && i < 5; i++) { - MS_LOGE("data[%zu] = %f", i, float_local_data[i]); + if (ms_tensor_ptr->data_type() != mindspore::kNumberTypeUInt8) { + MS_LOGE("data type is error : %d", ms_tensor_ptr->data_type()); + return env->NewByteArray(0); } + auto local_data_size = ms_tensor_ptr->Size(); auto ret = env->NewByteArray(local_data_size); env->SetByteArrayRegion(ret, 0, local_data_size, local_data); return ret; } +extern "C" JNIEXPORT jlongArray JNICALL Java_com_mindspore_lite_MSTensor_getLongData(JNIEnv *env, jobject thiz, + jlong tensor_ptr) { + auto *pointer = reinterpret_cast(tensor_ptr); + if (pointer == nullptr) { + MS_LOGE("Tensor pointer from java is nullptr"); + return env->NewLongArray(0); + } + + auto *ms_tensor_ptr = static_cast(pointer); + + auto *local_data = static_cast(ms_tensor_ptr->MutableData()); + if (local_data == nullptr) { + MS_LOGD("Tensor has no data"); + return env->NewLongArray(0); + } + + if (ms_tensor_ptr->data_type() != mindspore::kNumberTypeInt64) { + MS_LOGE("data type is error : %d", ms_tensor_ptr->data_type()); + return env->NewLongArray(0); + } + auto local_data_size = ms_tensor_ptr->Size(); + auto ret = env->NewLongArray(local_data_size); + env->SetLongArrayRegion(ret, 0, local_data_size, local_data); + return ret; +} + +extern "C" JNIEXPORT jintArray JNICALL Java_com_mindspore_lite_MSTensor_getIntData(JNIEnv *env, jobject thiz, + jlong tensor_ptr) { + auto *pointer = reinterpret_cast(tensor_ptr); + if (pointer == nullptr) { + MS_LOGE("Tensor pointer from java is nullptr"); + return env->NewIntArray(0); + } + + auto *ms_tensor_ptr = static_cast(pointer); + + auto *local_data = static_cast(ms_tensor_ptr->MutableData()); + if (local_data == nullptr) { + MS_LOGD("Tensor has no data"); + return env->NewIntArray(0); + } + + if (ms_tensor_ptr->data_type() != mindspore::kNumberTypeInt32) { + MS_LOGE("data type is error : %d", ms_tensor_ptr->data_type()); + return env->NewIntArray(0); + } + auto local_data_size = ms_tensor_ptr->Size(); + auto ret = env->NewIntArray(local_data_size); + env->SetIntArrayRegion(ret, 0, local_data_size, local_data); + return ret; +} + +extern "C" JNIEXPORT jfloatArray JNICALL Java_com_mindspore_lite_MSTensor_getFloatData(JNIEnv *env, jobject thiz, + jlong tensor_ptr) { + auto *pointer = reinterpret_cast(tensor_ptr); + if (pointer == nullptr) { + MS_LOGE("Tensor pointer from java is nullptr"); + return env->NewFloatArray(0); + } + + auto *ms_tensor_ptr = static_cast(pointer); + + auto *local_data = static_cast(ms_tensor_ptr->MutableData()); + if (local_data == nullptr) { + MS_LOGD("Tensor has no data"); + return env->NewFloatArray(0); + } + + if (ms_tensor_ptr->data_type() != mindspore::kNumberTypeFloat32) { + MS_LOGE("data type is error : %d", ms_tensor_ptr->data_type()); + return env->NewFloatArray(0); + } + auto local_data_size = ms_tensor_ptr->Size(); + auto ret = env->NewFloatArray(local_data_size); + env->SetFloatArrayRegion(ret, 0, local_data_size, local_data); + return ret; +} + extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_MSTensor_setData(JNIEnv *env, jobject thiz, jlong tensor_ptr, jbyteArray data, jlong data_len) { @@ -143,6 +222,36 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_MSTensor_setData(J return static_cast(true); } +extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_MSTensor_setByteBufferData(JNIEnv *env, jobject thiz, + jlong tensor_ptr, + jobject buffer) { + jbyte *p_data = reinterpret_cast(env->GetDirectBufferAddress(buffer)); // get buffer poiter + jlong data_len = env->GetDirectBufferCapacity(buffer); // get buffer capacity + if (!p_data) { + MS_LOGE("GetDirectBufferAddress return null"); + return NULL; + } + jbyteArray data = env->NewByteArray(data_len); // create byte[] + env->SetByteArrayRegion(data, 0, data_len, p_data); // copy data to byte[] + + auto *pointer = reinterpret_cast(tensor_ptr); + if (pointer == nullptr) { + MS_LOGE("Tensor pointer from java is nullptr"); + return static_cast(false); + } + + auto *ms_tensor_ptr = static_cast(pointer); + if (data_len != ms_tensor_ptr->Size()) { + MS_LOGE("data_len(%ld) not equal to Size of ms_tensor(%zu)", data_len, ms_tensor_ptr->Size()); + return static_cast(false); + } + jboolean is_copy = false; + auto *data_arr = env->GetByteArrayElements(data, &is_copy); + auto *local_data = ms_tensor_ptr->MutableData(); + memcpy(local_data, data_arr, data_len); + return static_cast(true); +} + extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_MSTensor_size(JNIEnv *env, jobject thiz, jlong tensor_ptr) { auto *pointer = reinterpret_cast(tensor_ptr); if (pointer == nullptr) {