|
|
|
@ -24,6 +24,14 @@ class FillConstantBatchSizeLikeOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
auto* out = ctx.Output<framework::Tensor>("Out");
|
|
|
|
|
auto* in = ctx.Input<framework::LoDTensor>("Input");
|
|
|
|
|
if (in->lod().size() && ctx.Attr<int>("input_dim_idx") == 0) {
|
|
|
|
|
// set the correct batch size for the LoDTensor.
|
|
|
|
|
auto odims = out->dims();
|
|
|
|
|
int output_dim_idx = ctx.Attr<int>("output_dim_idx");
|
|
|
|
|
odims[output_dim_idx] = static_cast<int>(in->lod().back().size()) - 1;
|
|
|
|
|
out->mutable_data<T>(odims, ctx.GetPlace());
|
|
|
|
|
}
|
|
|
|
|
out->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
auto value = ctx.Attr<float>("value");
|
|
|
|
|
|
|
|
|
|