|
|
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/modified_huber_loss_op.h"
|
|
|
|
|
#include <memory>
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
@ -86,38 +87,55 @@ class ModifiedHuberLossGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"), "X must be initialized.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Y"), "Y must be initialized.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("IntermediateVal"),
|
|
|
|
|
"Intermediate value must not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
|
|
|
|
"Input(Out@Grad) must not be null.");
|
|
|
|
|
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
auto y_dims = ctx->GetInputDim("Y");
|
|
|
|
|
auto intermediate_dims = ctx->GetInputDim("IntermediateVal");
|
|
|
|
|
auto out_grad_dims = ctx->GetInputDim(framework::GradVarName("Out"));
|
|
|
|
|
|
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
intermediate_dims, x_dims,
|
|
|
|
|
intermediate_dims, y_dims,
|
|
|
|
|
"The shape of X and intermediate value must be the same.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(out_grad_dims, x_dims,
|
|
|
|
|
PADDLE_ENFORCE_EQ(out_grad_dims, y_dims,
|
|
|
|
|
"The shape of Input(Out@Grad) and X must be the same.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (ctx->HasOutput(framework::GradVarName("X"))) {
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("X"), y_dims);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class ModifiedHuberLossGradOpDescMaker
|
|
|
|
|
: public framework::SingleGradOpDescMaker {
|
|
|
|
|
public:
|
|
|
|
|
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
std::unique_ptr<framework::OpDesc> Apply() const override {
|
|
|
|
|
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
|
|
|
|
|
op->SetType("modified_huber_loss_grad");
|
|
|
|
|
op->SetInput("Y", Input("Y"));
|
|
|
|
|
op->SetInput("IntermediateVal", Output("IntermediateVal"));
|
|
|
|
|
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
|
|
|
|
|
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
|
|
|
|
|
op->SetAttrMap(Attrs());
|
|
|
|
|
return op;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
REGISTER_OPERATOR(modified_huber_loss, ops::ModifiedHuberLossOp,
|
|
|
|
|
ops::ModifiedHuberLossOpMaker,
|
|
|
|
|
paddle::framework::DefaultGradOpDescMaker<true>);
|
|
|
|
|
ops::ModifiedHuberLossGradOpDescMaker);
|
|
|
|
|
REGISTER_OPERATOR(modified_huber_loss_grad, ops::ModifiedHuberLossGradOp);
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|