Rnn make stepnet member (#3469)

* make stepnet member

* add pybind support

* fix Inputs Outputs

* remove unique_ptr
revert-3824-remove_grad_op_type
Yan Chunwei 8 years ago committed by GitHub
parent 80de7e5ede
commit 0079fa3256

@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/framework/op_registry.h"
#include "paddle/framework/tensor_py.h"
#include "paddle/operators/net_op.h"
#include "paddle/operators/recurrent_op.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/place.h"
#include "paddle/string/to_string.h"
@ -241,6 +242,11 @@ All parameter, weight, gradient are variables in Paddle.
const std::shared_ptr<operators::NetOp> &net) -> void {
self.AddOp(std::static_pointer_cast<OperatorBase>(net));
})
.def("add_op",
[](operators::NetOp &self,
const std::shared_ptr<operators::RecurrentOp> &rnn) -> void {
self.AddOp(std::static_pointer_cast<OperatorBase>(rnn));
})
.def("complete_add_op", &operators::NetOp::CompleteAddOp)
.def("complete_add_op", [](std::shared_ptr<operators::NetOp> &self) {
self->CompleteAddOp();
@ -248,6 +254,29 @@ All parameter, weight, gradient are variables in Paddle.
ExposeOperator(net);
// recurrent_op
py::class_<operators::RecurrentOp, std::shared_ptr<operators::RecurrentOp>>
rnn(m, "RecurrentOp");
rnn.def_static(
"create",
[](py::bytes protobin) -> std::shared_ptr<operators::RecurrentOp> {
OpDesc desc;
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
"Cannot parse user input to OpDesc");
PADDLE_ENFORCE(desc.IsInitialized(),
"User OpDesc is not initialized, reason %s",
desc.InitializationErrorString());
auto rnn_op = OpRegistry::CreateOp(desc);
return std::dynamic_pointer_cast<operators::RecurrentOp>(rnn_op);
})
.def("set_stepnet",
[](operators::RecurrentOp &self,
const std::shared_ptr<operators::NetOp> &net) -> void {
self.set_stepnet(net);
});
ExposeOperator(rnn);
m.def("unique_integer", UniqueIntegerGenerator);
m.def("is_compile_gpu", IsCompileGPU);

@ -66,6 +66,5 @@ op_library(sgd_op SRCS sgd_op.cc sgd_op.cu)
op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
DEPS framework_proto tensor op_registry operator net_op)
cc_test(recurrent_op_test SRCS recurrent_op_test.cc DEPS recurrent_op gtest mul_op add_op)
op_library(uniform_random_op
SRCS uniform_random_op.cc uniform_random_op.cu)

