|
|
|
|
@ -21,14 +21,19 @@ namespace framework {
|
|
|
|
|
|
|
|
|
|
class OperatorTest : public OperatorBase {
|
|
|
|
|
public:
|
|
|
|
|
void Init() override { x = 1; }
|
|
|
|
|
void InferShape(const std::shared_ptr<Scope>& scope) const override {}
|
|
|
|
|
void Run(const std::shared_ptr<Scope>& scope,
|
|
|
|
|
const platform::DeviceContext& dev_ctx) const override {
|
|
|
|
|
float scale = GetAttr<float>("scale");
|
|
|
|
|
ASSERT_NEAR(scale, 3.14, 1e-5);
|
|
|
|
|
ASSERT_EQ(scope->GetVariable(inputs_[0]), nullptr);
|
|
|
|
|
ASSERT_EQ(x, 1);
|
|
|
|
|
ASSERT_NE(scope->GetVariable(outputs_[0]), nullptr);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
float x = 0;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class OperatorTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
|
|
|
|
|
|