diff --git a/CMakeLists.txt b/CMakeLists.txt index f56c5d382a..920c20d6f8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -204,12 +204,11 @@ include(external/snappy) # download snappy include(external/snappystream) include(external/threadpool) +set(WITH_ANAKIN OFF CACHE STRING "Disable Anakin first, will add it later." FORCE) if(WITH_GPU) include(cuda) include(tensorrt) include(external/anakin) -else() - set(WITH_ANAKIN OFF CACHE STRING "Anakin is valid only when GPU is set." FORCE) endif() include(cudnn) # set cudnn libraries, must before configure diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 2ce73df024..c020ff45ad 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -6,7 +6,7 @@ paddle.fluid.Program.create_block ArgSpec(args=['self', 'parent_idx'], varargs=N paddle.fluid.Program.current_block ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) paddle.fluid.Program.get_desc ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) paddle.fluid.Program.global_block ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) -paddle.fluid.Program.inference_optimize ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) +paddle.fluid.Program.inference_optimize ArgSpec(args=['self', 'export_for_deployment'], varargs=None, keywords=None, defaults=(True,)) paddle.fluid.Program.list_vars ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) paddle.fluid.Program.optimized_guard ArgSpec(args=[], varargs='args', keywords='kwds', defaults=None) paddle.fluid.Program.parse_from_string ArgSpec(args=['binary_str'], varargs=None, keywords=None, defaults=None) @@ -18,6 +18,9 @@ paddle.fluid.Operator.all_attrs ArgSpec(args=['self'], varargs=None, keywords=No paddle.fluid.Operator.attr ArgSpec(args=['self', 'name'], varargs=None, keywords=None, defaults=None) paddle.fluid.Operator.attr_type ArgSpec(args=['self', 'name'], varargs=None, keywords=None, defaults=None) paddle.fluid.Operator.block_attr ArgSpec(args=['self', 'name'], varargs=None, keywords=None, defaults=None) +paddle.fluid.Operator.block_attr_id ArgSpec(args=['self', 'name'], varargs=None, keywords=None, defaults=None) +paddle.fluid.Operator.blocks_attr ArgSpec(args=['self', 'name'], varargs=None, keywords=None, defaults=None) +paddle.fluid.Operator.blocks_attr_ids ArgSpec(args=['self', 'name'], varargs=None, keywords=None, defaults=None) paddle.fluid.Operator.has_attr ArgSpec(args=['self', 'name'], varargs=None, keywords=None, defaults=None) paddle.fluid.Operator.has_kernel ArgSpec(args=['self', 'op_type'], varargs=None, keywords=None, defaults=None) paddle.fluid.Operator.input ArgSpec(args=['self', 'name'], varargs=None, keywords=None, defaults=None) @@ -52,7 +55,7 @@ paddle.fluid.Inferencer.__init__ ArgSpec(args=['self', 'infer_func', 'param_path paddle.fluid.Inferencer.infer ArgSpec(args=['self', 'inputs', 'return_numpy'], varargs=None, keywords=None, defaults=(True,)) paddle.fluid.DistributeTranspiler.__init__ ArgSpec(args=['self', 'config'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.DistributeTranspiler.get_pserver_program ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None) -paddle.fluid.DistributeTranspiler.get_startup_program ArgSpec(args=['self', 'endpoint', 'pserver_program'], varargs=None, keywords=None, defaults=None) +paddle.fluid.DistributeTranspiler.get_startup_program ArgSpec(args=['self', 'endpoint', 'pserver_program', 'startup_program'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.DistributeTranspiler.get_trainer_program ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) paddle.fluid.DistributeTranspiler.transpile ArgSpec(args=['self', 'trainer_id', 'program', 'pservers', 'trainers', 'sync_mode'], varargs=None, keywords=None, defaults=(None, '127.0.0.1:6174', 1, True)) paddle.fluid.InferenceTranspiler.__init__ @@ -74,7 +77,7 @@ paddle.fluid.io.save_persistables ArgSpec(args=['executor', 'dirname', 'main_pro paddle.fluid.io.load_vars ArgSpec(args=['executor', 'dirname', 'main_program', 'vars', 'predicate', 'filename'], varargs=None, keywords=None, defaults=(None, None, None, None)) paddle.fluid.io.load_params ArgSpec(args=['executor', 'dirname', 'main_program', 'filename'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.io.load_persistables ArgSpec(args=['executor', 'dirname', 'main_program', 'filename'], varargs=None, keywords=None, defaults=(None, None)) -paddle.fluid.io.save_inference_model ArgSpec(args=['dirname', 'feeded_var_names', 'target_vars', 'executor', 'main_program', 'model_filename', 'params_filename'], varargs=None, keywords=None, defaults=(None, None, None)) +paddle.fluid.io.save_inference_model ArgSpec(args=['dirname', 'feeded_var_names', 'target_vars', 'executor', 'main_program', 'model_filename', 'params_filename', 'export_for_deployment'], varargs=None, keywords=None, defaults=(None, None, None, True)) paddle.fluid.io.load_inference_model ArgSpec(args=['dirname', 'executor', 'model_filename', 'params_filename'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.io.get_inference_program ArgSpec(args=['target_vars', 'main_program'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.initializer.ConstantInitializer.__init__ ArgSpec(args=['self', 'value', 'force_cpu'], varargs=None, keywords=None, defaults=(0.0, False)) @@ -156,6 +159,7 @@ paddle.fluid.layers.relu ArgSpec(args=['x'], varargs=None, keywords=None, defaul paddle.fluid.layers.log ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.crop ArgSpec(args=['x', 'shape', 'offsets', 'name'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.layers.rank_loss ArgSpec(args=['label', 'left', 'right', 'name'], varargs=None, keywords=None, defaults=(None,)) +paddle.fluid.layers.flatten ArgSpec(args=['x', 'axis', 'name'], varargs=None, keywords=None, defaults=(1, None)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) paddle.fluid.layers.open_recordio_file ArgSpec(args=['filename', 'shapes', 'lod_levels', 'dtypes', 'pass_num', 'for_parallel'], varargs=None, keywords=None, defaults=(1, True)) paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)) @@ -324,7 +328,7 @@ paddle.fluid.contrib.BeamSearchDecoder.update_array ArgSpec(args=['self', 'array paddle.fluid.contrib.memory_usage ArgSpec(args=['program', 'batch_size'], varargs=None, keywords=None, defaults=None) paddle.fluid.transpiler.DistributeTranspiler.__init__ ArgSpec(args=['self', 'config'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.transpiler.DistributeTranspiler.get_pserver_program ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None) -paddle.fluid.transpiler.DistributeTranspiler.get_startup_program ArgSpec(args=['self', 'endpoint', 'pserver_program'], varargs=None, keywords=None, defaults=None) +paddle.fluid.transpiler.DistributeTranspiler.get_startup_program ArgSpec(args=['self', 'endpoint', 'pserver_program', 'startup_program'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.transpiler.DistributeTranspiler.get_trainer_program ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) paddle.fluid.transpiler.DistributeTranspiler.transpile ArgSpec(args=['self', 'trainer_id', 'program', 'pservers', 'trainers', 'sync_mode'], varargs=None, keywords=None, defaults=(None, '127.0.0.1:6174', 1, True)) paddle.fluid.transpiler.InferenceTranspiler.__init__ diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index c9d55fbf52..5736a5c4e2 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -28,6 +28,38 @@ namespace paddle { namespace framework { namespace ir { +/* + * The graph is a Directed Acyclic Single Static Assignment Graph. + * + * In more detail, the following properties must hold: + * + * The graph shouldn't contain cycle. Each node is a black-box to the graph + * so the node itself could be a loop operator. + * + * Each Variable-type node has only one input (thus single static assignment). + * + * The output/input of operator is variable and the output/input of variable + * is operator. + * + * The following data harzards in Program are addressed in the Graph: + * + * Write-After-Read + * a = op1(x) + * x = op2(b) + * A control-dependency connection is created bettwen op1 and op2 such that + * op1->op2, so as to ensure correct order. + * + * Write-After-Write + * x = op1(a) + * x = op2(b) + * A control-dependency connection is created between op1 and op2 such that + * op1->op2, so as to ensure correct order. + * + * Other properties currently hold, but is not enforced yet: + * + * Variable-type node (not control dep) with the same variable name share + * the same underlying VarDesc. + */ class Graph { public: explicit Graph(const ProgramDesc &program); diff --git a/paddle/fluid/framework/ir/graph_test.cc b/paddle/fluid/framework/ir/graph_test.cc index f9e6bdf362..b1b8d1c586 100644 --- a/paddle/fluid/framework/ir/graph_test.cc +++ b/paddle/fluid/framework/ir/graph_test.cc @@ -36,7 +36,7 @@ class SumOpMaker : public OpProtoAndCheckerMaker { public: void Make() { AddInput("X", "").AsDuplicable(); - AddOutput("Out", ""); + AddOutput("Out", "").AsDuplicable(); AddComment(""); } }; @@ -59,11 +59,27 @@ class SumOpVarTypeInference : public VarTypeInference { block->Var(out_var_name)->SetType(default_var_type); } }; + +class DummyOpMaker : public OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("X", "").AsDuplicable(); + AddOutput("Out", "").AsDuplicable(); + AddComment(""); + } +}; + +class DummyOpVarTypeInference : public VarTypeInference { + public: + void operator()(const OpDesc &op_desc, BlockDesc *block) const override {} +}; } // namespace framework } // namespace paddle REGISTER_OPERATOR(sum, paddle::framework::NOP, paddle::framework::SumOpMaker, paddle::framework::SumOpVarTypeInference); +REGISTER_OPERATOR(dummy, paddle::framework::NOP, paddle::framework::SumOpMaker, + paddle::framework::SumOpVarTypeInference); REGISTER_OPERATOR(sum_without_infer_var_type, paddle::framework::NOP, paddle::framework::SumOpMaker); @@ -110,5 +126,83 @@ TEST(GraphTest, Basic) { } ASSERT_EQ(nodes.size(), 5); } + +TEST(GraphTest, WriteAfterRead) { + // void Test() { + ProgramDesc prog; + auto *op = prog.MutableBlock(0)->AppendOp(); + op->SetType("sum"); + op->SetInput("X", {"a"}); + op->SetOutput("Out", {"b"}); + op->SetAttr("op_role", 1); + + op = prog.MutableBlock(0)->AppendOp(); + op->SetType("dummy"); + op->SetInput("X", {"c"}); + op->SetOutput("Out", {"a"}); + op->SetAttr("op_role", 1); + + prog.MutableBlock(0)->Var("a")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(0)->Var("b")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(0)->Var("c")->SetType(proto::VarType::LOD_TENSOR); + + std::unique_ptr g(new ir::Graph(prog)); + ir::Node *control_dep1 = nullptr; + ir::Node *control_dep2 = nullptr; + for (ir::Node *n : g->Nodes()) { + if (n->Name() == "sum") { + ASSERT_EQ(n->outputs[0]->Name(), "b"); + ASSERT_TRUE(ir::IsControlDepVar(*n->outputs[1])); + control_dep1 = n->outputs[1]; + ASSERT_EQ(n->outputs.size(), 2); + } + if (n->Name() == "dummy") { + ASSERT_EQ(n->inputs[0]->Name(), "c"); + ASSERT_TRUE(ir::IsControlDepVar(*n->inputs[1])); + control_dep2 = n->inputs[1]; + ASSERT_EQ(n->inputs.size(), 2); + } + } + ASSERT_EQ(control_dep1, control_dep2); +} + +TEST(GraphTest, WriteAfterWrite) { + // void Test() { + ProgramDesc prog; + auto *op = prog.MutableBlock(0)->AppendOp(); + op->SetType("sum"); + op->SetInput("X", {"a"}); + op->SetOutput("Out", {"b"}); + op->SetAttr("op_role", 1); + + op = prog.MutableBlock(0)->AppendOp(); + op->SetType("dummy"); + op->SetInput("X", {"c"}); + op->SetOutput("Out", {"b"}); + op->SetAttr("op_role", 1); + + prog.MutableBlock(0)->Var("a")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(0)->Var("b")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(0)->Var("c")->SetType(proto::VarType::LOD_TENSOR); + + std::unique_ptr g(new ir::Graph(prog)); + ir::Node *control_dep1 = nullptr; + ir::Node *control_dep2 = nullptr; + for (ir::Node *n : g->Nodes()) { + if (n->Name() == "sum") { + ASSERT_EQ(n->outputs[0]->Name(), "b"); + ASSERT_TRUE(ir::IsControlDepVar(*n->outputs[1])); + ASSERT_EQ(n->outputs.size(), 2); + control_dep1 = n->outputs[1]; + } + if (n->Name() == "dummy") { + ASSERT_EQ(n->inputs[0]->Name(), "c"); + ASSERT_TRUE(ir::IsControlDepVar(*n->inputs[1])); + control_dep2 = n->inputs[1]; + ASSERT_EQ(n->inputs.size(), 2); + ASSERT_EQ(control_dep1, control_dep2); + } + } +} } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index a190199f1c..03f7e71c03 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -238,7 +238,20 @@ Attribute OpDesc::GetNullableAttr(const std::string &name) const { } } -int OpDesc::GetBlockAttr(const std::string &name) const { +std::vector OpDesc::GetBlocksAttrIds(const std::string &name) const { + auto it = attrs_.find(name); + PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name); + auto blocks = boost::get>(it->second); + + std::vector ids; + for (auto n : blocks) { + ids.push_back(n->ID()); + } + + return ids; +} + +int OpDesc::GetBlockAttrId(const std::string &name) const { auto it = attrs_.find(name); PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name); return boost::get(it->second)->ID(); diff --git a/paddle/fluid/framework/op_desc.h b/paddle/fluid/framework/op_desc.h index 74dd8ec002..b77d84125a 100644 --- a/paddle/fluid/framework/op_desc.h +++ b/paddle/fluid/framework/op_desc.h @@ -83,7 +83,9 @@ class OpDesc { Attribute GetNullableAttr(const std::string &name) const; - int GetBlockAttr(const std::string &name) const; + int GetBlockAttrId(const std::string &name) const; + + std::vector GetBlocksAttrIds(const std::string &name) const; void Rename(const std::string &old_name, const std::string &new_name); diff --git a/paddle/fluid/framework/program_desc.cc b/paddle/fluid/framework/program_desc.cc index 1e01a6e900..20bdc7830f 100644 --- a/paddle/fluid/framework/program_desc.cc +++ b/paddle/fluid/framework/program_desc.cc @@ -58,7 +58,7 @@ ProgramDesc::ProgramDesc(const ProgramDesc &o) { for (const std::string &attr_name : op->AttrNames()) { if (op->GetAttrType(attr_name) == proto::AttrType::BLOCK) { int sub_block_id = - o.Block(block_id).Op(op_id)->GetBlockAttr(attr_name); + o.Block(block_id).Op(op_id)->GetBlockAttrId(attr_name); op->SetBlockAttr(attr_name, MutableBlock(sub_block_id)); } } diff --git a/paddle/fluid/framework/tensor.cc b/paddle/fluid/framework/tensor.cc index c7286dacf0..56bb9142da 100644 --- a/paddle/fluid/framework/tensor.cc +++ b/paddle/fluid/framework/tensor.cc @@ -112,5 +112,6 @@ Tensor& Tensor::Resize(const DDim& dims) { const DDim& Tensor::dims() const { return dims_; } int64_t Tensor::numel() const { return product(dims_); } + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/tensor_impl.h b/paddle/fluid/framework/tensor_impl.h index 7f678f869a..b7b62eef23 100644 --- a/paddle/fluid/framework/tensor_impl.h +++ b/paddle/fluid/framework/tensor_impl.h @@ -59,6 +59,14 @@ inline T* Tensor::mutable_data(platform::Place place) { } inline Tensor ReshapeToMatrix(const Tensor& src, int num_col_dims) { + int rank = src.dims().size(); + PADDLE_ENFORCE_GE( + rank, 2, + "'ReshapeToMatrix()' is only used for flatten high rank " + "tensors to matrixs. Can not be used in reshaping vectors."); + if (rank == 2) { + return src; + } Tensor res; res.ShareDataWith(src); res.Resize(flatten_to_2d(src.dims(), num_col_dims)); diff --git a/paddle/fluid/inference/api/api_impl.cc b/paddle/fluid/inference/api/api_impl.cc index 08d7af6d3a..e31c637e96 100644 --- a/paddle/fluid/inference/api/api_impl.cc +++ b/paddle/fluid/inference/api/api_impl.cc @@ -22,6 +22,9 @@ limitations under the License. */ #include #include "paddle/fluid/inference/api/api_impl.h" +#include "paddle/fluid/platform/profiler.h" + +DEFINE_bool(profile, false, "Turn on profiler for fluid"); namespace paddle { namespace { @@ -58,6 +61,15 @@ bool NativePaddlePredictor::Init( std::shared_ptr parent_scope) { VLOG(3) << "Predictor::init()"; + if (FLAGS_profile) { + LOG(WARNING) << "Profiler is actived, might affect the performance"; + LOG(INFO) << "You can turn off by set gflags '-profile false'"; + + auto tracking_device = config_.use_gpu ? platform::ProfilerState::kAll + : platform::ProfilerState::kCPU; + platform::EnableProfiler(tracking_device); + } + if (config_.use_gpu) { place_ = paddle::platform::CUDAPlace(config_.device); } else { @@ -102,6 +114,10 @@ bool NativePaddlePredictor::Init( } NativePaddlePredictor::~NativePaddlePredictor() { + if (FLAGS_profile) { + platform::DisableProfiler(platform::EventSortingKey::kTotal, + "./profile.log"); + } if (sub_scope_) { scope_->DeleteScope(sub_scope_); } diff --git a/paddle/fluid/operators/.flatten_op.cc.swp b/paddle/fluid/operators/.flatten_op.cc.swp deleted file mode 100644 index 3395b6074b..0000000000 Binary files a/paddle/fluid/operators/.flatten_op.cc.swp and /dev/null differ diff --git a/paddle/fluid/operators/cross_entropy_op.cc b/paddle/fluid/operators/cross_entropy_op.cc index a3bec3da45..578ab63bc3 100644 --- a/paddle/fluid/operators/cross_entropy_op.cc +++ b/paddle/fluid/operators/cross_entropy_op.cc @@ -28,23 +28,26 @@ class CrossEntropyOp : public framework::OperatorWithKernel { auto x_dims = ctx->GetInputDim("X"); auto label_dims = ctx->GetInputDim("Label"); - PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, "Input(X)'s rank should be 2."); - PADDLE_ENFORCE_EQ(label_dims.size(), 2UL, - "Input(Label)'s rank should be 2."); - PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0], - "The 1st dimension of Input(X) and Input(Label) should " - "be equal."); + int rank = x_dims.size(); + PADDLE_ENFORCE_EQ(rank, label_dims.size(), + "Input(X) and Input(Label) shall have the same rank."); + PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), + framework::slice_ddim(label_dims, 0, rank - 1), + "Input(X) and Input(Label) shall have the same shape " + "except the last dimension."); if (ctx->Attrs().Get("soft_label")) { - PADDLE_ENFORCE_EQ(x_dims[1], label_dims[1], - "If Attr(soft_label) == true, the 2nd dimension of " + PADDLE_ENFORCE_EQ(x_dims[rank - 1], label_dims[rank - 1], + "If Attr(soft_label) == true, the last dimension of " "Input(X) and Input(Label) should be equal."); } else { - PADDLE_ENFORCE_EQ(label_dims[1], 1UL, - "If Attr(softLabel) == false, the 2nd dimension of " + PADDLE_ENFORCE_EQ(label_dims[rank - 1], 1UL, + "If Attr(softLabel) == false, the last dimension of " "Input(Label) should be 1."); } - ctx->SetOutputDim("Y", {x_dims[0], 1}); + auto y_dims = x_dims; + y_dims[rank - 1] = 1; + ctx->SetOutputDim("Y", y_dims); ctx->ShareLoD("X", /*->*/ "Y"); } @@ -74,24 +77,28 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel { auto x_dims = ctx->GetInputDim("X"); auto label_dims = ctx->GetInputDim("Label"); auto dy_dims = ctx->GetInputDim(framework::GradVarName("Y")); - PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2."); - PADDLE_ENFORCE_EQ(dy_dims.size(), 2, "Input(Y@Grad)'s rank should be 2."); - PADDLE_ENFORCE_EQ(label_dims.size(), 2, "Input(Label)'s rank should be 2."); - PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0], - "The 1st dimension of Input(X) and Input(Label) should " - "be equal."); - PADDLE_ENFORCE_EQ(x_dims[0], dy_dims[0], - "The 1st dimension of Input(X) and Input(Y@Grad) should " - "be equal."); - PADDLE_ENFORCE_EQ(dy_dims[1], 1, - "The 2nd dimension of Input(Y@Grad) should be 1."); + int rank = x_dims.size(); + PADDLE_ENFORCE_EQ(dy_dims.size(), rank, + "Input(Y@Grad) and Input(X) should have the same rank."); + PADDLE_ENFORCE_EQ(label_dims.size(), rank, + "Input(Label) and Input(X) should have the same rank."); + PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), + framework::slice_ddim(label_dims, 0, rank - 1), + "The Input(X) and Input(Label) should have the same " + "shape except the last dimension."); + PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), + framework::slice_ddim(dy_dims, 0, rank - 1), + "The Input(X) and Input(Y@Grad) should have the same " + "shape except the last dimension."); + PADDLE_ENFORCE_EQ(dy_dims[rank - 1], 1, + "The last dimension of Input(Y@Grad) should be 1."); if (ctx->Attrs().Get("soft_label")) { - PADDLE_ENFORCE_EQ(x_dims[1], label_dims[1], - "When Attr(soft_label) == true, the 2nd dimension of " + PADDLE_ENFORCE_EQ(x_dims[rank - 1], label_dims[rank - 1], + "When Attr(soft_label) == true, the last dimension of " "Input(X) and Input(Label) should be equal."); } else { - PADDLE_ENFORCE_EQ(label_dims[1], 1, - "When Attr(soft_label) == false, the 2nd dimension of " + PADDLE_ENFORCE_EQ(label_dims[rank - 1], 1, + "When Attr(soft_label) == false, the last dimension of " "Input(Label) should be 1."); } ctx->SetOutputDim(framework::GradVarName("X"), x_dims); @@ -113,18 +120,20 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("X", - "(Tensor, default Tensor), a 2-D tensor with shape [N x D]," - " where N is the batch size and D is the number of classes. " - "This input is a probability computed by the previous operator, " - "which is almost always the result of a softmax operator."); - AddInput("Label", - "(Tensor), the ground truth which is a 2-D tensor. When " - "soft_label is set to false, Label is a Tensor with shape " - "[N x 1]. When soft_label is set to true, Label is a " - "Tensor with shape [N x D]."); + "(Tensor, default Tensor), a tensor whose last dimension " + "size is equal to the number of classes. This input is a " + "probability computed by the previous operator, which is almost " + "always the result of a softmax operator."); + AddInput( + "Label", + "(Tensor), the tensor which represents the ground truth. It has the " + "same shape with 'X' except the last dimension. When soft_label is set " + "to false, the last dimension size is 1; when soft_label is set to " + "true, the last dimension size is equal to the number of classes."); AddOutput("Y", - "(Tensor, default Tensor), a 2-D tensor with shape " - "[N x 1]. The cross entropy loss."); + "(Tensor, default Tensor), a tensor whose shape is same " + "with 'X' except that the last dimension size is 1. It " + "represents the cross entropy loss."); AddAttr("soft_label", "(bool, default false), a flag indicating whether to " "interpretate the given labels as soft labels.") @@ -132,6 +141,12 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( CrossEntropy Operator. +The input 'X' and 'Label' will first be logically flattened to 2-D matrixs. +The matrix's second dimension(row length) is as same as the original last +dimension, and the first dimension(column length) is the product of all other +original dimensions. Then the softmax computation will take palce on each raw +of flattened matrixs. + It supports both standard cross-entropy and soft-label cross-entropy loss computation. 1) One-hot cross-entropy: diff --git a/paddle/fluid/operators/cross_entropy_op.h b/paddle/fluid/operators/cross_entropy_op.h index 19a2aec92b..36b58d8014 100644 --- a/paddle/fluid/operators/cross_entropy_op.h +++ b/paddle/fluid/operators/cross_entropy_op.h @@ -33,8 +33,13 @@ class CrossEntropyOpKernel : public framework::OpKernel { auto* y = ctx.Output("Y"); y->mutable_data(ctx.GetPlace()); + int rank = x->dims().size(); + Tensor x_2d = framework::ReshapeToMatrix(*x, rank - 1); + Tensor labels_2d = framework::ReshapeToMatrix(*labels, rank - 1); + Tensor y_2d = framework::ReshapeToMatrix(*y, rank - 1); + math::CrossEntropyFunctor()( - ctx.template device_context(), y, x, labels, + ctx.template device_context(), &y_2d, &x_2d, &labels_2d, ctx.Attr("soft_label")); } }; @@ -98,9 +103,12 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel { auto* dy = ctx.Input(framework::GradVarName("Y")); auto* label = ctx.Input("Label"); auto* dx = ctx.Output(framework::GradVarName("X")); - auto* dx_data = dx->mutable_data(ctx.GetPlace()); + T* dx_data = dx->mutable_data(ctx.GetPlace()); - int64_t class_num = x->dims()[1]; + // Following computation only depends on the last dimension size. So it's + // unnecessary to convert tensors to 2-D views. + int rank = x->dims().size(); + int64_t class_num = x->dims()[rank - 1]; if (ctx.Attr("soft_label")) { XeSoftlabelGradFunctor functor(dx_data, dy->data(), x->data(), label->data(), diff --git a/paddle/fluid/operators/shape_op.cc b/paddle/fluid/operators/shape_op.cc index b44d5f8980..1be9fe47af 100644 --- a/paddle/fluid/operators/shape_op.cc +++ b/paddle/fluid/operators/shape_op.cc @@ -38,7 +38,7 @@ class ShapeOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("Input", "(Tensor), The input tensor."); AddOutput("Out", "(Tensor), The shape of input tensor, the data type of the shape" - " is int64_t, will be on the same device with the input Tensor."); + " is int32_t, will be on the same device with the input Tensor."); AddComment(R"DOC( Shape Operator @@ -53,5 +53,5 @@ Get the shape of input tensor. Only support CPU input Tensor now. namespace ops = paddle::operators; REGISTER_OPERATOR(shape, ops::ShapeOp, ops::ShapeOpMaker, paddle::framework::EmptyGradOpMaker); -REGISTER_OP_CPU_KERNEL(shape, ops::ShapeKernel, ops::ShapeKernel, +REGISTER_OP_CPU_KERNEL(shape, ops::ShapeKernel, ops::ShapeKernel, ops::ShapeKernel, ops::ShapeKernel); diff --git a/paddle/fluid/operators/shape_op.cu b/paddle/fluid/operators/shape_op.cu index 7736a2a1e1..d8fa9515ab 100644 --- a/paddle/fluid/operators/shape_op.cu +++ b/paddle/fluid/operators/shape_op.cu @@ -15,6 +15,6 @@ limitations under the License. */ #include "paddle/fluid/operators/shape_op.h" REGISTER_OP_CUDA_KERNEL(shape, paddle::operators::ShapeKernel, - paddle::operators::ShapeKernel, + paddle::operators::ShapeKernel, paddle::operators::ShapeKernel, paddle::operators::ShapeKernel); diff --git a/paddle/fluid/operators/shape_op.h b/paddle/fluid/operators/shape_op.h index 3be86b66a5..0d510a5055 100644 --- a/paddle/fluid/operators/shape_op.h +++ b/paddle/fluid/operators/shape_op.h @@ -27,7 +27,7 @@ class ShapeKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { auto* in_t = ctx.Input("Input"); auto* out_t = ctx.Output("Out"); - auto out_data = out_t->mutable_data(platform::CPUPlace()); + auto out_data = out_t->mutable_data(platform::CPUPlace()); auto in_dims = in_t->dims(); for (int i = 0; i < in_dims.size(); ++i) { out_data[i] = in_dims[i]; diff --git a/paddle/fluid/operators/softmax_op.h b/paddle/fluid/operators/softmax_op.h index 1205bd0587..cf1eeb017d 100644 --- a/paddle/fluid/operators/softmax_op.h +++ b/paddle/fluid/operators/softmax_op.h @@ -31,16 +31,12 @@ class SoftmaxKernel : public framework::OpKernel { // allocate memory on device. Out->mutable_data(context.GetPlace()); - auto dims = X->dims(); - auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1); - framework::LoDTensor flattened_x; - framework::LoDTensor flattened_out; - flattened_x.ShareDataWith(*X).Resize(flattened_dims); - flattened_out.ShareDataWith(*Out).Resize(flattened_dims); + int rank = X->dims().size(); + Tensor X_2d = framework::ReshapeToMatrix(*X, rank - 1); + Tensor Out_2d = framework::ReshapeToMatrix(*Out, rank - 1); math::SoftmaxFunctor()( - context.template device_context(), &flattened_x, - &flattened_out); + context.template device_context(), &X_2d, &Out_2d); } }; @@ -55,18 +51,14 @@ class SoftmaxGradKernel : public framework::OpKernel { // allocate memory on device. dX->mutable_data(context.GetPlace()); - auto dims = Out->dims(); - auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1); - framework::LoDTensor flattened_out; - framework::LoDTensor flattened_d_out; - framework::LoDTensor flattened_d_x; - flattened_out.ShareDataWith(*Out).Resize(flattened_dims); - flattened_d_out.ShareDataWith(*dOut).Resize(flattened_dims); - flattened_d_x.ShareDataWith(*dX).Resize(flattened_dims); + int rank = Out->dims().size(); + Tensor Out_2d = framework::ReshapeToMatrix(*Out, rank - 1); + Tensor dOut_2d = framework::ReshapeToMatrix(*dOut, rank - 1); + Tensor dX_2d = framework::ReshapeToMatrix(*dX, rank - 1); math::SoftmaxGradFunctor()( - context.template device_context(), &flattened_out, - &flattened_d_out, &flattened_d_x); + context.template device_context(), &Out_2d, &dOut_2d, + &dX_2d); } }; diff --git a/paddle/fluid/platform/profiler.cc b/paddle/fluid/platform/profiler.cc index d0286719b9..652a6ec7a4 100644 --- a/paddle/fluid/platform/profiler.cc +++ b/paddle/fluid/platform/profiler.cc @@ -270,12 +270,13 @@ struct EventItem { double min_time; double max_time; double ave_time; + float ratio; }; // Print results void PrintProfiler(const std::vector>& events_table, const std::string& sorted_domain, const size_t name_width, - const size_t data_width) { + const size_t data_width, double total) { // Output header information std::cout << "\n------------------------->" << " Profiling Report " @@ -300,7 +301,8 @@ void PrintProfiler(const std::vector>& events_table, std::cout << std::setw(name_width) << "Event" << std::setw(data_width) << "Calls" << std::setw(data_width) << "Total" << std::setw(data_width) << "Min." << std::setw(data_width) - << "Max." << std::setw(data_width) << "Ave." << std::endl; + << "Max." << std::setw(data_width) << "Ave." + << std::setw(data_width) << "Ratio." << std::endl; for (size_t i = 0; i < events_table.size(); ++i) { for (size_t j = 0; j < events_table[i].size(); ++j) { const EventItem& event_item = events_table[i][j]; @@ -309,7 +311,9 @@ void PrintProfiler(const std::vector>& events_table, << std::setw(data_width) << event_item.total_time << std::setw(data_width) << event_item.min_time << std::setw(data_width) << event_item.max_time - << std::setw(data_width) << event_item.ave_time << std::endl; + << std::setw(data_width) << event_item.ave_time + << std::setw(data_width) << event_item.total_time / total + << std::endl; } } std::cout << std::endl; @@ -359,6 +363,7 @@ void ParseEvents(const std::vector>& events, std::vector> events_table; size_t max_name_width = 0; + double total = 0.; // the total time for (size_t i = 0; i < events.size(); i++) { std::list pushed_events; std::vector event_items; @@ -379,6 +384,7 @@ void ParseEvents(const std::vector>& events, g_state == ProfilerState::kAll) ? rit->CudaElapsedMs(events[i][j]) : rit->CpuElapsedMs(events[i][j]); + total += event_time; std::string event_name = "thread" + std::to_string(rit->thread_id()) + "::" + rit->name(); @@ -387,7 +393,8 @@ void ParseEvents(const std::vector>& events, if (event_idx.find(event_name) == event_idx.end()) { event_idx[event_name] = event_items.size(); EventItem event_item = {event_name, 1, event_time, - event_time, event_time, event_time}; + event_time, event_time, event_time, + 0.}; event_items.push_back(event_item); } else { int index = event_idx[event_name]; @@ -431,7 +438,7 @@ void ParseEvents(const std::vector>& events, } // Print report - PrintProfiler(events_table, sorted_domain, max_name_width + 4, 12); + PrintProfiler(events_table, sorted_domain, max_name_width + 4, 12, total); } void DisableProfiler(EventSortingKey sorted_key, diff --git a/paddle/fluid/pybind/protobuf.cc b/paddle/fluid/pybind/protobuf.cc index 2199f5311f..be623703c2 100644 --- a/paddle/fluid/pybind/protobuf.cc +++ b/paddle/fluid/pybind/protobuf.cc @@ -301,7 +301,8 @@ void BindOpDesc(pybind11::module *m) { std::string ser(seriralized); self.SetAttr(name, ser); }) - .def("block_attr", &pd::OpDesc::GetBlockAttr) + .def("block_attr_id", &pd::OpDesc::GetBlockAttrId) + .def("blocks_attr_ids", &pd::OpDesc::GetBlocksAttrIds) .def("check_attrs", &pd::OpDesc::CheckAttrs) .def("infer_shape", &pd::OpDesc::InferShape) .def("infer_var_type", &pd::OpDesc::InferVarType) diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index 6b73974511..fd6a76dd0c 100644 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -344,7 +344,7 @@ def _append_backward_ops_(block, grad_sub_block_list = [] # If the op has its own sub-block, deal with the sub-block first if op.has_attr("sub_block"): - sub_block = program.block(op.block_attr("sub_block")) + sub_block = program.block(op.block_attr_id("sub_block")) grad_sub_block = program.create_block() grad_sub_block._set_forward_block_idx(sub_block.idx) cb = _callback_lookup_(op) @@ -406,7 +406,7 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map): for op_idx in range(start_op_idx, block.desc.op_size()): op_desc = block.desc.op(op_idx) if op_desc.has_attr("sub_block"): - sub_block = block.program.block(op_desc.block_attr("sub_block")) + sub_block = block.program.block(op_desc.block_attr_id("sub_block")) _append_backward_vars_(sub_block, 0, grad_to_var, grad_info_map) new_vars = set() # create new gradient variables diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 3d7c29c6ea..45b3abb88c 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -476,23 +476,25 @@ class Operator(object): attrs=None): self.block = block self.desc = desc - self.attrs = attrs - if self.attrs is None: - self.attrs = dict() + # note: not add self.attrs here: + # https://github.com/PaddlePaddle/Paddle/pull/12583#pullrequestreview-145093173 + op_attrs = attrs + if op_attrs is None: + op_attrs = dict() del attrs op_maker = core.op_proto_and_checker_maker - if op_maker.kOpRoleAttrName() not in self.attrs: - self.attrs[op_maker.kOpRoleAttrName()] = self.block.program.op_role + if op_maker.kOpRoleAttrName() not in op_attrs: + op_attrs[op_maker.kOpRoleAttrName()] = self.block.program.op_role role_var_name = op_maker.kOpRoleVarAttrName() if len(self.block.program. - op_role_var) != 0 and role_var_name not in self.attrs: - self.attrs[role_var_name] = self.block.program.op_role_var + op_role_var) != 0 and role_var_name not in op_attrs: + op_attrs[role_var_name] = self.block.program.op_role_var - if role_var_name in self.attrs and len(self.attrs[role_var_name]) == 0: - del self.attrs[role_var_name] + if role_var_name in op_attrs and len(op_attrs[role_var_name]) == 0: + del op_attrs[role_var_name] if len(self.desc.type()) != 0: return @@ -576,15 +578,14 @@ class Operator(object): arg.op = self self.desc.set_output(out_proto.name, out_arg_names) - if self.attrs is not None: - if not isinstance(self.attrs, dict): + if op_attrs is not None: + if not isinstance(op_attrs, dict): raise TypeError("'attrs' should be a dict.") for attr in proto.attrs: attr_name = attr.name - if (attr_name not in self.attrs) or ( - self.attrs[attr_name] is None): + if (attr_name not in op_attrs) or (op_attrs[attr_name] is None): continue - attr_val = self.attrs[attr_name] + attr_val = op_attrs[attr_name] self._update_desc_attr(attr_name, attr_val) self.desc.check_attrs() @@ -732,7 +733,6 @@ class Operator(object): Raises: ValueError: If the type of value doesn't match with desc.attr_type(name). """ - self.attrs[name] = val self._update_desc_attr(name, val) def _update_desc_attr(self, name, val): @@ -774,9 +774,9 @@ class Operator(object): """ return self.desc.attr(name) - def block_attr(self, name): + def block_attr_id(self, name): """ - Get the block attribute by name. + Get the block attribute's id by name. Args: name(str): the attribute name. @@ -784,22 +784,74 @@ class Operator(object): Returns: int: the block index. """ - return self.desc.block_attr(name) + return self.desc.block_attr_id(name) + + def block_attr(self, name): + """ + Get the block attribute by name. + + Args: + name(str): the attribute name. + + Returns: + block: the block attribute. + """ + + id = self.block_attr_id(name) + assert (id >= 0 and id < len(self.block.program.blocks)) + return self.block.program.blocks[id] + + def blocks_attr(self, name): + """ + Get the blocks attribute by name. + + Args: + name(str): the attribute name. + + Returns: + list: list of the blocks attribute. + """ + attrs = [] + for i in self.blocks_attr_ids(name): + assert (i >= 0 and i < len(self.block.program.blocks)) + attrs.append(self.block.program.blocks[i]) + + return attrs + + def blocks_attr_ids(self, name): + """ + Get the blocks attribute's ids by name. + + Args: + name(str): the attribute name. + + Returns: + list: list of the blocks ids. + """ + + return self.desc.blocks_attr_ids(name) def all_attrs(self): """ Get the attribute dict. Returns: - dict: The Operator's attribute dict. + dict: The Operator's attribute dict, name->attr. """ attr_names = self.attr_names attr_map = {} for n in attr_names: - if n == 'sub_block': + attr_type = self.desc.attr_type(n) + if attr_type == core.AttrType.BLOCK: attr_map[n] = self.block_attr(n) - else: - attr_map[n] = self.attr(n) + continue + + if attr_type == core.AttrType.BLOCKS: + attr_map[n] = self.blocks_attr(n) + continue + + attr_map[n] = self.attr(n) + return attr_map @@ -1518,11 +1570,17 @@ class Program(object): The two code snippets above will generate same programs. """ if for_test: - p = self.inference_optimize() + p = self.inference_optimize(export_for_deployment=False) else: p = Program() + p.current_block_idx = self.current_block_idx + p._seed = self._seed p.desc = core.ProgramDesc(self.desc) - p.blocks = [Block(p, i) for i in range(self.desc.num_blocks())] + p.blocks = [Block(p, i) for i in xrange(self.desc.num_blocks())] + + p._current_role = self._current_role + p._op_role_var = self._op_role_var + p._sync_with_cpp() p._copy_param_info_from(self) @@ -1578,7 +1636,7 @@ class Program(object): res._sync_with_cpp() return res - def inference_optimize(self): + def inference_optimize(self, export_for_deployment=True): """ This method will create a new program and do following adjustments on it: 1. Remove all reader variables and their creator ops if exist. @@ -1589,6 +1647,10 @@ class Program(object): attribute of operators to :code:`True`. All the :code:`Parameter` information will be lost. + Args: + export_for_deployment(bool): remove the read ops that are added by py_reader + for cpp inference library + Notes: This API is a very low level API. Use :code:`Program.clone(for_test=True)` instead. @@ -1603,16 +1665,17 @@ class Program(object): # remove all readers and the read_op if exist read_op_idx = 0 root_block = res.desc.block(0) - while True: - if read_op_idx >= root_block.op_size() or root_block.op( - read_op_idx).type() == 'read': - break - read_op_idx += 1 - if read_op_idx < root_block.op_size(): - root_block._remove_op(0, read_op_idx + 1) - for var in root_block.all_vars(): - if var.type() == core.VarDesc.VarType.READER: - root_block._remove_var(var.name()) + if export_for_deployment: + while True: + if read_op_idx >= root_block.op_size() or root_block.op( + read_op_idx).type() == 'read': + break + read_op_idx += 1 + if read_op_idx < root_block.op_size(): + root_block._remove_op(0, read_op_idx + 1) + for var in root_block.all_vars(): + if var.type() == core.VarDesc.VarType.READER: + root_block._remove_var(var.name()) # change all `is_test` attributes to True for i in range(res.desc.num_blocks()): diff --git a/python/paddle/fluid/initializer.py b/python/paddle/fluid/initializer.py index 83290ac608..3f740dd7c5 100644 --- a/python/paddle/fluid/initializer.py +++ b/python/paddle/fluid/initializer.py @@ -264,7 +264,8 @@ class NormalInitializer(Initializer): "dtype": int(var.dtype), "mean": self._mean, "std": self._std_dev, - "seed": self._seed + "seed": self._seed, + "use_mkldnn": False }) var.op = op return op diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 55e517f1f4..af73421032 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -555,7 +555,8 @@ def save_inference_model(dirname, executor, main_program=None, model_filename=None, - params_filename=None): + params_filename=None, + export_for_deployment=True): """ Prune the given `main_program` to build a new program especially for inference, and then save it and all related parameters to given `dirname` by the `executor`. @@ -577,6 +578,8 @@ def save_inference_model(dirname, params_filename(str|None): The name of file to save all related parameters. If it is setted None, parameters will be saved in separate files . + export_for_deployment(bool): remove the read ops that are added by py_reader + for cpp inference lib. Default True Returns: None @@ -643,7 +646,8 @@ def save_inference_model(dirname, copy_program.desc.flush() pruned_program = copy_program.prune(targets=target_vars) - inference_program = pruned_program.inference_optimize() + inference_program = pruned_program.inference_optimize( + export_for_deployment=export_for_deployment) fetch_var_names = [v.name for v in target_vars] prepend_feed_ops(inference_program, feeded_var_names) diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 0800c02d9e..b996c83688 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -20,7 +20,9 @@ from .layer_function_generator import autodoc, templatedoc from ..layer_helper import LayerHelper from . import tensor from . import nn +from . import ops import math +import numpy from functools import reduce __all__ = [ @@ -264,10 +266,11 @@ def detection_output(loc, prior_box_var=prior_box_var, target_box=loc, code_type='decode_center_size') - old_shape = scores.shape - scores = nn.reshape(x=scores, shape=(-1, old_shape[-1])) + compile_shape = scores.shape + run_shape = ops.shape(scores) + scores = nn.flatten(x=scores, axis=2) scores = nn.softmax(input=scores) - scores = nn.reshape(x=scores, shape=old_shape) + scores = nn.reshape(x=scores, shape=compile_shape, actual_shape=run_shape) scores = nn.transpose(scores, perm=[0, 2, 1]) scores.stop_gradient = True nmsed_outs = helper.create_tmp_variable(dtype=decoded_box.dtype) @@ -677,9 +680,10 @@ def ssd_loss(location, raise ValueError("Only support mining_type == max_negative now.") num, num_prior, num_class = confidence.shape + conf_shape = ops.shape(confidence) def __reshape_to_2d(var): - return nn.reshape(x=var, shape=[-1, var.shape[-1]]) + return nn.flatten(x=var, axis=2) # 1. Find matched boundding box by prior box. # 1.1 Compute IOU similarity between ground-truth boxes and prior boxes. @@ -690,7 +694,8 @@ def ssd_loss(location, # 2. Compute confidence for mining hard examples # 2.1. Get the target label based on matched indices - gt_label = nn.reshape(x=gt_label, shape=gt_label.shape + (1, )) + gt_label = nn.reshape( + x=gt_label, shape=(len(gt_label.shape) - 1) * (0, ) + (-1, 1)) gt_label.stop_gradient = True target_label, _ = target_assign( gt_label, matched_indices, mismatch_value=background_label) @@ -701,9 +706,12 @@ def ssd_loss(location, target_label = __reshape_to_2d(target_label) target_label.stop_gradient = True conf_loss = nn.softmax_with_cross_entropy(confidence, target_label) - # 3. Mining hard examples - conf_loss = nn.reshape(x=conf_loss, shape=(num, num_prior)) + conf_loss = nn.reshape( + x=conf_loss, + shape=(num, num_prior), + actual_shape=ops.slice( + conf_shape, axes=[0], starts=[0], ends=[2])) conf_loss.stop_gradient = True neg_indices = helper.create_tmp_variable(dtype='int32') dtype = matched_indices.dtype @@ -772,7 +780,11 @@ def ssd_loss(location, # 5.3 Compute overall weighted loss. loss = conf_loss_weight * conf_loss + loc_loss_weight * loc_loss # reshape to [N, Np], N is the batch size and Np is the prior box number. - loss = nn.reshape(x=loss, shape=[-1, num_prior]) + loss = nn.reshape( + x=loss, + shape=(num, num_prior), + actual_shape=ops.slice( + conf_shape, axes=[0], starts=[0], ends=[2])) loss = nn.reduce_sum(loss, dim=1, keep_dim=True) if normalize: normalizer = nn.reduce_sum(target_loc_weight) @@ -1005,13 +1017,7 @@ def multi_box_head(inputs, """ def _reshape_with_axis_(input, axis=1): - if not (axis > 0 and axis < len(input.shape)): - raise ValueError("The axis should be smaller than " - "the arity of input and bigger than 0.") - new_shape = [ - -1, reduce(lambda x, y: x * y, input.shape[axis:len(input.shape)]) - ] - out = nn.reshape(x=input, shape=new_shape) + out = nn.flatten(x=input, axis=axis) return out def _is_list_or_tuple_(data): @@ -1101,11 +1107,13 @@ def multi_box_head(inputs, stride=stride) mbox_loc = nn.transpose(mbox_loc, perm=[0, 2, 3, 1]) - new_shape = [ + compile_shape = [ mbox_loc.shape[0], mbox_loc.shape[1] * mbox_loc.shape[2] * mbox_loc.shape[3] / 4, 4 ] - mbox_loc_flatten = nn.reshape(mbox_loc, shape=new_shape) + run_shape = tensor.assign(numpy.array([0, -1, 4]).astype("int32")) + mbox_loc_flatten = nn.reshape( + mbox_loc, shape=compile_shape, actual_shape=run_shape) mbox_locs.append(mbox_loc_flatten) # get conf @@ -1117,11 +1125,15 @@ def multi_box_head(inputs, padding=pad, stride=stride) conf_loc = nn.transpose(conf_loc, perm=[0, 2, 3, 1]) - new_shape = [ + new_shape = [0, -1, num_classes] + compile_shape = [ conf_loc.shape[0], conf_loc.shape[1] * conf_loc.shape[2] * conf_loc.shape[3] / num_classes, num_classes ] - conf_loc_flatten = nn.reshape(conf_loc, shape=new_shape) + run_shape = tensor.assign( + numpy.array([0, -1, num_classes]).astype("int32")) + conf_loc_flatten = nn.reshape( + conf_loc, shape=compile_shape, actual_shape=run_shape) mbox_confs.append(conf_loc_flatten) if len(box_results) == 1: diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 0960b54123..c75e7eeb43 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -112,6 +112,7 @@ __all__ = [ 'log', 'crop', 'rank_loss', + 'flatten', ] @@ -5361,3 +5362,70 @@ def rank_loss(label, left, right, name=None): "Right": right}, outputs={'Out': out}) return out + + +def flatten(x, axis=1, name=None): + """ + **Flatten layer** + Flattens the input tensor into a 2D matrix. + + Examples: + Case 1: + Given + X.shape = (3, 100, 100, 4) + and + axis = 2 + We get: + Out.shape = (3 * 100, 4 * 100) + + Case 2: + Given + X.shape = (3, 100, 100, 4) + and + axis = 0 + We get: + Out.shape = (1, 3 * 100 * 100 * 4) + + Args: + x (Variable): A tensor of rank >= axis. + axis (int): Indicate up to which input dimensions (exclusive) should + be flattened to the outer dimension of the output. + The value for axis must be in the range [0, R], where R + is the rank of the input tensor. When axis = 0, the shape + of the output tensor is (1, (d_0 X d_1 ... d_n), where the + shape of the input tensor is (d_0, d_1, ... d_n). + name(str|None): A name for this layer(optional). If set None, the layer + will be named automatically. + + Returns: + Variable: A 2D tensor with the contents of the input tensor, with input + dimensions up to axis flattened to the outer dimension of + the output and remaining input dimensions flattened into the + inner dimension of the output. + + Raises: + ValueError: If x is not a variable. + ValueError: If axis is not in range [0, rank(x)]. + + Examples: + + .. code-block:: python + + x = fluid.layers.data(name="x", shape=[4, 4, 3], dtype="float32") + out = fluid.layers.flatten(x=x, axis=2) + """ + helper = LayerHelper('flatten', **locals()) + + if not (isinstance(x, Variable)): + raise ValueError("The input x should be a Variable") + + if not (isinstance(axis, int)) or axis > len(x.shape) or axis < 0: + raise ValueError("The axis should be a int, and in range [0, rank(x)]") + + out = helper.create_tmp_variable(x.dtype) + helper.append_op( + type='flatten', + inputs={"X": x}, + outputs={'Out': out}, + attrs={"axis": axis}) + return out diff --git a/python/paddle/fluid/tests/unittests/test_cross_entropy_op.py b/python/paddle/fluid/tests/unittests/test_cross_entropy_op.py index c5b9e92d69..86ac159323 100644 --- a/python/paddle/fluid/tests/unittests/test_cross_entropy_op.py +++ b/python/paddle/fluid/tests/unittests/test_cross_entropy_op.py @@ -105,5 +105,107 @@ class TestCrossEntropyOp3(OpTest): ["X"], "Y", max_relative_error=0.05, numeric_grad_delta=0.001) +class TestCrossEntropyOp4(OpTest): + """Test high rank tensor cross-entropy with discrete one-hot labels. + """ + + def setUp(self): + self.op_type = "cross_entropy" + shape = [10, 2, 4] + ins_num = np.prod(np.array(shape)) + class_num = 10 + + X_2d = randomize_probability(ins_num, class_num, dtype='float64') + + label_2d = np.random.randint(0, class_num, (ins_num, 1), dtype="int64") + cross_entropy_2d = np.asmatrix( + [[-np.log(X_2d[i][label_2d[i][0]])] for i in range(X_2d.shape[0])], + dtype="float64") + + X = X_2d.reshape(shape + [class_num]) + label = label_2d.reshape(shape + [1]) + cross_entropy = np.array(cross_entropy_2d).reshape(shape + [1]) + + self.inputs = {"X": X, "Label": label} + self.outputs = {"Y": cross_entropy} + self.attrs = {"soft_label": False} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Y", numeric_grad_delta=0.001) + + +class TestCrossEntropyOp5(OpTest): + """Test high rank tensor cross-entropy with vectorized soft labels. + """ + + def setUp(self): + self.op_type = "cross_entropy" + shape = [4, 3] + ins_num = np.prod(np.array(shape)) + class_num = 37 + + X_2d = randomize_probability(ins_num, class_num) + label_2d = np.random.uniform(0.1, 1.0, + [ins_num, class_num]).astype("float32") + label_2d /= label_2d.sum(axis=1, keepdims=True) + cross_entropy_2d = (-label_2d * np.log(X_2d)).sum( + axis=1, keepdims=True).astype("float32") + + X = X_2d.reshape(shape + [class_num]) + label = label_2d.reshape(shape + [class_num]) + cross_entropy = np.array(cross_entropy_2d).reshape(shape + [1]) + + self.inputs = {"X": X, "Label": label} + self.outputs = {"Y": cross_entropy} + self.attrs = {"soft_label": True} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad( + ["X"], "Y", max_relative_error=0.05, numeric_grad_delta=0.001) + + +class TestCrossEntropyOp6(OpTest): + """Test high rank tensor cross-entropy with vectorized one-hot representation of labels. + """ + + def setUp(self): + self.op_type = "cross_entropy" + shape = [4, 3, 2] + ins_num = np.prod(np.array(shape)) + class_num = 17 + + X_2d = randomize_probability(ins_num, class_num) + label_index_2d = np.random.randint( + 0, class_num, (ins_num), dtype="int32") + label_2d = np.zeros(X_2d.shape) + label_2d[np.arange(ins_num), label_index_2d] = 1 + + cross_entropy_2d = np.asmatrix( + [[-np.log(X_2d[i][label_index_2d[i]])] + for i in range(X_2d.shape[0])], + dtype="float32") + + X = X_2d.reshape(shape + [class_num]) + label = label_2d.reshape(shape + [class_num]) + cross_entropy = np.array(cross_entropy_2d).reshape(shape + [1]) + + self.inputs = {"X": X, "Label": label.astype(np.float32)} + self.outputs = {"Y": cross_entropy} + self.attrs = {"soft_label": True} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad( + ["X"], "Y", max_relative_error=0.05, numeric_grad_delta=0.001) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_desc_clone.py b/python/paddle/fluid/tests/unittests/test_desc_clone.py new file mode 100644 index 0000000000..8603d3a5b3 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_desc_clone.py @@ -0,0 +1,196 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import argparse +import time +import math + +import paddle +import paddle.fluid as fluid +import paddle.fluid.profiler as profiler +from paddle.fluid import core +import unittest +from multiprocessing import Process +import os +import signal +import collections + +SEED = 1 +DTYPE = "float32" +paddle.dataset.mnist.fetch() + + +# random seed must set before configuring the network. +# fluid.default_startup_program().random_seed = SEED +def cnn_model(data): + conv_pool_1 = fluid.nets.simple_img_conv_pool( + input=data, + filter_size=5, + num_filters=20, + pool_size=2, + pool_stride=2, + act="relu") + conv_pool_2 = fluid.nets.simple_img_conv_pool( + input=conv_pool_1, + filter_size=5, + num_filters=50, + pool_size=2, + pool_stride=2, + act="relu") + + # TODO(dzhwinter) : refine the initializer and random seed settting + SIZE = 10 + input_shape = conv_pool_2.shape + param_shape = [reduce(lambda a, b: a * b, input_shape[1:], 1)] + [SIZE] + scale = (2.0 / (param_shape[0]**2 * SIZE))**0.5 + + predict = fluid.layers.fc( + input=conv_pool_2, + size=SIZE, + act="softmax", + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.NormalInitializer( + loc=0.0, scale=scale))) + return predict + + +def get_model(batch_size): + # Input data + images = fluid.layers.data(name='pixel', shape=[1, 28, 28], dtype=DTYPE) + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + + # Train program + predict = cnn_model(images) + cost = fluid.layers.cross_entropy(input=predict, label=label) + avg_cost = fluid.layers.mean(x=cost) + + # Evaluator + batch_size_tensor = fluid.layers.create_tensor(dtype='int64') + batch_acc = fluid.layers.accuracy( + input=predict, label=label, total=batch_size_tensor) + + inference_program = fluid.default_main_program().clone() + # Optimization + opt = fluid.optimizer.AdamOptimizer( + learning_rate=0.001, beta1=0.9, beta2=0.999) + + # Reader + train_reader = paddle.batch( + paddle.dataset.mnist.train(), batch_size=batch_size) + test_reader = paddle.batch( + paddle.dataset.mnist.test(), batch_size=batch_size) + opt.minimize(avg_cost) + return inference_program, avg_cost, train_reader, test_reader, batch_acc, predict + + +def get_transpiler(trainer_id, main_program, pserver_endpoints, trainers): + t = fluid.DistributeTranspiler() + t.transpile( + trainer_id=trainer_id, + program=main_program, + pservers=pserver_endpoints, + trainers=trainers) + return t + + +def operator_equal(a, b): + for k, v in a.__dict__.iteritems(): + if isinstance(v, fluid.framework.Program) or \ + isinstance(v, fluid.framework.Block): + continue + + elif isinstance(v, core.OpDesc): + if v.serialize_to_string() != b.__dict__[k].serialize_to_string(): + raise ValueError("In operator_equal not equal:{0}\n".format(k)) + + elif isinstance(v, collections.OrderedDict): + v0 = sorted(v.iteritems(), key=lambda x: x[0]) + v1 = sorted(b.__dict__[k].iteritems(), key=lambda x: x[0]) + + if v0 != v1: + raise ValueError("In operator_equal not equal:{0}\n".format(k)) + + elif (v != b.__dict__[k]): + raise ValueError("In operator_equal not equal:{0}\n".format(k)) + + return True + + +def block_equal(a, b): + for k, v in a.__dict__.iteritems(): + if isinstance(v, core.ProgramDesc) or isinstance( + v, fluid.framework.Program) or isinstance(v, core.BlockDesc): + continue + + elif k == "ops": + for i in range(0, len(a.ops)): + if not operator_equal(a.ops[i], b.ops[i]): + raise ValueError("In block_equal not equal:{0}\n".format(k)) + assert (len(a.ops) == len(b.ops)) + + elif isinstance(v, collections.OrderedDict): + v0 = sorted(v.iteritems(), key=lambda x: x[0]) + v1 = sorted(b.__dict__[k].iteritems(), key=lambda x: x[0]) + + if v0 != v1: + raise ValueError("In block_equal not equal:{0}\n".format(k)) + + elif (v != b.__dict__[k]): + raise ValueError("In block_equal not equal:{0}\n".format(k)) + + return True + + +def program_equal(a, b): + for k, v in a.__dict__.iteritems(): + if isinstance(v, core.ProgramDesc): + continue + + elif k == 'blocks': + for i in range(0, len(a.blocks)): + if not block_equal(a.blocks[i], b.blocks[i]): + raise ValueError("In operator_equal not equal:{0}\n".format( + k)) + return False + assert (len(a.blocks) == len(b.blocks)) + + elif (v != b.__dict__[k]): + raise ValueError("In program_equal not equal:{0}\n".format(k)) + + return True + + +class TestDistMnist(unittest.TestCase): + def test_desc_clone(self): + get_model(batch_size=20) + + pserver_endpoints = "127.0.0.1:9123" + trainers = 1 + current_endpoint = "127.0.0.1:9123" + t = get_transpiler(0, + fluid.default_main_program(), pserver_endpoints, + trainers) + + pserver_prog = t.get_pserver_program(current_endpoint) + startup_prog = t.get_startup_program(current_endpoint, pserver_prog) + main = pserver_prog.clone() + startup = startup_prog.clone() + + self.assertTrue(program_equal(main, pserver_prog)) + self.assertTrue(program_equal(startup, startup_prog)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index 1deccbe4af..4379463aca 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -130,7 +130,7 @@ class TestDistBase(unittest.TestCase): self._ps_endpoints = "127.0.0.1:9123,127.0.0.1:9124" self._python_interp = "python" - def start_pserver(self, model_file): + def start_pserver(self, model_file, check_error_log): ps0_ep, ps1_ep = self._ps_endpoints.split(",") ps0_cmd = "%s %s pserver %s 0 %s %d TRUE" % \ (self._python_interp, model_file, self._ps_endpoints, ps0_ep, @@ -139,11 +139,23 @@ class TestDistBase(unittest.TestCase): (self._python_interp, model_file, self._ps_endpoints, ps1_ep, self._trainers) + ps0_pipe = subprocess.PIPE + ps1_pipe = subprocess.PIPE + if check_error_log: + print("ps0_cmd:", ps0_cmd) + print("ps1_cmd:", ps1_cmd) + ps0_pipe = open("/tmp/ps0_err.log", "wb") + ps1_pipe = open("/tmp/ps1_err.log", "wb") + ps0_proc = subprocess.Popen( - ps0_cmd.split(" "), stdout=subprocess.PIPE, stderr=subprocess.PIPE) + ps0_cmd.split(" "), stdout=subprocess.PIPE, stderr=ps0_pipe) ps1_proc = subprocess.Popen( - ps1_cmd.split(" "), stdout=subprocess.PIPE, stderr=subprocess.PIPE) - return ps0_proc, ps1_proc + ps1_cmd.split(" "), stdout=subprocess.PIPE, stderr=ps1_pipe) + + if not check_error_log: + return ps0_proc, ps1_proc, None, None + else: + return ps0_proc, ps1_proc, ps0_pipe, ps1_pipe def _wait_ps_ready(self, pid): retry_times = 50 @@ -160,7 +172,7 @@ class TestDistBase(unittest.TestCase): (e, retry_times)) retry_times -= 1 - def check_with_place(self, model_file, delta=1e-3): + def check_with_place(self, model_file, delta=1e-3, check_error_log=False): # *ATTENTION* THIS TEST NEEDS AT LEAST 2GPUS TO RUN required_envs = { "PATH": os.getenv("PATH"), @@ -169,17 +181,32 @@ class TestDistBase(unittest.TestCase): "FLAGS_fraction_of_gpu_memory_to_use": "0.15", "FLAGS_cudnn_deterministic": "1" } + + if check_error_log: + required_envs["GLOG_v"] = "7" + required_envs["GLOG_logtostderr"] = "1" + # Run local to get a base line env_local = {"CUDA_VISIBLE_DEVICES": "0"} env_local.update(required_envs) local_cmd = "%s %s trainer %s 0 %s %d FLASE" % \ (self._python_interp, model_file, "127.0.0.1:1234", "127.0.0.1:1234", 1) - local_proc = subprocess.Popen( - local_cmd.split(" "), - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - env=env_local) + if not check_error_log: + local_proc = subprocess.Popen( + local_cmd.split(" "), + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env=env_local) + else: + print("trainer cmd:", local_cmd) + err_log = open("/tmp/trainer.err.log", "wb") + local_proc = subprocess.Popen( + local_cmd.split(" "), + stdout=subprocess.PIPE, + stderr=err_log, + env=env_local) + local_proc.wait() out, err = local_proc.communicate() local_ret = out @@ -187,7 +214,8 @@ class TestDistBase(unittest.TestCase): sys.stderr.write('local_stderr: %s\n' % err) # Run dist train to compare with local results - ps0, ps1 = self.start_pserver(model_file) + ps0, ps1, ps0_pipe, ps1_pipe = self.start_pserver(model_file, + check_error_log) self._wait_ps_ready(ps0.pid) self._wait_ps_ready(ps1.pid) @@ -205,15 +233,23 @@ class TestDistBase(unittest.TestCase): env1.update(required_envs) FNULL = open(os.devnull, 'w') + tr0_pipe = subprocess.PIPE + tr1_pipe = subprocess.PIPE + if check_error_log: + print("tr0_cmd:", tr0_cmd) + print("tr1_cmd:", tr1_cmd) + tr0_pipe = open("/tmp/tr0_err.log", "wb") + tr1_pipe = open("/tmp/tr1_err.log", "wb") + tr0_proc = subprocess.Popen( tr0_cmd.split(" "), stdout=subprocess.PIPE, - stderr=subprocess.PIPE, + stderr=tr0_pipe, env=env0) tr1_proc = subprocess.Popen( tr1_cmd.split(" "), stdout=subprocess.PIPE, - stderr=subprocess.PIPE, + stderr=tr1_pipe, env=env1) tr0_proc.wait() @@ -230,6 +266,13 @@ class TestDistBase(unittest.TestCase): local_first_loss = eval(local_lines[0])[0] local_last_loss = eval(local_lines[1])[0] + # close trainer file + if check_error_log: + tr0_pipe.close() + tr1_pipe.close() + + ps0_pipe.close() + ps1_pipe.close() # FIXME: use terminate() instead of sigkill. os.kill(ps0.pid, signal.SIGKILL) os.kill(ps1.pid, signal.SIGKILL) diff --git a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py index b6f4f0726f..0543e62381 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py +++ b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py @@ -259,7 +259,7 @@ class TestLRDecayConditional(TranspilerTest): serv_op = pserver.blocks[0].ops[0] sub_blocks = [] optimize_blocks = [] - for b in serv_op.attrs["optimize_blocks"]: + for b in serv_op.all_attrs()["optimize_blocks"]: optimize_blocks.append(b.idx) for b in pserver.blocks: if b.idx not in optimize_blocks: diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 8f2dac786d..38a138a8fa 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -465,6 +465,17 @@ class TestBook(unittest.TestCase): self.assertIsNotNone(out) print(str(program)) + def test_flatten(self): + program = Program() + with program_guard(program): + x = layers.data( + name='x', + append_batch_size=False, + shape=[4, 4, 3], + dtype="float32") + out = layers.flatten(x, axis=1, name="flatten") + self.assertIsNotNone(out) + def test_shape(self): program = Program() with program_guard(program): diff --git a/python/paddle/fluid/tests/unittests/test_program.py b/python/paddle/fluid/tests/unittests/test_program.py index c51a482393..0997afc97a 100644 --- a/python/paddle/fluid/tests/unittests/test_program.py +++ b/python/paddle/fluid/tests/unittests/test_program.py @@ -17,6 +17,7 @@ import unittest from paddle.fluid.framework import Program, default_main_program, program_guard, grad_var_name import paddle.fluid.layers as layers +import paddle.fluid as fluid main_program = default_main_program() @@ -98,6 +99,39 @@ class TestProgram(unittest.TestCase): new_program = main_program.clone() self.assertNotEqual(0, len(new_program.blocks[0].all_parameters())) + def test_program_inference_optimize(self): + def net(): + reader = fluid.layers.py_reader( + capacity=10, + shapes=[[-1, 10], [-1, 1]], + lod_levels=[0, 0], + dtypes=['float32', 'int64'], + use_double_buffer=True) + in_data, label = fluid.layers.read_file(reader) + predict_label = fluid.layers.fc(in_data, size=2, act='softmax') + loss = fluid.layers.mean( + fluid.layers.cross_entropy( + input=predict_label, label=label)) + + optimizer = fluid.optimizer.Adam() + optimizer.minimize(loss) + + startup_program = fluid.Program() + main_program = fluid.Program() + with fluid.program_guard(main_program, startup_program): + net() + no_read_program = main_program.inference_optimize() + keep_read_program = main_program.inference_optimize( + export_for_deployment=False) + no_read_ops = no_read_program.global_block().ops + keep_read_ops = keep_read_program.global_block().ops + self.assertEqual(len(keep_read_ops) - len(no_read_ops), 2) + self.assertEqual(keep_read_ops[0].type, 'create_double_buffer_reader') + self.assertEqual(keep_read_ops[1].type, 'read') + + for i in range(len(no_read_ops)): + self.assertEqual(no_read_ops[i].type, keep_read_ops[i + 2].type) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_protobuf_descs.py b/python/paddle/fluid/tests/unittests/test_protobuf_descs.py index 621dd68134..9853fb4e9a 100644 --- a/python/paddle/fluid/tests/unittests/test_protobuf_descs.py +++ b/python/paddle/fluid/tests/unittests/test_protobuf_descs.py @@ -68,7 +68,7 @@ class TestOpDesc(unittest.TestCase): self.assertEqual(8, len(op.attr_names())) op.set_block_attr("block_attr", program_desc.block(0)) - self.assertEqual(0, op.block_attr("block_attr")) + self.assertEqual(0, op.block_attr_id("block_attr")) mul_op = block.append_op() mul_op.set_type("mul") diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 1bb86acdf8..cd6cf558d5 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -530,7 +530,10 @@ class DistributeTranspiler(object): pserver_program._sync_with_cpp() return pserver_program - def get_startup_program(self, endpoint, pserver_program): + def get_startup_program(self, + endpoint, + pserver_program, + startup_program=None): """ Get startup program for current parameter server. Modify operator input variables if there are variables that @@ -540,12 +543,17 @@ class DistributeTranspiler(object): endpoint (str): current pserver endpoint. pserver_program (Program): call get_pserver_program first and pass the result here. + startup_program (Program): if pass None, will use + default_startup_program Returns: Program: parameter server side startup program. """ s_prog = Program() - orig_s_prog = default_startup_program() + if not startup_program: + orig_s_prog = default_startup_program() + else: + orig_s_prog = startup_program s_prog.random_seed = orig_s_prog.random_seed params = self.param_grad_ep_mapping[endpoint]["params"] @@ -584,12 +592,12 @@ class DistributeTranspiler(object): if op.type in [ "gaussian_random", "fill_constant", "uniform_random" ]: - op.attrs["shape"] = new_outputs["Out"].shape + op.set_attr("shape", list(new_outputs["Out"].shape)) s_prog.global_block().append_op( type=op.type, inputs=new_inputs, outputs=new_outputs, - attrs=op.attrs) + attrs=op.all_attrs()) return s_prog # ====================== private transpiler functions ===================== @@ -603,7 +611,7 @@ class DistributeTranspiler(object): self.table_name = None for op in self.origin_program.global_block().ops: if op.type == LOOKUP_TABLE_TYPE: - if op.attrs['is_distributed'] is True: + if op.attr('is_distributed') is True: if self.table_name is None: self.table_name = op.input("W")[0] if self.table_name != op.input("W")[0]: @@ -1263,7 +1271,7 @@ class DistributeTranspiler(object): type=opt_op.type, inputs=new_inputs, outputs=outputs, - attrs=opt_op.attrs) + attrs=opt_op.all_attrs()) def _is_splited_grad_var(self, var, var_dict): grad_block = None @@ -1296,7 +1304,7 @@ class DistributeTranspiler(object): block._clone_variable(var) return block.append_op( - type=op.type, inputs=inputs, outputs=outputs, attrs=op.attrs) + type=op.type, inputs=inputs, outputs=outputs, attrs=op.all_attrs()) def _append_pserver_non_opt_ops(self, optimize_block, opt_op): program = optimize_block.program @@ -1337,7 +1345,7 @@ class DistributeTranspiler(object): type=opt_op.type, inputs=inputs, outputs=outputs, - attrs=opt_op.attrs) + attrs=opt_op.all_attrs()) def _is_op_connected(self, op1, op2): # If one op's input is another op's output or @@ -1442,8 +1450,8 @@ class DistributeTranspiler(object): # optimize op_maker = core.op_proto_and_checker_maker optimize_role = core.op_proto_and_checker_maker.OpRole.Optimize - if op_maker.kOpRoleAttrName() in op.attrs and \ - int(op.attrs[op_maker.kOpRoleAttrName()]) == int(optimize_role): + if op_maker.kOpRoleAttrName() in op.attr_names and \ + int(op.all_attrs()[op_maker.kOpRoleAttrName()]) == int(optimize_role): return True return False @@ -1466,8 +1474,8 @@ class DistributeTranspiler(object): # and op_role_var to get the pair. for input_name in op.input_arg_names: if input_name.find("@GRAD") != -1 and \ - op.attrs[RPC_OP_ROLE_ATTR_NAME]: - param_name = op.attrs[OP_ROLE_VAR_ATTR_NAME][0] + op.attr(RPC_OP_ROLE_ATTR_NAME): + param_name = op.attr(OP_ROLE_VAR_ATTR_NAME)[0] params_grads.append([ origin_var_dict[param_name], origin_var_dict[input_name]