!13307 GPU fix shared_ptr in GpuKernel

From: @VectorSL
Reviewed-by: @cristoval,@chujinjin
Signed-off-by: @chujinjin
pull/13307/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit eb1c0310a9

@ -54,18 +54,18 @@ class DynamicRangeGpuKernel : public GpuKernel {
DynamicRangeErrorCode error_code = DynamicRangeErrorCode::kOk;
CHECK_CUDA_RET_WITH_ERROR(c_node_ptr_,
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
cudaMemcpyAsync(&error_code, error_code_device_address, sizeof(DynamicRangeErrorCode),
cudaMemcpyDeviceToHost, reinterpret_cast<cudaStream_t>(stream_ptr)),
"Failed to copy error code to host.");
CHECK_CUDA_RET_WITH_EXCEPT(c_node_ptr_, cudaDeviceSynchronize(), "cudaDeviceSyncFailed");
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaDeviceSynchronize(), "cudaDeviceSyncFailed");
// use workspace[0] for actual output shape, we know it must be 1d
CHECK_CUDA_RET_WITH_ERROR(c_node_ptr_,
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
cudaMemcpyAsync(&output_shape_, output_shape_device_address, sizeof(int64_t),
cudaMemcpyDeviceToHost, reinterpret_cast<cudaStream_t>(stream_ptr)),
"Failed to copy output_shape to host.");
CHECK_CUDA_RET_WITH_EXCEPT(c_node_ptr_, cudaDeviceSynchronize(), "cudaDeviceSyncFailed");
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaDeviceSynchronize(), "cudaDeviceSyncFailed");
LogExceptionIfNotOk(error_code);
@ -98,17 +98,16 @@ class DynamicRangeGpuKernel : public GpuKernel {
void PostExecute() override {
// required synchronize for PostExecute
CHECK_CUDA_RET_WITH_EXCEPT(c_node_ptr_, cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream_ptr_)),
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream_ptr_)),
"cudaStreamSynchronize failed");
std::vector<TypeId> output_type = {AnfAlgo::GetOutputInferDataType(c_node_ptr_, 0)};
std::vector<TypeId> output_type = {AnfAlgo::GetOutputInferDataType(kernel_node_.lock(), 0)};
std::vector<std::vector<size_t>> output_shape = {{(size_t)output_shape_}};
AnfAlgo::SetOutputInferTypeAndShape(output_type, output_shape, c_node_ptr_.get());
AnfAlgo::SetOutputInferTypeAndShape(output_type, output_shape, kernel_node_.lock().get());
}
void ResetResource() noexcept override {
stream_ptr_ = nullptr;
c_node_ptr_ = nullptr;
output_shape_ = 0;
max_output_length_ = 0;
input_size_list_.clear();
@ -124,7 +123,7 @@ class DynamicRangeGpuKernel : public GpuKernel {
}
max_output_length_ = GetAttr<int64_t>(kernel_node, "maxlen");
c_node_ptr_ = kernel_node;
kernel_node_ = kernel_node;
InitSizeLists();
return true;
@ -145,7 +144,6 @@ class DynamicRangeGpuKernel : public GpuKernel {
private:
void *stream_ptr_;
CNodePtr c_node_ptr_;
int64_t output_shape_;
int64_t max_output_length_;

@ -54,6 +54,7 @@ class TopKGpuKernel : public GpuKernel {
}
bool Init(const CNodePtr &kernel_node) override {
kernel_node_ = kernel_node;
auto input_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto output_shapes = AnfAlgo::GetOutputInferShape(kernel_node, 0);
input_shape_size_ = input_shapes.size();

@ -64,17 +64,17 @@ class UniqueGpuKernel : public GpuKernel {
"cudaStreamSynchronized failed");
std::vector<TypeId> type_ids;
std::vector<std::vector<size_t>> shapes;
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node_);
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node_.lock());
for (size_t i = 0; i < output_num; ++i) {
std::vector<size_t> shape = AnfAlgo::GetOutputInferShape(kernel_node_, i);
std::vector<size_t> shape = AnfAlgo::GetOutputInferShape(kernel_node_.lock(), i);
if (i == 0) {
shape[0] = post_output_size_;
}
TypeId type_id = AnfAlgo::GetOutputInferDataType(kernel_node_, i);
TypeId type_id = AnfAlgo::GetOutputInferDataType(kernel_node_.lock(), i);
type_ids.emplace_back(type_id);
shapes.emplace_back(shape);
}
AnfAlgo::SetOutputInferTypeAndShape(type_ids, shapes, kernel_node_.get());
AnfAlgo::SetOutputInferTypeAndShape(type_ids, shapes, kernel_node_.lock().get());
}
void ResetResource() noexcept override {
@ -84,7 +84,6 @@ class UniqueGpuKernel : public GpuKernel {
num_elements_ = 1;
post_output_size_ = 0;
stream_ptr_ = nullptr;
kernel_node_ = nullptr;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
@ -106,7 +105,6 @@ class UniqueGpuKernel : public GpuKernel {
size_t workspace_size_;
int num_elements_;
int post_output_size_;
CNodePtr kernel_node_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;

@ -39,7 +39,7 @@ class ZerosLikeGpuKernel : public GpuKernel {
CHECK_CUDA_RET_WITH_EXCEPT(
kernel_node_,
// have to use a float literal instead of an int literal beacuse of ambigious half() overload.
// have to use a float literal instead of an int literal because of ambiguous half() overload.
cudaMemsetAsync(output_device_address, static_cast<T>(0.0), input_size_ * sizeof(T),
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemset failed");
@ -61,7 +61,6 @@ class ZerosLikeGpuKernel : public GpuKernel {
}
void ResetResource() noexcept override {
kernel_node_ = nullptr;
input_size_ = 1;
input_size_list_.clear();
output_size_list_.clear();
@ -76,7 +75,6 @@ class ZerosLikeGpuKernel : public GpuKernel {
}
private:
CNodePtr kernel_node_;
size_t input_size_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;

@ -73,7 +73,7 @@ class GpuKernel : public KernelMod {
protected:
virtual void InitResource() {}
virtual void InitSizeLists() = 0;
CNodePtr kernel_node_;
std::weak_ptr<CNode> kernel_node_;
template <typename T>
inline T *GetDeviceAddress(const std::vector<AddressPtr> &addr_list, size_t index) {
@ -202,7 +202,7 @@ class GpuKernel : public KernelMod {
// set the tensor descriptor for cudnn/cublas
void CudnnSetTensorNdDescriptor(const std::vector<size_t> &shape, cudnnTensorDescriptor_t descriptor,
cudnnDataType_t data_type, const CNodePtr &node) {
cudnnDataType_t data_type, const std::weak_ptr<CNode> &node) {
if (shape.size() < 3) {
MS_EXCEPTION(ValueError) << "cudnnSetTensorNdDescriptor don't support" << shape.size() << "D.";
}

@ -63,6 +63,7 @@ class LayerNormGradGradGpuKernel : public GpuKernel {
return true;
}
bool Init(const CNodePtr &kernel_node) override {
kernel_node_ = kernel_node;
int begin_norm_axis = static_cast<int>(GetAttr<int64_t>(kernel_node, "begin_norm_axis"));
int begin_params_axis = static_cast<int>(GetAttr<int64_t>(kernel_node, "begin_params_axis"));

@ -53,16 +53,16 @@ class GpuConvertToDynamicShapeGpuKernel : public GpuKernel {
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(cuda_stream_ptr_)),
"cudaStreamSynchronized failed");
std::vector<TypeId> output_types = {AnfAlgo::GetOutputInferDataType(c_node_ptr_, 0)};
std::vector<TypeId> output_types = {AnfAlgo::GetOutputInferDataType(kernel_node_.lock(), 0)};
std::vector<std::vector<size_t>> output_shapes = {input_shape_};
AnfAlgo::SetOutputInferTypeAndShape(output_types, output_shapes, c_node_ptr_.get());
AnfAlgo::SetOutputInferTypeAndShape(output_types, output_shapes, kernel_node_.lock().get());
}
bool Init(const CNodePtr &kernel_node) override {
kernel_node_ = kernel_node;
size_t input_count = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_count != 1) {
MS_LOG(ERROR) << input_count << "inputs were provided, but GpuConvertToDynamicShapeGpuKernel exepects 1.";
MS_LOG(ERROR) << input_count << "inputs were provided, but GpuConvertToDynamicShapeGpuKernel expects 1.";
return false;
}
@ -71,15 +71,12 @@ class GpuConvertToDynamicShapeGpuKernel : public GpuKernel {
input_size_ *= e;
}
c_node_ptr_ = kernel_node;
InitSizeLists();
return true;
}
void ResetResource() noexcept override {
c_node_ptr_ = nullptr;
cuda_stream_ptr_ = nullptr;
input_shape_.clear();
input_size_ = 1;
@ -93,7 +90,6 @@ class GpuConvertToDynamicShapeGpuKernel : public GpuKernel {
private:
void *cuda_stream_ptr_;
CNodePtr c_node_ptr_;
std::vector<size_t> input_shape_;
size_t input_size_;

@ -49,7 +49,7 @@ namespace gpu {
cudaError_t status = (expression); \
if (status != cudaSuccess) { \
MS_LOG(ERROR) << "CUDA Error: " << message << " | Error Number: " << status << " " << cudaGetErrorString(status) \
<< trace::DumpSourceLines(node); \
<< trace::DumpSourceLines(node.lock()); \
} \
}
@ -72,13 +72,13 @@ namespace gpu {
} \
}
#define CHECK_CUDA_RET_WITH_EXCEPT(node, expression, message) \
{ \
cudaError_t status = (expression); \
if (status != cudaSuccess) { \
MS_LOG(EXCEPTION) << "CUDA Error: " << message << " | Error Number: " << status << " " \
<< cudaGetErrorString(status) << trace::DumpSourceLines(node); \
} \
#define CHECK_CUDA_RET_WITH_EXCEPT(node, expression, message) \
{ \
cudaError_t status = (expression); \
if (status != cudaSuccess) { \
MS_LOG(EXCEPTION) << "CUDA Error: " << message << " | Error Number: " << status << " " \
<< cudaGetErrorString(status) << trace::DumpSourceLines(node.lock()); \
} \
}
#define CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(expression, message) \
@ -90,13 +90,13 @@ namespace gpu {
} \
}
#define CHECK_CUDNN_RET_WITH_EXCEPT(node, expression, message) \
{ \
cudnnStatus_t status = (expression); \
if (status != CUDNN_STATUS_SUCCESS) { \
MS_LOG(EXCEPTION) << "cuDNN Error: " << message << " | Error Number: " << status << " " \
<< cudnnGetErrorString(status) << trace::DumpSourceLines(node); \
} \
#define CHECK_CUDNN_RET_WITH_EXCEPT(node, expression, message) \
{ \
cudnnStatus_t status = (expression); \
if (status != CUDNN_STATUS_SUCCESS) { \
MS_LOG(EXCEPTION) << "cuDNN Error: " << message << " | Error Number: " << status << " " \
<< cudnnGetErrorString(status) << trace::DumpSourceLines(node.lock()); \
} \
}
#define CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(expression, message) \
@ -117,13 +117,13 @@ namespace gpu {
} \
}
#define CHECK_CUDNN_RET_WITH_ERROR(node, expression, message) \
{ \
cudnnStatus_t status = (expression); \
if (status != CUDNN_STATUS_SUCCESS) { \
MS_LOG(ERROR) << "cuDNN Error: " << message << " | Error Number: " << status << " " \
<< cudnnGetErrorString(status) << trace::DumpSourceLines(node); \
} \
#define CHECK_CUDNN_RET_WITH_ERROR(node, expression, message) \
{ \
cudnnStatus_t status = (expression); \
if (status != CUDNN_STATUS_SUCCESS) { \
MS_LOG(ERROR) << "cuDNN Error: " << message << " | Error Number: " << status << " " \
<< cudnnGetErrorString(status) << trace::DumpSourceLines(node.lock()); \
} \
}
#define CHECK_CUBLAS_RET_WITH_EXCEPT_NOTRACE(expression, message) \
@ -139,7 +139,7 @@ namespace gpu {
cublasStatus_t status = (expression); \
if (status != CUBLAS_STATUS_SUCCESS) { \
MS_LOG(EXCEPTION) << "cuBLAS Error: " << message << " | Error Number: " << status \
<< trace::DumpSourceLines(node); \
<< trace::DumpSourceLines(node.lock()); \
} \
}
@ -164,7 +164,7 @@ namespace gpu {
cusolverStatus_t status = (expression); \
if (status != CUSOLVER_STATUS_SUCCESS) { \
MS_LOG(EXCEPTION) << "cusolver Error: " << message << " | Error Number: " << status \
<< trace::DumpSourceLines(node); \
<< trace::DumpSourceLines(node.lock()); \
; \
} \
}
@ -177,12 +177,13 @@ namespace gpu {
} \
}
#define CHECK_NCCL_RET_WITH_EXCEPT(node, expression, message) \
{ \
int result = (expression); \
if (result != ncclSuccess) { \
MS_LOG(EXCEPTION) << "NCCL Error: " << message << " | Error Number: " << result << trace::DumpSourceLines(node); \
} \
#define CHECK_NCCL_RET_WITH_EXCEPT(node, expression, message) \
{ \
int result = (expression); \
if (result != ncclSuccess) { \
MS_LOG(EXCEPTION) << "NCCL Error: " << message << " | Error Number: " << result \
<< trace::DumpSourceLines(node.lock()); \
} \
}
#define VARIABLE_NOT_USED(var) \

@ -177,7 +177,8 @@ bool GenSendRecvCNodesForAllReduce(const std::shared_ptr<session::KernelGraph> &
MS_EXCEPTION_IF_NULL(*recv_node);
cudaEvent_t event = nullptr;
CHECK_CUDA_RET_WITH_EXCEPT(*send_node, cudaEventCreate(&event, cudaEventDisableTiming),
std::weak_ptr<CNode> send_node_ = *send_node;
CHECK_CUDA_RET_WITH_EXCEPT(send_node_, cudaEventCreate(&event, cudaEventDisableTiming),
"Creating cuda event failed.");
AnfAlgo::SetNodeAttr(kAttrRecordEvent, MakeValue(reinterpret_cast<uintptr_t>(event)), *send_node);
AnfAlgo::SetNodeAttr(kAttrWaitEvent, MakeValue(reinterpret_cast<uintptr_t>(event)), *recv_node);

Loading…
Cancel
Save