relu optimize

pull/9661/head
wilfChen 4 years ago
parent 651b1c3577
commit c1d3bd2160

@ -0,0 +1,27 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADD_RELU_V2_IMPL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADD_RELU_V2_IMPL_H_
#include "runtime/device/gpu/cuda_common.h"
template <typename T>
void AddReluV2(const size_t num, const T *x1, const T *x2, T *y, uint32_t *mask, cudaStream_t cuda_stream);
template <typename T>
void AddReluGradV2(const size_t size, const T *x1, const T *x2, const uint32_t *mask, T *dx, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADD_RELU_IMPL_H_

@ -0,0 +1,68 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/cuda_impl/add_relu_v2_impl.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh"
template <typename T>
__global__ void AddReluV2Kernel(const size_t num, const T *x1, const T *x2, T *y, uint32_t *mask) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num; i += gridDim.x * blockDim.x) {
T sum = x1[i] + x2[i];
bool p = sum > static_cast<T>(0);
y[i] = p ? sum : static_cast<T>(0);
auto warp_predict = BallotSync(p, __activemask());
if (LaneId() == 0) {
mask[WarpId(i)] = warp_predict;
}
}
}
template <typename T>
void AddReluV2(const size_t num, const T *x1, const T *x2, T *y, uint32_t *mask, cudaStream_t cuda_stream) {
AddReluV2Kernel<<<kBlocksPerGrid(num), kThreadsPerBlock, 0, cuda_stream>>>(num, x1, x2, y, mask);
}
template <typename T>
__global__ void AddReluGradV2Kernel(const size_t num, const T *x1, const T *x2, const uint32_t *mask, T *dx) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num; i += gridDim.x * blockDim.x) {
bool positive = mask[WarpId(i)] & (1 << LaneId());
dx[i] = positive ? x1[i] + x2[i] : static_cast<T>(0);
}
}
template <typename T>
void AddReluGradV2(const size_t num, const T *x1, const T *x2, const uint32_t *mask, T *dx, cudaStream_t cuda_stream) {
AddReluGradV2Kernel<<<kBlocksPerGrid(num), kThreadsPerBlock, 0, cuda_stream>>>(num, x1, x2, mask, dx);
}
template void AddReluV2(const size_t num, const float *x1, const float *x2, float *y, uint32_t *mask,
cudaStream_t cuda_stream);
template void AddReluV2(const size_t num, const half *x1, const half *x2, half *y, uint32_t *mask,
cudaStream_t cuda_stream);
template void AddReluV2(const size_t num, const int32_t *x1, const int32_t *x2, int32_t *y, uint32_t *mask,
cudaStream_t cuda_stream);
template void AddReluV2(const size_t num, const int64_t *x1, const int64_t *x2, int64_t *y, uint32_t *mask,
cudaStream_t cuda_stream);
template void AddReluGradV2(const size_t num, const float *x1, const float *x2, const uint32_t *mask, float *dx,
cudaStream_t cuda_stream);
template void AddReluGradV2(const size_t num, const half *x1, const half *x2, const uint32_t *mask, half *dx,
cudaStream_t cuda_stream);
template void AddReluGradV2(const size_t num, const int32_t *x1, const int32_t *x2, const uint32_t *mask, int32_t *dx,
cudaStream_t cuda_stream);
template void AddReluGradV2(const size_t num, const int64_t *x1, const int64_t *x2, const uint32_t *mask, int64_t *dx,
cudaStream_t cuda_stream);

