|
|
|
@ -135,33 +135,34 @@ struct Formater {
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// TODO(ChunweiYan) there should be some other printers for TensorArray
|
|
|
|
|
class TensorPrintOp : public framework::OperatorBase {
|
|
|
|
|
class PrintOp : public framework::OperatorBase {
|
|
|
|
|
public:
|
|
|
|
|
TensorPrintOp(const std::string &type,
|
|
|
|
|
const framework::VariableNameMap &inputs,
|
|
|
|
|
const framework::VariableNameMap &outputs,
|
|
|
|
|
const framework::AttributeMap &attrs)
|
|
|
|
|
PrintOp(const std::string &type, const framework::VariableNameMap &inputs,
|
|
|
|
|
const framework::VariableNameMap &outputs,
|
|
|
|
|
const framework::AttributeMap &attrs)
|
|
|
|
|
: OperatorBase(type, inputs, outputs, attrs) {}
|
|
|
|
|
|
|
|
|
|
TensorPrintOp(const TensorPrintOp &o)
|
|
|
|
|
: framework::OperatorBase(
|
|
|
|
|
static_cast<const framework::OperatorBase &>(o)) {
|
|
|
|
|
PADDLE_THROW("Not implemented.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
void RunImpl(const framework::Scope &scope,
|
|
|
|
|
const platform::Place &place) const override {
|
|
|
|
|
const framework::Variable *in_var_ptr = nullptr;
|
|
|
|
|
std::string printed_var_name = "";
|
|
|
|
|
|
|
|
|
|
in_var_ptr = scope.FindVar(Input("In"));
|
|
|
|
|
printed_var_name = Inputs("In").front();
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(in_var_ptr);
|
|
|
|
|
|
|
|
|
|
auto &in_tensor = in_var_ptr->Get<framework::LoDTensor>();
|
|
|
|
|
const auto in_var = scope.FindVar(Input("In"));
|
|
|
|
|
auto out_var = scope.FindVar(Output("Out"));
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(in_var, "The input should not be found in scope",
|
|
|
|
|
Input("In"));
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(out_var, "The output should not be found in scope",
|
|
|
|
|
Output("Out"));
|
|
|
|
|
auto &in_tensor = in_var->Get<framework::LoDTensor>();
|
|
|
|
|
framework::LoDTensor *out_tensor =
|
|
|
|
|
out_var->GetMutable<framework::LoDTensor>();
|
|
|
|
|
|
|
|
|
|
PrintValue(place, Inputs("In").front(), in_tensor);
|
|
|
|
|
framework::TensorCopy(in_tensor, place, out_tensor);
|
|
|
|
|
out_tensor->set_lod(in_tensor.lod());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PrintValue(const platform::Place &place,
|
|
|
|
|
const std::string &printed_var_name,
|
|
|
|
|
const framework::LoDTensor &in_tensor) const {
|
|
|
|
|
std::string print_phase = Attr<std::string>("print_phase");
|
|
|
|
|
bool is_forward = Attr<bool>("is_forward");
|
|
|
|
|
|
|
|
|
@ -177,12 +178,12 @@ class TensorPrintOp : public framework::OperatorBase {
|
|
|
|
|
printed_tensor.set_lod(in_tensor.lod());
|
|
|
|
|
printed_tensor.Resize(in_tensor.dims());
|
|
|
|
|
|
|
|
|
|
if (platform::is_cpu_place(in_tensor.place())) {
|
|
|
|
|
if (is_cpu_place(in_tensor.place())) {
|
|
|
|
|
printed_tensor.ShareDataWith(in_tensor);
|
|
|
|
|
} else {
|
|
|
|
|
// copy data to cpu to print
|
|
|
|
|
platform::CPUPlace place;
|
|
|
|
|
framework::TensorCopy(in_tensor, place, &printed_tensor);
|
|
|
|
|
TensorCopy(in_tensor, place, &printed_tensor);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Formater formater;
|
|
|
|
@ -215,6 +216,7 @@ class PrintOpProtoAndCheckMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
void Make() override {
|
|
|
|
|
AddInput("In", "Input tensor to be displayed.");
|
|
|
|
|
AddOutput("Out", "The output tensor.");
|
|
|
|
|
AddAttr<int>("first_n", "Only log `first_n` number of times.");
|
|
|
|
|
AddAttr<std::string>("message", "A string message to print as a prefix.");
|
|
|
|
|
AddAttr<int>("summarize", "Number of elements printed.");
|
|
|
|
@ -239,10 +241,23 @@ tensor `t`.)DOC");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class InferShapeForward : public framework::InferShapeBase {
|
|
|
|
|
class PrintOpInferShape : public framework::InferShapeBase {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
VLOG(10) << "PrintOpInferShape";
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("In"), "Input(In) should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null.");
|
|
|
|
|
ctx->ShareDim("In", /*->*/ "Out");
|
|
|
|
|
ctx->ShareLoD("In", /*->*/ "Out");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class PrintOpVarTypeInference : public framework::VarTypeInference {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(framework::InferShapeContext *context) const override {
|
|
|
|
|
PADDLE_ENFORCE(context->HasInput("In"), "Input(In) should not be null.");
|
|
|
|
|
void operator()(framework::InferVarTypeContext *ctx) const override {
|
|
|
|
|
auto input_type = ctx->GetType(ctx->Input("In")[0]);
|
|
|
|
|
auto out_name = ctx->Output("Out").front();
|
|
|
|
|
ctx->SetType(out_name, input_type);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -253,7 +268,8 @@ class PrintOpGradientMaker : public framework::SingleGradOpDescMaker {
|
|
|
|
|
std::unique_ptr<framework::OpDesc> Apply() const override {
|
|
|
|
|
auto *op_desc_ptr = new framework::OpDesc();
|
|
|
|
|
op_desc_ptr->SetType("print");
|
|
|
|
|
op_desc_ptr->SetInput("In", InputGrad("In"));
|
|
|
|
|
op_desc_ptr->SetInput("In", OutputGrad("Out"));
|
|
|
|
|
op_desc_ptr->SetOutput("Out", InputGrad("In"));
|
|
|
|
|
op_desc_ptr->SetAttrMap(Attrs());
|
|
|
|
|
op_desc_ptr->SetAttr("is_forward", false);
|
|
|
|
|
return std::unique_ptr<framework::OpDesc>(op_desc_ptr);
|
|
|
|
@ -265,5 +281,6 @@ class PrintOpGradientMaker : public framework::SingleGradOpDescMaker {
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
|
|
|
|
|
REGISTER_OPERATOR(print, ops::TensorPrintOp, ops::PrintOpProtoAndCheckMaker,
|
|
|
|
|
ops::PrintOpGradientMaker, ops::InferShapeForward);
|
|
|
|
|
REGISTER_OPERATOR(print, ops::PrintOp, ops::PrintOpProtoAndCheckMaker,
|
|
|
|
|
ops::PrintOpGradientMaker, ops::PrintOpInferShape,
|
|
|
|
|
ops::PrintOpVarTypeInference);
|
|
|
|
|