!8766 fix bugs when shape is empty or equal to zero

From: @huaweib
Reviewed-by: @kisnwang,@liangchenghui,@jjfeing
Signed-off-by: @jjfeing
pull/8766/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 10a89ce17d

@ -40,6 +40,9 @@ void ReduceCPUKernel::InitKernel(const CNodePtr &kernel_node) {
} }
shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
CheckAxis(kernel_node); CheckAxis(kernel_node);
if (shape_.empty()) {
shape_.push_back(1);
}
for (size_t i = 0; i < shape_.size(); ++i) { for (size_t i = 0; i < shape_.size(); ++i) {
if (shape_[i] <= 0) { if (shape_[i] <= 0) {
MS_LOG(EXCEPTION) << "shape value is invalid."; MS_LOG(EXCEPTION) << "shape value is invalid.";

@ -40,8 +40,10 @@ class BinaryCrossEntropyGpuKernel : public GpuKernel {
T *weight = GetDeviceAddress<T>(inputs, 2); T *weight = GetDeviceAddress<T>(inputs, 2);
T *loss = GetDeviceAddress<T>(outputs, 0); T *loss = GetDeviceAddress<T>(outputs, 0);
T *tmp_loss = GetDeviceAddress<T>(workspace, 0); T *tmp_loss = GetDeviceAddress<T>(workspace, 0);
BinaryCrossEntropyLoss(input_size_, reduction_, input_x, input_y, weight, loss, tmp_loss, if (input_size_ > 0) {
reinterpret_cast<cudaStream_t>(stream_ptr)); BinaryCrossEntropyLoss(input_size_, reduction_, input_x, input_y, weight, loss, tmp_loss,
reinterpret_cast<cudaStream_t>(stream_ptr));
}
return true; return true;
} }

@ -42,8 +42,10 @@ class BinaryCrossEntropyGradGpuKernel : public GpuKernel {
T *dloss = GetDeviceAddress<T>(inputs, 2); T *dloss = GetDeviceAddress<T>(inputs, 2);
T *weight = GetDeviceAddress<T>(inputs, 3); T *weight = GetDeviceAddress<T>(inputs, 3);
T *dx = GetDeviceAddress<T>(outputs, 0); T *dx = GetDeviceAddress<T>(outputs, 0);
BinaryCrossEntropyLossGrad(input_size_, reduction_, input_x, input_y, weight, dloss, dx, if (input_size_ > 0) {
reinterpret_cast<cudaStream_t>(stream_ptr)); BinaryCrossEntropyLossGrad(input_size_, reduction_, input_x, input_y, weight, dloss, dx,
reinterpret_cast<cudaStream_t>(stream_ptr));
}
return true; return true;
} }
@ -52,7 +54,6 @@ class BinaryCrossEntropyGradGpuKernel : public GpuKernel {
for (size_t i = 0; i < input_shape.size(); i++) { for (size_t i = 0; i < input_shape.size(); i++) {
input_size_ *= input_shape[i]; input_size_ *= input_shape[i];
} }
string reduction = GetAttr<string>(kernel_node, "reduction"); string reduction = GetAttr<string>(kernel_node, "reduction");
if (reduction == "none") { if (reduction == "none") {
reduction_ = 0; reduction_ = 0;

Loading…
Cancel
Save