@ -0,0 +1,27 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADD_RELU_V2_IMPL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADD_RELU_V2_IMPL_H_
#include "runtime/device/gpu/cuda_common.h"
template <typename T>
void AddReluV2(const size_t num, const T *x1, const T *x2, T *y, uint32_t *mask, cudaStream_t cuda_stream);
template <typename T>
void AddReluGradV2(const size_t size, const T *x1, const T *x2, const uint32_t *mask, T *dx, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADD_RELU_IMPL_H_

@ -15,6 +15,7 @@
*/
#include "backend/kernel_compiler/gpu/cuda_impl/relu_impl.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh"
#include "runtime/device/gpu/cuda_common.h"
template <typename T>
@ -34,3 +35,47 @@ template void CalReLU(int size, float *input_addr, float *output_addr, cudaStrea
template void CalReLU(int size, half *input_addr, half *output_addr, cudaStream_t cuda_stream);
template void CalReLU(int size, int32_t *input_addr, int32_t *output_addr, cudaStream_t cuda_stream);
template void CalReLU(int size, int64_t *input_addr, int64_t *output_addr, cudaStream_t cuda_stream);
template <typename T>
__global__ void ReluV2Kernel(const size_t num, const T *x, T *y, uint32_t *mask) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num; i += blockDim.x * gridDim.x) {
T v = x[i];
bool p = v > static_cast<T>(0);
y[i] = p ? v : static_cast<T>(0);
auto warp_predict = BallotSync(p, __activemask());
if (LaneId() == 0) {
mask[WarpId(i)] = warp_predict;
}
}
}
template <typename T>
void ReluV2(const size_t num, const T *x, T *y, uint32_t *mask, cudaStream_t cuda_stream) {
ReluV2Kernel<<<kBlocksPerGrid(num), kThreadsPerBlock, 0, cuda_stream>>>(num, x, y, mask);
}
template <typename T>
__global__ void ReluGradV2Kernel(const size_t num, const T *dy, const uint32_t *mask, T *dx) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num; i += blockDim.x * gridDim.x) {
bool p = mask[WarpId(i)] & (1 << LaneId());
dx[i] = p ? dy[i] : static_cast<T>(0);
}
}
template <typename T>
void ReluGradV2(const size_t num, const T *dy, const uint32_t *mask, T *dx, cudaStream_t cuda_stream) {
ReluGradV2Kernel<<<kBlocksPerGrid(num), kThreadsPerBlock, 0, cuda_stream>>>(num, dy, mask, dx);
}
template void ReluV2(const size_t num, const float *x, float *y, uint32_t *mask, cudaStream_t cuda_stream);
template void ReluV2(const size_t num, const half *x, half *y, uint32_t *mask, cudaStream_t cuda_stream);
template void ReluV2(const size_t num, const int32_t *x, int32_t *y, uint32_t *mask, cudaStream_t cuda_stream);
template void ReluV2(const size_t num, const int64_t *x, int64_t *y, uint32_t *mask, cudaStream_t cuda_stream);
template void ReluGradV2(const size_t num, const float *dy, const uint32_t *mask, float *dx, cudaStream_t cuda_stream);
template void ReluGradV2(const size_t num, const half *dy, const uint32_t *mask, half *dx, cudaStream_t cuda_stream);
template void ReluGradV2(const size_t num, const int32_t *dy, const uint32_t *mask, int32_t *dx,
cudaStream_t cuda_stream);
template void ReluGradV2(const size_t num, const int64_t *dy, const uint32_t *mask, int64_t *dx,
cudaStream_t cuda_stream);

