!4001 add test case for opencl for pooling

Merge pull request !4001 from chenzhongming/master
pull/4001/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 73b05602c9

@ -289,7 +289,11 @@ if (SUPPORT_GPU)
${TEST_DIR}/ut/src/runtime/kernel/opencl/matmul_tests.cc
${TEST_DIR}/ut/src/runtime/kernel/opencl/depthwise_conv2d_tests.cc
${TEST_DIR}/ut/src/runtime/kernel/opencl/concat_tests.cc
${TEST_DIR}/ut/src/runtime/kernel/opencl/softmax_cl_tests.cc
${TEST_DIR}/ut/src/runtime/kernel/opencl/softmax_tests.cc
${TEST_DIR}/ut/src/runtime/kernel/opencl/arithmetic_tests.cc
${TEST_DIR}/ut/src/runtime/kernel/opencl/avg_pooling_tests.cc
${TEST_DIR}/ut/src/runtime/kernel/opencl/max_pooling_tests.cc
${TEST_DIR}/ut/src/runtime/kernel/opencl/utils_tests.cc
)
endif()

@ -0,0 +1,176 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common_test.h"
#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h"
#include "mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.h"
namespace mindspore {
void BoardcaseAdd(const float *a, const float b, float *c, const int size) {
for (int i = 0; i < size; i++) {
c[i] = a[i] + b;
}
}
void ElementAdd(const float *a, const float *b, float *c, const int size) {
for (int i = 0; i < size; i++) {
c[i] = a[i] + b[i];
}
}
bool DataCompare(const float *a, const float *b, const int size, const float accuracy = 1e-4) {
for (int i = 0; i < size; i++) {
auto diff = fabs(a[i] - b[i]);
if (diff > accuracy) {
MS_LOG(ERROR) << "compare failed at " << i << " exp " << a[i] << " bug got " << b[i];
return false;
}
}
return true;
}
void InitData(void *data, const int size) {
float *data_float = reinterpret_cast<float *>(data);
static unsigned int seed = 123;
for (int i = 0; i < size; i++) {
data_float[i] = static_cast<int>(rand_r(&seed)) % 100;
}
}
void LogData(void *data, const int size, const std::string prefix) {
std::cout << prefix;
float *data_float = reinterpret_cast<float *>(data);
for (int i = 0; i < size; i++) {
std::cout << data_float[i] << ",";
}
std::cout << std::endl;
}
void TestCase(const std::vector<int> &shape_a, const std::vector<int> &shape_b) {
std::cout << "TestCase" << std::endl;
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
bool is_bias_add = shape_b.empty();
auto tensorType = schema::NodeType_ValueNode;
std::cout << "TestCase tensor" << std::endl;
lite::tensor::Tensor *tensor_a =
new lite::tensor::Tensor(kNumberTypeFloat32, shape_a, schema::Format_NHWC4, tensorType);
lite::tensor::Tensor *tensor_b =
new lite::tensor::Tensor(kNumberTypeFloat32, shape_b, schema::Format_NHWC4, tensorType);
lite::tensor::Tensor *tensor_c =
new lite::tensor::Tensor(kNumberTypeFloat32, shape_a, schema::Format_NHWC4, tensorType);
int64_t element_num = tensor_a->ElementsC4Num();
int64_t element_num_b = is_bias_add ? 1 : tensor_b->ElementsC4Num();
std::cout << "TestCase new data" << std::endl;
float *data_a = new float[element_num];
float *data_b = new float[element_num_b];
float *data_c_cpu = new float[element_num];
float *data_c_ocl = new float[element_num];
InitData(data_a, element_num);
InitData(data_b, element_num_b);
memset(data_c_ocl, 0, sizeof(float) * element_num);
std::cout << "TestCase run cpu" << std::endl;
if (is_bias_add) {
BoardcaseAdd(data_a, static_cast<float *>(data_b)[0], data_c_cpu, element_num);
} else {
ElementAdd(data_a, data_b, data_c_cpu, element_num);
}
std::cout << "TestCase set data" << std::endl;
std::vector<lite::tensor::Tensor *> inputs = {tensor_a};
if (!is_bias_add) {
inputs.push_back(tensor_b);
} else {
tensor_b->MallocData();
memcpy(tensor_b->Data(), data_b, sizeof(float));
}
std::vector<lite::tensor::Tensor *> outputs = {tensor_c};
ArithmeticParameter *param = new ArithmeticParameter();
param->ndim_ = 4;
param->op_parameter_.type_ = PrimitiveType_Add;
std::vector<lite::tensor::Tensor *> arithmetic_inputs = {tensor_a, tensor_b};
lite::Context ctx;
auto *arith_kernel =
new kernel::ArithmeticOpenCLKernel(reinterpret_cast<OpParameter *>(param), arithmetic_inputs, outputs, &ctx);
arith_kernel->Init();
std::vector<kernel::LiteKernel *> kernels{arith_kernel};
auto *kernel = new kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels);
std::cout << "TestCase Init" << std::endl;
kernel->Init();
memcpy(inputs[0]->Data(), data_a, sizeof(float) * element_num);
if (!is_bias_add) {
memcpy(inputs[1]->Data(), data_b, sizeof(float) * element_num_b);
}
std::cout << "TestCase Run" << std::endl;
kernel->Run();
memcpy(data_c_ocl, outputs[0]->Data(), sizeof(float) * element_num);
// ocl_runtime->SyncCommandQueue();
LogData(data_a, 10, "Data A : ");
LogData(data_b, tensor_b->shape().empty() ? 1 : 10, "Data B : ");
LogData(data_c_cpu, 10, "Expect compute : ");
LogData(outputs[0]->Data(), 10, "OpenCL compute : ");
bool cmp = DataCompare(data_c_cpu, data_c_ocl, element_num);
MS_LOG(INFO) << "Compare " << (cmp ? "success!" : "failed!");
std::cout << "TestCase End" << std::endl;
// free
delete[] data_a;
delete[] data_b;
delete[] data_c_cpu;
delete[] data_c_ocl;
delete kernel;
delete arith_kernel;
for (auto tensor : inputs) {
delete tensor;
}
for (auto tensor : outputs) {
delete tensor;
}
lite::opencl::OpenCLRuntime::DeleteInstance();
}
class TestArithmeticOpenCL : public mindspore::Common {
public:
TestArithmeticOpenCL() {}
};
TEST_F(TestArithmeticOpenCL, AddElementwiseTest) {
const std::vector<int> &shape_a = {1, 32, 32, 4};
const std::vector<int> &shape_b = {1, 32, 32, 4};
TestCase(shape_a, shape_b);
}
// TEST_F(TestOpenCLKernel, AddBoardcaseTest) {
// const std::vector<int> &shape_a = {1, 4, 128, 128};
// const std::vector<int> &shape_b = {};
// TestCase(shape_a, shape_b);
//}
} // namespace mindspore

