diff --git a/paddle/fluid/operators/gather_op.cc b/paddle/fluid/operators/gather_op.cc index 648afe7e82..162766546b 100644 --- a/paddle/fluid/operators/gather_op.cc +++ b/paddle/fluid/operators/gather_op.cc @@ -93,6 +93,15 @@ class GatherGradOp : public framework::OperatorWithKernel { ctx, framework::GradVarName("Out")), ctx.device_context()); } + framework::OpKernelType GetKernelTypeForVar( + const std::string& var_name, const framework::Tensor& tensor, + const framework::OpKernelType& expected_kernel_type) const override { + if (var_name == "Axis") { + return expected_kernel_type; + } + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); + } }; class GatherOpMaker : public framework::OpProtoAndCheckerMaker { diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 4a01f7e7fa..adad9cfdc2 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -785,9 +785,12 @@ def gather(x, index, axis=None, name=None): if axis is None: axis = 0 axis_tensor = axis + if not isinstance(axis, Variable) and axis == 0: + return paddle.fluid.layers.gather(input=x, index=index, overwrite=True) if not isinstance(axis, Variable): with device_guard("cpu"): - axis_tensor = fill_constant(shape=[1], dtype='int64', value=axis) + axis_tensor = fill_constant( + shape=[1], dtype='int64', value=axis, force_cpu=True) if in_dygraph_mode(): return core.ops.gather(x, index, axis_tensor)