|
|
|
@ -65,25 +65,58 @@ int CastFp16CPUKernel::DoCast(int thread_id) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto offset = thread_id * stride_;
|
|
|
|
|
auto output_data = out_tensors_.at(0)->MutableData();
|
|
|
|
|
switch (input->data_type()) {
|
|
|
|
|
case kNumberTypeBool:
|
|
|
|
|
BoolToFloat16(reinterpret_cast<bool *>(input->MutableData()) + offset,
|
|
|
|
|
reinterpret_cast<float16_t *>(output_data) + offset, data_num);
|
|
|
|
|
case kNumberTypeUInt8:
|
|
|
|
|
Uint8ToFloat16(reinterpret_cast<uint8_t *>(input->MutableData()) + offset,
|
|
|
|
|
reinterpret_cast<float16_t *>(output_data) + offset, data_num);
|
|
|
|
|
case kNumberTypeFloat32:
|
|
|
|
|
Float32ToFloat16(reinterpret_cast<float *>(input->MutableData()) + offset,
|
|
|
|
|
reinterpret_cast<float16_t *>(output_data) + offset, data_num);
|
|
|
|
|
break;
|
|
|
|
|
case kNumberTypeFloat16:
|
|
|
|
|
Float16ToFloat32(reinterpret_cast<float16_t *>(input->MutableData()) + offset,
|
|
|
|
|
reinterpret_cast<float *>(output_data) + offset, data_num);
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
MS_LOG(ERROR) << "Unsupported input data type " << input->data_type();
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
auto output = out_tensors_.at(0);
|
|
|
|
|
auto output_data = output->data_c();
|
|
|
|
|
auto input_data_type = input->data_type();
|
|
|
|
|
auto output_data_type = output->data_type();
|
|
|
|
|
|
|
|
|
|
if (input_data_type == kNumberTypeFloat16) {
|
|
|
|
|
switch (output_data_type) {
|
|
|
|
|
case kNumberTypeInt64:
|
|
|
|
|
Float16ToInt64(reinterpret_cast<float16_t *>(input->data_c()) + offset,
|
|
|
|
|
reinterpret_cast<int64_t *>(output_data) + offset, data_num);
|
|
|
|
|
break;
|
|
|
|
|
case kNumberTypeInt32:
|
|
|
|
|
Float16ToInt32(reinterpret_cast<float16_t *>(input->data_c()) + offset,
|
|
|
|
|
reinterpret_cast<int32_t *>(output_data) + offset, data_num);
|
|
|
|
|
break;
|
|
|
|
|
case kNumberTypeFloat32:
|
|
|
|
|
Float16ToFloat32(reinterpret_cast<float16_t *>(input->MutableData()) + offset,
|
|
|
|
|
reinterpret_cast<float *>(output_data) + offset, data_num);
|
|
|
|
|
break;
|
|
|
|
|
case kNumberTypeFloat16:
|
|
|
|
|
memcpy(reinterpret_cast<float16_t *>(output_data) + offset,
|
|
|
|
|
reinterpret_cast<float16_t *>(input->data_c()) + offset, data_num * sizeof(float16_t));
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
MS_LOG(ERROR) << "Unsupported output data type " << output_data_type;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
} else if (input_data_type == kNumberTypeFloat32) {
|
|
|
|
|
switch (output_data_type) {
|
|
|
|
|
case kNumberTypeInt64:
|
|
|
|
|
Float32ToInt64(reinterpret_cast<float *>(input->data_c()) + offset,
|
|
|
|
|
reinterpret_cast<int64_t *>(output_data) + offset, data_num);
|
|
|
|
|
break;
|
|
|
|
|
case kNumberTypeInt32:
|
|
|
|
|
Float32ToInt32(reinterpret_cast<float *>(input->data_c()) + offset,
|
|
|
|
|
reinterpret_cast<int32_t *>(output_data) + offset, data_num);
|
|
|
|
|
break;
|
|
|
|
|
case kNumberTypeFloat32:
|
|
|
|
|
memcpy(reinterpret_cast<float *>(output_data) + offset, reinterpret_cast<float *>(input->data_c()) + offset,
|
|
|
|
|
data_num * sizeof(float));
|
|
|
|
|
break;
|
|
|
|
|
case kNumberTypeFloat16:
|
|
|
|
|
Float32ToFloat16(reinterpret_cast<float *>(input->MutableData()) + offset,
|
|
|
|
|
reinterpret_cast<float16_t *>(output_data) + offset, data_num);
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
MS_LOG(ERROR) << "Unsupported output data type " << output_data_type;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(ERROR) << "Unsupported input data type " << input_data_type;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|