!11339 [MS][LITE][Develop] add new ops named split and modify matmul_strassen
From: @pengyongrong Reviewed-by: Signed-off-by:pull/11339/MERGE
commit
ca99a7bd0b
@ -0,0 +1,114 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
|
||||
#define UP_DIV(x, y) (((x) + (y) - (1)) / (y))
|
||||
#define C4NUM 4
|
||||
|
||||
#define CHECK_IDX_ALIGN \
|
||||
const int X = get_global_id(0); \
|
||||
const int Y = get_global_id(1); \
|
||||
const int Z = get_global_id(2); \
|
||||
if (X > in_shape.x * in_shape.y || Y > in_shape.z || Z > in_shape.w || in_shape.y == 0) { \
|
||||
return; \
|
||||
}
|
||||
|
||||
#define ARGS_ALIGN \
|
||||
const int IN = X / in_shape.y; \
|
||||
const int IH = X % in_shape.y; \
|
||||
int coordinate_x = IN * in_shape.y + IH; \
|
||||
int coordinate_y = Y * in_shape.w + Z; \
|
||||
FLT4 result = READ_IMAGE(input, smp_none, (int2)(coordinate_y, coordinate_x));
|
||||
|
||||
__kernel void split_out2_axis3(__read_only image2d_t input, __write_only image2d_t output1,
|
||||
__write_only image2d_t output2, __global int *split_sizes_, int4 in_shape,
|
||||
int4 out_shape1, int4 out_shape2) {
|
||||
CHECK_IDX_ALIGN;
|
||||
ARGS_ALIGN;
|
||||
int boundary = UP_DIV(split_sizes_[0], C4NUM);
|
||||
if (Z < boundary) {
|
||||
coordinate_x = IN * out_shape1.y + IH;
|
||||
coordinate_y = Y * out_shape1.w + Z;
|
||||
WRITE_IMAGE(output1, (int2)(coordinate_y, coordinate_x), result);
|
||||
} else {
|
||||
coordinate_x = IN * out_shape2.y + IH;
|
||||
coordinate_y = Y * out_shape2.w + Z - boundary;
|
||||
WRITE_IMAGE(output2, (int2)(coordinate_y, coordinate_x), result);
|
||||
}
|
||||
}
|
||||
|
||||
__kernel void split_out2_axis2(__read_only image2d_t input, __write_only image2d_t output1,
|
||||
__write_only image2d_t output2, __global int *split_sizes_, int4 in_shape,
|
||||
int4 out_shape1, int4 out_shape2) {
|
||||
CHECK_IDX_ALIGN;
|
||||
ARGS_ALIGN;
|
||||
if (Y < split_sizes_[0]) {
|
||||
coordinate_x = IN * out_shape1.y + IH;
|
||||
coordinate_y = Y * out_shape1.w + Z;
|
||||
WRITE_IMAGE(output1, (int2)(coordinate_y, coordinate_x), result);
|
||||
} else {
|
||||
coordinate_x = IN * out_shape2.y + IH;
|
||||
coordinate_y = (Y - split_sizes_[0]) * out_shape2.w + Z;
|
||||
WRITE_IMAGE(output2, (int2)(coordinate_y, coordinate_x), result);
|
||||
}
|
||||
}
|
||||
|
||||
__kernel void split_out2_axis1(__read_only image2d_t input, __write_only image2d_t output1,
|
||||
__write_only image2d_t output2, __global int *split_sizes_, int4 in_shape,
|
||||
int4 out_shape1, int4 out_shape2) {
|
||||
CHECK_IDX_ALIGN;
|
||||
ARGS_ALIGN;
|
||||
if (IH < split_sizes_[0]) {
|
||||
coordinate_x = IN * out_shape1.y + IH;
|
||||
coordinate_y = Y * out_shape1.w + Z;
|
||||
WRITE_IMAGE(output1, (int2)(coordinate_y, coordinate_x), result);
|
||||
} else {
|
||||
coordinate_x = IN * out_shape2.y + IH - split_sizes_[0];
|
||||
coordinate_y = Y * out_shape2.w + Z;
|
||||
WRITE_IMAGE(output2, (int2)(coordinate_y, coordinate_x), result);
|
||||
}
|
||||
}
|
||||
|
||||
// UnAlign in Axis C for concat
|
||||
#define CHECK_IDX_UNALIGN \
|
||||
const int X = get_global_id(0); \
|
||||
const int Y = get_global_id(1); \
|
||||
if (X >= in_shape.x * in_shape.y || Y >= in_shape.z || in_shape.y == 0) { \
|
||||
return; \
|
||||
}
|
||||
|
||||
#define ARGS_UNALIGN \
|
||||
const int IN = X / in_shape.y, IH = X % in_shape.y; \
|
||||
const int IW = Y; \
|
||||
const int Align_inShape = UP_DIV(in_shape.w, C4NUM); \
|
||||
int index_input = (IN * in_shape.y + IH) * stride_w + IW * Align_inShape * C4NUM;
|
||||
|
||||
int dosplit(__global FLT *input, __write_only image2d_t output, int4 out_shape, int IN, int IH, int IW,
|
||||
int index_input) {
|
||||
int Remainder = out_shape.w % C4NUM;
|
||||
int coordinate_x = IN * out_shape.y + IH;
|
||||
int align_w = UP_DIV(out_shape.w, C4NUM);
|
||||
for (int i = 0; i < align_w; ++i) {
|
||||
int coordinate_y = IW * align_w + i;
|
||||
if ((i + 1) * C4NUM <= out_shape.w) {
|
||||
FLT4 result = {input[index_input], input[index_input + 1], input[index_input + 2], input[index_input + 3]};
|
||||
WRITE_IMAGE(output, (int2)(coordinate_y, coordinate_x), result);
|
||||
index_input += 4;
|
||||
} else {
|
||||
FLT result_temp[4] = {};
|
||||
for (int j = 0; j < Remainder; ++j) {
|
||||
result_temp[j] = input[index_input++];
|
||||
}
|
||||
FLT4 result = {result_temp[0], result_temp[1], result_temp[2], result_temp[3]};
|
||||
WRITE_IMAGE(output, (int2)(coordinate_y, coordinate_x), result);
|
||||
}
|
||||
}
|
||||
return index_input;
|
||||
}
|
||||
|
||||
__kernel void split_out2_axis3_unalign(__global FLT *input, __write_only image2d_t output1,
|
||||
__write_only image2d_t output2, __global int *split_sizes_, int4 in_shape,
|
||||
int4 out_shape1, int4 out_shape2, int stride_w) {
|
||||
CHECK_IDX_UNALIGN;
|
||||
ARGS_UNALIGN;
|
||||
index_input = dosplit(input, output1, out_shape1, IN, IH, IW, index_input);
|
||||
index_input = dosplit(input, output2, out_shape2, IN, IH, IW, index_input);
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,206 @@
|
||||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/runtime/kernel/opencl/kernel/split.h"
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <set>
|
||||
#include "src/kernel_registry.h"
|
||||
#include "src/runtime/kernel/opencl/cl/split.cl.inc"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kGPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::PrimitiveType_Split;
|
||||
namespace mindspore::kernel {
|
||||
|
||||
int SplitOpenCLKernel::RunAxis0() {
|
||||
auto allocator_ = ocl_runtime_->GetAllocator();
|
||||
std::vector<size_t> img_size;
|
||||
auto src_data = in_tensors_[0]->data_c();
|
||||
cl::Image2D *in_image = reinterpret_cast<cl::Image2D *>(allocator_->GetImage(src_data));
|
||||
if (in_image == nullptr) {
|
||||
MS_LOG(ERROR) << "RunAxis0 in_image can not be nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto src_area = cl::array<cl::size_type, 3U>{0, 0, 0};
|
||||
for (int i = 0; i < out_tensors_.size(); i++) {
|
||||
auto dst_data = out_tensors_[i]->data_c();
|
||||
allocator_->GetImageSize(dst_data, &img_size);
|
||||
auto dst_area = cl::array<cl::size_type, 3U>{0, 0, 0};
|
||||
auto region = cl::array<cl::size_type, 3U>{img_size[0], img_size[1], 1};
|
||||
cl::Image2D *out_image = reinterpret_cast<cl::Image2D *>(allocator_->GetImage(dst_data));
|
||||
if (out_image == nullptr) {
|
||||
MS_LOG(ERROR) << "RunAxis0 out_image can not be nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
ocl_runtime_->GetDefaultCommandQueue()->enqueueCopyImage(*in_image, *out_image, src_area, dst_area, region);
|
||||
src_area[1] += region[1];
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int SplitOpenCLKernel::CheckSpecs() {
|
||||
if (out_tensors_.size() != 2 || in_tensors_.size() != 1) {
|
||||
MS_LOG(ERROR) << "in size: " << in_tensors_.size() << ", out size: " << out_tensors_.size();
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (in_tensors_.at(0)->IsConst()) {
|
||||
MS_LOG(ERROR) << "in_tensors_ must be tensor";
|
||||
return RET_ERROR;
|
||||
}
|
||||
for (auto &out_tensor : out_tensors_) {
|
||||
if (out_tensor->IsConst()) {
|
||||
MS_LOG(ERROR) << "out_tensor must be tensor";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
auto param = reinterpret_cast<SplitParameter *>(this->op_parameter_);
|
||||
if (param->num_split_ != 2 && param->num_split_ != 1) {
|
||||
MS_LOG(ERROR) << "num_split_ only supported 1 or 2 yet";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (param->split_dim_ < 0 || param->split_dim_ > 3) {
|
||||
MS_LOG(ERROR) << "split_dim_ must between 0~3";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (param->split_sizes_ == nullptr) {
|
||||
MS_LOG(ERROR) << "split_sizes_ can not nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
void SplitOpenCLKernel::AlignSplitSizes(SplitParameter *param, const std::vector<int> &in_shape) {
|
||||
auto allocator = ocl_runtime_->GetAllocator();
|
||||
int shape_dim = in_shape.at(param->split_dim_);
|
||||
if (num_split_ == 1) {
|
||||
size_t num_split = UP_DIV(shape_dim, param->split_sizes_[0]);
|
||||
split_sizes_ = reinterpret_cast<int *>(allocator->Malloc(num_split * sizeof(int)));
|
||||
for (int i = 0; i < num_split - 1; ++i) {
|
||||
split_sizes_[i] = (i + 1) * param->split_sizes_[0];
|
||||
}
|
||||
} else {
|
||||
int sum = 0;
|
||||
split_sizes_ = reinterpret_cast<int *>(allocator->Malloc(num_split_ * sizeof(int)));
|
||||
for (int i = 0; i < num_split_ - 1; ++i) {
|
||||
sum += param->split_sizes_[i];
|
||||
split_sizes_[i] = sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int SplitOpenCLKernel::Prepare() {
|
||||
auto param = reinterpret_cast<SplitParameter *>(this->op_parameter_);
|
||||
auto in_shape = in_tensors_.at(0)->shape();
|
||||
int increment_dim = C4NUM - in_shape.size();
|
||||
split_dim_ = param->split_dim_ == 0 ? param->split_dim_ : param->split_dim_ + increment_dim;
|
||||
num_split_ = param->num_split_;
|
||||
if (split_dim_ == 0) {
|
||||
return RET_OK;
|
||||
}
|
||||
for (int i = 0; i < out_tensors_.size(); ++i) {
|
||||
int length = out_tensors_[0]->shape().size();
|
||||
if (split_dim_ == 3) {
|
||||
if (out_tensors_[i]->shape()[length - 1] % C4NUM != 0) {
|
||||
Align_ = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
AlignSplitSizes(param, in_shape);
|
||||
std::string kernel_name = "split_out";
|
||||
kernel_name += num_split_ == 1 ? std::to_string(out_tensors().size()) : std::to_string(num_split_);
|
||||
kernel_name += "_axis" + std::to_string(split_dim_);
|
||||
if (!Align_) {
|
||||
kernel_name += "_unalign";
|
||||
}
|
||||
MS_LOG(DEBUG) << "kernel_name=: " << kernel_name;
|
||||
std::string source = split_source;
|
||||
std::string program_name = "split";
|
||||
ocl_runtime_->LoadSource(program_name, source);
|
||||
ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name);
|
||||
MS_LOG(DEBUG) << kernel_name << " Init Done!";
|
||||
SetConstArgs();
|
||||
SetGlobalLocal();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
void SplitOpenCLKernel::SetConstArgs() {
|
||||
int arg_cn = out_tensors_.size() + 2;
|
||||
cl_int4 shape = {};
|
||||
for (int i = 0; i < in_tensors_[0]->shape().size(); ++i) {
|
||||
shape.s[i] = in_tensors_[0]->shape()[i];
|
||||
}
|
||||
Broadcast2GpuShape(in_shape_.s, shape.s, out_tensors_[0]->shape().size(), 1);
|
||||
if (Align_) {
|
||||
in_shape_.s[3] = UP_DIV(in_shape_.s[3], C4NUM);
|
||||
}
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_shape_);
|
||||
|
||||
for (int i = 0; i < out_tensors_.size(); ++i) {
|
||||
cl_int4 temp = {};
|
||||
for (int j = 0; j < out_tensors_[i]->shape().size(); ++j) {
|
||||
temp.s[j] = out_tensors_[i]->shape()[j];
|
||||
}
|
||||
Broadcast2GpuShape(out_shape_.s, temp.s, out_tensors_[i]->shape().size(), 1);
|
||||
if (Align_) {
|
||||
out_shape_.s[3] = UP_DIV(out_shape_.s[3], C4NUM);
|
||||
}
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, out_shape_);
|
||||
}
|
||||
GpuTensorInfo img_info(in_tensors_.at(0));
|
||||
size_t dtype = enable_fp16_ ? sizeof(cl_half) : sizeof(cl_float);
|
||||
stride_w = img_info.RowPitch() / dtype;
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, stride_w);
|
||||
return;
|
||||
}
|
||||
|
||||
void SplitOpenCLKernel::SetGlobalLocal() {
|
||||
OH = in_shape_.s[0] * in_shape_.s[1];
|
||||
OW = in_shape_.s[2];
|
||||
if (Align_) {
|
||||
OC = in_shape_.s[3];
|
||||
}
|
||||
global_size_ = {OH, OW, OC};
|
||||
local_size_ = {1, 1, 1};
|
||||
OpenCLKernel::AlignGlobalLocal(global_size_, local_size_);
|
||||
return;
|
||||
}
|
||||
|
||||
int SplitOpenCLKernel::Run() {
|
||||
if (split_dim_ == 0) {
|
||||
RunAxis0();
|
||||
return RET_OK;
|
||||
}
|
||||
int arg_cn = 0;
|
||||
if (Align_) {
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_.at(0)->data_c());
|
||||
} else {
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_.at(0)->data_c(), lite::opencl::MemType::BUF);
|
||||
}
|
||||
for (int i = 0; i < out_tensors_.size(); ++i) {
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, out_tensors_.at(i)->data_c());
|
||||
}
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, split_sizes_, lite::opencl::MemType::BUF);
|
||||
ocl_runtime_->RunKernel(kernel_, global_range_, local_range_, nullptr, &event_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Split, OpenCLKernelCreator<SplitOpenCLKernel>)
|
||||
REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_Split, OpenCLKernelCreator<SplitOpenCLKernel>)
|
||||
} // namespace mindspore::kernel
|
@ -0,0 +1,60 @@
|
||||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_SPLIT_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_SPLIT_H_
|
||||
|
||||
#include <vector>
|
||||
#include "src/runtime/kernel/opencl/opencl_kernel.h"
|
||||
#include "nnacl/split_parameter.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
|
||||
class SplitOpenCLKernel : public OpenCLKernel {
|
||||
public:
|
||||
SplitOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs)
|
||||
: OpenCLKernel(parameter, inputs, outputs) {}
|
||||
|
||||
~SplitOpenCLKernel() override = default;
|
||||
|
||||
int Prepare() override;
|
||||
|
||||
int CheckSpecs() override;
|
||||
void SetConstArgs() override;
|
||||
void SetGlobalLocal() override;
|
||||
int Run() override;
|
||||
|
||||
private:
|
||||
void AlignSplitSizes(SplitParameter *param, const std::vector<int> &in_shape);
|
||||
int RunAxis0();
|
||||
|
||||
private:
|
||||
cl_int4 in_shape_{};
|
||||
cl_int4 out_shape_ = {};
|
||||
bool Align_{true};
|
||||
bool enable_fp16_{false};
|
||||
size_t num_split_ = 1;
|
||||
int *split_sizes_{nullptr};
|
||||
int split_dim_ = 0;
|
||||
cl_int stride_w{1};
|
||||
uint32_t OH = {1};
|
||||
uint32_t OW = {1};
|
||||
uint32_t OC = {1};
|
||||
};
|
||||
|
||||
} // namespace mindspore::kernel
|
||||
#endif
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,77 @@
|
||||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_STRASSEN_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_STRASSEN_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "src/runtime/kernel/opencl/kernel/matmul.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
|
||||
class StrassenOpenCLKernel : public MatMulOpenCLKernel {
|
||||
public:
|
||||
StrassenOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs)
|
||||
: MatMulOpenCLKernel(parameter, inputs, outputs) {}
|
||||
~StrassenOpenCLKernel() override = default;
|
||||
|
||||
public:
|
||||
int Run() override;
|
||||
int Prepare() override;
|
||||
int InitWeights() override;
|
||||
void SetConstArgs() override;
|
||||
void SetGlobalLocal() override;
|
||||
|
||||
// strassen
|
||||
private:
|
||||
void AllocatorMemoryForStrassen(int NumA, int NumB);
|
||||
void DoStrassen(void *data, void *weight, void *result, const int size, const int depth, const int threshold);
|
||||
void StrassenSetGlobalLocal(size_t strassen_size, int type_flag);
|
||||
void StrassenSetConstArgs(cl::Kernel *kernel, int index, int strassen_size, bool is_matmul_kernel);
|
||||
void StrassenDataFilled(cl::Kernel *kernel, void *input, void *output, const int size, cl_int2 offset,
|
||||
lite::opencl::MemType mem_type);
|
||||
void StrassenAddSub(cl::Kernel *kernel, void *input, void *output, const int size, cl_int4 offset, int flag,
|
||||
lite::opencl::MemType mem_type);
|
||||
void StrassenBackResult(cl::Kernel *kernel, void *input1, void *input2, void *input3, void *input4, void *input5,
|
||||
void *input6, void *input7, void *output, const int size);
|
||||
void StrassenRunMmatmul(void *input, void *weight, void *output, const int size);
|
||||
void PrintImage2d(void *IMGData, size_t typesize, size_t width, size_t size);
|
||||
cl::Kernel kernel_IMG_add_sub_2;
|
||||
cl::Kernel MatMul_StrassenBUFFilled;
|
||||
cl::Kernel MatMul_StrassenIMGFilled;
|
||||
cl::Kernel kernel_BUF_add_sub_2;
|
||||
cl::Kernel kernel_back_result;
|
||||
cl::NDRange global_add_sub_, local_add_sub_;
|
||||
std::vector<size_t> global_size_add_sub;
|
||||
std::vector<size_t> local_size_add_sub;
|
||||
// image 2d
|
||||
void *A_temp[MAXDEPTH] = {nullptr};
|
||||
void *M1[MAXDEPTH] = {nullptr};
|
||||
void *M2[MAXDEPTH] = {nullptr};
|
||||
void *M3[MAXDEPTH] = {nullptr};
|
||||
void *M4[MAXDEPTH] = {nullptr};
|
||||
void *M5[MAXDEPTH] = {nullptr};
|
||||
void *M6[MAXDEPTH] = {nullptr};
|
||||
void *M7[MAXDEPTH] = {nullptr};
|
||||
// buffer
|
||||
void *B_temp[MAXDEPTH] = {nullptr};
|
||||
};
|
||||
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_WINOGRAD_H_
|
@ -0,0 +1,57 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "ut/src/runtime/kernel/opencl/common.h"
|
||||
#include "nnacl/split_parameter.h"
|
||||
|
||||
namespace mindspore::lite::opencl::test {
|
||||
|
||||
class TestOpenCL_Split : public CommonTest {};
|
||||
|
||||
namespace {
|
||||
// PrimitiveType_Split: src/ops/populate/split_populate.cc
|
||||
OpParameter *CreateParameter(int split_dim_, int num_split_, std::vector<int> split_sizes_) {
|
||||
auto *param = test::CreateParameter<SplitParameter>(schema::PrimitiveType_Split);
|
||||
param->split_dim_ = split_dim_;
|
||||
param->num_split_ = num_split_;
|
||||
param->split_sizes_ = reinterpret_cast<int *>(malloc(param->num_split_ * sizeof(int)));
|
||||
for (int i = 0; i < param->num_split_; ++i) {
|
||||
param->split_sizes_[i] = split_sizes_[i];
|
||||
}
|
||||
return reinterpret_cast<OpParameter *>(param);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TEST_F(TestOpenCL_Split, input2_axis3) {
|
||||
std::vector<int> input_shape = {2, 2, 2, 12};
|
||||
std::vector<int> output_shape1 = {2, 2, 2, 6};
|
||||
std::vector<int> output_shape2 = {2, 2, 2, 6};
|
||||
int split_dim_ = 3;
|
||||
int num_split_ = 2; // len of split_sizes_
|
||||
std::vector<int> split_sizes_{6, 6};
|
||||
size_t input_size, output1_size, output2_size;
|
||||
std::string inputPpath = "./test_data/splitfp32_input.bin";
|
||||
std::string output1Ppath = "./test_data/splitfp32_output1.bin";
|
||||
std::string output2Ppath = "./test_data/splitfp32_output2.bin";
|
||||
auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(inputPpath.c_str(), &input_size));
|
||||
auto output_data1 = reinterpret_cast<float *>(mindspore::lite::ReadFile(output1Ppath.c_str(), &output1_size));
|
||||
auto output_data2 = reinterpret_cast<float *>(mindspore::lite::ReadFile(output2Ppath.c_str(), &output2_size));
|
||||
for (auto fp16_enable : {false}) {
|
||||
auto *param = CreateParameter(split_dim_, num_split_, split_sizes_);
|
||||
TestMain({{input_shape, input_data, VAR}}, {{output_shape1, output_data1}, {output_shape2, output_data2}}, param,
|
||||
fp16_enable, fp16_enable ? 1e-3 : 1e-9);
|
||||
}
|
||||
}
|
||||
} // namespace mindspore::lite::opencl::test
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading…
Reference in new issue