|
|
|
@ -110,10 +110,12 @@ struct VisitDataArgMinMaxFunctor {
|
|
|
|
|
CALL_ARG_MINMAX_FUNCTOR(6);
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
PADDLE_THROW(
|
|
|
|
|
"%s operator doesn't supports tensors whose ranks are greater "
|
|
|
|
|
"than 6.",
|
|
|
|
|
(EnumArgMinMaxValue == kArgMin ? "argmin" : "argmax"));
|
|
|
|
|
PADDLE_ENFORCE_LE(
|
|
|
|
|
x_dims.size(), 6,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"%s operator doesn't supports tensors whose ranks are greater "
|
|
|
|
|
"than 6.",
|
|
|
|
|
(EnumArgMinMaxValue == kArgMin ? "argmin" : "argmax")));
|
|
|
|
|
break;
|
|
|
|
|
#undef CALL_ARG_MINMAX_FUNCTOR
|
|
|
|
|
}
|
|
|
|
@ -164,7 +166,8 @@ class ArgMinMaxOp : public framework::OperatorWithKernel {
|
|
|
|
|
PADDLE_ENFORCE_LT(
|
|
|
|
|
axis, x_dims.size(),
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"'axis'(%d) must be less than Rank(X)(%d).", axis, x_dims.size()));
|
|
|
|
|
"'axis'(%d) must be less than Rank(X)(%d) of Input(X).", axis,
|
|
|
|
|
x_dims.size()));
|
|
|
|
|
|
|
|
|
|
const int& dtype = ctx->Attrs().Get<int>("dtype");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
@ -192,10 +195,11 @@ class ArgMinMaxOp : public framework::OperatorWithKernel {
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE_LE(
|
|
|
|
|
all_element_num, INT_MAX,
|
|
|
|
|
"The element num of the argmin/argmax input at axis is "
|
|
|
|
|
"%d, is larger than int32 maximum value:%d, you must "
|
|
|
|
|
"set the dtype of argmin/argmax to 'int64'.",
|
|
|
|
|
all_element_num, INT_MAX);
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The element num of the argmin/argmax input at axis is "
|
|
|
|
|
"%d, is larger than int32 maximum value:%d, you must "
|
|
|
|
|
"set the dtype of argmin/argmax to 'int64'.",
|
|
|
|
|
all_element_num, INT_MAX));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
std::vector<int64_t> vec;
|
|
|
|
|