|
|
|
@ -71,7 +71,8 @@ int CastCPUKernel::DoCast(int thread_id) {
|
|
|
|
|
auto input_data_type = input->data_type();
|
|
|
|
|
auto output_data_type = output->data_type();
|
|
|
|
|
if (output_data_type != kNumberTypeFloat32) {
|
|
|
|
|
if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeInt32) {
|
|
|
|
|
if (input_data_type == kNumberTypeFloat32 &&
|
|
|
|
|
(output_data_type == kNumberTypeInt32 || output_data_type == kNumberTypeInt64)) {
|
|
|
|
|
Float32ToInt32(reinterpret_cast<float *>(input->data_c()) + offset,
|
|
|
|
|
reinterpret_cast<int32_t *>(output_data) + offset, data_num);
|
|
|
|
|
} else if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeFloat16) {
|
|
|
|
@ -81,10 +82,6 @@ int CastCPUKernel::DoCast(int thread_id) {
|
|
|
|
|
(output_data_type == kNumberTypeInt32 || output_data_type == kNumberTypeInt64)) {
|
|
|
|
|
memcpy(reinterpret_cast<int32_t *>(output_data) + offset, reinterpret_cast<int32_t *>(input->data_c()) + offset,
|
|
|
|
|
data_num * sizeof(int32_t));
|
|
|
|
|
} else if (input_data_type == kNumberTypeFloat32 &&
|
|
|
|
|
(output_data_type == kNumberTypeInt32 || output_data_type == kNumberTypeInt64)) {
|
|
|
|
|
memcpy(reinterpret_cast<float *>(output_data) + offset, reinterpret_cast<float *>(input->data_c()) + offset,
|
|
|
|
|
data_num * sizeof(float));
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(ERROR) << "Unsupported datatype from " << input_data_type << " to " << output_data_type;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|