parent
af48c17798
commit
e2d56df80f
@ -0,0 +1,42 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST;
|
||||
|
||||
__kernel void Scale_IMG(__read_only image2d_t input, __read_only image2d_t scale, __read_only image2d_t offset,
|
||||
__write_only image2d_t output, const int2 output_shape) {
|
||||
int X = get_global_id(0);
|
||||
int Y = get_global_id(1);
|
||||
if (X >= output_shape.x || Y >= output_shape.y) {
|
||||
return;
|
||||
}
|
||||
|
||||
FLT4 in = read_imagef(input, smp_none, (int2)(X, Y));
|
||||
FLT4 s = read_imagef(scale, smp_none, (int2)(X, Y));
|
||||
FLT4 o = read_imagef(offset, smp_none, (int2)(X, Y));
|
||||
WRITE_IMAGE(output, (int2)(X, Y), in * s + o);
|
||||
}
|
||||
|
||||
__kernel void BoardcastScale_IMG(__read_only image2d_t input, float scale, float offset, __write_only image2d_t output,
|
||||
const int2 output_shape) {
|
||||
int X = get_global_id(0);
|
||||
int Y = get_global_id(1);
|
||||
if (X >= output_shape.x || Y >= output_shape.y) {
|
||||
return;
|
||||
}
|
||||
|
||||
FLT4 in = read_imagef(input, smp_none, (int2)(X, Y));
|
||||
WRITE_IMAGE(output, (int2)(X, Y), in * (FLT)scale + (FLT)offset);
|
||||
}
|
||||
|
||||
__kernel void Scale_C_IMG(__read_only image2d_t input, __read_only image2d_t scale, __read_only image2d_t offset,
|
||||
__write_only image2d_t output, const int2 output_shape, const int C) {
|
||||
int X = get_global_id(0);
|
||||
int Y = get_global_id(1);
|
||||
if (X >= output_shape.x || Y >= output_shape.y) {
|
||||
return;
|
||||
}
|
||||
|
||||
FLT4 in = read_imagef(input, smp_none, (int2)(X, Y));
|
||||
FLT4 s = read_imagef(scale, smp_none, (int2)(X % C, 0));
|
||||
FLT4 o = read_imagef(offset, smp_none, (int2)(X % C, 0));
|
||||
WRITE_IMAGE(output, (int2)(X, Y), in * s + o);
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,56 @@
|
||||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_SCALE_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_SCALE_H_
|
||||
|
||||
#include <vector>
|
||||
#include "nnacl/scale.h"
|
||||
#include "src/runtime/opencl/opencl_runtime.h"
|
||||
#include "src/runtime/kernel/opencl/opencl_kernel.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
|
||||
class ScaleOpenCLKernel : public OpenCLKernel {
|
||||
public:
|
||||
explicit ScaleOpenCLKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx)
|
||||
: OpenCLKernel(parameter, inputs, outputs) {}
|
||||
~ScaleOpenCLKernel() override;
|
||||
|
||||
int Init() override;
|
||||
int Run() override;
|
||||
int GetImageSize(size_t idx, std::vector<size_t> *img_size) override;
|
||||
|
||||
private:
|
||||
std::vector<size_t> InitGlobalSize() const;
|
||||
void Image2dGetWorkGroupSize();
|
||||
void BufferGetWorkGroupSize();
|
||||
int InitBuffer();
|
||||
|
||||
cl::Kernel kernel_;
|
||||
lite::opencl::OpenCLRuntime *ocl_runtime_;
|
||||
bool element_flag_{true};
|
||||
void *scale_ptr_{nullptr};
|
||||
void *offset_ptr_{nullptr};
|
||||
int axis_{0};
|
||||
|
||||
std::vector<size_t> local_size_;
|
||||
std::vector<size_t> global_size_;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_SCALE_H_
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in new issue