|
|
|
@ -28,14 +28,15 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes,
|
|
|
|
|
const size_t seq_len, bool infer_shape_mode) {
|
|
|
|
|
PADDLE_ENFORCE(!inlinks.empty(), "no in links are provided.");
|
|
|
|
|
for (size_t i = 0; i < inlinks.size(); ++i) {
|
|
|
|
|
auto input_var = step_scopes[0]->FindVar(inlinks[i]);
|
|
|
|
|
PADDLE_ENFORCE(input_var != nullptr, "input link [%s] is not in scope.",
|
|
|
|
|
inlinks[i]);
|
|
|
|
|
// global inputs
|
|
|
|
|
auto input_var = step_scopes[0]->parent().FindVar(inlinks[i]);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(input_var, "input link [%s] is not in scope.",
|
|
|
|
|
inlinks[i]);
|
|
|
|
|
|
|
|
|
|
LoDTensor* input = input_var->GetMutable<LoDTensor>();
|
|
|
|
|
f::DDim dims = input->dims();
|
|
|
|
|
PADDLE_ENFORCE(static_cast<size_t>(dims[0]) == seq_len,
|
|
|
|
|
"all the inlinks must have same length");
|
|
|
|
|
PADDLE_ENFORCE_EQ(static_cast<size_t>(dims[0]), seq_len,
|
|
|
|
|
"all the inlinks be the same length");
|
|
|
|
|
f::DDim step_dims = slice_ddim(dims, 1, dims.size());
|
|
|
|
|
for (size_t j = 0; j < seq_len; j++) {
|
|
|
|
|
Tensor* step_input =
|
|
|
|
@ -54,15 +55,14 @@ void ConcatOutputs(const std::vector<Scope*>& step_scopes,
|
|
|
|
|
const std::vector<std::string>& outlinks,
|
|
|
|
|
const size_t seq_len, bool infer_shape_mode) {
|
|
|
|
|
for (size_t i = 0; i < outlinks.size(); i++) {
|
|
|
|
|
auto output_var = step_scopes[0]->FindVar(outlinks[i]);
|
|
|
|
|
PADDLE_ENFORCE(output_var != nullptr, "output link [%s] is not in scope.",
|
|
|
|
|
outlinks[i]);
|
|
|
|
|
auto output_var = step_scopes[0]->parent().FindVar(outlinks[i]);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(output_var, "output link [%s] is not in scope.",
|
|
|
|
|
outlinks[i]);
|
|
|
|
|
LoDTensor* output = output_var->GetMutable<LoDTensor>();
|
|
|
|
|
|
|
|
|
|
if (infer_shape_mode) {
|
|
|
|
|
auto step_scope_var = step_scopes[0]->FindVar(outlinks[i].internal);
|
|
|
|
|
PADDLE_ENFORCE(step_scope_var != nullptr, "%s not in scope",
|
|
|
|
|
outlinks[i].internal);
|
|
|
|
|
auto step_scope_var = step_scopes[0]->FindVar(outlinks[i]);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(step_scope_var, "%s not in scope", outlinks[i]);
|
|
|
|
|
f::DDim step_dims =
|
|
|
|
|
step_scope_var->template GetMutable<LoDTensor>()->dims();
|
|
|
|
|
std::vector<int64_t> dims_vec = vectorize(step_dims);
|
|
|
|
|