Add the cpu version of segment sum mean max min oprevert-27520-disable_pr
parent
afe94903c3
commit
f4c750d721
@ -0,0 +1,148 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/math/segment_pooling.h"
|
||||
#include <string>
|
||||
#include "paddle/fluid/framework/eigen.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/math/math_function.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
template <typename T, typename IndexT>
|
||||
class SegmentPoolFunctor<platform::CPUDeviceContext, T, IndexT> {
|
||||
public:
|
||||
void operator()(const platform::CPUDeviceContext& context,
|
||||
const framework::Tensor& input,
|
||||
const framework::Tensor& segments, framework::Tensor* output,
|
||||
framework::Tensor* index,
|
||||
const std::string pooltype = "SUM") {
|
||||
const IndexT* segment_ids = segments.data<IndexT>();
|
||||
auto curent_id = segment_ids[0];
|
||||
int64_t last_idx = 0;
|
||||
int64_t w = input.numel() / input.dims()[0];
|
||||
auto& place = *context.eigen_device();
|
||||
for (int64_t idx = 1; idx <= segments.numel(); ++idx) {
|
||||
if (idx < segments.numel()) {
|
||||
if (segment_ids[idx] == curent_id) continue;
|
||||
PADDLE_ENFORCE_GE(segment_ids[idx], curent_id,
|
||||
platform::errors::InvalidArgument(
|
||||
"The segment ids should be sorted, but got "
|
||||
"segment_ids[%d]:%d > segment_ids[%d]:%d.",
|
||||
idx - 1, curent_id, idx, segment_ids[idx]));
|
||||
}
|
||||
|
||||
Tensor out_t = output->Slice(curent_id, curent_id + 1);
|
||||
Tensor in_t = input.Slice(last_idx, idx);
|
||||
|
||||
int64_t h = idx - last_idx;
|
||||
auto in_e =
|
||||
framework::EigenMatrix<T>::From(in_t, framework::make_ddim({h, w}));
|
||||
auto out_e = framework::EigenVector<T>::Flatten(out_t);
|
||||
|
||||
auto reduce_dim = Eigen::array<int, 1>({{0}});
|
||||
if (pooltype == "MEAN") {
|
||||
out_e.device(place) = in_e.mean(reduce_dim);
|
||||
} else if (pooltype == "SUM") {
|
||||
out_e.device(place) = in_e.sum(reduce_dim);
|
||||
} else if (pooltype == "MAX") {
|
||||
out_e.device(place) = in_e.maximum(reduce_dim);
|
||||
} else if (pooltype == "MIN") {
|
||||
out_e.device(place) = in_e.minimum(reduce_dim);
|
||||
} else {
|
||||
PADDLE_THROW(platform::errors::InvalidArgument(
|
||||
"Unsupported segment pooling type, only MEAN, SUM, MAX, MIN "
|
||||
"available, but got %s.",
|
||||
pooltype));
|
||||
}
|
||||
|
||||
last_idx = idx;
|
||||
if (idx < segments.numel()) curent_id = segment_ids[idx];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename IndexT>
|
||||
class SegmentPoolGradFunctor<platform::CPUDeviceContext, T, IndexT> {
|
||||
public:
|
||||
void operator()(const platform::CPUDeviceContext& context,
|
||||
const framework::Tensor& input,
|
||||
const framework::Tensor& output,
|
||||
const framework::Tensor& out_grad,
|
||||
const framework::Tensor& segments, framework::Tensor* in_grad,
|
||||
const framework::Tensor* index = nullptr,
|
||||
const std::string pooltype = "SUM") {
|
||||
const IndexT* segment_ids = segments.data<IndexT>();
|
||||
auto& place = *context.eigen_device();
|
||||
auto curent_id = segment_ids[0];
|
||||
int64_t last_idx = 0;
|
||||
int64_t w = in_grad->numel() / in_grad->dims()[0];
|
||||
for (int64_t idx = 1; idx <= segments.numel(); ++idx) {
|
||||
if (idx < segments.numel()) {
|
||||
if (segment_ids[idx] == curent_id) continue;
|
||||
PADDLE_ENFORCE_GE(segment_ids[idx], curent_id,
|
||||
platform::errors::InvalidArgument(
|
||||
"The segment ids should be sorted, but got "
|
||||
"segment_ids[%d]:%d > segment_ids[%d]:%d.",
|
||||
idx - 1, curent_id, idx, segment_ids[idx]));
|
||||
}
|
||||
|
||||
Tensor out_g_t = out_grad.Slice(curent_id, curent_id + 1);
|
||||
Tensor in_g_t = in_grad->Slice(last_idx, idx);
|
||||
|
||||
int64_t h = idx - last_idx;
|
||||
auto in_g_e = framework::EigenMatrix<T>::From(in_g_t, {h, w});
|
||||
auto out_g_e = framework::EigenMatrix<T>::From(out_g_t, {1, w});
|
||||
Eigen::DSizes<int, 2> bcast(h, 1);
|
||||
|
||||
if (pooltype == "MEAN") {
|
||||
in_g_e.device(place) = (out_g_e / static_cast<T>(h)).broadcast(bcast);
|
||||
} else if (pooltype == "SUM") {
|
||||
in_g_e.device(place) = out_g_e.broadcast(bcast);
|
||||
} else if (pooltype == "MAX" || pooltype == "MIN") {
|
||||
Tensor out_t = output.Slice(curent_id, curent_id + 1);
|
||||
Tensor in_t = input.Slice(last_idx, idx);
|
||||
auto in_e = framework::EigenMatrix<T>::From(in_t, {h, w});
|
||||
auto out_e = framework::EigenMatrix<T>::From(out_t, {1, w});
|
||||
in_g_e.device(place) =
|
||||
(in_e == out_e.broadcast(bcast)).template cast<T>() *
|
||||
out_g_e.broadcast(bcast);
|
||||
} else {
|
||||
PADDLE_THROW(platform::errors::InvalidArgument(
|
||||
"Unsupported segment pooling type, only MEAN, SUM, MAX, MIN "
|
||||
"available, but got %s.",
|
||||
pooltype));
|
||||
}
|
||||
|
||||
last_idx = idx;
|
||||
if (idx < segments.numel()) curent_id = segment_ids[idx];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
using CPU = platform::CPUDeviceContext;
|
||||
template class SegmentPoolFunctor<CPU, float, int>;
|
||||
template class SegmentPoolFunctor<CPU, float, int64_t>;
|
||||
template class SegmentPoolFunctor<CPU, double, int>;
|
||||
template class SegmentPoolFunctor<CPU, double, int64_t>;
|
||||
template class SegmentPoolGradFunctor<CPU, float, int>;
|
||||
template class SegmentPoolGradFunctor<CPU, float, int64_t>;
|
||||
template class SegmentPoolGradFunctor<CPU, double, int>;
|
||||
template class SegmentPoolGradFunctor<CPU, double, int64_t>;
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,46 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include "paddle/fluid/framework/tensor.h"
|
||||
#include "paddle/fluid/platform/device_context.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename DeviceContext, typename T, typename IndexT>
|
||||
class SegmentPoolFunctor {
|
||||
public:
|
||||
/* mean pool has summed_ids output */
|
||||
void operator()(const DeviceContext& context, const framework::Tensor& input,
|
||||
const framework::Tensor& segments, framework::Tensor* output,
|
||||
framework::Tensor* summed_ids = nullptr,
|
||||
const std::string pooltype = "SUM");
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T, typename IndexT>
|
||||
class SegmentPoolGradFunctor {
|
||||
public:
|
||||
/* mean pool has summed_ids output */
|
||||
void operator()(const DeviceContext& context, const framework::Tensor& input,
|
||||
const framework::Tensor& output,
|
||||
const framework::Tensor& out_grad,
|
||||
const framework::Tensor& segments, framework::Tensor* in_grad,
|
||||
const framework::Tensor* summed_ids = nullptr,
|
||||
const std::string pooltype = "SUM");
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,166 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/segment_pool_op.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class SegmentPoolOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "SegmentPool");
|
||||
OP_INOUT_CHECK(ctx->HasInput("SegmentIds"), "Input", "SegmentIds",
|
||||
"SegmentPool");
|
||||
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "SegmentPool");
|
||||
auto dims = ctx->GetInputDim("X");
|
||||
dims[0] = -1;
|
||||
ctx->SetOutputDim("Out", dims);
|
||||
|
||||
if (ctx->Attrs().Get<std::string>("pooltype") == "MEAN") {
|
||||
OP_INOUT_CHECK(ctx->HasOutput("SummedIds"), "Output", "SummedIds",
|
||||
"SegmentPool");
|
||||
ctx->SetOutputDim("SummedIds", {-1, 1});
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(
|
||||
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
|
||||
ctx.device_context());
|
||||
}
|
||||
};
|
||||
|
||||
class SegmentPoolOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X", "(Tensor) The input data of SegmentPoolOp");
|
||||
AddInput("SegmentIds",
|
||||
"(Tensor) 1-D tensor which have the same size with the fist "
|
||||
"dimension of input X.");
|
||||
AddOutput("Out", "(Tensor) The output of SegmentPoolOp.");
|
||||
AddOutput("SummedIds",
|
||||
"(Tensor) This tensor is used to counts of segment ids for the "
|
||||
"backward of the mean pool.")
|
||||
.AsIntermediate();
|
||||
AddAttr<std::string>(
|
||||
"pooltype",
|
||||
"(string, default 'SUM') the pooling type of SegmentPoolOp.")
|
||||
.SetDefault("SUM")
|
||||
.InEnum({"SUM", "MEAN", "MIN", "MAX"});
|
||||
AddComment(R"DOC(
|
||||
Segment Pool Operator.
|
||||
|
||||
This operator will pool the elements of input `X` which with the same index
|
||||
in `SegmentIds`.
|
||||
|
||||
For SUM operation, it computes a tensor such that $Out_i = \sum_{j} X_{j}$
|
||||
where sum is over j such that `SegmentIds[j] == i`.
|
||||
|
||||
For MEAN operation, it computes a tensor such that
|
||||
$Out_i = \frac{1}{n_i} \sum_{j} X_{j}$ where sum is over j such that
|
||||
`SegmentIds[j] == i` and $n_i$ is the number of all index `SegmentIds[j] == i`.
|
||||
|
||||
For MIN operation, it computes a tensor such that $Out_i = \min_{j} X_{j}$
|
||||
where min is over j such that `SegmentIds[j] == i`.
|
||||
|
||||
For MAX operation, it computes a tensor such that $Out_i = \max_{j} X_{j}$
|
||||
where max is over j such that `SegmentIds[j] == i`.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class SegmentPoolGradOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
|
||||
framework::GradVarName("Out"), "SegmentPoolGrad");
|
||||
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "SegmentPoolGrad");
|
||||
auto og_dims = ctx->GetInputDim(framework::GradVarName("Out"));
|
||||
auto x_dims = ctx->GetInputDim("X");
|
||||
PADDLE_ENFORCE_EQ(og_dims.size(), x_dims.size(),
|
||||
platform::errors::InvalidArgument(
|
||||
"The rank of output grad must equal to Input(X). But "
|
||||
"received: input rank %u, input shape [%s].",
|
||||
og_dims.size(), og_dims));
|
||||
for (int64_t i = 1; i < og_dims.size(); ++i) {
|
||||
PADDLE_ENFORCE_EQ(
|
||||
og_dims[i], x_dims[i],
|
||||
platform::errors::InvalidArgument(
|
||||
"The dimension mismatch between Input(OUT@GRAD) and "
|
||||
"Input(X). Received Input(OUT@GRAD): input rank %u, "
|
||||
"input shape [%s]; received Input(X): input rank %u, "
|
||||
"input shape [%s].",
|
||||
og_dims.size(), og_dims, x_dims.size(), x_dims));
|
||||
}
|
||||
|
||||
ctx->ShareDim("X", /*->*/ framework::GradVarName("X"));
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
|
||||
ctx, framework::GradVarName("Out")),
|
||||
ctx.device_context());
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class SegmentPoolGradOpMaker : public framework::SingleGradOpMaker<T> {
|
||||
public:
|
||||
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
|
||||
|
||||
protected:
|
||||
void Apply(GradOpPtr<T> op_desc_ptr) const override {
|
||||
op_desc_ptr->SetType("segment_pool_grad");
|
||||
op_desc_ptr->SetInput("X", this->Input("X"));
|
||||
op_desc_ptr->SetInput("SegmentIds", this->Input("SegmentIds"));
|
||||
op_desc_ptr->SetInput("Out", this->Output("Out"));
|
||||
if (BOOST_GET_CONST(std::string, this->GetAttr("pooltype")) == "MEAN") {
|
||||
op_desc_ptr->SetInput("SummedIds", this->Output("SummedIds"));
|
||||
}
|
||||
op_desc_ptr->SetInput(framework::GradVarName("Out"),
|
||||
this->OutputGrad("Out"));
|
||||
op_desc_ptr->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
|
||||
op_desc_ptr->SetAttrMap(this->Attrs());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(segment_pool, ops::SegmentPoolOp, ops::SegmentPoolOpMaker,
|
||||
ops::SegmentPoolGradOpMaker<paddle::framework::OpDesc>,
|
||||
ops::SegmentPoolGradOpMaker<paddle::imperative::OpBase>);
|
||||
REGISTER_OPERATOR(segment_pool_grad, ops::SegmentPoolGradOp);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
segment_pool,
|
||||
ops::SegmentPoolKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::SegmentPoolKernel<paddle::platform::CPUDeviceContext, double>);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
segment_pool_grad,
|
||||
ops::SegmentPoolGradKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::SegmentPoolGradKernel<paddle::platform::CPUDeviceContext, double>);
|
@ -0,0 +1,130 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include "paddle/fluid/framework/eigen.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/math/math_function.h"
|
||||
#include "paddle/fluid/operators/math/segment_pooling.h"
|
||||
#include "paddle/fluid/platform/macros.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
template <typename DeviceContext, typename T, typename IndexT>
|
||||
void SegmentKernelLaunchHelper(const framework::ExecutionContext& context) {
|
||||
auto* input = context.Input<Tensor>("X");
|
||||
auto* segment = context.Input<Tensor>("SegmentIds");
|
||||
auto* output = context.Output<Tensor>("Out");
|
||||
std::string pooltype = context.Attr<std::string>("pooltype");
|
||||
Tensor* summed_ids = nullptr;
|
||||
|
||||
int64_t num_indices = segment->numel();
|
||||
PADDLE_ENFORCE_EQ(
|
||||
num_indices, input->dims()[0],
|
||||
platform::errors::InvalidArgument(
|
||||
"Segment_ids should be the same size as dimension 0 of input X."));
|
||||
PADDLE_ENFORCE_EQ(num_indices, segment->dims()[0],
|
||||
platform::errors::InvalidArgument(
|
||||
"Segment_ids should be 1-D tensor, or it's other "
|
||||
"dimension size is 1. Segment_ids's shape is: [%s].",
|
||||
segment->dims()));
|
||||
|
||||
if (input->numel() == 0 || segment->numel() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
bool cpu_place = context.GetPlace().type() == typeid(platform::CPUPlace);
|
||||
if (cpu_place) {
|
||||
auto dims = input->dims();
|
||||
auto* segment_ids = segment->data<IndexT>();
|
||||
dims[0] = static_cast<int64_t>(segment_ids[segment->numel() - 1] + 1);
|
||||
PADDLE_ENFORCE_GT(
|
||||
dims[0], 0,
|
||||
platform::errors::InvalidArgument(
|
||||
"Segment ids must be >= 0, but got last id %d", dims[0]));
|
||||
output->Resize({dims});
|
||||
output->mutable_data<T>(context.GetPlace());
|
||||
math::SetConstant<DeviceContext, T> set_zero;
|
||||
auto& dev_ctx = context.template device_context<DeviceContext>();
|
||||
set_zero(dev_ctx, output, static_cast<T>(0));
|
||||
}
|
||||
|
||||
SegmentPoolFunctor<DeviceContext, T, IndexT> pool;
|
||||
|
||||
pool(context.template device_context<DeviceContext>(), *input, *segment,
|
||||
output, summed_ids, pooltype);
|
||||
}
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class SegmentPoolKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto* segment = context.Input<Tensor>("SegmentIds");
|
||||
auto index_type = segment->type();
|
||||
if (index_type == framework::proto::VarType::INT32) {
|
||||
SegmentKernelLaunchHelper<DeviceContext, T, int>(context);
|
||||
} else if (index_type == framework::proto::VarType::INT64) {
|
||||
SegmentKernelLaunchHelper<DeviceContext, T, int64_t>(context);
|
||||
} else {
|
||||
PADDLE_THROW(platform::errors::InvalidArgument(
|
||||
"Unsupported index type, Expected int, int64, but got %s.",
|
||||
index_type));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class SegmentPoolGradKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto* input = context.Input<Tensor>("X");
|
||||
auto* output = context.Input<Tensor>("Out");
|
||||
auto* segment = context.Input<Tensor>("SegmentIds");
|
||||
auto* out_g = context.Input<Tensor>(framework::GradVarName("Out"));
|
||||
auto* in_g = context.Output<Tensor>(framework::GradVarName("X"));
|
||||
std::string pooltype = context.Attr<std::string>("pooltype");
|
||||
|
||||
const Tensor* summed_ids = nullptr;
|
||||
if (pooltype == "MEAN") {
|
||||
summed_ids = context.Input<Tensor>("SummedIds");
|
||||
}
|
||||
|
||||
in_g->mutable_data<T>(context.GetPlace());
|
||||
math::SetConstant<DeviceContext, T> set_zero;
|
||||
auto& dev_ctx = context.template device_context<DeviceContext>();
|
||||
set_zero(dev_ctx, in_g, static_cast<T>(0));
|
||||
|
||||
auto index_type = segment->type();
|
||||
if (index_type == framework::proto::VarType::INT32) {
|
||||
SegmentPoolGradFunctor<DeviceContext, T, int> pool;
|
||||
pool(context.template device_context<DeviceContext>(), *input, *output,
|
||||
*out_g, *segment, in_g, summed_ids, pooltype);
|
||||
} else if (index_type == framework::proto::VarType::INT64) {
|
||||
SegmentPoolGradFunctor<DeviceContext, T, int64_t> pool;
|
||||
pool(context.template device_context<DeviceContext>(), *input, *output,
|
||||
*out_g, *segment, in_g, summed_ids, pooltype);
|
||||
} else {
|
||||
PADDLE_THROW(platform::errors::InvalidArgument(
|
||||
"Unsupported index type, Expected int, int64, but got %s.",
|
||||
index_type));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,202 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
import sys
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
def compute_segment_sum(x, segment_ids):
|
||||
length = segment_ids[-1] + 1
|
||||
target_shape = list(x.shape)
|
||||
target_shape[0] = length
|
||||
results = np.zeros(target_shape, dtype=x.dtype)
|
||||
for index, ids in enumerate(segment_ids):
|
||||
results[ids, :] += x[index, :]
|
||||
return results
|
||||
|
||||
|
||||
def compute_segment_mean(x, segment_ids):
|
||||
length = segment_ids[-1] + 1
|
||||
target_shape = list(x.shape)
|
||||
target_shape[0] = length
|
||||
results = np.zeros(target_shape, dtype=x.dtype)
|
||||
count = np.zeros(length, dtype=x.dtype) + 1e-8
|
||||
for index, ids in enumerate(segment_ids):
|
||||
results[ids, :] += x[index, :]
|
||||
count[ids] += 1
|
||||
results = results / count.reshape([-1, 1])
|
||||
return results
|
||||
|
||||
|
||||
def compute_segment_min_max(x, segment_ids, pooltype="MAX"):
|
||||
length = segment_ids[-1] + 1
|
||||
target_shape = list(x.shape)
|
||||
target_shape[0] = length
|
||||
gradient = np.zeros_like(x)
|
||||
results = np.zeros(target_shape, dtype=x.dtype)
|
||||
last_idx = 0
|
||||
current_id = segment_ids[0]
|
||||
for idx in range(1, len(segment_ids) + 1):
|
||||
if idx < len(segment_ids):
|
||||
if segment_ids[idx] == current_id:
|
||||
continue
|
||||
sub_x = x[last_idx:idx, :]
|
||||
if pooltype == "MAX":
|
||||
results[current_id] = np.amax(sub_x, axis=0)
|
||||
elif pooltype == "MIN":
|
||||
results[current_id] = np.amin(sub_x, axis=0)
|
||||
else:
|
||||
raise ValueError("Invalid pooltype, only MAX, MIN supported!")
|
||||
gradient[last_idx:idx, :][sub_x == results[current_id]] = 1
|
||||
last_idx = idx
|
||||
if idx < len(segment_ids):
|
||||
current_id = segment_ids[idx]
|
||||
|
||||
return results, gradient / results.size
|
||||
|
||||
|
||||
class TestSegmentOps(OpTest):
|
||||
def set_data(self):
|
||||
x = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
|
||||
segment_ids = self.set_segment(len(x), len(x) // 5 + 1)
|
||||
return x, segment_ids
|
||||
|
||||
def set_segment(self, origin_len, reduce_len):
|
||||
segment = np.zeros(reduce_len, dtype='int64')
|
||||
segment = np.random.randint(0, reduce_len, size=[origin_len])
|
||||
segment = np.sort(segment)
|
||||
return segment.astype('int64')
|
||||
|
||||
def compute(self, x, segment_ids):
|
||||
return compute_segment_sum(x, segment_ids)
|
||||
|
||||
def prepare(self):
|
||||
self.op_type = "segment_pool"
|
||||
self.dtype = np.float64
|
||||
self.shape = [30, 15]
|
||||
self.attrs = {"pooltype": "SUM"}
|
||||
|
||||
def setUp(self):
|
||||
self.prepare()
|
||||
x, segment_ids = self.set_data()
|
||||
result = self.compute(x, segment_ids)
|
||||
self.inputs = {
|
||||
'X': x.astype(self.dtype),
|
||||
'SegmentIds': segment_ids.astype(np.int64)
|
||||
}
|
||||
self.outputs = {'Out': result.astype(self.dtype)}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad(["X"], "Out")
|
||||
|
||||
|
||||
class TestSegmentSum2(TestSegmentOps):
|
||||
def prepare(self):
|
||||
super(TestSegmentSum2, self).prepare()
|
||||
self.shape = [40, 20]
|
||||
self.dtype = np.float32
|
||||
|
||||
def setUp(self):
|
||||
self.prepare()
|
||||
x, segment_ids = self.set_data()
|
||||
result = self.compute(x, segment_ids)
|
||||
self.inputs = {
|
||||
'X': x.astype(self.dtype),
|
||||
'SegmentIds': segment_ids.astype(np.int32)
|
||||
}
|
||||
self.outputs = {'Out': result.astype(self.dtype)}
|
||||
|
||||
|
||||
class TestSegmentMax(TestSegmentOps):
|
||||
def compute(self, x, segment_ids):
|
||||
return compute_segment_min_max(x, segment_ids, pooltype="MAX")
|
||||
|
||||
def prepare(self):
|
||||
super(TestSegmentMax, self).prepare()
|
||||
self.shape = [40, 20]
|
||||
self.attrs = {'pooltype': "MAX"}
|
||||
|
||||
def setUp(self):
|
||||
self.prepare()
|
||||
x, segment_ids = self.set_data()
|
||||
result, self.gradient = self.compute(x, segment_ids)
|
||||
self.inputs = {
|
||||
'X': x.astype(self.dtype),
|
||||
'SegmentIds': segment_ids.astype(np.int32)
|
||||
}
|
||||
self.outputs = {'Out': result.astype(self.dtype)}
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad(["X"], "Out", user_defined_grads=[self.gradient])
|
||||
|
||||
|
||||
class TestSegmentMax2(TestSegmentMax):
|
||||
def prepare(self):
|
||||
super(TestSegmentMax2, self).prepare()
|
||||
self.dtype = np.float32
|
||||
|
||||
|
||||
class TestSegmentMin(TestSegmentMax):
|
||||
def compute(self, x, segment_ids):
|
||||
return compute_segment_min_max(x, segment_ids, pooltype="MIN")
|
||||
|
||||
def prepare(self):
|
||||
super(TestSegmentMin, self).prepare()
|
||||
self.attrs = {'pooltype': "MIN"}
|
||||
|
||||
|
||||
class TestSegmentMin2(TestSegmentMin):
|
||||
def prepare(self):
|
||||
super(TestSegmentMin2, self).prepare()
|
||||
self.dtype = np.float32
|
||||
|
||||
|
||||
class TestSegmentMean(TestSegmentOps):
|
||||
def compute(self, x, segment_ids):
|
||||
return compute_segment_mean(x, segment_ids)
|
||||
|
||||
def prepare(self):
|
||||
super(TestSegmentMean, self).prepare()
|
||||
self.shape = [40, 20]
|
||||
self.attrs = {'pooltype': "MEAN"}
|
||||
|
||||
def setUp(self):
|
||||
self.prepare()
|
||||
x, segment_ids = self.set_data()
|
||||
result = self.compute(x, segment_ids)
|
||||
self.inputs = {'X': x, 'SegmentIds': segment_ids}
|
||||
self.outputs = {
|
||||
'Out': result,
|
||||
'SummedIds': compute_segment_sum(
|
||||
np.ones([len(x), 1]).astype(self.dtype), segment_ids)
|
||||
}
|
||||
|
||||
|
||||
class TestSegmentMean2(TestSegmentMean):
|
||||
def prepare(self):
|
||||
super(TestSegmentMean2, self).prepare()
|
||||
self.dtype = np.float32
|
||||
self.shape = [30, 20]
|
||||
self.attrs = {'pooltype': "MEAN"}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue