gpu support broadcast kernels

pull/934/head
wilfChen 5 years ago
parent afe048474d
commit 16f0688230

@ -0,0 +1,138 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/gpu/cuda_impl/broadcast_impl.cuh"
#include "device/gpu/cuda_common.h"
template <typename T, typename S>
struct GreaterFunc {
__device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return lhs > rhs ? true : false; }
};
template <typename T, typename S>
struct LessFunc {
__device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return lhs < rhs ? true : false; }
};
template <typename T, typename S>
struct MinimumFunc {
__device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return lhs < rhs ? lhs : rhs; }
};
template <typename T, typename S>
struct MaximumFunc {
__device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return lhs > rhs ? lhs : rhs; }
};
template <typename T, typename S>
struct PowerFunc {
__device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return pow(lhs, rhs); }
};
__device__ __forceinline__ int Index(const int &index, const int &dim) { return dim == 1 ? 0 : index; }
template <typename T, typename S, typename Func>
__device__ __forceinline__ void BroadcastOperator(const int &l0, const int &l1, const int &l2, const int &l3,
const int &r0, const int &r1, const int &r2, const int &r3,
const int &d0, const int &d1, const int &d2, const int &d3,
const T *input0, const T *input1, S *output) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < d0 * d1 * d2 * d3; pos += blockDim.x * gridDim.x) {
int i = pos / (d1 * d2 * d3) % d0;
int j = pos / (d2 * d3) % d1;
int k = pos / d3 % d2;
int l = pos % d3;
int l_index = Index(i, l0) * l1 * l2 * l3 + Index(j, l1) * l2 * l3 + Index(k, l2) * l3 + Index(l, l3);
int r_index = Index(i, r0) * r1 * r2 * r3 + Index(j, r1) * r2 * r3 + Index(k, r2) * r3 + Index(l, r3);
output[pos] = Func()(input0[l_index], input1[r_index]);
}
}
template <typename T, typename S>
__global__ void BroadcastKernel(const int l0, const int l1, const int l2, const int l3, const int r0, const int r1,
const int r2, const int r3, const int d0, const int d1, const int d2, const int d3,
enum BroadcastOpType op, const T *input0, const T *input1, S *output) {
switch (op) {
case BROADCAST_TYPE_GREATER:
return BroadcastOperator<T, S, GreaterFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
output);
case BROADCAST_TYPE_LESS:
return BroadcastOperator<T, S, LessFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
output);
case BROADCAST_TYPE_MINIMUM:
return BroadcastOperator<T, S, MinimumFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
output);
case BROADCAST_TYPE_MAXIMUM:
return BroadcastOperator<T, S, MaximumFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
output);
case BROADCAST_TYPE_POWER:
return BroadcastOperator<T, S, PowerFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
output);
}
}
template <typename T, typename S>
void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, const int &r2,
const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, enum BroadcastOpType op,
const T *input0, const T *input1, S *output, cudaStream_t stream) {
int size = d0 * d1 * d2 * d3;
BroadcastKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, op,
input0, input1, output);
}
template <typename T, typename S, typename Func>
__device__ __forceinline__ void NoBroadcastOperator(const int &nums, const T *input0, const T *input1, S *output) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nums; pos += blockDim.x * gridDim.x) {
output[pos] = Func()(input0[pos], input1[pos]);
}
}
template <typename T, typename S>
__global__ void NoBroadcastKernel(const int nums, enum BroadcastOpType op, const T *input0, const T *input1,
S *output) {
switch (op) {
case BROADCAST_TYPE_GREATER:
return NoBroadcastOperator<T, S, GreaterFunc<T, bool>>(nums, input0, input1, output);
case BROADCAST_TYPE_LESS:
return NoBroadcastOperator<T, S, LessFunc<T, bool>>(nums, input0, input1, output);
case BROADCAST_TYPE_MINIMUM:
return NoBroadcastOperator<T, S, MinimumFunc<T, S>>(nums, input0, input1, output);
case BROADCAST_TYPE_MAXIMUM:
return NoBroadcastOperator<T, S, MaximumFunc<T, S>>(nums, input0, input1, output);
case BROADCAST_TYPE_POWER:
return NoBroadcastOperator<T, S, PowerFunc<T, S>>(nums, input0, input1, output);
}
}
template <typename T, typename S>
void NoBroadcast(const int &nums, enum BroadcastOpType op, const T *input0, const T *input1, S *output,
cudaStream_t stream) {
NoBroadcastKernel<<<GET_BLOCKS(nums), GET_THREADS, 0, stream>>>(nums, op, input0, input1, output);
}
template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1,
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3,
enum BroadcastOpType op, const float *input0, const float *input1, bool *output,
cudaStream_t stream);
template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1,
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3,
enum BroadcastOpType op, const float *input0, const float *input1, float *output,
cudaStream_t stream);
template void NoBroadcast(const int &nums, enum BroadcastOpType op, const float *input0, const float *input1,
bool *output, cudaStream_t stream);
template void NoBroadcast(const int &nums, enum BroadcastOpType op, const float *input0, const float *input1,
float *output, cudaStream_t stream);

