|
|
|
@ -58,6 +58,8 @@ class ArgMinMaxKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto& out = *(ctx.Output<framework::LoDTensor>("Out"));
|
|
|
|
|
out.mutable_data<Tout>(ctx.GetPlace());
|
|
|
|
|
auto axis = ctx.Attr<int64_t>("axis");
|
|
|
|
|
auto x_rank = x.dims().size();
|
|
|
|
|
if (axis < 0) axis += x_rank;
|
|
|
|
|
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
|
|
|
|
|
|
|
|
|
#define CALL_ARG_MINMAX_FUNCTOR(rank) \
|
|
|
|
|