GPU Ops Rectification and fix bug

pull/9360/head
Pengyongrong 4 years ago
parent 51d885815a
commit 1fb298b68b

@ -147,7 +147,7 @@ __kernel void ElementFloorMod(__read_only image2d_t input_a, __read_only image2d
FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(X, Y));
FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(X, Y));
FLT4 result = floor(divide_no_check(a, b)) * b;
FLT4 result = a - floor(divide_no_check(a, b)) * b;
result = clamp(result, (FLT)(act_min), (FLT)(act_max));
WRITE_IMAGE(output, (int2)(X, Y), result);
}
@ -445,33 +445,39 @@ __kernel void BroadcastFloorDiv(__read_only image2d_t input_a, float b, __write_
result = clamp(result, (FLT)(act_min), (FLT)(act_max));
WRITE_IMAGE(output, (int2)(X, Y), result);
}
__kernel void BroadcastFloorMod(__read_only image2d_t input_a, float b, __write_only image2d_t output,
const int2 output_shape, float act_min, float act_max) {
int X = get_global_id(0);
int Y = get_global_id(1);
if (X >= output_shape.x || Y >= output_shape.y) {
__kernel void BroadcastNHWC4FloorMod(__read_only image2d_t input_a, __read_only image2d_t input_b,
__write_only image2d_t output, const int4 a_shape, const int4 b_shape,
const int4 output_shape, const int broadcastC_flag, float act_min, float act_max) {
int X = get_global_id(0); // C4
int Y = get_global_id(1); // W
int Z = get_global_id(2); // H
if (X >= output_shape.w || Y >= output_shape.z || Z >= output_shape.y) {
return;
}
FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(X, Y));
FLT4 result = floor(divide_no_check(a, (FLT4)b)) * (FLT)b;
FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(Y * a_shape.w + X, Z));
FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(X, 0));
FLT4 result = a - floor(divide_no_check(a, b)) * b;
result = clamp(result, (FLT)(act_min), (FLT)(act_max));
WRITE_IMAGE(output, (int2)(X, Y), result);
WRITE_IMAGE(output, (int2)(Y * output_shape.w + X, Z), result);
}
__kernel void BroadcastSquaredDifference(__read_only image2d_t input_a, float b, __write_only image2d_t output,
const int2 output_shape, float act_min, float act_max) {
int X = get_global_id(0);
int Y = get_global_id(1);
if (X >= output_shape.x || Y >= output_shape.y) {
__kernel void BroadcastNHWC4SquaredDifference(__read_only image2d_t input_a, __read_only image2d_t input_b,
__write_only image2d_t output, const int4 a_shape, const int4 b_shape,
const int4 output_shape, const int broadcastC_flag, float act_min,
float act_max) {
int X = get_global_id(0); // C4
int Y = get_global_id(1); // w
int Z = get_global_id(2); // H
if (X >= output_shape.w || Y >= output_shape.z || Z >= output_shape.y) {
return;
}
FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(X, Y));
FLT4 result = pown((a - (FLT4)b), (int4)2);
FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(Y * a_shape.w + X, Z));
FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(X, 0));
FLT4 result = pown((a - b), (int4)2);
result = clamp(result, (FLT)(act_min), (FLT)(act_max));
WRITE_IMAGE(output, (int2)(X, Y), result);
WRITE_IMAGE(output, (int2)(Y * output_shape.w + X, Z), result);
}
__kernel void BroadcastEqual(__read_only image2d_t input_a, float b, __write_only image2d_t output,

@ -1,11 +1,11 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#define INT4 int4
#define INT2 int2
#define C4NUM 4
__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST;
__kernel void Batch_normalization_NHWC4(__read_only image2d_t input, __read_only image2d_t scale,
__read_only image2d_t offset, __read_only image2d_t mean,
__read_only image2d_t variance, __write_only image2d_t output,
const INT4 input_shape, float epsilon) {
__kernel void Batch_normalization_NHWC4(__read_only image2d_t input, __global FLT *scale, __global FLT *offset,
__global FLT *mean, __global FLT *variance, __write_only image2d_t output,
const INT4 input_shape, float epsilon, int unalign_input_w) {
int X = get_global_id(0); // H
int Y = get_global_id(1); // W
int Z = get_global_id(2); // C/4
@ -14,38 +14,28 @@ __kernel void Batch_normalization_NHWC4(__read_only image2d_t input, __read_only
}
FLT4 result = READ_IMAGE(input, smp_none, (int2)((Y)*input_shape.w + Z, (X)));
FLT4 result_mean = READ_IMAGE(mean, smp_none, (int2)((Z), (0)));
FLT4 result_var = READ_IMAGE(variance, smp_none, (int2)((Z), (0)));
FLT4 result_scale = READ_IMAGE(scale, smp_none, (int2)((Z), (0)));
FLT4 result_offset = READ_IMAGE(offset, smp_none, (int2)((Z), (0)));
result.x = result_scale.x * ((result.x - result_mean.x) / sqrt(result_var.x + epsilon)) + result_offset.x;
result.y = result_scale.y * ((result.y - result_mean.y) / sqrt(result_var.y + epsilon)) + result_offset.y;
result.z = result_scale.z * ((result.z - result_mean.z) / sqrt(result_var.z + epsilon)) + result_offset.z;
result.w = result_scale.w * ((result.w - result_mean.w) / sqrt(result_var.w + epsilon)) + result_offset.w;
WRITE_IMAGE(output, (int2)((Y)*input_shape.w + Z, (X)), result);
}
__kernel void Batch_normalization_NC4HW4(__read_only image2d_t input, __read_only image2d_t scale,
__read_only image2d_t offset, __read_only image2d_t mean,
__read_only image2d_t variance, __write_only image2d_t output,
const INT4 input_shape, float epsilon) {
int X = get_global_id(0); // H
int Y = get_global_id(1); // W
int Z = get_global_id(2); // C/4
if (X >= input_shape.y || Y >= input_shape.z || Z >= input_shape.w) {
return;
FLT result_mean[4] = {0.0f, 0.0f, 0.0f, 0.0f};
FLT result_var[4] = {0.0f, 0.0f, 0.0f, 0.0f};
FLT result_scale[4] = {1.0f, 1.0f, 1.0f, 1.0f};
FLT result_offset[4] = {0.0f, 0.0f, 0.0f, 0.0f};
if ((Z + 1) * C4NUM <= unalign_input_w) {
for (int i = 0; i < C4NUM; ++i) {
result_mean[i] = mean[Z * C4NUM + i];
result_var[i] = variance[Z * C4NUM + i];
result_scale[i] = scale[Z * C4NUM + i];
result_offset[i] = offset[Z * C4NUM + i];
}
} else {
for (int i = 0; i < unalign_input_w % C4NUM; ++i) {
result_mean[i] = mean[Z * C4NUM + i];
result_var[i] = variance[Z * C4NUM + i];
result_scale[i] = scale[Z * C4NUM + i];
result_offset[i] = offset[Z * C4NUM + i];
}
}
FLT4 result = READ_IMAGE(input, smp_none, (int2)((Y), (Z * input_shape.y + X)));
FLT4 result_mean = READ_IMAGE(mean, smp_none, (int2)((0), (Z)));
FLT4 result_var = READ_IMAGE(variance, smp_none, (int2)((0), (Z)));
FLT4 result_scale = READ_IMAGE(scale, smp_none, (int2)((0), (Z)));
FLT4 result_offset = READ_IMAGE(offset, smp_none, (int2)((0), (Z)));
result.x = result_scale.x * ((result.x - result_mean.x) / sqrt(result_var.x + epsilon)) + result_offset.x;
result.y = result_scale.y * ((result.y - result_mean.y) / sqrt(result_var.y + epsilon)) + result_offset.y;
result.z = result_scale.z * ((result.z - result_mean.z) / sqrt(result_var.z + epsilon)) + result_offset.z;
result.w = result_scale.w * ((result.w - result_mean.w) / sqrt(result_var.w + epsilon)) + result_offset.w;
WRITE_IMAGE(output, (int2)((Y), (Z * input_shape.y + X)), result);
result.x = result_scale[0] * ((result.x - result_mean[0]) / sqrt(result_var[0] + epsilon)) + result_offset[0];
result.y = result_scale[1] * ((result.y - result_mean[1]) / sqrt(result_var[1] + epsilon)) + result_offset[1];
result.z = result_scale[2] * ((result.z - result_mean[2]) / sqrt(result_var[2] + epsilon)) + result_offset[2];
result.w = result_scale[3] * ((result.w - result_mean[3]) / sqrt(result_var[3] + epsilon)) + result_offset[3];
WRITE_IMAGE(output, (int2)((Y)*input_shape.w + Z, (X)), result);
}

@ -1,5 +1,7 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST;
#define UP_DIV(x, y) (((x) + (y) - (1)) / (y))
#define C4NUM 4
#define CHECK_IDX \
int X = get_global_id(0); \
int Y = get_global_id(1); \
@ -22,17 +24,27 @@ FLT OptimizedPowerImpl(FLT x, int exponent) {
return exponent >= 0 ? result : 1 / result;
}
__kernel void power(__read_only image2d_t input0, __read_only image2d_t input1, __write_only image2d_t output,
__kernel void power(__read_only image2d_t input0, __global FLT *input1, __write_only image2d_t output,
int4 output_shape, FLT4 parameter) {
CHECK_IDX;
int n = X / output_shape.y;
int h = X % output_shape.y;
int unalign_w = (int)parameter.w;
FLT4 result;
FLT4 result0 = READ_IMAGE(input0, smp_none, (int2)((Y)*output_shape.w + Z, (n * output_shape.y + h)));
FLT4 result1 = READ_IMAGE(input1, smp_none, (int2)((Y)*output_shape.w + Z, (n * output_shape.y + h)));
int index_weight = (n * output_shape.y + h) * output_shape.z * unalign_w + Y * unalign_w + Z * C4NUM;
FLT tmp_result[4];
FLT tmp_result0[4] = {result0.x, result0.y, result0.z, result0.w};
FLT tmp_result1[4] = {result1.x, result1.y, result1.z, result1.w};
FLT tmp_result1[4] = {0.0f, 0.0f, 0.0f, 0.0f};
if ((Z + 1) * C4NUM <= unalign_w) {
for (int i = 0; i < C4NUM; ++i) {
tmp_result1[i] = input1[index_weight + i];
}
} else {
for (int i = 0; i < unalign_w % C4NUM; ++i) {
tmp_result1[i] = input1[index_weight + i];
}
}
for (int i = 0; i < 4; ++i) {
tmp_result0[i] = tmp_result0[i] * parameter.z + parameter.y;
if (floor(tmp_result1[i]) == tmp_result1[i]) {

@ -30,7 +30,13 @@ using mindspore::schema::PrimitiveType_BatchNorm;
namespace mindspore::kernel {
int BatchNormOpenCLKernel::CheckSpecs() { return RET_OK; }
int BatchNormOpenCLKernel::CheckSpecs() {
if (in_tensors_.at(0)->shape()[0] > 1) {
MS_LOG(ERROR) << " Unsupported batch_size >1 ";
return RET_ERROR;
}
return RET_OK;
}
void BatchNormGetWorkGroup(const std::vector<size_t> &global, std::vector<size_t> *local, int max_size) {
const int max_divider = 8;
@ -49,17 +55,19 @@ void BatchNormGetWorkGroup(const std::vector<size_t> &global, std::vector<size_t
void BatchNormOpenCLKernel::SetConstArgs() {
int arg_cn = 6;
auto param = reinterpret_cast<BatchNormParameter *>(this->op_parameter_);
auto input0_shape = in_tensors_[0]->shape();
cl_int4 input_shape_ = {input0_shape[0], input0_shape[1], input0_shape[2], UP_DIV(input0_shape[3], C4NUM)};
auto input0_shape = in_tensors_.at(0)->shape();
cl_int4 input_shape_ = {input0_shape.at(0), input0_shape.at(1), input0_shape.at(2),
UP_DIV(input0_shape.at(3), C4NUM)};
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, input_shape_);
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, param->epsilon_);
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, input0_shape.at(3));
}
void BatchNormOpenCLKernel::SetGlobalLocal() {
auto output_shape = out_tensors_[0]->shape();
uint32_t OH = output_shape[1];
uint32_t OW = output_shape[2];
uint32_t OC = UP_DIV(output_shape[3], C4NUM);
auto output_shape = out_tensors_.at(0)->shape();
uint32_t OH = output_shape.at(1);
uint32_t OW = output_shape.at(2);
uint32_t OC = UP_DIV(output_shape.at(3), C4NUM);
const std::vector<size_t> &max_global = ocl_runtime_->GetWorkItemSize();
local_size_ = {1, 1, 1}; // init local
@ -68,7 +76,86 @@ void BatchNormOpenCLKernel::SetGlobalLocal() {
OpenCLKernel::AlignGlobalLocal(global_size_, local_size_);
}
int BatchNormOpenCLKernel::Initweight() {
auto allocator = ocl_runtime_->GetAllocator();
GpuTensorInfo img_info(in_tensors_.at(1));
auto weight_tensor = in_tensors_.at(1);
size_t weight_size = img_info.OriginSize;
// allocated memory for weight and init value
scale_ = allocator->Malloc(weight_size);
offset_ = allocator->Malloc(weight_size);
mean_ = allocator->Malloc(weight_size);
variance_ = allocator->Malloc(weight_size);
allocator->MapBuffer(scale_, CL_MAP_WRITE, nullptr, true);
allocator->MapBuffer(offset_, CL_MAP_WRITE, nullptr, true);
allocator->MapBuffer(mean_, CL_MAP_WRITE, nullptr, true);
allocator->MapBuffer(variance_, CL_MAP_WRITE, nullptr, true);
memset(scale_, 1, weight_size);
memset(offset_, 0x00, weight_size);
memset(mean_, 0x00, weight_size);
memset(variance_, 0x00, weight_size);
if (weight_tensor->data_type() == kNumberTypeFloat16) {
if (use_fp16_enable_) {
memcpy(scale_, in_tensors_.at(1)->data_c(), weight_size);
memcpy(offset_, in_tensors_.at(2)->data_c(), weight_size);
memcpy(mean_, in_tensors_.at(3)->data_c(), weight_size);
memcpy(variance_, in_tensors_.at(4)->data_c(), weight_size);
} else {
auto scale_fp32 = reinterpret_cast<float *>(scale_);
auto offset_fp32 = reinterpret_cast<float *>(offset_);
auto mean_fp32 = reinterpret_cast<float *>(mean_);
auto variance_fp32 = reinterpret_cast<float *>(variance_);
auto origin_scale_fp16 = reinterpret_cast<float16_t *>(in_tensors_.at(1)->data_c());
auto origin_offset_fp16 = reinterpret_cast<float16_t *>(in_tensors_.at(2)->data_c());
auto origin_mean_fp16 = reinterpret_cast<float16_t *>(in_tensors_.at(3)->data_c());
auto origin_variance_fp16 = reinterpret_cast<float16_t *>(in_tensors_.at(4)->data_c());
for (int i = 0; i < img_info.ElementsNum; ++i) {
scale_fp32[i] = static_cast<float>(origin_scale_fp16[i]);
offset_fp32[i] = static_cast<float>(origin_offset_fp16[i]);
mean_fp32[i] = static_cast<float>(origin_mean_fp16[i]);
variance_fp32[i] = static_cast<float>(origin_variance_fp16[i]);
}
}
} else {
if (use_fp16_enable_) {
auto scale_fp16 = reinterpret_cast<float16_t *>(scale_);
auto offset_fp16 = reinterpret_cast<float16_t *>(offset_);
auto mean_fp16 = reinterpret_cast<float16_t *>(mean_);
auto variance_fp16 = reinterpret_cast<float16_t *>(variance_);
auto origin_scale_fp32 = reinterpret_cast<float *>(in_tensors_.at(1)->data_c());
auto origin_offset_fp32 = reinterpret_cast<float *>(in_tensors_.at(2)->data_c());
auto origin_mean_fp32 = reinterpret_cast<float *>(in_tensors_.at(3)->data_c());
auto origin_variance_fp32 = reinterpret_cast<float *>(in_tensors_.at(4)->data_c());
for (int i = 0; i < img_info.ElementsNum; ++i) {
scale_fp16[i] = static_cast<float16_t>(origin_scale_fp32[i]);
offset_fp16[i] = static_cast<float16_t>(origin_offset_fp32[i]);
mean_fp16[i] = static_cast<float16_t>(origin_mean_fp32[i]);
variance_fp16[i] = static_cast<float16_t>(origin_variance_fp32[i]);
}
} else {
memcpy(scale_, in_tensors_.at(1)->data_c(), weight_size);
memcpy(offset_, in_tensors_.at(2)->data_c(), weight_size);
memcpy(mean_, in_tensors_.at(3)->data_c(), weight_size);
memcpy(variance_, in_tensors_.at(4)->data_c(), weight_size);
}
}
allocator->UnmapBuffer(scale_);
allocator->UnmapBuffer(offset_);
allocator->UnmapBuffer(mean_);
allocator->UnmapBuffer(variance_);
return RET_OK;
}
int BatchNormOpenCLKernel::Prepare() {
use_fp16_enable_ = ocl_runtime_->GetFp16Enable();
std::string kernel_name = "Batch_normalization_NHWC4";
std::set<std::string> build_options;
std::string source = batchnorm_source;
@ -76,6 +163,11 @@ int BatchNormOpenCLKernel::Prepare() {
ocl_runtime_->LoadSource(program_name, source);
ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name, build_options);
MS_LOG(DEBUG) << kernel_name << " Init Done!";
int ret = Initweight();
if (ret) {
MS_LOG(ERROR) << "Initweight failed ";
return RET_ERROR;
}
SetConstArgs();
SetGlobalLocal();
@ -85,12 +177,12 @@ int BatchNormOpenCLKernel::Prepare() {
int BatchNormOpenCLKernel::Run() {
MS_LOG(DEBUG) << this->name() << " Running! ";
int arg_cn = 0;
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_[0]->data_c()); // input tensor
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_[1]->data_c()); // scale
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_[2]->data_c()); // offest
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_[3]->data_c()); // mean
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_[4]->data_c()); // variance
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, out_tensors_[0]->data_c()); // out tensor
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_.at(0)->data_c()); // input tensor
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, scale_, lite::opencl::MemType::BUF); // scale
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, offset_, lite::opencl::MemType::BUF); // offest
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, mean_, lite::opencl::MemType::BUF); // mean
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, variance_, lite::opencl::MemType::BUF); // variance
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, out_tensors_.at(0)->data_c()); // out tensor
ocl_runtime_->RunKernel(kernel_, global_range_, local_range_, nullptr, &event_);
return RET_OK;
}

