Correct the forward of sequence_softmax_op.

tonyyang-svail-feed-op-desgin
Liu Yiqun 8 years ago
parent 4d9293940b
commit 12f2b8eb07

@ -42,8 +42,7 @@ class ReshapeOp : public framework::OperatorWithKernel {
int64_t capacity =
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
auto *in = ctx.Input<framework::Tensor>("X");
int64_t in_size = framework::product(in->dims());
PADDLE_ENFORCE_EQ(capacity, in_size,
PADDLE_ENFORCE_EQ(capacity, in->numel(),
"The size of Input(X) mismatches with Attr(shape).");
// resize output
std::vector<int64_t> shape_int64(shape.size(), 0);

@ -30,18 +30,20 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel {
"Output(Out) of SequenceSoftmaxOp should not be null.");
auto *x = ctx.Input<framework::LoDTensor>("X");
auto dims = x->dims();
auto lod = x->lod();
PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now.");
auto dims = x->dims();
PADDLE_ENFORCE_GE(
dims[0],
/* batch_size */ static_cast<int64_t>(lod[0].size() - 1),
"The first dimension of Input(X) should be larger than batch size.");
PADDLE_ENFORCE_EQ(x->numel(), static_cast<int64_t>(lod[0].size() - 1),
const size_t level = lod.size() - 1;
PADDLE_ENFORCE_EQ(x->numel(), static_cast<int64_t>(lod[level].back()),
"The width of each timestep in Input(X) of "
"SequenceSoftmaxOp should be 1.");
dims[0] = lod[0].size() - 1;
std::cout << DebugString() << std::endl;
ctx.Output<framework::LoDTensor>("Out")->Resize({dims});
}
};

@ -38,7 +38,7 @@ class SequenceSoftmaxKernel : public framework::OpKernel {
auto* out = ctx.Output<LoDTensor>("Out");
auto lod = x->lod();
const size_t level = lod.size();
const size_t level = lod.size() - 1;
out->mutable_data<T>(ctx.GetPlace());
for (int i = 0; i < static_cast<int>(lod[level].size()) - 1; ++i) {
@ -47,6 +47,10 @@ class SequenceSoftmaxKernel : public framework::OpKernel {
Tensor x_i = x->Slice<T>(start_pos, end_pos);
Tensor out_i = out->Slice<T>(start_pos, end_pos);
// Reshape from (end_pos - start_pos) x 1UL to 1UL x (end_pos - start_pos)
framework::DDim dims = framework::make_ddim({1UL, end_pos - start_pos});
x_i.Resize(dims);
out_i.Resize(dims);
math::SoftmaxFunctor<Place, T>()(&x_i, &out_i, ctx);
}
}

Loading…
Cancel
Save