@ -20,4 +20,9 @@
#include "runtime/device/gpu/cuda_common.h"
template <typename T>
void CalReLU(int input_size, T *input_addr, T *output_addr, cudaStream_t cuda_stream);
template <typename T>
void ReluV2(const size_t num, const T *x, T *y, uint32_t *mask, cudaStream_t cuda_stream);
template <typename T>
void ReluGradV2(const size_t num, const T *dy, const uint32_t *mask, T *dx, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RELU_H_

@ -20,16 +20,18 @@
#include <cuda_fp16.h>
#include "runtime/device/gpu/cuda_common.h"
#define kThreadsPerBlock (256)
#define kBlocksPerGrid(n) ((n + kThreadsPerBlock - 1) / kThreadsPerBlock)
__device__ static inline double MsAtomicAdd(double *address, const double val) {
unsigned long long int* address_as_ull = (unsigned long long int*)address; // NOLINT
unsigned long long int old = *address_as_ull; // NOLINT
unsigned long long int assumed; // NOLINT
do {
assumed = old;
old = atomicCAS(address_as_ull, assumed, __double_as_longlong(val + __longlong_as_double(assumed)));
}
while (assumed != old); // NOLINT
return __longlong_as_double(old);
unsigned long long int *address_as_ull = (unsigned long long int *)address; // NOLINT
unsigned long long int old = *address_as_ull; // NOLINT
unsigned long long int assumed; // NOLINT
do {
assumed = old;
old = atomicCAS(address_as_ull, assumed, __double_as_longlong(val + __longlong_as_double(assumed)));
} while (assumed != old); // NOLINT
return __longlong_as_double(old);
}
__device__ static inline float MsAtomicAdd(float *address, const float val) { return atomicAdd(address, val); }
@ -42,7 +44,7 @@ __device__ static inline unsigned int MsAtomicAdd(unsigned int *address, unsigne
__device__ static inline int8_t MsAtomicAdd(int8_t *address, int8_t val) {
size_t offset = (size_t)address & 3;
uint32_t * address_as_ui = (uint32_t *)((char *)address - offset); // NOLINT
uint32_t *address_as_ui = (uint32_t *)((char *)address - offset); // NOLINT
uint32_t old = *address_as_ui;
uint32_t shift = offset * 8;
uint32_t old_byte;
@ -60,27 +62,27 @@ __device__ static inline int8_t MsAtomicAdd(int8_t *address, int8_t val) {
}
__device__ static inline int64_t MsAtomicAdd(int64_t *address, int64_t val) {
unsigned long long * address_as_ui = (unsigned long long *) (address); // NOLINT
unsigned long long old = *address_as_ui; // NOLINT
unsigned long long newval; // NOLINT
unsigned long long assumed; // NOLINT
unsigned long long *address_as_ui = (unsigned long long *)(address); // NOLINT
unsigned long long old = *address_as_ui; // NOLINT
unsigned long long newval; // NOLINT
unsigned long long assumed; // NOLINT
do {
assumed = old;
newval = val + (int64_t)old;
newval = val + (int64_t)old;
old = atomicCAS(address_as_ui, assumed, newval);
} while (assumed != old);
return (int64_t)old;
}
__device__ static inline bool MsAtomicAdd(bool *address, bool val) {
*address = address && val;
return address[0];
*address = address && val;
return address[0];
}
__device__ static inline unsigned char MsAtomicAdd(short *address, short val) { // NOLINT
bool is_4_byte_aligned = ((size_t) address & 2) == 0;
unsigned int *aligned = (unsigned int *) ((size_t) address & ~2);
bool is_4_byte_aligned = ((size_t)address & 2) == 0;
unsigned int *aligned = (unsigned int *)((size_t)address & ~2);
unsigned int old = *aligned;
unsigned int assumed;
@ -91,16 +93,16 @@ __device__ static inline unsigned char MsAtomicAdd(short *address, short val) {
if (is_4_byte_aligned) {
replacement = (old & 0xffff0000) | (((old & 0xffff) + val) & 0xffff);
} else {
replacement = old + ((unsigned int) val << 16);
replacement = old + ((unsigned int)val << 16);
}
old = atomicCAS(aligned, assumed, replacement);
} while (assumed != old);
if (is_4_byte_aligned) {
return (short) (old & 0xffff); // NOLINT
return (short)(old & 0xffff); // NOLINT
} else {
return (short) (old >> 16); // NOLINT
return (short)(old >> 16); // NOLINT
}
}
@ -112,7 +114,8 @@ __device__ static inline half MsAtomicAdd(half *address, half val) {
unsigned short old_as_us; // NOLINT
do {
assumed = old;
old_as_us = static_cast<unsigned short>(reinterpret_cast<size_t>(address) & 2 ? old >> 16 : old & 0xffff); // NOLINT
old_as_us =
static_cast<unsigned short>(reinterpret_cast<size_t>(address) & 2 ? old >> 16 : old & 0xffff); // NOLINT
half sum = __float2half_rn(__half2float(__ushort_as_half(old_as_us)) + static_cast<float>(val));
unsigned short sum_as_us = __half_as_ushort(sum); // NOLINT
unsigned int sum_as_ui =
@ -123,16 +126,16 @@ __device__ static inline half MsAtomicAdd(half *address, half val) {
return half(raw);
}
__device__ static inline unsigned char MsAtomicAdd(unsigned char* address, unsigned char val) {
__device__ static inline unsigned char MsAtomicAdd(unsigned char *address, unsigned char val) {
// We use cuda's atomicCAS(unsigned int*, unsigned int, unsigned int) to
// implement MsAtomicAdd. An unsigned char may not be 4 byte aligned, but
// unsigned int* must be 4 byte aligned. This variable contains the offset,
// in bytes, of the beginning of address, within the 4 byte aligned space that
// contains it.
size_t address_offset = (size_t) address & 3;
size_t address_offset = (size_t)address & 3;
// Address of the 4 byte aligned space that contains address.
unsigned int* aligned = (unsigned int*) ((unsigned char*) address - address_offset);
unsigned int *aligned = (unsigned int *)((unsigned char *)address - address_offset);
// Constants which will be used later with __byte_perm. __byte_perm is a cuda
// function which takes 3 unsigned int's (x, y, selector) as parameters and
@ -166,9 +169,9 @@ __device__ static inline unsigned char MsAtomicAdd(unsigned char* address, unsig
return __byte_perm(old, 0, address_offset);
}
__device__ static inline char MsAtomicAdd(char* address, char val) {
size_t address_offset = (size_t) address & 3;
unsigned int* aligned = reinterpret_cast<unsigned int *>(reinterpret_cast<char *>(address) - address_offset);
__device__ static inline char MsAtomicAdd(char *address, char val) {
size_t address_offset = (size_t)address & 3;
unsigned int *aligned = reinterpret_cast<unsigned int *>(reinterpret_cast<char *>(address) - address_offset);
unsigned int selectors[] = {0x3214, 0x3240, 0x3410, 0x4210};
unsigned int selector = selectors[address_offset];
unsigned int old = *aligned;
@ -185,4 +188,12 @@ __device__ static inline char MsAtomicAdd(char* address, char val) {
return __byte_perm(old, 0, address_offset);
}
__device__ __forceinline__ unsigned BallotSync(int predicate, unsigned mask = 0xffffffff) {
return __ballot_sync(mask, predicate);
}
enum : unsigned { warp_size = 32, log_wap_size = 5 };
__device__ __forceinline__ unsigned LaneId() { return threadIdx.x & (warp_size - 1); }
__device__ __forceinline__ unsigned WarpId(const unsigned &tid) { return tid >> log_wap_size; }
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UTIL_H_

@ -0,0 +1,50 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/nn/fused_add_relu_grad_v2_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(FusedAddReluGradV2,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeUInt32)
.AddOutputAttr(kNumberTypeFloat32),
FusedAddReluGradV2GpuKernel, float)
MS_REG_GPU_KERNEL_ONE(FusedAddReluGradV2,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeUInt32)
.AddOutputAttr(kNumberTypeFloat16),
FusedAddReluGradV2GpuKernel, half)
MS_REG_GPU_KERNEL_ONE(FusedAddReluGradV2,
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeUInt32)
.AddOutputAttr(kNumberTypeInt32),
FusedAddReluGradV2GpuKernel, int32_t)
MS_REG_GPU_KERNEL_ONE(FusedAddReluGradV2,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeUInt32)
.AddOutputAttr(kNumberTypeInt64),
FusedAddReluGradV2GpuKernel, int64_t)
} // namespace kernel
} // namespace mindspore

@ -0,0 +1,86 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_ADD_RELU_GRAD_V2_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_ADD_RELU_GRAD_V2_GPU_KERNEL_H_
#include <vector>
#include <algorithm>
#include <functional>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/add_relu_v2_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T>
class FusedAddReluGradV2GpuKernel : public GpuKernel {
public:
FusedAddReluGradV2GpuKernel() { ResetResource(); }
~FusedAddReluGradV2GpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
auto x1 = GetDeviceAddress<T>(inputs, 0);
auto x2 = GetDeviceAddress<T>(inputs, 1);
auto mask = GetDeviceAddress<uint32_t>(inputs, 2);
auto dx = GetDeviceAddress<T>(outputs, 0);
AddReluGradV2(element_num_, x1, x2, mask, dx, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
MS_EXCEPTION_IF_NULL(kernel_node);
auto shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
element_num_ = std::accumulate(shape.begin(), shape.end(), static_cast<size_t>(1), std::multiplies<size_t>());
InitSizeLists();
return true;
}
void ResetResource() noexcept override {
element_num_ = 0;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
protected:
void InitSizeLists() override {
auto size = element_num_ * sizeof(T);
input_size_list_.push_back(size);
input_size_list_.push_back(size);
input_size_list_.push_back(size);
output_size_list_.push_back(size);
size = (element_num_ + 31) / 32 * sizeof(uint32_t);
input_size_list_.push_back(size);
}
private:
size_t element_num_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_ADD_RELU_GRAD_V2_GPU_KERNEL_H_

@ -0,0 +1,50 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/nn/fused_add_relu_v2_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(FusedAddReluV2,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeUInt32),
FusedAddReluV2GpuKernel, float)
MS_REG_GPU_KERNEL_ONE(FusedAddReluV2,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeUInt32),
FusedAddReluV2GpuKernel, half)
MS_REG_GPU_KERNEL_ONE(FusedAddReluV2,
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeUInt32),
FusedAddReluV2GpuKernel, int32_t)
MS_REG_GPU_KERNEL_ONE(FusedAddReluV2,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeUInt32),
FusedAddReluV2GpuKernel, int64_t)
} // namespace kernel
} // namespace mindspore

