|
|
|
@ -21,6 +21,7 @@ namespace rnn {
|
|
|
|
|
namespace f = paddle::framework;
|
|
|
|
|
|
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
|
using LoDTensor = framework::LoDTensor;
|
|
|
|
|
|
|
|
|
|
void SegmentInputs(const std::vector<Scope*>& step_scopes,
|
|
|
|
|
const std::vector<Link>& inlinks, const size_t seq_len,
|
|
|
|
@ -31,7 +32,7 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes,
|
|
|
|
|
PADDLE_ENFORCE(input_var != nullptr, "input link [%s] is not in scope.",
|
|
|
|
|
inlinks[i].external);
|
|
|
|
|
|
|
|
|
|
Tensor* input = input_var->GetMutable<Tensor>();
|
|
|
|
|
LoDTensor* input = input_var->GetMutable<LoDTensor>();
|
|
|
|
|
f::DDim dims = input->dims();
|
|
|
|
|
PADDLE_ENFORCE(static_cast<size_t>(dims[0]) == seq_len,
|
|
|
|
|
"all the inlinks must have same length");
|
|
|
|
@ -40,6 +41,8 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes,
|
|
|
|
|
Tensor* step_input =
|
|
|
|
|
step_scopes[j]->NewVar(inlinks[i].internal)->GetMutable<Tensor>();
|
|
|
|
|
if (!infer_shape_mode) {
|
|
|
|
|
// The input of operators of each step is Tensor here.
|
|
|
|
|
// Maybe need to modify Slice function.
|
|
|
|
|
*step_input = input->Slice<float>(j, j + 1);
|
|
|
|
|
}
|
|
|
|
|
step_input->Resize(step_dims);
|
|
|
|
@ -54,21 +57,23 @@ void ConcatOutputs(const std::vector<Scope*>& step_scopes,
|
|
|
|
|
auto output_var = step_scopes[0]->FindVar(outlinks[i].external);
|
|
|
|
|
PADDLE_ENFORCE(output_var != nullptr, "output link [%s] is not in scope.",
|
|
|
|
|
outlinks[i].external);
|
|
|
|
|
Tensor* output = output_var->GetMutable<Tensor>();
|
|
|
|
|
LoDTensor* output = output_var->GetMutable<LoDTensor>();
|
|
|
|
|
|
|
|
|
|
if (infer_shape_mode) {
|
|
|
|
|
auto step_scope_var = step_scopes[0]->FindVar(outlinks[i].internal);
|
|
|
|
|
PADDLE_ENFORCE(step_scope_var != nullptr, "%s not in scope",
|
|
|
|
|
outlinks[i].internal);
|
|
|
|
|
f::DDim step_dims = step_scope_var->template GetMutable<Tensor>()->dims();
|
|
|
|
|
f::DDim step_dims =
|
|
|
|
|
step_scope_var->template GetMutable<LoDTensor>()->dims();
|
|
|
|
|
std::vector<int64_t> dims_vec = vectorize(step_dims);
|
|
|
|
|
dims_vec.insert(dims_vec.begin(), seq_len);
|
|
|
|
|
output->Resize(f::make_ddim(dims_vec));
|
|
|
|
|
} else {
|
|
|
|
|
output->mutable_data<float>(platform::CPUPlace());
|
|
|
|
|
for (size_t j = 0; j < seq_len; j++) {
|
|
|
|
|
Tensor* step_output =
|
|
|
|
|
step_scopes[j]->FindVar(outlinks[i].internal)->GetMutable<Tensor>();
|
|
|
|
|
LoDTensor* step_output = step_scopes[j]
|
|
|
|
|
->FindVar(outlinks[i].internal)
|
|
|
|
|
->GetMutable<LoDTensor>();
|
|
|
|
|
// TODO(luotao02) data type and platform::DeviceContext() should set
|
|
|
|
|
// correctly
|
|
|
|
|
(output->Slice<float>(j, j + 1))
|
|
|
|
@ -94,8 +99,8 @@ void LinkMemories(const std::vector<Scope*>& scopes,
|
|
|
|
|
auto scope = scopes[step_id];
|
|
|
|
|
auto linked_scope = scopes[step_id + offset];
|
|
|
|
|
for (auto& attr : memories) {
|
|
|
|
|
auto mem = scope->FindVar(attr.pre_var)->GetMutable<Tensor>();
|
|
|
|
|
auto linked_mem = linked_scope->FindVar(attr.var)->GetMutable<Tensor>();
|
|
|
|
|
auto mem = scope->FindVar(attr.pre_var)->GetMutable<LoDTensor>();
|
|
|
|
|
auto linked_mem = linked_scope->FindVar(attr.var)->GetMutable<LoDTensor>();
|
|
|
|
|
if (infer_shape_mode) {
|
|
|
|
|
mem->Resize(linked_mem->dims());
|
|
|
|
|
} else {
|
|
|
|
|