Enhance print_op.

add_depthwiseConv_op_gpu
yangyaming 7 years ago
parent 9deb1756a2
commit a091d1a31c

@ -16,12 +16,17 @@
#include <ctime>
#include "paddle/framework/op_registry.h"
#include "paddle/framework/variable.h"
namespace paddle {
namespace operators {
#define CLOG std::cout
const std::string kForward = "FORWARD";
const std::string kBackward = "BACKWARD";
const std::string kBoth = "BOTH";
struct Formater {
std::string message;
std::string name;
@ -122,40 +127,77 @@ class TensorPrintOp : public framework::OperatorBase {
TensorPrintOp(const TensorPrintOp& o)
: framework::OperatorBase(
static_cast<const framework::OperatorBase&>(o)) {
PADDLE_THROW("Not implemented");
PADDLE_THROW("Not implemented.");
}
void Run(const framework::Scope& scope,
const platform::Place& place) const override {
// Only run the `first_n` times.
const framework::Variable* in_var_ptr = nullptr;
std::string phase = kForward;
std::string printed_var_name = "";
auto& inputs = Inputs();
if (inputs.find("In") != inputs.end() && !Inputs("In").empty()) {
in_var_ptr = scope.FindVar(Input("In"));
printed_var_name = Inputs("In").front();
} else if (inputs.find("In@GRAD") != inputs.end() &&
!Inputs("In@GRAD").empty()) {
in_var_ptr = scope.FindVar(Input("In@GRAD"));
printed_var_name = Inputs("In@GRAD").front();
phase = kBackward;
} else {
PADDLE_THROW("Unknown phase, should be forward or backward.");
}
PADDLE_ENFORCE_NOT_NULL(in_var_ptr);
auto& in_tensor = in_var_ptr->Get<framework::LoDTensor>();
auto* out_var_ptr = scope.FindVar(Output("Out"));
auto& out_tensor = *out_var_ptr->GetMutable<framework::LoDTensor>();
// Just copy data from input tensor to output tensor
// output tensor share same memory with input tensor
out_tensor.ShareDataWith(in_tensor);
out_tensor.set_lod(in_tensor.lod());
std::string print_phase = Attr<std::string>("print_phase");
if (print_phase != phase && print_phase != kBoth) {
return;
}
int first_n = Attr<int>("first_n");
if (first_n > 0 && ++times_ > first_n) return;
PADDLE_ENFORCE(!Inputs("input").empty(), "input should be set");
auto* input_var = scope.FindVar(Input("input"));
PADDLE_ENFORCE_NOT_NULL(input_var);
auto& tensor = input_var->Get<framework::LoDTensor>();
framework::LoDTensor printed_tensor;
printed_tensor.set_lod(in_tensor.lod());
printed_tensor.Resize(in_tensor.dims());
// TODO(ChunweiYan) support GPU
PADDLE_ENFORCE(platform::is_cpu_place(tensor.place()));
if (platform::is_cpu_place(in_tensor.place())) {
printed_tensor.ShareDataWith(in_tensor);
} else {
// copy data to cpu to print
platform::CPUPlace place;
framework::Copy(in_tensor, place, &printed_tensor);
}
Formater formater;
if (Attr<bool>("print_tensor_name")) {
formater.name = Inputs("input").front();
formater.name = printed_var_name;
}
if (Attr<bool>("print_tensor_type")) {
formater.dtype = tensor.type();
formater.dtype = printed_tensor.type();
}
if (Attr<bool>("print_tensor_shape")) {
formater.dims.assign(tensor.dims()[0],
tensor.dims()[tensor.dims().size() - 1]);
auto& dims = printed_tensor.dims();
formater.dims.resize(dims.size());
for (int i = 0; i < dims.size(); ++i) formater.dims[i] = dims[i];
}
if (Attr<bool>("print_tensor_lod")) {
formater.lod = tensor.lod();
formater.lod = printed_tensor.lod();
}
formater.summarize = Attr<int>("summarize");
formater.data = (void*)tensor.data<void>();
formater(tensor.numel());
formater.data = (void*)printed_tensor.data<void>();
formater(printed_tensor.numel());
}
private:
@ -166,27 +208,46 @@ class PrintOpProtoAndCheckMaker : public framework::OpProtoAndCheckerMaker {
public:
PrintOpProtoAndCheckMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "the tensor that will be displayed.");
AddInput("In", "Input tensor to be displayed.");
AddAttr<int>("first_n", "Only log `first_n` number of times.");
AddAttr<std::string>("message", "A string message to print as a prefix.");
AddAttr<int>("summarize", "Print this number of elements in the tensor.");
AddAttr<int>("summarize", "Number of elements printed.");
AddAttr<bool>("print_tensor_name", "Whether to print the tensor name.");
AddAttr<bool>("print_tensor_type", "Whether to print the tensor's dtype.");
AddAttr<bool>("print_tensor_shape", "Whether to print the tensor's shape.");
AddAttr<bool>("print_tensor_lod", "Whether to print the tensor's lod.");
AddAttr<std::string>(
"print_phase",
"(string, default 'BOTH') Which phase to display including 'FORWARD' "
"'BACKWARD' and 'BOTH'.")
.SetDefault(kBoth)
.InEnum({kForward, kBackward, kBoth});
AddOutput("Out", "Output tensor with same data as input tensor.");
AddComment(R"DOC(
Creates a print op that will print when a tensor is accessed.
Creates a print op that will print when a tensor is accessed.
Wraps the tensor passed in so that whenever that a tensor is accessed,
the message `message` is printed, along with the current value of the
tensor `t`.)DOC");
Wraps the tensor passed in so that whenever that a tensor is accessed,
the message `message` is printed, along with the current value of the
tensor `t`.)DOC");
}
};
class InferShape : public framework::InferShapeBase {
class InferShapeForward : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* context) const override {
PADDLE_ENFORCE(context->HasInput("input"), "input should be set");
PADDLE_ENFORCE(context->HasInput("In"), "Input(In) should not be null.");
context->ShareLoD("In", /*->*/ "Out");
context->SetOutputDim("Out", context->GetInputDim("In"));
}
};
class InferShapeBackward : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* context) const override {
PADDLE_ENFORCE(context->HasInput("In@GRAD"),
"Input(In@GRAD) should not be null.");
context->ShareLoD("In@GRAD", /*->*/ "Out");
context->SetOutputDim("Out", context->GetInputDim("In@GRAD"));
}
};
@ -196,11 +257,27 @@ class InferVarType : public framework::VarTypeInference {
framework::BlockDesc* block) const override {}
};
class PrintOpProtoAndCheckGradOpMaker
: public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
std::unique_ptr<framework::OpDesc> Apply() const override {
auto* op_desc_ptr = new framework::OpDesc();
op_desc_ptr->SetType("print_grad");
op_desc_ptr->SetInput("In@GRAD", OutputGrad("Out"));
op_desc_ptr->SetOutput("Out", InputGrad("In"));
op_desc_ptr->SetAttrMap(Attrs());
return std::unique_ptr<framework::OpDesc>(op_desc_ptr);
}
};
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(print, paddle::operators::TensorPrintOp,
paddle::operators::PrintOpProtoAndCheckMaker,
paddle::operators::InferShape,
paddle::operators::InferVarType,
paddle::framework::EmptyGradOpMaker);
namespace ops = paddle::operators;
REGISTER_OPERATOR(print, ops::TensorPrintOp, ops::PrintOpProtoAndCheckMaker,
ops::PrintOpProtoAndCheckGradOpMaker, ops::InferShapeForward,
ops::InferVarType);
REGISTER_OPERATOR(print_grad, ops::TensorPrintOp, ops::InferShapeBackward);