@ -0,0 +1,85 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_ADD_RELU_V2_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_ADD_RELU_V2_GPU_KERNEL_H_
#include <vector>
#include <algorithm>
#include <functional>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/add_relu_v2_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T>
class FusedAddReluV2GpuKernel : public GpuKernel {
public:
FusedAddReluV2GpuKernel() { ResetResource(); }
~FusedAddReluV2GpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
auto x1 = GetDeviceAddress<T>(inputs, 0);
auto x2 = GetDeviceAddress<T>(inputs, 1);
auto y = GetDeviceAddress<T>(outputs, 0);
auto mask = GetDeviceAddress<uint32_t>(outputs, 1);
AddReluV2(element_num_, x1, x2, y, mask, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
MS_EXCEPTION_IF_NULL(kernel_node);
auto shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
element_num_ = std::accumulate(shape.begin(), shape.end(), static_cast<size_t>(1), std::multiplies<size_t>());
InitSizeLists();
return true;
}
void ResetResource() noexcept override {
element_num_ = 0;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
protected:
void InitSizeLists() override {
auto size = element_num_ * sizeof(T);
input_size_list_.push_back(size);
input_size_list_.push_back(size);
output_size_list_.push_back(size);
size = (element_num_ + 31) / 32 * sizeof(uint32_t);
output_size_list_.push_back(size);
}
private:
size_t element_num_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_ADD_RELU_V2_GPU_KERNEL_H_

@ -0,0 +1,38 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/nn/relu_grad_v2_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(
ReluGradV2,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat32),
ReluGradV2GpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
ReluGradV2,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat16),
ReluGradV2GpuKernel, half)
MS_REG_GPU_KERNEL_ONE(
ReluGradV2,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt32),
ReluGradV2GpuKernel, int32_t)
MS_REG_GPU_KERNEL_ONE(
ReluGradV2,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt64),
ReluGradV2GpuKernel, int64_t)
} // namespace kernel
} // namespace mindspore

@ -0,0 +1,83 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_GRAD_V2_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_GRAD_V2_GPU_KERNEL_H_
#include <vector>
#include <algorithm>
#include <functional>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/relu_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T>
class ReluGradV2GpuKernel : public GpuKernel {
public:
ReluGradV2GpuKernel() { ResetResource(); }
~ReluGradV2GpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
auto dy = GetDeviceAddress<T>(inputs, 0);
auto mask = GetDeviceAddress<uint32_t>(inputs, 1);
auto dx = GetDeviceAddress<T>(outputs, 0);
ReluGradV2(element_num_, dy, mask, dx, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
MS_EXCEPTION_IF_NULL(kernel_node);
auto shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
element_num_ = std::accumulate(shape.begin(), shape.end(), static_cast<size_t>(1), std::multiplies<size_t>());
InitSizeLists();
return true;
}
void ResetResource() noexcept override {
element_num_ = 0;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
protected:
void InitSizeLists() override {
auto size = element_num_ * sizeof(T);
input_size_list_.push_back(size);
output_size_list_.push_back(size);
auto mask_size = (element_num_ + 31) / 32 * sizeof(uint32_t);
input_size_list_.push_back(mask_size);
}
private:
size_t element_num_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_V2_GRAD_GPU_KERNEL_H_

@ -0,0 +1,37 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/nn/relu_v2_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(
ReLUV2,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt32),
ReluV2GpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
ReLUV2,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt32),
ReluV2GpuKernel, half)
MS_REG_GPU_KERNEL_ONE(
ReLUV2, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
ReluV2GpuKernel, int32_t)
MS_REG_GPU_KERNEL_ONE(
ReLUV2,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt32),
ReluV2GpuKernel, int64_t)
} // namespace kernel
} // namespace mindspore

