infer feed operator output variable shape with dims attribute

revert-4814-Add_sequence_project_op
qijun 8 years ago
parent 2fc7fc7a18
commit 975a51294e

@ -32,8 +32,12 @@ class FeedOp : public framework::OperatorWithKernel {
g_feed_variable->Get<std::vector<framework::Tensor>>();
PADDLE_ENFORCE_GT(tensors.size(), static_cast<size_t>(col));
auto in_dim = tensors[col].dims();
ctx->SetOutputDim("Out", in_dim);
auto& shape = ctx->Attrs().Get<std::vector<int>>("dims");
std::vector<int64_t> shape_int64(shape.size(), 0);
std::transform(shape.begin(), shape.end(), shape_int64.begin(),
[](int a) { return static_cast<int64_t>(a); });
ctx->SetOutputDim("Out", framework::make_ddim(shape_int64));
// TODO(qijun): need to handle LodTensor later
}

Loading…
Cancel
Save