Rnn make stepnet member (#3469)

* make stepnet member

* add pybind support

* fix Inputs Outputs

* remove unique_ptr
revert-3824-remove_grad_op_type
Yan Chunwei 8 years ago committed by GitHub
parent 80de7e5ede
commit 0079fa3256

@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/framework/tensor_py.h" #include "paddle/framework/tensor_py.h"
#include "paddle/operators/net_op.h" #include "paddle/operators/net_op.h"
#include "paddle/operators/recurrent_op.h"
#include "paddle/platform/enforce.h" #include "paddle/platform/enforce.h"
#include "paddle/platform/place.h" #include "paddle/platform/place.h"
#include "paddle/string/to_string.h" #include "paddle/string/to_string.h"
@ -241,6 +242,11 @@ All parameter, weight, gradient are variables in Paddle.
const std::shared_ptr<operators::NetOp> &net) -> void { const std::shared_ptr<operators::NetOp> &net) -> void {
self.AddOp(std::static_pointer_cast<OperatorBase>(net)); self.AddOp(std::static_pointer_cast<OperatorBase>(net));
}) })
.def("add_op",
[](operators::NetOp &self,
const std::shared_ptr<operators::RecurrentOp> &rnn) -> void {
self.AddOp(std::static_pointer_cast<OperatorBase>(rnn));
})
.def("complete_add_op", &operators::NetOp::CompleteAddOp) .def("complete_add_op", &operators::NetOp::CompleteAddOp)
.def("complete_add_op", [](std::shared_ptr<operators::NetOp> &self) { .def("complete_add_op", [](std::shared_ptr<operators::NetOp> &self) {
self->CompleteAddOp(); self->CompleteAddOp();
@ -248,6 +254,29 @@ All parameter, weight, gradient are variables in Paddle.
ExposeOperator(net); ExposeOperator(net);
// recurrent_op
py::class_<operators::RecurrentOp, std::shared_ptr<operators::RecurrentOp>>
rnn(m, "RecurrentOp");
rnn.def_static(
"create",
[](py::bytes protobin) -> std::shared_ptr<operators::RecurrentOp> {
OpDesc desc;
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
"Cannot parse user input to OpDesc");
PADDLE_ENFORCE(desc.IsInitialized(),
"User OpDesc is not initialized, reason %s",
desc.InitializationErrorString());
auto rnn_op = OpRegistry::CreateOp(desc);
return std::dynamic_pointer_cast<operators::RecurrentOp>(rnn_op);
})
.def("set_stepnet",
[](operators::RecurrentOp &self,
const std::shared_ptr<operators::NetOp> &net) -> void {
self.set_stepnet(net);
});
ExposeOperator(rnn);
m.def("unique_integer", UniqueIntegerGenerator); m.def("unique_integer", UniqueIntegerGenerator);
m.def("is_compile_gpu", IsCompileGPU); m.def("is_compile_gpu", IsCompileGPU);

@ -66,6 +66,5 @@ op_library(sgd_op SRCS sgd_op.cc sgd_op.cu)
op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
DEPS framework_proto tensor op_registry operator net_op) DEPS framework_proto tensor op_registry operator net_op)
cc_test(recurrent_op_test SRCS recurrent_op_test.cc DEPS recurrent_op gtest mul_op add_op)
op_library(uniform_random_op op_library(uniform_random_op
SRCS uniform_random_op.cc uniform_random_op.cu) SRCS uniform_random_op.cc uniform_random_op.cu)

@ -36,15 +36,13 @@ void RecurrentAlgorithm::InferShape(const Scope& scope) const {
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_, rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_,
true /*infer_shape_mode*/); true /*infer_shape_mode*/);
InitMemories(step_scopes[0], true /*infer_shape_mode*/); InitMemories(step_scopes[0], true /*infer_shape_mode*/);
Variable* net = scope.FindVar(arg_->step_net);
PADDLE_ENFORCE(net != nullptr, "failed to get step net");
for (size_t i = 0; i < seq_len_; i++) { for (size_t i = 0; i < seq_len_; i++) {
if (i > 0) { if (i > 0) {
rnn::LinkMemories(step_scopes, arg_->memories, i, -1, rnn::LinkMemories(step_scopes, arg_->memories, i, -1,
true /*infer_shape_mode*/); true /*infer_shape_mode*/);
} }
net->GetMutable<NetOp>()->InferShape(*step_scopes[i]); (*stepnet_)->InferShape(*step_scopes[i]);
} }
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_, rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
true /*infer_shape_mode*/); true /*infer_shape_mode*/);
@ -56,7 +54,6 @@ void RecurrentAlgorithm::Run(const Scope& scope,
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_, rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_,
false /*infer_shape_mode*/); false /*infer_shape_mode*/);
InitMemories(step_scopes[0], false /*infer_shape_mode*/); InitMemories(step_scopes[0], false /*infer_shape_mode*/);
Variable* net = scope.FindVar(arg_->step_net);
for (size_t step_id = 0; step_id < seq_len_; step_id++) { for (size_t step_id = 0; step_id < seq_len_; step_id++) {
// create output alias variables // create output alias variables
@ -64,7 +61,7 @@ void RecurrentAlgorithm::Run(const Scope& scope,
rnn::LinkMemories(step_scopes, arg_->memories, step_id, -1, rnn::LinkMemories(step_scopes, arg_->memories, step_id, -1,
false /*infer_shape_mode*/); false /*infer_shape_mode*/);
} }
net->GetMutable<NetOp>()->Run(*step_scopes[step_id], dev_ctx); (*stepnet_)->Run(*step_scopes[step_id], dev_ctx);
} }
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_, rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
false /*infer_shape_mode*/); false /*infer_shape_mode*/);
@ -78,18 +75,16 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const {
auto step_scopes = step_scopes_var->GetMutable<std::vector<Scope*>>(); auto step_scopes = step_scopes_var->GetMutable<std::vector<Scope*>>();
// Now all variables in scope must be created outside of op. // Now all variables in scope must be created outside of op.
auto net_var = scope.FindVar(arg_->step_net); PADDLE_ENFORCE_NOT_NULL(stepnet_);
PADDLE_ENFORCE(net_var != nullptr, "no stepnet called %s in scope", PADDLE_ENFORCE(!(*stepnet_)->Outputs().empty(), "stepnet_ op has no outputs");
arg_->step_net); PADDLE_ENFORCE(!(*stepnet_)->Outputs().empty(), "net_op has no outputs");
auto net_op = net_var->GetMutable<NetOp>();
PADDLE_ENFORCE(!net_op->Outputs().empty(), "net_op has no outputs");
if (seq_len_ > step_scopes->size()) { if (seq_len_ > step_scopes->size()) {
for (size_t i = step_scopes->size(); i < seq_len_; ++i) { for (size_t i = step_scopes->size(); i < seq_len_; ++i) {
auto& step_scope = scope.NewScope(); auto& step_scope = scope.NewScope();
// create step net's temp inputs // create step net's temp inputs
for (auto& input : net_op->Inputs()) { for (auto& input : (*stepnet_)->Inputs()) {
// the weight are located in parent scope // the weight are located in parent scope
for (auto& var_name : input.second) { for (auto& var_name : input.second) {
if (!step_scope.FindVar(var_name)) { if (!step_scope.FindVar(var_name)) {
@ -98,7 +93,7 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const {
} }
} }
// create stepnet's outputs // create stepnet's outputs
for (const auto& output : net_op->Outputs()) { for (const auto& output : (*stepnet_)->Outputs()) {
for (auto& var_name : output.second) { for (auto& var_name : output.second) {
step_scope.NewVar(var_name); step_scope.NewVar(var_name);
} }
@ -140,9 +135,8 @@ RecurrentOp::RecurrentOp(const std::string& type,
const framework::OperatorBase::VarNameMap& outputs, const framework::OperatorBase::VarNameMap& outputs,
const framework::AttributeMap& attrs) const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) { : OperatorBase(type, inputs, outputs, attrs) {
std::unique_ptr<rnn::Argument> arg(new rnn::Argument()); rnn::InitArgument(kArgName, &arg_, *this);
rnn::InitArgument(kArgName, arg.get(), *this); alg_.Init(&arg_, &stepnet_);
alg_.Init(std::move(arg));
} }
class RecurrentAlgorithmProtoAndCheckerMaker class RecurrentAlgorithmProtoAndCheckerMaker
@ -158,7 +152,6 @@ class RecurrentAlgorithmProtoAndCheckerMaker
.AsDuplicable(); .AsDuplicable();
AddInput(name.boot_memories, "variables to initialize memories.") AddInput(name.boot_memories, "variables to initialize memories.")
.AsDuplicable(); .AsDuplicable();
AddInput(name.step_net, "network shared by all steps.");
AddOutput(name.outlinks, "the outputs that need to concated for all steps.") AddOutput(name.outlinks, "the outputs that need to concated for all steps.")
.AsDuplicable(); .AsDuplicable();
@ -180,14 +173,12 @@ void RecurrentGradientAlgorithm::Run(
auto step_scopes = GetStepScopes(scope); auto step_scopes = GetStepScopes(scope);
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_, rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_,
false /*infer_shape_mode*/); false /*infer_shape_mode*/);
Variable* net = scope.FindVar(arg_->step_net);
PADDLE_ENFORCE(net != nullptr, "failed to get step net");
for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) { for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) {
if (static_cast<size_t>(step_id) != seq_len_ - 1) { if (static_cast<size_t>(step_id) != seq_len_ - 1) {
rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1, rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1,
false /*infer_shape_mode*/); false /*infer_shape_mode*/);
} }
net->GetMutable<NetOp>()->Run(*step_scopes[step_id], dev_ctx); (*stepnet_)->Run(*step_scopes[step_id], dev_ctx);
} }
LinkBootMemoryGradients(step_scopes[0], false); LinkBootMemoryGradients(step_scopes[0], false);
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_, rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
@ -219,14 +210,12 @@ void RecurrentGradientAlgorithm::InferShape(const Scope& scope) const {
auto step_scopes = GetStepScopes(scope); auto step_scopes = GetStepScopes(scope);
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_, rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_,
true /*infer_shape_mode*/); true /*infer_shape_mode*/);
Variable* net = scope.FindVar(arg_->step_net);
PADDLE_ENFORCE(net != nullptr, "failed to get step net");
for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) { for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) {
if (static_cast<size_t>(step_id) != seq_len_ - 1) { if (static_cast<size_t>(step_id) != seq_len_ - 1) {
rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1, rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1,
true /*infer_shape_mode*/); true /*infer_shape_mode*/);
} }
net->GetMutable<NetOp>()->InferShape(*step_scopes[step_id]); (*stepnet_)->InferShape(*step_scopes[step_id]);
} }
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_, rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
true /*infer_shape_mode*/); true /*infer_shape_mode*/);
@ -238,9 +227,8 @@ RecurrentGradientOp::RecurrentGradientOp(
const framework::OperatorBase::VarNameMap& outputs, const framework::OperatorBase::VarNameMap& outputs,
const framework::AttributeMap& attrs) const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) { : OperatorBase(type, inputs, outputs, attrs) {
std::unique_ptr<rnn::Argument> arg(new rnn::Argument()); rnn::InitArgument(kArgName, &arg_, *this);
rnn::InitArgument(kArgName, arg.get(), *this); alg_.Init(&arg_, &stepnet_);
alg_.Init(std::move(arg));
} }
} // namespace operators } // namespace operators

