!8162 gpu support dynamic shape

From: @wilfchen
Reviewed-by: @limingqi107,@cristoval
Signed-off-by: @cristoval
pull/8162/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 36e69f6ef9

@ -35,21 +35,7 @@ const std::map<std::string, cudnnReduceTensorOp_t> kReduceTypeMap = {
template <typename T>
class ArrayReduceGpuKernel : public GpuKernel {
public:
ArrayReduceGpuKernel()
: cudnn_handle_(nullptr),
reduce_tensor_op_(CUDNN_REDUCE_TENSOR_ADD),
data_type_(CUDNN_DATA_FLOAT),
nan_prop_(CUDNN_NOT_PROPAGATE_NAN),
reduce_indices_(CUDNN_REDUCE_TENSOR_NO_INDICES),
reduce_tensor_descriptor_(nullptr),
inputA_descriptor_(nullptr),
outputC_descriptor_(nullptr),
keep_dims_(false),
all_match_(false),
is_null_input_(false),
input_size_(0),
output_size_(0),
workspace_size_(0) {}
ArrayReduceGpuKernel() { ResetResource(); }
~ArrayReduceGpuKernel() override { DestroyResource(); }
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
@ -94,7 +80,7 @@ class ArrayReduceGpuKernel : public GpuKernel {
MS_LOG(ERROR) << "Output number is " << output_num << ", but reduce op needs 1 output.";
return false;
}
int input_dim_length = SizeToInt(AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0).size());
int input_dim_length = SizeToInt(AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0).size());
if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("axis")->isa<ValueTuple>() ||
AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("axis")->isa<ValueList>()) {
@ -117,8 +103,8 @@ class ArrayReduceGpuKernel : public GpuKernel {
}
keep_dims_ = GetAttr<bool>(kernel_node, "keep_dims");
auto inputA_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto outputC_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
auto inputA_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0);
auto outputC_shape = AnfAlgo::GetOutputRealDeviceShapeIfExist(kernel_node, 0);
is_null_input_ = CHECK_NULL_INPUT(inputA_shape);
if (is_null_input_) {
MS_LOG(WARNING) << "ArrayReduceGpuKernel input is null";
@ -132,6 +118,35 @@ class ArrayReduceGpuKernel : public GpuKernel {
return true;
}
void ResetResource() noexcept override {
cudnn_handle_ = nullptr;
reduce_tensor_op_ = CUDNN_REDUCE_TENSOR_ADD;
data_type_ = CUDNN_DATA_FLOAT;
nan_prop_ = CUDNN_NOT_PROPAGATE_NAN;
reduce_indices_ = CUDNN_REDUCE_TENSOR_NO_INDICES;
reduce_tensor_descriptor_ = nullptr;
inputA_descriptor_ = nullptr;
outputC_descriptor_ = nullptr;
keep_dims_ = false;
all_match_ = false;
is_null_input_ = false;
input_size_ = 0;
output_size_ = 0;
workspace_size_ = 0;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
void DestroyResource() noexcept override {
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyReduceTensorDescriptor(reduce_tensor_descriptor_),
"cudnnDestroyReduceTensorDescriptor failed.");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(inputA_descriptor_),
"cudnnDestroyTensorDescriptor failed.");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(outputC_descriptor_),
"cudnnDestroyTensorDescriptor failed.");
}
protected:
void InitResource() override {
cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
@ -160,14 +175,6 @@ class ArrayReduceGpuKernel : public GpuKernel {
}
private:
void DestroyResource() noexcept {
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyReduceTensorDescriptor(reduce_tensor_descriptor_),
"cudnnDestroyReduceTensorDescriptor failed.");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(inputA_descriptor_),
"cudnnDestroyTensorDescriptor failed.");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(outputC_descriptor_),
"cudnnDestroyTensorDescriptor failed.");
}
void InferArrayReduceType(const CNodePtr &kernel_node) {
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);
auto iter = kReduceTypeMap.find(kernel_name);

@ -26,5 +26,14 @@ MS_REG_GPU_KERNEL_TWO(
GatherV2,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
GatherV2GpuFwdKernel, half, int)
MS_REG_GPU_KERNEL_TWO(
SparseGatherV2,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
GatherV2GpuFwdKernel, float, int)
MS_REG_GPU_KERNEL_TWO(
SparseGatherV2,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
GatherV2GpuFwdKernel, half, int)
} // namespace kernel
} // namespace mindspore

