Merge pull request #7574 from lcy-seso/wraper_for_l2_normalize

add python wrapper for l2 normalize layer.
add_depthwiseConv_op_gpu
Guo Sheng 7 years ago committed by GitHub
commit 4b3e22b865
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -499,3 +499,8 @@ swish
------ ------
.. autofunction:: paddle.v2.fluid.layers.swish .. autofunction:: paddle.v2.fluid.layers.swish
:noindex: :noindex:
l2_normalize
------------
.. autofunction:: paddle.v2.fluid.layers.l2_normalize
:noindex:

@ -51,8 +51,8 @@ class ClipOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC( AddComment(R"DOC(
Clip Operator. Clip Operator.
The clip operator limits the value of given input within an interval. The interval is The clip operator limits the value of given input within an interval. The
specified with arguments 'min' and 'max': interval is specified with arguments 'min' and 'max':
$$ $$
Out = \min(\max(X, min), max) Out = \min(\max(X, min), max)

@ -26,9 +26,9 @@ class ElementwiseOp : public framework::OperatorWithKernel {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of elementwise op should not be null"); "Input(X) of elementwise op should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Y"), PADDLE_ENFORCE(ctx->HasInput("Y"),
"Input(Y) of elementwise op should not be null"); "Input(Y) of elementwise op should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of elementwise op should not be null."); "Output(Out) of elementwise op should not be null.");
@ -45,12 +45,12 @@ class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ElementwiseOpMaker(OpProto* proto, OpAttrChecker* op_checker) ElementwiseOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor) The first input tensor of elementwise op"); AddInput("X", "(Tensor), The first input tensor of elementwise op.");
AddInput("Y", "(Tensor) The second input tensor of elementwise op"); AddInput("Y", "(Tensor), The second input tensor of elementwise op.");
AddOutput("Out", "The output of elementwise op"); AddOutput("Out", "The output of elementwise op.");
AddAttr<int>("axis", AddAttr<int>("axis",
"(int, default -1) The starting dimension index " "(int, default -1). The start dimension index "
"for broadcasting Y onto X") "for broadcasting Y onto X.")
.SetDefault(-1) .SetDefault(-1)
.EqualGreaterThan(-1); .EqualGreaterThan(-1);
comment_ = R"DOC( comment_ = R"DOC(
@ -58,19 +58,18 @@ Limited Elementwise {name} Operator.
The equation is: The equation is:
.. math:: $${equation}$$
{equation}
X is a tensor of any dimension and the dimensions of tensor Y must be smaller than $X$ is a tensor of any dimension and the dimensions of tensor $Y$ must be
or equal to the dimensions of X. smaller than or equal to the dimensions of $X$.
There are two cases for this operator: There are two cases for this operator:
1. The shape of Y is same with X; 1. The shape of $Y$ is same with $X$;
2. The shape of Y is a subset of X. 2. The shape of $Y$ is a subset of $X$.
For case 2: For case 2:
Y will be broadcasted to match the shape of X and axis should be $Y$ will be broadcasted to match the shape of $X$ and axis should be
the starting dimension index for broadcasting Y onto X. set to index of the start dimension to broadcast $Y$ onto $X$.
For example For example
.. code-block:: python .. code-block:: python
@ -81,7 +80,8 @@ For example
shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1 shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1
shape(X) = (2, 3, 4, 5), shape(Y) = (2), with axis=0 shape(X) = (2, 3, 4, 5), shape(Y) = (2), with axis=0
Either of the inputs X and Y or none can carry the LoD (Level of Details) information. However, the output only shares the LoD information with input X. Either of the inputs $X$ and $Y$ or none can carry the LoD (Level of Details)
information. However, the output only shares the LoD information with input $X$.
)DOC"; )DOC";
AddComment(comment_); AddComment(comment_);