@ -15,6 +15,7 @@
#pragma once #pragma once
#include "paddle/framework/operator.h" #include "paddle/framework/operator.h"
#include "paddle/operators/net_op.h"
#include "paddle/operators/rnn/recurrent_op_utils.h" #include "paddle/operators/rnn/recurrent_op_utils.h"
namespace paddle { namespace paddle {
@ -33,7 +34,11 @@ class RecurrentAlgorithm {
void Run(const framework::Scope& scope, void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const; const platform::DeviceContext& dev_ctx) const;
void Init(std::unique_ptr<rnn::Argument> arg) { arg_ = std::move(arg); } void Init(rnn::Argument* arg, std::shared_ptr<NetOp>* stepnet) {
PADDLE_ENFORCE_NOT_NULL(stepnet, "stepnet should be set before.");
arg_ = arg;
stepnet_ = stepnet;
}
/** /**
* InferShape must be called before Run. * InferShape must be called before Run.
@ -58,7 +63,8 @@ class RecurrentAlgorithm {
void InitMemories(framework::Scope* step_scopes, bool infer_shape_mode) const; void InitMemories(framework::Scope* step_scopes, bool infer_shape_mode) const;
private: private:
std::unique_ptr<rnn::Argument> arg_; std::shared_ptr<NetOp>* stepnet_;
rnn::Argument* arg_;
mutable size_t seq_len_; mutable size_t seq_len_;
}; };
@ -74,7 +80,11 @@ class RecurrentGradientAlgorithm {
* operator. * operator.
*/ */
public: public:
void Init(std::unique_ptr<rnn::Argument> arg) { arg_ = std::move(arg); } void Init(rnn::Argument* arg, std::shared_ptr<NetOp>* stepnet) {
PADDLE_ENFORCE_NOT_NULL(stepnet, "stepnet should be set before.");
arg_ = std::move(arg);
stepnet_ = stepnet;
}
void Run(const framework::Scope& scope, void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const; const platform::DeviceContext& dev_ctx) const;
@ -95,8 +105,9 @@ class RecurrentGradientAlgorithm {
} }
private: private:
std::unique_ptr<rnn::Argument> arg_; rnn::Argument* arg_;
mutable size_t seq_len_; mutable size_t seq_len_;
std::shared_ptr<NetOp>* stepnet_;
}; };
class RecurrentOp final : public framework::OperatorBase { class RecurrentOp final : public framework::OperatorBase {
@ -115,10 +126,15 @@ class RecurrentOp final : public framework::OperatorBase {
alg_.Run(scope, dev_ctx); alg_.Run(scope, dev_ctx);
} }
void set_stepnet(std::shared_ptr<NetOp> net) { stepnet_ = net; }
const NetOp* stepnet() const { return stepnet_.get(); }
static const rnn::ArgumentName kArgName; static const rnn::ArgumentName kArgName;
private: private:
RecurrentAlgorithm alg_; RecurrentAlgorithm alg_;
rnn::Argument arg_;
std::shared_ptr<NetOp> stepnet_;
}; };
class RecurrentGradientOp final : public framework::OperatorBase { class RecurrentGradientOp final : public framework::OperatorBase {
@ -141,8 +157,13 @@ class RecurrentGradientOp final : public framework::OperatorBase {
static const rnn::ArgumentName kArgName; static const rnn::ArgumentName kArgName;
void set_stepnet(const std::shared_ptr<NetOp>& net) { stepnet_ = net; }
const NetOp* stepnet() const { return stepnet_.get(); }
private: private:
RecurrentGradientAlgorithm alg_; RecurrentGradientAlgorithm alg_;
std::shared_ptr<NetOp> stepnet_;
rnn::Argument arg_;
}; };
} // namespace operators } // namespace operators