@ -39,6 +39,14 @@ class BatchNormOpenCLKernel : public OpenCLKernel {
void SetGlobalLocal() override;
private:
int Initweight();
private:
bool use_fp16_enable_{false};
void *scale_{nullptr};
void *offset_{nullptr};
void *mean_{nullptr};
void *variance_{nullptr};
cl::Kernel kernel_;
};

@ -200,6 +200,6 @@ int ConcatOpenCLKernel::Run() {
return RET_OK;
}
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Concat, OpenCLKernelCreator<ConcatOpenCLKernel>);
REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_Concat, OpenCLKernelCreator<ConcatOpenCLKernel>);
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Concat, OpenCLKernelCreator<ConcatOpenCLKernel>)
REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_Concat, OpenCLKernelCreator<ConcatOpenCLKernel>)
} // namespace mindspore::kernel

@ -31,31 +31,54 @@ using mindspore::schema::PrimitiveType_Power;
namespace mindspore::kernel {
int PowerOpenCLKernel::Init() {
use_fp16_enable_ = ocl_runtime_->GetFp16Enable();
int PowerOpenCLKernel::CheckSpecs() {
auto param = reinterpret_cast<PowerParameter *>(this->op_parameter_);
std::string kernel_name = "power";
std::string source = power_source;
std::string program_name = "power";
broadcast_ = param->broadcast_;
if (!(broadcast_ && in_tensors_.size() == 1)) {
if (in_tensors_.size() == 2 && in_tensors_.at(0)->shape().size() != in_tensors_.at(1)->shape().size()) {
MS_LOG(ERROR) << "Unsupported input->shape.size " << in_tensors_.at(0)->shape().size()
<< "!=" << in_tensors_.at(1)->shape().size();
return RET_ERROR;
} else if (in_tensors_.size() > 2 || in_tensors_.at(0)->shape().size() > 4) {
MS_LOG(ERROR) << "Unsupported in_tensors_->shape.size " << in_tensors_.size() << " or "
<< "in_tensors_[0]->shape().size(): " << in_tensors_.at(0)->shape().size();
return RET_ERROR;
}
}
return RET_OK;
}
if (in_tensors_.size() == 2 && in_tensors_[0]->shape().size() != in_tensors_[1]->shape().size()) {
MS_LOG(ERROR) << "Unsupported input0->shape.size " << in_tensors_[0]->shape().size()
<< "!=" << in_tensors_[1]->shape().size();
return RET_ERROR;
} else if (in_tensors_.size() > 2 || in_tensors_[0]->shape().size() > 4) {
MS_LOG(ERROR) << "Unsupported in_tensors_->shape.size " << in_tensors_.size() << " or "
<< "in_tensors_[0]->shape().size(): " << in_tensors_[0]->shape().size();
return RET_ERROR;
} else if (broadcast_ && in_tensors_.size() == 1) {
power_ = param->power_;
kernel_name += "_broadcast";
int PowerOpenCLKernel::Initweight() {
auto allocator = ocl_runtime_->GetAllocator();
GpuTensorInfo img_info(in_tensors_.at(1));
auto weight_tensor = in_tensors_.at(1);
size_t weight_size = img_info.OriginSize;
weight_ = allocator->Malloc(weight_size);
allocator->MapBuffer(weight_, CL_MAP_WRITE, nullptr, true);
memset(weight_, 0x00, weight_size);
if (weight_tensor->data_type() == kNumberTypeFloat16) {
if (use_fp16_enable_) {
memcpy(weight_, weight_tensor->data_c(), weight_size);
} else {
auto weight_fp32 = reinterpret_cast<float *>(weight_);
auto origin_bias_fp16 = reinterpret_cast<float16_t *>(weight_tensor->data_c());
for (int i = 0; i < img_info.ElementsNum; ++i) {
weight_fp32[i] = static_cast<float>(origin_bias_fp16[i]);
}
}
} else {
if (use_fp16_enable_) {
auto weight_fp16 = reinterpret_cast<float16_t *>(weight_);
auto origin_bias_fp32 = reinterpret_cast<float *>(weight_tensor->data_c());
for (int i = 0; i < img_info.ElementsNum; ++i) {
weight_fp16[i] = static_cast<float16_t>(origin_bias_fp32[i]);
}
} else {
memcpy(weight_, weight_tensor->data_c(), weight_size);
}
}
scale_ = param->scale_;
shift_ = param->shift_;
ocl_runtime_->LoadSource(program_name, source);
ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name);
MS_LOG(DEBUG) << kernel_name << " Init Done!";
allocator->UnmapBuffer(weight_);
return RET_OK;
}
@ -73,87 +96,83 @@ void PowerGetWorkGroup(const std::vector<size_t> &global, std::vector<size_t> *l
local->push_back(z);
}
int PowerOpenCLKernel::InferShapeTo4D() {
if (in_tensors_[0]->shape().size() <= 4) {
if (in_tensors_[0]->shape().size() == 1) {
N_ = in_tensors_[0]->shape()[0];
} else if (in_tensors_[0]->shape().size() == 2) {
N_ = in_tensors_[0]->shape()[0];
C_ = in_tensors_[0]->shape()[1];
} else if (in_tensors_[0]->shape().size() == 3) {
N_ = in_tensors_[0]->shape()[0];
W_ = in_tensors_[0]->shape()[1];
C_ = in_tensors_[0]->shape()[2];
} else {
N_ = in_tensors_[0]->shape()[0];
H_ = in_tensors_[0]->shape()[1];
W_ = in_tensors_[0]->shape()[2];
C_ = in_tensors_[0]->shape()[3];
}
} else {
MS_LOG(ERROR) << "Unsupported inputdim: " << in_tensors_[0]->shape().size();
return RET_ERROR;
}
return RET_OK;
}
int PowerOpenCLKernel::Run() {
MS_LOG(DEBUG) << this->name() << " Running! ";
auto output_shape = out_tensors_[0]->shape();
InferShapeTo4D();
cl_int4 output_shape_ = {static_cast<cl_int>(N_), static_cast<cl_int>(H_), static_cast<cl_int>(W_),
static_cast<cl_int>(UP_DIV(C_, C4NUM))};
const std::vector<size_t> &max_global = ocl_runtime_->GetWorkItemSize();
std::vector<size_t> local = {1, 1, 1};
uint32_t OH = N_ * H_;
uint32_t OW = W_;
uint32_t OC = UP_DIV(C_, C4NUM);
std::vector<size_t> global = {OH, OW, OC};
PowerGetWorkGroup(global, &local, max_global[0]);
int arg_cn = 0;
if (broadcast_) {
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_[0]->data_c());
void PowerOpenCLKernel::SetConstArgs() {
float unalign_w = static_cast<float>(out_shape_.s[3]);
out_shape_.s[3] = UP_DIV(out_shape_.s[3], C4NUM);
int arg_cn = 2;
if (!broadcast_) {
arg_cn++;
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, out_shape_);
} else {
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_[0]->data_c());
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_[1]->data_c());
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, out_shape_);
}
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, out_tensors_[0]->data_c());
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, output_shape_);
if (use_fp16_enable_) {
auto x = static_cast<float16_t>(power_);
auto y = static_cast<float16_t>(shift_);
auto z = static_cast<float16_t>(scale_);
auto w = static_cast<float16_t>(unalign_w);
cl_half4 parameter = {*(reinterpret_cast<uint16_t *>(&x)), *(reinterpret_cast<uint16_t *>(&y)),
*(reinterpret_cast<uint16_t *>(&z)), 1};
*(reinterpret_cast<uint16_t *>(&z)), *(reinterpret_cast<uint16_t *>(&w))};
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, parameter);
} else {
cl_float4 parameter = {power_, shift_, scale_, 1};
cl_float4 parameter = {power_, shift_, scale_, unalign_w};
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, parameter);
}
AlignGlobalLocal(global, local);
ocl_runtime_->RunKernel(kernel_, global_range_, local_range_);
return RET_OK;
}
kernel::LiteKernel *PowerOpenCLKernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter,
const lite::InnerContext *ctx, const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
auto *kernel = new (std::nothrow) PowerOpenCLKernel(opParameter, inputs, outputs);
if (kernel == nullptr) {
MS_LOG(ERROR) << " new PowerOpenCLKernel failed ";
free(opParameter);
return nullptr;
void PowerOpenCLKernel::SetGlobalLocal() {
cl_int4 output_shape = {};
for (int i = 0; i < out_tensors_.at(0)->shape().size(); ++i) {
output_shape.s[i] = out_tensors_.at(0)->shape()[i];
}
Broadcast2GpuShape(out_shape_.s, output_shape.s, out_tensors_.at(0)->shape().size(), 1);
const std::vector<size_t> &max_global = ocl_runtime_->GetWorkItemSize();
std::vector<size_t> local_size_ = {1, 1, 1};
uint32_t OH = out_shape_.s[0] * out_shape_.s[1];
uint32_t OW = out_shape_.s[2];
uint32_t OC = UP_DIV(out_shape_.s[3], C4NUM);
std::vector<size_t> global_size_ = {OH, OW, OC};
PowerGetWorkGroup(global_size_, &local_size_, max_global[0]);
OpenCLKernel::AlignGlobalLocal(global_size_, local_size_);
}
int PowerOpenCLKernel::Prepare() {
use_fp16_enable_ = ocl_runtime_->GetFp16Enable();
auto param = reinterpret_cast<PowerParameter *>(this->op_parameter_);
std::string kernel_name = "power";
std::string source = power_source;
std::string program_name = "power";
broadcast_ = param->broadcast_;
if (broadcast_ && in_tensors_.size() == 1) {
power_ = param->power_;
kernel_name += "_broadcast";
} else {
Initweight();
}
auto ret = kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << " Init kernel failed, name: Power ";
delete kernel;
return nullptr;
scale_ = param->scale_;
shift_ = param->shift_;
ocl_runtime_->LoadSource(program_name, source);
ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name);
MS_LOG(DEBUG) << kernel_name << " Init Done!";
SetGlobalLocal();
SetConstArgs();
return RET_OK;
}
int PowerOpenCLKernel::Run() {
MS_LOG(DEBUG) << this->name() << " Running! ";
int arg_cn = 0;
if (broadcast_) {
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_.at(0)->data_c());
} else {
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_.at(0)->data_c());
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, weight_, lite::opencl::MemType::BUF);
}
return kernel;
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, out_tensors_.at(0)->data_c());
ocl_runtime_->RunKernel(kernel_, global_range_, local_range_);
return RET_OK;
}
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Power, PowerOpenCLKernelCreator)
REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_Power, PowerOpenCLKernelCreator)
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Power, OpenCLKernelCreator<PowerOpenCLKernel>)
REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_Power, OpenCLKernelCreator<PowerOpenCLKernel>)
} // namespace mindspore::kernel