@ -58,21 +58,21 @@ class ExpandOpMaker : public framework::OpProtoAndCheckerMaker {
ExpandOpMaker(OpProto* proto, OpAttrChecker* op_checker) ExpandOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(Tensor, default Tensor<float>) A tensor with rank in [1, 6]." "(Tensor, default Tensor<float>). A tensor with rank in [1, 6]."
"X is the input tensor to be expanded."); "X is the input to be expanded.");
AddOutput("Out", AddOutput("Out",
"(Tensor, default Tensor<float>) A tensor with rank in [1, 6]." "(Tensor, default Tensor<float>). A tensor with rank in [1, 6]."
"The rank of Output(Out) is same as Input(X) except that each " "The rank of Output(Out) have the same with Input(X). "
"dimension size of Output(Out) is equal to corresponding " "After expanding, size of each dimension of Output(Out) is equal "
"dimension size of Input(X) multiplying corresponding value of " "to size of the corresponding dimension of Input(X) multiplying "
"Attr(expand_times)."); "the corresponding value given by Attr(expand_times).");
AddAttr<std::vector<int>>("expand_times", AddAttr<std::vector<int>>("expand_times",
"Expand times number for each dimension."); "Expand times number for each dimension.");
AddComment(R"DOC( AddComment(R"DOC(
Expand operator tiles the input by given times number. You should set times Expand operator tiles the input by given times number. You should set times
number for each dimension by providing attribute 'expand_times'. The rank of X number for each dimension by providing attribute 'expand_times'. The rank of X
should be in [1, 6]. Please notice that size of 'expand_times' must be same with should be in [1, 6]. Please note that size of 'expand_times' must be the same
X's rank. Following is a using case: with X's rank. Following is a using case:
Input(X) is a 3-D tensor with shape [2, 3, 1]: Input(X) is a 3-D tensor with shape [2, 3, 1]:

@ -16,13 +16,22 @@ from paddle.trainer.config_parser import *
from default_decorators import * from default_decorators import *
__all__ = [ __all__ = [
"evaluator_base", "classification_error_evaluator", "auc_evaluator", "evaluator_base",
"pnpair_evaluator", "precision_recall_evaluator", "ctc_error_evaluator", "classification_error_evaluator",
"chunk_evaluator", "sum_evaluator", "column_sum_evaluator", "auc_evaluator",
"value_printer_evaluator", "gradient_printer_evaluator", "pnpair_evaluator",
"maxid_printer_evaluator", "maxframe_printer_evaluator", "precision_recall_evaluator",
"seqtext_printer_evaluator", "classification_error_printer_evaluator", "ctc_error_evaluator",
"detection_map_evaluator" "chunk_evaluator",
"sum_evaluator",
"column_sum_evaluator",
"value_printer_evaluator",
"gradient_printer_evaluator",
"maxid_printer_evaluator",
"maxframe_printer_evaluator",
"seqtext_printer_evaluator",
"classification_error_printer_evaluator",
"detection_map_evaluator",
] ]

@ -116,8 +116,8 @@ def _debug_string_(proto, throw_on_error=True):
""" """
error_fields = list() error_fields = list()
if not proto.IsInitialized(error_fields) and throw_on_error: if not proto.IsInitialized(error_fields) and throw_on_error:
raise ValueError("{0} are not initialized\nThe message is {1}".format( raise ValueError("{0} are not initialized.\nThe message is {1}:\n".
error_fields, proto)) format(error_fields, proto))
return proto.__str__() return proto.__str__()
@ -374,12 +374,13 @@ class Operator(object):
>>> outputs={"Out": [var1]}) >>> outputs={"Out": [var1]})
Args: Args:
block(Block): The block has the current operator block(Block): The block has the current operator.
desc(core.OpDesc): The protobuf description desc(core.OpDesc): The protobuf description.
type(str): The type of operator. type(str): The type of operator.
inputs(dict): The input dictionary. Key is the input parameter name. inputs(dict): The input dictionary. Key is the input parameter name.
Value is a list of variables. Value is a list of variables.
outputs(dict): The output dictionary. Has same format with inputs outputs(dict): The output dictionary which has the same format with
inputs.
attrs(dict): The attributes dictionary. Key is attribute name. Value attrs(dict): The attributes dictionary. Key is attribute name. Value
is the attribute value. The attribute type should be as same as is the attribute value. The attribute type should be as same as
the type registered in C++ the type registered in C++
@ -436,10 +437,11 @@ class Operator(object):
for m in proto.outputs: for m in proto.outputs:
need.add(m.name) need.add(m.name)
if not given == need: if not given == need:
raise ValueError( raise ValueError(("Incorrect setting for output(s) of "
"Incorrect setting for output(s) of operator \"%s\". Need: [%s] Given: [%s]" "operator \"%s\". Need: [%s] Given: [%s]") %
% (type, ", ".join(str(e) for e in need), ", ".join( (type, ", ".join(str(e)
str(e) for e in given))) for e in need), ", ".join(
str(e) for e in given)))
for out_proto in proto.outputs: for out_proto in proto.outputs:
out_args = outputs[out_proto.name] out_args = outputs[out_proto.name]
@ -818,9 +820,8 @@ class Program(object):
if isinstance(t, Variable): if isinstance(t, Variable):
t = t.op t = t.op
else: else:
raise ValueError( raise ValueError(("All targets of prune() can only be "
"All targets of prune() can only be Variable or Operator." "Variable or Operator."))
)
targets_idx.append([t.block.idx, t.idx]) targets_idx.append([t.block.idx, t.idx])
res = Program() res = Program()