@ -0,0 +1,124 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <memory>
#include "utils/log_adapter.h"
#include "common/common_test.h"
#include "mindspore/lite/src/common/file_utils.h"
#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h"
#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h"
#include "mindspore/lite/src/runtime/kernel/opencl/kernel/pooling2d.h"
namespace mindspore {
class TestAvgPoolingOpenCL : public mindspore::Common {};
void InitAvgPoolingParam(PoolingParameter *param) {
param->input_batch_ = 1;
param->input_h_ = 2;
param->input_w_ = 2;
param->input_channel_ = 4;
param->output_batch_ = 1;
param->output_h_ = 1;
param->output_w_ = 1;
param->output_channel_ = 4;
param->window_h_ = 2;
param->window_w_ = 2;
param->stride_h_ = 2;
param->stride_w_ = 2;
param->pad_u_ = 0;
param->pad_d_ = 0;
param->pad_l_ = 0;
param->pad_r_ = 0;
param->max_pooling_ = false;
param->avg_pooling_ = true;
}
TEST_F(TestAvgPoolingOpenCL, AvgPoolFp32) {
MS_LOG(INFO) << "start TEST_F TestPoolingOpenCL";
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
ocl_runtime->Init();
MS_LOG(INFO) << "create PoolingParameter";
auto param = new PoolingParameter();
InitAvgPoolingParam(param);
MS_LOG(INFO) << "create Tensors";
std::vector<int> shape_in = {
param->input_batch_,
param->input_h_,
param->input_w_,
param->input_channel_,
};
std::vector<int> shape_out = {
param->output_batch_,
param->output_h_,
param->output_w_,
param->output_channel_,
};
auto data_type = kNumberTypeFloat32;
auto tensorType = schema::NodeType_ValueNode;
lite::tensor::Tensor *tensor_in = new lite::tensor::Tensor(data_type, shape_in, schema::Format_NHWC, tensorType);
lite::tensor::Tensor *tensor_out = new lite::tensor::Tensor(data_type, shape_out, schema::Format_NHWC, tensorType);
std::vector<lite::tensor::Tensor *> inputs{tensor_in};
std::vector<lite::tensor::Tensor *> outputs{tensor_out};
MS_LOG(INFO) << "create OpenCL Kernel";
auto *pooling_kernel = new kernel::PoolingOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs);
pooling_kernel->Init();
std::vector<kernel::LiteKernel *> kernels{pooling_kernel};
MS_LOG(INFO) << "create SubGraphOpenCLKernel";
auto *pGraph = new kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels);
pGraph->Init();
MS_LOG(INFO) << "initialize data";
std::vector<lite::tensor::Tensor *> tensor_map = {tensor_in};
for (auto &tensor_file : tensor_map) {
auto tensor = tensor_file;
size_t size = tensor->Size();
const float data[16] = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f,
0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f};
memcpy(tensor->Data(), data, size);
}
MS_LOG(INFO) << "pGraph->Run()";
pGraph->Run();
MS_LOG(INFO) << "==================output data=================";
float *output_data = reinterpret_cast<float *>(tensor_out->Data());
printf("output:");
for (int i = 0; i < 4; i++) {
printf("%.3f ", output_data[i]);
}
printf("\n");
size_t output_size = tensor_out->Size();
float expect[4] = {2.0f, 3.0f, 4.0f, 5.0f};
for (int i = 0; i < tensor_out->ElementsNum(); ++i)
if (std::fabs(output_data[i] - expect[i]) > 1e-5) {
printf("idx[%d] except=%.3f output=%.3f, ", i, expect[i], output_data[i]);
}
printf("test all close OK!\n");
lite::CompareOutputData(output_data, expect, 4);
}
} // namespace mindspore

