Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into profiler_tool

detection_output_fixbug
Yibing Liu 7 years ago
commit 24458ae305

@ -21,6 +21,8 @@ cc_test(variable_test SRCS variable_test.cc)
cc_library(scope SRCS scope.cc DEPS glog)
cc_test(scope_test SRCS scope_test.cc DEPS scope)
cc_library(data_transform SRCS data_transform.cc DEPS tensor framework_proto)
cc_test(data_transform_test SRCS data_transform_test.cc DEPS data_transform device_context)
cc_library(attribute SRCS attribute.cc DEPS framework_proto)
cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc
@ -29,7 +31,8 @@ cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute)
cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker)
cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto)
cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog shape_inference)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog
shape_inference data_transform)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry init)
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog)
@ -65,6 +68,3 @@ cc_library(init SRCS init.cc DEPS gflags device_context place stringpiece)
cc_test(init_test SRCS init_test.cc DEPS init)
cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto)
cc_library(data_transform SRCS data_transform.cc DEPS tensor framework_proto)
cc_test(data_transform_test SRCS data_transform_test.cc DEPS data_transform device_context)

@ -32,17 +32,16 @@ using DataTransformFN =
const Variable& in, Variable* out)>;
using KernelTypePair = std::pair<OpKernelType, OpKernelType>;
static void hash_combine(std::size_t& seed, const OpKernelType& t) {
OpKernelType::Hash kernel_type_hasher;
seed ^= kernel_type_hasher(t) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
struct KernelTypePairHash {
static void HashCombine(const OpKernelType& t, std::size_t* seed) {
OpKernelType::Hash kernel_type_hasher;
(*seed) ^= kernel_type_hasher(t) + 0x9e3779b9 + (*seed << 6) + (*seed >> 2);
}
size_t operator()(const KernelTypePair& kernel_pair) const {
std::size_t seed = 0;
hash_combine(seed, kernel_pair.first);
hash_combine(seed, kernel_pair.second);
HashCombine(kernel_pair.first, &seed);
HashCombine(kernel_pair.second, &seed);
return seed;
}
};

@ -15,6 +15,7 @@ limitations under the License. */
#include <algorithm>
#include <atomic>
#include "paddle/framework/data_transform.h"
#include "paddle/framework/executor.h"
#include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/operator.h"
@ -411,7 +412,38 @@ void OperatorWithKernel::Run(const Scope& scope,
expected_kernel_key);
}
kernel_iter->second->Compute(ctx);
if (actual_kernel_key == expected_kernel_key) {
kernel_iter->second->Compute(ctx);
} else {
Scope& op_scope = scope.NewScope();
auto input_vars = this->InputVars();
for (auto var_name : input_vars) {
op_scope.Var(var_name);
}
// TODO(qijun) get appropriate DeviceContext from DeviceContext pool
platform::DeviceContext* trans_dev_ctx = nullptr;
std::vector<platform::DeviceContext*> trans_dev_ctx_vec{trans_dev_ctx};
// TODO(qijun) get appropriate DataTransformFN from global map
framework::DataTransformFN trans_fun = nullptr;
// Wait for transform starting
dev_ctx->Wait();
for (auto var_name : input_vars) {
trans_fun(trans_dev_ctx_vec, *(scope.FindVar(var_name)),
op_scope.FindVar(var_name));
}
// Wait for data transform finishing
for (auto ctx : trans_dev_ctx_vec) {
ctx->Wait();
}
// Create a new ExecutionContext
ExecutionContext op_ctx(*this, op_scope, *dev_ctx);
kernel_iter->second->Compute(op_ctx);
}
}
OpKernelType OperatorWithKernel::GetActualKernelType(

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -24,13 +24,13 @@ class SoftmaxOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SoftmaxOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Y"),
"Output(Y) of SoftmaxOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SoftmaxOp should not be null.");
auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE(x_dims.size() == 2UL,
"The input of softmax op must be a matrix.");
ctx->SetOutputDim("Y", x_dims);
ctx->SetOutputDim("Out", x_dims);
}
};
@ -41,7 +41,7 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X",
"The input tensor of softmax. "
"2-D with shape [batch_size, input_feature_dimensions].");
AddOutput("Y", "The normalized values with the same shape as X.");
AddOutput("Out", "The normalized values with the same shape as X.");
AddComment(R"DOC(
Softmax Operator.
@ -59,7 +59,7 @@ exponential values of all the other dimensions is the output of the softmax
operator.
For each row $i$ and each column $j$ in Input(X), we have:
$$Y[i, j] = \frac{\exp(X[i, j])}{\sum_j(exp(X[i, j])}$$
$$Out[i, j] = \frac{\exp(X[i, j])}{\sum_j(exp(X[i, j])}$$
)DOC");
}
@ -70,12 +70,12 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should be not null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")),
"Input(Y@GRAD) should be not null.");
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Y"),
ctx->GetInputDim(framework::GradVarName("Y")),
"Input(Y) and its gradients should have a same shape.");
PADDLE_ENFORCE(ctx->HasInput("Out"), "Input(Out) should be not null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should be not null.");
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Out"),
ctx->GetInputDim(framework::GradVarName("Out")),
"Input(Out) and its gradients should have a same shape.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}

