@ -23,24 +23,24 @@ class ScatterOp : public framework::OperatorWithKernel {
using framework : : OperatorWithKernel : : OperatorWithKernel ;
using framework : : OperatorWithKernel : : OperatorWithKernel ;
void InferShape ( framework : : InferShapeContext * ctx ) const override {
void InferShape ( framework : : InferShapeContext * ctx ) const override {
PADDLE_ENFORCE ( ctx - > HasInput ( " Ref " ) ,
PADDLE_ENFORCE ( ctx - > HasInput ( " X " ) ,
" Input( Ref ) of ScatterOp should not be null." ) ;
" Input( X ) of ScatterOp should not be null." ) ;
PADDLE_ENFORCE ( ctx - > HasInput ( " I ndex " ) ,
PADDLE_ENFORCE ( ctx - > HasInput ( " I ds " ) ,
" Input(I ndex ) of ScatterOp should not be null." ) ;
" Input(I ds ) of ScatterOp should not be null." ) ;
PADDLE_ENFORCE ( ctx - > HasInput ( " Updates " ) ,
PADDLE_ENFORCE ( ctx - > HasInput ( " Updates " ) ,
" Input(Updates) of ScatterOp should not be null. " ) ;
" Input(Updates) of ScatterOp should not be null. " ) ;
PADDLE_ENFORCE ( ctx - > HasOutput ( " Out " ) ,
PADDLE_ENFORCE ( ctx - > HasOutput ( " Out " ) ,
" Output(Out) of ScatterOp should not be null. " ) ;
" Output(Out) of ScatterOp should not be null. " ) ;
auto updates_dims = ctx - > GetInputDim ( " Updates " ) ;
auto updates_dims = ctx - > GetInputDim ( " Updates " ) ;
auto ref_dims = ctx - > GetInputDim ( " Ref " ) ;
auto ref_dims = ctx - > GetInputDim ( " X " ) ;
PADDLE_ENFORCE_EQ ( ctx - > GetInputDim ( " I ndex " ) . size ( ) , 1 ,
PADDLE_ENFORCE_EQ ( ctx - > GetInputDim ( " I ds " ) . size ( ) , 1 ,
" Update I ndex should be 1-D." ) ;
" Update I ds should be 1-D." ) ;
PADDLE_ENFORCE_EQ ( ref_dims . size ( ) , updates_dims . size ( ) ,
PADDLE_ENFORCE_EQ ( ref_dims . size ( ) , updates_dims . size ( ) ,
" Ref erence and Updates should have the same shape size" ) ;
" X erence and Updates should have the same shape size" ) ;
PADDLE_ENFORCE_EQ ( ctx - > GetInputDim ( " Updates " ) [ 0 ] ,
PADDLE_ENFORCE_EQ ( ctx - > GetInputDim ( " Updates " ) [ 0 ] ,
ctx - > GetInputDim ( " I ndex " ) [ 0 ] ,
ctx - > GetInputDim ( " I ds " ) [ 0 ] ,
" Updates and I ndex should have same batch-size." ) ;
" Updates and I ds should have same batch-size." ) ;
framework : : DDim data_dim ( updates_dims ) ;
framework : : DDim data_dim ( updates_dims ) ;
for ( int i = 1 ; i < data_dim . size ( ) ; + + i ) {
for ( int i = 1 ; i < data_dim . size ( ) ; + + i ) {
PADDLE_ENFORCE_EQ ( data_dim [ i ] , updates_dims [ i ] ) ;
PADDLE_ENFORCE_EQ ( data_dim [ i ] , updates_dims [ i ] ) ;
@ -52,7 +52,7 @@ class ScatterOp : public framework::OperatorWithKernel {
framework : : OpKernelType GetExpectedKernelType (
framework : : OpKernelType GetExpectedKernelType (
const framework : : ExecutionContext & ctx ) const override {
const framework : : ExecutionContext & ctx ) const override {
return framework : : OpKernelType (
return framework : : OpKernelType (
framework : : ToDataType ( ctx . Input < Tensor > ( " Ref " ) - > type ( ) ) ,
framework : : ToDataType ( ctx . Input < Tensor > ( " X " ) - > type ( ) ) ,
ctx . device_context ( ) ) ;
ctx . device_context ( ) ) ;
}
}
} ;
} ;
@ -64,14 +64,14 @@ class ScatterGradOp : public framework::OperatorWithKernel {
void InferShape ( framework : : InferShapeContext * ctx ) const override {
void InferShape ( framework : : InferShapeContext * ctx ) const override {
ctx - > SetOutputDim ( framework : : GradVarName ( " Updates " ) ,
ctx - > SetOutputDim ( framework : : GradVarName ( " Updates " ) ,
ctx - > GetInputDim ( " Updates " ) ) ;
ctx - > GetInputDim ( " Updates " ) ) ;
ctx - > SetOutputDim ( framework : : GradVarName ( " Ref " ) , ctx - > GetInputDim ( " Ref " ) ) ;
ctx - > SetOutputDim ( framework : : GradVarName ( " X " ) , ctx - > GetInputDim ( " X " ) ) ;
}
}
protected :
protected :
framework : : OpKernelType GetExpectedKernelType (
framework : : OpKernelType GetExpectedKernelType (
const framework : : ExecutionContext & ctx ) const override {
const framework : : ExecutionContext & ctx ) const override {
return framework : : OpKernelType (
return framework : : OpKernelType (
framework : : ToDataType ( ctx . Input < Tensor > ( " Ref " ) - > type ( ) ) ,
framework : : ToDataType ( ctx . Input < Tensor > ( " X " ) - > type ( ) ) ,
ctx . device_context ( ) ) ;
ctx . device_context ( ) ) ;
}
}
} ;
} ;
@ -80,9 +80,8 @@ class ScatterOpMaker : public framework::OpProtoAndCheckerMaker {
public :
public :
ScatterOpMaker ( OpProto * proto , OpAttrChecker * op_checker )
ScatterOpMaker ( OpProto * proto , OpAttrChecker * op_checker )
: OpProtoAndCheckerMaker ( proto , op_checker ) {
: OpProtoAndCheckerMaker ( proto , op_checker ) {
AddInput ( " Ref " , " The source input of scatter op " ) ;
AddInput ( " X " , " The source input of scatter op " ) ;
AddInput ( " Index " ,
AddInput ( " Ids " , " The index input of scatter op where X will be updated " ) ;
" The index input of scatter op where Ref will be updated " ) ;
AddInput ( " Updates " , " The updated value of updates op " ) ;
AddInput ( " Updates " , " The updated value of updates op " ) ;
AddOutput ( " Out " , " The output of add op " ) ;
AddOutput ( " Out " , " The output of add op " ) ;
AddComment ( R " DOC(
AddComment ( R " DOC(
@ -91,8 +90,8 @@ Scatter Operator.
This operator obtains output by updating the input on selected indices on the first axis :
This operator obtains output by updating the input on selected indices on the first axis :
$ $
$ $
Out = Ref \ \
Out = X \ \
Out [ I ndex] = Ref [ Index ] + Updates
Out [ I ds] = X [ Ids ] + Updates
$ $
$ $
) DOC " );
) DOC " );