Add nullptr check

update-doc-pybind
wanghaoshuang 8 years ago
parent 14fb15b685
commit 743dfd82e7

@ -68,9 +68,10 @@ class ClipOpGrad : public framework::OperatorWithKernel {
"Input(Out@GRAD) should not be null");
auto x_dims = ctx.Input<LoDTensor>("X")->dims();
auto *x_grad = ctx.Output<LoDTensor>(framework::GradVarName("X"));
if (x_grad != nullptr) {
x_grad->Resize(x_dims);
}
}
};
} // namespace operators

@ -43,6 +43,7 @@ class ClipGradientOpCUDAKernel : public framework::OpKernel {
auto min = context.Attr<float>("min");
auto* d_out = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* d_x = context.Output<LoDTensor>(framework::GradVarName("X"));
if (d_x != nullptr) {
auto* x = context.Input<LoDTensor>("X");
auto dims = d_x->dims();
int64_t count = d_out->numel();
@ -60,6 +61,7 @@ class ClipGradientOpCUDAKernel : public framework::OpKernel {
.stream()>>>(count, min, max, x_data, d_out_data,
d_x_data);
}
}
};
} // namespace operators

@ -78,6 +78,7 @@ class ClipGradKernel : public framework::OpKernel {
auto min = context.op().Attr<float>("min");
auto* d_out = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* d_x = context.Output<LoDTensor>(framework::GradVarName("X"));
if (d_x != nullptr) {
auto* x = context.Input<LoDTensor>("X");
auto dims = d_x->dims();
int64_t count = d_out->numel();
@ -92,6 +93,7 @@ class ClipGradKernel : public framework::OpKernel {
}
}
}
}
};
} // namespace operators

Loading…
Cancel
Save