|
|
|
@ -39,46 +39,7 @@ using mindspore::schema::PrimitiveType_Activation;
|
|
|
|
|
|
|
|
|
|
namespace mindspore::kernel {
|
|
|
|
|
|
|
|
|
|
void ActivationOpenClKernel::InitBuffer() {
|
|
|
|
|
auto allocator = lite::opencl::OpenCLRuntime::GetInstance()->GetAllocator();
|
|
|
|
|
int elem_num = UP_ROUND(nhwc_shape_[3], C4NUM);
|
|
|
|
|
alpha_buff_ = allocator->Malloc(elem_num * fp_size);
|
|
|
|
|
alpha_buff_ = allocator->MapBuffer(alpha_buff_, CL_MAP_WRITE, nullptr, true);
|
|
|
|
|
memset(alpha_buff_, 0x00, elem_num * fp_size);
|
|
|
|
|
if (in_tensors_.size() == 1) {
|
|
|
|
|
if (enable_fp16_) {
|
|
|
|
|
uint16_t alpha_fp16 = Float32ToShort(alpha_);
|
|
|
|
|
auto alpha_buff_fp16 = reinterpret_cast<uint16_t *>(alpha_buff_);
|
|
|
|
|
for (int i = 0; i < nhwc_shape_[3]; i++) {
|
|
|
|
|
alpha_buff_fp16[i] = alpha_fp16;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
auto alpha_buff_fp16 = reinterpret_cast<float *>(alpha_buff_);
|
|
|
|
|
for (int i = 0; i < nhwc_shape_[3]; i++) {
|
|
|
|
|
alpha_buff_fp16[i] = alpha_;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
if (enable_fp16_) {
|
|
|
|
|
if (in_tensors_[1]->data_type() == kNumberTypeFloat32) {
|
|
|
|
|
auto alpha_buff_fp16 = reinterpret_cast<uint16_t *>(alpha_buff_);
|
|
|
|
|
for (int i = 0; i < nhwc_shape_[3]; i++) {
|
|
|
|
|
alpha_buff_fp16[i] = Float32ToShort(reinterpret_cast<float *>(in_tensors_[0]->Data())[i]);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
memcpy(alpha_buff_, in_tensors_[0]->Data(), nhwc_shape_[3] * fp_size);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
if (in_tensors_[1]->data_type() == kNumberTypeFloat16) {
|
|
|
|
|
MS_LOG(WARNING) << "fp16 model run in fp32 mode not support.";
|
|
|
|
|
memcpy(alpha_buff_, in_tensors_[0]->Data(), nhwc_shape_[3] * fp_size);
|
|
|
|
|
} else {
|
|
|
|
|
memcpy(alpha_buff_, in_tensors_[0]->Data(), nhwc_shape_[3] * fp_size);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
allocator->UnmapBuffer(alpha_buff_);
|
|
|
|
|
}
|
|
|
|
|
void ActivationOpenClKernel::InitBuffer() {}
|
|
|
|
|
|
|
|
|
|
int ActivationOpenClKernel::Init() {
|
|
|
|
|
in_size_ = in_tensors_[0]->shape().size();
|
|
|
|
@ -102,9 +63,6 @@ int ActivationOpenClKernel::Init() {
|
|
|
|
|
MS_LOG(ERROR) << "Activate fun only support dim=4 or 2, but your dim=" << in_size_;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
if (type_ == ActivationType_LEAKY_RELU) {
|
|
|
|
|
InitBuffer();
|
|
|
|
|
}
|
|
|
|
|
std::map<int, std::vector<std::string>> Program_Kernel{
|
|
|
|
|
{ActivationType_LEAKY_RELU, std::vector<std::string>{"LEAKY_RELU", "LeakyRelu"}},
|
|
|
|
|
{ActivationType_RELU, std::vector<std::string>{"RELU", "Relu"}},
|
|
|
|
@ -119,9 +77,6 @@ int ActivationOpenClKernel::Init() {
|
|
|
|
|
std::set<std::string> build_options;
|
|
|
|
|
ocl_runtime->LoadSource(Program_Kernel[type_][0], source);
|
|
|
|
|
std::string kernel_name = Program_Kernel[type_][1];
|
|
|
|
|
if (type_ == ActivationType_LEAKY_RELU) {
|
|
|
|
|
kernel_name += "_" + std::string(EnumNameFormat(op_format_));
|
|
|
|
|
}
|
|
|
|
|
ocl_runtime->BuildKernel(kernel_, Program_Kernel[type_][0], kernel_name, build_options);
|
|
|
|
|
in_ori_format_ = in_tensors_[0]->GetFormat();
|
|
|
|
|
out_ori_format_ = out_tensors_[0]->GetFormat();
|
|
|
|
@ -140,10 +95,7 @@ int ActivationOpenClKernel::Run() {
|
|
|
|
|
ocl_runtime->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->Data());
|
|
|
|
|
ocl_runtime->SetKernelArg(kernel_, arg_idx++, img2d_shape);
|
|
|
|
|
if (type_ == ActivationType_LEAKY_RELU) {
|
|
|
|
|
ocl_runtime->SetKernelArg(kernel_, arg_idx++, alpha_buff_, lite::opencl::MemType::BUF);
|
|
|
|
|
cl_int4 input_shape = {static_cast<int>(nhwc_shape_[0]), static_cast<int>(nhwc_shape_[1]),
|
|
|
|
|
static_cast<int>(nhwc_shape_[2]), static_cast<int>(nhwc_shape_[3])};
|
|
|
|
|
ocl_runtime->SetKernelArg(kernel_, arg_idx++, input_shape);
|
|
|
|
|
ocl_runtime->SetKernelArg(kernel_, arg_idx++, alpha_);
|
|
|
|
|
}
|
|
|
|
|
std::vector<size_t> local = {};
|
|
|
|
|
std::vector<size_t> global = {static_cast<size_t>(img2d_shape.s[1]), static_cast<size_t>(img2d_shape.s[2])};
|
|
|
|
|