|
|
|
@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/operators/reduce_op.h"
|
|
|
|
|
#include "paddle/operators/net_op.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
@ -38,10 +37,14 @@ class ReduceOp : public framework::OperatorWithKernel {
|
|
|
|
|
dim, x_rank,
|
|
|
|
|
"The dim should be in the range [-rank(input), rank(input)).");
|
|
|
|
|
bool reduce_all = ctx->Attrs().Get<bool>("reduce_all");
|
|
|
|
|
bool keep_dim = ctx->Attrs().Get<bool>("keep_dim");
|
|
|
|
|
if (reduce_all) {
|
|
|
|
|
ctx->SetOutputDim("Out", {1});
|
|
|
|
|
if (keep_dim)
|
|
|
|
|
ctx->SetOutputDim(
|
|
|
|
|
"Out", framework::make_ddim(std::vector<int64_t>(x_rank, 1)));
|
|
|
|
|
else
|
|
|
|
|
ctx->SetOutputDim("Out", {1});
|
|
|
|
|
} else {
|
|
|
|
|
bool keep_dim = ctx->Attrs().Get<bool>("keep_dim");
|
|
|
|
|
auto dims_vector = vectorize(x_dims);
|
|
|
|
|
if (keep_dim || x_rank == 1) {
|
|
|
|
|
dims_vector[dim] = 1;
|
|
|
|
|