Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into reduce-elementwise-warning

update-doc-pybind
qiaolongfei 7 years ago
commit abdcd8282e

@ -24,6 +24,9 @@ static ProgramDesc* g_program_desc = nullptr;
ProgramDesc& GetProgramDesc() { ProgramDesc& GetProgramDesc() {
if (g_program_desc == nullptr) { if (g_program_desc == nullptr) {
g_program_desc = new ProgramDesc(); g_program_desc = new ProgramDesc();
auto root_block = g_program_desc->mutable_blocks()->Add();
root_block->set_idx(0);
root_block->set_parent_idx(-1);
} }
return *g_program_desc; return *g_program_desc;
} }

@ -18,7 +18,7 @@ namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
template class SoftmaxFunctor<platform::GPUPlace, float>; template class SoftmaxFunctor<platform::CPUPlace, float>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators

@ -82,40 +82,38 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext& ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Logits"), PADDLE_ENFORCE(ctx->HasInput("Logits"),
"Input(Logits) should be not null."); "Input(Logits) should be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
"Input(Label) should be not null.");
PADDLE_ENFORCE(ctx->HasOutput("Softmax"),
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Softmax"), "Output(Softmax) should be not null.");
"Output(Softmax) should be not null."); PADDLE_ENFORCE(ctx->HasOutput("Loss"), "Output(Loss) should be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Loss"),
"Output(Loss) should be not null."); auto logits_dims = ctx->GetInputDim("Logits");
auto labels_dims = ctx->GetInputDim("Label");
const Tensor* logits = ctx.Input<Tensor>("Logits");
const Tensor* labels = ctx.Input<Tensor>("Label");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
logits->dims().size(), 2UL, logits_dims.size(), 2UL,
"The input of softmax_with_cross_entropy should be a 2-D tensor."); "The input of softmax_with_cross_entropy should be a 2-D tensor.");
PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("Label")->dims().size(), 2UL, PADDLE_ENFORCE_EQ(labels_dims.size(), 2UL,
"The labels should be a 2-D tensor."); "The labels should be a 2-D tensor.");
if (ctx.Attr<bool>("softLabel")) { if (ctx->Attrs().Get<bool>("softLabel")) {
PADDLE_ENFORCE_EQ(logits->dims()[1], labels->dims()[1], PADDLE_ENFORCE_EQ(logits_dims[1], labels_dims[1],
"If Attr(softLabel) == true, the 2nd dimension of " "If Attr(softLabel) == true, the 2nd dimension of "
"Input(X) and Input(Label) should be equal."); "Input(X) and Input(Label) should be equal.");
} else { } else {
PADDLE_ENFORCE_EQ(labels->dims()[1], 1UL, PADDLE_ENFORCE_EQ(labels_dims[1], 1UL,
"If Attr(softLabel) == false, the 2nd dimension of " "If Attr(softLabel) == false, the 2nd dimension of "
"Input(Label) should be 1."); "Input(Label) should be 1.");
} }
ctx.Output<framework::Tensor>("Softmax")->Resize(logits->dims()); ctx->SetOutputDim("Softmax", logits_dims);
ctx.Output<framework::Tensor>("Loss")->Resize({logits->dims()[0], 1}); ctx->SetOutputDim("Loss", {logits_dims[0], 1});
ctx.ShareLoD("Logits", /*->*/ "Softmax"); ctx->ShareLoD("Logits", /*->*/ "Softmax");
ctx.ShareLoD("Logits", /*->*/ "Loss"); ctx->ShareLoD("Logits", /*->*/ "Loss");
} }
}; };
@ -124,33 +122,32 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext& ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Loss")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Loss")),
"Input(Loss@Grad) should not be null."); "Input(Loss@Grad) should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Softmax"), PADDLE_ENFORCE(ctx->HasInput("Softmax"),
"Input(Softmax) should be not null."); "Input(Softmax) should be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
"Input(Label) should be not null."); PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Logits")),
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar(framework::GradVarName("Logits")), "Output(Logits@Grad) should be not null.");
"Output(Logits@Grad) should be not null.");
auto softmax_dims = ctx->GetInputDim("Softmax");
const Tensor* softmax = ctx.Input<Tensor>("Softmax"); auto labels_dims = ctx->GetInputDim("Label");
const Tensor* labels = ctx.Input<Tensor>("Label"); PADDLE_ENFORCE_EQ(labels_dims.size(), 2UL,
PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("Label")->dims().size(), 2UL,
"The labels should be a 2-D tensor."); "The labels should be a 2-D tensor.");
if (ctx.Attr<bool>("softLabel")) { if (ctx->Attrs().Get<bool>("softLabel")) {
PADDLE_ENFORCE_EQ(softmax->dims()[1], labels->dims()[1], PADDLE_ENFORCE_EQ(softmax_dims[1], labels_dims[1],
"When Attr(softLabel) == true, the 2nd dimension of " "When Attr(softLabel) == true, the 2nd dimension of "
"Input(X) and Input(Label) should be equal."); "Input(X) and Input(Label) should be equal.");
} else { } else {
PADDLE_ENFORCE_EQ(labels->dims()[1], 1UL, PADDLE_ENFORCE_EQ(labels_dims[1], 1UL,
"When Attr(softLabel) == false, the 2nd dimension of " "When Attr(softLabel) == false, the 2nd dimension of "
"Input(Label) should be 1."); "Input(Label) should be 1.");
} }
ctx.Output<framework::LoDTensor>(framework::GradVarName("Logits")) ctx->SetOutputDim(framework::GradVarName("Logits"),
->Resize(ctx.Input<Tensor>("Softmax")->dims()); ctx->GetInputDim("Softmax"));
} }
}; };

@ -1,6 +1,6 @@
if(WITH_PYTHON) if(WITH_PYTHON)
cc_library(paddle_pybind SHARED cc_library(paddle_pybind SHARED
SRCS pybind.cc SRCS pybind.cc protobuf.cc
DEPS pybind python backward DEPS pybind python backward
${GLOB_OP_LIB}) ${GLOB_OP_LIB})
endif(WITH_PYTHON) endif(WITH_PYTHON)

File diff suppressed because it is too large Load Diff

@ -0,0 +1,35 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <Python.h>
#include <fstream>
#include <vector>
#include "paddle/framework/op_registry.h"
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
namespace py = pybind11;
namespace paddle {
namespace pybind {
void BindProgramDesc(py::module& m);
void BindBlockDesc(py::module& m);
void BindVarDsec(py::module& m);
void BindOpDesc(py::module& m);
} // namespace pybind
} // namespace paddle

@ -12,13 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <Python.h> #include "paddle/pybind/protobuf.h"
#include <fstream>
#include <vector>
#include "paddle/framework/backward.h" #include "paddle/framework/backward.h"
#include "paddle/framework/lod_tensor.h" #include "paddle/framework/lod_tensor.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/cond_op.h" #include "paddle/operators/cond_op.h"
#include "paddle/operators/net_op.h" #include "paddle/operators/net_op.h"
#include "paddle/operators/recurrent_op.h" #include "paddle/operators/recurrent_op.h"
@ -27,11 +24,6 @@ limitations under the License. */
#include "paddle/pybind/pybind.h" #include "paddle/pybind/pybind.h"
#include "paddle/pybind/tensor_py.h" #include "paddle/pybind/tensor_py.h"
#include "paddle/string/to_string.h" #include "paddle/string/to_string.h"
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
namespace py = pybind11;
namespace paddle { namespace paddle {
namespace pybind { namespace pybind {
@ -320,6 +312,11 @@ All parameter, weight, gradient are variables in Paddle.
m.def("is_compile_gpu", IsCompileGPU); m.def("is_compile_gpu", IsCompileGPU);
BindProgramDesc(m);
BindBlockDesc(m);
BindVarDsec(m);
BindOpDesc(m);
return m.ptr(); return m.ptr();
} }
} // namespace pybind } // namespace pybind

@ -5,22 +5,31 @@ from op_test import OpTest
def modified_huber_loss_forward(val): def modified_huber_loss_forward(val):
if val < -1: if val < -1:
return -4 * val return -4. * val
elif val < 1: elif val < 1:
return (1 - val) * (1 - val) return (1. - val) * (1. - val)
else: else:
return 0 return 0.
class TestModifiedHuberLossOp(OpTest): class TestModifiedHuberLossOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = 'modified_huber_loss' self.op_type = 'modified_huber_loss'
samples_num = 32 samples_num = 32
self.inputs = {
'X': np.random.uniform(-1, 1., (samples_num, 1)).astype('float32'), x_np = np.random.uniform(-2., 2., (samples_num, 1)).astype('float32')
'Y': np.random.choice([0, 1], samples_num).reshape((samples_num, 1)) y_np = np.random.choice([0, 1], samples_num).reshape(
} (samples_num, 1)).astype('float32')
product_res = self.inputs['X'] * (2 * self.inputs['Y'] - 1) product_res = x_np * (2. * y_np - 1.)
# keep away from the junction of piecewise function
for pos, val in np.ndenumerate(product_res):
while abs(val - 1.) < 0.05:
x_np[pos] = np.random.uniform(-2., 2.)
y_np[pos] = np.random.choice([0, 1])
product_res[pos] = x_np[pos] * (2 * y_np[pos] - 1)
val = product_res[pos]
self.inputs = {'X': x_np, 'Y': y_np}
loss = np.vectorize(modified_huber_loss_forward)(product_res) loss = np.vectorize(modified_huber_loss_forward)(product_res)
self.outputs = { self.outputs = {
@ -32,7 +41,7 @@ class TestModifiedHuberLossOp(OpTest):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Out', max_relative_error=0.005) self.check_grad(['X'], 'Out', max_relative_error=0.01)
if __name__ == '__main__': if __name__ == '__main__':

@ -0,0 +1,131 @@
import unittest
import paddle.v2.framework.core as core
class TestOpDesc(unittest.TestCase):
def test_op_desc(self):
prog = core.ProgramDesc.__create_program_desc__()
self.assertIsNotNone(prog)
block = prog.block(0)
self.assertIsNotNone(block)
op = block.append_op()
self.assertIsNotNone(op)
op.set_type("test")
self.assertEqual("test", op.type())
op.set_input("X", ["a", "b", "c"])
self.assertEqual(["a", "b", "c"], op.input("X"))
self.assertEqual(["X"], op.input_names())
op.set_output("Out", ["z"])
self.assertEqual(['z'], op.output("Out"))
self.assertEqual(["Out"], op.output_names())
op.set_attr("int_attr", 1)
self.assertEqual(1, op.attr("int_attr"))
self.assertTrue(op.has_attr("int_attr"))
self.assertEqual(core.AttrType.INT, op.attr_type("int_attr"))
op.set_attr("float_attr", -1.32)
self.assertAlmostEqual(-1.32, op.attr("float_attr"), delta=1e-4)
self.assertTrue(op.has_attr("float_attr"))
op.set_attr("bool_attr", False)
self.assertFalse(op.attr("bool_attr"))
op.set_attr("string_attr", "abc")
self.assertEqual("abc", op.attr("string_attr"))
self.assertTrue(op.has_attr("string_attr"))
op.set_attr("ints_attr", [1, 2, 3])
self.assertEqual([1, 2, 3], op.attr("ints_attr"))
expected = [1.2, 2.3, 3.4]
op.set_attr("floats_attr", expected)
for e, a in zip(expected, op.attr("floats_attr")):
self.assertAlmostEqual(e, a, delta=1e-4)
op.set_attr("strings_attr", ["a", "b", "c"])
self.assertEqual(["a", "b", "c"], op.attr("strings_attr"))
op.set_attr("bools_attr", [True, False, True])
self.assertEqual([True, False, True], op.attr("bools_attr"))
self.assertEqual(8, len(op.attr_names()))
op.set_block_attr("block_attr", prog.block(0))
self.assertEqual(0, op.get_block_attr("block_attr"))
class TestProgramDesc(unittest.TestCase):
def test_instance(self):
program_desc = core.ProgramDesc.__create_program_desc__()
self.assertIsNotNone(program_desc)
del program_desc
program_desc = core.ProgramDesc.instance()
self.assertIsNotNone(program_desc)
self.assertIsNotNone(program_desc.block(0))
del program_desc
def test_append_block(self):
prog_desc = core.ProgramDesc.__create_program_desc__()
self.assertIsNotNone(prog_desc)
block_root = prog_desc.block(0)
self.assertIsNotNone(block_root)
self.assertEqual(block_root.id, 0)
block1 = prog_desc.append_block(block_root)
block2 = prog_desc.append_block(block1)
self.assertIsNotNone(block1)
self.assertEqual(block1.id, block2.parent)
self.assertEqual(block_root.id, block1.parent)
block3 = prog_desc.append_block(block_root)
self.assertEqual(block3.parent, block_root.id)
self.assertEqual(prog_desc.block(1).id, 1)
self.assertEqual(4, prog_desc.num_blocks())
class TestVarDesc(unittest.TestCase):
def test_shape(self):
program_desc = core.ProgramDesc.__create_program_desc__()
block = program_desc.block(0)
var = block.new_var('my_var')
src_shape = [3, 2, 10, 8]
var.set_shape(src_shape)
res_shape = var.shape()
self.assertEqual(src_shape, res_shape)
def test_data_type(self):
program_desc = core.ProgramDesc.__create_program_desc__()
block = program_desc.block(0)
var = block.new_var('my_var')
var.set_data_type(core.DataType.INT32)
self.assertEqual(core.DataType.INT32, var.data_type())
class TestBlockDesc(unittest.TestCase):
def test_add_var(self):
prog = core.ProgramDesc.__create_program_desc__()
self.assertIsNotNone(prog)
block = prog.block(0)
self.assertIsNotNone(block)
var1 = block.new_var("var1")
var2 = block.new_var("var2")
var3 = block.new_var("var3")
all_vars = block.all_vars()
self.assertEqual(set(all_vars), set([var1, var2, var3]))
var2_re = block.var("var2")
self.assertEqual(var2_re, var2)
def test_add_op(self):
prog = core.ProgramDesc.__create_program_desc__()
self.assertIsNotNone(prog)
block = prog.block(0)
self.assertIsNotNone(block)
op1 = block.append_op()
op2 = block.append_op()
op0 = block.prepend_op()
all_ops = block.all_ops()
self.assertEqual(all_ops, [op0, op1, op2])
if __name__ == '__main__':
unittest.main()
Loading…
Cancel
Save