File diff suppressed because it is too large Load Diff

@ -106,7 +106,6 @@ void LinkMemories(const std::vector<Scope*>& scopes,
void InitArgument(const ArgumentName& name, Argument* arg, void InitArgument(const ArgumentName& name, Argument* arg,
const framework::OperatorBase& op) { const framework::OperatorBase& op) {
arg->step_net = op.Input(name.step_net);
arg->step_scopes = op.Output(name.step_scopes); arg->step_scopes = op.Output(name.step_scopes);
auto inlinks = op.Inputs(name.inlinks); auto inlinks = op.Inputs(name.inlinks);

@ -177,4 +177,26 @@ class OperatorFactory(object):
return self.get_op_info(type).attrs return self.get_op_info(type).attrs
class __RecurrentOp__(object):
__proto__ = None
type = 'recurrent_op'
def __init__(self):
# cache recurrent_op's proto
if self.__proto__ is None:
for op_proto in get_all_op_protos():
if op_proto.type == self.type:
self.__proto__ = op_proto
def __call__(self, *args, **kwargs):
if self.type not in args and 'type' not in kwargs:
kwargs['type'] = self.type
# create proto
create_method = OpDescCreationMethod(self.__proto__)
proto = create_method(*args, **kwargs)
# create rnnop
return core.RecurrentOp.create(proto.SerializeToString())
Operator = OperatorFactory() # Default global factory Operator = OperatorFactory() # Default global factory
RecurrentOp = __RecurrentOp__()

@ -2,7 +2,7 @@ import logging
import paddle.v2.framework.core as core import paddle.v2.framework.core as core
import unittest import unittest
import numpy as np import numpy as np
from paddle.v2.framework.op import Operator from paddle.v2.framework.op import Operator, RecurrentOp
def py_sigmoid(x): def py_sigmoid(x):
@ -98,11 +98,11 @@ class TestRecurrentOp(unittest.TestCase):
def forward(self): def forward(self):
self.scope = core.Scope() self.scope = core.Scope()
self.create_global_variables() self.create_global_variables()
self.create_rnn_op()
self.create_step_net() self.create_step_net()
rnn_op = self.create_rnn_op()
ctx = core.DeviceContext.create(core.CPUPlace()) ctx = core.DeviceContext.create(core.CPUPlace())
rnn_op.infer_shape(self.scope) self.rnnop.infer_shape(self.scope)
rnn_op.run(self.scope, ctx) self.rnnop.run(self.scope, ctx)
return np.array(self.scope.find_var("h").get_tensor()) return np.array(self.scope.find_var("h").get_tensor())
def create_global_variables(self): def create_global_variables(self):
@ -128,8 +128,7 @@ class TestRecurrentOp(unittest.TestCase):
def create_rnn_op(self): def create_rnn_op(self):
# create RNNOp # create RNNOp
rnnop = Operator( self.rnnop = RecurrentOp(
"recurrent_op",
# inputs # inputs
inlinks=["x"], inlinks=["x"],
boot_memories=["h_boot"], boot_memories=["h_boot"],
@ -142,14 +141,9 @@ class TestRecurrentOp(unittest.TestCase):
outlink_alias=["h@alias"], outlink_alias=["h@alias"],
pre_memories=["h@pre"], pre_memories=["h@pre"],
memories=["h@alias"]) memories=["h@alias"])
return rnnop
def create_step_net(self): def create_step_net(self):
var = self.scope.new_var("stepnet") stepnet = core.Net.create()
stepnet = var.get_net()
# x_fc_op = Operator("fc", X="x@alias", W="W", Y="Wx")
# h_fc_op = Operator("fc", X="h@pre", W="U", Y="Uh")
x_fc_op = Operator("mul", X="x@alias", Y="W", Out="Wx") x_fc_op = Operator("mul", X="x@alias", Y="W", Out="Wx")
h_fc_op = Operator("mul", X="h@pre", Y="U", Out="Uh") h_fc_op = Operator("mul", X="h@pre", Y="U", Out="Uh")
sum_op = Operator("add_two", X="Wx", Y="Uh", Out="sum") sum_op = Operator("add_two", X="Wx", Y="Uh", Out="sum")
@ -158,6 +152,7 @@ class TestRecurrentOp(unittest.TestCase):
for op in [x_fc_op, h_fc_op, sum_op, sig_op]: for op in [x_fc_op, h_fc_op, sum_op, sig_op]:
stepnet.add_op(op) stepnet.add_op(op)
stepnet.complete_add_op(True) stepnet.complete_add_op(True)
self.rnnop.set_stepnet(stepnet)
def test_forward(self): def test_forward(self):
print 'test recurrent op forward' print 'test recurrent op forward'

Loading…
Cancel
Save