|
|
|
@ -23,13 +23,13 @@ namespace f = paddle::framework;
|
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
|
|
|
|
|
|
void SegmentInputs(const std::vector<Scope*>& step_scopes,
|
|
|
|
|
const std::vector<Link>& inlinks, const size_t seq_len,
|
|
|
|
|
bool infer_shape_mode) {
|
|
|
|
|
const std::vector<std::string>& inlinks,
|
|
|
|
|
const size_t seq_len, bool infer_shape_mode) {
|
|
|
|
|
PADDLE_ENFORCE(!inlinks.empty(), "no in links are provided.");
|
|
|
|
|
for (size_t i = 0; i < inlinks.size(); ++i) {
|
|
|
|
|
auto input_var = step_scopes[0]->FindVar(inlinks[i].external);
|
|
|
|
|
auto input_var = step_scopes[0]->FindVar(inlinks[i]);
|
|
|
|
|
PADDLE_ENFORCE(input_var != nullptr, "input link [%s] is not in scope.",
|
|
|
|
|
inlinks[i].external);
|
|
|
|
|
inlinks[i]);
|
|
|
|
|
|
|
|
|
|
Tensor* input = input_var->GetMutable<Tensor>();
|
|
|
|
|
f::DDim dims = input->dims();
|
|
|
|
@ -38,7 +38,7 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes,
|
|
|
|
|
f::DDim step_dims = slice_ddim(dims, 1, dims.size());
|
|
|
|
|
for (size_t j = 0; j < seq_len; j++) {
|
|
|
|
|
Tensor* step_input =
|
|
|
|
|
step_scopes[j]->NewVar(inlinks[i].internal)->GetMutable<Tensor>();
|
|
|
|
|
step_scopes[j]->NewVar(inlinks[i])->GetMutable<Tensor>();
|
|
|
|
|
if (!infer_shape_mode) {
|
|
|
|
|
*step_input = input->Slice<float>(j, j + 1);
|
|
|
|
|
}
|
|
|
|
@ -48,18 +48,17 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ConcatOutputs(const std::vector<Scope*>& step_scopes,
|
|
|
|
|
const std::vector<Link>& outlinks, const size_t seq_len,
|
|
|
|
|
bool infer_shape_mode) {
|
|
|
|
|
const std::vector<std::string>& outlinks,
|
|
|
|
|
const size_t seq_len, bool infer_shape_mode) {
|
|
|
|
|
for (size_t i = 0; i < outlinks.size(); i++) {
|
|
|
|
|
auto output_var = step_scopes[0]->FindVar(outlinks[i].external);
|
|
|
|
|
auto output_var = step_scopes[0]->FindVar(outlinks[i]);
|
|
|
|
|
PADDLE_ENFORCE(output_var != nullptr, "output link [%s] is not in scope.",
|
|
|
|
|
outlinks[i].external);
|
|
|
|
|
outlinks[i]);
|
|
|
|
|
Tensor* output = output_var->GetMutable<Tensor>();
|
|
|
|
|
|
|
|
|
|
if (infer_shape_mode) {
|
|
|
|
|
auto step_scope_var = step_scopes[0]->FindVar(outlinks[i].internal);
|
|
|
|
|
PADDLE_ENFORCE(step_scope_var != nullptr, "%s not in scope",
|
|
|
|
|
outlinks[i].internal);
|
|
|
|
|
auto step_scope_var = step_scopes[0]->FindVar(outlinks[i]);
|
|
|
|
|
PADDLE_ENFORCE(step_scope_var != nullptr, "%s not in scope", outlinks[i]);
|
|
|
|
|
f::DDim step_dims = step_scope_var->template GetMutable<Tensor>()->dims();
|
|
|
|
|
std::vector<int64_t> dims_vec = vectorize(step_dims);
|
|
|
|
|
dims_vec.insert(dims_vec.begin(), seq_len);
|
|
|
|
@ -68,7 +67,7 @@ void ConcatOutputs(const std::vector<Scope*>& step_scopes,
|
|
|
|
|
output->mutable_data<float>(platform::CPUPlace());
|
|
|
|
|
for (size_t j = 0; j < seq_len; j++) {
|
|
|
|
|
Tensor* step_output =
|
|
|
|
|
step_scopes[j]->FindVar(outlinks[i].internal)->GetMutable<Tensor>();
|
|
|
|
|
step_scopes[j]->FindVar(outlinks[i])->GetMutable<Tensor>();
|
|
|
|
|
// TODO(luotao02) data type and platform::DeviceContext() should set
|
|
|
|
|
// correctly
|
|
|
|
|
(output->Slice<float>(j, j + 1))
|
|
|
|
@ -108,29 +107,9 @@ void InitArgument(const ArgumentName& name, Argument* arg,
|
|
|
|
|
const framework::OperatorBase& op) {
|
|
|
|
|
arg->step_scopes = op.Output(name.step_scopes);
|
|
|
|
|
|
|
|
|
|
auto inlinks = op.Inputs(name.inlinks);
|
|
|
|
|
auto inlink_alias = op.Attr<std::vector<std::string>>(name.inlink_alias);
|
|
|
|
|
PADDLE_ENFORCE(inlinks.size() == inlink_alias.size(),
|
|
|
|
|
"the size of inlinks and inlink_alias don't match:%d,%d",
|
|
|
|
|
inlinks.size(), inlink_alias.size());
|
|
|
|
|
for (size_t i = 0; i < inlinks.size(); ++i) {
|
|
|
|
|
rnn::Link link;
|
|
|
|
|
link.external = inlinks[i];
|
|
|
|
|
link.internal = inlink_alias[i];
|
|
|
|
|
(arg->inlinks).push_back(link);
|
|
|
|
|
}
|
|
|
|
|
arg->inlinks = op.Inputs(name.inlinks);
|
|
|
|
|
|
|
|
|
|
auto outlinks = op.Outputs(name.outlinks);
|
|
|
|
|
auto outlink_alias = op.Attr<std::vector<std::string>>(name.outlink_alias);
|
|
|
|
|
PADDLE_ENFORCE(outlinks.size() == outlink_alias.size(),
|
|
|
|
|
"the size of outlinks and outlink_alias don't match:%d,%d",
|
|
|
|
|
outlinks.size(), outlink_alias.size());
|
|
|
|
|
for (size_t i = 0; i < outlinks.size(); ++i) {
|
|
|
|
|
rnn::Link link;
|
|
|
|
|
link.external = outlinks[i];
|
|
|
|
|
link.internal = outlink_alias[i];
|
|
|
|
|
(arg->outlinks).push_back(link);
|
|
|
|
|
}
|
|
|
|
|
arg->outlinks = op.Outputs(name.outlinks);
|
|
|
|
|
|
|
|
|
|
auto boot_memories = op.Inputs(name.boot_memories);
|
|
|
|
|
|
|
|
|
|