|
|
@ -44,8 +44,7 @@ class FillOp : public framework::OperatorWithKernel {
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* context) const override {
|
|
|
|
void InferShape(framework::InferShapeContext* context) const override {
|
|
|
|
PADDLE_ENFORCE_EQ(context->HasOutput("Out"), true,
|
|
|
|
OP_INOUT_CHECK(context->HasOutput("Out"), "Output", "Out", "Fill");
|
|
|
|
"Output(Out) of FillOp should not be null.");
|
|
|
|
|
|
|
|
auto& shape = context->Attrs().Get<std::vector<int>>("shape");
|
|
|
|
auto& shape = context->Attrs().Get<std::vector<int>>("shape");
|
|
|
|
context->SetOutputDim("Out", framework::make_ddim(shape));
|
|
|
|
context->SetOutputDim("Out", framework::make_ddim(shape));
|
|
|
|
}
|
|
|
|
}
|
|
|
|