parent
edc4ac2c25
commit
91568de9eb
@ -0,0 +1,232 @@
|
||||
#ifdef cl_khr_fp16
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
#endif
|
||||
|
||||
#define C4NUM 4
|
||||
#define UP_DIV(x, y) (((x) + (y) - (1)) / (y))
|
||||
__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
|
||||
__kernel void OneHotAxis0(__read_only image2d_t src_data, __write_only image2d_t dst_data, int2 in_image2d_shape,
|
||||
int4 out_shape, int depth, float on_value, float off_value, int C) {
|
||||
int X = get_global_id(0); // C4
|
||||
int Y = get_global_id(1); // W
|
||||
int Z = get_global_id(2); // H * N
|
||||
if (X >= out_shape.w || Y >= out_shape.z || Z >= out_shape.x * out_shape.y) return;
|
||||
int N = Z / out_shape.y;
|
||||
int H = Z % out_shape.y;
|
||||
int in_index = (H * out_shape.z + Y) * out_shape.w + X;
|
||||
FLT4 indices = READ_IMAGE(src_data, smp_zero, (int2)(in_index % in_image2d_shape.x, in_index / in_image2d_shape.x));
|
||||
int *indices_int = (int *)&indices;
|
||||
FLT4 result = (FLT4)(0.f);
|
||||
if (4 * X < C) {
|
||||
if (indices_int[0] == N) {
|
||||
result.x = (FLT)(on_value);
|
||||
} else {
|
||||
result.x = (FLT)(off_value);
|
||||
}
|
||||
}
|
||||
if (4 * X + 1 < C) {
|
||||
if (indices_int[1] == N) {
|
||||
result.y = (FLT)(on_value);
|
||||
} else {
|
||||
result.y = (FLT)(off_value);
|
||||
}
|
||||
}
|
||||
if (4 * X + 2 < C) {
|
||||
if (indices_int[2] == N) {
|
||||
result.z = (FLT)(on_value);
|
||||
} else {
|
||||
result.z = (FLT)(off_value);
|
||||
}
|
||||
}
|
||||
if (4 * X + 3 < C) {
|
||||
if (indices_int[3] == N) {
|
||||
result.w = (FLT)(on_value);
|
||||
} else {
|
||||
result.w = (FLT)(off_value);
|
||||
}
|
||||
}
|
||||
WRITE_IMAGE(dst_data, (int2)(Y * out_shape.w + X, Z), result);
|
||||
}
|
||||
|
||||
__kernel void OneHotAxis1(__read_only image2d_t src_data, __write_only image2d_t dst_data, int2 in_image2d_shape,
|
||||
int4 out_shape, int depth, float on_value, float off_value, int C) {
|
||||
int X = get_global_id(0); // C4
|
||||
int Y = get_global_id(1); // W
|
||||
int Z = get_global_id(2); // H * N
|
||||
if (X >= out_shape.w || Y >= out_shape.z || Z >= out_shape.x * out_shape.y) return;
|
||||
int N = Z / out_shape.y;
|
||||
int H = Z % out_shape.y;
|
||||
int in_index = (N * out_shape.z + Y) * out_shape.w + X;
|
||||
FLT4 indices = READ_IMAGE(src_data, smp_zero, (int2)(in_index % in_image2d_shape.x, in_index / in_image2d_shape.x));
|
||||
int *indices_int = (int *)&indices;
|
||||
FLT4 result = (FLT4)(0.f);
|
||||
if (4 * X < C) {
|
||||
if (indices_int[0] == H) {
|
||||
result.x = (FLT)(on_value);
|
||||
} else {
|
||||
result.x = (FLT)(off_value);
|
||||
}
|
||||
}
|
||||
if (4 * X + 1 < C) {
|
||||
if (indices_int[1] == H) {
|
||||
result.y = (FLT)(on_value);
|
||||
} else {
|
||||
result.y = (FLT)(off_value);
|
||||
}
|
||||
}
|
||||
if (4 * X + 2 < C) {
|
||||
if (indices_int[2] == H) {
|
||||
result.z = (FLT)(on_value);
|
||||
} else {
|
||||
result.z = (FLT)(off_value);
|
||||
}
|
||||
}
|
||||
if (4 * X + 3 < C) {
|
||||
if (indices_int[3] == H) {
|
||||
result.w = (FLT)(on_value);
|
||||
} else {
|
||||
result.w = (FLT)(off_value);
|
||||
}
|
||||
}
|
||||
WRITE_IMAGE(dst_data, (int2)(Y * out_shape.w + X, Z), result);
|
||||
}
|
||||
|
||||
__kernel void OneHotAxis2(__read_only image2d_t src_data, __write_only image2d_t dst_data, int2 in_image2d_shape,
|
||||
int4 out_shape, int depth, float on_value, float off_value, int C) {
|
||||
int X = get_global_id(0); // C4
|
||||
int Y = get_global_id(1); // W
|
||||
int Z = get_global_id(2); // H * N
|
||||
if (X >= out_shape.w || Y >= out_shape.z || Z >= out_shape.x * out_shape.y) return;
|
||||
int N = Z / out_shape.y;
|
||||
int H = Z % out_shape.y;
|
||||
int in_index = (N * out_shape.y + H) * out_shape.w + X;
|
||||
FLT4 indices = READ_IMAGE(src_data, smp_zero, (int2)(in_index % in_image2d_shape.x, in_index / in_image2d_shape.x));
|
||||
int *indices_int = (int *)&indices;
|
||||
FLT4 result = (FLT4)(0.f);
|
||||
if (4 * X < C) {
|
||||
if (indices_int[0] == Y) {
|
||||
result.x = (FLT)(on_value);
|
||||
} else {
|
||||
result.x = (FLT)(off_value);
|
||||
}
|
||||
}
|
||||
if (4 * X + 1 < C) {
|
||||
if (indices_int[1] == Y) {
|
||||
result.y = (FLT)(on_value);
|
||||
} else {
|
||||
result.y = (FLT)(off_value);
|
||||
}
|
||||
}
|
||||
if (4 * X + 2 < C) {
|
||||
if (indices_int[2] == Y) {
|
||||
result.z = (FLT)(on_value);
|
||||
} else {
|
||||
result.z = (FLT)(off_value);
|
||||
}
|
||||
}
|
||||
if (4 * X + 3 < C) {
|
||||
if (indices_int[3] == Y) {
|
||||
result.w = (FLT)(on_value);
|
||||
} else {
|
||||
result.w = (FLT)(off_value);
|
||||
}
|
||||
}
|
||||
WRITE_IMAGE(dst_data, (int2)(Y * out_shape.w + X, Z), result);
|
||||
}
|
||||
|
||||
__kernel void OneHotAxis3(__read_only image2d_t src_data, __write_only image2d_t dst_data, int2 in_image2d_shape,
|
||||
int4 out_shape, int depth, float on_value, float off_value, int C) {
|
||||
int X = get_global_id(0); // C4
|
||||
int Y = get_global_id(1); // W
|
||||
int Z = get_global_id(2); // H * N
|
||||
if (X >= out_shape.w || Y >= out_shape.z || Z >= out_shape.x * out_shape.y) return;
|
||||
int N = Z / out_shape.y;
|
||||
int H = Z % out_shape.y;
|
||||
int ci4_size = UP_DIV(out_shape.z, C4NUM);
|
||||
int in_index_c4 = (N * out_shape.y + H) * ci4_size + Y / 4;
|
||||
int in_index_c4_remainder = Y % 4;
|
||||
FLT4 indices =
|
||||
READ_IMAGE(src_data, smp_zero, (int2)(in_index_c4 % in_image2d_shape.x, in_index_c4 / in_image2d_shape.x));
|
||||
int *indices_int = (int *)&indices;
|
||||
int index_one = indices_int[in_index_c4_remainder];
|
||||
FLT4 result = (FLT4)(0.f);
|
||||
if (4 * X < C) {
|
||||
if (index_one == 4 * X) {
|
||||
result.x = (FLT)(on_value);
|
||||
} else {
|
||||
result.x = (FLT)(off_value);
|
||||
}
|
||||
}
|
||||
if (4 * X + 1 < C) {
|
||||
if (index_one == 4 * X + 1) {
|
||||
result.y = (FLT)(on_value);
|
||||
} else {
|
||||
result.y = (FLT)(off_value);
|
||||
}
|
||||
}
|
||||
if (4 * X + 2 < C) {
|
||||
if (index_one == 4 * X + 2) {
|
||||
result.z = (FLT)(on_value);
|
||||
} else {
|
||||
result.z = (FLT)(off_value);
|
||||
}
|
||||
}
|
||||
if (4 * X + 3 < C) {
|
||||
if (index_one == 4 * X + 3) {
|
||||
result.w = (FLT)(on_value);
|
||||
} else {
|
||||
result.w = (FLT)(off_value);
|
||||
}
|
||||
}
|
||||
WRITE_IMAGE(dst_data, (int2)(Y * out_shape.w + X, Z), result);
|
||||
}
|
||||
|
||||
__kernel void OneHot2DAxis0(__read_only image2d_t src_data, __write_only image2d_t dst_data, int2 in_image2d_shape,
|
||||
int4 out_shape, int depth, float on_value, float off_value, int C) {
|
||||
int X = get_global_id(0); // C4
|
||||
int Y = get_global_id(1); // W
|
||||
int Z = get_global_id(2); // N
|
||||
if (X >= out_shape.w || Y >= out_shape.z || Z >= out_shape.x * out_shape.y) return;
|
||||
FLT4 result = (FLT4)(0.f);
|
||||
int channel = 4 * X;
|
||||
if (channel < C) {
|
||||
FLT4 indices = READ_IMAGE(src_data, smp_zero, (int2)(0, channel));
|
||||
int index = ((int *)&indices)[0];
|
||||
if (index == Z) {
|
||||
result.x = (FLT)(on_value);
|
||||
} else {
|
||||
result.x = (FLT)(off_value);
|
||||
}
|
||||
}
|
||||
channel++;
|
||||
if (channel < C) {
|
||||
FLT4 indices = READ_IMAGE(src_data, smp_zero, (int2)(0, channel));
|
||||
int index = ((int *)&indices)[0];
|
||||
if (index == Z) {
|
||||
result.y = (FLT)(on_value);
|
||||
} else {
|
||||
result.y = (FLT)(off_value);
|
||||
}
|
||||
}
|
||||
channel++;
|
||||
if (channel < C) {
|
||||
FLT4 indices = READ_IMAGE(src_data, smp_zero, (int2)(0, channel));
|
||||
int index = ((int *)&indices)[0];
|
||||
if (index == Z) {
|
||||
result.z = (FLT)(on_value);
|
||||
} else {
|
||||
result.z = (FLT)(off_value);
|
||||
}
|
||||
}
|
||||
channel++;
|
||||
if (channel < C) {
|
||||
FLT4 indices = READ_IMAGE(src_data, smp_zero, (int2)(0, channel));
|
||||
int index = ((int *)&indices)[0];
|
||||
if (index == Z) {
|
||||
result.w = (FLT)(on_value);
|
||||
} else {
|
||||
result.w = (FLT)(off_value);
|
||||
}
|
||||
}
|
||||
WRITE_IMAGE(dst_data, (int2)(Y * out_shape.w + X, Z), result);
|
||||
}
|
@ -0,0 +1,102 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include "include/errorcode.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "src/runtime/kernel/opencl/kernel/one_hot.h"
|
||||
#include "src/runtime/kernel/opencl/cl/one_hot.cl.inc"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kGPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::PrimitiveType_OneHot;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
int OneHotOpenCLKernel::CheckSpecs() { return RET_OK; }
|
||||
|
||||
int OneHotOpenCLKernel::Prepare() {
|
||||
std::string kernel_name = "OneHot";
|
||||
auto param = reinterpret_cast<OneHotParameter *>(op_parameter_);
|
||||
in_shape_ = Image2DInfo(in_tensors_[0]);
|
||||
out_shape_ = Image2DInfo(out_tensors_[0]);
|
||||
axis_ = out_shape_.AlignAxis(param->axis_);
|
||||
if (in_tensors_[0]->shape().size() == 1 && axis_ == 0) {
|
||||
kernel_name += "2DAxis0";
|
||||
} else {
|
||||
kernel_name += "Axis" + std::to_string(axis_);
|
||||
}
|
||||
#ifdef PROGRAM_WITH_IL
|
||||
kernel_ = ocl_runtime_->GetKernelFromBinary(kernel_name);
|
||||
#else
|
||||
std::set<std::string> build_options;
|
||||
std::string source = one_hot_source;
|
||||
std::string program_name = "OneHot";
|
||||
ocl_runtime_->LoadSource(program_name, source);
|
||||
ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name, build_options);
|
||||
#endif
|
||||
InitWeights();
|
||||
SetConstArgs();
|
||||
SetGlobalLocal();
|
||||
MS_LOG(DEBUG) << kernel_name << " Init Done!";
|
||||
return mindspore::lite::RET_OK;
|
||||
}
|
||||
|
||||
int OneHotOpenCLKernel::InitWeights() {
|
||||
if (in_tensors_.size() <= 1) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
depth_ = static_cast<int32_t *>(in_tensors_[1]->data_c())[0];
|
||||
if (in_tensors_.size() > 2) {
|
||||
on_value_ = static_cast<float *>(in_tensors_[2]->data_c())[0];
|
||||
}
|
||||
if (in_tensors_.size() > 3) {
|
||||
off_value_ = static_cast<float *>(in_tensors_[3]->data_c())[0];
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
void OneHotOpenCLKernel::SetConstArgs() {
|
||||
cl_int2 cl_in_image2d_shape = {static_cast<cl_int>(in_shape_.width), static_cast<cl_int>(in_shape_.height)};
|
||||
cl_int4 cl_out_shape = {static_cast<cl_int>(out_shape_.N), static_cast<cl_int>(out_shape_.H),
|
||||
static_cast<cl_int>(out_shape_.W), static_cast<cl_int>(out_shape_.Slice)};
|
||||
int arg_idx = 2;
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, cl_in_image2d_shape);
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, cl_out_shape);
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, depth_);
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, on_value_);
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, off_value_);
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, static_cast<int>(out_shape_.C));
|
||||
}
|
||||
void OneHotOpenCLKernel::SetGlobalLocal() {
|
||||
global_range_ = {out_shape_.Slice, out_shape_.W, out_shape_.H * out_shape_.N};
|
||||
}
|
||||
|
||||
int OneHotOpenCLKernel::Run() {
|
||||
MS_LOG(DEBUG) << this->name() << " Running!";
|
||||
int arg_idx = 0;
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, in_tensors_[0]->data_c());
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->data_c());
|
||||
ocl_runtime_->RunKernel(kernel_, global_range_, local_range_, nullptr);
|
||||
return mindspore::lite::RET_OK;
|
||||
}
|
||||
|
||||
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_OneHot, OpenCLKernelCreator<OneHotOpenCLKernel>)
|
||||
REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_OneHot, OpenCLKernelCreator<OneHotOpenCLKernel>)
|
||||
} // namespace mindspore::kernel
|
@ -0,0 +1,52 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_ONE_HOT_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_ONE_HOT_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "src/lite_kernel.h"
|
||||
#include "src/runtime/kernel/opencl/opencl_kernel.h"
|
||||
#include "nnacl/fp32/one_hot.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class OneHotOpenCLKernel : public OpenCLKernel {
|
||||
public:
|
||||
OneHotOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs)
|
||||
: OpenCLKernel(parameter, inputs, outputs) {}
|
||||
~OneHotOpenCLKernel() override = default;
|
||||
|
||||
int Run() override;
|
||||
int Prepare() override;
|
||||
int InitWeights() override;
|
||||
int CheckSpecs() override;
|
||||
void SetConstArgs() override;
|
||||
void SetGlobalLocal() override;
|
||||
|
||||
private:
|
||||
cl::Kernel kernel_;
|
||||
int depth_{0};
|
||||
float on_value_{1.0f};
|
||||
float off_value_{0.0f};
|
||||
int axis_{0};
|
||||
Image2DInfo in_shape_ = Image2DInfo(nullptr);
|
||||
Image2DInfo out_shape_ = Image2DInfo(nullptr);
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_ONE_HOT_H_
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in new issue