Fix fill_constant_batch_size_like_op when input is LoDTensor. (#10943)

release/0.13.0
qingqing01 7 years ago committed by GitHub
parent bf869e45c1
commit 91bd5835df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -24,6 +24,14 @@ class FillConstantBatchSizeLikeOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* out = ctx.Output<framework::Tensor>("Out");
auto* in = ctx.Input<framework::LoDTensor>("Input");
if (in->lod().size() && ctx.Attr<int>("input_dim_idx") == 0) {
// set the correct batch size for the LoDTensor.
auto odims = out->dims();
int output_dim_idx = ctx.Attr<int>("output_dim_idx");
odims[output_dim_idx] = static_cast<int>(in->lod().back().size()) - 1;
out->mutable_data<T>(odims, ctx.GetPlace());
}
out->mutable_data<T>(ctx.GetPlace());
auto value = ctx.Attr<float>("value");

@ -50,5 +50,27 @@ class TestFillConstantBatchSizeLikeWhenSecondDimIsBatchSize(OpTest):
self.check_output()
class TestFillConstantBatchSizeLikeWithLoDTensor(OpTest):
def setUp(self):
self.op_type = "fill_constant_batch_size_like"
self.inputs = {
'Input': (np.random.random((31, 28)).astype("float32"),
[[0, 9, 23, 31]])
}
self.attrs = {
'value': 3.5,
'shape': [-1, 16],
'input_dim_idx': 0,
'output_dim_idx': 0
}
out = np.random.random((3, 16)).astype("float32")
out.fill(3.5)
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
if __name__ == "__main__":
unittest.main()

Loading…
Cancel
Save