|
|
|
@ -105,15 +105,7 @@ class LoDTensorArray2TensorOp : public framework::OperatorBase {
|
|
|
|
|
auto out_inx_dim = out_inx.dims();
|
|
|
|
|
out_inx_dim[0] = inx.size();
|
|
|
|
|
out_inx.Resize(out_inx_dim);
|
|
|
|
|
|
|
|
|
|
auto &local_scope = scope.NewScope();
|
|
|
|
|
std::string var_name = "out_index";
|
|
|
|
|
framework::Variable *tmp_index_var = local_scope.Var(var_name);
|
|
|
|
|
auto &tmp_index_tensor =
|
|
|
|
|
*(tmp_index_var->GetMutable<paddle::framework::LoDTensor>());
|
|
|
|
|
tmp_index_tensor.Resize(out_inx_dim);
|
|
|
|
|
int *tmp_index_data =
|
|
|
|
|
tmp_index_tensor.mutable_data<int>(platform::CPUPlace());
|
|
|
|
|
int *tmp_index_data = out_inx.mutable_data<int>(platform::CPUPlace());
|
|
|
|
|
|
|
|
|
|
auto out_dims = inx[0].dims();
|
|
|
|
|
size_t out_dim_sum = 0;
|
|
|
|
@ -122,18 +114,17 @@ class LoDTensorArray2TensorOp : public framework::OperatorBase {
|
|
|
|
|
out_dim_sum += inx_dims[axis];
|
|
|
|
|
tmp_index_data[index] = inx_dims[axis];
|
|
|
|
|
}
|
|
|
|
|
out_inx.ShareDataWith(tmp_index_tensor);
|
|
|
|
|
|
|
|
|
|
// get input array items' dims
|
|
|
|
|
out_dims[axis] = out_dim_sum;
|
|
|
|
|
out.Resize(out_dims);
|
|
|
|
|
|
|
|
|
|
LodTensorArray2LodTensorVector(local_scope, base_name, Input("X"), &names);
|
|
|
|
|
LodTensorArray2LodTensorVector(scope, base_name, Input("X"), &names);
|
|
|
|
|
// Invoke concat Op
|
|
|
|
|
auto concat_op = framework::OpRegistry::CreateOp(
|
|
|
|
|
"concat", {{"X", names}}, {{"Out", {Output("Out")}}}, attrs);
|
|
|
|
|
|
|
|
|
|
concat_op->Run(local_scope, place);
|
|
|
|
|
concat_op->Run(scope, place);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|