|
|
|
@ -24,8 +24,18 @@ class ScatterOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(const framework::InferShapeContext &ctx) const override {
|
|
|
|
|
framework::DDim output_dims(ctx.Input<Tensor>("Ref")->dims());
|
|
|
|
|
ctx.Output<Tensor>("Out")->Resize(output_dims);
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("Index")->dims().size(), 1,
|
|
|
|
|
"Update Index should be 1-D.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("Ref")->dims().size(),
|
|
|
|
|
ctx.Input<Tensor>("Updates")->dims().size(),
|
|
|
|
|
"Reference and Updates should have the same shape size");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("Updates")->dims()[0],
|
|
|
|
|
ctx.Input<Tensor>("Index")->dims()[0],
|
|
|
|
|
"Updates and Index should have same batch-size.");
|
|
|
|
|
framework::DDim data_dim(ctx.Input<Tensor>("Updates")->dims());
|
|
|
|
|
for (int i = 1; i < data_dim.size(); ++i)
|
|
|
|
|
PADDLE_ENFORCE_EQ(data_dim[i], ctx.Input<Tensor>("Updates")->dims()[i]);
|
|
|
|
|
ctx.Output<Tensor>("Out")->Resize(ctx.Input<Tensor>("Ref")->dims());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -35,13 +45,13 @@ class ScatterGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(const framework::InferShapeContext &ctx) const override {
|
|
|
|
|
auto Updates_grad = ctx.Output<Tensor>(framework::GradVarName("Updates"));
|
|
|
|
|
auto Updates = ctx.Input<Tensor>("Updates");
|
|
|
|
|
auto Ref_grad = ctx.Output<Tensor>(framework::GradVarName("Ref"));
|
|
|
|
|
auto Ref = ctx.Input<Tensor>("Ref");
|
|
|
|
|
auto *dUpdates = ctx.Output<Tensor>(framework::GradVarName("Updates"));
|
|
|
|
|
auto *Updates = ctx.Input<Tensor>("Updates");
|
|
|
|
|
auto *dRef = ctx.Output<Tensor>(framework::GradVarName("Ref"));
|
|
|
|
|
auto *Ref = ctx.Input<Tensor>("Ref");
|
|
|
|
|
|
|
|
|
|
Ref_grad->Resize(Ref->dims());
|
|
|
|
|
Updates_grad->Resize(Updates->dims());
|
|
|
|
|
dRef->Resize(Ref->dims());
|
|
|
|
|
dUpdates->Resize(Updates->dims());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|