@ -0,0 +1,82 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_V2_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_V2_GPU_KERNEL_H_
#include <vector>
#include <algorithm>
#include <functional>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/relu_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T>
class ReluV2GpuKernel : public GpuKernel {
public:
ReluV2GpuKernel() { ResetResource(); }
~ReluV2GpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
auto x = GetDeviceAddress<T>(inputs, 0);
auto y = GetDeviceAddress<T>(outputs, 0);
auto mask = GetDeviceAddress<uint32_t>(outputs, 1);
ReluV2(element_num_, x, y, mask, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
MS_EXCEPTION_IF_NULL(kernel_node);
auto shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
element_num_ = std::accumulate(shape.begin(), shape.end(), static_cast<size_t>(1), std::multiplies<size_t>());
InitSizeLists();
return true;
}
void ResetResource() noexcept override {
element_num_ = 0;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
protected:
void InitSizeLists() override {
auto size = element_num_ * sizeof(T);
input_size_list_.push_back(size);
output_size_list_.push_back(size);
auto mask_size = (element_num_ + 31) / 32 * sizeof(uint32_t);
output_size_list_.push_back(mask_size);
}
private:
size_t element_num_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_MASK_GPU_KERNEL_H_

@ -0,0 +1,89 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/gpu/add_relu_grad_v2_fusion.h"
#include <memory>
#include <vector>
#include <string>
#include "backend/session/anf_runtime_algorithm.h"
#include "ir/primitive.h"
#include "utils/utils.h"
#include "backend/optimizer/common/helper.h"
namespace mindspore {
namespace opt {
namespace {
kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) {
std::vector<std::string> inputs_format;
std::vector<std::string> outputs_format;
std::vector<TypeId> inputs_type;
std::vector<TypeId> outputs_type;
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(node); ++input_index) {
inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index));
inputs_format.push_back(kOpFormat_DEFAULT);
}
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(node); ++output_index) {
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index));
outputs_format.push_back(kOpFormat_DEFAULT);
}
builder.SetInputsDeviceType(inputs_type);
builder.SetInputsFormat(inputs_format);
builder.SetOutputsDeviceType(outputs_type);
builder.SetOutputsFormat(outputs_format);
return builder.Build();
}
} // namespace
const BaseRef AddReluGradV2Fusion::DefinePattern() const {
VectorRef relu_grad = VectorRef({prim::kPrimReluGradV2, VectorRef({prim::kPrimTensorAdd, x1_, x2_}), mask_});
return relu_grad;
}
const AnfNodePtr AddReluGradV2Fusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(equiv);
auto x1 = utils::cast<AnfNodePtr>((*equiv)[x1_]);
auto x2 = utils::cast<AnfNodePtr>((*equiv)[x2_]);
auto mask = utils::cast<AnfNodePtr>((*equiv)[mask_]);
auto tensor_add = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
MS_EXCEPTION_IF_NULL(tensor_add);
auto users = GetRealNodeUsedList(graph, tensor_add);
if (users->size() > 1) {
return nullptr;
}
auto prim = std::make_shared<Primitive>(kFusedAddReluGradV2Name);
MS_EXCEPTION_IF_NULL(prim);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), x1, x2, mask};
auto add_relugrad = graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(add_relugrad);
auto types = {AnfAlgo::GetOutputInferDataType(node, 0)};
auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)};
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, add_relugrad.get());
add_relugrad->set_scope(node->scope());
auto build_info = GenerateKernelBuildInfo(add_relugrad);
AnfAlgo::SetSelectKernelBuildInfo(build_info, add_relugrad.get());
return add_relugrad;
}
} // namespace opt
} // namespace mindspore

