tensor_array_to_tensor_op.cc, test=develop (#19289)

expand_as_op_1
石晓伟 6 years ago committed by GitHub
parent 0436efd6a3
commit 30adea0a23
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -105,15 +105,7 @@ class LoDTensorArray2TensorOp : public framework::OperatorBase {
auto out_inx_dim = out_inx.dims();
out_inx_dim[0] = inx.size();
out_inx.Resize(out_inx_dim);
auto &local_scope = scope.NewScope();
std::string var_name = "out_index";
framework::Variable *tmp_index_var = local_scope.Var(var_name);
auto &tmp_index_tensor =
*(tmp_index_var->GetMutable<paddle::framework::LoDTensor>());
tmp_index_tensor.Resize(out_inx_dim);
int *tmp_index_data =
tmp_index_tensor.mutable_data<int>(platform::CPUPlace());
int *tmp_index_data = out_inx.mutable_data<int>(platform::CPUPlace());
auto out_dims = inx[0].dims();
size_t out_dim_sum = 0;
@ -122,18 +114,17 @@ class LoDTensorArray2TensorOp : public framework::OperatorBase {
out_dim_sum += inx_dims[axis];
tmp_index_data[index] = inx_dims[axis];
}
out_inx.ShareDataWith(tmp_index_tensor);
// get input array items' dims
out_dims[axis] = out_dim_sum;
out.Resize(out_dims);
LodTensorArray2LodTensorVector(local_scope, base_name, Input("X"), &names);
LodTensorArray2LodTensorVector(scope, base_name, Input("X"), &names);
// Invoke concat Op
auto concat_op = framework::OpRegistry::CreateOp(
"concat", {{"X", names}}, {{"Out", {Output("Out")}}}, attrs);
concat_op->Run(local_scope, place);
concat_op->Run(scope, place);
}
};

Loading…
Cancel
Save