|
|
|
@ -83,7 +83,7 @@ template <typename DeviceContext, typename T>
|
|
|
|
|
class SetValueKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const {
|
|
|
|
|
const int rank = ctx.Output<framework::LoDTensor>("Out")->dims().size();
|
|
|
|
|
const int rank = ctx.Input<framework::LoDTensor>("Input")->dims().size();
|
|
|
|
|
|
|
|
|
|
// TODO(liym27): A more elegent code to do this. C++ has to make template
|
|
|
|
|
// integer as constant, but we had better have alternative writing in the
|
|
|
|
@ -107,6 +107,9 @@ class SetValueKernel : public framework::OpKernel<T> {
|
|
|
|
|
case 6:
|
|
|
|
|
SetValueCompute<6>(ctx);
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
PADDLE_THROW(platform::errors::InvalidArgument(
|
|
|
|
|
"The rank of input should be less than 7, but received %d.", rank));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|