|
|
|
@ -148,9 +148,17 @@ class SliceOp : public framework::OperatorWithKernel {
|
|
|
|
|
protected:
|
|
|
|
|
framework::OpKernelType GetExpectedKernelType(
|
|
|
|
|
const framework::ExecutionContext &ctx) const override {
|
|
|
|
|
auto *in_var = ctx.InputVar("Input");
|
|
|
|
|
if (in_var->IsType<framework::LoDTensor>()) {
|
|
|
|
|
auto &in_tensor = in_var->Get<framework::LoDTensor>();
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
in_tensor.IsInitialized(), true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The tensor Input (Input) of Slice op is not initialized."));
|
|
|
|
|
return framework::OpKernelType(in_tensor.type(), in_tensor.place());
|
|
|
|
|
}
|
|
|
|
|
return framework::OpKernelType(
|
|
|
|
|
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
|
|
|
|
|
ctx.device_context());
|
|
|
|
|
OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace());
|
|
|
|
|
}
|
|
|
|
|
framework::OpKernelType GetKernelTypeForVar(
|
|
|
|
|
const std::string &var_name, const Tensor &tensor,
|
|
|
|
|