|
|
|
@ -32,7 +32,7 @@ class LogsumexpOp : public framework::OperatorWithKernel {
|
|
|
|
|
PADDLE_ENFORCE_LE(x_rank, 4,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The input tensor X's dimensions of logsumexp "
|
|
|
|
|
"should be less equal than 4. But received X's "
|
|
|
|
|
"should be less or equal than 4. But received X's "
|
|
|
|
|
"dimensions = %d, X's shape = [%s].",
|
|
|
|
|
x_rank, x_dims));
|
|
|
|
|
auto axis = ctx->Attrs().Get<std::vector<int>>("axis");
|
|
|
|
@ -45,20 +45,18 @@ class LogsumexpOp : public framework::OperatorWithKernel {
|
|
|
|
|
axis.size()));
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < axis.size(); i++) {
|
|
|
|
|
PADDLE_ENFORCE_LT(
|
|
|
|
|
axis[i], x_rank,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"axis[%d] should be in the "
|
|
|
|
|
"range [-dimension(X), dimension(X)] "
|
|
|
|
|
"where dimesion(X) is %d. But received axis[i] = %d.",
|
|
|
|
|
i, x_rank, axis[i]));
|
|
|
|
|
PADDLE_ENFORCE_GE(
|
|
|
|
|
axis[i], -x_rank,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"axis[%d] should be in the "
|
|
|
|
|
"range [-dimension(X), dimension(X)] "
|
|
|
|
|
"where dimesion(X) is %d. But received axis[i] = %d.",
|
|
|
|
|
i, x_rank, axis[i]));
|
|
|
|
|
PADDLE_ENFORCE_LT(axis[i], x_rank,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"axis[%d] should be in the "
|
|
|
|
|
"range [-D, D), where D is the dimensions of X and "
|
|
|
|
|
"D is %d. But received axis[%d] = %d.",
|
|
|
|
|
i, x_rank, i, axis[i]));
|
|
|
|
|
PADDLE_ENFORCE_GE(axis[i], -x_rank,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"axis[%d] should be in the "
|
|
|
|
|
"range [-D, D), where D is the dimensions of X and "
|
|
|
|
|
"D is %d. But received axis[%d] = %d.",
|
|
|
|
|
i, x_rank, i, axis[i]));
|
|
|
|
|
if (axis[i] < 0) {
|
|
|
|
|
axis[i] += x_rank;
|
|
|
|
|
}
|
|
|
|
|