|
|
@ -42,13 +42,25 @@ class OpBase {
|
|
|
|
|
|
|
|
|
|
|
|
~OpBase() { VLOG(3) << "Destruct Op: " << Type(); }
|
|
|
|
~OpBase() { VLOG(3) << "Destruct Op: " << Type(); }
|
|
|
|
|
|
|
|
|
|
|
|
const std::string& Type() const { return op_->Type(); }
|
|
|
|
const std::string& Type() const {
|
|
|
|
|
|
|
|
return op_ ? op_->Type() : UnknownOpType();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
const framework::AttributeMap& Attrs() const { return attrs_; }
|
|
|
|
const framework::AttributeMap& Attrs() const { return attrs_; }
|
|
|
|
|
|
|
|
|
|
|
|
const framework::OpInfo& Info() const { return op_->Info(); }
|
|
|
|
const framework::OpInfo& Info() const {
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(op_, platform::errors::PreconditionNotMet(
|
|
|
|
|
|
|
|
"OpBase::Info() should be called after "
|
|
|
|
|
|
|
|
"OpBase::SetType() is called"));
|
|
|
|
|
|
|
|
return op_->Info();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
const framework::OperatorBase& InnerOp() const { return *op_; }
|
|
|
|
const framework::OperatorBase& InnerOp() const {
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(op_, platform::errors::PreconditionNotMet(
|
|
|
|
|
|
|
|
"OpBase::InnerOp() should be called after "
|
|
|
|
|
|
|
|
"OpBase::SetType() is called"));
|
|
|
|
|
|
|
|
return *op_;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void ClearBackwardTrace();
|
|
|
|
void ClearBackwardTrace();
|
|
|
|
|
|
|
|
|
|
|
@ -63,7 +75,7 @@ class OpBase {
|
|
|
|
void SetType(const std::string& type);
|
|
|
|
void SetType(const std::string& type);
|
|
|
|
|
|
|
|
|
|
|
|
void CheckAttrs() {
|
|
|
|
void CheckAttrs() {
|
|
|
|
auto& info = op_->Info();
|
|
|
|
auto& info = Info();
|
|
|
|
if (info.Checker() != nullptr) {
|
|
|
|
if (info.Checker() != nullptr) {
|
|
|
|
info.Checker()->Check(&attrs_, true);
|
|
|
|
info.Checker()->Check(&attrs_, true);
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -150,6 +162,12 @@ class OpBase {
|
|
|
|
const framework::AttributeMap& attrs,
|
|
|
|
const framework::AttributeMap& attrs,
|
|
|
|
const platform::Place& place);
|
|
|
|
const platform::Place& place);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
|
|
|
static const std::string& UnknownOpType() {
|
|
|
|
|
|
|
|
static std::string kUnknownOpType{"unknown"};
|
|
|
|
|
|
|
|
return kUnknownOpType;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
private:
|
|
|
|
NameVarMap<VariableWrapper> ins_;
|
|
|
|
NameVarMap<VariableWrapper> ins_;
|
|
|
|
NameVarMap<VariableWrapper> outs_;
|
|
|
|
NameVarMap<VariableWrapper> outs_;
|
|
|
|