@ -0,0 +1,87 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include "utils/log_adapter.h"
#include "common/common_test.h"
#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h"
#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h"
#include "mindspore/lite/src/runtime/kernel/opencl/kernel/pooling2d.h"
#include "mindspore/lite/test/ut/src/runtime/kernel/opencl/utils_tests.h"
namespace mindspore {
class TestMaxPoolingOpenCL : public mindspore::Common {};
void InitParameter(PoolingParameter *param) {
param->window_h_ = 2;
param->window_w_ = 2;
param->stride_h_ = 2;
param->stride_w_ = 2;
param->pad_u_ = 0;
param->pad_d_ = 0;
param->pad_l_ = 0;
param->pad_r_ = 0;
param->avg_pooling_ = false;
param->max_pooling_ = true;
}
TEST_F(TestMaxPoolingOpenCL, MaxPool_1_32_512_96) {
MS_LOG(INFO) << "ocl runtime";
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
ocl_runtime->Init();
MS_LOG(INFO) << "PoolingParameter";
auto param = new PoolingParameter;
InitParameter(param);
// define tensor
MS_LOG(INFO) << "define tensor";
std::vector<int> input_shape = {1, 16, 256, 192};
std::vector<int> output_shape = {1, 8, 128, 192};
auto data_type = kNumberTypeFloat32;
auto tensorType = schema::NodeType_ValueNode;
auto input_tensor = new lite::tensor::Tensor(data_type, input_shape, schema::Format_NHWC4, tensorType);
auto output_tensor = new lite::tensor::Tensor(data_type, output_shape, schema::Format_NHWC4, tensorType);
std::vector<lite::tensor::Tensor *> inputs{input_tensor};
std::vector<lite::tensor::Tensor *> outputs{output_tensor};
// run
auto *pooling_kernel = new kernel::PoolingOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs);
pooling_kernel->Init();
std::vector<kernel::LiteKernel *> kernels{pooling_kernel};
auto *pGraph = new kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels);
pGraph->Init();
// load data
MS_LOG(INFO) << "load data";
std::string input_file = "maxpool_in.bin";
std::string expect_file = "maxpool_out.bin";
LoadTestData(input_tensor->Data(), input_tensor->Size(), input_file);
auto *input_data = reinterpret_cast<float *>(input_tensor->Data());
printf("input[0:10]:");
for (int i = 0; i < 10; i++) {
printf("[%d]:%.3f ", i, input_data[i]);
}
printf("\n");
pGraph->Run();
MS_LOG(INFO) << "compare result";
CompareOutput(output_tensor, expect_file);
}
} // namespace mindspore