@ -36,15 +36,13 @@ void RecurrentAlgorithm::InferShape(const Scope& scope) const {
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_,
true /*infer_shape_mode*/);
InitMemories(step_scopes[0], true /*infer_shape_mode*/);
Variable* net = scope.FindVar(arg_->step_net);
PADDLE_ENFORCE(net != nullptr, "failed to get step net");
for (size_t i = 0; i < seq_len_; i++) {
if (i > 0) {
rnn::LinkMemories(step_scopes, arg_->memories, i, -1,
true /*infer_shape_mode*/);
}
net->GetMutable<NetOp>()->InferShape(*step_scopes[i]);
(*stepnet_)->InferShape(*step_scopes[i]);
}
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
true /*infer_shape_mode*/);
@ -56,7 +54,6 @@ void RecurrentAlgorithm::Run(const Scope& scope,
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_,
false /*infer_shape_mode*/);
InitMemories(step_scopes[0], false /*infer_shape_mode*/);
Variable* net = scope.FindVar(arg_->step_net);
for (size_t step_id = 0; step_id < seq_len_; step_id++) {
// create output alias variables
@ -64,7 +61,7 @@ void RecurrentAlgorithm::Run(const Scope& scope,
rnn::LinkMemories(step_scopes, arg_->memories, step_id, -1,
false /*infer_shape_mode*/);
}
net->GetMutable<NetOp>()->Run(*step_scopes[step_id], dev_ctx);
(*stepnet_)->Run(*step_scopes[step_id], dev_ctx);
}
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
false /*infer_shape_mode*/);
@ -78,18 +75,16 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const {
auto step_scopes = step_scopes_var->GetMutable<std::vector<Scope*>>();
// Now all variables in scope must be created outside of op.
auto net_var = scope.FindVar(arg_->step_net);
PADDLE_ENFORCE(net_var != nullptr, "no stepnet called %s in scope",
arg_->step_net);
auto net_op = net_var->GetMutable<NetOp>();
PADDLE_ENFORCE(!net_op->Outputs().empty(), "net_op has no outputs");
PADDLE_ENFORCE_NOT_NULL(stepnet_);
PADDLE_ENFORCE(!(*stepnet_)->Outputs().empty(), "stepnet_ op has no outputs");
PADDLE_ENFORCE(!(*stepnet_)->Outputs().empty(), "net_op has no outputs");
if (seq_len_ > step_scopes->size()) {
for (size_t i = step_scopes->size(); i < seq_len_; ++i) {
auto& step_scope = scope.NewScope();
// create step net's temp inputs
for (auto& input : net_op->Inputs()) {
for (auto& input : (*stepnet_)->Inputs()) {
// the weight are located in parent scope
for (auto& var_name : input.second) {
if (!step_scope.FindVar(var_name)) {
@ -98,7 +93,7 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const {
}
}
// create stepnet's outputs
for (const auto& output : net_op->Outputs()) {
for (const auto& output : (*stepnet_)->Outputs()) {
for (auto& var_name : output.second) {
step_scope.NewVar(var_name);
}
@ -140,9 +135,8 @@ RecurrentOp::RecurrentOp(const std::string& type,
const framework::OperatorBase::VarNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {
std::unique_ptr<rnn::Argument> arg(new rnn::Argument());
rnn::InitArgument(kArgName, arg.get(), *this);
alg_.Init(std::move(arg));
rnn::InitArgument(kArgName, &arg_, *this);
alg_.Init(&arg_, &stepnet_);
}
class RecurrentAlgorithmProtoAndCheckerMaker
@ -158,7 +152,6 @@ class RecurrentAlgorithmProtoAndCheckerMaker
.AsDuplicable();
AddInput(name.boot_memories, "variables to initialize memories.")
.AsDuplicable();
AddInput(name.step_net, "network shared by all steps.");
AddOutput(name.outlinks, "the outputs that need to concated for all steps.")
.AsDuplicable();
@ -180,14 +173,12 @@ void RecurrentGradientAlgorithm::Run(
auto step_scopes = GetStepScopes(scope);
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_,
false /*infer_shape_mode*/);
Variable* net = scope.FindVar(arg_->step_net);
PADDLE_ENFORCE(net != nullptr, "failed to get step net");
for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) {
if (static_cast<size_t>(step_id) != seq_len_ - 1) {
rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1,
false /*infer_shape_mode*/);
}
net->GetMutable<NetOp>()->Run(*step_scopes[step_id], dev_ctx);
(*stepnet_)->Run(*step_scopes[step_id], dev_ctx);
}
LinkBootMemoryGradients(step_scopes[0], false);
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
@ -219,14 +210,12 @@ void RecurrentGradientAlgorithm::InferShape(const Scope& scope) const {
auto step_scopes = GetStepScopes(scope);
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_,
true /*infer_shape_mode*/);
Variable* net = scope.FindVar(arg_->step_net);
PADDLE_ENFORCE(net != nullptr, "failed to get step net");
for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) {
if (static_cast<size_t>(step_id) != seq_len_ - 1) {
rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1,
true /*infer_shape_mode*/);
}
net->GetMutable<NetOp>()->InferShape(*step_scopes[step_id]);
(*stepnet_)->InferShape(*step_scopes[step_id]);
}
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
true /*infer_shape_mode*/);
@ -238,9 +227,8 @@ RecurrentGradientOp::RecurrentGradientOp(
const framework::OperatorBase::VarNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {
std::unique_ptr<rnn::Argument> arg(new rnn::Argument());
rnn::InitArgument(kArgName, arg.get(), *this);
alg_.Init(std::move(arg));
rnn::InitArgument(kArgName, &arg_, *this);
alg_.Init(&arg_, &stepnet_);
}
} // namespace operators

@ -15,6 +15,7 @@
#pragma once
#include "paddle/framework/operator.h"
#include "paddle/operators/net_op.h"
#include "paddle/operators/rnn/recurrent_op_utils.h"
namespace paddle {
@ -33,7 +34,11 @@ class RecurrentAlgorithm {
void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const;
void Init(std::unique_ptr<rnn::Argument> arg) { arg_ = std::move(arg); }
void Init(rnn::Argument* arg, std::shared_ptr<NetOp>* stepnet) {
PADDLE_ENFORCE_NOT_NULL(stepnet, "stepnet should be set before.");
arg_ = arg;
stepnet_ = stepnet;
}
/**
* InferShape must be called before Run.
@ -58,7 +63,8 @@ class RecurrentAlgorithm {
void InitMemories(framework::Scope* step_scopes, bool infer_shape_mode) const;
private:
std::unique_ptr<rnn::Argument> arg_;
std::shared_ptr<NetOp>* stepnet_;
rnn::Argument* arg_;
mutable size_t seq_len_;
};
@ -74,7 +80,11 @@ class RecurrentGradientAlgorithm {
* operator.
*/
public:
void Init(std::unique_ptr<rnn::Argument> arg) { arg_ = std::move(arg); }
void Init(rnn::Argument* arg, std::shared_ptr<NetOp>* stepnet) {
PADDLE_ENFORCE_NOT_NULL(stepnet, "stepnet should be set before.");
arg_ = std::move(arg);
stepnet_ = stepnet;
}
void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const;
@ -95,8 +105,9 @@ class RecurrentGradientAlgorithm {
}
private:
std::unique_ptr<rnn::Argument> arg_;
rnn::Argument* arg_;
mutable size_t seq_len_;
std::shared_ptr<NetOp>* stepnet_;
};
class RecurrentOp final : public framework::OperatorBase {
@ -115,10 +126,15 @@ class RecurrentOp final : public framework::OperatorBase {
alg_.Run(scope, dev_ctx);
}
void set_stepnet(std::shared_ptr<NetOp> net) { stepnet_ = net; }
const NetOp* stepnet() const { return stepnet_.get(); }
static const rnn::ArgumentName kArgName;
private:
RecurrentAlgorithm alg_;
rnn::Argument arg_;
std::shared_ptr<NetOp> stepnet_;
};
class RecurrentGradientOp final : public framework::OperatorBase {
@ -141,8 +157,13 @@ class RecurrentGradientOp final : public framework::OperatorBase {
static const rnn::ArgumentName kArgName;
void set_stepnet(const std::shared_ptr<NetOp>& net) { stepnet_ = net; }
const NetOp* stepnet() const { return stepnet_.get(); }
private:
RecurrentGradientAlgorithm alg_;
std::shared_ptr<NetOp> stepnet_;
rnn::Argument arg_;
};
} // namespace operators

File diff suppressed because it is too large Load Diff

@ -106,7 +106,6 @@ void LinkMemories(const std::vector<Scope*>& scopes,
void InitArgument(const ArgumentName& name, Argument* arg,
const framework::OperatorBase& op) {
arg->step_net = op.Input(name.step_net);
arg->step_scopes = op.Output(name.step_scopes);
auto inlinks = op.Inputs(name.inlinks);

@ -23,7 +23,7 @@ class OpDescCreationMethod(object):
"""
A Functor object to convert user input(use key word args) to OpDesc based on
OpProto.
:param op_proto: The OpProto object.
:type op_proto: op_proto_pb2.OpProto
"""
@ -177,4 +177,26 @@ class OperatorFactory(object):
return self.get_op_info(type).attrs
class __RecurrentOp__(object):
__proto__ = None
type = 'recurrent_op'
def __init__(self):
# cache recurrent_op's proto
if self.__proto__ is None:
for op_proto in get_all_op_protos():
if op_proto.type == self.type:
self.__proto__ = op_proto
def __call__(self, *args, **kwargs):
if self.type not in args and 'type' not in kwargs:
kwargs['type'] = self.type
# create proto
create_method = OpDescCreationMethod(self.__proto__)
proto = create_method(*args, **kwargs)
# create rnnop
return core.RecurrentOp.create(proto.SerializeToString())
Operator = OperatorFactory() # Default global factory
RecurrentOp = __RecurrentOp__()

@ -2,7 +2,7 @@ import logging
import paddle.v2.framework.core as core
import unittest
import numpy as np
from paddle.v2.framework.op import Operator
from paddle.v2.framework.op import Operator, RecurrentOp
def py_sigmoid(x):
@ -98,11 +98,11 @@ class TestRecurrentOp(unittest.TestCase):
def forward(self):
self.scope = core.Scope()
self.create_global_variables()
self.create_rnn_op()
self.create_step_net()
rnn_op = self.create_rnn_op()
ctx = core.DeviceContext.create(core.CPUPlace())
rnn_op.infer_shape(self.scope)
rnn_op.run(self.scope, ctx)
self.rnnop.infer_shape(self.scope)
self.rnnop.run(self.scope, ctx)
return np.array(self.scope.find_var("h").get_tensor())
def create_global_variables(self):
@ -128,8 +128,7 @@ class TestRecurrentOp(unittest.TestCase):
def create_rnn_op(self):
# create RNNOp
rnnop = Operator(
"recurrent_op",
self.rnnop = RecurrentOp(
# inputs
inlinks=["x"],
boot_memories=["h_boot"],
@ -142,14 +141,9 @@ class TestRecurrentOp(unittest.TestCase):
outlink_alias=["h@alias"],
pre_memories=["h@pre"],
memories=["h@alias"])
return rnnop
def create_step_net(self):
var = self.scope.new_var("stepnet")
stepnet = var.get_net()
# x_fc_op = Operator("fc", X="x@alias", W="W", Y="Wx")
# h_fc_op = Operator("fc", X="h@pre", W="U", Y="Uh")
stepnet = core.Net.create()
x_fc_op = Operator("mul", X="x@alias", Y="W", Out="Wx")
h_fc_op = Operator("mul", X="h@pre", Y="U", Out="Uh")
sum_op = Operator("add_two", X="Wx", Y="Uh", Out="sum")
@ -158,6 +152,7 @@ class TestRecurrentOp(unittest.TestCase):
for op in [x_fc_op, h_fc_op, sum_op, sig_op]:
stepnet.add_op(op)
stepnet.complete_add_op(True)
self.rnnop.set_stepnet(stepnet)
def test_forward(self):
print 'test recurrent op forward'

Loading…
Cancel
Save