add supported data types for cast ops

pull/8287/head
liuwenhao4 4 years ago
parent d79bcc923e
commit bcaf43e0fb

@ -17,6 +17,7 @@
#include <vector> #include <vector>
#include "schema/model_generated.h" #include "schema/model_generated.h"
#include "src/kernel_registry.h" #include "src/kernel_registry.h"
#include "src/tensor.h"
#include "nnacl/fp32/cast.h" #include "nnacl/fp32/cast.h"
#include "nnacl/op_base.h" #include "nnacl/op_base.h"
#include "src/runtime/runtime_api.h" #include "src/runtime/runtime_api.h"
@ -70,6 +71,12 @@ int CastCPUKernel::DoCast(int thread_id) {
MS_ASSERT(output_data != nullptr); MS_ASSERT(output_data != nullptr);
auto input_data_type = input->data_type(); auto input_data_type = input->data_type();
auto output_data_type = output->data_type(); auto output_data_type = output->data_type();
if (input_data_type == output_data_type) {
auto datalen = lite::DataTypeSize(input_data_type);
memcpy(reinterpret_cast<char *>(output_data) + offset * datalen,
reinterpret_cast<char *>(input->data_c()) + offset * datalen, data_num * datalen);
return RET_OK;
}
if (output_data_type != kNumberTypeFloat32) { if (output_data_type != kNumberTypeFloat32) {
if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeInt64) { if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeInt64) {
Float32ToInt64(reinterpret_cast<float *>(input->data_c()) + offset, Float32ToInt64(reinterpret_cast<float *>(input->data_c()) + offset,
@ -83,9 +90,6 @@ int CastCPUKernel::DoCast(int thread_id) {
} else if (input_data_type == kNumberTypeInt32 && output_data_type == kNumberTypeInt64) { } else if (input_data_type == kNumberTypeInt32 && output_data_type == kNumberTypeInt64) {
Int32ToInt64(reinterpret_cast<int32_t *>(input->data_c()) + offset, Int32ToInt64(reinterpret_cast<int32_t *>(input->data_c()) + offset,
reinterpret_cast<int64_t *>(output_data) + offset, data_num); reinterpret_cast<int64_t *>(output_data) + offset, data_num);
} else if (input_data_type == kNumberTypeInt32 && output_data_type == kNumberTypeInt32) {
memcpy(reinterpret_cast<int32_t *>(output_data) + offset, reinterpret_cast<int32_t *>(input->data_c()) + offset,
data_num * sizeof(int32_t));
} else { } else {
MS_LOG(ERROR) << "Unsupported datatype from " << input_data_type << " to " << output_data_type; MS_LOG(ERROR) << "Unsupported datatype from " << input_data_type << " to " << output_data_type;
return RET_ERROR; return RET_ERROR;
@ -108,10 +112,6 @@ int CastCPUKernel::DoCast(int thread_id) {
Fp16ToFloat32(reinterpret_cast<uint16_t *>(input->MutableData()) + offset, Fp16ToFloat32(reinterpret_cast<uint16_t *>(input->MutableData()) + offset,
reinterpret_cast<float *>(output_data) + offset, data_num); reinterpret_cast<float *>(output_data) + offset, data_num);
break; break;
case kNumberTypeFloat32:
memcpy(reinterpret_cast<float *>(output_data) + offset, reinterpret_cast<float *>(input->data_c()) + offset,
data_num * sizeof(float));
break;
default: default:
MS_LOG(ERROR) << "Unsupported input data type " << input_data_type; MS_LOG(ERROR) << "Unsupported input data type " << input_data_type;
return RET_ERROR; return RET_ERROR;

Loading…
Cancel
Save