@ -0,0 +1,63 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string>
#include "utils/log_adapter.h"
#include "mindspore/lite/src/common/file_utils.h"
#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h"
#include "mindspore/lite/test/ut/src/runtime/kernel/opencl/utils_tests.h"
namespace mindspore {
void LoadTestData(void *dst, size_t dst_size, const std::string &file_path) {
if (file_path.empty()) {
memset(dst, dst_size, dst_size);
} else {
memcpy(dst, reinterpret_cast<const void *>(dst_size), dst_size);
}
}
void CompareOutput(lite::tensor::Tensor *output_tensor, const std::string &file_path) {
float *output_data = reinterpret_cast<float *>(output_tensor->Data());
size_t output_size = output_tensor->Size();
float *expect_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(file_path.c_str(), &output_size));
printf("output[0:10]:");
for (int i = 0; i < 10; i++) {
printf("[%d]:%.3f ", i, output_data[i]);
}
printf("\n");
printf("expect[0:10]:");
for (int i = 0; i < 10; i++) {
printf("[%d]:%.3f ", i, expect_data[i]);
}
printf("\n");
constexpr float atol = 1e-5;
for (int i = 0; i < output_tensor->ElementsNum(); ++i) {
if (std::fabs(output_data[i] - expect_data[i]) > atol) {
printf("error at idx[%d] expect=%.3f output=%.3f \n", i, expect_data[i], output_data[i]);
printf("error at idx[%d] expect=%.3f output=%.3f \n", i, expect_data[i], output_data[i]);
printf("error at idx[%d] expect=%.3f output=%.3f \n", i, expect_data[i], output_data[i]);
return;
}
}
printf("compare success!\n");
printf("compare success!\n");
printf("compare success!\n\n\n");
}
} // namespace mindspore

@ -0,0 +1,36 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string>
#include <iostream>
#include "tests/ut/cpp/common/common_test.h"
#include "utils/log_adapter.h"
#include "mindspore/lite/src/common/file_utils.h"
#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h"
#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h"
#ifndef TESTS_UT_OPENCL_KERNEL_TESTS_UTILS_H_
#define TESTS_UT_OPENCL_KERNEL_TESTS_UTILS_H_
namespace mindspore {
void LoadTestData(void *dst, size_t dst_size, const std::string &file_path);
void CompareOutput(lite::tensor::Tensor *output_tensor, const std::string &file_path);
} // namespace mindspore
#endif // TESTS_UT_OPENCL_KERNEL_TESTS_UTILS_H_

@ -446,7 +446,7 @@ def export(network, *inputs, file_name, mean=127.5, std_dev=127.5, file_format='
Ascend model.
- BINARY: Binary format for model. An intermidiate representation format for models.
"""
supported_device = ["Ascend"]
supported_device = ["Ascend", "GPU"]
supported_formats = ['GEIR', 'BINARY']
mean = validator.check_type("mean", mean, (int, float))

Loading…
Cancel
Save