|
|
|
@ -66,27 +66,47 @@ class AssignFunctor {
|
|
|
|
|
const platform::DeviceContext &dev_ctx_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class AssignOp : public framework::OperatorBase {
|
|
|
|
|
class AssignOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
AssignOp(const std::string &type, const framework::VariableNameMap &inputs,
|
|
|
|
|
const framework::VariableNameMap &outputs,
|
|
|
|
|
const framework::AttributeMap &attrs)
|
|
|
|
|
: OperatorBase(type, inputs, outputs, attrs) {}
|
|
|
|
|
: OperatorWithKernel(type, inputs, outputs, attrs) {}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
void RunImpl(const framework::Scope &scope,
|
|
|
|
|
const platform::Place &place) const override {
|
|
|
|
|
auto *x = scope.FindVar(Input("X"));
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
if (ctx->HasInput("X")) {
|
|
|
|
|
auto type = ctx->GetInputsVarType("X")[0];
|
|
|
|
|
if (type == framework::proto::VarType::SELECTED_ROWS ||
|
|
|
|
|
type == framework::proto::VarType::LOD_TENSOR) {
|
|
|
|
|
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
|
|
|
|
|
if (type == framework::proto::VarType::LOD_TENSOR) {
|
|
|
|
|
ctx->ShareLoD("X", /*->*/ "Out");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
framework::OpKernelType GetExpectedKernelType(
|
|
|
|
|
const framework::ExecutionContext &ctx) const override {
|
|
|
|
|
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
|
|
|
|
|
ctx.device_context());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class AssignKernel {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(const framework::ExecutionContext &ctx) const {
|
|
|
|
|
auto *x = ctx.InputVar("X");
|
|
|
|
|
if (x == nullptr) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
auto *out = scope.FindVar(Output("Out"));
|
|
|
|
|
auto *out = ctx.OutputVar("Out");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
out != nullptr,
|
|
|
|
|
"The Output(Out) should not be null if the Input(X) is set.");
|
|
|
|
|
|
|
|
|
|
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
|
|
|
|
|
auto &dev_ctx = *pool.Get(place);
|
|
|
|
|
auto &dev_ctx = *pool.Get(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
framework::VisitVarType(*x, AssignFunctor(out, dev_ctx));
|
|
|
|
|
}
|
|
|
|
@ -110,19 +130,6 @@ raise error if the type is not listed above.
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class AssignInferShape : public framework::InferShapeBase {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(framework::InferShapeContext *context) const override {
|
|
|
|
|
if (context->HasInput("X")) {
|
|
|
|
|
auto type = context->GetInputsVarType("X")[0];
|
|
|
|
|
if (type == framework::proto::VarType::SELECTED_ROWS ||
|
|
|
|
|
type == framework::proto::VarType::LOD_TENSOR) {
|
|
|
|
|
context->SetOutputDim("Out", context->GetInputDim("X"));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class AssignGradMaker : public framework::SingleGradOpDescMaker {
|
|
|
|
|
public:
|
|
|
|
|
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
|
|
|
|
@ -142,4 +149,13 @@ class AssignGradMaker : public framework::SingleGradOpDescMaker {
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
REGISTER_OPERATOR(assign, ops::AssignOp, ops::AssignGradMaker,
|
|
|
|
|
ops::AssignInferShape, ops::AssignOpProtoMaker);
|
|
|
|
|
ops::AssignOpProtoMaker);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL_FUNCTOR(assign, float, ops::AssignKernel, double,
|
|
|
|
|
ops::AssignKernel, int, ops::AssignKernel,
|
|
|
|
|
int64_t, ops::AssignKernel);
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
REGISTER_OP_CUDA_KERNEL_FUNCTOR(assign, float, ops::AssignKernel, double,
|
|
|
|
|
ops::AssignKernel, int, ops::AssignKernel,
|
|
|
|
|
int64_t, ops::AssignKernel);
|
|
|
|
|
#endif
|
|
|
|
|