@ -26,13 +26,13 @@ class SoftmaxKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<Tensor>("X");
auto* Y = context.Output<Tensor>("Y");
auto* Out = context.Output<Tensor>("Out");
// allocate memory on device.
Y->mutable_data<T>(context.GetPlace());
Out->mutable_data<T>(context.GetPlace());
math::SoftmaxFunctor<DeviceContext, T>()(
context.template device_context<DeviceContext>(), X, Y);
context.template device_context<DeviceContext>(), X, Out);
}
};
@ -40,15 +40,15 @@ template <typename DeviceContext, typename T>
class SoftmaxGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* Y = context.Input<Tensor>("Y");
auto* dY = context.Input<Tensor>(framework::GradVarName("Y"));
auto* Out = context.Input<Tensor>("Out");
auto* dOut = context.Input<Tensor>(framework::GradVarName("Out"));
auto* dX = context.Output<Tensor>(framework::GradVarName("X"));
// allocate memory on device.
dX->mutable_data<T>(context.GetPlace());
math::SoftmaxGradFunctor<DeviceContext, T>()(
context.template device_context<DeviceContext>(), Y, dY, dX);
context.template device_context<DeviceContext>(), Out, dOut, dX);
}
};

@ -180,10 +180,22 @@ def save_inference_model(dirname,
:return: None
"""
if isinstance(feeded_var_names, basestring):
feeded_var_names = [feeded_var_names]
else:
if not (bool(feeded_var_names) and all(
isinstance(name, basestring) for name in feeded_var_names)):
raise ValueError("'feed_var_names' should be a list of str.")
if isinstance(target_vars, Variable):
feeded_var_names = [feeded_var_names]
else:
if not (bool(target_vars) and all(
isinstance(var, Variable) for var in target_vars)):
raise ValueError("'target_vars' should be a list of Variable.")
if main_program is None:
main_program = default_main_program()
if not isinstance(target_vars, list):
target_vars = [target_vars]
if not os.path.isdir(dirname):
os.makedirs(dirname)

@ -184,7 +184,7 @@ class LayerHelper(object):
self.append_op(
type=act_type,
inputs={"X": [input_var]},
outputs={"Y": [tmp]},
outputs={"Out": [tmp]},
attrs=act)
return tmp

@ -386,7 +386,8 @@ def square_error_cost(input, label, **kwargs):
square_out = helper.create_tmp_variable(dtype=input.dtype)
helper.append_op(
type='square', inputs={'X': [minus_out]}, outputs={'Y': [square_out]})
type='square', inputs={'X': [minus_out]},
outputs={'Out': [square_out]})
return square_out
@ -604,7 +605,7 @@ def sequence_pool(input, pool_type, **kwargs):
sqrt : out.data = [2.82, 6.93, 4.24], where 2.82=(1+3)/sqrt(2),
6.93=(2+4+6)/sqrt(3), 4.24=(5+1)/sqrt(2)
max : out.data = [3, 6, 5], where 3=max(1,3), 6=max(2,4,6), 5=max(5,1)
Args:
input(variable): The input variable which is a LoDTensor.
pool_type (string): The pooling type of sequence_pool.
@ -616,7 +617,7 @@ def sequence_pool(input, pool_type, **kwargs):
Examples:
.. code-block:: python
x = fluid.layers.data(name='x', shape=[7, 1],
dtype='float32', lod_level=1)
avg_x = fluid.layers.sequence_pool(input=x, pool_type='average')
@ -654,7 +655,7 @@ def sequence_first_step(input, **kwargs):
out.dim = [3, 1]
with condition len(x.lod[-1]) - 1 == out.dims[0]
out.data = [1, 2, 5], where 1=first(1,3), 2=first(2,4,6), 5=first(5,1)
Args:
input(variable): The input variable which is a LoDTensor.
@ -664,7 +665,7 @@ def sequence_first_step(input, **kwargs):
Examples:
.. code-block:: python
x = fluid.layers.data(name='x', shape=[7, 1],
dtype='float32', lod_level=1)
x_first_step = fluid.layers.sequence_first_step(input=x)
@ -687,7 +688,7 @@ def sequence_last_step(input, **kwargs):
out.dim = [3, 1]
with condition len(x.lod[-1]) - 1 == out.dims[0]
out.data = [3, 6, 1], where 3=last(1,3), 6=last(2,4,6), 1=last(5,1)
Args:
input(variable): The input variable which is a LoDTensor.
@ -697,7 +698,7 @@ def sequence_last_step(input, **kwargs):
Examples:
.. code-block:: python
x = fluid.layers.data(name='x', shape=[7, 1],
dtype='float32', lod_level=1)
x_last_step = fluid.layers.sequence_last_step(input=x)
@ -1132,7 +1133,7 @@ def reduce_sum(input, dim=None, keep_dim=False):
Returns:
Variable: The reduced Tensor variable.
Examples:
.. code-block:: python
@ -1176,7 +1177,7 @@ def reduce_mean(input, dim=None, keep_dim=False):
Returns:
Variable: The reduced Tensor variable.
Examples:
.. code-block:: python

File diff suppressed because it is too large Load Diff

@ -7,7 +7,7 @@ def fc(X, W, Y):
ret_v = core.Net.create()
ret_v.append_op(Operator("mul", X="X", Y="W", Out="pre_activation"))
ret_v.append_op(Operator("sigmoid", X="pre_activation", Y=Y))
ret_v.append_op(Operator("sigmoid", X="pre_activation", Out=Y))
ret_v.complete_add_op(True)
return ret_v
@ -30,7 +30,7 @@ Op(plain_net), inputs:{all[W, X, Y]}, outputs:{all[Out, fc.out, pre_activation]}
Op(plain_net), inputs:{all[W, X]}, outputs:{all[fc.out, pre_activation]}.
Op(plain_net), inputs:{all[W, X]}, outputs:{all[fc.out, pre_activation]}.
Op(mul), inputs:{X[X], Y[W]}, outputs:{Out[pre_activation]}.
Op(sigmoid), inputs:{X[pre_activation]}, outputs:{Y[fc.out]}.
Op(sigmoid), inputs:{X[pre_activation]}, outputs:{Out[fc.out]}.
'''
self.assertEqual(expected, "\n" + str(net))

@ -17,14 +17,14 @@ class TestSoftmaxOp(OpTest):
'X': np.random.uniform(0.1, 1, [10, 10]).astype("float32")
}
self.outputs = {
'Y': np.apply_along_axis(stable_softmax, 1, self.inputs['X'])
'Out': np.apply_along_axis(stable_softmax, 1, self.inputs['X'])
}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y')
self.check_grad(['X'], 'Out')
if __name__ == "__main__":

Loading…
Cancel
Save