diff --git a/cmake/external_libs/opencl.cmake b/cmake/external_libs/opencl.cmake index d40c2e7201..90ad02c314 100644 --- a/cmake/external_libs/opencl.cmake +++ b/cmake/external_libs/opencl.cmake @@ -1,9 +1,9 @@ if (ENABLE_GITEE) - set(REQ_URL "https://gitee.com/mirrors/flatbuffers/repository/archive/v2020.06.16.tar.gz") + set(REQ_URL "https://gitee.com/mirrors/OpenCL-Headers/repository/archive/v2020.06.16.tar.gz") set(MD5 "fc7627b5a8a95ecbe3d5df43bc88aa44") set(PKG_GIT_TAG "") __download_pkg_with_git(OpenCL-Headers ${REQ_URL} ${PKG_GIT_TAG} ${MD5}) - set(REQ_URL "https://gitee.com/mirrors/flatbuffers/repository/archive/v2.0.12.tar.gz") + set(REQ_URL "https://gitee.com/mirrors/OpenCL-CLHPP/repository/archive/v2.0.12.tar.gz") set(MD5 "bd00fca8f861b3b65660d719f00a58dd") set(PKG_GIT_TAG "") __download_pkg_with_git(OpenCL-CLHPP ${REQ_URL} ${PKG_GIT_TAG} ${MD5}) diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/layer_norm.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/layer_norm.cl new file mode 100644 index 0000000000..5d5c7d162f --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/layer_norm.cl @@ -0,0 +1,103 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; +#define UP_DIV(x, y) (((x) + (y) - (1)) / (y)) +#define C4NUM 4 + +__kernel void ComputeMeanVarDim1NHWC4(__read_only image2d_t src_data, __global FLT *mean_, __global FLT *variance_, + int4 in_shape, int normalized_shape_size) { + int X = get_global_id(0); // n*h + int Y = get_global_id(1); // w + if (X > in_shape.x * in_shape.y || Y > in_shape.z || in_shape.y == 0) { + return; + } + int n = X / in_shape.y; + int h = X % in_shape.y; + int w = Y; + int ci4 = UP_DIV(in_shape.w, C4NUM); + int remainder = in_shape.w % C4NUM; + FLT4 mean_temp = {0.0f, 0.0f, 0.0f, 0.0f}; + FLT4 var_temp = {0.0f, 0.0f, 0.0f, 0.0f}; + FLT mean = 0.0f; + FLT var = 0.0f; + + // compute mean + for (int i = 0; i < ci4; ++i) { + FLT4 result_temp = READ_IMAGE(src_data, smp_none, (int2)(w * ci4 + i, n * in_shape.y + h)); + mean_temp += result_temp; + } + mean = (mean_temp.x + mean_temp.y + mean_temp.z + mean_temp.w) / normalized_shape_size; + mean_temp.x = mean_temp.y = mean_temp.z = mean_temp.w = mean; + + // compute var + for (int i = 0; i < ci4; ++i) { + FLT4 result_temp = READ_IMAGE(src_data, smp_none, (int2)(w * ci4 + i, n * in_shape.y + h)); + if ((i + 1) * C4NUM <= in_shape.w) { + var_temp += (result_temp - mean_temp) * (result_temp - mean_temp); + } else { + if (remainder == 1) { + mean_temp.x = mean; + mean_temp.y = mean_temp.z = mean_temp.w = 0.0f; + } else if (remainder == 2) { + mean_temp.x = mean_temp.y = mean; + mean_temp.z = mean_temp.w = 0.0f; + } else { + mean_temp.x = mean_temp.y = mean_temp.z = mean; + mean_temp.w = 0.0f; + } + var_temp += (result_temp - mean_temp) * (result_temp - mean_temp); + } + } + var = (var_temp.x + var_temp.y + var_temp.z + var_temp.w) / normalized_shape_size; + + // write result to dst + int postion = (n * in_shape.y + h) * in_shape.z + w; + mean_[postion] = mean; + variance_[postion] = var; +} + +__kernel void LayerNormalization_NHWC4(__read_only image2d_t src_data, __write_only image2d_t dst_data, + __global FLT *mean_, __global FLT *variance_, __global FLT *gamma_, + __global FLT *beta_, int4 in_shape, float epsilon_, int normalized_dims_, + int elementwise_affine_) { + int X = get_global_id(0); // n*h + int Y = get_global_id(1); // w + int Z = get_global_id(2); // c4 + if (X >= in_shape.x * in_shape.y || Y >= in_shape.z || Z >= in_shape.w || in_shape.y == 0) { + return; + } + int n = X / in_shape.y; + int h = X % in_shape.y; + int w = Y; + int c = Z; + int ci4 = UP_DIV(in_shape.w, C4NUM); + int postion_mv = 0; + int postion_gb = 0; + if (normalized_dims_ == 1) { + postion_mv = (n * in_shape.y + h) * in_shape.z + w; + postion_gb = c * C4NUM; + } else if (normalized_dims_ == 2) { + postion_mv = n * in_shape.y + h; + postion_gb = w * ci4 * C4NUM + c * C4NUM; + } else if (normalized_dims_ == 3) { + postion_mv = n; + postion_gb = (h * in_shape.z + w) * ci4 * C4NUM + c * C4NUM; + } + FLT4 result = {0.0f, 0.0f, 0.0f, 0.0f}; + FLT4 result_in = READ_IMAGE(src_data, smp_none, (int2)(w * ci4 + c, n * in_shape.y + h)); + if (elementwise_affine_) { + result.x = ((result_in.x - mean_[postion_mv]) / sqrt(variance_[postion_mv] + epsilon_)) * gamma_[postion_gb] + + beta_[postion_gb]; + result.y = ((result_in.y - mean_[postion_mv]) / sqrt(variance_[postion_mv] + epsilon_)) * gamma_[postion_gb + 1] + + beta_[postion_gb + 1]; + result.z = ((result_in.z - mean_[postion_mv]) / sqrt(variance_[postion_mv] + epsilon_)) * gamma_[postion_gb + 2] + + beta_[postion_gb + 2]; + result.w = ((result_in.w - mean_[postion_mv]) / sqrt(variance_[postion_mv] + epsilon_)) * gamma_[postion_gb + 3] + + beta_[postion_gb + 3]; + } else { + result.x = ((result_in.x - mean_[postion_mv]) / sqrt(variance_[postion_mv] + epsilon_)); + result.y = ((result_in.y - mean_[postion_mv]) / sqrt(variance_[postion_mv] + epsilon_)); + result.z = ((result_in.z - mean_[postion_mv]) / sqrt(variance_[postion_mv] + epsilon_)); + result.w = ((result_in.w - mean_[postion_mv]) / sqrt(variance_[postion_mv] + epsilon_)); + } + WRITE_IMAGE(dst_data, (int2)((w * ci4 + c), (n * in_shape.y + h)), result); +} diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/layer_norm.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/layer_norm.cc new file mode 100644 index 0000000000..b98e24ff77 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/layer_norm.cc @@ -0,0 +1,250 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include "src/kernel_registry.h" +#include "src/runtime/kernel/opencl/kernel/layer_norm.h" +#include "nnacl/layer_norm_parameter.h" +#include "src/runtime/kernel/opencl/utils.h" +#include "src/runtime/kernel/opencl/cl/layer_norm.cl.inc" + +using mindspore::kernel::KERNEL_ARCH::kGPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_LayerNorm; + +namespace mindspore::kernel { + +int LayerNormOpenCLKernel::CheckSpecs() { + auto param = reinterpret_cast(this->op_parameter_); + if (param->elementwise_mode_ == ELEMENTWISE_PER_CHANNEL) { + if (in_tensors_.size() != 3) { + MS_LOG(ERROR) << " invalid in_tensors_ size" << in_tensors_.size() << std::endl; + return RET_ERROR; + } + if (param->normalized_dims_ > in_tensors_.at(0)->shape().size()) { + MS_LOG(ERROR) << " invalid normalized_shape_ size" << param->normalized_dims_ << std::endl; + return RET_ERROR; + } + } else if (param->elementwise_mode_ == ELEMENTWISE_NOT) { + if (in_tensors_.size() != 1) { + MS_LOG(ERROR) << " invalid in_tensors_ size" << in_tensors_.size() << std::endl; + return RET_ERROR; + } + } else { + MS_LOG(ERROR) << "Unsupported elementwise_mode_" << param->elementwise_mode_; + return RET_ERROR; + } + if (in_tensors_.at(0)->shape().size() != 4 || out_tensors_.size() != 1) { + MS_LOG(ERROR) << "UnSupported in_tensors_.shape.size: " << in_tensors_.at(0)->shape().size() + << " out_tensors_.size(): " << out_tensors_.size(); + return RET_ERROR; + } + if (param->normalized_dims_ != 1) { + MS_LOG(ERROR) << "UnSupported normalized_shape_ size: " << param->normalized_dims_; + return RET_ERROR; + } + return RET_OK; +} + +void LayerNormGetWorkGroup(const std::vector &global, std::vector *local, int max_size) { + const int max_divider = 8; + const int max_x = 4, max_y = 8; + int x = std::min(GetMaxDivisorStrategy1(global[0], max_divider), max_x); + int yz = max_size / x; + int y = std::min(std::min(GetMaxDivisorStrategy1(global[1], max_divider), yz), max_y); + int z = std::min(yz / y, static_cast(UP_DIV(global[2], 2))); + + local->clear(); + local->push_back(x); + local->push_back(y); + local->push_back(z); +} + +void LayerNormOpenCLKernel::SetConstArgs() { + int arg_cn = 6; + ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_shape_); + ocl_runtime_->SetKernelArg(kernel_, arg_cn++, epsilon_); + ocl_runtime_->SetKernelArg(kernel_, arg_cn++, normalized_dims_); + if (elementwise_affine_) { + ocl_runtime_->SetKernelArg(kernel_, arg_cn++, 1); + } else { + ocl_runtime_->SetKernelArg(kernel_, arg_cn++, 0); + } + + ocl_runtime_->SetKernelArg(kernel_mean_var_, 3, in_shape_); + ocl_runtime_->SetKernelArg(kernel_mean_var_, 4, normalized_shape_size_); +} + +void AlignMeanVarGlobalLocal(const std::vector &global, const std::vector &local, cl::NDRange *global_range, + cl::NDRange *local_range) { + *local_range = cl::NDRange(local[0], local[1], local[2]); + *global_range = + cl::NDRange(UP_ROUND(global[0], local[0]), UP_ROUND(global[1], local[1]), UP_ROUND(global[2], local[2])); +} + +void LayerNormOpenCLKernel::SetGlobalLocal() { + size_t OH = 1, OW = 1, OC = 1; + OH = in_shape_.s[0] * in_shape_.s[1]; + OW = in_shape_.s[2]; + OC = UP_DIV(in_shape_.s[3], C4NUM); + local_size_ = {1, 1, 1}; // init local + global_size_ = {OH, OW, OC}; + const std::vector &max_global = ocl_runtime_->GetWorkItemSize(); + LayerNormGetWorkGroup(global_size_, &local_size_, max_global[0]); + OpenCLKernel::AlignGlobalLocal(global_size_, local_size_); + if (normalized_dims_ != in_tensors_.at(0)->shape().size()) { + if (normalized_dims_ == 1) { + OH = in_shape_.s[0] * in_shape_.s[1]; + OW = in_shape_.s[2]; + OC = 1; + } else if (normalized_dims_ == 2) { + OH = in_shape_.s[0] * in_shape_.s[1]; + OW = 1; + OC = 1; + } else { + OH = in_shape_.s[0]; + OW = 1; + OC = 1; + } + } else { + OH = 1; + OW = 1; + OC = 1; + } + AlignMeanVarGlobalLocal({static_cast(OH), static_cast(OW), static_cast(OC)}, {1, 1, 1}, + &global_mean_var_, &local_mean_var_); +} + +int LayerNormOpenCLKernel::Initweight() { + auto allocator = ocl_runtime_->GetAllocator(); + GpuTensorInfo img_info(in_tensors_.at(1)); // gamma + auto weight_tensor = in_tensors_.at(1); + size_t weight_size = img_info.Image2DSize; + // allocated memory for weight and init value + gamma_ = allocator->Malloc(weight_size); + beta_ = allocator->Malloc(weight_size); + allocator->MapBuffer(gamma_, CL_MAP_WRITE, nullptr, true); + allocator->MapBuffer(beta_, CL_MAP_WRITE, nullptr, true); + memset(gamma_, 0x01, weight_size); + memset(beta_, 0x00, weight_size); + + if (weight_tensor->data_type() == kNumberTypeFloat16) { + if (use_fp16_enable_) { + memcpy(gamma_, in_tensors_.at(1)->data_c(), weight_size); + memcpy(beta_, in_tensors_.at(2)->data_c(), weight_size); + } else { + auto gamma_fp32 = reinterpret_cast(gamma_); + auto beta_fp32 = reinterpret_cast(beta_); + auto origin_gamma_fp16 = reinterpret_cast(in_tensors_.at(1)->data_c()); + auto origin_beta_fp16 = reinterpret_cast(in_tensors_.at(2)->data_c()); + + for (int i = 0; i < img_info.ElementsNum; ++i) { + gamma_fp32[i] = static_cast(origin_gamma_fp16[i]); + beta_fp32[i] = static_cast(origin_beta_fp16[i]); + } + } + } else { + if (use_fp16_enable_) { + auto gamma_fp16 = reinterpret_cast(gamma_); + auto beta_fp16 = reinterpret_cast(beta_); + auto origin_gamma_fp32 = reinterpret_cast(in_tensors_.at(1)->data_c()); + auto origin_beta_fp32 = reinterpret_cast(in_tensors_.at(2)->data_c()); + + for (int i = 0; i < img_info.ElementsNum; ++i) { + gamma_fp16[i] = static_cast(origin_gamma_fp32[i]); + beta_fp16[i] = static_cast(origin_beta_fp32[i]); + } + } else { + memcpy(gamma_, in_tensors_.at(1)->data_c(), weight_size); + memcpy(beta_, in_tensors_.at(2)->data_c(), weight_size); + } + } + allocator->UnmapBuffer(gamma_); + allocator->UnmapBuffer(beta_); + return RET_OK; +} + +int LayerNormOpenCLKernel::Prepare() { + use_fp16_enable_ = ocl_runtime_->GetFp16Enable(); + auto param = reinterpret_cast(this->op_parameter_); + elementwise_affine_ = param->elementwise_mode_; + normalized_dims_ = param->normalized_dims_; + epsilon_ = param->epsilon_; + if (elementwise_affine_) { + int ret = Initweight(); + if (ret) { + MS_LOG(ERROR) << "Initweight failed "; + return RET_ERROR; + } + } + auto allocator = ocl_runtime_->GetAllocator(); + size_t mean_size = 1; + size_t size = in_tensors_.at(0)->shape().size() - normalized_dims_; + for (int i = 0; i < size; ++i) { + mean_size *= in_tensors_.at(0)->shape()[i]; + } + size_t size_dtype = use_fp16_enable_ ? sizeof(float16_t) : sizeof(float); + mean_size *= size_dtype; + mean_ = allocator->Malloc(mean_size); + var_ = allocator->Malloc(mean_size); + GpuTensorInfo img_info(in_tensors_.at(0)); + in_shape_.s[0] = img_info.N, in_shape_.s[1] = img_info.H, in_shape_.s[2] = img_info.W, in_shape_.s[3] = img_info.C; + + for (int i = 0; i < normalized_dims_; ++i) { + normalized_shape_size_ *= param->normalized_shape_[i]; + } + std::string kernel_name = "LayerNormalization_NHWC4"; + std::string kernel_name_mean_var = "ComputeMeanVar"; + std::set build_options; + std::string source = layer_norm_source; + std::string program_name = "LayerNormalization"; + ocl_runtime_->LoadSource(program_name, source); + ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name, build_options); + kernel_name_mean_var += "Dim" + std::to_string(normalized_dims_) + "NHWC4"; + ocl_runtime_->BuildKernel(kernel_mean_var_, program_name, kernel_name_mean_var, build_options); + MS_LOG(DEBUG) << kernel_name << " Init Done!"; + SetConstArgs(); + SetGlobalLocal(); + + return RET_OK; +} + +int LayerNormOpenCLKernel::Run() { + MS_LOG(DEBUG) << this->name() << " Running! "; + int arg1_cn = 0; + ocl_runtime_->SetKernelArg(kernel_mean_var_, arg1_cn++, in_tensors_.at(0)->data_c()); // input tensor + ocl_runtime_->SetKernelArg(kernel_mean_var_, arg1_cn++, mean_, lite::opencl::MemType::BUF); // mean_ + ocl_runtime_->SetKernelArg(kernel_mean_var_, arg1_cn++, var_, lite::opencl::MemType::BUF); // var_ return RET_OK; + ocl_runtime_->RunKernel(kernel_mean_var_, global_mean_var_, local_mean_var_, nullptr, &event_); + + int arg_cn = 0; + ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_.at(0)->data_c()); // input tensor + ocl_runtime_->SetKernelArg(kernel_, arg_cn++, out_tensors_.at(0)->data_c()); // out tensor + ocl_runtime_->SetKernelArg(kernel_, arg_cn++, mean_, lite::opencl::MemType::BUF); // mean_ + ocl_runtime_->SetKernelArg(kernel_, arg_cn++, var_, lite::opencl::MemType::BUF); // var_ + ocl_runtime_->SetKernelArg(kernel_, arg_cn++, gamma_, lite::opencl::MemType::BUF); // gamma_ + ocl_runtime_->SetKernelArg(kernel_, arg_cn++, beta_, lite::opencl::MemType::BUF); // beta_ + ocl_runtime_->RunKernel(kernel_, global_range_, local_range_, nullptr, &event_); + return RET_OK; +} + +REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_LayerNorm, OpenCLKernelCreator) +REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_LayerNorm, OpenCLKernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/layer_norm.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/layer_norm.h new file mode 100644 index 0000000000..3bc57c12c1 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/layer_norm.h @@ -0,0 +1,61 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_LAYER_NORM_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_LAYER_NORM_H_ + +#include +#include "src/runtime/kernel/opencl/opencl_kernel.h" + +namespace mindspore::kernel { + +class LayerNormOpenCLKernel : public OpenCLKernel { + public: + LayerNormOpenCLKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : OpenCLKernel(parameter, inputs, outputs) {} + + ~LayerNormOpenCLKernel() override = default; + + int Run() override; + int Prepare() override; + + int CheckSpecs() override; + void SetConstArgs() override; + void SetGlobalLocal() override; + + private: + int Initweight(); + void GetMeanVar(); + + private: + cl::Kernel kernel_mean_var_; + cl::NDRange global_mean_var_, local_mean_var_; + bool use_fp16_enable_{false}; + void *gamma_{nullptr}; + void *mean_{nullptr}; + void *var_{nullptr}; + void *beta_{nullptr}; + cl_int4 in_shape_{}; + int elementwise_affine_; + int32_t normalized_dims_{1}; + int normalized_shape_size_{1}; + float epsilon_{0.0f}; + cl::Kernel kernel_; +}; + +} // namespace mindspore::kernel +#endif diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/common.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/common.cc index 68b8409670..57c7ece607 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/common.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/common.cc @@ -104,6 +104,10 @@ void TestMain(const std::vector &input_infos, std::tupleInit() after construct subgraph like scheduler.cc + MS_LOG(DEBUG) << "call sub_graph->Init()"; + EXPECT_TRUE(sub_graph->Init() == RET_OK); + // simulating benchmark: session_->CompileGraph() -> PrepareKernels() -> OpenCLSubGraph.Prepare() MS_LOG(DEBUG) << "call sub_graph->Prepare()"; EXPECT_TRUE(sub_graph->Prepare() == RET_OK); // will set Tensor's allocator be OpenCLAllocator diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/layer_norm_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/layer_norm_tests.cc new file mode 100644 index 0000000000..c199c5c7d1 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/layer_norm_tests.cc @@ -0,0 +1,62 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ut/src/runtime/kernel/opencl/common.h" +#include "nnacl/layer_norm_parameter.h" + +namespace mindspore::lite::opencl::test { + +class TestOpenCL_LayerNorm : public CommonTest {}; + +namespace { +// PrimitiveType_Stack: src/ops/populate/stack_populate.cc +OpParameter *CreateParameter(float epsilon, int normalized_dims_, std::vector normalizedShape) { + auto *param = test::CreateParameter(schema::PrimitiveType_LayerNorm); + param->elementwise_mode_ = ELEMENTWISE_PER_CHANNEL; + param->epsilon_ = epsilon; + param->normalized_dims_ = normalized_dims_; + for (int i = 0; i < normalizedShape.size() && i < normalized_dims_; ++i) { + param->normalized_shape_[i] = normalizedShape[i]; + } + return reinterpret_cast(param); +} +} // namespace + +TEST_F(TestOpenCL_LayerNorm, test1) { + float epsilon = 1e-5; + int normalized_dims_ = 1; + std::vector normalizedShape = {5}; + std::vector input_shape = {2, 3, 4, 5}; + std::vector gamma_shape = {1, 1, 1, 5}; + std::vector beta_shape = {1, 1, 1, 5}; + std::vector output_shape = {2, 3, 4, 5}; + size_t input_size, gamma_size, beta_size, output_size; + std::string inputPpath = "./test_data/layernormfp32_input.bin"; + std::string gammaPpath = "./test_data/gammafp32_input.bin"; + std::string betaPpath = "./test_data/betafp32_input.bin"; + std::string correctOutputPath = "./test_data/layernormfp32_output.bin"; + auto input_data = reinterpret_cast(mindspore::lite::ReadFile(inputPpath.c_str(), &input_size)); + auto gamma_data = reinterpret_cast(mindspore::lite::ReadFile(gammaPpath.c_str(), &gamma_size)); + auto beta_data = reinterpret_cast(mindspore::lite::ReadFile(betaPpath.c_str(), &beta_size)); + auto output_data = reinterpret_cast(mindspore::lite::ReadFile(correctOutputPath.c_str(), &output_size)); + for (auto fp16_enable : {false}) { + auto *param = CreateParameter(epsilon, normalized_dims_, normalizedShape); + + TestMain( + {{input_shape, input_data, VAR}, {gamma_shape, gamma_data, CONST_TENSOR}, {beta_shape, beta_data, CONST_TENSOR}}, + {output_shape, output_data}, param, fp16_enable, fp16_enable ? 1e-3 : 1e-6); + } +} +} // namespace mindspore::lite::opencl::test diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/test_data/layer_norm/test1/betafp32_input.bin b/mindspore/lite/test/ut/src/runtime/kernel/opencl/test_data/layer_norm/test1/betafp32_input.bin new file mode 100644 index 0000000000..df879cf495 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/opencl/test_data/layer_norm/test1/betafp32_input.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/test_data/layer_norm/test1/gammafp32_input.bin b/mindspore/lite/test/ut/src/runtime/kernel/opencl/test_data/layer_norm/test1/gammafp32_input.bin new file mode 100644 index 0000000000..774852fd9d Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/opencl/test_data/layer_norm/test1/gammafp32_input.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/test_data/layer_norm/test1/layernormfp32_input.bin b/mindspore/lite/test/ut/src/runtime/kernel/opencl/test_data/layer_norm/test1/layernormfp32_input.bin new file mode 100644 index 0000000000..d2a676b6d1 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/test_data/layer_norm/test1/layernormfp32_input.bin @@ -0,0 +1,2 @@ +SvBa9<=c?EyO> =ڿD=0>v+>F?̳'>덿8;8q>?#HsEg>@?yc>7$cKQ?13C??c@^?o?,$2?1e"?]﮿9v?NTl ?Ă> 2?}L>Z~?K +?;|?¿~?~t>#hl? ?r=ﺿ{q>?8>}I?UX~yx?P>(brn?cƁ q|?>0|?o?p???P[@>02Vt=k>0ᾶh>6=X'G=g>^zɉ? ">8|>LjǻMێI?o>>z>/}%龺ɐ&[? \ No newline at end of file diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/test_data/layer_norm/test1/layernormfp32_output.bin b/mindspore/lite/test/ut/src/runtime/kernel/opencl/test_data/layer_norm/test1/layernormfp32_output.bin new file mode 100644 index 0000000000..0a40469cec Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/opencl/test_data/layer_norm/test1/layernormfp32_output.bin differ