Add Init to OperatorBase (#2838)

cblas_new
Qiao Longfei 8 years ago committed by GitHub
parent 90cf44d79a
commit 728665d709

@ -119,6 +119,7 @@ class OpRegistry {
op->attrs_[attr.name()] = AttrTypeHelper::GetAttrValue(attr);
}
op_checkers().at(op_type).Check(op->attrs_);
op->Init();
return op;
}

@ -49,6 +49,10 @@ class OperatorBase {
std::string DebugString() const;
/// Init will be called after CreateOperator, you can put some initialization
/// logic here.
virtual void Init() {}
/// InferShape infer the size of Variables used by this Operator with
/// information inside scope
virtual void InferShape(const std::shared_ptr<Scope>& scope) const = 0;

@ -21,14 +21,19 @@ namespace framework {
class OperatorTest : public OperatorBase {
public:
void Init() override { x = 1; }
void InferShape(const std::shared_ptr<Scope>& scope) const override {}
void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const override {
float scale = GetAttr<float>("scale");
ASSERT_NEAR(scale, 3.14, 1e-5);
ASSERT_EQ(scope->GetVariable(inputs_[0]), nullptr);
ASSERT_EQ(x, 1);
ASSERT_NE(scope->GetVariable(outputs_[0]), nullptr);
}
public:
float x = 0;
};
class OperatorTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {

Loading…
Cancel
Save