|
|
|
@ -11,6 +11,7 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/operators/sum_op.h"
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include "paddle/operators/net_op.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
@ -57,21 +58,23 @@ or not. But the output only shares the LoD with the first input.
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class SumGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
class SumGradOp : public NetOp {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
SumGradOp(const std::string& type, const framework::VariableNameMap& inputs,
|
|
|
|
|
const framework::VariableNameMap& outputs,
|
|
|
|
|
const framework::AttributeMap& attrs)
|
|
|
|
|
: NetOp(type, inputs, outputs, attrs) {
|
|
|
|
|
auto& x_grad_names = Outputs(framework::GradVarName("X"));
|
|
|
|
|
auto out_grad_name = this->Input(framework::GradVarName("Out"));
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(framework::InferShapeContextBase* ctx) const override {
|
|
|
|
|
auto out_grad_dims = ctx->GetInputDim(framework::GradVarName("Out"));
|
|
|
|
|
auto x_grad_names = ctx->Outputs(framework::GradVarName("X"));
|
|
|
|
|
size_t x_length = x_grad_names.size();
|
|
|
|
|
std::vector<framework::DDim> x_grad_dims;
|
|
|
|
|
x_grad_dims.reserve(x_length);
|
|
|
|
|
for (size_t i = 0; i < x_length; ++i) {
|
|
|
|
|
x_grad_dims.push_back(out_grad_dims);
|
|
|
|
|
framework::AttributeMap grad_attrs;
|
|
|
|
|
grad_attrs["scale"] = 1.0f;
|
|
|
|
|
for (auto& x_grad_name : x_grad_names) {
|
|
|
|
|
AppendOp(framework::OpRegistry::CreateOp(
|
|
|
|
|
"scale", {{"X", {out_grad_name}}}, {{"Out", {x_grad_name}}},
|
|
|
|
|
grad_attrs));
|
|
|
|
|
}
|
|
|
|
|
ctx->SetOutputsDim(framework::GradVarName("X"), x_grad_dims);
|
|
|
|
|
CompleteAddOp(false);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -81,5 +84,3 @@ class SumGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
REGISTER_OP(sum, ops::SumOp, ops::SumOpMaker, sum_grad, ops::SumGradOp);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(sum, ops::SumKernel<paddle::platform::CPUPlace, float>);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(sum_grad,
|
|
|
|
|
ops::SumGradKernel<paddle::platform::CPUPlace, float>);
|
|
|
|
|