@ -117,7 +117,8 @@ def Print(input,
print_tensor_name=True,
print_tensor_type=True,
print_tensor_shape=True,
print_tensor_lod=True):
print_tensor_lod=True,
print_phase='both'):
'''
**Print operator**
@ -128,18 +129,21 @@ def Print(input,
tensor `t`.
Args:
input(Variable): A Tensor to print.
summarize(int): Print this number of elements in the tensor, will print all
if left negative.
message(str): A string message to print as a prefix.
first_n(int): Only log `first_n` number of times.
print_tensor_name(bool): Print the tensor name.
print_tensor_type(bool): Print the tensor type.
print_tensor_shape(bool): Print the tensor shape.
print_tensor_lod(bool): Print the tensor lod.
input (Variable): A Tensor to print.
summarize (int): Print this number of elements in the tensor, will print
all if left is negative.
message (str): A string message to print as a prefix.
first_n (int): Only log `first_n` number of times.
print_tensor_name (bool): Print the tensor name.
print_tensor_type (bool): Print the tensor type.
print_tensor_shape (bool): Print the tensor shape.
print_tensor_lod (bool): Print the tensor lod.
print_phase (bool): Which phase to displace, including 'forward',
'backward' and 'both'. If set to 'backward' or 'both', will
print the gradients of input tensor.
Returns:
None
Variable: Output tensor, same data with input tensor.
Examples:
.. code-block:: python
@ -149,10 +153,10 @@ def Print(input,
message="The content of some_layer: ")
'''
helper = LayerHelper('print', **locals())
out = helper.create_tmp_variable(dtype='int32')
out = helper.create_tmp_variable(dtype=helper.input_dtype())
helper.append_op(
type='print',
inputs={'input': input},
inputs={'In': input},
attrs={
'first_n': first_n,
'summarize': summarize,
@ -161,7 +165,9 @@ def Print(input,
'print_tensor_type': print_tensor_type,
'print_tensor_shape': print_tensor_shape,
'print_tensor_lod': print_tensor_lod,
})
'print_phase': print_phase.upper()
},
outputs={'Out': out})
return out

