|
|
|
@ -351,57 +351,6 @@ std::set<size_t> OpenCLKernel::GenerateLocalByGlobal(size_t global_i) {
|
|
|
|
|
return local_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int OpenCLKernel::DequantWeight() {
|
|
|
|
|
bool is_fp16 = ocl_runtime_->GetFp16Enable();
|
|
|
|
|
auto *weight_tensor = in_tensors_.at(kWeightIndex);
|
|
|
|
|
restore_quant_data_ = weight_tensor->data_c();
|
|
|
|
|
dequant_flag_ = !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited &&
|
|
|
|
|
restore_quant_data_ != nullptr;
|
|
|
|
|
if (dequant_flag_) {
|
|
|
|
|
void *dequant_weight{nullptr};
|
|
|
|
|
bool set_flag{true};
|
|
|
|
|
if (is_fp16) {
|
|
|
|
|
#ifdef ENABLE_ARM64
|
|
|
|
|
if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeInt8) {
|
|
|
|
|
dequant_weight = lite::DequantUtil::DequantData<int8_t, float16_t>(weight_tensor);
|
|
|
|
|
weight_tensor->set_data_type(kNumberTypeFloat16);
|
|
|
|
|
} else if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeInt16) {
|
|
|
|
|
dequant_weight = lite::DequantUtil::DequantData<int16_t, float16_t>(weight_tensor);
|
|
|
|
|
weight_tensor->set_data_type(kNumberTypeFloat16);
|
|
|
|
|
} else {
|
|
|
|
|
set_flag = false;
|
|
|
|
|
}
|
|
|
|
|
#else
|
|
|
|
|
set_flag = false;
|
|
|
|
|
#endif
|
|
|
|
|
} else {
|
|
|
|
|
if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeInt8) {
|
|
|
|
|
dequant_weight = lite::DequantUtil::DequantData<int8_t, float>(weight_tensor);
|
|
|
|
|
weight_tensor->set_data_type(kNumberTypeFloat32);
|
|
|
|
|
} else if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeInt16) {
|
|
|
|
|
dequant_weight = lite::DequantUtil::DequantData<int16_t, float>(weight_tensor);
|
|
|
|
|
weight_tensor->set_data_type(kNumberTypeFloat32);
|
|
|
|
|
} else {
|
|
|
|
|
set_flag = false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (set_flag && dequant_weight == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "dequant data failed.";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
weight_tensor->set_data(dequant_weight);
|
|
|
|
|
}
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void OpenCLKernel::FreeDequantedWeight() {
|
|
|
|
|
auto *weight_tensor = in_tensors_.at(kWeightIndex);
|
|
|
|
|
if (dequant_flag_) {
|
|
|
|
|
free(weight_tensor->data_c());
|
|
|
|
|
weight_tensor->set_data(restore_quant_data_);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int OpenCLKernel::CheckSpecs() {
|
|
|
|
|
if (out_mem_type_ == lite::opencl::MemType::IMG) {
|
|
|
|
|
if (!GpuTensorInfo(out_tensors_[0]).IsImageSizeValid()) {
|
|
|
|
|