|
|
|
@ -32,8 +32,12 @@ class FeedOp : public framework::OperatorWithKernel {
|
|
|
|
|
g_feed_variable->Get<std::vector<framework::Tensor>>();
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_GT(tensors.size(), static_cast<size_t>(col));
|
|
|
|
|
auto in_dim = tensors[col].dims();
|
|
|
|
|
ctx->SetOutputDim("Out", in_dim);
|
|
|
|
|
|
|
|
|
|
auto& shape = ctx->Attrs().Get<std::vector<int>>("dims");
|
|
|
|
|
std::vector<int64_t> shape_int64(shape.size(), 0);
|
|
|
|
|
std::transform(shape.begin(), shape.end(), shape_int64.begin(),
|
|
|
|
|
[](int a) { return static_cast<int64_t>(a); });
|
|
|
|
|
ctx->SetOutputDim("Out", framework::make_ddim(shape_int64));
|
|
|
|
|
// TODO(qijun): need to handle LodTensor later
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|