|
|
|
@ -67,7 +67,30 @@ int OpenCLKernel::GetImageSize(size_t idx, lite::opencl::ImageSize *img_size) {
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
auto img_info = GpuTensorInfo(out_tensors_[idx]);
|
|
|
|
|
size_t img_dtype = ocl_runtime_->GetFp16Enable() ? CL_HALF_FLOAT : CL_FLOAT;
|
|
|
|
|
size_t img_dtype = CL_FLOAT;
|
|
|
|
|
switch (out_tensors_[idx]->data_type()) {
|
|
|
|
|
case kNumberTypeFloat32:
|
|
|
|
|
case kNumberTypeInt32:
|
|
|
|
|
case kNumberTypeUInt32: {
|
|
|
|
|
img_dtype = CL_FLOAT;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
case kNumberTypeFloat16:
|
|
|
|
|
case kNumberTypeInt16:
|
|
|
|
|
case kNumberTypeUInt16: {
|
|
|
|
|
img_dtype = CL_HALF_FLOAT;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
case kNumberTypeInt8:
|
|
|
|
|
case kNumberTypeUInt8: {
|
|
|
|
|
img_dtype = CL_UNSIGNED_INT8;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
default: {
|
|
|
|
|
MS_LOG(WARNING) << "Unsupported data_type " << out_tensors_[idx]->data_type();
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
*img_size = {img_info.width, img_info.height, img_dtype};
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
@ -326,6 +349,7 @@ std::set<size_t> OpenCLKernel::GenerateLocalByGlobal(size_t global_i) {
|
|
|
|
|
}
|
|
|
|
|
return local_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int OpenCLKernel::DequantWeight() {
|
|
|
|
|
bool is_fp16 = ocl_runtime_->GetFp16Enable();
|
|
|
|
|
auto *weight_tensor = in_tensors_.at(kWeightIndex);
|
|
|
|
@ -368,6 +392,7 @@ int OpenCLKernel::DequantWeight() {
|
|
|
|
|
}
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void OpenCLKernel::FreeDequantedWeight() {
|
|
|
|
|
auto *weight_tensor = in_tensors_.at(kWeightIndex);
|
|
|
|
|
if (dequant_flag_) {
|
|
|
|
|