|
|
|
@ -31,6 +31,9 @@ class FeedKernel : public framework::OpKernel<T> {
|
|
|
|
|
g_feed_variable->Get<std::vector<framework::Tensor>>();
|
|
|
|
|
int col = ctx.template Attr<int>("col");
|
|
|
|
|
PADDLE_ENFORCE_GT(tensors.size(), static_cast<size_t>(col));
|
|
|
|
|
// TODO(qijun):
|
|
|
|
|
// check tensors[col].dims() with attribute,
|
|
|
|
|
// except the first dimenson.
|
|
|
|
|
out->CopyFrom<T>(tensors[col], ctx.GetPlace());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|