|
|
|
@ -24,7 +24,7 @@ class FakeInitInferShape : public framework::InferShapeBase {
|
|
|
|
|
void operator()(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
"Output(Out) of FakeInitOp should not be null.");
|
|
|
|
|
auto &shape = ctx->Attrs().Get<std::vector<int>>("shape");
|
|
|
|
|
auto &shape = ctx->Attrs().Get<std::vector<int64_t>>("shape");
|
|
|
|
|
ctx->SetOutputDim("Out", framework::make_ddim(shape));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -42,10 +42,10 @@ class FakeInitOp : public framework::OperatorBase {
|
|
|
|
|
|
|
|
|
|
if (out_var.IsType<framework::LoDTensor>()) {
|
|
|
|
|
tensor = out_var.GetMutable<framework::LoDTensor>();
|
|
|
|
|
tensor->Resize(framework::make_ddim(Attr<std::vector<int>>("shape")));
|
|
|
|
|
tensor->Resize(framework::make_ddim(Attr<std::vector<int64_t>>("shape")));
|
|
|
|
|
} else if (out_var.IsType<framework::SelectedRows>()) {
|
|
|
|
|
tensor = out_var.GetMutable<framework::SelectedRows>()->mutable_value();
|
|
|
|
|
tensor->Resize(framework::make_ddim(Attr<std::vector<int>>("shape")));
|
|
|
|
|
tensor->Resize(framework::make_ddim(Attr<std::vector<int64_t>>("shape")));
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW(
|
|
|
|
|
"fake init op's output only"
|
|
|
|
@ -63,7 +63,8 @@ class FakeInitOpVarTypeInference : public framework::VarTypeInference {
|
|
|
|
|
class FakeInitOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
void Make() override {
|
|
|
|
|
AddAttr<std::vector<int>>("shape", "(vector<int>) The shape of the output");
|
|
|
|
|
AddAttr<std::vector<int64_t>>("shape",
|
|
|
|
|
"(vector<int64_t>) The shape of the output");
|
|
|
|
|
AddOutput("Out",
|
|
|
|
|
"(Tensor) Tensor of specified shape will be filled "
|
|
|
|
|
"with the specified value");
|
|
|
|
|