Refine reduce codes to save compiling time and binary size (#20676)
* refine reduce code to save compiling time and binary sizes, test=develop * add reduce rank check to avoid bug, test=developrevert-20712-fix_depthwise_conv
parent
1d925440ca
commit
34e3adaece
@ -0,0 +1,53 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace detail {
|
||||
|
||||
TEST(test_reduce_rank_check, all) {
|
||||
using EnforceNotMet = paddle::platform::EnforceNotMet;
|
||||
constexpr int kMaxRank = framework::DDim::kMaxRank;
|
||||
|
||||
for (int rank = 0; rank < kMaxRank; rank++) {
|
||||
for (int reduce_rank = 0; reduce_rank <= rank; reduce_rank++) {
|
||||
bool is_valid = false;
|
||||
if (rank % 2 == 0) {
|
||||
is_valid = (reduce_rank == rank / 2);
|
||||
} else {
|
||||
if (reduce_rank == (rank - 1) / 2) {
|
||||
is_valid = true;
|
||||
} else if (reduce_rank == (rank + 1) / 2) {
|
||||
is_valid = true;
|
||||
} else {
|
||||
is_valid = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (is_valid) {
|
||||
CheckReduceRankIsValid(reduce_rank, rank);
|
||||
} else {
|
||||
ASSERT_THROW(CheckReduceRankIsValid(reduce_rank, rank),
|
||||
paddle::platform::EnforceNotMet);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
Loading…
Reference in new issue