|
|
|
|
@ -25,7 +25,7 @@ class SliceOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Input"),
|
|
|
|
|
"Input (Input) of slice op should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
@ -58,7 +58,7 @@ class SliceOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
framework::OpKernelType GetExpectedKernelType(
|
|
|
|
|
const framework::ExecutionContext &ctx) const override {
|
|
|
|
|
const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
return framework::OpKernelType(
|
|
|
|
|
framework::ToDataType(ctx.Input<Tensor>("Input")->type()),
|
|
|
|
|
ctx.GetPlace());
|
|
|
|
|
@ -119,15 +119,54 @@ Following examples will explain how slice works:
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class SliceOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Input"), "Input should not be null");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
|
|
|
|
"Input(Out@GRAD) should not be null");
|
|
|
|
|
auto x_dims = ctx->GetInputDim("Input");
|
|
|
|
|
auto x_grad_name = framework::GradVarName("Input");
|
|
|
|
|
if (ctx->HasOutput(x_grad_name)) {
|
|
|
|
|
ctx->SetOutputDim(x_grad_name, x_dims);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class SliceOpGradMaker : public framework::SingleGradOpDescMaker {
|
|
|
|
|
public:
|
|
|
|
|
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
std::unique_ptr<framework::OpDesc> Apply() const override {
|
|
|
|
|
auto* bind = new framework::OpDesc();
|
|
|
|
|
bind->SetInput("Input", Input("Input"));
|
|
|
|
|
bind->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
|
|
|
|
|
bind->SetOutput(framework::GradVarName("Input"), InputGrad("Input"));
|
|
|
|
|
bind->SetAttrMap(Attrs());
|
|
|
|
|
bind->SetType("slice_grad");
|
|
|
|
|
return std::unique_ptr<framework::OpDesc>(bind);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
REGISTER_OPERATOR(slice, ops::SliceOp, ops::SliceOpMaker,
|
|
|
|
|
paddle::framework::EmptyGradOpMaker);
|
|
|
|
|
ops::SliceOpGradMaker);
|
|
|
|
|
REGISTER_OPERATOR(slice_grad, ops::SliceOpGrad);
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
slice, ops::SliceKernel<paddle::platform::CPUDeviceContext, int>,
|
|
|
|
|
ops::SliceKernel<paddle::platform::CPUDeviceContext, int64_t>,
|
|
|
|
|
ops::SliceKernel<paddle::platform::CPUDeviceContext, float>,
|
|
|
|
|
ops::SliceKernel<paddle::platform::CPUDeviceContext, double>);
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
slice_grad, ops::SliceGradKernel<paddle::platform::CPUDeviceContext, int>,
|
|
|
|
|
ops::SliceGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
|
|
|
|
|
ops::SliceGradKernel<paddle::platform::CPUDeviceContext, float>,
|
|
|
|
|
ops::SliceGradKernel<paddle::platform::CPUDeviceContext, double>);
|
|
|
|
|
|