parent
49c78de682
commit
78408c1a99
@ -0,0 +1,36 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
|
||||
|
||||
#define Pad(dataformat, in_x, in_y, out_x, out_y) \
|
||||
__kernel void Pad_##dataformat(__read_only image2d_t input, __write_only image2d_t output, int4 input_shape, \
|
||||
int4 output_shape, int2 pad, float constant_value) { \
|
||||
int oh = get_global_id(0); \
|
||||
int ow = get_global_id(1); \
|
||||
int co_slice = get_global_id(2); \
|
||||
int OH = output_shape.y; \
|
||||
int OW = output_shape.z; \
|
||||
int CO_SLICES = output_shape.w; \
|
||||
\
|
||||
if (oh >= OH || ow >= OW || co_slice >= CO_SLICES) { \
|
||||
return; \
|
||||
} \
|
||||
\
|
||||
int IH = input_shape.y; \
|
||||
int IW = input_shape.z; \
|
||||
int CI_SLICES = input_shape.w; \
|
||||
\
|
||||
int pad_top = pad.x; \
|
||||
int pad_left = pad.y; \
|
||||
int ih = oh - pad_top; \
|
||||
int iw = ow - pad_left; \
|
||||
\
|
||||
FLT4 result = (FLT4)(constant_value); \
|
||||
if (ih >= 0 && ih < IH && iw >= 0 && iw < IW) { \
|
||||
result = READ_IMAGE(input, smp_zero, (int2)(in_x, in_y)); \
|
||||
} \
|
||||
WRITE_IMAGE(output, (int2)(out_x, out_y), result); \
|
||||
}
|
||||
|
||||
Pad(NHWC4, iw *CI_SLICES + co_slice, ih, ow *CO_SLICES + co_slice, oh);
|
||||
Pad(NC4HW4, iw, co_slice *IH + ih, ow, co_slice *OH + oh);
|
@ -0,0 +1,157 @@
|
||||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <string>
|
||||
#include <set>
|
||||
#include <algorithm>
|
||||
#include "src/common/utils.h"
|
||||
#include "src/runtime/kernel/opencl/kernel/pad.h"
|
||||
#include "src/runtime/kernel/opencl/utils.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/runtime/kernel/opencl/cl/pad.cl.inc"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kGPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::PaddingMode_CONSTANT;
|
||||
using mindspore::schema::PrimitiveType_Pad;
|
||||
using mindspore::schema::Format::Format_NC4HW4;
|
||||
using mindspore::schema::Format::Format_NCHW;
|
||||
using mindspore::schema::Format::Format_NHWC;
|
||||
using mindspore::schema::Format::Format_NHWC4;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
|
||||
int PadOpenCLKernel::Init() {
|
||||
auto param = reinterpret_cast<PadParameter *>(op_parameter_);
|
||||
std::set<std::string> build_options;
|
||||
|
||||
if (op_format_ != Format_NHWC4 && op_format_ != Format_NC4HW4) {
|
||||
MS_LOG(ERROR) << "op_format_ " << op_format_ << " not support!";
|
||||
}
|
||||
if (in_tensors_.empty()) {
|
||||
MS_LOG(ERROR) << "PadOpenCLKernel in_tensors is empty";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (out_tensors_.empty()) {
|
||||
MS_LOG(ERROR) << "PadOpenCLKernel out_tensors is empty";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (param->paddings_[0] || param->paddings_[1] || param->paddings_[6] || param->paddings_[7]) {
|
||||
MS_LOG(ERROR) << "PadOpenCLKernel not support pad at Batch/Channel axis";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (param->pad_mode_ != PaddingMode_CONSTANT) {
|
||||
MS_LOG(ERROR) << "PadOpenCLKernel only support CONSTANT MODE";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto input_tensor = in_tensors_[0];
|
||||
auto output_tensor = out_tensors_[0];
|
||||
in_ori_format_ = input_tensor->GetFormat();
|
||||
out_ori_format_ = output_tensor->GetFormat();
|
||||
input_tensor->SetFormat(op_format_);
|
||||
output_tensor->SetFormat(op_format_);
|
||||
|
||||
CI_ = input_tensor->Channel();
|
||||
IH_ = input_tensor->Height();
|
||||
IW_ = input_tensor->Width();
|
||||
CO_ = output_tensor->Channel();
|
||||
OH_ = output_tensor->Height();
|
||||
OW_ = output_tensor->Width();
|
||||
CI_SLICES_ = UP_DIV(CI_, C4NUM);
|
||||
CO_SLICES_ = UP_DIV(CO_, C4NUM);
|
||||
|
||||
const std::string source = pad_source;
|
||||
const std::string kernel_name = op_format_ == Format_NHWC4 ? "Pad_NHWC4" : "Pad_NC4HW4";
|
||||
const std::string &program_name = kernel_name;
|
||||
ocl_runtime_->LoadSource(program_name, source);
|
||||
ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name, build_options);
|
||||
|
||||
MS_LOG(DEBUG) << "Pad Init Done!";
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int PadOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) {
|
||||
size_t im_dst_x, im_dst_y;
|
||||
if (in_tensors_[0]->GetFormat() == Format_NHWC4) {
|
||||
if (OW_ * CO_SLICES_ <= MAX_IMAGE2D_SIZE) {
|
||||
{
|
||||
im_dst_x = OW_ * CO_SLICES_;
|
||||
im_dst_y = OH_;
|
||||
}
|
||||
} else {
|
||||
im_dst_x = OH_ * CO_SLICES_;
|
||||
im_dst_y = OW_;
|
||||
}
|
||||
} else {
|
||||
im_dst_y = OH_ * CO_SLICES_;
|
||||
im_dst_x = OW_;
|
||||
}
|
||||
size_t img_dtype = ocl_runtime_->GetFp16Enable() ? CL_HALF_FLOAT : CL_FLOAT;
|
||||
img_size->clear();
|
||||
img_size->push_back(im_dst_x);
|
||||
img_size->push_back(im_dst_y);
|
||||
img_size->push_back(img_dtype);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int PadOpenCLKernel::Run() {
|
||||
MS_LOG(DEBUG) << this->name() << " Running!";
|
||||
|
||||
auto param = reinterpret_cast<PadParameter *>(op_parameter_);
|
||||
cl_int4 input_shape = {1, IH_, IW_, CI_SLICES_};
|
||||
cl_int4 output_shape = {1, OH_, OW_, CO_SLICES_};
|
||||
cl_int2 pad_top_left = {param->paddings_[2], param->paddings_[4]};
|
||||
|
||||
int arg_cn = 0;
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_[0]->data_c(), lite::opencl::MemType::IMG);
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, out_tensors_[0]->data_c(), lite::opencl::MemType::IMG);
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, input_shape);
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, output_shape);
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, pad_top_left);
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, static_cast<cl_float>(param->constant_value_));
|
||||
|
||||
std::vector<size_t> global = {static_cast<size_t>(OH_), static_cast<size_t>(OW_), static_cast<size_t>(CO_SLICES_)};
|
||||
std::vector<size_t> local = {8, 4, 1};
|
||||
ocl_runtime_->RunKernel(kernel_, global, local, nullptr);
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
kernel::LiteKernel *OpenCLPadKernelCreator(const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter,
|
||||
const lite::InnerContext *ctx, const kernel::KernelKey &desc,
|
||||
const mindspore::lite::PrimitiveC *primitive) {
|
||||
auto *kernel = new (std::nothrow) PadOpenCLKernel(reinterpret_cast<OpParameter *>(opParameter), inputs, outputs);
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "Create OpenCL Pad kernel failed!";
|
||||
return nullptr;
|
||||
}
|
||||
auto ret = kernel->Init();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init kernel failed, name: Pad";
|
||||
delete kernel;
|
||||
return nullptr;
|
||||
}
|
||||
return kernel;
|
||||
}
|
||||
|
||||
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Pad, OpenCLPadKernelCreator)
|
||||
REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_Pad, OpenCLPadKernelCreator)
|
||||
} // namespace mindspore::kernel
|
@ -0,0 +1,53 @@
|
||||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_PAD_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_PAD_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "src/tensor.h"
|
||||
#include "src/runtime/kernel/opencl/opencl_kernel.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "nnacl/pad_parameter.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
|
||||
class PadOpenCLKernel : public OpenCLKernel {
|
||||
public:
|
||||
explicit PadOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs)
|
||||
: OpenCLKernel(parameter, inputs, outputs) {}
|
||||
~PadOpenCLKernel() override{};
|
||||
|
||||
int Init() override;
|
||||
int Run() override;
|
||||
int GetImageSize(size_t idx, std::vector<size_t> *img_size) override;
|
||||
|
||||
private:
|
||||
int CI_{};
|
||||
int IH_{};
|
||||
int IW_{};
|
||||
int CO_{};
|
||||
int OH_{};
|
||||
int OW_{};
|
||||
int CI_SLICES_{};
|
||||
int CO_SLICES_{};
|
||||
cl::Kernel kernel_;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_PAD_H_
|
@ -0,0 +1,168 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <memory>
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "common/common_test.h"
|
||||
#include "mindspore/lite/src/common/file_utils.h"
|
||||
#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h"
|
||||
#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h"
|
||||
#include "mindspore/lite/src/runtime/kernel/opencl/kernel/pad.h"
|
||||
#include "nnacl/pack.h"
|
||||
|
||||
using mindspore::kernel::LiteKernel;
|
||||
using mindspore::kernel::PadOpenCLKernel;
|
||||
using mindspore::kernel::SubGraphOpenCLKernel;
|
||||
using mindspore::lite::Tensor;
|
||||
using mindspore::schema::Format;
|
||||
using mindspore::schema::Format_NC4HW4;
|
||||
using mindspore::schema::Format_NHWC;
|
||||
using mindspore::schema::Format_NHWC4;
|
||||
using mindspore::schema::NodeType_ValueNode;
|
||||
using mindspore::schema::PaddingMode;
|
||||
using mindspore::schema::PaddingMode_CONSTANT;
|
||||
using mindspore::schema::PaddingMode_REFLECT;
|
||||
using mindspore::schema::PaddingMode_SYMMETRIC;
|
||||
|
||||
namespace mindspore {
|
||||
|
||||
class TestPadOpenCL : public mindspore::CommonTest {};
|
||||
|
||||
void TEST_MAIN(PadParameter *param, Format input_format, Format output_format, Format op_format, const TypeId data_type,
|
||||
const std::vector<int> &input_shape, const std::vector<int> &output_shape, const float *input_data,
|
||||
const float *expect_data) {
|
||||
auto ocl_runtime_wrapper = lite::opencl::OpenCLRuntimeWrapper();
|
||||
auto ocl_runtime = ocl_runtime_wrapper.GetInstance();
|
||||
ocl_runtime->Init();
|
||||
ocl_runtime->SetFp16Enable(data_type == kNumberTypeFloat16);
|
||||
auto allocator = ocl_runtime->GetAllocator();
|
||||
|
||||
MS_LOG(DEBUG) << "create Tensors";
|
||||
auto input = Tensor(kNumberTypeFloat32, input_shape, input_format, lite::TensorCategory(NodeType_ValueNode));
|
||||
auto output = Tensor(kNumberTypeFloat32, output_shape, output_format, lite::TensorCategory(NodeType_ValueNode));
|
||||
|
||||
MS_LOG(DEBUG) << "create OpenCL Kernel";
|
||||
std::vector<lite::Tensor *> inputs{&input};
|
||||
std::vector<lite::Tensor *> outputs{&output};
|
||||
auto kernel = std::make_unique<PadOpenCLKernel>(reinterpret_cast<OpParameter *>(param), inputs, outputs);
|
||||
if (kernel == nullptr) {
|
||||
return;
|
||||
}
|
||||
kernel->SetFormatType(op_format);
|
||||
kernel->Init();
|
||||
|
||||
MS_LOG(DEBUG) << "create SubGraph";
|
||||
std::vector<kernel::LiteKernel *> kernels{kernel.release()};
|
||||
auto sub_graph = new (std::nothrow) SubGraphOpenCLKernel({&input}, {&output}, kernels, kernels, kernels);
|
||||
input.MallocData(allocator);
|
||||
sub_graph->Init();
|
||||
memcpy(input.data_c(), input_data, input.Size());
|
||||
sub_graph->Run();
|
||||
if (lite::CompareOutputData(reinterpret_cast<float *>(output.data_c()), output.ElementsNum(),
|
||||
const_cast<float *>(expect_data), output.ElementsNum())) {
|
||||
FAIL();
|
||||
} else {
|
||||
std::cout << "COMPARE SUCCESS!\n";
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "release resources";
|
||||
input.SetData(nullptr);
|
||||
output.SetData(nullptr);
|
||||
delete sub_graph;
|
||||
}
|
||||
|
||||
TEST_F(TestPadOpenCL, TestPad3) {
|
||||
auto param = static_cast<PadParameter *>(malloc(sizeof(PadParameter)));
|
||||
if (param == nullptr) {
|
||||
MS_LOG(ERROR) << "PadParameter create error.";
|
||||
return;
|
||||
}
|
||||
param->pad_mode_ = PaddingMode_CONSTANT;
|
||||
param->constant_value_ = 0.0f;
|
||||
param->padding_length = MAX_PAD_SIZE;
|
||||
int paddings[MAX_PAD_SIZE] = {0, 0, 3, 3, 3, 3, 0, 0};
|
||||
memcpy(param->paddings_, paddings, sizeof(paddings));
|
||||
|
||||
float input_data[48] = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0,
|
||||
12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0,
|
||||
24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0,
|
||||
36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0};
|
||||
float expect_data[300] = {
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 12.0, 13.0, 14.0, 15.0,
|
||||
16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 36.0,
|
||||
37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0};
|
||||
|
||||
TEST_MAIN(param, Format_NHWC, Format_NHWC, Format_NHWC4, kNumberTypeFloat32, {1, 4, 4, 3}, {1, 10, 10, 3}, input_data,
|
||||
expect_data);
|
||||
TEST_MAIN(param, Format_NHWC, Format_NHWC, Format_NC4HW4, kNumberTypeFloat32, {1, 4, 4, 3}, {1, 10, 10, 3},
|
||||
input_data, expect_data);
|
||||
TEST_MAIN(param, Format_NHWC, Format_NHWC, Format_NHWC4, kNumberTypeFloat16, {1, 4, 4, 3}, {1, 10, 10, 3}, input_data,
|
||||
expect_data);
|
||||
TEST_MAIN(param, Format_NHWC, Format_NHWC, Format_NC4HW4, kNumberTypeFloat16, {1, 4, 4, 3}, {1, 10, 10, 3},
|
||||
input_data, expect_data);
|
||||
}
|
||||
|
||||
TEST_F(TestPadOpenCL, TestPad4) {
|
||||
auto param = static_cast<PadParameter *>(malloc(sizeof(PadParameter)));
|
||||
if (param == nullptr) {
|
||||
MS_LOG(ERROR) << "PadParameter create error.";
|
||||
return;
|
||||
}
|
||||
param->pad_mode_ = PaddingMode_CONSTANT;
|
||||
param->constant_value_ = 1.0f;
|
||||
param->padding_length = MAX_PAD_SIZE;
|
||||
int paddings[MAX_PAD_SIZE] = {0, 0, 3, 3, 3, 3, 0, 0};
|
||||
memcpy(param->paddings_, paddings, sizeof(paddings));
|
||||
|
||||
float input_data[48] = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0,
|
||||
12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0,
|
||||
24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0,
|
||||
36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0};
|
||||
float expect_data[300] = {
|
||||
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
|
||||
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
|
||||
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
|
||||
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
|
||||
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
|
||||
1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 1.0, 1.0, 1.0,
|
||||
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 12.0, 13.0, 14.0, 15.0,
|
||||
16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
|
||||
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0,
|
||||
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 36.0,
|
||||
37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
|
||||
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
|
||||
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
|
||||
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
|
||||
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
|
||||
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0};
|
||||
|
||||
TEST_MAIN(param, Format_NHWC, Format_NHWC, Format_NHWC4, kNumberTypeFloat32, {1, 4, 4, 3}, {1, 10, 10, 3}, input_data,
|
||||
expect_data);
|
||||
}
|
||||
|
||||
} // namespace mindspore
|
Loading…
Reference in new issue