Merge pull request #3300 from dzhwinter/type_alias

"remove type alias"
fixstartbug
Yi Wang 8 years ago committed by GitHub
commit 18cf078660

@ -13,6 +13,7 @@
limitations under the License. */
#include "paddle/framework/backward.h"
#include <list>
#include "paddle/framework/op_registry.h"
#include "paddle/operators/net_op.h"

@ -17,16 +17,21 @@
#include <gtest/gtest.h>
#include "paddle/framework/op_registry.h"
#include "paddle/operators/net_op.h"
#include "paddle/operators/type_alias.h"
namespace paddle {
namespace framework {
using OperatorBase = framework::OperatorBase;
using OpProtoAndCheckerMaker = framework::OpProtoAndCheckerMaker;
using OpProto = framework::OpProto;
using OpAttrChecker = framework::OpAttrChecker;
using Scope = framework::Scope;
using DeviceContext = platform::DeviceContext;
class EmptyOp : public OperatorBase {
public:
void InferShape(const Scope &scope) const override {}
void Run(const Scope &scope,
const platform::DeviceContext &dev_ctx) const override {}
void Run(const Scope &scope, const DeviceContext &dev_ctx) const override {}
};
class RowWiseAddOpMaker : public OpProtoAndCheckerMaker {
@ -71,7 +76,7 @@ class NoGradOpMaker : public OpProtoAndCheckerMaker {
}
};
class FcOp : public ops::NetOp {
class FcOp : public operators::NetOp {
public:
void Init() override {
AddOp(OpRegistry::CreateOp("mul", {Input("X"), Input("W")},
@ -143,6 +148,7 @@ class AddOpMaker : public OpProtoAndCheckerMaker {
} // namespace paddle
namespace f = paddle::framework;
namespace ops = paddle::operators;
using EnforceNotMet = paddle::platform::EnforceNotMet;
REGISTER_OP(rowwise_add, f::EmptyOp, f::RowWiseAddOpMaker);
REGISTER_GRADIENT_OP(rowwise_add, rowwise_add_grad, f::EmptyOp);

@ -18,11 +18,8 @@ limitations under the License. */
#include "paddle/framework/backward.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/tensor_py.h"
#include "paddle/operators/net_op.h"
#include "paddle/operators/type_alias.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/place.h"
#include "pybind11/numpy.h"
@ -45,6 +42,9 @@ USE_OP_WITHOUT_KERNEL(recurrent_op);
USE_OP(uniform_random);
namespace paddle {
namespace framework {
using Tensor = framework::Tensor;
template <typename ClassType>
void ExposeOperator(ClassType &m) {
m.def("infer_shape", &ClassType::type::InferShape)
@ -150,8 +150,8 @@ All parameter, weight, gradient are variables in Paddle.
[](Variable &self) -> Tensor * { return self.GetMutable<Tensor>(); },
py::return_value_policy::reference)
.def("get_net",
[](Variable &self) -> ops::NetOp * {
return self.GetMutable<ops::NetOp>();
[](Variable &self) -> operators::NetOp * {
return self.GetMutable<operators::NetOp>();
},
py::return_value_policy::reference);
@ -230,23 +230,24 @@ All parameter, weight, gradient are variables in Paddle.
ExposeOperator(operator_base);
py::class_<ops::NetOp, std::shared_ptr<ops::NetOp>> net(m, "Net");
py::class_<operators::NetOp, std::shared_ptr<operators::NetOp>> net(m, "Net");
net.def_static("create",
[]() -> std::shared_ptr<ops::NetOp> {
auto retv = std::make_shared<ops::NetOp>();
[]() -> std::shared_ptr<operators::NetOp> {
auto retv = std::make_shared<operators::NetOp>();
retv->type_ = "plain_net";
return retv;
})
.def("add_op", &ops::NetOp::AddOp)
.def(
"add_op",
[](ops::NetOp &self, const std::shared_ptr<ops::NetOp> &net) -> void {
self.AddOp(std::static_pointer_cast<OperatorBase>(net));
})
.def("complete_add_op", &ops::NetOp::CompleteAddOp)
.def("complete_add_op",
[](std::shared_ptr<ops::NetOp> &self) { self->CompleteAddOp(); });
.def("add_op", &operators::NetOp::AddOp)
.def("add_op",
[](operators::NetOp &self,
const std::shared_ptr<operators::NetOp> &net) -> void {
self.AddOp(std::static_pointer_cast<OperatorBase>(net));
})
.def("complete_add_op", &operators::NetOp::CompleteAddOp)
.def("complete_add_op", [](std::shared_ptr<operators::NetOp> &self) {
self->CompleteAddOp();
});
ExposeOperator(net);

@ -59,6 +59,7 @@ op_library(cross_entropy_op SRCS cross_entropy_op.cc cross_entropy_op.cu)
op_library(fill_zeros_like_op SRCS fill_zeros_like_op.cc fill_zeros_like_op.cu)
op_library(sgd_op SRCS sgd_op.cc sgd_op.cu)
cc_test(sgd_op_test SRCS sgd_op_test.cc DEPS sgd_op)
op_library(fc_op
SRCS fc_op.cc

@ -17,9 +17,9 @@ limitations under the License. */
namespace paddle {
namespace operators {
class AddOp : public OperatorWithKernel {
class AddOp : public framework::OperatorWithKernel {
protected:
void InferShape(const InferShapeContext &ctx) const override {
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_EQ(ctx.InputSize(), 2);
PADDLE_ENFORCE_EQ(ctx.OutputSize(), 1);
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(0), "Inputs of AddOp must all be set");
@ -31,9 +31,9 @@ class AddOp : public OperatorWithKernel {
}
};
class AddOpMaker : public OpProtoAndCheckerMaker {
class AddOpMaker : public framework::OpProtoAndCheckerMaker {
public:
AddOpMaker(OpProto *proto, OpAttrChecker *op_checker)
AddOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of add op");
AddInput("Y", "The second input of add op");
@ -46,14 +46,17 @@ The equation is: Out = X + Y
}
};
class AddOpGrad : public OperatorWithKernel {
class AddOpGrad : public framework::OperatorWithKernel {
protected:
void InferShape(const InferShapeContext &ctx) const override {}
void InferShape(const framework::InferShapeContext &ctx) const override {}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(add_two, ops::AddOp, ops::AddOpMaker);
REGISTER_GRADIENT_OP(add_two, add_two_grad, ops::AddOpGrad);
REGISTER_OP_CPU_KERNEL(add_two, ops::AddKernel<ops::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(add_two,
ops::AddKernel<paddle::platform::CPUPlace, float>);

@ -16,4 +16,6 @@
#include "paddle/framework/op_registry.h"
#include "paddle/operators/add_op.h"
REGISTER_OP_GPU_KERNEL(add_two, ops::AddKernel<ops::GPUPlace, float>);
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(add_two,
ops::AddKernel<paddle::platform::GPUPlace, float>);

@ -13,15 +13,21 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/operators/type_alias.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T>
class AddKernel : public OpKernel {
class AddKernel : public framework::OpKernel {
public:
void Compute(const ExecutionContext& context) const override {
void Compute(const framework::ExecutionContext& context) const override {
auto input0 = context.Input<Tensor>(0);
auto input1 = context.Input<Tensor>(1);
auto output = context.Output<Tensor>(0);

@ -14,9 +14,9 @@ limitations under the License. */
#include <gtest/gtest.h>
#define private public
#include <paddle/framework/op_registry.h>
#include "paddle/framework/op_registry.h"
USE_OP(add_two);
// USE_OP(add_two_grad);
TEST(AddOp, GetOpProto) {
auto& protos = paddle::framework::OpRegistry::protos();

@ -17,9 +17,9 @@ limitations under the License. */
namespace paddle {
namespace operators {
class OnehotCrossEntropyOp : public OperatorWithKernel {
class OnehotCrossEntropyOp : public framework::OperatorWithKernel {
protected:
void InferShape(const InferShapeContext &ctx) const override {
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_EQ(ctx.InputSize(), 2,
"Input size of OnehotCrossEntropyOp must be two");
PADDLE_ENFORCE_EQ(ctx.OutputSize(), 1,
@ -37,9 +37,9 @@ class OnehotCrossEntropyOp : public OperatorWithKernel {
}
};
class OnehotCrossEntropyGradientOp : public OperatorWithKernel {
class OnehotCrossEntropyGradientOp : public framework::OperatorWithKernel {
protected:
void InferShape(const InferShapeContext &ctx) const override {
void InferShape(const framework::InferShapeContext &ctx) const override {
auto X_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto X = ctx.Input<Tensor>("X");
@ -48,9 +48,10 @@ class OnehotCrossEntropyGradientOp : public OperatorWithKernel {
}
};
class OnehotCrossEntropyOpMaker : public OpProtoAndCheckerMaker {
class OnehotCrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
public:
OnehotCrossEntropyOpMaker(OpProto *proto, OpAttrChecker *op_checker)
OnehotCrossEntropyOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of OnehotCrossEntropyOp");
AddInput("label", "The second input of OnehotCrossEntropyOp");
@ -66,12 +67,14 @@ OnehotCrossEntropy Operator.
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(onehot_cross_entropy, ops::OnehotCrossEntropyOp,
ops::OnehotCrossEntropyOpMaker);
REGISTER_OP_CPU_KERNEL(onehot_cross_entropy,
ops::OnehotCrossEntropyOpKernel<ops::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
onehot_cross_entropy,
ops::OnehotCrossEntropyOpKernel<paddle::platform::CPUPlace, float>);
REGISTER_GRADIENT_OP(onehot_cross_entropy, onehot_cross_entropy_grad,
ops::OnehotCrossEntropyGradientOp);
REGISTER_OP_CPU_KERNEL(
onehot_cross_entropy_grad,
ops::OnehotCrossEntropyGradientOpKernel<ops::CPUPlace, float>);
ops::OnehotCrossEntropyGradientOpKernel<paddle::platform::CPUPlace, float>);

@ -14,3 +14,8 @@
#define EIGEN_USE_GPU
#include "paddle/operators/cross_entropy_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
onehot_cross_entropy,
ops::OnehotCrossEntropyOpKernel<paddle::platform::GPUPlace, float>);

@ -13,11 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/operators/type_alias.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
T tolerable_value(T x) {
static_assert(std::is_floating_point<T>::value,
@ -38,9 +40,9 @@ T tolerable_value(T x) {
}
template <typename Place, typename T>
class OnehotCrossEntropyOpKernel : public OpKernel {
class OnehotCrossEntropyOpKernel : public framework::OpKernel {
public:
void Compute(const ExecutionContext& ctx) const override {
void Compute(const framework::ExecutionContext& ctx) const override {
auto X = ctx.Input<Tensor>("X");
const T* Xdata = X->data<T>();
const int* label_data = ctx.Input<Tensor>(1)->data<int>();
@ -61,9 +63,9 @@ class OnehotCrossEntropyOpKernel : public OpKernel {
};
template <typename Place, typename T>
class OnehotCrossEntropyGradientOpKernel : public OpKernel {
class OnehotCrossEntropyGradientOpKernel : public framework::OpKernel {
public:
void Compute(const ExecutionContext& ctx) const override {
void Compute(const framework::ExecutionContext& ctx) const override {
auto X = ctx.Input<Tensor>("X");
auto dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto dY = ctx.Input<Tensor>(framework::GradVarName("Y"));

@ -12,11 +12,16 @@
See the License for the specific language governing permissions and
limitations under the License. */
#include "type_alias.h"
#include "paddle/operators/net_op.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using OpRegistry = framework::OpRegistry;
class FullyConnectedOp : public NetOp {
public:
void Init() override {
@ -39,9 +44,10 @@ class FullyConnectedOp : public NetOp {
}
};
class FullyConnectedOpMaker : public OpProtoAndCheckerMaker {
class FullyConnectedOpMaker : public framework::OpProtoAndCheckerMaker {
public:
FullyConnectedOpMaker(OpProto *proto, OpAttrChecker *op_checker)
FullyConnectedOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "the input of fc operator");
AddInput("W", "the weight of fc operator");
@ -66,4 +72,5 @@ USE_OP(rowwise_add);
USE_OP(sigmoid);
USE_OP(softmax);
namespace ops = paddle::operators;
REGISTER_OP(fc, ops::FullyConnectedOp, ops::FullyConnectedOpMaker);

@ -50,8 +50,8 @@ The output will have the same size with input.
} // namespace operators
} // namespace paddle
REGISTER_OP(fill_zeros_like, paddle::operators::FillZerosLikeOp,
paddle::operators::FillZerosLikeOpMaker);
namespace ops = paddle::operators;
REGISTER_OP(fill_zeros_like, ops::FillZerosLikeOp, ops::FillZerosLikeOpMaker);
REGISTER_OP_CPU_KERNEL(
fill_zeros_like,
paddle::operators::FillZerosLikeKernel<paddle::platform::CPUPlace, float>);
ops::FillZerosLikeKernel<paddle::platform::CPUPlace, float>);

@ -16,6 +16,7 @@
#include "paddle/framework/op_registry.h"
#include "paddle/operators/fill_zeros_like_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
fill_zeros_like,
paddle::operators::FillZerosLikeKernel<paddle::platform::GPUPlace, float>);
ops::FillZerosLikeKernel<paddle::platform::GPUPlace, float>);

@ -13,7 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/operators/type_alias.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {

@ -17,9 +17,9 @@ limitations under the License. */
namespace paddle {
namespace operators {
class MeanOp : public OperatorWithKernel {
class MeanOp : public framework::OperatorWithKernel {
protected:
void InferShape(const InferShapeContext &ctx) const override {
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_EQ(ctx.InputSize(), 1, "Input size of AddOp must be one");
PADDLE_ENFORCE_EQ(ctx.OutputSize(), 1, "Output size of AddOp must be one");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(0), "input should be set");
@ -28,9 +28,9 @@ class MeanOp : public OperatorWithKernel {
}
};
class MeanOpMaker : public OpProtoAndCheckerMaker {
class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MeanOpMaker(OpProto *proto, OpAttrChecker *op_checker)
MeanOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of mean op");
AddOutput("Out", "The output of mean op").IgnoreGradient();
@ -38,9 +38,9 @@ class MeanOpMaker : public OpProtoAndCheckerMaker {
}
};
class MeanGradOp : public OperatorWithKernel {
class MeanGradOp : public framework::OperatorWithKernel {
protected:
void InferShape(const InferShapeContext &ctx) const override {
void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<Tensor>("X" + framework::kGradVarSuffix)
->Resize(ctx.Input<Tensor>("X")->dims());
}
@ -49,7 +49,10 @@ class MeanGradOp : public OperatorWithKernel {
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(mean, ops::MeanOp, ops::MeanOpMaker);
REGISTER_OP_CPU_KERNEL(mean, ops::MeanKernel<ops::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(mean,
ops::MeanKernel<paddle::platform::CPUPlace, float>);
REGISTER_GRADIENT_OP(mean, mean_grad, ops::MeanGradOp);
REGISTER_OP_CPU_KERNEL(mean_grad, ops::MeanGradKernel<ops::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(mean_grad,
ops::MeanGradKernel<paddle::platform::CPUPlace, float>);

@ -16,5 +16,8 @@
#include "paddle/operators/mean_op.h"
REGISTER_OP_GPU_KERNEL(mean, ops::MeanKernel<ops::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(mean_grad, ops::MeanGradKernel<ops::GPUPlace, float>);
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(mean,
ops::MeanKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(mean_grad,
ops::MeanGradKernel<paddle::platform::GPUPlace, float>);

@ -13,15 +13,24 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/operators/type_alias.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenScalar = framework::EigenScalar<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T>
class MeanKernel : public OpKernel {
class MeanKernel : public framework::OpKernel {
public:
void Compute(const ExecutionContext& context) const override {
void Compute(const framework::ExecutionContext& context) const override {
auto input = context.Input<Tensor>(0);
auto output = context.Output<Tensor>(0);
@ -36,9 +45,9 @@ class MeanKernel : public OpKernel {
};
template <typename Place, typename T>
class MeanGradKernel : public OpKernel {
class MeanGradKernel : public framework::OpKernel {
public:
void Compute(const ExecutionContext& context) const override {
void Compute(const framework::ExecutionContext& context) const override {
auto OG = context.Input<Tensor>("Out" + framework::kGradVarSuffix);
PADDLE_ENFORCE(framework::product(OG->dims()) == 1,
"Mean Gradient should be scalar");

@ -17,9 +17,9 @@
namespace paddle {
namespace operators {
class MulOp : public OperatorWithKernel {
class MulOp : public framework::OperatorWithKernel {
protected:
void InferShape(const InferShapeContext &ctx) const override {
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 2, "The mul op must take two inputs");
auto dim0 = ctx.Input<Tensor>(0)->dims();
auto dim1 = ctx.Input<Tensor>(1)->dims();
@ -37,9 +37,9 @@ class MulOp : public OperatorWithKernel {
}
};
class MulOpMaker : public OpProtoAndCheckerMaker {
class MulOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MulOpMaker(OpProto *proto, OpAttrChecker *op_checker)
MulOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of mul op");
AddInput("Y", "The second input of mul op");
@ -52,9 +52,9 @@ The equation is: Out = X * Y
}
};
class MulOpGrad : public OperatorWithKernel {
class MulOpGrad : public framework::OperatorWithKernel {
protected:
void InferShape(const InferShapeContext &ctx) const override {}
void InferShape(const framework::InferShapeContext &ctx) const override {}
std::string DebugString() const override {
LOG(INFO) << "MulGrad";
return "";
@ -64,7 +64,8 @@ class MulOpGrad : public OperatorWithKernel {
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker);
REGISTER_GRADIENT_OP(mul, mul_grad, ops::MulOpGrad);
REGISTER_OP_CPU_KERNEL(mul, ops::MulKernel<ops::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(mul, ops::MulKernel<paddle::platform::CPUPlace, float>);

@ -15,4 +15,6 @@
#define EIGEN_USE_GPU
#include "paddle/operators/mul_op.h"
REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel<ops::GPUPlace, float>);
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel<paddle::platform::GPUPlace, float>);

@ -13,16 +13,21 @@
limitations under the License. */
#pragma once
#include "paddle/operators/type_alias.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T>
class MulKernel : public OpKernel {
class MulKernel : public framework::OpKernel {
public:
void Compute(const ExecutionContext& context) const override {
void Compute(const framework::ExecutionContext& context) const override {
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair = {
{Eigen::IndexPair<Eigen::DenseIndex>(1, 0)}};
@ -40,5 +45,6 @@ class MulKernel : public OpKernel {
Z.device(place) = X.contract(Y, dim_pair);
}
};
} // namespace operators
} // namespace paddle

@ -15,7 +15,6 @@
*/
#include "paddle/operators/net_op.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {

@ -14,13 +14,7 @@ limitations under the License. */
#pragma once
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/op_proto.pb.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/scope.h"
#include "paddle/operators/type_alias.h"
#include "paddle/platform/device_context.h"
namespace paddle {
namespace operators {

@ -2,31 +2,27 @@
#include <gtest/gtest.h>
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
namespace paddle {
namespace operators {
using Scope = framework::Scope;
using DeviceContext = platform::DeviceContext;
static int infer_shape_cnt = 0;
static int run_cnt = 0;
class TestOp : public OperatorBase {
class TestOp : public framework::OperatorBase {
public:
void InferShape(const framework::Scope& scope) const override {
++infer_shape_cnt;
}
void Run(const framework::Scope& scope,
const paddle::platform::DeviceContext& dev_ctx) const override {
void InferShape(const Scope& scope) const override { ++infer_shape_cnt; }
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {
++run_cnt;
}
};
class EmptyOp : public OperatorBase {
class EmptyOp : public framework::OperatorBase {
public:
void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {}
void Run(const Scope& scope, const DeviceContext& dev_ctx) const override {}
};
template <typename T>
@ -72,7 +68,7 @@ TEST(OpKernel, all) {
net->Run(scope, dev_ctx);
ASSERT_EQ(2, infer_shape_cnt);
ASSERT_EQ(2, run_cnt);
ASSERT_THROW(net->AddOp(op2), paddle::platform::EnforceNotMet);
ASSERT_THROW(net->AddOp(op2), platform::EnforceNotMet);
}
TEST(NetOp, insert_op) {

@ -14,17 +14,19 @@
#include "paddle/operators/recurrent_op.h"
#include <glog/logging.h>
#include <cstring>
#include <sstream>
#include "paddle/framework/op_registry.h"
#include "paddle/operators/net_op.h"
#include "paddle/platform/enforce.h"
namespace paddle {
namespace operators {
using Scope = framework::Scope;
using Variable = framework::Variable;
using Tensor = framework::Tensor;
void RecurrentAlgorithm::InferShape(const Scope& scope) const {
seq_len_ = scope.FindVar((arg_->inlinks[0]).external)
->GetMutable<Tensor>()
@ -135,10 +137,11 @@ void RecurrentOp::Init() {
alg_.Init(std::move(arg));
}
class RecurrentAlgorithmProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
class RecurrentAlgorithmProtoAndCheckerMaker
: public framework::OpProtoAndCheckerMaker {
public:
RecurrentAlgorithmProtoAndCheckerMaker(OpProto* proto,
OpAttrChecker* op_checker)
RecurrentAlgorithmProtoAndCheckerMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
const auto& name = RecurrentOp::kArgName;
// inputs and outputs stored in proto

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save