@ -0,0 +1,42 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_ADD_RELU_GRAD_V2_FUSION_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_ADD_RELU_GRAD_V2_FUSION_H_
#include <memory>
#include "backend/optimizer/common/optimizer.h"
namespace mindspore {
namespace opt {
class AddReluGradV2Fusion : public PatternProcessPass {
public:
explicit AddReluGradV2Fusion(bool multigraph = true) : PatternProcessPass("add_relu_grad", multigraph) {
x1_ = std::make_shared<Var>();
x2_ = std::make_shared<Var>();
mask_ = std::make_shared<Var>();
}
~AddReluGradV2Fusion() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
VarPtr x1_;
VarPtr x2_;
VarPtr mask_;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_ADD_RELUGRAD_FUSION_H_

@ -0,0 +1,94 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/gpu/add_relu_v2_fusion.h"
#include <memory>
#include <vector>
#include <string>
#include "backend/session/anf_runtime_algorithm.h"
#include "ir/primitive.h"
#include "utils/utils.h"
#include "backend/optimizer/common/helper.h"
namespace mindspore {
namespace opt {
namespace {
kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) {
std::vector<std::string> inputs_format;
std::vector<std::string> outputs_format;
std::vector<TypeId> inputs_type;
std::vector<TypeId> outputs_type;
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(node); ++input_index) {
inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index));
inputs_format.push_back(kOpFormat_DEFAULT);
}
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(node); ++output_index) {
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index));
outputs_format.push_back(kOpFormat_DEFAULT);
}
builder.SetInputsDeviceType(inputs_type);
builder.SetInputsFormat(inputs_format);
builder.SetOutputsDeviceType(outputs_type);
builder.SetOutputsFormat(outputs_format);
return builder.Build();
}
} // namespace
const BaseRef AddReluV2Fusion::DefinePattern() const {
VectorRef relu = VectorRef({prim::kPrimReluV2, VectorRef({prim::kPrimTensorAdd, x1_, x2_})});
return relu;
}
const AnfNodePtr AddReluV2Fusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(equiv);
auto x1 = utils::cast<AnfNodePtr>((*equiv)[x1_]);
auto x2 = utils::cast<AnfNodePtr>((*equiv)[x2_]);
auto tensor_add = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
MS_EXCEPTION_IF_NULL(tensor_add);
auto users = GetRealNodeUsedList(graph, tensor_add);
if (users->size() > 1) {
return nullptr;
}
auto prim = std::make_shared<Primitive>(kFusedAddReluV2Name);
MS_EXCEPTION_IF_NULL(prim);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), x1, x2};
auto add_relu = graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(add_relu);
std::vector<TypeId> types;
std::vector<std::vector<size_t>> shapes;
for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(node); i++) {
types.push_back(AnfAlgo::GetOutputInferDataType(node, i));
shapes.push_back(AnfAlgo::GetOutputInferShape(node, i));
}
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, add_relu.get());
add_relu->set_scope(node->scope());
auto build_info = GenerateKernelBuildInfo(add_relu);
AnfAlgo::SetSelectKernelBuildInfo(build_info, add_relu.get());
return add_relu;
}
} // namespace opt
} // namespace mindspore

@ -0,0 +1,40 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_ADD_RELU_FUSION_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_ADD_RELU_FUSION_H_
#include <memory>
#include "backend/optimizer/common/optimizer.h"
namespace mindspore {
namespace opt {
class AddReluV2Fusion : public PatternProcessPass {
public:
explicit AddReluV2Fusion(bool multigraph = true) : PatternProcessPass("add_relu_v2_fusion", multigraph) {
x1_ = std::make_shared<Var>();
x2_ = std::make_shared<Var>();
}
~AddReluV2Fusion() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
VarPtr x1_;
VarPtr x2_;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_ADD_RELU_FUSION_H_

@ -0,0 +1,151 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/gpu/relu_v2_pass.h"
#include <memory>
#include <string>
#include <vector>
#include <algorithm>
#include <functional>
#include "backend/session/anf_runtime_algorithm.h"
#include "ir/primitive.h"
#include "utils/utils.h"
#include "abstract/abstract_value.h"
#include "backend/optimizer/common/helper.h"
namespace mindspore {
namespace opt {
namespace {
const size_t kReluV2OutputNum = 2;
CNodePtr GetRelu(const CNodePtr &relu_grad) {
MS_EXCEPTION_IF_NULL(relu_grad);
if (relu_grad->size() != kReluGradInputNum) {
MS_LOG_EXCEPTION << "ReluGrad has wrong input size " << relu_grad->size();
}
auto relu_anf = relu_grad->input(2);
MS_EXCEPTION_IF_NULL(relu_anf);
return relu_anf->cast<CNodePtr>();
}
kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) {
std::vector<std::string> inputs_format;
std::vector<std::string> outputs_format;
std::vector<TypeId> inputs_type;
std::vector<TypeId> outputs_type;
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(node); ++input_index) {
inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index));
inputs_format.push_back(kOpFormat_DEFAULT);
}
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(node); ++output_index) {
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index));
outputs_format.push_back(kOpFormat_DEFAULT);
}
builder.SetInputsDeviceType(inputs_type);
builder.SetInputsFormat(inputs_format);
builder.SetOutputsDeviceType(outputs_type);
builder.SetOutputsFormat(outputs_format);
return builder.Build();
}
CNodePtr CreateReluV2(const FuncGraphPtr &graph, const CNodePtr &relu) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(relu);
if (relu->size() != kReluInputNum) {
MS_LOG_EXCEPTION << "Relu has wrong input size " << relu->size();
}
auto prim = std::make_shared<Primitive>(kReluV2OpName);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), relu->input(1)};
auto new_node = graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(new_node);
new_node->set_scope(relu->scope());
if (AnfAlgo::IsDynamicShape(relu)) {
return nullptr;
}
std::vector<size_t> output_shape = AnfAlgo::GetOutputInferShape(relu, 0);
auto element_num =
std::accumulate(output_shape.begin(), output_shape.end(), static_cast<size_t>(1), std::multiplies<size_t>());
std::vector<size_t> mask_shape = {(element_num + 31) / 32};
auto shapes = {AnfAlgo::GetOutputInferShape(relu, 0), mask_shape};
auto types = {AnfAlgo::GetOutputInferDataType(relu, 0), kNumberTypeUInt32};
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, new_node.get());
auto build_info = GenerateKernelBuildInfo(new_node);
AnfAlgo::SetSelectKernelBuildInfo(build_info, new_node.get());
return new_node;
}
CNodePtr CreateReluGradV2(const FuncGraphPtr &graph, const CNodePtr &relu_grad, const AnfNodePtr &second_input) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(relu_grad);
MS_EXCEPTION_IF_NULL(second_input);
auto prim = std::make_shared<Primitive>(kReluGradV2OpName);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), relu_grad->input(1), second_input};
auto new_node = graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(new_node);
new_node->set_scope(relu_grad->scope());
new_node->set_abstract(relu_grad->abstract());
std::vector<TypeId> types;
std::vector<std::vector<size_t>> shapes;
for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(relu_grad); i++) {
types.push_back(AnfAlgo::GetOutputInferDataType(relu_grad, i));
shapes.push_back(AnfAlgo::GetOutputInferShape(relu_grad, i));
}
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, new_node.get());
new_node->set_scope(relu_grad->scope());
auto build_info = GenerateKernelBuildInfo(new_node);
AnfAlgo::SetSelectKernelBuildInfo(build_info, new_node.get());
return new_node;
}
} // namespace
const BaseRef ReluV2Pass::DefinePattern() const {
VectorRef relu_grad({prim::kPrimReluGrad, dy_, VectorRef({prim::kPrimRelu, x_})});
return relu_grad;
}
const AnfNodePtr ReluV2Pass::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
auto relu_grad = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(relu_grad);
auto relu = GetRelu(relu_grad);
MS_EXCEPTION_IF_NULL(relu);
auto relu_v2 = CreateReluV2(graph, relu);
if (relu_v2 == nullptr) {
return nullptr;
}
std::vector<AnfNodePtr> relu_v2_node_outputs;
CreateMultipleOutputsOfAnfNode(graph, relu_v2, kReluV2OutputNum, &relu_v2_node_outputs);
auto relu_grad_v2 = CreateReluGradV2(graph, relu_grad, relu_v2_node_outputs[1]);
auto manage = graph->manager();
MS_EXCEPTION_IF_NULL(manage);
manage->Replace(relu, relu_v2_node_outputs[0]);
return relu_grad_v2;
}
} // namespace opt
} // namespace mindspore

