|
|
|
@ -148,10 +148,14 @@ class PrintOp : public framework::OperatorBase {
|
|
|
|
|
const platform::Place &place) const override {
|
|
|
|
|
const auto in_var = scope.FindVar(Input("In"));
|
|
|
|
|
auto out_var = scope.FindVar(Output("Out"));
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(in_var, "The input should not be found in scope",
|
|
|
|
|
Input("In"));
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(out_var, "The output should not be found in scope",
|
|
|
|
|
Output("Out"));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
in_var, platform::errors::NotFound("The input:%s not found in scope",
|
|
|
|
|
Input("In")));
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
out_var, platform::errors::NotFound("The output:%s not found in scope",
|
|
|
|
|
Output("Out")));
|
|
|
|
|
|
|
|
|
|
auto &in_tensor = in_var->Get<framework::LoDTensor>();
|
|
|
|
|
framework::LoDTensor *out_tensor =
|
|
|
|
|
out_var->GetMutable<framework::LoDTensor>();
|
|
|
|
@ -246,8 +250,8 @@ class PrintOpInferShape : public framework::InferShapeBase {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
VLOG(10) << "PrintOpInferShape";
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("In"), "Input(In) should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null.");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("In"), "Input", "In", "Print");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Print");
|
|
|
|
|
ctx->ShareDim("In", /*->*/ "Out");
|
|
|
|
|
ctx->ShareLoD("In", /*->*/ "Out");
|
|
|
|
|
}
|
|
|
|
|