|
|
|
|
@ -106,9 +106,9 @@ class LoDTensorArray2TensorOp : public framework::OperatorBase {
|
|
|
|
|
out_inx_dim[0] = inx.size();
|
|
|
|
|
out_inx.Resize(out_inx_dim);
|
|
|
|
|
|
|
|
|
|
auto &local_scope = scope.NewScope();
|
|
|
|
|
std::string var_name = "out_index";
|
|
|
|
|
framework::Variable *tmp_index_var =
|
|
|
|
|
const_cast<framework::Scope &>(scope).Var(var_name);
|
|
|
|
|
framework::Variable *tmp_index_var = local_scope.Var(var_name);
|
|
|
|
|
auto &tmp_index_tensor =
|
|
|
|
|
*(tmp_index_var->GetMutable<paddle::framework::LoDTensor>());
|
|
|
|
|
tmp_index_tensor.Resize(out_inx_dim);
|
|
|
|
|
@ -128,12 +128,12 @@ class LoDTensorArray2TensorOp : public framework::OperatorBase {
|
|
|
|
|
out_dims[axis] = out_dim_sum;
|
|
|
|
|
out.Resize(out_dims);
|
|
|
|
|
|
|
|
|
|
LodTensorArray2LodTensorVector(scope, base_name, Input("X"), &names);
|
|
|
|
|
// Invoke Reshape Op
|
|
|
|
|
LodTensorArray2LodTensorVector(local_scope, base_name, Input("X"), &names);
|
|
|
|
|
// Invoke concat Op
|
|
|
|
|
auto concat_op = framework::OpRegistry::CreateOp(
|
|
|
|
|
"concat", {{"X", names}}, {{"Out", {Output("Out")}}}, attrs);
|
|
|
|
|
|
|
|
|
|
concat_op->Run(scope, place);
|
|
|
|
|
concat_op->Run(local_scope, place);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|