add inplace to assign op, test=develop (#19927)

expand_as_op_1
Zeng Jinle 6 years ago committed by GitHub
parent 55ce696986
commit cc157d5990
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -45,15 +45,16 @@ class NoNeedBufferVarsInference {
const AttributeMap &attrs_;
};
#define DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(class_type, ...) \
class class_type : public ::paddle::framework::NoNeedBufferVarsInference { \
public: \
using ::paddle::framework::NoNeedBufferVarsInference:: \
NoNeedBufferVarsInference; \
\
std::unordered_set<std::string> operator()() const override { \
return {__VA_ARGS__}; \
} \
#define DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(class_type, ...) \
class class_type final \
: public ::paddle::framework::NoNeedBufferVarsInference { \
public: \
using ::paddle::framework::NoNeedBufferVarsInference:: \
NoNeedBufferVarsInference; \
\
std::unordered_set<std::string> operator()() const final { \
return {__VA_ARGS__}; \
} \
}
} // namespace framework

@ -144,12 +144,14 @@ class AssignGradMaker : public framework::SingleGradOpDescMaker {
}
};
DECLARE_INPLACE_OP_INFERER(AssignOpInplaceInferer, {"X", "Out"});
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(assign, ops::AssignOp, ops::AssignGradMaker,
ops::AssignOpProtoMaker);
ops::AssignOpProtoMaker, ops::AssignOpInplaceInferer);
REGISTER_OP_CPU_KERNEL_FUNCTOR(assign, float, ops::AssignKernel, double,
ops::AssignKernel, int, ops::AssignKernel,
int64_t, ops::AssignKernel);

Loading…
Cancel
Save