Check the rank of input in kernel of set_value op (#30147)

revert-31562-mean
liym27 4 years ago committed by GitHub
parent b7335b4db7
commit 3ce878f309
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -83,7 +83,7 @@ template <typename DeviceContext, typename T>
class SetValueKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
const int rank = ctx.Output<framework::LoDTensor>("Out")->dims().size();
const int rank = ctx.Input<framework::LoDTensor>("Input")->dims().size();
// TODO(liym27): A more elegent code to do this. C++ has to make template
// integer as constant, but we had better have alternative writing in the
@ -107,6 +107,9 @@ class SetValueKernel : public framework::OpKernel<T> {
case 6:
SetValueCompute<6>(ctx);
break;
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"The rank of input should be less than 7, but received %d.", rank));
}
}

Loading…
Cancel
Save