@ -1,20 +1,54 @@
import unittest
import numpy as np
from paddle.v2.fluid.executor import Executor
import paddle.v2.fluid.core as core
import paddle.v2.fluid.layers as pd
from paddle.v2.fluid.executor import Executor
import paddle.v2.fluid.layers as layers
from paddle.v2.fluid.backward import append_backward
from paddle.v2.fluid.framework import switch_main_program
from paddle.v2.fluid.framework import Program
import numpy as np
class TestPrintOpCPU(unittest.TestCase):
def setUp(self):
self.place = core.CPUPlace()
self.x_tensor = core.LoDTensor()
tensor_np = np.random.random(size=(2, 3)).astype('float32')
self.x_tensor.set(tensor_np, self.place)
self.x_tensor.set_lod([[0, 1, 1]])
def build_network(self, only_forward, **kargs):
x = layers.data('x', shape=[3], dtype='float32', lod_level=1)
x.stop_gradient = False
printed = layers.Print(input=x, **kargs)
if only_forward: return printed
loss = layers.mean(x=printed)
append_backward(loss=loss)
return loss
class TestSumOp(unittest.TestCase):
def test_tensor(self):
i = pd.zeros(shape=[2, 10], dtype='float32')
def test_forward(self):
switch_main_program(Program())
printed = self.build_network(True, print_phase='forward')
exe = Executor(self.place)
outs = exe.run(feed={'x': self.x_tensor},
fetch_list=[printed],
return_numpy=False)
pd.Print(i, message="I am a message", summarize=10)
def test_backward(self):
switch_main_program(Program())
loss = self.build_network(False, print_phase='backward')
exe = Executor(self.place)
outs = exe.run(feed={'x': self.x_tensor},
fetch_list=[loss],
return_numpy=False)
cpu = core.CPUPlace()
exe = Executor(cpu)
exe.run()
class TestPrintOpGPU(TestPrintOpCPU):
def setUp(self):
self.place = core.CUDAPlace(0)
self.x_tensor = core.LoDTensor()
tensor_np = np.random.random(size=(2, 3)).astype('float32')
self.x_tensor.set(tensor_np, self.place)
self.x_tensor.set_lod([[0, 1, 1]])
if __name__ == '__main__':

Loading…
Cancel
Save