Merge pull request #3430 from wangkuiyi/add_operatorbase_constructors

Add constructors to OperatorBase and all sub-classes
revert-3824-remove_grad_op_type
Yi Wang 8 years ago committed by GitHub
commit 38f4b1d59e

@ -30,6 +30,8 @@ using DeviceContext = platform::DeviceContext;
class EmptyOp : public OperatorBase { class EmptyOp : public OperatorBase {
public: public:
DEFINE_OPERATOR_CTOR(EmptyOp, OperatorBase)
void InferShape(const Scope &scope) const override {} void InferShape(const Scope &scope) const override {}
void Run(const Scope &scope, const DeviceContext &dev_ctx) const override {} void Run(const Scope &scope, const DeviceContext &dev_ctx) const override {}
}; };

@ -10,6 +10,8 @@ namespace framework {
class NOP : public OperatorBase { class NOP : public OperatorBase {
public: public:
DEFINE_OPERATOR_CTOR(NOP, OperatorBase)
void InferShape(const Scope &scope) const override {} void InferShape(const Scope &scope) const override {}
void Run(const Scope &scope, void Run(const Scope &scope,
const platform::DeviceContext &dev_ctx) const override {} const platform::DeviceContext &dev_ctx) const override {}

@ -7,6 +7,8 @@ namespace paddle {
namespace framework { namespace framework {
class CosineOp : public OperatorBase { class CosineOp : public OperatorBase {
public: public:
DEFINE_OPERATOR_CTOR(CosineOp, OperatorBase)
void Run(const Scope& scope, void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {} const platform::DeviceContext& dev_ctx) const override {}
void InferShape(const Scope& scope) const override {} void InferShape(const Scope& scope) const override {}
@ -27,6 +29,8 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
class MyTestOp : public OperatorBase { class MyTestOp : public OperatorBase {
public: public:
DEFINE_OPERATOR_CTOR(MyTestOp, OperatorBase)
void InferShape(const Scope& scope) const override {} void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope, void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {} const platform::DeviceContext& dev_ctx) const override {}

@ -63,6 +63,17 @@ class ExecutionContext;
*/ */
class OperatorBase { class OperatorBase {
public: public:
OperatorBase() {} // TODO(yi): This constructor is to be removed.
OperatorBase(const std::string& type, const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
const AttributeMap& attrs,
std::unordered_map<std::string, int>* in_out_idxs)
: type_(type),
inputs_(inputs),
outputs_(outputs),
attrs_(attrs),
in_out_idxs_(in_out_idxs) {}
virtual ~OperatorBase() {} virtual ~OperatorBase() {}
template <typename T> template <typename T>
@ -109,6 +120,9 @@ class OperatorBase {
const std::vector<std::string> Inputs() const { return inputs_; } const std::vector<std::string> Inputs() const { return inputs_; }
const std::vector<std::string> Outputs() const { return outputs_; } const std::vector<std::string> Outputs() const { return outputs_; }
const AttributeMap& Attrs() const { return attrs_; } const AttributeMap& Attrs() const { return attrs_; }
const std::unordered_map<std::string, int>* InOutIdx() const {
return in_out_idxs_.get();
}
public: public:
std::string type_; std::string type_;
@ -286,6 +300,14 @@ class OpKernel {
class OperatorWithKernel : public OperatorBase { class OperatorWithKernel : public OperatorBase {
public: public:
OperatorWithKernel() {} // TODO(yi): This constructor is to be removed.
OperatorWithKernel(const std::string& type,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
const AttributeMap& attrs,
std::unordered_map<std::string, int>* in_out_idxs)
: OperatorBase(type, inputs, outputs, attrs, in_out_idxs) {}
struct OpKernelKey { struct OpKernelKey {
platform::Place place_; platform::Place place_;
@ -335,5 +357,15 @@ class OperatorWithKernel : public OperatorBase {
virtual void InferShape(const InferShapeContext& ctx) const = 0; virtual void InferShape(const InferShapeContext& ctx) const = 0;
}; };
#define DEFINE_OPERATOR_CTOR(Class, ParentClass) \
public: \
Class() { /* TODO(yi): This constructor is to be removed. */ \
} \
Class(const std::string& type, const std::vector<std::string>& inputs, \
const std::vector<std::string>& outputs, \
const ::paddle::framework::AttributeMap& attrs, \
std::unordered_map<std::string, int>* in_out_idxs) \
: ParentClass(type, inputs, outputs, attrs, in_out_idxs) {}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle

@ -23,6 +23,8 @@ static int op_run_num = 0;
class OpWithoutKernelTest : public OperatorBase { class OpWithoutKernelTest : public OperatorBase {
public: public:
DEFINE_OPERATOR_CTOR(OpWithoutKernelTest, OperatorBase)
void Init() override { x = 1; } void Init() override { x = 1; }
void InferShape(const Scope& scope) const override {} void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope, void Run(const Scope& scope,
@ -97,6 +99,8 @@ class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
static int cpu_kernel_run_num = 0; static int cpu_kernel_run_num = 0;
class OpWithKernelTest : public OperatorWithKernel { class OpWithKernelTest : public OperatorWithKernel {
public:
DEFINE_OPERATOR_CTOR(OpWithKernelTest, OperatorWithKernel)
protected: protected:
void InferShape(const framework::InferShapeContext& ctx) const override {} void InferShape(const framework::InferShapeContext& ctx) const override {}
}; };
@ -116,6 +120,8 @@ class CPUKernelTest : public OpKernel {
// multiple inputs test // multiple inputs test
class OperatorMultiInputsTest : public OperatorBase { class OperatorMultiInputsTest : public OperatorBase {
public: public:
DEFINE_OPERATOR_CTOR(OperatorMultiInputsTest, OperatorBase)
void Init() override { x = 1; } void Init() override { x = 1; }
void InferShape(const Scope& scope) const override {} void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope, void Run(const Scope& scope,

@ -18,6 +18,7 @@ namespace paddle {
namespace operators { namespace operators {
class AddOp : public framework::OperatorWithKernel { class AddOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(AddOp, framework::OperatorWithKernel)
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_EQ(ctx.InputSize(), 2); PADDLE_ENFORCE_EQ(ctx.InputSize(), 2);
@ -47,6 +48,7 @@ The equation is: Out = X + Y
}; };
class AddOpGrad : public framework::OperatorWithKernel { class AddOpGrad : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(AddOpGrad, framework::OperatorWithKernel)
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override {} void InferShape(const framework::InferShapeContext &ctx) const override {}
}; };

@ -18,6 +18,7 @@ namespace paddle {
namespace operators { namespace operators {
class OnehotCrossEntropyOp : public framework::OperatorWithKernel { class OnehotCrossEntropyOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(OnehotCrossEntropyOp, framework::OperatorWithKernel)
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_EQ(ctx.InputSize(), 2, PADDLE_ENFORCE_EQ(ctx.InputSize(), 2,
@ -38,6 +39,8 @@ class OnehotCrossEntropyOp : public framework::OperatorWithKernel {
}; };
class OnehotCrossEntropyGradientOp : public framework::OperatorWithKernel { class OnehotCrossEntropyGradientOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(OnehotCrossEntropyGradientOp,
framework::OperatorWithKernel)
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
auto X_grad = ctx.Output<Tensor>(framework::GradVarName("X")); auto X_grad = ctx.Output<Tensor>(framework::GradVarName("X"));

@ -18,6 +18,7 @@ namespace paddle {
namespace operators { namespace operators {
class FillZerosLikeOp : public framework::OperatorWithKernel { class FillZerosLikeOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(FillZerosLikeOp, framework::OperatorWithKernel)
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_EQ(ctx.InputSize(), 1UL, PADDLE_ENFORCE_EQ(ctx.InputSize(), 1UL,

@ -43,6 +43,7 @@ class GaussianRandomKernel : public framework::OpKernel {
}; };
class GaussianRandomOp : public framework::OperatorWithKernel { class GaussianRandomOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(GaussianRandomOp, framework::OperatorWithKernel)
protected: protected:
void InferShape(const framework::InferShapeContext& context) const override { void InferShape(const framework::InferShapeContext& context) const override {
auto* tensor = context.Output<framework::Tensor>(0); auto* tensor = context.Output<framework::Tensor>(0);

@ -18,6 +18,7 @@ namespace paddle {
namespace operators { namespace operators {
class MeanOp : public framework::OperatorWithKernel { class MeanOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(MeanOp, framework::OperatorWithKernel)
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_EQ(ctx.InputSize(), 1, "Input size of AddOp must be one"); PADDLE_ENFORCE_EQ(ctx.InputSize(), 1, "Input size of AddOp must be one");
@ -39,6 +40,7 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
}; };
class MeanGradOp : public framework::OperatorWithKernel { class MeanGradOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(MeanGradOp, framework::OperatorWithKernel)
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<Tensor>(framework::GradVarName("X")) ctx.Output<Tensor>(framework::GradVarName("X"))

@ -18,6 +18,7 @@ namespace paddle {
namespace operators { namespace operators {
class MulOp : public framework::OperatorWithKernel { class MulOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(MulOp, framework::OperatorWithKernel)
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 2, "The mul op must take two inputs"); PADDLE_ENFORCE(ctx.InputSize() == 2, "The mul op must take two inputs");
@ -53,6 +54,7 @@ The equation is: Out = X * Y
}; };
class MulOpGrad : public framework::OperatorWithKernel { class MulOpGrad : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(MulOpGrad, framework::OperatorWithKernel)
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override {} void InferShape(const framework::InferShapeContext &ctx) const override {}
std::string DebugString() const override { std::string DebugString() const override {

@ -35,6 +35,8 @@ namespace operators {
*/ */
class NetOp : public framework::OperatorBase { class NetOp : public framework::OperatorBase {
public: public:
DEFINE_OPERATOR_CTOR(NetOp, framework::OperatorBase)
/** /**
* Infer all the operators' input and output variables' shapes, will be called * Infer all the operators' input and output variables' shapes, will be called
* before every mini-batch * before every mini-batch

@ -12,6 +12,8 @@ static int run_cnt = 0;
class TestOp : public framework::OperatorBase { class TestOp : public framework::OperatorBase {
public: public:
DEFINE_OPERATOR_CTOR(TestOp, framework::OperatorBase)
void InferShape(const Scope& scope) const override { ++infer_shape_cnt; } void InferShape(const Scope& scope) const override { ++infer_shape_cnt; }
void Run(const Scope& scope, void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override { const platform::DeviceContext& dev_ctx) const override {
@ -21,6 +23,8 @@ class TestOp : public framework::OperatorBase {
class EmptyOp : public framework::OperatorBase { class EmptyOp : public framework::OperatorBase {
public: public:
DEFINE_OPERATOR_CTOR(EmptyOp, framework::OperatorBase)
void InferShape(const Scope& scope) const override {} void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope, const DeviceContext& dev_ctx) const override {} void Run(const Scope& scope, const DeviceContext& dev_ctx) const override {}
}; };

@ -100,6 +100,7 @@ class RecurrentGradientAlgorithm {
}; };
class RecurrentOp final : public framework::OperatorBase { class RecurrentOp final : public framework::OperatorBase {
DEFINE_OPERATOR_CTOR(RecurrentOp, framework::OperatorBase)
public: public:
void Init() override; void Init() override;

@ -18,6 +18,7 @@ namespace paddle {
namespace operators { namespace operators {
class RowWiseAddOp : public framework::OperatorWithKernel { class RowWiseAddOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(RowWiseAddOp, framework::OperatorWithKernel)
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 2UL, PADDLE_ENFORCE(ctx.InputSize() == 2UL,

@ -18,6 +18,7 @@ namespace paddle {
namespace operators { namespace operators {
class SGDOp : public framework::OperatorWithKernel { class SGDOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(SGDOp, framework::OperatorWithKernel)
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_EQ(ctx.InputSize(), 2, "Input size of SGDOp must be two"); PADDLE_ENFORCE_EQ(ctx.InputSize(), 2, "Input size of SGDOp must be two");

@ -18,6 +18,7 @@ namespace paddle {
namespace operators { namespace operators {
class SigmoidOp : public framework::OperatorWithKernel { class SigmoidOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(SigmoidOp, framework::OperatorWithKernel)
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 1, "Sigmoid Op only have one input"); PADDLE_ENFORCE(ctx.InputSize() == 1, "Sigmoid Op only have one input");
@ -38,6 +39,7 @@ class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
}; };
class SigmoidOpGrad : public framework::OperatorWithKernel { class SigmoidOpGrad : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(SigmoidOpGrad, framework::OperatorWithKernel)
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<Tensor>(0)->Resize(ctx.Input<Tensor>(0)->dims()); ctx.Output<Tensor>(0)->Resize(ctx.Input<Tensor>(0)->dims());

@ -18,6 +18,7 @@ namespace paddle {
namespace operators { namespace operators {
class SoftmaxOp : public framework::OperatorWithKernel { class SoftmaxOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(SoftmaxOp, framework::OperatorWithKernel)
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_EQ(ctx.InputSize(), 1UL, PADDLE_ENFORCE_EQ(ctx.InputSize(), 1UL,
@ -42,6 +43,7 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
}; };
class SoftmaxOpGrad : public framework::OperatorWithKernel { class SoftmaxOpGrad : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(SoftmaxOpGrad, framework::OperatorWithKernel)
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_EQ(ctx.InputSize(), 3UL, PADDLE_ENFORCE_EQ(ctx.InputSize(), 3UL,

@ -46,6 +46,7 @@ class CPUUniformRandomKernel : public framework::OpKernel {
}; };
class UniformRandomOp : public framework::OperatorWithKernel { class UniformRandomOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(UniformRandomOp, framework::OperatorWithKernel)
protected: protected:
void InferShape(const framework::InferShapeContext& ctx) const override { void InferShape(const framework::InferShapeContext& ctx) const override {
PADDLE_ENFORCE(GetAttr<float>("min") < GetAttr<float>("max"), PADDLE_ENFORCE(GetAttr<float>("min") < GetAttr<float>("max"),

Loading…
Cancel
Save