@ -0,0 +1,40 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_H_
#include "device/gpu/cuda_common.h"
enum BroadcastOpType {
BROADCAST_TYPE_GREATER = 0,
BROADCAST_TYPE_LESS = 1,
BROADCAST_TYPE_MAXIMUM = 2,
BROADCAST_TYPE_MINIMUM = 3,
BROADCAST_TYPE_POWER = 4,
BROADCAST_TYPE_INVALID = 0xffffffff,
};
template <typename T, typename S>
void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, const int &r2,
const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, enum BroadcastOpType op,
const T *input0, const T *input1, S *output, cudaStream_t stream);
template <typename T, typename S>
void NoBroadcast(const int &size, enum BroadcastOpType op, const T *input0, const T *input1, S *output,
cudaStream_t stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_H_

@ -38,13 +38,5 @@ MS_REG_GPU_KERNEL_ONE(
MS_REG_GPU_KERNEL_ONE(
Sub, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BinaryOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(
Maximum,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BinaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
Maximum,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BinaryOpGpuKernel, half)
} // namespace kernel
} // namespace mindspore

@ -27,16 +27,9 @@
#include "kernel/gpu/kernel_constants.h"
namespace mindspore {
namespace kernel {
enum BinaryOpType {
BINARY_OP_ADD = 0,
BINARY_OP_SUB,
BINARY_OP_MUL,
BINARY_OP_DIV,
BINARY_OP_MAX,
BINARY_OP_INVALID_TYPE = 255
};
enum BinaryOpType { BINARY_OP_ADD = 0, BINARY_OP_SUB, BINARY_OP_MUL, BINARY_OP_DIV, BINARY_OP_INVALID_TYPE = 255 };
static const std::map<std::string, BinaryOpType> kBinaryOpTypeMap = {
{"Sub", BINARY_OP_SUB}, {"Mul", BINARY_OP_MUL}, {"RealDiv", BINARY_OP_DIV}, {"Maximum", BINARY_OP_MAX}};
{"Sub", BINARY_OP_SUB}, {"Mul", BINARY_OP_MUL}, {"RealDiv", BINARY_OP_DIV}};
template <typename T>
class BinaryOpGpuKernel : public GpuKernel {
public:
@ -88,10 +81,6 @@ class BinaryOpGpuKernel : public GpuKernel {
inputB_addr = workspace_addr;
break;
}
case BINARY_OP_MAX: {
inputB_addr = input_addr2;
break;
}
default: {
MS_LOG(EXCEPTION) << "Binary operation " << binary_op_type_ << " is not supported.";
}
@ -209,10 +198,6 @@ class BinaryOpGpuKernel : public GpuKernel {
tensor_op_ = CUDNN_OP_TENSOR_ADD;
break;
}
case BINARY_OP_MAX: {
tensor_op_ = CUDNN_OP_TENSOR_MAX;
break;
}
default: {
MS_LOG(EXCEPTION) << "Binary operation " << binary_op_type_ << " is not supported.";
}

@ -0,0 +1,40 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/gpu/math/broadcast_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_TWO(
Greater,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernel, float, bool)
MS_REG_GPU_KERNEL_TWO(
Less, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernel, float, bool)
MS_REG_GPU_KERNEL_TWO(
Maximum,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernel, float, float)
MS_REG_GPU_KERNEL_TWO(
Minimum,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernel, float, float)
MS_REG_GPU_KERNEL_TWO(
Pow, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernel, float, float)
} // namespace kernel
} // namespace mindspore

@ -0,0 +1,132 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_BROADCAST_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_BROADCAST_GPU_KERNEL_H_
#include <cuda_runtime_api.h>
#include <vector>
#include <string>
#include <map>
#include "kernel/gpu/gpu_kernel.h"
#include "kernel/gpu/gpu_kernel_factory.h"
#include "kernel/gpu/cuda_impl/broadcast_impl.cuh"
#include "kernel/gpu/kernel_constants.h"
namespace mindspore {
namespace kernel {
template <typename T, typename S>
class BroadcastOpGpuKernel : public GpuKernel {
public:
BroadcastOpGpuKernel()
: op_type_(BROADCAST_TYPE_INVALID), need_broadcast_(false), input1_num_(1), input2_num_(1), output_num_(1) {}
~BroadcastOpGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) override {
T *lhs = GetDeviceAddress<T>(inputs, 0);
T *rhs = GetDeviceAddress<T>(inputs, 1);
S *output = GetDeviceAddress<S>(outputs, 0);
if (need_broadcast_) {
Broadcast(lhs_shape_[0], lhs_shape_[1], lhs_shape_[2], lhs_shape_[3], rhs_shape_[0], rhs_shape_[1], rhs_shape_[2],
rhs_shape_[3], output_shape_[0], output_shape_[1], output_shape_[2], output_shape_[3], op_type_, lhs,
rhs, output, reinterpret_cast<cudaStream_t>(stream_ptr));
} else {
NoBroadcast(output_num_, op_type_, lhs, rhs, output, reinterpret_cast<cudaStream_t>(stream_ptr));
}
return true;
}
bool Init(const CNodePtr &kernel_node) override {
GetOpType(kernel_node);
auto shape1 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto shape2 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
auto shape3 = AnfAlgo::GetOutputInferShape(kernel_node, 0);
need_broadcast_ = IsBroadcast(shape1, shape2);
if (need_broadcast_ && shape1.size() > 4) {
MS_LOG(EXCEPTION) << "Broadcast operation not support dim greater than 4";
}
for (size_t i = 0; i < shape1.size(); i++) {
lhs_shape_[i] = shape1[i];
rhs_shape_[i] = shape2[i];
output_shape_[i] = shape3[i];
input1_num_ *= shape1[i];
input2_num_ *= shape2[i];
output_num_ *= shape3[i];
}
InitSizeLists();
return true;
}
protected:
void InitResource() override { return; }
void InitSizeLists() override {
input_size_list_.push_back(input1_num_ * sizeof(T));
input_size_list_.push_back(input2_num_ * sizeof(T));
output_size_list_.push_back(output_num_ * sizeof(S));
}
private:
void GetOpType(const CNodePtr &kernel_node) {
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);
static std::map<std::string, BroadcastOpType> kBroadcastTypeMap = {
{"Greater", BROADCAST_TYPE_GREATER}, {"Less", BROADCAST_TYPE_LESS}, {"Maximum", BROADCAST_TYPE_MAXIMUM},
{"Minimum", BROADCAST_TYPE_MINIMUM}, {"Pow", BROADCAST_TYPE_POWER},
};
auto iter = kBroadcastTypeMap.find(kernel_name);
if (iter == kBroadcastTypeMap.end()) {
MS_LOG(EXCEPTION) << "operation " << kernel_name << " is not supported.";
} else {
op_type_ = iter->second;
}
}
bool IsBroadcast(const std::vector<size_t> &lhs, const std::vector<size_t> &rhs) {
for (size_t i = 0; i < lhs.size(); i++) {
if (lhs[i] != rhs[i]) {
return true;
}
}
return false;
}
BroadcastOpType op_type_;
bool need_broadcast_;
int input1_num_;
int input2_num_;
int output_num_;
int lhs_shape_[4] = {1, 1, 1, 1};
int rhs_shape_[4] = {1, 1, 1, 1};
int output_shape_[4] = {1, 1, 1, 1};
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_GPU_BINARYOP_GPU_KERNEL_H_

@ -0,0 +1,81 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
from mindspore.ops import operations as P
from mindspore.nn import Cell
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
import mindspore.context as context
import numpy as np
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_nobroadcast():
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
x1_np = np.random.rand(10, 20).astype(np.float32)
x2_np = np.random.rand(10, 20).astype(np.float32)
output_ms = P.Minimum()(Tensor(x1_np), Tensor(x2_np))
output_np = np.minimum(x1_np, x2_np)
assert np.allclose(output_ms.asnumpy(), output_np)
output_ms = P.Maximum()(Tensor(x1_np), Tensor(x2_np))
output_np = np.maximum(x1_np, x2_np)
assert np.allclose(output_ms.asnumpy(), output_np)
output_ms = P.Greater()(Tensor(x1_np), Tensor(x2_np))
output_np = x1_np > x2_np
assert np.allclose(output_ms.asnumpy(), output_np)
output_ms = P.Less()(Tensor(x1_np), Tensor(x2_np))
output_np = x1_np < x2_np
assert np.allclose(output_ms.asnumpy(), output_np)
output_ms = P.Pow()(Tensor(x1_np), Tensor(x2_np))
output_np = np.power(x1_np, x2_np)
assert np.allclose(output_ms.asnumpy(), output_np)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_broadcast():
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
x1_np = np.random.rand(3, 1, 5, 1).astype(np.float32)
x2_np = np.random.rand(1, 4, 1, 6).astype(np.float32)
output_ms = P.Minimum()(Tensor(x1_np), Tensor(x2_np))
output_np = np.minimum(x1_np, x2_np)
assert np.allclose(output_ms.asnumpy(), output_np)
output_ms = P.Maximum()(Tensor(x1_np), Tensor(x2_np))
output_np = np.maximum(x1_np, x2_np)
assert np.allclose(output_ms.asnumpy(), output_np)
output_ms = P.Greater()(Tensor(x1_np), Tensor(x2_np))
output_np = x1_np > x2_np
assert np.allclose(output_ms.asnumpy(), output_np)
output_ms = P.Less()(Tensor(x1_np), Tensor(x2_np))
output_np = x1_np < x2_np
assert np.allclose(output_ms.asnumpy(), output_np)
output_ms = P.Pow()(Tensor(x1_np), Tensor(x2_np))
output_np = np.power(x1_np, x2_np)
assert np.allclose(output_ms.asnumpy(), output_np)
Loading…
Cancel
Save