@ -28,9 +28,9 @@ def data(name,
**Data Layer** **Data Layer**
This function takes in the input and based on whether data has This function takes in the input and based on whether data has
to be returned back as a minibatch, it creates the global variable using to be returned back as a minibatch, it creates the global variable by using
the helper functions. The global variables can be accessed by all the the helper functions. The global variables can be accessed by all the
following operations and layers in the graph. following operators in the graph.
All the input variables of this function are passed in as local variables All the input variables of this function are passed in as local variables
to the LayerHelper constructor. to the LayerHelper constructor.

File diff suppressed because it is too large Load Diff

@ -0,0 +1,95 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import unittest
import paddle.v2.fluid as fluid
import paddle.v2.fluid.core as core
import numpy as np
class TestNormalization(unittest.TestCase):
data_desc = {"name": "input", "shape": (2, 3, 7)}
def gen_random_input(self):
"""Generate random input data.
"""
self.data = np.random.random(
size=self.data_desc["shape"]).astype("float32")
def set_program(self, axis, epsilon):
"""Build the test program.
"""
data = fluid.layers.data(
name=self.data_desc["name"],
shape=self.data_desc["shape"],
dtype="float32",
append_batch_size=False)
data.stop_gradient = False
l2_norm = fluid.layers.l2_normalize(x=data, axis=axis, epsilon=epsilon)
out = fluid.layers.reduce_sum(l2_norm, dim=None)
fluid.backward.append_backward(loss=out)
self.fetch_list = [l2_norm]
def run_program(self):
"""Run the test program.
"""
places = [core.CPUPlace()]
if core.is_compile_gpu():
places.append(core.CUDAPlace(0))
for place in places:
self.set_inputs(place)
exe = fluid.Executor(place)
output = exe.run(fluid.default_main_program(),
feed=self.inputs,
fetch_list=self.fetch_list,
return_numpy=True)
self.op_output = output
def set_inputs(self, place):
"""Set the randomly generated data to the test program.
"""
self.inputs = {}
tensor = fluid.Tensor()
tensor.set(self.data, place)
self.inputs[self.data_desc["name"]] = tensor
def l2_normalize(self, data, axis, epsilon):
""" Compute the groundtruth.
"""
output = data * np.reciprocal(
np.sum(np.square(data), axis=axis, keepdims=True))
return output
def test_l2_normalize(self):
""" Test the python wrapper for l2_normalize.
"""
axis = 1
#TODO(caoying) epsilon is not supported due to lack of a maximum_op.
epsilon = 1e-6
self.gen_random_input()
self.set_program(axis, epsilon)
self.run_program()
expect_output = self.l2_normalize(self.data, axis, epsilon)
# check output
self.assertTrue(np.allclose(self.op_output, expect_output, atol=0.001))
if __name__ == '__main__':
unittest.main()
Loading…
Cancel
Save