@ -0,0 +1,40 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_RELU_V2_FUSION_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_RELU_V2_FUSION_H_
#include <memory>
#include "backend/optimizer/common/optimizer.h"
namespace mindspore {
namespace opt {
class ReluV2Pass : public PatternProcessPass {
public:
explicit ReluV2Pass(bool multigraph = true) : PatternProcessPass("relu_v2_fusion", multigraph) {
x_ = std::make_shared<Var>();
dy_ = std::make_shared<Var>();
}
~ReluV2Pass() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
VarPtr x_;
VarPtr dy_;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_RELU_V2_FUSION_H_

@ -35,6 +35,9 @@
#include "backend/optimizer/gpu/remove_format_transform_pair.h"
#include "backend/optimizer/gpu/remove_redundant_format_transform.h"
#include "backend/optimizer/gpu/reduce_precision_fusion.h"
#include "backend/optimizer/gpu/relu_v2_pass.h"
#include "backend/optimizer/gpu/add_relu_v2_fusion.h"
#include "backend/optimizer/gpu/add_relu_grad_v2_fusion.h"
#include "backend/optimizer/graph_kernel/add_atomic_clean_gpu.h"
#include "backend/optimizer/graph_kernel/arithmetic_simplify.h"
#include "backend/optimizer/graph_kernel/basic_ops_fusion.h"
@ -142,6 +145,9 @@ void GPUSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_gra
pm->AddPass(std::make_shared<opt::RemoveFormatTransformPair>());
pm->AddPass(std::make_shared<opt::RemoveRedundantFormatTransform>());
pm->AddPass(std::make_shared<opt::CudnnInplaceAggregate>());
pm->AddPass(std::make_shared<opt::ReluV2Pass>());
pm->AddPass(std::make_shared<opt::AddReluV2Fusion>());
pm->AddPass(std::make_shared<opt::AddReluGradV2Fusion>());
pm->AddPass(std::make_shared<opt::AllReduceFusion>());
pm->AddPass(std::make_shared<opt::GetitemTuple>());
pm->AddPass(std::make_shared<opt::ReducePrecisionFusion>("reduce_precision"));

@ -245,6 +245,8 @@ constexpr auto kBasicLSTMCellCStateGradOpName = "BasicLSTMCellCStateGrad";
constexpr auto kBasicLSTMCellCStateGradV2OpName = "BasicLSTMCellCStateGradV2";
constexpr auto kMatMulV2OpName = "MatMulV2";
constexpr auto kBroadcastToOpName = "BroadcastTo";
constexpr auto kFusedAddReluV2Name = "FusedAddReluV2";
constexpr auto kFusedAddReluGradV2Name = "FusedAddReluGradV2";
// Hcom Op Type
constexpr auto kHcomOpTypeAllReduce = "HcomAllReduce";

@ -146,6 +146,7 @@ inline const PrimitivePtr kPrimFusedBatchNormGradEx = std::make_shared<Primitive
inline const PrimitivePtr kPrimBatchNorm = std::make_shared<Primitive>("BatchNorm");
inline const PrimitivePtr kPrimBatchNormGrad = std::make_shared<Primitive>("BatchNormGrad");
inline const PrimitivePtr kPrimReluGrad = std::make_shared<Primitive>("ReluGrad");
inline const PrimitivePtr kPrimReluGradV2 = std::make_shared<Primitive>("ReluGradV2");
inline const PrimitivePtr kPrimRelu6Grad = std::make_shared<Primitive>("ReLU6Grad");
inline const PrimitivePtr kPrimConv2DBackpropInput = std::make_shared<Primitive>("Conv2DBackpropInput");
inline const PrimitivePtr kPrimConv2DBackpropFilter = std::make_shared<Primitive>("Conv2DBackpropFilter");

@ -0,0 +1,142 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
import mindspore.ops.operations._grad_ops as G
class ReluNet(nn.Cell):
def __init__(self):
super(ReluNet, self).__init__()
self.relu = P.ReLU()
self.relu_grad = G.ReluGrad()
def construct(self, x, dy):
y = self.relu(x)
dx = self.relu_grad(dy, y)
return y, dx
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_ReluV2():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=True)
x = Tensor(np.array([[[[-1, 1, 10],
[1, -1, 1],
[10, 1, -1]]]]).astype(np.float32))
dy = Tensor(np.array([[[[1, 0, 3],
[0, 1, 0],
[2, 1, 1]]]]).astype(np.float32))
expect_y = np.array([[[[0, 1, 10,],
[1, 0, 1,],
[10, 1, 0.]]]]).astype(np.float32)
expect_dx = np.array([[[[0, 0, 3],
[0, 0, 0],
[2, 1, 0]]]]).astype(np.float32)
net = ReluNet()
y, dx = net(Tensor(x), Tensor(dy))
assert np.allclose(y.asnumpy(), expect_y)
assert np.allclose(dx.asnumpy(), expect_dx)
class AddReluNet(nn.Cell):
def __init__(self):
super(AddReluNet, self).__init__()
self.add = P.TensorAdd()
self.relu = P.ReLU()
self.relu_grad = G.ReluGrad()
def construct(self, x1, x2, dy):
y = self.add(x1, x2)
y = self.relu(y)
dx = self.relu_grad(dy, y)
return y, dx
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_AddRelu():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=True)
x1 = Tensor(np.array([[[[-1, 1, 10],
[1, -1, 1],
[10, 1, -1]]]]).astype(np.float32))
x2 = Tensor(np.array([[[[-1, 1, 10],
[1, -1, 1],
[10, 1, -1]]]]).astype(np.float32))
dy = Tensor(np.array([[[[1, 0, 3],
[0, 1, 0],
[2, 1, 1]]]]).astype(np.float32))
expect_y = np.array([[[[0, 2, 20],
[2, 0, 2],
[20, 2, 0]]]]).astype(np.float32)
expect_dx = np.array([[[[0, 0, 3],
[0, 0, 0],
[2, 1, 0]]]]).astype(np.float32)
net = AddReluNet()
y, dx1 = net(Tensor(x1), Tensor(x2), Tensor(dy))
assert np.allclose(y.asnumpy(), expect_y)
assert np.allclose(dx1.asnumpy(), expect_dx)
class AddReluGradNet(nn.Cell):
def __init__(self):
super(AddReluGradNet, self).__init__()
self.add = P.TensorAdd()
self.relu = P.ReLU()
self.relu_grad = G.ReluGrad()
def construct(self, x, dy1, dy2):
y = self.relu(x)
dy = self.add(dy1, dy2)
dx = self.relu_grad(dy, y)
return y, dx
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_AddReluGrad():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=True)
x = Tensor(np.array([[[[-1, 1, 10],
[1, -1, 1],
[10, 1, -1]]]]).astype(np.float32))
dy1 = Tensor(np.array([[[[1, 0, 3],
[0, 1, 0],
[2, 1, 1]]]]).astype(np.float32))
dy2 = Tensor(np.array([[[[1, 0, 3],
[0, 1, 0],
[2, 1, 1]]]]).astype(np.float32))
expect_y = np.array([[[[0, 1, 10,],
[1, 0, 1,],
[10, 1, 0.]]]]).astype(np.float32)
expect_dx = np.array([[[[0, 0, 6],
[0, 0, 0],
[4, 2, 0]]]]).astype(np.float32)
net = AddReluGradNet()
y, dx1 = net(Tensor(x), Tensor(dy1), Tensor(dy2))
assert np.allclose(y.asnumpy(), expect_y)
assert np.allclose(dx1.asnumpy(), expect_dx)
Loading…
Cancel
Save