|
|
|
@ -24,7 +24,7 @@ class FillConstantInferShape : public framework::InferShapeBase {
|
|
|
|
|
void operator()(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
"Output(Out) of FillConstantOp should not be null.");
|
|
|
|
|
auto &shape = ctx->Attrs().Get<std::vector<int>>("shape");
|
|
|
|
|
auto &shape = ctx->Attrs().Get<std::vector<int64_t>>("shape");
|
|
|
|
|
ctx->SetOutputDim("Out", framework::make_ddim(shape));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -47,10 +47,10 @@ class FillConstantOp : public framework::OperatorBase {
|
|
|
|
|
|
|
|
|
|
if (out_var.IsType<framework::LoDTensor>()) {
|
|
|
|
|
tensor = out_var.GetMutable<framework::LoDTensor>();
|
|
|
|
|
tensor->Resize(framework::make_ddim(Attr<std::vector<int>>("shape")));
|
|
|
|
|
tensor->Resize(framework::make_ddim(Attr<std::vector<int64_t>>("shape")));
|
|
|
|
|
} else if (out_var.IsType<framework::SelectedRows>()) {
|
|
|
|
|
tensor = out_var.GetMutable<framework::SelectedRows>()->mutable_value();
|
|
|
|
|
tensor->Resize(framework::make_ddim(Attr<std::vector<int>>("shape")));
|
|
|
|
|
tensor->Resize(framework::make_ddim(Attr<std::vector<int64_t>>("shape")));
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW(
|
|
|
|
|
"fill constant op's output only"
|
|
|
|
@ -83,7 +83,8 @@ class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
"(int, default 5 (FP32)) "
|
|
|
|
|
"Output data type")
|
|
|
|
|
.SetDefault(framework::proto::VarType::FP32);
|
|
|
|
|
AddAttr<std::vector<int>>("shape", "(vector<int>) The shape of the output");
|
|
|
|
|
AddAttr<std::vector<int64_t>>("shape",
|
|
|
|
|
"(vector<int64_t>) The shape of the output");
|
|
|
|
|
AddAttr<float>("value", "(float, default 0) The value to be filled")
|
|
|
|
|
.SetDefault(0.0f);
|
|
|
|
|
AddAttr<bool>("force_cpu",
|
|
|
|
|