parent
651b1c3577
commit
c1d3bd2160
@ -0,0 +1,27 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADD_RELU_V2_IMPL_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADD_RELU_V2_IMPL_H_
|
||||
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
template <typename T>
|
||||
void AddReluV2(const size_t num, const T *x1, const T *x2, T *y, uint32_t *mask, cudaStream_t cuda_stream);
|
||||
|
||||
template <typename T>
|
||||
void AddReluGradV2(const size_t size, const T *x1, const T *x2, const uint32_t *mask, T *dx, cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADD_RELU_IMPL_H_
|
@ -0,0 +1,68 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/add_relu_v2_impl.cuh"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh"
|
||||
|
||||
template <typename T>
|
||||
__global__ void AddReluV2Kernel(const size_t num, const T *x1, const T *x2, T *y, uint32_t *mask) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num; i += gridDim.x * blockDim.x) {
|
||||
T sum = x1[i] + x2[i];
|
||||
bool p = sum > static_cast<T>(0);
|
||||
y[i] = p ? sum : static_cast<T>(0);
|
||||
|
||||
auto warp_predict = BallotSync(p, __activemask());
|
||||
if (LaneId() == 0) {
|
||||
mask[WarpId(i)] = warp_predict;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void AddReluV2(const size_t num, const T *x1, const T *x2, T *y, uint32_t *mask, cudaStream_t cuda_stream) {
|
||||
AddReluV2Kernel<<<kBlocksPerGrid(num), kThreadsPerBlock, 0, cuda_stream>>>(num, x1, x2, y, mask);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void AddReluGradV2Kernel(const size_t num, const T *x1, const T *x2, const uint32_t *mask, T *dx) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num; i += gridDim.x * blockDim.x) {
|
||||
bool positive = mask[WarpId(i)] & (1 << LaneId());
|
||||
dx[i] = positive ? x1[i] + x2[i] : static_cast<T>(0);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void AddReluGradV2(const size_t num, const T *x1, const T *x2, const uint32_t *mask, T *dx, cudaStream_t cuda_stream) {
|
||||
AddReluGradV2Kernel<<<kBlocksPerGrid(num), kThreadsPerBlock, 0, cuda_stream>>>(num, x1, x2, mask, dx);
|
||||
}
|
||||
|
||||
template void AddReluV2(const size_t num, const float *x1, const float *x2, float *y, uint32_t *mask,
|
||||
cudaStream_t cuda_stream);
|
||||
template void AddReluV2(const size_t num, const half *x1, const half *x2, half *y, uint32_t *mask,
|
||||
cudaStream_t cuda_stream);
|
||||
template void AddReluV2(const size_t num, const int32_t *x1, const int32_t *x2, int32_t *y, uint32_t *mask,
|
||||
cudaStream_t cuda_stream);
|
||||
template void AddReluV2(const size_t num, const int64_t *x1, const int64_t *x2, int64_t *y, uint32_t *mask,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
template void AddReluGradV2(const size_t num, const float *x1, const float *x2, const uint32_t *mask, float *dx,
|
||||
cudaStream_t cuda_stream);
|
||||
template void AddReluGradV2(const size_t num, const half *x1, const half *x2, const uint32_t *mask, half *dx,
|
||||
cudaStream_t cuda_stream);
|
||||
template void AddReluGradV2(const size_t num, const int32_t *x1, const int32_t *x2, const uint32_t *mask, int32_t *dx,
|
||||
cudaStream_t cuda_stream);
|
||||
template void AddReluGradV2(const size_t num, const int64_t *x1, const int64_t *x2, const uint32_t *mask, int64_t *dx,
|
||||
cudaStream_t cuda_stream);
|
@ -0,0 +1,27 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADD_RELU_V2_IMPL_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADD_RELU_V2_IMPL_H_
|
||||
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
template <typename T>
|
||||
void AddReluV2(const size_t num, const T *x1, const T *x2, T *y, uint32_t *mask, cudaStream_t cuda_stream);
|
||||
|
||||
template <typename T>
|
||||
void AddReluGradV2(const size_t size, const T *x1, const T *x2, const uint32_t *mask, T *dx, cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADD_RELU_IMPL_H_
|
@ -0,0 +1,50 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/nn/fused_add_relu_grad_v2_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(FusedAddReluGradV2,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
FusedAddReluGradV2GpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(FusedAddReluGradV2,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
FusedAddReluGradV2GpuKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(FusedAddReluGradV2,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
FusedAddReluGradV2GpuKernel, int32_t)
|
||||
MS_REG_GPU_KERNEL_ONE(FusedAddReluGradV2,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
FusedAddReluGradV2GpuKernel, int64_t)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
@ -0,0 +1,86 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_ADD_RELU_GRAD_V2_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_ADD_RELU_GRAD_V2_GPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/add_relu_v2_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class FusedAddReluGradV2GpuKernel : public GpuKernel {
|
||||
public:
|
||||
FusedAddReluGradV2GpuKernel() { ResetResource(); }
|
||||
~FusedAddReluGradV2GpuKernel() override = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
auto x1 = GetDeviceAddress<T>(inputs, 0);
|
||||
auto x2 = GetDeviceAddress<T>(inputs, 1);
|
||||
auto mask = GetDeviceAddress<uint32_t>(inputs, 2);
|
||||
auto dx = GetDeviceAddress<T>(outputs, 0);
|
||||
|
||||
AddReluGradV2(element_num_, x1, x2, mask, dx, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
auto shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
element_num_ = std::accumulate(shape.begin(), shape.end(), static_cast<size_t>(1), std::multiplies<size_t>());
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
void ResetResource() noexcept override {
|
||||
element_num_ = 0;
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
auto size = element_num_ * sizeof(T);
|
||||
input_size_list_.push_back(size);
|
||||
input_size_list_.push_back(size);
|
||||
input_size_list_.push_back(size);
|
||||
output_size_list_.push_back(size);
|
||||
|
||||
size = (element_num_ + 31) / 32 * sizeof(uint32_t);
|
||||
input_size_list_.push_back(size);
|
||||
}
|
||||
|
||||
private:
|
||||
size_t element_num_;
|
||||
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_ADD_RELU_GRAD_V2_GPU_KERNEL_H_
|
@ -0,0 +1,50 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/nn/fused_add_relu_v2_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(FusedAddReluV2,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeUInt32),
|
||||
FusedAddReluV2GpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(FusedAddReluV2,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeUInt32),
|
||||
FusedAddReluV2GpuKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(FusedAddReluV2,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeUInt32),
|
||||
FusedAddReluV2GpuKernel, int32_t)
|
||||
MS_REG_GPU_KERNEL_ONE(FusedAddReluV2,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt32),
|
||||
FusedAddReluV2GpuKernel, int64_t)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
@ -0,0 +1,85 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_ADD_RELU_V2_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_ADD_RELU_V2_GPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/add_relu_v2_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class FusedAddReluV2GpuKernel : public GpuKernel {
|
||||
public:
|
||||
FusedAddReluV2GpuKernel() { ResetResource(); }
|
||||
~FusedAddReluV2GpuKernel() override = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
auto x1 = GetDeviceAddress<T>(inputs, 0);
|
||||
auto x2 = GetDeviceAddress<T>(inputs, 1);
|
||||
auto y = GetDeviceAddress<T>(outputs, 0);
|
||||
auto mask = GetDeviceAddress<uint32_t>(outputs, 1);
|
||||
|
||||
AddReluV2(element_num_, x1, x2, y, mask, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
auto shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
element_num_ = std::accumulate(shape.begin(), shape.end(), static_cast<size_t>(1), std::multiplies<size_t>());
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
void ResetResource() noexcept override {
|
||||
element_num_ = 0;
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
auto size = element_num_ * sizeof(T);
|
||||
input_size_list_.push_back(size);
|
||||
input_size_list_.push_back(size);
|
||||
output_size_list_.push_back(size);
|
||||
|
||||
size = (element_num_ + 31) / 32 * sizeof(uint32_t);
|
||||
output_size_list_.push_back(size);
|
||||
}
|
||||
|
||||
private:
|
||||
size_t element_num_;
|
||||
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_ADD_RELU_V2_GPU_KERNEL_H_
|
@ -0,0 +1,38 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/nn/relu_grad_v2_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
ReluGradV2,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ReluGradV2GpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
ReluGradV2,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat16),
|
||||
ReluGradV2GpuKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
ReluGradV2,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
ReluGradV2GpuKernel, int32_t)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
ReluGradV2,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt64),
|
||||
ReluGradV2GpuKernel, int64_t)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
@ -0,0 +1,83 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_GRAD_V2_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_GRAD_V2_GPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/relu_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class ReluGradV2GpuKernel : public GpuKernel {
|
||||
public:
|
||||
ReluGradV2GpuKernel() { ResetResource(); }
|
||||
~ReluGradV2GpuKernel() override = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
auto dy = GetDeviceAddress<T>(inputs, 0);
|
||||
auto mask = GetDeviceAddress<uint32_t>(inputs, 1);
|
||||
auto dx = GetDeviceAddress<T>(outputs, 0);
|
||||
|
||||
ReluGradV2(element_num_, dy, mask, dx, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
auto shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
element_num_ = std::accumulate(shape.begin(), shape.end(), static_cast<size_t>(1), std::multiplies<size_t>());
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
void ResetResource() noexcept override {
|
||||
element_num_ = 0;
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
auto size = element_num_ * sizeof(T);
|
||||
input_size_list_.push_back(size);
|
||||
output_size_list_.push_back(size);
|
||||
|
||||
auto mask_size = (element_num_ + 31) / 32 * sizeof(uint32_t);
|
||||
input_size_list_.push_back(mask_size);
|
||||
}
|
||||
|
||||
private:
|
||||
size_t element_num_;
|
||||
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_V2_GRAD_GPU_KERNEL_H_
|
@ -0,0 +1,37 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/nn/relu_v2_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
ReLUV2,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt32),
|
||||
ReluV2GpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
ReLUV2,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt32),
|
||||
ReluV2GpuKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
ReLUV2, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
ReluV2GpuKernel, int32_t)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
ReLUV2,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt32),
|
||||
ReluV2GpuKernel, int64_t)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
@ -0,0 +1,82 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_V2_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_V2_GPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/relu_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class ReluV2GpuKernel : public GpuKernel {
|
||||
public:
|
||||
ReluV2GpuKernel() { ResetResource(); }
|
||||
~ReluV2GpuKernel() override = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
auto x = GetDeviceAddress<T>(inputs, 0);
|
||||
auto y = GetDeviceAddress<T>(outputs, 0);
|
||||
auto mask = GetDeviceAddress<uint32_t>(outputs, 1);
|
||||
|
||||
ReluV2(element_num_, x, y, mask, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
auto shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
element_num_ = std::accumulate(shape.begin(), shape.end(), static_cast<size_t>(1), std::multiplies<size_t>());
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
void ResetResource() noexcept override {
|
||||
element_num_ = 0;
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
auto size = element_num_ * sizeof(T);
|
||||
input_size_list_.push_back(size);
|
||||
output_size_list_.push_back(size);
|
||||
auto mask_size = (element_num_ + 31) / 32 * sizeof(uint32_t);
|
||||
output_size_list_.push_back(mask_size);
|
||||
}
|
||||
|
||||
private:
|
||||
size_t element_num_;
|
||||
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_MASK_GPU_KERNEL_H_
|
@ -0,0 +1,89 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/optimizer/gpu/add_relu_grad_v2_fusion.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "utils/utils.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) {
|
||||
std::vector<std::string> inputs_format;
|
||||
std::vector<std::string> outputs_format;
|
||||
std::vector<TypeId> inputs_type;
|
||||
std::vector<TypeId> outputs_type;
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
||||
|
||||
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(node); ++input_index) {
|
||||
inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index));
|
||||
inputs_format.push_back(kOpFormat_DEFAULT);
|
||||
}
|
||||
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(node); ++output_index) {
|
||||
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index));
|
||||
outputs_format.push_back(kOpFormat_DEFAULT);
|
||||
}
|
||||
builder.SetInputsDeviceType(inputs_type);
|
||||
builder.SetInputsFormat(inputs_format);
|
||||
builder.SetOutputsDeviceType(outputs_type);
|
||||
builder.SetOutputsFormat(outputs_format);
|
||||
return builder.Build();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef AddReluGradV2Fusion::DefinePattern() const {
|
||||
VectorRef relu_grad = VectorRef({prim::kPrimReluGradV2, VectorRef({prim::kPrimTensorAdd, x1_, x2_}), mask_});
|
||||
return relu_grad;
|
||||
}
|
||||
|
||||
const AnfNodePtr AddReluGradV2Fusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &equiv) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(equiv);
|
||||
auto x1 = utils::cast<AnfNodePtr>((*equiv)[x1_]);
|
||||
auto x2 = utils::cast<AnfNodePtr>((*equiv)[x2_]);
|
||||
auto mask = utils::cast<AnfNodePtr>((*equiv)[mask_]);
|
||||
|
||||
auto tensor_add = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
|
||||
MS_EXCEPTION_IF_NULL(tensor_add);
|
||||
auto users = GetRealNodeUsedList(graph, tensor_add);
|
||||
if (users->size() > 1) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto prim = std::make_shared<Primitive>(kFusedAddReluGradV2Name);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), x1, x2, mask};
|
||||
auto add_relugrad = graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(add_relugrad);
|
||||
auto types = {AnfAlgo::GetOutputInferDataType(node, 0)};
|
||||
auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, add_relugrad.get());
|
||||
add_relugrad->set_scope(node->scope());
|
||||
|
||||
auto build_info = GenerateKernelBuildInfo(add_relugrad);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(build_info, add_relugrad.get());
|
||||
return add_relugrad;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
@ -0,0 +1,42 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_ADD_RELU_GRAD_V2_FUSION_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_ADD_RELU_GRAD_V2_FUSION_H_
|
||||
|
||||
#include <memory>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class AddReluGradV2Fusion : public PatternProcessPass {
|
||||
public:
|
||||
explicit AddReluGradV2Fusion(bool multigraph = true) : PatternProcessPass("add_relu_grad", multigraph) {
|
||||
x1_ = std::make_shared<Var>();
|
||||
x2_ = std::make_shared<Var>();
|
||||
mask_ = std::make_shared<Var>();
|
||||
}
|
||||
~AddReluGradV2Fusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
VarPtr x1_;
|
||||
VarPtr x2_;
|
||||
VarPtr mask_;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_ADD_RELUGRAD_FUSION_H_
|
@ -0,0 +1,94 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/optimizer/gpu/add_relu_v2_fusion.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "utils/utils.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) {
|
||||
std::vector<std::string> inputs_format;
|
||||
std::vector<std::string> outputs_format;
|
||||
std::vector<TypeId> inputs_type;
|
||||
std::vector<TypeId> outputs_type;
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
||||
|
||||
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(node); ++input_index) {
|
||||
inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index));
|
||||
inputs_format.push_back(kOpFormat_DEFAULT);
|
||||
}
|
||||
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(node); ++output_index) {
|
||||
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index));
|
||||
outputs_format.push_back(kOpFormat_DEFAULT);
|
||||
}
|
||||
builder.SetInputsDeviceType(inputs_type);
|
||||
builder.SetInputsFormat(inputs_format);
|
||||
builder.SetOutputsDeviceType(outputs_type);
|
||||
builder.SetOutputsFormat(outputs_format);
|
||||
return builder.Build();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef AddReluV2Fusion::DefinePattern() const {
|
||||
VectorRef relu = VectorRef({prim::kPrimReluV2, VectorRef({prim::kPrimTensorAdd, x1_, x2_})});
|
||||
return relu;
|
||||
}
|
||||
|
||||
const AnfNodePtr AddReluV2Fusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &equiv) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(equiv);
|
||||
auto x1 = utils::cast<AnfNodePtr>((*equiv)[x1_]);
|
||||
auto x2 = utils::cast<AnfNodePtr>((*equiv)[x2_]);
|
||||
|
||||
auto tensor_add = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
|
||||
MS_EXCEPTION_IF_NULL(tensor_add);
|
||||
auto users = GetRealNodeUsedList(graph, tensor_add);
|
||||
if (users->size() > 1) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto prim = std::make_shared<Primitive>(kFusedAddReluV2Name);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), x1, x2};
|
||||
auto add_relu = graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(add_relu);
|
||||
|
||||
std::vector<TypeId> types;
|
||||
std::vector<std::vector<size_t>> shapes;
|
||||
for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(node); i++) {
|
||||
types.push_back(AnfAlgo::GetOutputInferDataType(node, i));
|
||||
shapes.push_back(AnfAlgo::GetOutputInferShape(node, i));
|
||||
}
|
||||
|
||||
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, add_relu.get());
|
||||
add_relu->set_scope(node->scope());
|
||||
|
||||
auto build_info = GenerateKernelBuildInfo(add_relu);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(build_info, add_relu.get());
|
||||
return add_relu;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
@ -0,0 +1,40 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_ADD_RELU_FUSION_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_ADD_RELU_FUSION_H_
|
||||
|
||||
#include <memory>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class AddReluV2Fusion : public PatternProcessPass {
|
||||
public:
|
||||
explicit AddReluV2Fusion(bool multigraph = true) : PatternProcessPass("add_relu_v2_fusion", multigraph) {
|
||||
x1_ = std::make_shared<Var>();
|
||||
x2_ = std::make_shared<Var>();
|
||||
}
|
||||
~AddReluV2Fusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
VarPtr x1_;
|
||||
VarPtr x2_;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_ADD_RELU_FUSION_H_
|
@ -0,0 +1,151 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/optimizer/gpu/relu_v2_pass.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "utils/utils.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
const size_t kReluV2OutputNum = 2;
|
||||
|
||||
CNodePtr GetRelu(const CNodePtr &relu_grad) {
|
||||
MS_EXCEPTION_IF_NULL(relu_grad);
|
||||
if (relu_grad->size() != kReluGradInputNum) {
|
||||
MS_LOG_EXCEPTION << "ReluGrad has wrong input size " << relu_grad->size();
|
||||
}
|
||||
auto relu_anf = relu_grad->input(2);
|
||||
MS_EXCEPTION_IF_NULL(relu_anf);
|
||||
return relu_anf->cast<CNodePtr>();
|
||||
}
|
||||
|
||||
kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) {
|
||||
std::vector<std::string> inputs_format;
|
||||
std::vector<std::string> outputs_format;
|
||||
std::vector<TypeId> inputs_type;
|
||||
std::vector<TypeId> outputs_type;
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
||||
|
||||
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(node); ++input_index) {
|
||||
inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index));
|
||||
inputs_format.push_back(kOpFormat_DEFAULT);
|
||||
}
|
||||
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(node); ++output_index) {
|
||||
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index));
|
||||
outputs_format.push_back(kOpFormat_DEFAULT);
|
||||
}
|
||||
builder.SetInputsDeviceType(inputs_type);
|
||||
builder.SetInputsFormat(inputs_format);
|
||||
builder.SetOutputsDeviceType(outputs_type);
|
||||
builder.SetOutputsFormat(outputs_format);
|
||||
return builder.Build();
|
||||
}
|
||||
|
||||
CNodePtr CreateReluV2(const FuncGraphPtr &graph, const CNodePtr &relu) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(relu);
|
||||
if (relu->size() != kReluInputNum) {
|
||||
MS_LOG_EXCEPTION << "Relu has wrong input size " << relu->size();
|
||||
}
|
||||
|
||||
auto prim = std::make_shared<Primitive>(kReluV2OpName);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), relu->input(1)};
|
||||
auto new_node = graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(new_node);
|
||||
new_node->set_scope(relu->scope());
|
||||
|
||||
if (AnfAlgo::IsDynamicShape(relu)) {
|
||||
return nullptr;
|
||||
}
|
||||
std::vector<size_t> output_shape = AnfAlgo::GetOutputInferShape(relu, 0);
|
||||
auto element_num =
|
||||
std::accumulate(output_shape.begin(), output_shape.end(), static_cast<size_t>(1), std::multiplies<size_t>());
|
||||
|
||||
std::vector<size_t> mask_shape = {(element_num + 31) / 32};
|
||||
auto shapes = {AnfAlgo::GetOutputInferShape(relu, 0), mask_shape};
|
||||
auto types = {AnfAlgo::GetOutputInferDataType(relu, 0), kNumberTypeUInt32};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, new_node.get());
|
||||
|
||||
auto build_info = GenerateKernelBuildInfo(new_node);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(build_info, new_node.get());
|
||||
return new_node;
|
||||
}
|
||||
|
||||
CNodePtr CreateReluGradV2(const FuncGraphPtr &graph, const CNodePtr &relu_grad, const AnfNodePtr &second_input) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(relu_grad);
|
||||
MS_EXCEPTION_IF_NULL(second_input);
|
||||
|
||||
auto prim = std::make_shared<Primitive>(kReluGradV2OpName);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), relu_grad->input(1), second_input};
|
||||
auto new_node = graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(new_node);
|
||||
new_node->set_scope(relu_grad->scope());
|
||||
new_node->set_abstract(relu_grad->abstract());
|
||||
|
||||
std::vector<TypeId> types;
|
||||
std::vector<std::vector<size_t>> shapes;
|
||||
for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(relu_grad); i++) {
|
||||
types.push_back(AnfAlgo::GetOutputInferDataType(relu_grad, i));
|
||||
shapes.push_back(AnfAlgo::GetOutputInferShape(relu_grad, i));
|
||||
}
|
||||
|
||||
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, new_node.get());
|
||||
new_node->set_scope(relu_grad->scope());
|
||||
|
||||
auto build_info = GenerateKernelBuildInfo(new_node);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(build_info, new_node.get());
|
||||
|
||||
return new_node;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef ReluV2Pass::DefinePattern() const {
|
||||
VectorRef relu_grad({prim::kPrimReluGrad, dy_, VectorRef({prim::kPrimRelu, x_})});
|
||||
return relu_grad;
|
||||
}
|
||||
|
||||
const AnfNodePtr ReluV2Pass::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto relu_grad = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(relu_grad);
|
||||
auto relu = GetRelu(relu_grad);
|
||||
MS_EXCEPTION_IF_NULL(relu);
|
||||
|
||||
auto relu_v2 = CreateReluV2(graph, relu);
|
||||
if (relu_v2 == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
std::vector<AnfNodePtr> relu_v2_node_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(graph, relu_v2, kReluV2OutputNum, &relu_v2_node_outputs);
|
||||
|
||||
auto relu_grad_v2 = CreateReluGradV2(graph, relu_grad, relu_v2_node_outputs[1]);
|
||||
auto manage = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manage);
|
||||
manage->Replace(relu, relu_v2_node_outputs[0]);
|
||||
return relu_grad_v2;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
@ -0,0 +1,40 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_RELU_V2_FUSION_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_RELU_V2_FUSION_H_
|
||||
|
||||
#include <memory>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class ReluV2Pass : public PatternProcessPass {
|
||||
public:
|
||||
explicit ReluV2Pass(bool multigraph = true) : PatternProcessPass("relu_v2_fusion", multigraph) {
|
||||
x_ = std::make_shared<Var>();
|
||||
dy_ = std::make_shared<Var>();
|
||||
}
|
||||
~ReluV2Pass() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
VarPtr x_;
|
||||
VarPtr dy_;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_RELU_V2_FUSION_H_
|
@ -0,0 +1,142 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
import mindspore.ops.operations._grad_ops as G
|
||||
|
||||
|
||||
class ReluNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ReluNet, self).__init__()
|
||||
self.relu = P.ReLU()
|
||||
self.relu_grad = G.ReluGrad()
|
||||
|
||||
def construct(self, x, dy):
|
||||
y = self.relu(x)
|
||||
dx = self.relu_grad(dy, y)
|
||||
return y, dx
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_ReluV2():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=True)
|
||||
|
||||
x = Tensor(np.array([[[[-1, 1, 10],
|
||||
[1, -1, 1],
|
||||
[10, 1, -1]]]]).astype(np.float32))
|
||||
dy = Tensor(np.array([[[[1, 0, 3],
|
||||
[0, 1, 0],
|
||||
[2, 1, 1]]]]).astype(np.float32))
|
||||
expect_y = np.array([[[[0, 1, 10,],
|
||||
[1, 0, 1,],
|
||||
[10, 1, 0.]]]]).astype(np.float32)
|
||||
expect_dx = np.array([[[[0, 0, 3],
|
||||
[0, 0, 0],
|
||||
[2, 1, 0]]]]).astype(np.float32)
|
||||
net = ReluNet()
|
||||
y, dx = net(Tensor(x), Tensor(dy))
|
||||
|
||||
assert np.allclose(y.asnumpy(), expect_y)
|
||||
assert np.allclose(dx.asnumpy(), expect_dx)
|
||||
|
||||
|
||||
class AddReluNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(AddReluNet, self).__init__()
|
||||
self.add = P.TensorAdd()
|
||||
self.relu = P.ReLU()
|
||||
self.relu_grad = G.ReluGrad()
|
||||
|
||||
def construct(self, x1, x2, dy):
|
||||
y = self.add(x1, x2)
|
||||
y = self.relu(y)
|
||||
dx = self.relu_grad(dy, y)
|
||||
return y, dx
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_AddRelu():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=True)
|
||||
|
||||
x1 = Tensor(np.array([[[[-1, 1, 10],
|
||||
[1, -1, 1],
|
||||
[10, 1, -1]]]]).astype(np.float32))
|
||||
x2 = Tensor(np.array([[[[-1, 1, 10],
|
||||
[1, -1, 1],
|
||||
[10, 1, -1]]]]).astype(np.float32))
|
||||
dy = Tensor(np.array([[[[1, 0, 3],
|
||||
[0, 1, 0],
|
||||
[2, 1, 1]]]]).astype(np.float32))
|
||||
expect_y = np.array([[[[0, 2, 20],
|
||||
[2, 0, 2],
|
||||
[20, 2, 0]]]]).astype(np.float32)
|
||||
expect_dx = np.array([[[[0, 0, 3],
|
||||
[0, 0, 0],
|
||||
[2, 1, 0]]]]).astype(np.float32)
|
||||
net = AddReluNet()
|
||||
y, dx1 = net(Tensor(x1), Tensor(x2), Tensor(dy))
|
||||
|
||||
assert np.allclose(y.asnumpy(), expect_y)
|
||||
assert np.allclose(dx1.asnumpy(), expect_dx)
|
||||
|
||||
class AddReluGradNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(AddReluGradNet, self).__init__()
|
||||
self.add = P.TensorAdd()
|
||||
self.relu = P.ReLU()
|
||||
self.relu_grad = G.ReluGrad()
|
||||
|
||||
def construct(self, x, dy1, dy2):
|
||||
y = self.relu(x)
|
||||
dy = self.add(dy1, dy2)
|
||||
dx = self.relu_grad(dy, y)
|
||||
return y, dx
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_AddReluGrad():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=True)
|
||||
|
||||
x = Tensor(np.array([[[[-1, 1, 10],
|
||||
[1, -1, 1],
|
||||
[10, 1, -1]]]]).astype(np.float32))
|
||||
dy1 = Tensor(np.array([[[[1, 0, 3],
|
||||
[0, 1, 0],
|
||||
[2, 1, 1]]]]).astype(np.float32))
|
||||
dy2 = Tensor(np.array([[[[1, 0, 3],
|
||||
[0, 1, 0],
|
||||
[2, 1, 1]]]]).astype(np.float32))
|
||||
expect_y = np.array([[[[0, 1, 10,],
|
||||
[1, 0, 1,],
|
||||
[10, 1, 0.]]]]).astype(np.float32)
|
||||
expect_dx = np.array([[[[0, 0, 6],
|
||||
[0, 0, 0],
|
||||
[4, 2, 0]]]]).astype(np.float32)
|
||||
net = AddReluGradNet()
|
||||
y, dx1 = net(Tensor(x), Tensor(dy1), Tensor(dy2))
|
||||
|
||||
assert np.allclose(y.asnumpy(), expect_y)
|
||||
assert np.allclose(dx1.asnumpy(), expect_dx)
|
Loading…
Reference in new issue