@ -31,20 +31,20 @@ class PowerOpenCLKernel : public OpenCLKernel {
~PowerOpenCLKernel() override = default;
int Init() override;
int Prepare() override;
int CheckSpecs() override;
void SetConstArgs() override;
void SetGlobalLocal() override;
int Run() override;
private:
int InferShapeTo4D();
int Initweight();
private:
size_t N_{1};
size_t H_{1};
size_t W_{1};
size_t C_{1};
cl_int4 out_shape_{};
bool broadcast_{false};
bool use_fp16_enable_{false};
void *weight_{nullptr};
float power_{1.0};
float scale_{0.0};
float shift_{1.0};

@ -134,7 +134,9 @@ void StackOpenCLKernel::SetGlobalLocal() {
int StackOpenCLKernel::Prepare() {
enable_fp16_ = ocl_runtime_->GetFp16Enable();
if (axis_ == 0) {
return RET_OK;
}
if (in_tensors_[0]->shape().size() == 1 && axis_ == 1) {
axis_ += 2;
} else if (in_tensors_[0]->shape().size() == axis_) {

@ -112,6 +112,74 @@ TEST_F(TestOpenCL_Arithmetic, BroadcastSub2) {
}
}
TEST_F(TestOpenCL_Arithmetic, BroadcastFloorMod) {
std::vector<int> input0_shape = {1, 1, 3, 4};
std::vector<int> input1_shape = {1, 1, 1, 4};
std::vector<int> output_shape = {1, 1, 3, 4};
float input0_data[] = {1.1, -1.1, 3.123, -5.432, 0.1234, -0.0312, 12.1, 21.1, 9.1, 9.0, -100, 0.1};
float input1_data[] = {1, 3, 2, 0.3};
float output_data[] = {0.100000, 1.900000, 1.123000, 0.268000, 0.123400, 2.968800,
0.100000, 0.100000, 0.100000, 0.000000, 0.000000, 0.100000};
for (auto fp16_enable : {true, false}) {
auto *param = CreateParameter(schema::PrimitiveType_FloorMod, input0_shape, input1_shape);
TestMain({{input0_shape, input0_data, VAR}, {input1_shape, input1_data, CONST_TENSOR}}, {output_shape, output_data},
param, fp16_enable, fp16_enable ? 1e-2 : 1e-6);
}
}
TEST_F(TestOpenCL_Arithmetic, FloorMod) {
std::vector<int> input0_shape = {1, 1, 3, 4};
std::vector<int> input1_shape = {1, 1, 3, 4};
std::vector<int> output_shape = {1, 1, 3, 4};
float input0_data[] = {1.1, -1.1, 3.123, -5.432, 0.1234, -0.0312, 12.1, 21.1, 9.1, 9.0, -100, 0.1};
float input1_data[] = {1, 3, 2, 0.3, 1, 3, 2, 0.3, 1, 3, 2, 0.3};
float output_data[] = {0.100000, 1.900000, 1.123000, 0.268000, 0.123400, 2.968800,
0.100000, 0.100000, 0.100000, 0.000000, 0.000000, 0.100000};
for (auto fp16_enable : {true, false}) {
auto *param = CreateParameter(schema::PrimitiveType_FloorMod, input0_shape, input1_shape);
TestMain({{input0_shape, input0_data, VAR}, {input1_shape, input1_data, CONST_TENSOR}}, {output_shape, output_data},
param, fp16_enable, fp16_enable ? 1e-2 : 1e-6);
}
}
TEST_F(TestOpenCL_Arithmetic, FloorModFile) {
std::vector<int> input0_shape = {1, 3, 4, 5};
std::vector<int> input1_shape = {1, 3, 4, 5};
std::vector<int> output_shape = {1, 3, 4, 5};
size_t input1_size, input2_size, output_size;
std::string input1Ppath = "./test_data/FloodModfp32_input1.bin";
std::string input2Ppath = "./test_data/FloodModfp32_input2.bin";
std::string correctOutputPath = "./test_data/FloodModfp32_output.bin";
auto input0_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input1Ppath.c_str(), &input1_size));
auto input1_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input2Ppath.c_str(), &input2_size));
auto output_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(correctOutputPath.c_str(), &output_size));
for (auto fp16_enable : {true}) {
auto *param = CreateParameter(schema::PrimitiveType_FloorMod, input0_shape, input1_shape);
TestMain({{input0_shape, input0_data, VAR}, {input1_shape, input1_data, CONST_TENSOR}}, {output_shape, output_data},
param, fp16_enable, fp16_enable ? 1e-2 : 1e-7);
}
}
TEST_F(TestOpenCL_Arithmetic, SquaredDifference) {
std::vector<int> input0_shape = {1, 512, 1, 5};
std::vector<int> input1_shape = {1, 1, 1, 5};
std::vector<int> output_shape = {1, 512, 1, 5};
size_t input1_size, input2_size, output_size;
std::string input1Ppath = "./test_data/SquaredDifferencefp32_input1.bin";
std::string input2Ppath = "./test_data/SquaredDifferencefp32_input2.bin";
std::string correctOutputPath = "./test_data/SquaredDifferencefp32_output.bin";
auto input0_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input1Ppath.c_str(), &input1_size));
auto input1_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input2Ppath.c_str(), &input2_size));
auto output_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(correctOutputPath.c_str(), &output_size));
for (auto fp16_enable : {true}) {
auto *param = CreateParameter(schema::PrimitiveType_SquaredDifference, input0_shape, input1_shape);
TestMain({{input0_shape, input0_data, VAR}, {input1_shape, input1_data, CONST_TENSOR}}, {output_shape, output_data},
param, fp16_enable, fp16_enable ? 1e-2 : 1e-9);
}
}
TEST_F(TestOpenCL_Arithmetic, ElementwiseDiv) {
std::vector<int> input0_shape = {1, 2, 2, 3};
std::vector<int> input1_shape = {1, 2, 2, 3};

@ -29,9 +29,9 @@ OpParameter *CreateParameter(float epsilon) {
}
} // namespace
TEST_F(TestOpenCL_BatchNorm, test0) {
TEST_F(TestOpenCL_BatchNorm, Align) {
std::vector<int> input_shape = {1, 2, 2, 8};
std::vector<int> weight_shape = {1, 1, 1, input_shape[3]};
std::vector<int> weight_shape = {1, 1, 1, 8};
std::vector<int> output_shape = {1, 2, 2, 8};
float input_data[] = {2.471454, -2.1379554, -0.0904604, 1.2928944, -0.19215967, -0.8677279, -0.12759617,
1.2242758, -0.06398406, -0.4041858, 0.20352598, -2.067808, 0.52113044, -1.567617,
@ -51,10 +51,38 @@ TEST_F(TestOpenCL_BatchNorm, test0) {
for (auto fp16_enable : {false, true}) {
auto *param = CreateParameter(1e-5);
TestMain({{input_shape, input_data, VAR},
{weight_shape, scale_data, VAR},
{weight_shape, offset_data, VAR},
{weight_shape, mean_data, VAR},
{weight_shape, var_data, VAR}},
{weight_shape, scale_data, CONST_TENSOR},
{weight_shape, offset_data, CONST_TENSOR},
{weight_shape, mean_data, CONST_TENSOR},
{weight_shape, var_data, CONST_TENSOR}},
{output_shape, output_data}, param, fp16_enable, fp16_enable ? 1e-3 : 1e-5);
}
}
TEST_F(TestOpenCL_BatchNorm, UnAlign) {
std::vector<int> input_shape = {1, 2, 2, 7};
std::vector<int> weight_shape = {1, 1, 1, 7};
std::vector<int> output_shape = {1, 2, 2, 7};
float input_data[] = {2.471454, -2.1379554, -0.0904604, 1.2928944, -0.19215967, -0.8677279, -0.12759617,
-0.06398406, -0.4041858, 0.20352598, -2.067808, 0.52113044, -1.567617, 0.28003863,
0.77298605, 0.29908583, 1.4015813, 1.330567, 1.760135, 0.6320845, 0.6995399,
-1.9738104, -1.3283046, 1.022744, 0.02741058, 0.84505165, -0.89434445, 1.983211};
float scale_data[] = {0.1201471, 0.142174, 0.5683258, 0.86815494, 0.23426804, 0.3634345, 0.0077846};
float offset_data[] = {0.58764684, 0.70790595, 0.945536, 0.8817803, 0.78489226, 0.5884778, 0.3441211};
float mean_data[] = {0.3016613, -0.89284, 0.63434774, 0.145766, 0.73353934, -0.6744012, 0.7087985};
float var_data[] = {2.5604038, 0.84985304, 0.36261332, 1.9083935, 0.4920925, 0.6476224, 0.6269014};
float output_data[] = {0.7505676, 0.515882, 0.26147857, 1.6026789, 0.47575232, 0.50116986, 0.33589783,
0.56019205, 0.7832671, 0.53893626, -0.5093127, 0.71395767, 0.18509413, 0.33990562,
0.6230367, 0.89172685, 1.6696336, 1.6263539, 1.1277269, 1.1784974, 0.34403008,
0.4167911, 0.6407478, 1.3120956, 0.80740136, 0.8221321, 0.4891496, 0.3566509};
for (auto fp16_enable : {false, true}) {
auto *param = CreateParameter(1e-5);
TestMain({{input_shape, input_data, VAR},
{weight_shape, scale_data, CONST_TENSOR},
{weight_shape, offset_data, CONST_TENSOR},
{weight_shape, mean_data, CONST_TENSOR},
{weight_shape, var_data, CONST_TENSOR}},
{output_shape, output_data}, param, fp16_enable, fp16_enable ? 1e-3 : 1e-5);
}
}

@ -13,12 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <memory>
#include "src/common/log_adapter.h"
#include "common/common_test.h"
#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h"
#include "mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.h"
#include "ut/src/runtime/kernel/opencl/common.h"
#include "mindspore/lite/src/runtime/kernel/opencl/kernel/power.h"
// PrimitiveType_Power: src/ops/populate/power_populate.cc
@ -30,142 +25,105 @@ class TestPowerOpenCLCI : public CommonTest {
public:
TestPowerOpenCLCI() {}
};
template <class T>
void CompareData(const T *output_data, const T *correct_data, int size, float err_bound) {
for (int i = 0; i < size; i++) {
T abs = fabs(output_data[i] - correct_data[i]);
ASSERT_LE(abs, err_bound);
}
}
template <class T>
void TEST_MAIN(const T *input_data1, const T *input_data2, const T *expect_data, const TypeId data_type,
const std::vector<int> &shape_a, const std::vector<int> &shape_b, const std::vector<int> &out_shape,
bool broadcast, const T scale = 1.0, const T shift = 0, const T exponent = 2) {
MS_LOG(INFO) << " begin test ";
auto runtime_wrapper = lite::opencl::OpenCLRuntimeWrapper();
auto runtime = runtime_wrapper.GetInstance();
runtime->Init();
if (data_type == kNumberTypeFloat16) {
runtime->SetFp16Enable(true);
}
auto allocator = runtime->GetAllocator();
auto tensor_type = lite::Tensor::CONST_TENSOR;
auto in_tensor1 = Tensor(data_type, shape_a, Format_NHWC, tensor_type);
auto in_tensor2 = Tensor(data_type, shape_b, Format_NHWC, tensor_type);
auto output_tensor = Tensor(data_type, out_shape, Format_NHWC, tensor_type);
MS_LOG(INFO) << " initialize tensors ";
auto param = reinterpret_cast<PowerParameter *>(malloc(sizeof(PowerParameter)));
if (param == nullptr) {
MS_LOG(INFO) << " new ActivationParameter failed ";
return;
}
param->scale_ = scale;
param->shift_ = shift;
std::vector<lite::Tensor *> inputs;
std::vector<lite::Tensor *> outputs{&output_tensor};
if (broadcast) {
param->broadcast_ = true;
inputs.push_back(&in_tensor1);
param->power_ = exponent;
} else {
inputs.push_back(&in_tensor1);
inputs.push_back(&in_tensor2);
}
auto *power_kernel =
new (std::nothrow) kernel::PowerOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs);
if (power_kernel == nullptr) {
MS_LOG(INFO) << " new kernel::PowerOpenCLKernel failed ";
delete param;
return;
}
power_kernel->Init();
// to do allocate memory for inputs
for (auto &input_tensor : inputs) {
input_tensor->MallocData(allocator);
}
MS_LOG(INFO) << " initialize sub_graph ";
std::vector<kernel::LiteKernel *> kernels{power_kernel};
auto *sub_graph = new (std::nothrow) kernel::OpenCLSubGraph(inputs, outputs, kernels, kernels, kernels);
if (sub_graph == nullptr) {
MS_LOG(INFO) << " new kernel::OpenCLSubGraph failed ";
delete param;
delete power_kernel;
return;
}
sub_graph->Init();
MS_LOG(INFO) << " initialize input data ";
size_t size = 1 * sizeof(T);
for (int i = 0; i < out_shape.size(); ++i) {
size *= out_shape[i];
}
if (broadcast) {
memcpy(inputs[0]->data_c(), input_data1, size);
} else {
memcpy(inputs[0]->data_c(), input_data1, size);
memcpy(inputs[1]->data_c(), input_data2, size);
}
std::cout << "==================output data================" << std::endl;
sub_graph->Run();
T *output_data_gpu = reinterpret_cast<T *>(output_tensor.data_c());
CompareData(output_data_gpu, expect_data, output_tensor.ElementsNum(), 0.0001);
delete sub_graph;
// PrimitiveType_Concat: src/ops/populate/concat_populate.cc
OpParameter *CreateParameter(bool broadcast_, float shift_, float scale_, float power_ = 2) {
auto *param = test::CreateParameter<PowerParameter>(schema::PrimitiveType_Power);
param->power_ = power_;
param->broadcast_ = broadcast_;
param->shift_ = shift_;
param->scale_ = scale_;
return reinterpret_cast<OpParameter *>(param);
}
TEST_F(TestPowerOpenCLCI, Int32CI) {
MS_LOG(INFO) << " init tensors ";
std::vector<int> shape_a = {1, 2, 8};
std::vector<int> shape_b = {1, 2, 8};
std::vector<int> input0_shape = {1, 2, 8};
std::vector<int> input1_shape = {1, 2, 8};
std::vector<int> output_shape = {1, 2, 8};
auto data_type = kNumberTypeFloat32;
const float input_data1[] = {2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0};
const float input_data2[] = {2, 2, 2, 1, 2, 2, 3, 3, 2, 2, 3, 0, 2, 2, 1, 2};
const float expect_data[] = {4.0, 9.0, 16.0, 5.0, 36.0, 49.0, 512, 729,
100.0, 121.0, 1728.0, 1.0, 196.0, 225.0, 16.0, 289.0};
TEST_MAIN(input_data1, input_data2, expect_data, data_type, shape_a, shape_b, output_shape, false);
bool broadcast_ = false;
float shift_ = 0;
float scale_ = 1;
float input0_data[] = {2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0};
float input1_data[] = {2, 2, 2, 1, 2, 2, 3, 3, 2, 2, 3, 0, 2, 2, 1, 2};
float output_data[] = {4.0, 9.0, 16.0, 5.0, 36.0, 49.0, 512, 729,
100.0, 121.0, 1728.0, 1.0, 196.0, 225.0, 16.0, 289.0};
for (auto fp16_enable : {false, true}) {
auto *param = CreateParameter(broadcast_, shift_, scale_);
TestMain({{input0_shape, input0_data, VAR}, {input1_shape, input1_data, CONST_TENSOR}}, {output_shape, output_data},
param, fp16_enable, fp16_enable ? 1e-3 : 1e-9);
}
}
TEST_F(TestPowerOpenCLCI, Fp32CI) {
MS_LOG(INFO) << " init tensors ";
std::vector<int> shape_a = {2, 8};
std::vector<int> shape_b = {2, 8};
std::vector<int> input0_shape = {2, 8};
std::vector<int> input1_shape = {2, 8};
std::vector<int> output_shape = {2, 8};
auto data_type = kNumberTypeFloat32;
const float input_data1[] = {0.78957046, -0.99770847, 1.05838929, 1.60738329, -1.66226552, -2.03170525,
-0.48257631, -0.94244638, 1.47462044, -0.80247114, 0.12354778, -0.36436107,
-2.41973013, -0.40221205, -0.26739485, 0.23298305};
const float input_data2[] = {3, 2, 2, 1, 2, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2};
const float expect_data[] = {0.49223521, 0.99542219, 1.12018788, 1.60738329, 2.76312667, 4.1278262,
0.23287989, 0.88820518, 3.20657016, 0.64395994, 0.01526405, 0.13275899,
5.85509388, 0.16177453, 0.07150001, 0.0542811};
TEST_MAIN(input_data1, input_data2, expect_data, data_type, shape_a, shape_b, output_shape, false);
bool broadcast_ = false;
float shift_ = 0;
float scale_ = 1;
float input0_data[] = {0.78957046, -0.99770847, 1.05838929, 1.60738329, -1.66226552, -2.03170525,
-0.48257631, -0.94244638, 1.47462044, -0.80247114, 0.12354778, -0.36436107,
-2.41973013, -0.40221205, -0.26739485, 0.23298305};
float input1_data[] = {3, 2, 2, 1, 2, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2};
float output_data[] = {0.49223521, 0.99542219, 1.12018788, 1.60738329, 2.76312667, 4.1278262, 0.23287989, 0.88820518,
3.20657016, 0.64395994, 0.01526405, 0.13275899, 5.85509388, 0.16177453, 0.07150001, 0.0542811};
for (auto fp16_enable : {false, true}) {
auto *param = CreateParameter(broadcast_, shift_, scale_);
TestMain({{input0_shape, input0_data, VAR}, {input1_shape, input1_data, CONST_TENSOR}}, {output_shape, output_data},
param, fp16_enable, fp16_enable ? 1e-2 : 1e-6);
}
}
TEST_F(TestPowerOpenCLCI, Fp16CI) {
MS_LOG(INFO) << " init tensors ";
std::vector<int> shape_a = {2, 8};
std::vector<int> shape_b = {2, 8};
std::vector<int> output_shape = {2, 8};
auto data_type = kNumberTypeFloat16;
const float16_t input_data1[] = {0.1531, -0.8003, -0.1848, 0.3833, -1.469, 0.5586, -0.3223, -0.8887,
0.697, -1.007, -0.45, -1.736, -0.462, -0.699, -0.596, 0.7466};
const float16_t input_data2[] = {2.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 1.0, 2.0, 2.0, 1.0, 1.0};
const float16_t expect_data[] = {0.02344, -0.8003, -0.1848, 0.147, 2.156, 0.312, 0.1039, 0.7896,
0.4856, 1.014, 0.2025, -1.736, 0.2134, 0.489, -0.596, 0.7466};
TEST_MAIN(input_data1, input_data2, expect_data, data_type, shape_a, shape_b, output_shape, false);
TEST_F(TestPowerOpenCLCI, Fp32UnAlign) {
std::vector<int> input0_shape = {2, 7};
std::vector<int> input1_shape = {2, 7};
std::vector<int> output_shape = {2, 7};
bool broadcast_ = false;
float shift_ = 0;
float scale_ = 1;
float input0_data[] = {0.78957046, -0.99770847, 1.05838929, 1.60738329, -1.66226552, -2.03170525, -0.48257631,
1.47462044, -0.80247114, 0.12354778, -0.36436107, -2.41973013, -0.40221205, -0.26739485};
float input1_data[] = {3, 2, 2, 1, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2};
float output_data[] = {0.49223521, 0.99542219, 1.12018788, 1.60738329, 2.76312667, 4.1278262, 0.23287989,
3.20657016, 0.64395994, 0.01526405, 0.13275899, 5.85509388, 0.16177453, 0.07150001};
for (auto fp16_enable : {false, true}) {
auto *param = CreateParameter(broadcast_, shift_, scale_);
TestMain({{input0_shape, input0_data, VAR}, {input1_shape, input1_data, CONST_TENSOR}}, {output_shape, output_data},
param, fp16_enable, fp16_enable ? 1e-2 : 1e-6);
}
}
TEST_F(TestPowerOpenCLCI, broadcast) {
MS_LOG(INFO) << " init tensors ";
std::vector<int> shape_a = {1, 2, 8};
std::vector<int> shape_b = {};
std::vector<int> input0_shape = {1, 2, 8};
std::vector<int> output_shape = {1, 2, 8};
auto data_type = kNumberTypeFloat32;
float input_data1[] = {2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0};
float expect_data[] = {4.0, 9.0, 16.0, 25.0, 36.0, 49.0, 64, 81, 100.0, 121.0, 144, 169, 196.0, 225.0, 256, 289.0};
TEST_MAIN(input_data1, input_data1, expect_data, data_type, shape_a, shape_b, output_shape, true);
bool broadcast_ = true;
float shift_ = 0;
float scale_ = 1;
float input0_data[] = {2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0};
float output_data[] = {4.0, 9.0, 16.0, 25.0, 36.0, 49.0, 64, 81, 100.0, 121.0, 144, 169, 196.0, 225.0, 256, 289.0};
for (auto fp16_enable : {false}) {
auto *param = CreateParameter(broadcast_, shift_, scale_);
TestMain({{input0_shape, input0_data, VAR}}, {output_shape, output_data}, param, fp16_enable,
fp16_enable ? 1e-3 : 1e-6);
}
}
TEST_F(TestPowerOpenCLCI, Fp16CI) {
std::vector<int> input0_shape = {2, 8};
std::vector<int> input1_shape = {2, 8};
std::vector<int> output_shape = {2, 8};
bool broadcast_ = false;
float shift_ = 0;
float scale_ = 1;
float input0_data[] = {0.1531, -0.8003, -0.1848, 0.3833, -1.469, 0.5586, -0.3223, -0.8887,
0.697, -1.007, -0.45, -1.736, -0.462, -0.699, -0.596, 0.7466};
float input1_data[] = {2.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 1.0, 2.0, 2.0, 1.0, 1.0};
float output_data[] = {0.02344, -0.8003, -0.1848, 0.147, 2.156, 0.312, 0.1039, 0.7896,
0.4856, 1.014, 0.2025, -1.736, 0.2134, 0.489, -0.596, 0.7466};
for (auto fp16_enable : {true}) {
auto *param = CreateParameter(broadcast_, shift_, scale_);
TestMain({{input0_shape, input0_data, VAR}, {input1_shape, input1_data, CONST_TENSOR}}, {output_shape, output_data},
param, fp16_enable, fp16_enable ? 1e-3 : 1e-6);
}
}
} // namespace mindspore::lite::opencl::test

Loading…
Cancel
Save