@ -27,7 +27,7 @@ namespace kernel {
template <typename T, typename S>
class GatherV2GpuFwdKernel : public GpuKernel {
public:
GatherV2GpuFwdKernel() : axis_(0), handle_(nullptr) {}
GatherV2GpuFwdKernel() { ResetResource(); }
~GatherV2GpuFwdKernel() = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
@ -52,9 +52,9 @@ class GatherV2GpuFwdKernel : public GpuKernel {
if (input_num != 2) {
MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but GatherGpuV2FwdKernel needs 2.";
}
input_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
indices_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
output_shapes_ = AnfAlgo::GetOutputInferShape(kernel_node, 0);
input_shapes_ = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0);
indices_shapes_ = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 1);
output_shapes_ = AnfAlgo::GetOutputRealDeviceShapeIfExist(kernel_node, 0);
axis_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "axis"));
if (axis_ < 0) {
@ -65,9 +65,18 @@ class GatherV2GpuFwdKernel : public GpuKernel {
InitSizeLists();
return true;
}
void ResetResource() noexcept override {
input_shapes_.clear();
indices_shapes_.clear();
output_shapes_.clear();
std::fill(dims_, dims_ + 3, 0);
axis_ = 0;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
protected:
void InitResource() override { handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); }
void InitSizeLists() override {
size_t size = GetSize(input_shapes_);
input_size_list_.push_back(size);
@ -118,7 +127,6 @@ class GatherV2GpuFwdKernel : public GpuKernel {
size_t dims_[3] = {};
int axis_;
cudnnHandle_t handle_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;

@ -28,14 +28,7 @@ namespace kernel {
template <typename T>
class SplitGpuFwdKernel : public GpuKernel {
public:
SplitGpuFwdKernel()
: axis_(0),
output_num_(1),
input_size_(1),
axis_step_(1),
all_size_before_axis_(1),
all_size_axis_(1),
outputs_host_(nullptr) {}
SplitGpuFwdKernel() { ResetResource(); }
~SplitGpuFwdKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
@ -59,7 +52,7 @@ class SplitGpuFwdKernel : public GpuKernel {
bool Init(const CNodePtr &kernel_node) override {
axis_ = static_cast<int64_t>(GetAttr<int64_t>(kernel_node, "axis"));
if (axis_ < 0) {
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto input_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0);
axis_ += SizeToInt(input_shape.size());
}
output_num_ = static_cast<int64_t>(GetAttr<int64_t>(kernel_node, "output_num"));
@ -68,7 +61,7 @@ class SplitGpuFwdKernel : public GpuKernel {
return false;
}
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto input_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0);
input_size_ = 1;
all_size_before_axis_ = 1;
all_size_axis_ = 1;
@ -88,7 +81,7 @@ class SplitGpuFwdKernel : public GpuKernel {
for (int i = 0; i < output_num_; i++) {
size_t output_size = 1;
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, i);
auto output_shape = AnfAlgo::GetOutputRealDeviceShapeIfExist(kernel_node, i);
for (size_t j = 0; j < output_shape.size(); j++) {
output_size *= output_shape[j];
}
@ -100,6 +93,19 @@ class SplitGpuFwdKernel : public GpuKernel {
return true;
}
void ResetResource() noexcept override {
axis_ = 0;
output_num_ = 1;
input_size_ = 1;
axis_step_ = 1;
all_size_before_axis_ = 1;
all_size_axis_ = 1;
outputs_host_ = nullptr;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
protected:
void InitSizeLists() override {}

@ -62,7 +62,7 @@ class TransposeGpuFwdKernel : public GpuKernel {
MS_LOG(ERROR) << "Output number is " << output_num << ", but transpose needs 1 output.";
return false;
}
auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
auto input_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0);
shape_size_ = input_shape.size();
if (shape_size_ > TRANSPOSE_MAX_DIMENSION) {
MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but transpose supports max " << TRANSPOSE_MAX_DIMENSION

@ -27,8 +27,7 @@ namespace kernel {
template <typename T, typename S>
class UnsortedSegmentSumGpuKernel : public GpuKernel {
public:
UnsortedSegmentSumGpuKernel()
: input_dim0_(1), input_dim1_(1), output_dim0_(1), output_dim1_(1), is_null_input_(false) {}
UnsortedSegmentSumGpuKernel() { ResetResource(); }
~UnsortedSegmentSumGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
@ -53,15 +52,15 @@ class UnsortedSegmentSumGpuKernel : public GpuKernel {
}
bool Init(const CNodePtr &kernel_node) override {
auto input_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto input_shapes = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0);
is_null_input_ = CHECK_NULL_INPUT(input_shapes);
if (is_null_input_) {
MS_LOG(WARNING) << "UnsortedSegmentSum input is null";
InitSizeLists();
return true;
}
auto ids_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
auto output_shapes = AnfAlgo::GetOutputInferShape(kernel_node, 0);
auto ids_shapes = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 1);
auto output_shapes = AnfAlgo::GetOutputRealDeviceShapeIfExist(kernel_node, 0);
auto axis = ids_shapes.size();
for (size_t i = 0; i < input_shapes.size(); i++) {
@ -81,6 +80,17 @@ class UnsortedSegmentSumGpuKernel : public GpuKernel {
return true;
}
void ResetResource() noexcept override {
input_dim0_ = 1;
input_dim1_ = 1;
output_dim0_ = 1;
output_dim1_ = 1;
is_null_input_ = false;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(input_dim0_ * input_dim1_ * sizeof(T));

@ -0,0 +1,36 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
namespace mindspore {
namespace kernel {
void GpuDynamicKernel::UpdateArgs() {
if (!is_input_dynamic_shape_ && is_output_dynamic_shape_ && !have_depends()) {
return;
}
MS_LOG(INFO) << "Update Args: " << cnode_ptr_->fullname_with_scope();
auto kernel_mod = AnfAlgo::GetKernelMod(cnode_ptr_);
MS_EXCEPTION_IF_NULL(kernel_mod);
auto gpu_kernel_mod = dynamic_cast<GpuKernel *>(kernel_mod);
MS_EXCEPTION_IF_NULL(gpu_kernel_mod);
gpu_kernel_mod->DestroyResource();
gpu_kernel_mod->ResetResource();
gpu_kernel_mod->Init(cnode_ptr_);
}
} // namespace kernel
} // namespace mindspore

@ -23,11 +23,13 @@
#include <vector>
#include <utility>
#include <map>
#include <memory>
#include "backend/kernel_compiler/kernel.h"
#include "backend/kernel_compiler/gpu/kernel_constants.h"
#include "runtime/device/gpu/gpu_device_manager.h"
#include "runtime/device/gpu/gpu_common.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "runtime/device/executor/dynamic_kernel.h"
using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm;
namespace mindspore {
@ -45,10 +47,28 @@ static std::map<int, int> kNHWCToNCHWAxisMap = {
{3, 1},
};
class GpuDynamicKernel : public device::DynamicKernel {
public:
explicit GpuDynamicKernel(const CNodePtr &cnode_ptr) : DynamicKernel(nullptr, cnode_ptr) {}
~GpuDynamicKernel() = default;
void UpdateArgs() override;
void PostExecute() final { MS_LOG(EXCEPTION) << "`PostExecute()` should not invoked with gpu backend"; };
void Execute() final { MS_LOG(EXCEPTION) << "`Execute()` should not invoked with gpu backend"; }
};
class GpuKernel : public KernelMod {
public:
virtual ~GpuKernel() = default;
virtual bool Init(const CNodePtr &kernel_node) = 0;
virtual void ResetResource() noexcept {
MS_LOG(EXCEPTION) << "kernel must override the `ResetResource()` method when dynamic shape";
}
virtual void DestroyResource() noexcept {}
virtual void PostExecute() {}
void InitDynamicKernel(const CNodePtr &cnode_ptr) { dynamic_kernel_ = std::make_shared<GpuDynamicKernel>(cnode_ptr); }
device::DynamicKernelPtr DynamicKernel() const { return dynamic_kernel_; }
protected:
virtual void InitResource() {}
@ -228,7 +248,10 @@ class GpuKernel : public KernelMod {
}
return type->second;
}
device::DynamicKernelPtr dynamic_kernel_;
};
} // namespace kernel
} // namespace mindspore

@ -123,6 +123,10 @@ class AddNGpuFwdKernel : public GpuKernel {
return true;
}
void DestroyResource() noexcept override {
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(input_descriptor_), "cudnnDestroyTensorDescriptor failed");
}
protected:
void InitResource() override {
cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
@ -141,9 +145,6 @@ class AddNGpuFwdKernel : public GpuKernel {
}
private:
void DestroyResource() noexcept {
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(input_descriptor_), "cudnnDestroyTensorDescriptor failed");
}
cudnnHandle_t cudnn_handle_;
cudnnTensorDescriptor_t input_descriptor_;
cudnnDataType_t cudnn_data_type_;

@ -112,6 +112,12 @@ class BiasAddGpuKernel : public GpuKernel {
return true;
}
void DestroyResource() noexcept override {
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyOpTensorDescriptor(op_desc_), "cudnnDestroyTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(b_desc_), "cudnnDestroyTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(x_desc_), "cudnnDestroyOpTensorDescriptor failed");
}
protected:
void InitResource() override {
cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
@ -129,12 +135,6 @@ class BiasAddGpuKernel : public GpuKernel {
}
private:
void DestroyResource() noexcept {
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyOpTensorDescriptor(op_desc_), "cudnnDestroyTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(b_desc_), "cudnnDestroyTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(x_desc_), "cudnnDestroyOpTensorDescriptor failed");
}
cudnnHandle_t cudnn_handle_;
cudnnDataType_t cudnn_data_type_;
cudnnTensorDescriptor_t x_desc_;

@ -31,13 +31,7 @@ constexpr int MAX_DIMS = 7;
template <typename T>
class BroadcastOpGpuKernel : public GpuKernel {
public:
BroadcastOpGpuKernel()
: op_type_(BROADCAST_TYPE_INVALID),
need_broadcast_(false),
is_comp_op_(false),
input1_num_(1),
input2_num_(1),
output_num_(1) {}
BroadcastOpGpuKernel() { ResetResource(); }
~BroadcastOpGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
@ -71,9 +65,9 @@ class BroadcastOpGpuKernel : public GpuKernel {
}
bool Init(const CNodePtr &kernel_node) override {
GetOpType(kernel_node);
auto shape1 = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
auto shape2 = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
auto shape3 = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
auto shape1 = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0);
auto shape2 = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 1);
auto shape3 = AnfAlgo::GetOutputRealDeviceShapeIfExist(kernel_node, 0);
need_broadcast_ = IsBroadcast(shape1, shape2);
if (need_broadcast_ && shape1.size() > 7) {
MS_LOG(EXCEPTION) << "Broadcast operation not support dim greater than 7";
@ -106,6 +100,20 @@ class BroadcastOpGpuKernel : public GpuKernel {
InitSizeLists();
return true;
}
void ResetResource() noexcept override {
op_type_ = BROADCAST_TYPE_INVALID;
need_broadcast_ = false;
is_comp_op_ = false;
input1_num_ = 1;
input2_num_ = 1;
output_num_ = 1;
lhs_shape_.clear();
rhs_shape_.clear();
output_shape_.clear();
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
protected:
void InitResource() override { return; }

@ -30,14 +30,7 @@ namespace kernel {
template <typename T>
class BroadcastOpGradGpuKernel : public GpuKernel {
public:
BroadcastOpGradGpuKernel()
: op_type_(BROADCAST_GRAD_TYPE_INVALID),
need_broadcast_(false),
input1_num_(1),
input2_num_(1),
output_num_(1),
grad_x_(false),
grad_y_(false) {}
BroadcastOpGradGpuKernel() { ResetResource(); }
~BroadcastOpGradGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
@ -105,6 +98,22 @@ class BroadcastOpGradGpuKernel : public GpuKernel {
return true;
}
void ResetResource() noexcept override {
op_type_ = BROADCAST_GRAD_TYPE_INVALID;
need_broadcast_ = false;
input1_num_ = 1;
input2_num_ = 1;
output_num_ = 1;
std::fill(x1_shape_, x1_shape_ + 4, 1);
std::fill(x2_shape_, x2_shape_ + 4, 1);
std::fill(dy_shape_, dy_shape_ + 4, 1);
grad_x_ = false;
grad_y_ = false;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
protected:
void InitResource() override { return; }
void InitSizeLists() override {

@ -69,21 +69,15 @@ static const std::map<std::string, UnaryOptype> kUnaryOpTypeMap = {{"Exp", UNARY
template <typename T>
class UnaryOpGpuKernel : public GpuKernel {
public:
UnaryOpGpuKernel()
: unary_op_type_(UNARY_OP_INVALID_TYPE),
input_size_(sizeof(T)),
output_size_(sizeof(T)),
workspace_size_(0),
is_null_input_(false) {}
UnaryOpGpuKernel() { ResetResource(); }
~UnaryOpGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
VARIABLE_NOT_USED(workspace);
T *input_addr = GetDeviceAddress<T>(inputs, 0);
T *output_addr = GetDeviceAddress<T>(outputs, 0);
@ -184,7 +178,7 @@ class UnaryOpGpuKernel : public GpuKernel {
MS_LOG(ERROR) << "Output number is " << output_num << ", but unary op needs 1 output.";
return false;
}
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto input_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0);
is_null_input_ = CHECK_NULL_INPUT(input_shape);
if (is_null_input_) {
MS_LOG(WARNING) << "UnaryOpGpuKernel input is null";
@ -198,6 +192,16 @@ class UnaryOpGpuKernel : public GpuKernel {
InitSizeLists();
return true;
}
void ResetResource() noexcept override {
unary_op_type_ = UNARY_OP_INVALID_TYPE;
input_size_ = sizeof(T);
output_size_ = sizeof(T);
workspace_size_ = 0;
is_null_input_ = false;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
protected:
void InitSizeLists() override {

@ -29,16 +29,7 @@ namespace kernel {
template <typename T>
class ActivationGpuFwdKernel : public GpuKernel {
public:
ActivationGpuFwdKernel()
: cudnn_handle_(nullptr),
activation_desc_(nullptr),
mode_(CUDNN_ACTIVATION_RELU),
data_descriptor_(nullptr),
is_null_input_(false),
cudnn_data_type_(CUDNN_DATA_FLOAT),
input_size_(0),
output_size_(0),
workspace_size_(0) {}
ActivationGpuFwdKernel() { ResetResource(); }
~ActivationGpuFwdKernel() override { DestroyResource(); }
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
@ -75,7 +66,7 @@ class ActivationGpuFwdKernel : public GpuKernel {
MS_LOG(ERROR) << "Argument number is " << input_num << ", but ActivationGpuFwdKernel needs 1.";
return false;
}
auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
auto input_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0);
is_null_input_ = CHECK_NULL_INPUT(input_shape);
if (is_null_input_) {
MS_LOG(WARNING) << "ActivationGpuFwdKernel input is null.";
@ -113,6 +104,27 @@ class ActivationGpuFwdKernel : public GpuKernel {
return true;
}
void DestroyResource() noexcept override {
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyActivationDescriptor(activation_desc_),
"cudnnDestroyActivationDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(data_descriptor_), "cudnnDestroyTensorDescriptor failed");
}
void ResetResource() noexcept override {
cudnn_handle_ = nullptr;
activation_desc_ = nullptr;
mode_ = CUDNN_ACTIVATION_RELU;
data_descriptor_ = nullptr;
is_null_input_ = false;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
cudnn_data_type_ = CUDNN_DATA_FLOAT;
input_size_ = 0;
output_size_ = 0;
workspace_size_ = 0;
}
protected:
void InitResource() override {
cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
@ -132,12 +144,6 @@ class ActivationGpuFwdKernel : public GpuKernel {
}
private:
void DestroyResource() noexcept {
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyActivationDescriptor(activation_desc_),
"cudnnDestroyActivationDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(data_descriptor_), "cudnnDestroyTensorDescriptor failed");
}
std::map<std::string, cudnnActivationMode_t> kernel_map = {{"ReLU", CUDNN_ACTIVATION_RELU},
{"ReLU6", CUDNN_ACTIVATION_CLIPPED_RELU},
{"Tanh", CUDNN_ACTIVATION_TANH},

@ -29,14 +29,7 @@ namespace kernel {
template <typename T>
class ActivationGradGpuKernel : public GpuKernel {
public:
ActivationGradGpuKernel()
: cudnn_handle_(nullptr),
activation_desc_(nullptr),
mode_(CUDNN_ACTIVATION_RELU),
data_descriptor_(nullptr),
is_null_input_(false),
cudnn_data_type_(CUDNN_DATA_FLOAT),
input_size_(0) {}
ActivationGradGpuKernel() { ResetResource(); }
~ActivationGradGpuKernel() override { DestroyResource(); }
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
@ -117,6 +110,25 @@ class ActivationGradGpuKernel : public GpuKernel {
return true;
}
void DestroyResource() noexcept override {
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyActivationDescriptor(activation_desc_),
"cudnnDestroyActivationDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(data_descriptor_), "cudnnDestroyTensorDescriptor failed");
}
void ResetResource() noexcept override {
cudnn_handle_ = nullptr;
activation_desc_ = nullptr;
mode_ = CUDNN_ACTIVATION_RELU;
data_descriptor_ = nullptr;
is_null_input_ = false;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
cudnn_data_type_ = CUDNN_DATA_FLOAT;
input_size_ = 0;
}
protected:
void InitResource() override {
cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
@ -135,12 +147,6 @@ class ActivationGradGpuKernel : public GpuKernel {
}
private:
void DestroyResource() noexcept {
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyActivationDescriptor(activation_desc_),
"cudnnDestroyActivationDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(data_descriptor_), "cudnnDestroyTensorDescriptor failed");
}
std::map<std::string, cudnnActivationMode_t> kernel_map = {{"ReluGrad", CUDNN_ACTIVATION_RELU},
{"ReLU6Grad", CUDNN_ACTIVATION_CLIPPED_RELU},
{"TanhGrad", CUDNN_ACTIVATION_TANH},

@ -121,6 +121,13 @@ class BatchNormGradGpuKernel : public GpuKernel {
return true;
}
void DestroyResource() noexcept override {
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(scale_bias_desc_), "Destroy para desc failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dx_desc_), "Destroy dx desc failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dy_desc_), "Destroy dy desc failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(x_desc_), "Destroy x desc failed");
}
protected:
void InitResource() override {
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
@ -152,13 +159,6 @@ class BatchNormGradGpuKernel : public GpuKernel {
}
private:
void DestroyResource() noexcept {
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(scale_bias_desc_), "Destroy para desc failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dx_desc_), "Destroy dx desc failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dy_desc_), "Destroy dy desc failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(x_desc_), "Destroy x desc failed");
}
int batch_;
int channel_;
int height_;

@ -111,6 +111,13 @@ class BiasAddGradGpuKernel : public GpuKernel {
return true;
}
void DestroyResource() noexcept override {
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnDestroyReduceTensorDescriptor(op_desc_),
"cudnnDestroyReduceTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(db_desc_), "cudnnDestroyTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dy_desc_), "cudnnDestroyOpTensorDescriptor failed");
}
protected:
void InitResource() override {
cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
@ -137,13 +144,6 @@ class BiasAddGradGpuKernel : public GpuKernel {
}
private:
void DestroyResource() noexcept {
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnDestroyReduceTensorDescriptor(op_desc_),
"cudnnDestroyReduceTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(db_desc_), "cudnnDestroyTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dy_desc_), "cudnnDestroyOpTensorDescriptor failed");
}
bool same_dims_;
cudnnHandle_t cudnn_handle_;
cudnnDataType_t cudnn_data_type_;

@ -198,6 +198,15 @@ class Conv2dGpuFwdKernel : public GpuKernel {
return true;
}
void DestroyResource() noexcept override {
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyConvolutionDescriptor(conv_desc_),
"cudnnDestroyConvolutionDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyFilterDescriptor(filter_desc_), "cudnnDestroyTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(padded_desc_), "cudnnDestroyTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(output_desc_), "cudnnDestroyTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(input_desc_), "cudnnDestroyTensorDescriptor failed");
}
protected:
void InitResource() override {
cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
@ -243,14 +252,6 @@ class Conv2dGpuFwdKernel : public GpuKernel {
}
private:
void DestroyResource() noexcept {
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyConvolutionDescriptor(conv_desc_),
"cudnnDestroyConvolutionDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyFilterDescriptor(filter_desc_), "cudnnDestroyTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(padded_desc_), "cudnnDestroyTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(output_desc_), "cudnnDestroyTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(input_desc_), "cudnnDestroyTensorDescriptor failed");
}
bool CheckParam(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 2) {

@ -199,6 +199,15 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel {
return true;
}
void DestroyResource() noexcept override {
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyConvolutionDescriptor(conv_desc_),
"cudnnDestroyConvolutionDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyFilterDescriptor(dw_desc_), "cudnnDestroyFilterDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(padded_descriptor_), "cudnnDestroyTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dy_desc_), "cudnnDestroyTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(x_desc_), "cudnnDestroyTensorDescriptor failed");
}
protected:
void InitResource() override {
cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
@ -243,14 +252,6 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel {
}
private:
void DestroyResource() noexcept {
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyConvolutionDescriptor(conv_desc_),
"cudnnDestroyConvolutionDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyFilterDescriptor(dw_desc_), "cudnnDestroyFilterDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(padded_descriptor_), "cudnnDestroyTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dy_desc_), "cudnnDestroyTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(x_desc_), "cudnnDestroyTensorDescriptor failed");
}
bool CheckParam(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 2) {

@ -203,6 +203,15 @@ class ConvGradInputGpuBkwKernel : public GpuKernel {
return true;
}
void DestroyResource() noexcept override {
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyConvolutionDescriptor(conv_desc_),
"cudnnDestroyConvolutionDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyFilterDescriptor(w_desc_), "cudnnDestroyFilterDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(padded_descriptor_), "cudnnDestroyTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dy_desc_), "cudnnDestroyTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dx_desc_), "cudnnDestroyTensorDescriptor failed");
}
protected:
void InitResource() override {
cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
@ -244,14 +253,6 @@ class ConvGradInputGpuBkwKernel : public GpuKernel {
}
private:
void DestroyResource() noexcept {
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyConvolutionDescriptor(conv_desc_),
"cudnnDestroyConvolutionDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyFilterDescriptor(w_desc_), "cudnnDestroyFilterDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(padded_descriptor_), "cudnnDestroyTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dy_desc_), "cudnnDestroyTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dx_desc_), "cudnnDestroyTensorDescriptor failed");
}
bool CheckParam(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 2) {

@ -27,7 +27,7 @@ namespace kernel {
template <typename T>
class FlattenGpuFwdKernel : public GpuKernel {
public:
FlattenGpuFwdKernel() : input_size_(0), output_size_(0), workspace_size_(0) {}
FlattenGpuFwdKernel() : input_size_(0) {}
~FlattenGpuFwdKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
@ -47,7 +47,7 @@ class FlattenGpuFwdKernel : public GpuKernel {
return true;
}
bool Init(const CNodePtr &kernel_node) override {
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0);
input_size_ = sizeof(T);
for (size_t i = 0; i < shape.size(); ++i) {
input_size_ *= shape[i];
@ -55,12 +55,17 @@ class FlattenGpuFwdKernel : public GpuKernel {
InitSizeLists();
return true;
}
void ResetResource() noexcept override {
input_size_ = 0;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(input_size_);
output_size_ = input_size_;
output_size_list_.push_back(output_size_);
output_size_list_.push_back(input_size_);
}
private:
@ -69,8 +74,6 @@ class FlattenGpuFwdKernel : public GpuKernel {
std::vector<size_t> workspace_size_list_;
size_t input_size_;
size_t output_size_;
size_t workspace_size_;
};
} // namespace kernel
} // namespace mindspore

@ -27,7 +27,7 @@ namespace kernel {
template <typename T>
class FlattenGardGpuBkwKernel : public GpuKernel {
public:
FlattenGardGpuBkwKernel() : input_size_(0), output_size_(0), workspace_size_(0) {}
FlattenGardGpuBkwKernel() { ResetResource(); }
~FlattenGardGpuBkwKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
@ -54,7 +54,7 @@ class FlattenGardGpuBkwKernel : public GpuKernel {
return false;
}
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0);
for (size_t i = 0; i < shape.size(); ++i) {
if (input_size_ == 0) {
input_size_ = 1;
@ -67,11 +67,17 @@ class FlattenGardGpuBkwKernel : public GpuKernel {
return true;
}
void ResetResource() noexcept override {
input_size_ = 0;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(input_size_);
output_size_ = input_size_;
output_size_list_.push_back(output_size_);
output_size_list_.push_back(input_size_);
}
private:
@ -80,8 +86,6 @@ class FlattenGardGpuBkwKernel : public GpuKernel {
std::vector<size_t> workspace_size_list_;
size_t input_size_;
size_t output_size_;
size_t workspace_size_;
};
} // namespace kernel
} // namespace mindspore

@ -140,6 +140,20 @@ class FusedBatchNormExGpuKernel : public GpuKernel {
return true;
}
void DestroyResource() noexcept override {
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(x_desc_), "Destroy x desc failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(y_desc_), "Destroy y desc failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(scale_bias_mean_var_desc_), "Destroy para desc failed");
if (bn_ops_ == CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION) {
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(z_desc_), "Destroy z desc failed");
}
if (bn_ops_ != CUDNN_BATCHNORM_OPS_BN) {
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyActivationDescriptor(activation_desc_),
"Destroy activation descriptor failed");
}
}
protected:
void InitResource() override {
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
@ -238,20 +252,6 @@ class FusedBatchNormExGpuKernel : public GpuKernel {
}
}
void DestroyResource() noexcept {
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(x_desc_), "Destroy x desc failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(y_desc_), "Destroy y desc failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(scale_bias_mean_var_desc_), "Destroy para desc failed");
if (bn_ops_ == CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION) {
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(z_desc_), "Destroy z desc failed");
}
if (bn_ops_ != CUDNN_BATCHNORM_OPS_BN) {
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyActivationDescriptor(activation_desc_),
"Destroy activation descriptor failed");
}
}
size_t input_x_size_;
size_t input_z_size_;
size_t para_size_;

@ -133,6 +133,12 @@ class FusedBatchNormGpuKernel : public GpuKernel {
return true;
}
void DestroyResource() noexcept override {
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(x_desc_), "Destroy x desc failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(y_desc_), "Destroy y desc failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(scale_bias_mean_var_desc_), "Destroy para desc failed");
}
protected:
void InitResource() override {
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
@ -165,12 +171,6 @@ class FusedBatchNormGpuKernel : public GpuKernel {
}
private:
void DestroyResource() noexcept {
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(x_desc_), "Destroy x desc failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(y_desc_), "Destroy y desc failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(scale_bias_mean_var_desc_), "Destroy para desc failed");
}
int batch_;
int channel_;
int height_;

@ -201,6 +201,21 @@ class FusedBatchNormGradExGpuKernel : public GpuKernel {
workspace_size_list_.push_back(workspace_size_);
}
void DestroyResource() noexcept override {
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(x_desc_), "Destroy x desc failed");
if (bn_ops_ != CUDNN_BATCHNORM_OPS_BN) {
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(y_desc_), "Destroy y desc failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyActivationDescriptor(activation_desc_),
"Destroy activation descriptor failed");
}
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dy_desc_), "Destroy dy desc failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dx_desc_), "Destroy dx desc failed");
if (bn_ops_ == CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION) {
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dz_desc_), "Destroy z desc failed");
}
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(scale_bias_diff_desc_), "Destroy para desc failed");
}
private:
void SetTensorDescriptor(const std::string &format, const std::vector<size_t> &shape) {
@ -255,22 +270,6 @@ class FusedBatchNormGradExGpuKernel : public GpuKernel {
}
}
void DestroyResource() noexcept {
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(x_desc_), "Destroy x desc failed");
if (bn_ops_ != CUDNN_BATCHNORM_OPS_BN) {
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(y_desc_), "Destroy y desc failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyActivationDescriptor(activation_desc_),
"Destroy activation descriptor failed");
}
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dy_desc_), "Destroy dy desc failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dx_desc_), "Destroy dx desc failed");
if (bn_ops_ == CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION) {
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dz_desc_), "Destroy z desc failed");
}
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(scale_bias_diff_desc_), "Destroy para desc failed");
}
size_t x_size_;
size_t para_size_;
size_t workspace_size_;

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save