|
|
|
@ -58,11 +58,11 @@ class CompileTimeInferShapeContext : public InferShapeContext {
|
|
|
|
|
PADDLE_ENFORCE_LT(j, Outputs(out).size());
|
|
|
|
|
auto *in_var = block_.FindVarRecursive(Inputs(in)[i]);
|
|
|
|
|
auto *out_var = block_.FindVarRecursive(Outputs(out)[j]);
|
|
|
|
|
if (in_var->GetType() != VarDesc::LOD_TENSOR) {
|
|
|
|
|
if (in_var->GetType() != proto::VarDesc::LOD_TENSOR) {
|
|
|
|
|
VLOG(3) << "input " << in << " is not LodTensor";
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE_EQ(in_var->GetType(), VarDesc::LOD_TENSOR,
|
|
|
|
|
PADDLE_ENFORCE_EQ(in_var->GetType(), proto::VarDesc::LOD_TENSOR,
|
|
|
|
|
"The %d-th output of Output(%s) must be LoDTensor.", j,
|
|
|
|
|
out);
|
|
|
|
|
out_var->SetLoDLevel(in_var->GetLodLevel());
|
|
|
|
@ -70,7 +70,7 @@ class CompileTimeInferShapeContext : public InferShapeContext {
|
|
|
|
|
bool IsRuntime() const override;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
VarDesc::VarType GetVarType(const std::string &name) const override;
|
|
|
|
|
proto::VarDesc::VarType GetVarType(const std::string &name) const override;
|
|
|
|
|
|
|
|
|
|
DDim GetDim(const std::string &name) const override;
|
|
|
|
|
|
|
|
|
@ -90,12 +90,12 @@ OpDescBind::OpDescBind(const std::string &type, const VariableNameMap &inputs,
|
|
|
|
|
need_update_ = true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
OpDescBind::OpDescBind(const OpDesc &desc, ProgramDescBind *prog)
|
|
|
|
|
OpDescBind::OpDescBind(const proto::OpDesc &desc, ProgramDescBind *prog)
|
|
|
|
|
: desc_(desc), need_update_(false) {
|
|
|
|
|
// restore inputs_
|
|
|
|
|
int input_size = desc_.inputs_size();
|
|
|
|
|
for (int i = 0; i < input_size; ++i) {
|
|
|
|
|
const OpDesc::Var &var = desc_.inputs(i);
|
|
|
|
|
const proto::OpDesc::Var &var = desc_.inputs(i);
|
|
|
|
|
std::vector<std::string> &args = inputs_[var.parameter()];
|
|
|
|
|
int argu_size = var.arguments_size();
|
|
|
|
|
args.reserve(argu_size);
|
|
|
|
@ -106,7 +106,7 @@ OpDescBind::OpDescBind(const OpDesc &desc, ProgramDescBind *prog)
|
|
|
|
|
// restore outputs_
|
|
|
|
|
int output_size = desc_.outputs_size();
|
|
|
|
|
for (int i = 0; i < output_size; ++i) {
|
|
|
|
|
const OpDesc::Var &var = desc_.outputs(i);
|
|
|
|
|
const proto::OpDesc::Var &var = desc_.outputs(i);
|
|
|
|
|
std::vector<std::string> &args = outputs_[var.parameter()];
|
|
|
|
|
int argu_size = var.arguments_size();
|
|
|
|
|
args.reserve(argu_size);
|
|
|
|
@ -115,9 +115,9 @@ OpDescBind::OpDescBind(const OpDesc &desc, ProgramDescBind *prog)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// restore attrs_
|
|
|
|
|
for (const OpDesc::Attr &attr : desc_.attrs()) {
|
|
|
|
|
for (const proto::OpDesc::Attr &attr : desc_.attrs()) {
|
|
|
|
|
std::string attr_name = attr.name();
|
|
|
|
|
if (attr.type() != AttrType::BLOCK) {
|
|
|
|
|
if (attr.type() != proto::AttrType::BLOCK) {
|
|
|
|
|
attrs_[attr_name] = GetAttrValue(attr);
|
|
|
|
|
} else {
|
|
|
|
|
auto bid = attr.block_idx();
|
|
|
|
@ -126,7 +126,7 @@ OpDescBind::OpDescBind(const OpDesc &desc, ProgramDescBind *prog)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
OpDesc *OpDescBind::Proto() {
|
|
|
|
|
proto::OpDesc *OpDescBind::Proto() {
|
|
|
|
|
Flush();
|
|
|
|
|
return &desc_;
|
|
|
|
|
}
|
|
|
|
@ -175,10 +175,10 @@ void OpDescBind::SetOutput(const std::string ¶m_name,
|
|
|
|
|
this->outputs_[param_name] = args;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AttrType OpDescBind::GetAttrType(const std::string &name) const {
|
|
|
|
|
proto::AttrType OpDescBind::GetAttrType(const std::string &name) const {
|
|
|
|
|
auto it = attrs_.find(name);
|
|
|
|
|
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
|
|
|
|
|
return static_cast<AttrType>(it->second.which() - 1);
|
|
|
|
|
return static_cast<proto::AttrType>(it->second.which() - 1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<std::string> OpDescBind::AttrNames() const {
|
|
|
|
@ -253,8 +253,8 @@ void OpDescBind::RenameInput(const std::string &old_name,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
struct SetAttrDescVisitor : public boost::static_visitor<void> {
|
|
|
|
|
explicit SetAttrDescVisitor(OpDesc::Attr *attr) : attr_(attr) {}
|
|
|
|
|
mutable OpDesc::Attr *attr_;
|
|
|
|
|
explicit SetAttrDescVisitor(proto::OpDesc::Attr *attr) : attr_(attr) {}
|
|
|
|
|
mutable proto::OpDesc::Attr *attr_;
|
|
|
|
|
void operator()(int v) const { attr_->set_i(v); }
|
|
|
|
|
void operator()(float v) const { attr_->set_f(v); }
|
|
|
|
|
void operator()(const std::string &v) const { attr_->set_s(v); }
|
|
|
|
@ -272,7 +272,9 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> {
|
|
|
|
|
void operator()(const std::vector<bool> &v) const {
|
|
|
|
|
VectorToRepeated(v, attr_->mutable_bools());
|
|
|
|
|
}
|
|
|
|
|
void operator()(BlockDesc *desc) const { attr_->set_block_idx(desc->idx()); }
|
|
|
|
|
void operator()(proto::BlockDesc *desc) const {
|
|
|
|
|
attr_->set_block_idx(desc->idx());
|
|
|
|
|
}
|
|
|
|
|
void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); }
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -297,7 +299,7 @@ void OpDescBind::Flush() {
|
|
|
|
|
auto *attr_desc = desc_.add_attrs();
|
|
|
|
|
attr_desc->set_name(attr.first);
|
|
|
|
|
attr_desc->set_type(
|
|
|
|
|
static_cast<framework::AttrType>(attr.second.which() - 1));
|
|
|
|
|
static_cast<proto::AttrType>(attr.second.which() - 1));
|
|
|
|
|
SetAttrDescVisitor visitor(attr_desc);
|
|
|
|
|
boost::apply_visitor(visitor, attr.second);
|
|
|
|
|
}
|
|
|
|
@ -375,7 +377,7 @@ void OpDescBind::InferVarType(BlockDescBind *block) const {
|
|
|
|
|
for (auto &out_pair : this->outputs_) {
|
|
|
|
|
for (auto &out_var_name : out_pair.second) {
|
|
|
|
|
block->FindRecursiveOrCreateVar(out_var_name)
|
|
|
|
|
->SetType(VarDesc::LOD_TENSOR);
|
|
|
|
|
->SetType(proto::VarDesc::LOD_TENSOR);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -484,7 +486,7 @@ void CompileTimeInferShapeContext::SetDim(const std::string &name,
|
|
|
|
|
}
|
|
|
|
|
bool CompileTimeInferShapeContext::IsRuntime() const { return false; }
|
|
|
|
|
|
|
|
|
|
VarDesc::VarType CompileTimeInferShapeContext::GetVarType(
|
|
|
|
|
proto::VarDesc::VarType CompileTimeInferShapeContext::GetVarType(
|
|
|
|
|
const std::string &name) const {
|
|
|
|
|
return block_.FindVarRecursive(name)->GetType();
|
|
|
|
|
}
|
|
|
|
|