!11479 【MS】【LITE】【GPU】opencl support int dtype

From: @wangdongxu6
Reviewed-by: @HilbertDavid,@ddwsky
Signed-off-by: @ddwsky
pull/11479/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 1929b420b2

@ -67,7 +67,30 @@ int OpenCLKernel::GetImageSize(size_t idx, lite::opencl::ImageSize *img_size) {
return RET_ERROR;
}
auto img_info = GpuTensorInfo(out_tensors_[idx]);
size_t img_dtype = ocl_runtime_->GetFp16Enable() ? CL_HALF_FLOAT : CL_FLOAT;
size_t img_dtype = CL_FLOAT;
switch (out_tensors_[idx]->data_type()) {
case kNumberTypeFloat32:
case kNumberTypeInt32:
case kNumberTypeUInt32: {
img_dtype = CL_FLOAT;
break;
}
case kNumberTypeFloat16:
case kNumberTypeInt16:
case kNumberTypeUInt16: {
img_dtype = CL_HALF_FLOAT;
break;
}
case kNumberTypeInt8:
case kNumberTypeUInt8: {
img_dtype = CL_UNSIGNED_INT8;
break;
}
default: {
MS_LOG(WARNING) << "Unsupported data_type " << out_tensors_[idx]->data_type();
return RET_ERROR;
}
}
*img_size = {img_info.width, img_info.height, img_dtype};
return RET_OK;
}
@ -326,6 +349,7 @@ std::set<size_t> OpenCLKernel::GenerateLocalByGlobal(size_t global_i) {
}
return local_;
}
int OpenCLKernel::DequantWeight() {
bool is_fp16 = ocl_runtime_->GetFp16Enable();
auto *weight_tensor = in_tensors_.at(kWeightIndex);
@ -368,6 +392,7 @@ int OpenCLKernel::DequantWeight() {
}
return RET_OK;
}
void OpenCLKernel::FreeDequantedWeight() {
auto *weight_tensor = in_tensors_.at(kWeightIndex);
if (dequant_flag_) {

@ -254,7 +254,10 @@ int OpenCLSubGraph::UpdateTensorDataTypePass() {
for (auto jv : cur_outs) {
if (out_set.count(jv) == 0) {
MS_ASSERT(jv);
jv->set_data_type(kNumberTypeFloat16);
// if Fp16Enable, only change fp32 to fp16, other dtype is reserved
if (jv->data_type() == kNumberTypeFloat32) {
jv->set_data_type(kNumberTypeFloat16);
}
}
}
}

@ -140,9 +140,19 @@ void *OpenCLAllocator::_Malloc(MemType mem_type, void *data, size_t size, const
auto svm_capabilities = ocl_runtime_->GetSVMCapabilities();
MS_ASSERT(img_size.size() == 0 || img_size.size() == 3);
if (mem_type == MemType::IMG) {
size_t dtype_size = img_size.dtype == CL_FLOAT ? sizeof(cl_float4) : sizeof(cl_half4);
size_t dtype_size = 0;
if (img_size.dtype == CL_FLOAT) {
dtype_size = sizeof(cl_float);
} else if (img_size.dtype == CL_HALF_FLOAT) {
dtype_size = sizeof(cl_half);
} else if (img_size.dtype == CL_UNSIGNED_INT8) {
dtype_size = sizeof(cl_uchar);
} else {
MS_LOG(ERROR) << "Unsupported dtype " << img_size.dtype;
return nullptr;
}
uint32_t image_alignment = ocl_runtime_->GetImagePitchAlignment();
size = UP_ROUND(img_size.width, image_alignment) * img_size.height * dtype_size;
size = UP_ROUND(img_size.width, image_alignment) * img_size.height * C4NUM * dtype_size;
}
if (size > ocl_runtime_->GetMaxAllocSize()) {
MS_LOG(ERROR) << "MallocData out of max_size, size: " << size;

@ -360,7 +360,7 @@ int OpenCLRuntime::BuildKernel(cl::Kernel &kernel, const std::string &program_na
"-DWRITE_IMAGE=write_imageh -DREAD_IMAGE=read_imageh -DTO_FLT=convert_half -DTO_FLT4=convert_half4";
} else {
build_option +=
" -DFLT=float -DFLT4=float4 -DFLT16=float16 -DAS_FLT4=as_float4 -DAS_UINT4=as_uint4 -DUINT4=uint4 "
" -DFLT=float -DFLT4=float4 -DFLT16=float16 -DAS_FLT4=as_float4 -DAS_UINT4=as_uint4 -DUINT4=uint4 "
"-DWRITE_IMAGE=write_imagef -DREAD_IMAGE=read_imagef -DTO_FLT=convert_float -DTO_FLT4=convert_float4";
}
build_option =

Loading…
Cancel
Save