|
|
|
@ -24,22 +24,23 @@ using Tensor = framework::Tensor;
|
|
|
|
|
using LoDTensor = framework::LoDTensor;
|
|
|
|
|
|
|
|
|
|
void SegmentInputs(const std::vector<Scope*>& step_scopes,
|
|
|
|
|
const std::vector<Link>& inlinks, const size_t seq_len,
|
|
|
|
|
bool infer_shape_mode) {
|
|
|
|
|
const std::vector<std::string>& inlinks,
|
|
|
|
|
const size_t seq_len, bool infer_shape_mode) {
|
|
|
|
|
PADDLE_ENFORCE(!inlinks.empty(), "no in links are provided.");
|
|
|
|
|
for (size_t i = 0; i < inlinks.size(); ++i) {
|
|
|
|
|
auto input_var = step_scopes[0]->FindVar(inlinks[i].external);
|
|
|
|
|
PADDLE_ENFORCE(input_var != nullptr, "input link [%s] is not in scope.",
|
|
|
|
|
inlinks[i].external);
|
|
|
|
|
// global inputs
|
|
|
|
|
auto input_var = step_scopes[0]->parent().FindVar(inlinks[i]);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(input_var, "input link [%s] is not in scope.",
|
|
|
|
|
inlinks[i]);
|
|
|
|
|
|
|
|
|
|
LoDTensor* input = input_var->GetMutable<LoDTensor>();
|
|
|
|
|
f::DDim dims = input->dims();
|
|
|
|
|
PADDLE_ENFORCE(static_cast<size_t>(dims[0]) == seq_len,
|
|
|
|
|
"all the inlinks must have same length");
|
|
|
|
|
PADDLE_ENFORCE_EQ(static_cast<size_t>(dims[0]), seq_len,
|
|
|
|
|
"all the inlinks be the same length");
|
|
|
|
|
f::DDim step_dims = slice_ddim(dims, 1, dims.size());
|
|
|
|
|
for (size_t j = 0; j < seq_len; j++) {
|
|
|
|
|
Tensor* step_input =
|
|
|
|
|
step_scopes[j]->NewVar(inlinks[i].internal)->GetMutable<Tensor>();
|
|
|
|
|
step_scopes[j]->NewVar(inlinks[i])->GetMutable<Tensor>();
|
|
|
|
|
if (!infer_shape_mode) {
|
|
|
|
|
// The input of operators of each step is Tensor here.
|
|
|
|
|
// Maybe need to modify Slice function.
|
|
|
|
@ -51,18 +52,17 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ConcatOutputs(const std::vector<Scope*>& step_scopes,
|
|
|
|
|
const std::vector<Link>& outlinks, const size_t seq_len,
|
|
|
|
|
bool infer_shape_mode) {
|
|
|
|
|
const std::vector<std::string>& outlinks,
|
|
|
|
|
const size_t seq_len, bool infer_shape_mode) {
|
|
|
|
|
for (size_t i = 0; i < outlinks.size(); i++) {
|
|
|
|
|
auto output_var = step_scopes[0]->FindVar(outlinks[i].external);
|
|
|
|
|
PADDLE_ENFORCE(output_var != nullptr, "output link [%s] is not in scope.",
|
|
|
|
|
outlinks[i].external);
|
|
|
|
|
auto output_var = step_scopes[0]->parent().FindVar(outlinks[i]);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(output_var, "output link [%s] is not in scope.",
|
|
|
|
|
outlinks[i]);
|
|
|
|
|
LoDTensor* output = output_var->GetMutable<LoDTensor>();
|
|
|
|
|
|
|
|
|
|
if (infer_shape_mode) {
|
|
|
|
|
auto step_scope_var = step_scopes[0]->FindVar(outlinks[i].internal);
|
|
|
|
|
PADDLE_ENFORCE(step_scope_var != nullptr, "%s not in scope",
|
|
|
|
|
outlinks[i].internal);
|
|
|
|
|
auto step_scope_var = step_scopes[0]->FindVar(outlinks[i]);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(step_scope_var, "%s not in scope", outlinks[i]);
|
|
|
|
|
f::DDim step_dims =
|
|
|
|
|
step_scope_var->template GetMutable<LoDTensor>()->dims();
|
|
|
|
|
std::vector<int64_t> dims_vec = vectorize(step_dims);
|
|
|
|
@ -71,9 +71,8 @@ void ConcatOutputs(const std::vector<Scope*>& step_scopes,
|
|
|
|
|
} else {
|
|
|
|
|
output->mutable_data<float>(platform::CPUPlace());
|
|
|
|
|
for (size_t j = 0; j < seq_len; j++) {
|
|
|
|
|
LoDTensor* step_output = step_scopes[j]
|
|
|
|
|
->FindVar(outlinks[i].internal)
|
|
|
|
|
->GetMutable<LoDTensor>();
|
|
|
|
|
LoDTensor* step_output =
|
|
|
|
|
step_scopes[j]->FindVar(outlinks[i])->GetMutable<LoDTensor>();
|
|
|
|
|
// TODO(luotao02) data type and platform::DeviceContext() should set
|
|
|
|
|
// correctly
|
|
|
|
|
(output->Slice<float>(j, j + 1))
|
|
|
|
@ -113,29 +112,9 @@ void InitArgument(const ArgumentName& name, Argument* arg,
|
|
|
|
|
const framework::OperatorBase& op) {
|
|
|
|
|
arg->step_scopes = op.Output(name.step_scopes);
|
|
|
|
|
|
|
|
|
|
auto inlinks = op.Inputs(name.inlinks);
|
|
|
|
|
auto inlink_alias = op.Attr<std::vector<std::string>>(name.inlink_alias);
|
|
|
|
|
PADDLE_ENFORCE(inlinks.size() == inlink_alias.size(),
|
|
|
|
|
"the size of inlinks and inlink_alias don't match:%d,%d",
|
|
|
|
|
inlinks.size(), inlink_alias.size());
|
|
|
|
|
for (size_t i = 0; i < inlinks.size(); ++i) {
|
|
|
|
|
rnn::Link link;
|
|
|
|
|
link.external = inlinks[i];
|
|
|
|
|
link.internal = inlink_alias[i];
|
|
|
|
|
(arg->inlinks).push_back(link);
|
|
|
|
|
}
|
|
|
|
|
arg->inlinks = op.Inputs(name.inlinks);
|
|
|
|
|
|
|
|
|
|
auto outlinks = op.Outputs(name.outlinks);
|
|
|
|
|
auto outlink_alias = op.Attr<std::vector<std::string>>(name.outlink_alias);
|
|
|
|
|
PADDLE_ENFORCE(outlinks.size() == outlink_alias.size(),
|
|
|
|
|
"the size of outlinks and outlink_alias don't match:%d,%d",
|
|
|
|
|
outlinks.size(), outlink_alias.size());
|
|
|
|
|
for (size_t i = 0; i < outlinks.size(); ++i) {
|
|
|
|
|
rnn::Link link;
|
|
|
|
|
link.external = outlinks[i];
|
|
|
|
|
link.internal = outlink_alias[i];
|
|
|
|
|
(arg->outlinks).push_back(link);
|
|
|
|
|
}
|
|
|
|
|
arg->outlinks = op.Outputs(name.outlinks);
|
|
|
|
|
|
|
|
|
|
auto boot_memories = op.Inputs(name.boot_memories);
|
|
|
|
|
|
|
|
|
|