diff --git a/CMakeLists.txt b/CMakeLists.txt index 4921226ec1..4783095194 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -86,6 +86,14 @@ if(ANDROID OR IOS) "Disable MKLDNN when cross-compiling for Android and iOS" FORCE) set(WITH_MKLML OFF CACHE STRING "Disable MKLML package when cross-compiling for Android and iOS" FORCE) + + # Compile PaddlePaddle mobile inference library + if (NOT WITH_C_API) + set(WITH_C_API ON CACHE STRING + "Always compile the C_API when cross-compiling for Android and iOS" FORCE) + endif() + set(MOBILE_INFERENCE ON) + add_definitions(-DPADDLE_MOBILE_INFERENCE) endif() set(THIRD_PARTY_PATH "${CMAKE_BINARY_DIR}/third_party" CACHE STRING @@ -160,9 +168,11 @@ endif(USE_NNPACK) add_subdirectory(proto) -# "add_subdirectory(go)" should be placed after the following loine, -# because it depends on paddle/optimizer. -add_subdirectory(paddle/optimizer) +if(NOT MOBILE_INFERENCE) + # "add_subdirectory(go)" should be placed after the following loine, + # because it depends on paddle/optimizer. + add_subdirectory(paddle/optimizer) +endif() # "add_subdirectory(paddle)" and "add_subdirectory(python)" should be # placed after this block, because they depends on it. diff --git a/cmake/configure.cmake b/cmake/configure.cmake index c1c93e17fd..db8f5ab045 100644 --- a/cmake/configure.cmake +++ b/cmake/configure.cmake @@ -24,6 +24,10 @@ if(WITH_DOUBLE) add_definitions(-DPADDLE_TYPE_DOUBLE) endif(WITH_DOUBLE) +if(WITH_TESTING) + add_definitions(-DPADDLE_WITH_TESTING) +endif(WITH_TESTING) + if(NOT WITH_TIMER) add_definitions(-DPADDLE_DISABLE_TIMER) endif(NOT WITH_TIMER) diff --git a/cmake/util.cmake b/cmake/util.cmake index d1aee3e170..117ab7f49c 100644 --- a/cmake/util.cmake +++ b/cmake/util.cmake @@ -73,25 +73,43 @@ function(link_paddle_exe TARGET_NAME) generate_rdma_links() endif() - target_circle_link_libraries(${TARGET_NAME} - ARCHIVE_START - paddle_gserver - paddle_function - ARCHIVE_END - paddle_pserver - paddle_trainer_lib - paddle_network - paddle_math - paddle_utils - paddle_parameter - paddle_proto - paddle_cuda - paddle_optimizer - ${EXTERNAL_LIBS} - ${CMAKE_THREAD_LIBS_INIT} - ${CMAKE_DL_LIBS} - ${RDMA_LD_FLAGS} - ${RDMA_LIBS}) + if(MOBILE_INFERENCE) + target_circle_link_libraries(${TARGET_NAME} + ARCHIVE_START + paddle_gserver + paddle_function + ARCHIVE_END + paddle_math + paddle_utils + paddle_parameter + paddle_proto + paddle_cuda + ${EXTERNAL_LIBS} + ${CMAKE_THREAD_LIBS_INIT} + ${CMAKE_DL_LIBS} + ${RDMA_LD_FLAGS} + ${RDMA_LIBS}) + else() + target_circle_link_libraries(${TARGET_NAME} + ARCHIVE_START + paddle_gserver + paddle_function + ARCHIVE_END + paddle_pserver + paddle_trainer_lib + paddle_network + paddle_math + paddle_utils + paddle_parameter + paddle_proto + paddle_cuda + paddle_optimizer + ${EXTERNAL_LIBS} + ${CMAKE_THREAD_LIBS_INIT} + ${CMAKE_DL_LIBS} + ${RDMA_LD_FLAGS} + ${RDMA_LIBS}) + endif() if(ANDROID) target_link_libraries(${TARGET_NAME} log) diff --git a/doc/design/block.md b/doc/design/block.md index 4d5dd4ba95..9c812732d6 100644 --- a/doc/design/block.md +++ b/doc/design/block.md @@ -5,12 +5,12 @@ Both deep learning systems and programming languages help users describe computation procedures. These systems use various representations of computation: - Caffe, Torch, and Paddle: sequences of layers. -- TensorFlow, Caffe2, Mxnet: graphs of operators. +- TensorFlow, Caffe2, Mxnet: graph of operators. - PaddlePaddle: nested blocks, like C++ and Java programs. ## Block in Programming Languages and Deep Learning -In programming languages, a block is a pair of curly braces that includes local variables definitions and a sequence of instructions, or operators. +In programming languages, a block is a pair of curly braces that includes local variables definitions and a sequence of instructions or operators. Blocks work with control flow structures like `if`, `else`, and `for`, which have equivalents in deep learning: @@ -24,14 +24,14 @@ A key difference is that a C++ program describes a one pass computation, whereas ## Stack Frames and the Scope Hierarchy -The existence of the backward makes the execution of a block of traditional programs and PaddlePaddle different to each other: +The existence of the backward pass makes the execution of a block of PaddlePaddle different from traditional programs: -| programming languages | PaddlePaddle | -|-----------------------|-------------------------------| -| stack | scope hierarchy | -| stack frame | scope | -| push at entering block| push at entering block | -| pop at leaving block | destroy at minibatch completes| +| programming languages | PaddlePaddle | +|-----------------------|---------------------------------| +| stack | scope hierarchy | +| stack frame | scope | +| push at entering block| push at entering block | +| pop at leaving block | destroy when minibatch completes| 1. In traditional programs: @@ -42,9 +42,9 @@ The existence of the backward makes the execution of a block of traditional prog 1. In PaddlePaddle - When the execution enters a block, PaddlePaddle adds a new scope, where it realizes variables. - - PaddlePaddle doesn't pop a scope after the execution of the block because variables therein are to be used by the backward pass. So it has a stack forest known as a *scope hierarchy*. + - PaddlePaddle doesn't pop a scope after the execution of the block because variables therein are used by the backward pass. So it has a stack forest known as a *scope hierarchy*. - The height of the highest tree is the maximum depth of nested blocks. - - After the process of a minibatch, PaddlePaddle destroys the scope hierarchy. + - After the processing of a minibatch, PaddlePaddle destroys the scope hierarchy. ## Use Blocks in C++ and PaddlePaddle Programs @@ -94,14 +94,14 @@ with ie.false_block(): o1, o2 = ie(cond) ``` -In both examples, the left branch computes `x+y` and `softmax(x+y)`, the right branch computes `x+1` and `fc(x)`. +In both examples, the left branch computes `x+y` and `softmax(x+y)`, the right branch computes `fc(x)` and `x+1` . -A difference is that variables in the C++ program contain scalar values, whereas those in the PaddlePaddle programs are mini-batches of instances. The `ie.input(true, 0)` invocation returns instances in the 0-th input, `x`, that corresponds to true values in `cond` as the local variable `x`, where `ie.input(false, 0)` returns instances corresponding to false values. +The difference is that variables in the C++ program contain scalar values, whereas those in the PaddlePaddle programs are mini-batches of instances. ### Blocks with `for` and `RNNOp` -The following RNN model from the [RNN design doc](./rnn.md) +The following RNN model in PaddlePaddle from the [RNN design doc](./rnn.md) : ```python x = sequence([10, 20, 30]) # shape=[None, 1] @@ -112,9 +112,9 @@ U = var(0.375, param=true) # shape=[1] rnn = pd.rnn() with rnn.step(): h = rnn.memory(init = m) - hh = rnn.previous_memory(h) + h_prev = rnn.previous_memory(h) a = layer.fc(W, x) - b = layer.fc(U, hh) + b = layer.fc(U, h_prev) s = pd.add(a, b) act = pd.sigmoid(s) rnn.update_memory(h, act) @@ -147,9 +147,9 @@ for (int i = 1; i <= sizeof(x)/sizeof(x[0]); ++i) { ## Compilation and Execution -Like TensorFlow programs, a PaddlePaddle program is written in Python. The first part describes a neural network as a protobuf message, and the rest part executes the message for training or inference. +Like TensorFlow, a PaddlePaddle program is written in Python. The first part describes a neural network as a protobuf message, and the rest executes the message for training or inference. -The generation of this protobuf message is like what a compiler generates a binary executable file. The execution of the message that the OS executes the binary file. +The generation of this protobuf message is similar to how a compiler generates a binary executable file. The execution of the message is similar to how the OS executes the binary file. ## The "Binary Executable File Format" @@ -186,8 +186,8 @@ Also, the RNN operator in above example is serialized into a protobuf message of ``` OpDesc { - inputs = {0} // the index of x - outputs = {5, 3} // indices of act and hidden_out + inputs = {0} // the index of x in vars of BlockDesc above + outputs = {5, 3} // indices of act and hidden_out in vars of BlockDesc above attrs { "memories" : {1} // the index of h "step_net" : @@ -203,14 +203,14 @@ This `OpDesc` value is in the `ops` field of the `BlockDesc` value representing During the generation of the Protobuf message, the Block should store VarDesc (the Protobuf message which describes Variable) and OpDesc (the Protobuf message which describes Operator). VarDesc in a block should have its name scope to avoid local variables affect parent block's name scope. -Child block's name scopes should inherit the parent's so that OpDesc in child block can reference a VarDesc that stored in parent block. For example +Child block's name scopes should inherit the parent's so that OpDesc in child block can reference a VarDesc that stored in parent block. For example: ```python -a = pd.Varaible(shape=[20, 20]) +a = pd.Variable(shape=[20, 20]) b = pd.fc(a, params=["fc.w", "fc.b"]) rnn = pd.create_rnn() -with rnn.stepnet() +with rnn.stepnet(): x = a.as_step_input() # reuse fc's parameter fc_without_b = pd.get_variable("fc.w") @@ -218,17 +218,17 @@ with rnn.stepnet() out = rnn() ``` -the method `pd.get_variable` can help retrieve a Variable by a name, a Variable may store in a parent block, but might be retrieved in a child block, so block should have a variable scope that supports inheritance. +The method `pd.get_variable` can help retrieve a Variable by the name. The Variable may be stored in a parent block, but might be retrieved in a child block, so block should have a variable scope that supports inheritance. In compiler design, the symbol table is a data structure created and maintained by compilers to store information about the occurrence of various entities such as variable names, function names, classes, etc. To store the definition of variables and operators, we define a C++ class `SymbolTable`, like the one used in compilers. -`SymbolTable` can do the following stuff: +`SymbolTable` can do the following: - store the definitions (some names and attributes) of variables and operators, -- to verify if a variable was declared, -- to make it possible to implement type checking (offer Protobuf message pointers to `InferShape` handlers). +- verify if a variable was declared, +- make it possible to implement type checking (offer Protobuf message pointers to `InferShape` handlers). ```c++ @@ -240,19 +240,18 @@ class SymbolTable { OpDesc* NewOp(const string& name=""); - // TODO determine whether name is generated by python or C++ - // currently assume that a unique name will be generated by C++ if the - // argument name left default. + // TODO determine whether name is generated by python or C++. + // Currently assume that a unique name will be generated by C++ if the + // argument name is left default. VarDesc* NewVar(const string& name=""); - // find a VarDesc by name, if recursive true, find parent's SymbolTable + // find a VarDesc by name, if recursive is true, find parent's SymbolTable // recursively. // this interface is introduced to support InferShape, find protobuf messages // of variables and operators, pass pointers into InferShape. - // operator // // NOTE maybe some C++ classes such as VarDescBuilder and OpDescBuilder should - // be proposed and embedded into pybind to enable python operate on C++ pointers. + // be proposed and embedded into pybind to enable python operation on C++ pointers. VarDesc* FindVar(const string& name, bool recursive=true); OpDesc* FindOp(const string& name); @@ -270,7 +269,7 @@ class SymbolTable { After all the description of variables and operators is added into SymbolTable, the block has enough information to run. -The `Block` class takes a `BlockDesc` as input, and provide `Run` and `InferShape` functions. +The `Block` class takes a `BlockDesc` as input, and provides `Run` and `InferShape` functions. ```c++ @@ -302,7 +301,7 @@ public: void CreateVariables(const framework::Scope& scope); void CreateOperators(); - // some other necessary interfaces of NetOp are list below + // some other necessary interfaces of NetOp are listed below // ... private: @@ -316,15 +315,14 @@ private: Block inherits from OperatorBase, which has a Run method. Block's Run method will run its operators sequentially. -There is another important interface called `Eval`, which take some arguments called targets, and generate a minimal graph which takes targets as the end points and creates a new Block, -after `Run`, `Eval` will get the latest value and return the targets. +There is another important interface called `Eval`, which takes some arguments called targets and generates a minimal graph which treats targets as the end points and creates a new Block. After `Run`, `Eval` will get the latest value and return the targets. The definition of Eval is as follows: ```c++ // clean a block description by targets using the corresponding dependency graph. // return a new BlockDesc with minimal number of operators. -// NOTE not return a Block but the block's description so that this can be distributed +// NOTE: The return type is not a Block but the block's description so that this can be distributed // to a cluster. BlockDesc Prune(const BlockDesc& desc, vector targets); diff --git a/doc/design/dcgan.png b/doc/design/dcgan.png new file mode 100644 index 0000000000..15e8e290a1 Binary files /dev/null and b/doc/design/dcgan.png differ diff --git a/doc/design/gan_api.md b/doc/design/gan_api.md new file mode 100644 index 0000000000..fb41df8615 --- /dev/null +++ b/doc/design/gan_api.md @@ -0,0 +1,253 @@ +# Design for GAN + +GAN (General Adversarial Net [https://arxiv.org/abs/1406.2661]) is an important model for unsupervised learning and widely used in many areas. + +It applies several important concepts in machine learning system design, including building and running subgraphs, dependency tracing, different optimizers in one executor and so forth. + +In our GAN design, we wrap it as a user-friendly easily customized python API to design different models. We take the conditional DC-GAN (Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks [https://arxiv.org/abs/1511.06434]) as an example due to its good performance on image generation. + +

+
+Figure 1. The overall running logic of GAN. The black solid arrows indicate the forward pass; the green dashed arrows indicate the backward pass of generator training; the red dashed arrows indicate the backward pass of the discriminator training. The BP pass of the green (red) arrow should only update the parameters in the green (red) boxes. The diamonds indicate the data providers. d\_loss and g\_loss marked in red and green are the two targets we would like to run. +

+ +The operators, layers and functions required/optional to build a GAN demo is summarized in https://github.com/PaddlePaddle/Paddle/issues/4563. + +

+
+Figure 2. Photo borrowed from the original DC-GAN paper. +

+ +## The Conditional-GAN might be a class. +This design we adopt the popular open source design in https://github.com/carpedm20/DCGAN-tensorflow and https://github.com/rajathkmp/DCGAN. It contains following data structure: + +- DCGAN(object): which contains everything required to build a GAN model. It provides following member functions methods as API: + +- __init__(...): Initialize hyper-parameters (like conv dimension and so forth), and declare model parameters of discriminator and generator as well. + +- generator(z, y=None): Generate a fake image from input noise z. If the label y is provided, the conditional GAN model will be chosen. +Returns a generated image. + +- discriminator(image): +Given an image, decide if it is from a real source or a fake one. +Returns a 0/1 binary label. + +- build_model(self): +build the whole GAN model, define training loss for both generator and discrimator. + +## Discussion on Engine Functions required to build GAN +- Trace the tensor and variable dependency in the engine executor. (Very critical, otherwise GAN can'be be trained correctly) +- Different optimizers responsible for optimizing different loss. + +To be more detailed, we introduce our design of DCGAN as following: + +### Class member Function: Initializer +- Set up hyper-parameters, including condtional dimension, noise dimension, batch size and so forth. +- Declare and define all the model variables. All the discriminator parameters are included in the list self.theta_D and all the generator parameters are included in the list self.theta_G. +```python +class DCGAN(object): + def __init__(self, y_dim=None): + + # hyper parameters + self.y_dim = y_dim # conditional gan or not + self.batch_size = 100 + self.z_dim = z_dim # input noise dimension + + # define parameters of discriminators + self.D_W0 = pd.Variable(shape=[3,3, 1, 128], data=pd.gaussian_normal_randomizer()) + self.D_b0 = pd.Variable(np.zeros(128)) # variable also support initialization using a numpy data + self.D_W1 = pd.Variable(shape=[784, 128], data=pd.gaussian_normal_randomizer()) + self.D_b1 = pd.Variable(np.zeros(128)) # variable also support initialization using a numpy data + self.D_W2 = pd.Varialble(np.random.rand(128, 1)) + self.D_b2 = pd.Variable(np.zeros(128)) + self.theta_D = [self.D_W0, self.D_b0, self.D_W1, self.D_b1, self.D_W2, self.D_b2] + + # define parameters of generators + self.G_W0 = pd.Variable(shape=[784, 128], data=pd.gaussian_normal_randomizer()) + self.G_b0 = pd.Variable(np.zeros(128)) # variable also support initialization using a numpy data + self.G_W1 = pd.Variable(shape=[784, 128], data=pd.gaussian_normal_randomizer()) + self.G_b1 = pd.Variable(np.zeros(128)) # variable also support initialization using a numpy data + self.G_W2 = pd.Varialble(np.random.rand(128, 1)) + self.G_b2 = pd.Variable(np.zeros(128)) + self.theta_G = [self.G_W0, self.G_b0, self.G_W1, self.G_b1, self.G_W2, self.G_b2] +``` + +### Class member Function: Generator +- Given a noisy input z, returns a fake image. +- Concatenation, batch-norm, FC operations required; +- Deconv layer required, which is missing now... +```python +class DCGAN(object): + def generator(self, z, y = None): + # input z: the random noise + # input y: input data label (optional) + # output G_im: generated fake images + + if not self.y_dim: + z = pd.layer.concat(1, [z, y]) + + G_h0 = pd.layer.fc(z, self.G_w0, self.G_b0) + G_h0_bn = pd.layer.batch_norm(G_h0) + G_h0_relu = pd.layer.relu(G_h0_bn) + + G_h1 = pd.layer.deconv(G_h0_relu, self.G_w1, self.G_b1) + G_h1_bn = pd.layer.batch_norm(G_h1) + G_h1_relu = pd.layer.relu(G_h1_bn) + + G_h2 = pd.layer.deconv(G_h1_relu, self.G_W2, self.G_b2)) + G_im = pd.layer.tanh(G_im) + return G_im +``` + +### Class member function: Discriminator +- Given a noisy input z, returns a fake image. +- Concatenation, Convolution, batch-norm, FC, Leaky-ReLU operations required; +```python +class DCGAN(object): + def discriminator(self, image): + # input image: either generated images or real ones + # output D_h2: binary logit of the label + + D_h0 = pd.layer.conv2d(image, w=self.D_w0, b=self.D_b0) + D_h0_bn = pd.layer.batchnorm(h0) + D_h0_relu = pd.layer.lrelu(h0_bn) + + D_h1 = pd.layer.conv2d(D_h0_relu, w=self.D_w1, b=self.D_b1) + D_h1_bn = pd.layer.batchnorm(D_h1) + D_h1_relu = pd.layer.lrelu(D_h1_bn) + + D_h2 = pd.layer.fc(D_h1_relu, w=self.D_w2, b=self.D_b2) + return D_h2 +``` + +### Class member function: Build the model +- Define data readers as placeholders to hold the data; +- Build generator and discriminators; +- Define two training losses for discriminator and generator, respectively. +If we have execution dependency engine to back-trace all tensors, the module building our GAN model will be like this: +```python +class DCGAN(object): + def build_model(self): + if self.y_dim: + self.y = pd.data(pd.float32, [self.batch_size, self.y_dim]) + self.images = pd.data(pd.float32, [self.batch_size, self.im_size, self.im_size]) + self.faked_images = pd.data(pd.float32, [self.batch_size, self.im_size, self.im_size]) + self.z = pd.data(tf.float32, [None, self.z_size]) + + # step 1: generate images by generator, classify real/fake images with discriminator + if self.y_dim: # if conditional GAN, includes label + self.G = self.generator(self.z, self.y) + self.D_t = self.discriminator(self.images) + # generated fake images + self.sampled = self.sampler(self.z, self.y) + self.D_f = self.discriminator(self.G) + else: # original version of GAN + self.G = self.generator(self.z) + self.D_t = self.discriminator(self.images) + # generate fake images + self.sampled = self.sampler(self.z) + self.D_f = self.discriminator(self.images) + + # step 2: define the two losses + self.d_loss_real = pd.reduce_mean(pd.cross_entropy(self.D_t, np.ones(self.batch_size)) + self.d_loss_fake = pd.reduce_mean(pd.cross_entropy(self.D_f, np.zeros(self.batch_size)) + self.d_loss = self.d_loss_real + self.d_loss_fake + + self.g_loss = pd.reduce_mean(pd.cross_entropy(self.D_f, np.ones(self.batch_szie)) +``` + +If we do not have dependency engine but blocks, the module building our GAN model will be like this: +```python +class DCGAN(object): + def build_model(self, default_block): + # input data in the default block + if self.y_dim: + self.y = pd.data(pd.float32, [self.batch_size, self.y_dim]) + self.images = pd.data(pd.float32, [self.batch_size, self.im_size, self.im_size]) + # self.faked_images = pd.data(pd.float32, [self.batch_size, self.im_size, self.im_size]) + self.z = pd.data(tf.float32, [None, self.z_size]) + + # step 1: generate images by generator, classify real/fake images with discriminator + with pd.default_block().g_block(): + if self.y_dim: # if conditional GAN, includes label + self.G = self.generator(self.z, self.y) + self.D_g = self.discriminator(self.G, self.y) + else: # original version of GAN + self.G = self.generator(self.z) + self.D_g = self.discriminator(self.G, self.y) + self.g_loss = pd.reduce_mean(pd.cross_entropy(self.D_g, np.ones(self.batch_szie)) + + with pd.default_block().d_block(): + if self.y_dim: # if conditional GAN, includes label + self.D_t = self.discriminator(self.images, self.y) + self.D_f = self.discriminator(self.G, self.y) + else: # original version of GAN + self.D_t = self.discriminator(self.images) + self.D_f = self.discriminator(self.G) + + # step 2: define the two losses + self.d_loss_real = pd.reduce_mean(pd.cross_entropy(self.D_t, np.ones(self.batch_size)) + self.d_loss_fake = pd.reduce_mean(pd.cross_entropy(self.D_f, np.zeros(self.batch_size)) + self.d_loss = self.d_loss_real + self.d_loss_fake +``` +Some small confusion and problems with this design: +- D\_g and D\_f are actually the same thing, but has to be written twice; i.e., if we want to run two sub-graphs conceptually, the same codes have to be written twice if they are shared by the graph. +- Requires ability to create a block anytime, rather than in if-else or rnn only; + +## Main function for the demo: +Generally, the user of GAN just need to the following things: +- Define an object as DCGAN class; +- Build the DCGAN model; +- Specify two optimizers for two different losses with respect to different parameters. +```python +# pd for short, should be more concise. +from paddle.v2 as pd +import numpy as np +import logging + +if __name__ == "__main__": + # dcgan class in the default graph/block + # if we use dependency engine as tensorflow + # the codes, will be slightly different like: + # dcgan = DCGAN() + # dcgan.build_model() + with pd.block() as def_block: + dcgan = DCGAN() + dcgan.build_model(def_block) + + # load mnist data + data_X, data_y = self.load_mnist() + + # Two subgraphs required!!! + with pd.block().d_block(): + d_optim = pd.train.Adam(lr = .001, beta= .1) + d_step = d_optim.minimize(dcgan.d_loss, dcgan.theta_D) + with pd.block.g_block(): + g_optim = pd.train.Adam(lr = .001, beta= .1) + g_step = pd.minimize(dcgan.g_loss, dcgan.theta_G) + + # executor + sess = pd.executor() + + # training + for epoch in xrange(10000): + for batch_id in range(N / batch_size): + idx = ... + # sample a batch + batch_im, batch_label = data_X[idx:idx+batch_size], data_y[idx:idx+batch_size] + # sample z + batch_z = np.random.uniform(-1., 1., [batch_size, z_dim]) + + if batch_id % 2 == 0: + sess.run(d_step, + feed_dict = {dcgan.images: batch_im, + dcgan.y: batch_label, + dcgan.z: batch_z}) + else: + sess.run(g_step, + feed_dict = {dcgan.z: batch_z}) +``` + +# More thinking about dependency engine v.s. block design: +- What if we just want to run an intermediate result? Do we need to run the whole block/graph? +- Should we call eval() to get the fake images in the first stage? And then train the discriminator in the second stage? diff --git a/doc/design/optimizer.md b/doc/design/optimizer.md new file mode 100644 index 0000000000..17440fae50 --- /dev/null +++ b/doc/design/optimizer.md @@ -0,0 +1,105 @@ +## Optimizer Design + +### The Problem + +A PaddlePaddle program, or a block, is a sequence of operators operating variables. A training program needs to do three kinds of works: + +1. the forward pass, which computes intermediate results and the cost(s), +1. the backward pass, which derives gradients from intermediate results and costs, and +1. the optimization pass, which update model parameters to optimize the cost(s). + +These works rely on three kinds of operators: + +1. forward operators, +1. gradient operators, and +1. optimization operators. + +It's true that users should be able to create all these operators manually by calling some low-level API, but it would be much more convenient if they could only describe the forward pass and let PaddlePaddle create the backward and optimization operators automatically. + +In this design, we propose a high-level API that automatically derives the optimisation pass and operators from the forward pass. + + +### High-level Python API to describe the training process + +1. User write code to describe the network: + + ```python + images = layer.data("images") + labels = layer.data("labels") + w1 = pd.var("w1") + b1 = pd.var("b1") + hidden = layer.fc(images, w=w1, b=b1) + cost = layer.mse(hidden, labels) + ``` + + The above code snippet will create forward operators in [Block](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/block.md). + + +2. Users create a certain kind of Optimizer with some argument. + + ```python + optimizer = AdagradOptimizer(learing_rate=0.001) + ``` + +3. Users use the optimizer to `minimize` a certain `cost` through updating parameters in parameter_list. + + ```python + opt_op_list = optimizer.minimize(cost, parameter_list=[w1, b1]) + ``` + The above code snippet will create gradient and optimization operators in Block. The return value of `minimize()` is list of optimization operators that will be run by session. + +4. Users use Session/Executor to run this opt_op_list as target to do training. + + ```python + sess.run(target= opt_op_list, ...) + ``` + +#### Optimizer Python interface: + +```python +class Optimizer(object): + """Optimizer Base class. + + """ + + def __init__(self): + pass + + def create_backward_pass(self, loss, parameter_list=None): + """ + create and add gradient Operators in BlockDesc to Compute gradients of `loss` + for parameters in parameter_list + + Args: + loss: an variable generated by cost function. + parameter_list: parameters that need to compute gradient and update to optimize the lost. + + Returns: + list of (parameters, gradients) pair. + """ + return None + + def create_optimization_pass(self, parameters_and_grads): + """Add optimization operators to update gradients to variables. + + Args: + parameters_and_grads: a list of (variable, gradient) pair to update. + + Returns: + optmization_op_list: a list of optimization operator that will update parameter using gradient. + """ + return None + + def minimize(self, loss, parameter_list): + """Add operations to minimize `loss` by updating `parameter_list`. + + This method combines interface `create_backward_pass()` and + `create_optimization_pass()` into one. + """ + params_grads = self.create_backward_pass(loss, parameter_list) + update_ops = self.create_optimization_pass(params_grads) + return update_ops + +``` + +Users can inherit the Optimizer above to create their own Optimizer with some special logic, such as AdagradOptimizer. diff --git a/doc/design/python_api.md b/doc/design/python_api.md index 6213da65c8..56ae1d925a 100644 --- a/doc/design/python_api.md +++ b/doc/design/python_api.md @@ -22,7 +22,7 @@ Whenever we create a block, we need to set its parent block to the current block ```python class Program(objects): def __init__(self): - self.proto = core.NewProgram() # a C++ ProgramDesc pointer. + self.desc = core.NewProgram() # a C++ ProgramDesc pointer. self.blocks = vector() self.blocks.append(Block(self, -1)) # the global block self.current_block = 0 # initialized to the global block @@ -57,7 +57,7 @@ A [Block](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/block.m ```python class Block(objects): def __init__(self, program, parent_idx): - self.proto = core.NewBlock(program.proto) + self.desc = core.NewBlock(program.desc) self.program = program self.vars = map() self.ops = vector() @@ -98,11 +98,11 @@ class Operator(object): outputs,# dict attrs # dict ): - self.proto = core.NewOpDesc(block.proto, type, inputs, outputs, attrs) - core.infer_shape(self.proto, inputs, outputs) + self.desc = core.NewOpDesc(block.desc, type, inputs, outputs, attrs) + core.infer_shape(self.desc, inputs, outputs) def type(self): - return self.proto.type() + return self.desc.type() ``` `Operator` creates the `OpDesc` message in C++ space, so that it can call the `InferShape` function, which is in C++. @@ -124,7 +124,7 @@ class Variable(object): name = unique_name_generator() self.name = name self.block = block - self.proto = core.NewVarDesc(block.proto, name, shape, lod_level) + self.desc = core.NewVarDesc(block.desc, name, shape, lod_level) self.writer = None ``` @@ -214,3 +214,7 @@ def fc_layer(input, size, ...): out.writer = op return out ``` + +## Optimizer + +[Optimizer Design Doc](./optimizer.md) diff --git a/doc/design/refactorization.md b/doc/design/refactorization.md index 629422e774..ec51aa1a0e 100644 --- a/doc/design/refactorization.md +++ b/doc/design/refactorization.md @@ -17,22 +17,22 @@ The goals of refactoring include: 1. A graph is composed of *variables* and *operators*. -1. The description of graphs must be capable of being serialized/deserialized, so that: +1. The description of graphs must be serializable/deserializable, so that: - 1. It can to be sent to the cloud for distributed execution, and + 1. It can be sent to the cloud for distributed execution, and 1. It can be sent to clients for mobile or enterprise deployment. -1. The Python program does the following steps +1. The Python program does two things - 1. *compilation*: run a Python program to generate a protobuf message representation of the graph and send it to + 1. *Compilation* runs a Python program to generate a protobuf message representation of the graph and send it to 1. the C++ library `libpaddle.so` for local execution, 1. the master process of a distributed training job for training, or 1. the server process of a Kubernetes serving job for distributed serving. - 1. *execution*: execute the graph by constructing instances of class [`Variable`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/variable.h#L24) and [`OperatorBase`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/operator.h#L70), according to the protobuf message. + 1. *Execution* executes the graph by constructing instances of class [`Variable`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/variable.h#L24) and [`OperatorBase`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/operator.h#L70), according to the protobuf message. ## Description and Realization of Computation Graph -At compile time, the Python program generates a protobuf message representation of the graph, or the description of the graph. +At compile time, the Python program generates a protobuf message representation of the graph, or a description of the graph. At runtime, the C++ program realizes the graph and runs it. @@ -42,11 +42,11 @@ At runtime, the C++ program realizes the graph and runs it. |Operation|[OpDesc](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/framework.proto#L35)|[Operator](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/operator.h#L64)| |Block|BlockDesc|Block| -The word *graph* is interchangeable with *block* in this document. A graph represents computation steps and local variables similar to a C++/Java program block, or a pair of parentheses(`{` and `}`). +The word *graph* is interchangeable with *block* in this document. A graph consists of computation steps and local variables similar to a C++/Java program block, or a pair of parentheses(`{` and `}`). ## Compilation and Execution -1. Run an application Python program to describe the graph. In particular, the Python application program does the following: +1. Run a Python program to describe the graph. In particular, the Python application program does the following: 1. Create `VarDesc` to represent local/intermediate variables, 1. Create operators and set attributes, @@ -54,10 +54,10 @@ The word *graph* is interchangeable with *block* in this document. A graph repr 1. Infer the type and the shape of variables, 1. Plan memory-reuse for variables, 1. Generate the backward graph - 1. Optimize the computation graph. - 1. Potentially, split the graph for distributed training. + 1. Add optimization operators to the computation graph. + 1. Optionally, split the graph for distributed training. -1. The invocation of `train` or [`infer`](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/v2/inference.py#L108) methods in the application Python program does the following: +1. The invocation of `train` or [`infer`](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/v2/inference.py#L108) methods in the Python program does the following: 1. Create a new Scope instance in the [scope hierarchy](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/scope.md) for each run of a block, 1. realize local variables defined in the BlockDesc message in the new scope, @@ -107,8 +107,8 @@ Compile Time -> IR -> Runtime ![class_diagram](http://api.paddlepaddle.org/graphviz?dot=https://gist.githubusercontent.com/reyoung/53df507f6749762675dff3e7ce53372f/raw/dd598e8f1976f5759f58af5e5ef94738a6b2e661/op.dot) * `Operator` is the fundamental building block of the user interface. - * Operator stores input/output variable names, and attributes. - * The `InferShape` interface is used to infer the shape of the output variable shapes based on the shapes of the input variables. + * Operator stores input/output variable names and attributes. + * The `InferShape` interface is used to infer the shape of the output variables based on the shapes of the input variables. * Use `Run` to compute the `output` variables from the `input` variables. --- @@ -139,7 +139,7 @@ Compile Time -> IR -> Runtime * Limit the number of `tensor.device(dev) = ` in your code. * `thrust::transform` and `std::transform`. * `thrust` has the same API as C++ standard library. Using `transform`, one can quickly implement customized element-wise kernels. - * `thrust` also has more complex APIs, like `scan`, `reduce`, `reduce_by_key`. + * `thrust`, in addition, supports more complex APIs, like `scan`, `reduce`, `reduce_by_key`. * Hand-writing `GPUKernel` and `CPU` code * Do not write in header (`.h`) files. CPU Kernel should be in cpp source (`.cc`) and GPU kernels should be in cuda (`.cu`) files. (GCC cannot compile GPU code.) --- @@ -185,10 +185,10 @@ Make sure the registration process is executed and linked. 1. Write an Op class and its gradient Op class, if required. 2. Write an Op maker class. In the constructor of this class, describe the inputs, outputs and attributes of the operator. 3. Invoke the macro `REGISTER_OP`. This macro will - 1. Call maker class to complete the `proto` and the `checker` + 1. Call maker class to complete `proto` and `checker` 2. Using the completed `proto` and `checker`, it will add a new key-value pair to the `OpInfoMap` -4. Invoke the `USE` macro in which the Op is used, to make sure that it is linked. +4. Invoke the `USE` macro in which the Op is used to make sure that it is linked. --- # Backward Module (1/2) @@ -199,13 +199,14 @@ Make sure the registration process is executed and linked. --- # Backward Module (2/2) ### Build Backward Network -- **Input**: graph of forward operators -- **Output**: graph of backward operators +- **Input**: a graph of forward operators +- **Output**: a graph of backward operators - **Corner cases in construction** - Shared Variables => insert an `Add` operator to combine gradients - No Gradient => insert a `fill_zero_grad` operator - Recursive NetOp => call `Backward` recursively - RNN Op => recursively call `Backward` on stepnet + - RNN Op => recursively call `Backward` on stepnet --- @@ -215,10 +216,10 @@ Make sure the registration process is executed and linked. * Only dims and data pointers are stored in `Tensor`. * All operations on `Tensor` are written in `Operator` or global functions. * Variable length Tensor design [LoDTensor](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/lod_tensor.md) -* `Variable` instances are the inputs and the outputs of an operator. Not just `Tensor`. +* `Variable` instances are the inputs and the outputs of an operator, not just `Tensor`. * `step_scopes` in RNN is a variable and not a tensor. -* `Scope` is where variables are stores. - * map +* `Scope` is where variables are stored. + * map * `Scope` has a hierarchical structure. The local scope can get variables from its parent scope. --- @@ -246,7 +247,7 @@ Make sure the registration process is executed and linked. --- # Control the migration quality - Compare the performance of migrated models with old ones. -- Follow the google C++ style +- Follow the google C++ style guide. - Build the automatic workflow of generating Python/C++ documentations. - The documentation of layers and ops should be written inside the code. - Take the documentation quality into account when submitting pull requests. diff --git a/doc/design/selected_rows.md b/doc/design/selected_rows.md new file mode 100644 index 0000000000..9e6f3b20cb --- /dev/null +++ b/doc/design/selected_rows.md @@ -0,0 +1,74 @@ +# Design Doc: Selected Rows + +`SelectedRows` is a kind of sparse tensor data type, which is designed to support `embedding` operators. The gradient of embedding table is a sparse tensor. Only a few rows are non-zero values in that tensor. It is straightforward to represent the sparse tensor by the following sparse tensor data structure: + +```cpp +class SelectedRows { + private: + vector rows_; + Tensor value_; + int height_; +}; +``` + +The field `height_` shows the first dimension of `SelectedRows`. The `rows` are the indices of which rows of `SelectedRows` are non-zeros. The `value_` field is an N-dim tensor and shape is `[rows.size() /* NUM_ROWS */, ...]`, which supplies values for each row. The dimension of `SelectedRows` satisfies `[height_] + value_.shape[1:]`. + +Suppose that a SelectedRows-typed variable `x` has many rows, but only two of them have values -- row 73 is `[1, 2]` and row 84 is `[3, 4]`, the `SelectedRows` representation would be: + +``` +x = SelectedRow { + rows = [73, 84], + value = [[1, 2], [3,4]] +} +``` + + +## SelectedRows in Protobuf + +`SelectedRows` is a kind of `Variable`. `VarDesc` in protobuf should describe the `SelectedRows` information. Only the tensor dimension of a `SelectedRows` will be described in compile-time since the `rows_` and `value_` are related to training data. +So we use `TensorDesc` to unify `data_type` and `dims`. A LodTensorDesc contains a `TensorDesc` and `lod_level`. The description of `SelectedRows` is a Tensor description. + +```proto +message TensorDesc { + required DataType data_type = 1; + repeated int64 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480] +} + +message LodTensorDesc { + required TensorDesc tensor = 1; + optional int lod_level = 2; +} + +message VarDesc { + required string name = 1; + enum VarType { + LOD_TENSOR = 0; + SELECTED_ROWS = 1; + } + required VarType type = 2; + optional LodTensorDesc lod_desc = 3; + optional TensorDesc selected_rows_desc = 4; + optional bool persistable = 5 [ default = false ]; +} +``` + +## InferShape for Selected Rows + +Just like `LoD` information, `InferShape` method will inference output tensor type as well. The operator should decide whether its output is a `SelectedRows` or `Dense` tensor. + +For example, the gradient operator of `TableLookup` will always generate `SelectedRows`. Its `InferShape` method should be like following + +```cpp +void TableLookupGrad::InferShape(context) { + ... + context.SetDataType("Embedding.Grad", kSelectedRows); +} +``` + + +## Sparse Operators + +There are several operators should be written to support `SelectedRows`. They are: + +1. Operators which generates `SelectedRows` gradient. e.g. Gradient of `TableLookupOp`. +2. Optimize operators which support `SelectedRows` gradient. e.g. `SGD` or `AdaGrad` for `SelectedRows`. However, there should be only one `SGD` operator. `OpWithKernel::Run` should select a suitable kernel for both `dense` tensor or `SelectedRows`. diff --git a/doc/design/test.dot b/doc/design/test.dot new file mode 100644 index 0000000000..62c69b8fc8 --- /dev/null +++ b/doc/design/test.dot @@ -0,0 +1,35 @@ + +digraph Test { + z -> generator -> G_img; + G_img -> discriminator -> D_f -> d_loss_f; + label0 -> d_loss_f -> d_loss; + + img -> discriminator -> D_t -> d_loss_t; + label1 -> d_loss_t -> d_loss; + + d_loss -> d_loss_t[color=red, style=dashed]; + d_loss -> d_loss_f[color=red, style=dashed]; + d_loss_t -> D_t[color=red, style=dashed]; + d_loss_f -> D_f[color=red, style=dashed]; + D_t -> discriminator[color=red, style=dashed]; + D_f -> discriminator[color=red, style=dashed]; + + D_f -> g_loss; + label2 -> g_loss; + + g_loss -> D_f[color=green, style=dashed]; + D_f -> discriminator[color=green, style=dashed]; + discriminator -> G_img[color=green, style=dashed]; + G_img -> generator[color=green, style=dashed]; + + discriminator [color=red, shape=box]; + generator [color=green, shape=box]; + z [shape=diamond]; + img [shape=diamond]; + label0 [shape=diamond]; + label1 [shape=diamond]; + label2 [shape=diamond]; + + d_loss [color=red]; + g_loss [color=green]; +} diff --git a/doc/design/test.dot.png b/doc/design/test.dot.png new file mode 100644 index 0000000000..4e121a40b9 Binary files /dev/null and b/doc/design/test.dot.png differ diff --git a/paddle/CMakeLists.txt b/paddle/CMakeLists.txt index b435de80a2..7d2becbdd7 100644 --- a/paddle/CMakeLists.txt +++ b/paddle/CMakeLists.txt @@ -1,27 +1,32 @@ add_subdirectory(cuda) add_subdirectory(function) add_subdirectory(utils) -add_subdirectory(testing) add_subdirectory(math) -add_subdirectory(parameter) add_subdirectory(gserver) -add_subdirectory(pserver) -add_subdirectory(trainer) -add_subdirectory(scripts) -add_subdirectory(string) - -if(Boost_FOUND) - add_subdirectory(memory) - add_subdirectory(platform) - add_subdirectory(framework) - add_subdirectory(operators) - add_subdirectory(pybind) -endif() +add_subdirectory(parameter) +add_subdirectory(testing) -if(WITH_C_API) +if(MOBILE_INFERENCE) add_subdirectory(capi) -endif() +else() + add_subdirectory(pserver) + add_subdirectory(trainer) + add_subdirectory(string) + add_subdirectory(scripts) + + if(WITH_C_API) + add_subdirectory(capi) + endif() + + if(Boost_FOUND) + add_subdirectory(memory) + add_subdirectory(platform) + add_subdirectory(framework) + add_subdirectory(operators) + add_subdirectory(pybind) + endif() -if(WITH_SWIG_PY) - add_subdirectory(api) + if(WITH_SWIG_PY) + add_subdirectory(api) + endif() endif() diff --git a/paddle/capi/CMakeLists.txt b/paddle/capi/CMakeLists.txt index b9bbe58951..2c458a78c5 100644 --- a/paddle/capi/CMakeLists.txt +++ b/paddle/capi/CMakeLists.txt @@ -37,9 +37,7 @@ set(PADDLE_CAPI_INFER_LIBS paddle_cuda paddle_function paddle_gserver - paddle_proto - paddle_pserver - paddle_network) + paddle_proto) cc_library(paddle_capi_whole DEPS paddle_capi ${PADDLE_CAPI_INFER_LIBS}) diff --git a/paddle/capi/tests/CMakeLists.txt b/paddle/capi/tests/CMakeLists.txt index 8208808b94..bb38ace628 100644 --- a/paddle/capi/tests/CMakeLists.txt +++ b/paddle/capi/tests/CMakeLists.txt @@ -4,11 +4,12 @@ add_unittest(capi_test_mats test_Vector.cpp target_include_directories(capi_test_mats PUBLIC ${PADDLE_CAPI_INC_PATH}) target_link_libraries(capi_test_mats paddle_capi) - -add_unittest_without_exec(capi_test_gradientMachine test_GradientMachine.cpp) -target_include_directories(capi_test_gradientMachine PUBLIC - ${PADDLE_CAPI_INC_PATH}) -target_link_libraries(capi_test_gradientMachine paddle_capi) -add_test(NAME capi_test_gradientMachine - COMMAND ${PADDLE_SOURCE_DIR}/paddle/.set_python_path.sh -d ${PADDLE_SOURCE_DIR}/python ${CMAKE_CURRENT_BINARY_DIR}/capi_test_gradientMachine - WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle/capi/tests) +if(NOT MOBILE_INFERENCE) + add_unittest_without_exec(capi_test_gradientMachine test_GradientMachine.cpp) + target_include_directories(capi_test_gradientMachine PUBLIC + ${PADDLE_CAPI_INC_PATH}) + target_link_libraries(capi_test_gradientMachine paddle_capi) + add_test(NAME capi_test_gradientMachine + COMMAND ${PADDLE_SOURCE_DIR}/paddle/.set_python_path.sh -d ${PADDLE_SOURCE_DIR}/python ${CMAKE_CURRENT_BINARY_DIR}/capi_test_gradientMachine + WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle/capi/tests) +endif() diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 3e0e0f5903..148610aa2c 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -19,7 +19,7 @@ cc_test(scope_test SRCS scope_test.cc DEPS scope) proto_library(framework_proto SRCS framework.proto) cc_library(attribute SRCS attribute.cc DEPS framework_proto) -cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute) +cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute ddim) cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute) cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker) cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto proto_desc) @@ -42,5 +42,12 @@ add_custom_command(TARGET framework_py_proto POST_BUILD cc_library(backward SRCS backward.cc DEPS net_op) cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_context) +cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto backward ${GLOB_OP_LIB}) +if(WITH_GPU) + nv_test(executor_test SRCS executor_test.cc DEPS executor) +else() + cc_test(executor_test SRCS executor_test.cc DEPS executor) +endif() + cc_library(tensor_array SRCS tensor_array.cc DEPS lod_tensor) cc_test(tensor_array_test SRCS tensor_array_test.cc DEPS tensor_array place) diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index c970e01dd1..063b108500 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -172,30 +172,14 @@ static std::unique_ptr BackwardRecursive( std::to_string(i)); net->ops_[op_offset]->Rename(name, dup_outputs.back()); } - // collect all the offset to append `add` op for each alias - // - // one variable is shared between multiple operators. - // insert add operator one by one, then add it to output - for (size_t output_idx = 0; output_idx < dup_outputs.size() - 1; - ++output_idx) { - auto insert_add_x = dup_outputs[output_idx]; - auto insert_add_y = dup_outputs[output_idx + 1]; - auto insert_add_out = name + "@SHARED@" + std::to_string(output_idx); - // first add op inserted - if (output_idx == dup_outputs.size() - 2) { - insert_add_out = name; - } - if (output_idx != 0) { - insert_add_y = name + "@SHARED@" + std::to_string(output_idx - 1); - } - insert_position.push_back( - {dup_op.back(), - OpRegistry::CreateOp("sum", {{"X", {insert_add_x, insert_add_y}}}, - {{"Out", {insert_add_out}}}, {})}); - } + // collect all the offset for each alias, + // insert a sum operator to add all aliases to output + insert_position.push_back( + {dup_op.back(), OpRegistry::CreateOp("sum", {{"X", dup_outputs}}, + {{"Out", {name}}}, {})}); } - // make sure the inserted `add` ops follow the BFS order. + // make sure the inserted `sum` ops follow the BFS order. insert_position.sort( [](const Pos& l, const Pos& r) { return l.first > r.first; }); @@ -302,7 +286,7 @@ std::vector> MakeOpGrad( return grad_op_descs; // empty vector } - grad_op_descs = OpRegistry::CreateGradOpDescs(*op_desc); + grad_op_descs = OpRegistry::CreateGradOpDescs(op_desc.get()); std::list> pending_fill_zeros_ops; for (auto& desc : grad_op_descs) { diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index 30225a4a99..3b7cbcd989 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -58,6 +58,8 @@ class MulOpMaker : public OpProtoAndCheckerMaker { AddInput("X", "A"); AddInput("Y", "B"); AddOutput("Out", "Out"); + AddAttr("x_num_col_dims", "").SetDefault(1).EqualGreaterThan(1); + AddAttr("y_num_col_dims", "").SetDefault(1).EqualGreaterThan(1); AddComment("Mul"); } }; @@ -440,6 +442,28 @@ TEST(Backward, simple_single_op) { std::vector({f::GradVarName("b")})); } +TEST(Backward, default_attribute) { + f::ProgramDesc *program_desc = GetNewProgramDesc(); + f::ProgramDescBind &program = f::ProgramDescBind::Instance(program_desc); + f::BlockDescBind *block = program.Block(0); + f::OpDescBind *op = block->AppendOp(); + op->SetType("mul"); + op->SetInput("X", {"x"}); + op->SetInput("Y", {"y"}); + op->SetOutput("Out", {"out"}); + + AppendBackward(program, {}); + + ASSERT_EQ(block->AllOps().size(), 2UL); + EXPECT_EQ(boost::get(op->GetAttr("x_num_col_dims")), 1); + EXPECT_EQ(boost::get(op->GetAttr("y_num_col_dims")), 1); + + f::OpDescBind *grad_op = block->AllOps()[1]; + ASSERT_EQ(grad_op->Type(), "mul_grad"); + EXPECT_EQ(boost::get(grad_op->GetAttr("x_num_col_dims")), 1); + EXPECT_EQ(boost::get(grad_op->GetAttr("y_num_col_dims")), 1); +} + TEST(Backward, simple_mult_op) { f::ProgramDesc *program_desc = GetNewProgramDesc(); f::ProgramDescBind &program = f::ProgramDescBind::Instance(program_desc); diff --git a/paddle/framework/block_desc.cc b/paddle/framework/block_desc.cc index 01f50e1393..509aa235d3 100644 --- a/paddle/framework/block_desc.cc +++ b/paddle/framework/block_desc.cc @@ -74,6 +74,12 @@ void BlockDescBind::Sync() { for (auto &op_desc : ops_) { op_field.AddAllocated(op_desc->Proto()); } + auto &var_field = *this->desc_->mutable_vars(); + var_field.Clear(); + var_field.Reserve(static_cast(vars_.size())); + for (auto &var_desc : vars_) { + var_field.AddAllocated(var_desc.second->Proto()); + } need_update_ = false; } } diff --git a/paddle/framework/block_desc.h b/paddle/framework/block_desc.h index 2de270f60e..3437e89923 100644 --- a/paddle/framework/block_desc.h +++ b/paddle/framework/block_desc.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include +#include #include #include #include "paddle/framework/op_desc.h" diff --git a/paddle/framework/data_type.h b/paddle/framework/data_type.h index 55e3931f87..649899d425 100644 --- a/paddle/framework/data_type.h +++ b/paddle/framework/data_type.h @@ -28,7 +28,6 @@ inline DataType ToDataType(std::type_index type) { return DataType::INT32; } else { PADDLE_THROW("Not supported"); - return static_cast(-1); } } diff --git a/paddle/framework/executor.cc b/paddle/framework/executor.cc new file mode 100644 index 0000000000..c388b2198e --- /dev/null +++ b/paddle/framework/executor.cc @@ -0,0 +1,163 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/executor.h" + +#include +#include +#include +#include +#include + +#include "paddle/framework/lod_tensor.h" +#include "paddle/framework/op_registry.h" +#include "paddle/framework/scope.h" + +namespace paddle { +namespace framework { + +const std::string kFeedOpType = "feed"; +const std::string kFetchOpType = "fetch"; + +Executor::Executor(const std::vector& places) { + PADDLE_ENFORCE_GT(places.size(), 0); + device_contexts_.resize(places.size()); + for (size_t i = 0; i < places.size(); i++) { + if (platform::is_cpu_place(places[i])) { + device_contexts_[i] = new platform::CPUDeviceContext( + boost::get(places[i])); + } else if (platform::is_gpu_place(places[i])) { +#ifdef PADDLE_WITH_CUDA + device_contexts_[i] = new platform::CUDADeviceContext( + boost::get(places[i])); +#else + PADDLE_THROW( + "'GPUPlace' is not supported, Please re-compile with WITH_GPU " + "option"); +#endif + } + } +} + +Executor::~Executor() { + for (auto& device_context : device_contexts_) { + delete device_context; + } +} + +void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id) { + // TODO(tonyyang-svail): + // - only runs on the first device (i.e. no interdevice communication) + // - will change to use multiple blocks for RNN op and Cond Op + PADDLE_ENFORCE_GT(pdesc.blocks_size(), block_id); + auto& block = pdesc.blocks(block_id); + auto& device = device_contexts_[0]; + + // Instantiate all the vars in the global scope + for (auto& var : block.vars()) { + scope->NewVar(var.name()); + } + + Scope& local_scope = scope->NewScope(); + + std::vector should_run = Prune(pdesc, block_id); + PADDLE_ENFORCE_EQ(should_run.size(), static_cast(block.ops_size())); + for (size_t i = 0; i < should_run.size(); ++i) { + if (should_run[i]) { + for (auto& var : block.ops(i).outputs()) { + for (auto& argu : var.arguments()) { + if (local_scope.FindVar(argu) == nullptr) { + local_scope.NewVar(argu); + } + } + } + auto op = paddle::framework::OpRegistry::CreateOp(block.ops(i)); + op->Run(local_scope, *device); + } + } + + // TODO(tonyyang-svail): + // - Destroy local_scope +} + +std::vector Prune(const ProgramDesc& pdesc, int block_id) { + // TODO(tonyyang-svail): + // - will change to use multiple blocks for RNN op and Cond Op + + auto& block = pdesc.blocks(block_id); + auto& ops = block.ops(); + + bool expect_feed = true; + for (auto& op_desc : ops) { + PADDLE_ENFORCE(op_desc.type() != kFeedOpType || expect_feed, + "All FeedOps are at the beginning of the ProgramDesc"); + expect_feed = (op_desc.type() == kFeedOpType); + } + + bool expect_fetch = true; + for (auto op_iter = ops.rbegin(); op_iter != ops.rend(); ++op_iter) { + auto& op_desc = *op_iter; + PADDLE_ENFORCE(op_desc.type() != kFetchOpType || expect_fetch, + "All FetchOps must at the end of the ProgramDesc"); + expect_fetch = (op_desc.type() == kFetchOpType); + } + + std::set dependent_vars; + std::vector should_run; + for (auto op_iter = ops.rbegin(); op_iter != ops.rend(); ++op_iter) { + auto& op_desc = *op_iter; + + bool found_dependent_vars = false; + for (auto& var : op_desc.outputs()) { + for (auto& argu : var.arguments()) { + if (dependent_vars.count(argu) != 0) { + found_dependent_vars = true; + } + } + } + + if (op_desc.type() == kFetchOpType || found_dependent_vars) { + // erase its output to the dependency graph + for (auto& var : op_desc.outputs()) { + for (auto& argu : var.arguments()) { + dependent_vars.erase(argu); + } + } + + // insert its input to the dependency graph + for (auto& var : op_desc.inputs()) { + for (auto& argu : var.arguments()) { + dependent_vars.insert(argu); + } + } + + should_run.push_back(true); + } else { + should_run.push_back(false); + } + } + + // TODO(tonyyang-svail): + // - check this after integration of Init + // PADDLE_ENFORCE(dependent_vars.empty()); + + // since we are traversing the ProgramDesc in reverse order + // we reverse the should_run vector + std::reverse(should_run.begin(), should_run.end()); + + return should_run; +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/executor.h b/paddle/framework/executor.h new file mode 100644 index 0000000000..4e3bc2c0a5 --- /dev/null +++ b/paddle/framework/executor.h @@ -0,0 +1,55 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/framework.pb.h" +#include "paddle/framework/op_info.h" +#include "paddle/framework/scope.h" +#include "paddle/framework/tensor.h" + +namespace paddle { +namespace framework { + +class Executor { + public: + explicit Executor(const std::vector& places); + ~Executor(); + + /* @Brief + * Runtime evaluation of the given ProgramDesc under certain Scope + * + * @param + * ProgramDesc + * Scope + */ + void Run(const ProgramDesc&, Scope*, int); + + private: + std::vector device_contexts_; +}; + +/* @Brief + * Pruning the graph + * + * @param + * ProgramDesc + * + * @return + * vector Same size as ops. Indicates whether an op should be run. + */ +std::vector Prune(const ProgramDesc& pdesc, int block_id); + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/executor_test.cc b/paddle/framework/executor_test.cc new file mode 100644 index 0000000000..137e53d849 --- /dev/null +++ b/paddle/framework/executor_test.cc @@ -0,0 +1,318 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/executor.h" + +#include +#include + +#include "gtest/gtest.h" +#include "paddle/framework/attribute.h" +#include "paddle/framework/backward.h" +#include "paddle/framework/block_desc.h" +#include "paddle/framework/op_desc.h" +#include "paddle/framework/op_registry.h" +#include "paddle/framework/operator.h" + +USE_OP(elementwise_add); +USE_OP(gaussian_random); +USE_OP(feed); +USE_OP(fetch); +USE_OP(mul); +USE_OP(sum); +USE_OP(squared_l2_distance); +USE_OP(fill_constant); +USE_OP(sgd); + +using namespace paddle::platform; +using namespace paddle::framework; + +void AddOp(const std::string& type, const VariableNameMap& inputs, + const VariableNameMap& outputs, AttributeMap attrs, + paddle::framework::BlockDescBind* block) { + // insert output + for (auto kv : outputs) { + for (auto v : kv.second) { + auto var = block->NewVar(v); + var->SetDataType(paddle::framework::DataType::FP32); + } + } + + // insert op + auto op = block->AppendOp(); + op->SetType(type); + for (auto& kv : inputs) { + op->SetInput(kv.first, kv.second); + } + for (auto& kv : outputs) { + op->SetOutput(kv.first, kv.second); + } + op->SetAttrMap(attrs); +} + +// Tensors in feed value variable will only be in CPUPlace +// So we can memcpy the data from vector to feed_value +template +void SetFeedVariable(const std::vector>& inputs, + const std::vector>& dims) { + Variable* g_feed_value = GetGlobalScope().FindVar("feed_value"); + auto& feed_inputs = + *(g_feed_value->GetMutable>()); + size_t size = inputs.size(); + feed_inputs.resize(size); + for (size_t i = 0; i < size; i++) { + T* dst = feed_inputs[i].mutable_data(make_ddim(dims[i]), CPUPlace()); + memcpy(dst, inputs[i].data(), inputs[i].size() * sizeof(T)); + } +} + +// Tensors in fetch value variable will only be in CPUPlace +// So we can memcpy the data from fetch_value to vector +template +std::vector> GetFetchVariable() { + Variable* g_fetch_value = GetGlobalScope().FindVar("fetch_value"); + auto& fetch_outputs = + *(g_fetch_value->GetMutable>()); + + size_t size = fetch_outputs.size(); + std::vector> result; + result.reserve(size); + for (size_t i = 0; i < size; i++) { + std::vector tmp; + tmp.resize(fetch_outputs[i].numel()); + memcpy(tmp.data(), fetch_outputs[i].data(), + fetch_outputs[i].numel() * sizeof(T)); + result.push_back(tmp); + } + + return result; +} + +class ExecutorTesterRandom : public ::testing::Test { + public: + virtual void SetUp() override { + int input_dim = 3, batch_size = 2, embed_dim = 5; + + auto temp_init_root_block = init_pdesc_.add_blocks(); + temp_init_root_block->set_idx(0); + temp_init_root_block->set_parent_idx(-1); + paddle::framework::ProgramDescBind& init_program = + paddle::framework::ProgramDescBind::Instance(&init_pdesc_); + paddle::framework::BlockDescBind* init_root_block = init_program.Block(0); + + AddOp("gaussian_random", {}, {{"Out", {"w1"}}}, + {{"dims", std::vector{input_dim, embed_dim}}}, init_root_block); + AddOp("gaussian_random", {}, {{"Out", {"w2"}}}, + {{"dims", std::vector{embed_dim, input_dim}}}, init_root_block); + AddOp("fetch", {{"Input", {"w1"}}}, {}, {{"col", 0}}, init_root_block); + AddOp("fetch", {{"Input", {"w2"}}}, {}, {{"col", 1}}, init_root_block); + + // flush + init_program.Proto(); + + // run block + auto temp_root_block = pdesc_.add_blocks(); + temp_root_block->set_idx(0); + temp_root_block->set_parent_idx(-1); + paddle::framework::ProgramDescBind& program = + paddle::framework::ProgramDescBind::Instance(&pdesc_); + paddle::framework::BlockDescBind* root_block = program.Block(0); + + // feed data + inputs_.push_back({1.0, 1.0, 1.0, 1.0, 1.0, 1.0}); + dims_.push_back({batch_size, input_dim}); + AddOp("feed", {}, {{"Out", {"a"}}}, + {{"dims", std::vector{batch_size, input_dim}}, {"col", 0}}, + root_block); + + // forward + AddOp("mul", {{"X", {"a"}}, {"Y", {"w1"}}}, {{"Out", {"b"}}}, {}, + root_block); + AddOp("mul", {{"X", {"b"}}, {"Y", {"w2"}}}, {{"Out", {"a_out"}}}, {}, + root_block); + AddOp("squared_l2_distance", {{"X", {"a"}}, {"Y", {"a_out"}}}, + {{"Out", {"l2_distance"}}, {"sub_result", {"l2_distance_sub"}}}, {}, + root_block); + + // backward + AddOp("fill_constant", {}, {{"Out", {"l2_distance@GRAD"}}}, + {{"shape", std::vector{batch_size, 1}}, {"value", float(1.0)}}, + root_block); + AppendBackward(program, {}); + + // update + AddOp("fill_constant", {}, {{"Out", {"learning_rate"}}}, + {{"shape", std::vector{1}}, {"value", float(0.001)}}, + root_block); + AddOp("sgd", {{"Param", {"w1"}}, + {"LearningRate", {"learning_rate"}}, + {"Grad", {"w1@GRAD"}}}, + {{"ParamOut", {"w1"}}}, {}, root_block); + AddOp("sgd", {{"Param", {"w2"}}, + {"LearningRate", {"learning_rate"}}, + {"Grad", {"w2@GRAD"}}}, + {{"ParamOut", {"w2"}}}, {}, root_block); + + AddOp("fetch", {{"Input", {"w1"}}}, {}, {{"col", 0}}, root_block); + AddOp("fetch", {{"Input", {"w2"}}}, {}, {{"col", 1}}, root_block); + AddOp("fetch", {{"Input", {"l2_distance"}}}, {}, {{"col", 0}}, root_block); + + // flush + program.Proto(); + } + + protected: + ProgramDesc init_pdesc_; + ProgramDesc pdesc_; + std::vector> inputs_; + std::vector> dims_; +}; + +class ExecutorTesterFeedAndFetch : public ::testing::Test { + public: + virtual void SetUp() override { + auto temp_root_block = pdesc_.add_blocks(); + temp_root_block->set_idx(0); + temp_root_block->set_parent_idx(-1); + + // wrap to BlockDescBind + paddle::framework::ProgramDescBind& program = + paddle::framework::ProgramDescBind::Instance(&pdesc_); + paddle::framework::BlockDescBind* root_block = program.Block(0); + + std::vector dim{6}; + + AddOp("feed", {}, {{"Out", {"a"}}}, {{"dims", dim}, {"col", 0}}, + root_block); + AddOp("feed", {}, {{"Out", {"b"}}}, {{"dims", dim}, {"col", 1}}, + root_block); + AddOp("fetch", {{"Input", {"a"}}}, {}, {{"col", 0}}, root_block); + AddOp("fetch", {{"Input", {"b"}}}, {}, {{"col", 1}}, root_block); + + // flush + program.Proto(); + + std::vector vec1 = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; + std::vector vec2 = {4.0, 5.0, 6.0, 7.0, 8.0, 9.0}; + inputs_.push_back(vec1); + inputs_.push_back(vec2); + dims_.push_back({static_cast(vec1.size())}); + dims_.push_back({static_cast(vec2.size())}); + } + + protected: + ProgramDesc pdesc_; + std::vector> inputs_; + std::vector> dims_; +}; + +#ifndef PADDLE_WITH_CUDA +TEST_F(ExecutorTesterRandom, CPU) { + std::vector places; + CPUPlace cpu_place; + places.push_back(cpu_place); + + // We have a global Scope and BuddyAllocator, and we must ensure + // global BuddyAllocator is initialized before global Scope. Thus, + // global Scope will deconstruct before BuddyAllocator. Otherwise, + // "pointer being freed was not allocated" error will appear. + paddle::memory::Used(cpu_place); + + std::unique_ptr executor(new Executor(places)); + + executor->Run(init_pdesc_, &GetGlobalScope(), 0); + SetFeedVariable(inputs_, dims_); + executor->Run(pdesc_, &GetGlobalScope(), 0); + std::vector> result = GetFetchVariable(); +} + +TEST_F(ExecutorTesterFeedAndFetch, CPU) { + std::vector places; + CPUPlace cpu_place; + places.push_back(cpu_place); + + // We have a global Scope and BuddyAllocator, and we must ensure + // global BuddyAllocator is initialized before global Scope. Thus, + // global Scope will deconstruct before BuddyAllocator. Otherwise, + // "pointer being freed was not allocated" error will appear. + paddle::memory::Used(cpu_place); + + std::unique_ptr executor(new Executor(places)); + + for (int batch_id = 0; batch_id < 3; batch_id++) { + SetFeedVariable(inputs_, dims_); + executor->Run(pdesc_, &GetGlobalScope(), 0); + std::vector> result = GetFetchVariable(); + PADDLE_ENFORCE_EQ(result.size(), inputs_.size()); + for (size_t i = 0; i < result.size(); ++i) { + PADDLE_ENFORCE_EQ(result[i].size(), inputs_[i].size()); + for (size_t j = 0; j < result[i].size(); ++j) { + PADDLE_ENFORCE_EQ(result[i][j], inputs_[i][j]); + } + } + } +} +#else +TEST_F(ExecutorTesterRandom, GPU) { + std::vector places; + GPUPlace gpu_place(0); + places.push_back(gpu_place); + + // We have a global Scope and BuddyAllocator, and we must ensure + // global BuddyAllocator is initialized before global Scope. Thus, + // global Scope will deconstruct before BuddyAllocator. Otherwise, + // "pointer being freed was not allocated" error will appear. + // If paddle is compiled with GPU, both CPU and GPU BuddyAllocator + // need to be used at first. + paddle::memory::Used(CPUPlace()); + paddle::memory::Used(gpu_place); + + std::unique_ptr executor(new Executor(places)); + + executor->Run(init_pdesc_, &GetGlobalScope(), 0); + for (int batch_id = 0; batch_id < 3; batch_id++) { + SetFeedVariable(inputs_, dims_); + executor->Run(pdesc_, &GetGlobalScope(), 0); + } +} + +TEST_F(ExecutorTesterFeedAndFetch, GPU) { + std::vector places; + GPUPlace gpu_place(0); + places.push_back(gpu_place); + // We have a global Scope and BuddyAllocator, and we must ensure + // global BuddyAllocator is initialized before global Scope. Thus, + // global Scope will deconstruct before BuddyAllocator. Otherwise, + // "pointer being freed was not allocated" error will appear. + // If paddle is compiled with GPU, both CPU and GPU BuddyAllocator + // need to be used at first. + paddle::memory::Used(CPUPlace()); + paddle::memory::Used(gpu_place); + + std::unique_ptr executor(new Executor(places)); + + for (int batch_id = 0; batch_id < 3; batch_id++) { + SetFeedVariable(inputs_, dims_); + executor->Run(pdesc_, &GetGlobalScope(), 0); + std::vector> result = GetFetchVariable(); + PADDLE_ENFORCE_EQ(result.size(), inputs_.size()); + for (size_t i = 0; i < result.size(); ++i) { + PADDLE_ENFORCE_EQ(result[i].size(), inputs_[i].size()); + for (size_t j = 0; j < result[i].size(); ++j) { + PADDLE_ENFORCE_EQ(result[i][j], inputs_[i][j]); + } + } + } +} +#endif diff --git a/paddle/framework/framework.proto b/paddle/framework/framework.proto index ac2827e547..b7a63f9ba1 100644 --- a/paddle/framework/framework.proto +++ b/paddle/framework/framework.proto @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ syntax = "proto2"; +option optimize_for = LITE_RUNTIME; package paddle.framework; enum AttrType { diff --git a/paddle/framework/op_desc.cc b/paddle/framework/op_desc.cc index 02aa74a842..e7538b4af3 100644 --- a/paddle/framework/op_desc.cc +++ b/paddle/framework/op_desc.cc @@ -13,7 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/framework/op_desc.h" +#include +#include #include "paddle/framework/block_desc.h" +#include "paddle/framework/operator.h" namespace paddle { namespace framework { @@ -25,6 +28,7 @@ OpDescBind::OpDescBind(const std::string &type, const VariableNameMap &inputs, inputs_ = inputs; outputs_ = outputs; attrs_ = attrs; + need_update_ = true; } OpDesc *OpDescBind::Proto() { @@ -184,5 +188,38 @@ void OpDescBind::Sync() { need_update_ = false; } } + +using InferShapeFuncMap = + std::unordered_map>; + +static InferShapeFuncMap &InferShapeFuncs() { + static InferShapeFuncMap *g_map = nullptr; + if (g_map == nullptr) { + g_map = new InferShapeFuncMap(); + auto &info_map = OpInfoMap::Instance(); + // all registered kernels + for (auto &pair : OperatorWithKernel::AllOpKernels()) { + auto &info = info_map.Get(pair.first); + // use empty type here to avoid runtime checks. + auto op = + static_cast(info.Creator()("", {}, {}, {})); + g_map->insert( + {pair.first, [op](InferShapeContext *ctx) { op->InferShape(ctx); }}); + } + } + return *g_map; +} + +void OpDescBind::InferShape(const BlockDescBind &block) const { + auto &funcs = InferShapeFuncs(); + auto it = funcs.find(this->Type()); + if (it == funcs.end()) { + PADDLE_THROW("Operator %s has not been registered", this->Type()); + } + CompileTimeInferShapeContext ctx(*this, block); + it->second(&ctx); +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/op_desc.h b/paddle/framework/op_desc.h index b39808dad1..81c4225041 100644 --- a/paddle/framework/op_desc.h +++ b/paddle/framework/op_desc.h @@ -52,8 +52,6 @@ class OpDescBind { void SetOutput(const std::string ¶m_name, const std::vector &args); - std::string DebugString() { return this->Proto()->DebugString(); } - bool HasAttr(const std::string &name) const { return attrs_.find(name) != attrs_.end(); } @@ -97,6 +95,13 @@ class OpDescBind { const VariableNameMap &Outputs() const { return outputs_; } + AttributeMap *MutableAttrMap() { + this->need_update_ = true; + return &this->attrs_; + } + + void InferShape(const BlockDescBind &block) const; + private: template static std::vector MapKeys(const MapType &map) { diff --git a/paddle/framework/op_registry.cc b/paddle/framework/op_registry.cc index 66043f6e04..b118edae17 100644 --- a/paddle/framework/op_registry.cc +++ b/paddle/framework/op_registry.cc @@ -60,9 +60,14 @@ std::unique_ptr OpRegistry::CreateOp(const OpDescBind& op_desc) { } std::vector> OpRegistry::CreateGradOpDescs( - const OpDescBind& op_desc) { - auto& info = OpInfoMap::Instance().Get(op_desc.Type()); - return info.grad_op_maker_(op_desc); + OpDescBind* op_desc) { + auto& info = OpInfoMap::Instance().Get(op_desc->Type()); + + if (info.Checker() != nullptr) { + info.Checker()->Check(*op_desc->MutableAttrMap()); + } + + return info.grad_op_maker_(*op_desc); } } // namespace framework diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index cce3605fd4..5ca3af52a6 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -80,7 +80,7 @@ class OpRegistry { static std::unique_ptr CreateOp(const OpDesc& op_desc); static std::vector> CreateGradOpDescs( - const OpDescBind& op_desc); + OpDescBind* op_desc); static std::unique_ptr CreateOp(const OpDescBind& op_desc); }; diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 2ca838f838..2fca816f35 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -205,13 +205,13 @@ void OperatorBase::GenerateTemporaryNames() { } template <> -const Tensor* InferShapeContext::Input(const std::string& name) const { +const Tensor* ExecutionContext::Input(const std::string& name) const { auto* var = InputVar(name); return var == nullptr ? nullptr : GetTensorFromVar(var); } template <> -const std::vector InferShapeContext::MultiInput( +const std::vector ExecutionContext::MultiInput( const std::string& name) const { auto names = op().Inputs(name); std::vector res; @@ -225,13 +225,13 @@ const std::vector InferShapeContext::MultiInput( } template <> -Tensor* InferShapeContext::Output(const std::string& name) const { +Tensor* ExecutionContext::Output(const std::string& name) const { auto var = OutputVar(name); return var == nullptr ? nullptr : var->GetMutable(); } template <> -std::vector InferShapeContext::MultiOutput( +std::vector ExecutionContext::MultiOutput( const std::string& name) const { auto names = op().Outputs(name); std::vector res; diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index d7bc9c9ffb..15f80b5720 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -57,7 +57,6 @@ inline std::string GradVarName(const std::string& var_name) { } class OperatorBase; -class InferShapeContext; class ExecutionContext; extern const Tensor* GetTensorFromVar(const Variable* var); @@ -143,9 +142,9 @@ class OperatorBase { // Macro for define a clone method. // If you are writing an kernel operator, `Clone` will be defined when you // register it. i.e. `Clone` method is not needed to define by yourself. -#define DEFINE_OP_CLONE_METHOD(cls) \ - std::unique_ptr Clone() const final { \ - return std::unique_ptr(new cls(*this)); \ +#define DEFINE_OP_CLONE_METHOD(cls) \ + std::unique_ptr<::paddle::framework::OperatorBase> Clone() const final { \ + return std::unique_ptr<::paddle::framework::OperatorBase>(new cls(*this)); \ } // Macro for define a default constructor for Operator. @@ -169,10 +168,11 @@ class NOP : public OperatorBase { } }; -class InferShapeContext { +class ExecutionContext { public: - InferShapeContext(const OperatorBase& op, const Scope& scope) - : op_(op), scope_(scope) {} + ExecutionContext(const OperatorBase& op, const Scope& scope, + const platform::DeviceContext& device_context) + : op_(op), scope_(scope), device_context_(device_context) {} const OperatorBase& op() const { return op_; } @@ -278,31 +278,6 @@ class InferShapeContext { out_tensor->set_lod(in_tensor.lod()); } - private: - const OperatorBase& op_; - const Scope& scope_; -}; - -template <> -const Tensor* InferShapeContext::Input(const std::string& name) const; - -template <> -const std::vector InferShapeContext::MultiInput( - const std::string& name) const; - -template <> -Tensor* InferShapeContext::Output(const std::string& name) const; - -template <> -std::vector InferShapeContext::MultiOutput( - const std::string& name) const; - -class ExecutionContext : public InferShapeContext { - public: - ExecutionContext(const OperatorBase& op, const Scope& scope, - const platform::DeviceContext& device_context) - : InferShapeContext(op, scope), device_context_(device_context) {} - template ::EigenDeviceType> @@ -315,10 +290,26 @@ class ExecutionContext : public InferShapeContext { } private: + const OperatorBase& op_; + const Scope& scope_; const platform::DeviceContext& device_context_; }; -class CompileTimeInferShapeContext : public InferShapeContextBase { +template <> +const Tensor* ExecutionContext::Input(const std::string& name) const; + +template <> +const std::vector ExecutionContext::MultiInput( + const std::string& name) const; + +template <> +Tensor* ExecutionContext::Output(const std::string& name) const; + +template <> +std::vector ExecutionContext::MultiOutput( + const std::string& name) const; + +class CompileTimeInferShapeContext : public InferShapeContext { public: CompileTimeInferShapeContext(const OpDescBind& op, const BlockDescBind& block) : op_(op), block_(block) {} @@ -414,7 +405,7 @@ class CompileTimeInferShapeContext : public InferShapeContextBase { const BlockDescBind& block_; }; -class RuntimeInferShapeContext : public InferShapeContextBase { +class RuntimeInferShapeContext : public InferShapeContext { public: RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope) : op_(op), scope_(scope) {} @@ -612,7 +603,7 @@ class OperatorWithKernel : public OperatorBase { }); } - virtual void InferShape(InferShapeContextBase* ctx) const = 0; + virtual void InferShape(InferShapeContext* ctx) const = 0; protected: // indicate kernel DataType by input data. Defaultly all input data must be diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index a0c17b41f2..a02f4668bc 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -113,7 +113,7 @@ class OpWithKernelTest : public OperatorWithKernel { using OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override {} + void InferShape(framework::InferShapeContext* ctx) const override {} DataType IndicateDataType(const ExecutionContext& ctx) const override { return DataType::FP32; } diff --git a/paddle/framework/program_desc.h b/paddle/framework/program_desc.h index 9b34a06aef..f29b1c54e7 100644 --- a/paddle/framework/program_desc.h +++ b/paddle/framework/program_desc.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include #include "paddle/framework/framework.pb.h" #include "paddle/platform/macros.h" @@ -31,8 +32,6 @@ class ProgramDescBind { BlockDescBind *Block(size_t idx) { return blocks_[idx].get(); } - std::string DebugString() { return Proto()->DebugString(); } - size_t Size() const { return blocks_.size(); } ProgramDesc *Proto(); diff --git a/paddle/framework/scope.cc b/paddle/framework/scope.cc index 080b4ac621..5821bac928 100644 --- a/paddle/framework/scope.cc +++ b/paddle/framework/scope.cc @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/framework/scope.h" + +#include // for unique_ptr +#include // for call_once #include "paddle/string/printf.h" namespace paddle { @@ -62,5 +65,17 @@ void Scope::DropKids() { kids_.clear(); } +std::once_flag feed_variable_flag; + +framework::Scope& GetGlobalScope() { + static std::unique_ptr g_scope{nullptr}; + std::call_once(feed_variable_flag, [&]() { + g_scope.reset(new framework::Scope()); + g_scope->NewVar("feed_value"); + g_scope->NewVar("fetch_value"); + }); + return *(g_scope.get()); +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/scope.h b/paddle/framework/scope.h index 7047f0d55e..a8cfb107c2 100644 --- a/paddle/framework/scope.h +++ b/paddle/framework/scope.h @@ -73,5 +73,7 @@ class Scope { DISABLE_COPY_AND_ASSIGN(Scope); }; +framework::Scope& GetGlobalScope(); + } // namespace framework } // namespace paddle diff --git a/paddle/framework/shape_inference.h b/paddle/framework/shape_inference.h index 74e0371e32..64aab16ae5 100644 --- a/paddle/framework/shape_inference.h +++ b/paddle/framework/shape_inference.h @@ -20,11 +20,11 @@ namespace paddle { namespace framework { // TODO(longfei): Once after both CompileTimeInferShapeContext and -// RuntimeInferShapeContext get merged, we can rename InferShapeContextBase into +// RuntimeInferShapeContext get merged, we can rename InferShapeContext into // InferShapeContext so to replace the current InferShapeContext. -class InferShapeContextBase { +class InferShapeContext { public: - virtual ~InferShapeContextBase() {} + virtual ~InferShapeContext() {} virtual bool HasInput(const std::string &name) const = 0; virtual bool HasOutput(const std::string &name) const = 0; diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index 80a3f0a393..ba82127d9c 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -95,6 +95,19 @@ class Tensor { template inline void CopyFrom(const Tensor& src, const platform::Place& dst_place); + /** + * @brief Copy the content of an external vector to a tensor. + * + * @param[in] src The external vector. + * @param[in] ctx The device context contains place where to store. + * + * * @note CopyFromVector assumes that the tensor has been resized + * before invoking. + */ + template + inline void CopyFromVector(const std::vector& src, + const platform::Place& dst_place); + /** * @brief Return the slice of the tensor. * diff --git a/paddle/framework/tensor_array.h b/paddle/framework/tensor_array.h index 94a14c2df4..293da04997 100644 --- a/paddle/framework/tensor_array.h +++ b/paddle/framework/tensor_array.h @@ -87,12 +87,12 @@ class TensorArray { LoDTensor Stack() const; /* - * Unpacks the given division of a rank-`R` tensor into rank-`(R-1)` tensors. + * Unstacks the given division of a rank-`R` tensor into rank-`(R-1)` tensors. */ void Unstack(const LoDTensor &source) const; /* - * Unpacks the given division of a rank-`R` tensor into rank-`(R-1)` tensors, + * Unstacks the given division of a rank-`R` tensor into rank-`(R-1)` tensors, * with memory of tensors shared. */ void UnstackShared(const LoDTensor &source) const; diff --git a/paddle/framework/tensor_impl.h b/paddle/framework/tensor_impl.h index 379eac94f9..8ee9941982 100644 --- a/paddle/framework/tensor_impl.h +++ b/paddle/framework/tensor_impl.h @@ -123,6 +123,29 @@ inline void Tensor::CopyFrom(const Tensor& src, #endif } +template +inline void Tensor::CopyFromVector(const std::vector& src, + const platform::Place& dst_place) { + auto src_ptr = static_cast(src.data()); + platform::CPUPlace src_place; + auto dst_ptr = static_cast(mutable_data(dst_place)); + auto size = src.size() * sizeof(T); + + if (platform::is_cpu_place(dst_place)) { + memory::Copy(boost::get(dst_place), dst_ptr, src_place, + src_ptr, size); + } +#ifdef PADDLE_WITH_CUDA + else if (platform::is_gpu_place(dst_place)) { + memory::Copy(boost::get(dst_place), dst_ptr, src_place, + src_ptr, size, 0); + } + PADDLE_ENFORCE(cudaStreamSynchronize(0), + "cudaStreamSynchronize failed in Tensor CopyFromVector"); + +#endif +} + template inline Tensor Tensor::Slice(const int& begin_idx, const int& end_idx) const { check_memory_size(); diff --git a/paddle/framework/tensor_test.cc b/paddle/framework/tensor_test.cc index 58cf0fc3cb..492eba69e1 100644 --- a/paddle/framework/tensor_test.cc +++ b/paddle/framework/tensor_test.cc @@ -263,6 +263,93 @@ TEST(Tensor, CopyFrom) { #endif } +TEST(Tensor, CopyFromVector) { + using namespace paddle::framework; + using namespace paddle::platform; + { + std::vector src_vec = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + Tensor cpu_tensor; + + // Copy to CPU Tensor + cpu_tensor.Resize(make_ddim({3, 3})); + auto cpu_place = new paddle::platform::CPUPlace(); + cpu_tensor.CopyFromVector(src_vec, *cpu_place); + + // Compare Tensors + const int* cpu_ptr = cpu_tensor.data(); + const int* src_ptr = src_vec.data(); + ASSERT_NE(src_ptr, cpu_ptr); + for (size_t i = 0; i < 9; ++i) { + EXPECT_EQ(src_ptr[i], cpu_ptr[i]); + } + + src_vec.erase(src_vec.begin(), src_vec.begin() + 5); + cpu_tensor.Resize(make_ddim({2, 2})); + cpu_tensor.CopyFromVector(src_vec, *cpu_place); + cpu_ptr = cpu_tensor.data(); + src_ptr = src_vec.data(); + ASSERT_NE(src_ptr, cpu_ptr); + for (size_t i = 0; i < 5; ++i) { + EXPECT_EQ(src_ptr[i], cpu_ptr[i]); + } + + delete cpu_place; + } + +#ifdef PADDLE_WITH_CUDA + { + std::vector src_vec = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + Tensor cpu_tensor; + Tensor gpu_tensor; + Tensor dst_tensor; + + // Copy to CPU Tensor + cpu_tensor.Resize(make_ddim({3, 3})); + auto cpu_place = new paddle::platform::CPUPlace(); + cpu_tensor.CopyFromVector(src_vec, *cpu_place); + + // Copy to GPUTensor + gpu_tensor.Resize(make_ddim({3, 3})); + auto gpu_place = new paddle::platform::GPUPlace(); + gpu_tensor.CopyFromVector(src_vec, *gpu_place); + // Copy from GPU to CPU tensor for comparison + dst_tensor.CopyFrom(gpu_tensor, *cpu_place); + + // Compare Tensors + const int* src_ptr = src_vec.data(); + const int* cpu_ptr = cpu_tensor.data(); + const int* dst_ptr = dst_tensor.data(); + ASSERT_NE(src_ptr, cpu_ptr); + ASSERT_NE(src_ptr, dst_ptr); + for (size_t i = 0; i < 9; ++i) { + EXPECT_EQ(src_ptr[i], cpu_ptr[i]); + EXPECT_EQ(src_ptr[i], dst_ptr[i]); + } + + src_vec.erase(src_vec.begin(), src_vec.begin() + 5); + + cpu_tensor.Resize(make_ddim({2, 2})); + cpu_tensor.CopyFromVector(src_vec, *cpu_place); + gpu_tensor.Resize(make_ddim({2, 2})); + gpu_tensor.CopyFromVector(src_vec, *gpu_place); + dst_tensor.CopyFrom(gpu_tensor, *cpu_place); + + src_ptr = src_vec.data(); + cpu_ptr = cpu_tensor.data(); + dst_ptr = dst_tensor.data(); + ASSERT_NE(src_ptr, cpu_ptr); + ASSERT_NE(src_ptr, dst_ptr); + for (size_t i = 0; i < 5; ++i) { + EXPECT_EQ(src_ptr[i], cpu_ptr[i]); + EXPECT_EQ(src_ptr[i], dst_ptr[i]); + } + + delete cpu_place; + delete gpu_place; + } +#endif +} + TEST(Tensor, ReshapeToMatrix) { using namespace paddle::framework; using namespace paddle::platform; diff --git a/paddle/framework/type_defs.h b/paddle/framework/type_defs.h index a5b9472213..6f65a942ba 100644 --- a/paddle/framework/type_defs.h +++ b/paddle/framework/type_defs.h @@ -15,6 +15,7 @@ #pragma once #include #include +#include #include "paddle/platform/variant.h" namespace paddle { diff --git a/paddle/framework/var_desc.cc b/paddle/framework/var_desc.cc index 13b9c5f3cd..a88e813b5e 100644 --- a/paddle/framework/var_desc.cc +++ b/paddle/framework/var_desc.cc @@ -32,5 +32,13 @@ std::vector VarDescBind::Shape() const { DataType VarDescBind::GetDataType() const { return desc_.lod_tensor().data_type(); } + +void VarDescBind::SetLoDLevel(int32_t lod_level) { + desc_.mutable_lod_tensor()->set_lod_level(lod_level); +} + +int32_t VarDescBind::GetLodLevel() const { + return desc_.lod_tensor().lod_level(); +} } // namespace framework } // namespace paddle diff --git a/paddle/framework/var_desc.h b/paddle/framework/var_desc.h index 4763bf09d0..464fece85f 100644 --- a/paddle/framework/var_desc.h +++ b/paddle/framework/var_desc.h @@ -66,6 +66,10 @@ class VarDescBind { DataType GetDataType() const; + void SetLoDLevel(int32_t lod_level); + + int32_t GetLodLevel() const; + private: VarDesc desc_; }; diff --git a/paddle/gserver/CMakeLists.txt b/paddle/gserver/CMakeLists.txt index 62cff9361c..5f39167afc 100644 --- a/paddle/gserver/CMakeLists.txt +++ b/paddle/gserver/CMakeLists.txt @@ -60,6 +60,36 @@ if(NOT WITH_PYTHON) dataproviders/PyDataProvider.h) endif() +if(MOBILE_INFERENCE) + # Remove evaluators + list(REMOVE_ITEM GSERVER_SOURCES + layers/ValidationLayer.cpp + evaluators/Evaluator.cpp + evaluators/DetectionMAPEvaluator.cpp + evaluators/CTCErrorEvaluator.cpp + evaluators/ChunkEvaluator.cpp) + + # Remove dataproviders + list(REMOVE_ITEM GSERVER_SOURCES + dataproviders/DataProvider.cpp + dataproviders/MultiDataProvider.cpp + dataproviders/ProtoDataProvider.cpp + dataproviders/PyDataProvider2.cpp + dataproviders/PyDataProvider.cpp) + + # Remove useless gradientmachines + list(REMOVE_ITEM GSERVER_SOURCES + gradientmachines/MultiNetwork.cpp + gradientmachines/RecurrentGradientMachine.cpp + gradientmachines/ParallelNeuralNetwork.cpp + gradientmachines/GradientMachineMode.cpp + gradientmachines/MultiGradientMachine.cpp) + + # Remove useless layers + list(REMOVE_ITEM GSERVER_SOURCES + layers/RecurrentLayerGroup.cpp) +endif() + if(WITH_GPU) cuda_add_library(paddle_gserver ${GSERVER_SOURCES}) else() diff --git a/paddle/gserver/gradientmachines/GradientMachine.cpp b/paddle/gserver/gradientmachines/GradientMachine.cpp index b44e4dc202..de5faf5e1e 100644 --- a/paddle/gserver/gradientmachines/GradientMachine.cpp +++ b/paddle/gserver/gradientmachines/GradientMachine.cpp @@ -17,12 +17,15 @@ limitations under the License. */ #include #include "paddle/utils/Logging.h" +#include "NeuralNetwork.h" +#include "hl_gpu.h" + +#ifndef PADDLE_MOBILE_INFERENCE #include "GradientMachineMode.h" #include "MultiGradientMachine.h" #include "MultiNetwork.h" -#include "NeuralNetwork.h" #include "ParallelNeuralNetwork.h" -#include "hl_gpu.h" +#endif namespace paddle { @@ -30,13 +33,16 @@ GradientMachine* GradientMachine::create( const ModelConfig& config, int mode, const std::vector& parameterTypes) { +#ifndef PADDLE_MOBILE_INFERENCE if (auto gm = IGradientMachineMode::tryCreateGradientMachine(mode, config)) { return gm; } if (FLAGS_trainer_count > 1) { return new MultiGradientMachine(config, FLAGS_use_gpu); } +#endif if (FLAGS_trainer_count == 1) { // single +#ifndef PADDLE_MOBILE_INFERENCE NeuralNetwork* nn; if (config.type() == "multi_nn") { /* multi submodel calculate, thread(s) will be initialized inside */ @@ -48,6 +54,9 @@ GradientMachine* GradientMachine::create( /* single thread calculate */ nn = NeuralNetwork::create(config); } +#else + NeuralNetwork* nn = NeuralNetwork::create(config); +#endif ParamInitCallback testParamInitCb = [](int paramId, Parameter* para) { para->enableType(PARAMETER_VALUE); }; diff --git a/paddle/gserver/gradientmachines/GradientMachine.h b/paddle/gserver/gradientmachines/GradientMachine.h index f9c82a2bef..ebfe0573cf 100644 --- a/paddle/gserver/gradientmachines/GradientMachine.h +++ b/paddle/gserver/gradientmachines/GradientMachine.h @@ -20,13 +20,16 @@ limitations under the License. */ #include "ModelConfig.pb.h" #include "TrainerConfig.pb.h" #include "paddle/gserver/dataproviders/DataProvider.h" -#include "paddle/gserver/evaluators/Evaluator.h" #include "paddle/gserver/layers/Layer.h" #include "paddle/math/Matrix.h" #include "paddle/parameter/Parameter.h" #include "paddle/parameter/ParameterUpdaterBase.h" #include "paddle/utils/Thread.h" +#ifndef PADDLE_MOBILE_INFERENCE +#include "paddle/gserver/evaluators/Evaluator.h" +#endif + namespace paddle { /** * @brief A gradient machine is capable of calculating some outputs given @@ -147,6 +150,7 @@ public: virtual void onPassEnd() = 0; +#ifndef PADDLE_MOBILE_INFERENCE /** * Create an evaluator which can be used for eval() */ @@ -156,6 +160,7 @@ public: * evaluate using the given evaluator */ virtual void eval(Evaluator* evaluator) const = 0; +#endif std::vector& getParameters() { return parameters_; } diff --git a/paddle/gserver/gradientmachines/NeuralNetwork.cpp b/paddle/gserver/gradientmachines/NeuralNetwork.cpp index 26cff3e677..dcf0acb5a2 100644 --- a/paddle/gserver/gradientmachines/NeuralNetwork.cpp +++ b/paddle/gserver/gradientmachines/NeuralNetwork.cpp @@ -14,15 +14,17 @@ limitations under the License. */ #include "paddle/utils/Util.h" +#include "NeuralNetwork.h" +#include "hl_gpu.h" +#include "paddle/gserver/layers/AgentLayer.h" #include "paddle/utils/CustomStackTrace.h" #include "paddle/utils/Logging.h" +#include "paddle/utils/Stat.h" +#ifndef PADDLE_MOBILE_INFERENCE #include "MultiNetwork.h" -#include "NeuralNetwork.h" #include "RecurrentGradientMachine.h" -#include "hl_gpu.h" -#include "paddle/gserver/layers/AgentLayer.h" -#include "paddle/utils/Stat.h" +#endif namespace paddle { void parameterInitNN(int paramId, @@ -54,6 +56,7 @@ void parameterInitNN(int paramId, } NeuralNetwork* NeuralNetwork::create(const ModelConfig& config) { +#ifndef PADDLE_MOBILE_INFERENCE if (config.type() == "recurrent_nn") { return newNeuralNetwork("root"); } else if (config.type() == "multi_nn") { @@ -61,6 +64,9 @@ NeuralNetwork* NeuralNetwork::create(const ModelConfig& config) { } else { return newNeuralNetwork(); } +#else + return new NeuralNetwork(); +#endif } std::map NeuralNetwork::dllInitMap; @@ -304,6 +310,8 @@ void NeuralNetwork::onPassEnd() { } } +#ifndef PADDLE_MOBILE_INFERENCE + class CombinedEvaluator : public Evaluator { public: void addEvaluator(std::unique_ptr&& evaluator) { @@ -466,6 +474,8 @@ Evaluator* NeuralNetwork::makeEvaluator() const { void NeuralNetwork::eval(Evaluator* evaluator) const { evaluator->eval(*this); } +#endif + void NeuralNetwork::setOutputGrad(const std::vector& args) { CHECK_GE(outputLayers_.size(), args.size()); for (size_t i = 0; i < args.size(); ++i) { diff --git a/paddle/gserver/gradientmachines/NeuralNetwork.h b/paddle/gserver/gradientmachines/NeuralNetwork.h index 12810f6425..56a1ec7846 100644 --- a/paddle/gserver/gradientmachines/NeuralNetwork.h +++ b/paddle/gserver/gradientmachines/NeuralNetwork.h @@ -97,9 +97,12 @@ public: virtual void onPassEnd(); +#ifndef PADDLE_MOBILE_INFERENCE virtual Evaluator* makeEvaluator() const; virtual void eval(Evaluator* evaluator) const; +#endif + virtual void resetState(); virtual void setOutputGrad(const std::vector& args); diff --git a/paddle/gserver/layers/Layer.cpp b/paddle/gserver/layers/Layer.cpp index e95f42c863..01f2aae6cf 100644 --- a/paddle/gserver/layers/Layer.cpp +++ b/paddle/gserver/layers/Layer.cpp @@ -15,11 +15,14 @@ limitations under the License. */ #include "paddle/utils/Util.h" #include "CostLayer.h" -#include "ValidationLayer.h" #include "paddle/math/SparseMatrix.h" #include "paddle/utils/Error.h" #include "paddle/utils/Logging.h" +#ifndef PADDLE_MOBILE_INFERENCE +#include "ValidationLayer.h" +#endif + DEFINE_bool(log_error_clipping, false, "enable log error clipping or not"); namespace paddle { @@ -103,10 +106,12 @@ LayerPtr Layer::create(const LayerConfig& config) { return LayerPtr(new MultiClassCrossEntropy(config)); else if (type == "rank-cost") return LayerPtr(new RankingCost(config)); +#ifndef PADDLE_MOBILE_INFERENCE else if (type == "auc-validation") return LayerPtr(new AucValidation(config)); else if (type == "pnpair-validation") return LayerPtr(new PnpairValidation(config)); +#endif return LayerPtr(registrar_.createByType(config.type(), config)); } diff --git a/paddle/gserver/tests/CMakeLists.txt b/paddle/gserver/tests/CMakeLists.txt index de9b8e63df..fcee19415c 100644 --- a/paddle/gserver/tests/CMakeLists.txt +++ b/paddle/gserver/tests/CMakeLists.txt @@ -1,15 +1,17 @@ # gserver pacakge unittests +if(NOT MOBILE_INFERENCE) ################### test_ProtoDataProvider ############ -add_unittest_without_exec(test_ProtoDataProvider - test_ProtoDataProvider.cpp) - -# test_ProtoDataProvider will mkdir as same name, -# so if WORKING_DIRECTORY is default directory, then -# mkdir will get error. -add_test(NAME test_ProtoDataProvider - COMMAND ${CMAKE_CURRENT_BINARY_DIR}/test_ProtoDataProvider - WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle) + add_unittest_without_exec(test_ProtoDataProvider + test_ProtoDataProvider.cpp) + + # test_ProtoDataProvider will mkdir as same name, + # so if WORKING_DIRECTORY is default directory, then + # mkdir will get error. + add_test(NAME test_ProtoDataProvider + COMMAND ${CMAKE_CURRENT_BINARY_DIR}/test_ProtoDataProvider + WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle) +endif() ################# test_LayerGrad ####################### add_unittest_without_exec(test_LayerGrad @@ -98,9 +100,11 @@ add_unittest_without_exec(test_KmaxSeqScore add_test(NAME test_KmaxSeqScore COMMAND test_KmaxSeqScore) +if(NOT MOBILE_INFERENCE) ################## test_Evaluator ####################### -add_unittest(test_Evaluator - test_Evaluator.cpp) + add_unittest(test_Evaluator + test_Evaluator.cpp) +endif() ################ test_LinearChainCRF #################### add_simple_unittest(test_LinearChainCRF) @@ -131,27 +135,31 @@ if(NOT WITH_DOUBLE) WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle) endif() +if(NOT MOBILE_INFERENCE) ############### test_RecurrentGradientMachine ############### -# TODO(yuyang18): There is some bug in test_RecurrentGradientMachine -# I will fix it. -add_unittest_without_exec(test_RecurrentGradientMachine - test_RecurrentGradientMachine.cpp) -add_test(NAME test_RecurrentGradientMachine - COMMAND .set_python_path.sh -d - ${PADDLE_SOURCE_DIR}/python:${PADDLE_SOURCE_DIR}/paddle/gserver/tests - ${CMAKE_CURRENT_BINARY_DIR}/test_RecurrentGradientMachine - WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle) - -add_unittest_without_exec(test_NetworkCompare - test_NetworkCompare.cpp) -if(WITH_GPU) - add_test(NAME test_NetworkCompare - COMMAND .set_python_path.sh -d ${PADDLE_SOURCE_DIR}/python ${CMAKE_CURRENT_BINARY_DIR}/test_NetworkCompare --use_gpu=true - WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle) -else() - add_test(NAME test_NetworkCompare - COMMAND .set_python_path.sh -d ${PADDLE_SOURCE_DIR}/python ${CMAKE_CURRENT_BINARY_DIR}/test_NetworkCompare --use_gpu=false - WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle) + # TODO(yuyang18): There is some bug in test_RecurrentGradientMachine + # I will fix it. + add_unittest_without_exec(test_RecurrentGradientMachine + test_RecurrentGradientMachine.cpp) + add_test(NAME test_RecurrentGradientMachine + COMMAND .set_python_path.sh -d + ${PADDLE_SOURCE_DIR}/python:${PADDLE_SOURCE_DIR}/paddle/gserver/tests + ${CMAKE_CURRENT_BINARY_DIR}/test_RecurrentGradientMachine + WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle) +endif() + +if(NOT MOBILE_INFERENCE) + add_unittest_without_exec(test_NetworkCompare + test_NetworkCompare.cpp) + if(WITH_GPU) + add_test(NAME test_NetworkCompare + COMMAND .set_python_path.sh -d ${PADDLE_SOURCE_DIR}/python ${CMAKE_CURRENT_BINARY_DIR}/test_NetworkCompare --use_gpu=true + WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle) + else() + add_test(NAME test_NetworkCompare + COMMAND .set_python_path.sh -d ${PADDLE_SOURCE_DIR}/python ${CMAKE_CURRENT_BINARY_DIR}/test_NetworkCompare --use_gpu=false + WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle) + endif() endif() diff --git a/paddle/gserver/tests/LayerGradUtil.h b/paddle/gserver/tests/LayerGradUtil.h index 88e831f78b..e10a27eedf 100644 --- a/paddle/gserver/tests/LayerGradUtil.h +++ b/paddle/gserver/tests/LayerGradUtil.h @@ -15,7 +15,6 @@ limitations under the License. */ #pragma once #include "ModelConfig.pb.h" #include "paddle/gserver/layers/DataLayer.h" -#include "paddle/trainer/Trainer.h" #include "paddle/testing/TestUtil.h" using namespace std; // NOLINT diff --git a/paddle/gserver/tests/test_ActivationGrad.cpp b/paddle/gserver/tests/test_ActivationGrad.cpp index de93972a58..f4c2a07c44 100644 --- a/paddle/gserver/tests/test_ActivationGrad.cpp +++ b/paddle/gserver/tests/test_ActivationGrad.cpp @@ -17,7 +17,6 @@ limitations under the License. */ #include #include "ModelConfig.pb.h" #include "paddle/gserver/layers/DataLayer.h" -#include "paddle/trainer/Trainer.h" #include "LayerGradUtil.h" #include "paddle/testing/TestUtil.h" diff --git a/paddle/gserver/tests/test_BatchNorm.cpp b/paddle/gserver/tests/test_BatchNorm.cpp index 050fde9d0a..41116f4809 100644 --- a/paddle/gserver/tests/test_BatchNorm.cpp +++ b/paddle/gserver/tests/test_BatchNorm.cpp @@ -17,7 +17,6 @@ limitations under the License. */ #include #include "ModelConfig.pb.h" #include "paddle/gserver/layers/DataLayer.h" -#include "paddle/trainer/Trainer.h" #include "paddle/utils/GlobalConstants.h" #include "LayerGradUtil.h" diff --git a/paddle/gserver/tests/test_CRFLayerGrad.cpp b/paddle/gserver/tests/test_CRFLayerGrad.cpp index df14449291..f010066ebc 100644 --- a/paddle/gserver/tests/test_CRFLayerGrad.cpp +++ b/paddle/gserver/tests/test_CRFLayerGrad.cpp @@ -16,7 +16,6 @@ limitations under the License. */ #include "ModelConfig.pb.h" #include "paddle/gserver/layers/DataLayer.h" #include "paddle/gserver/layers/LinearChainCRF.h" -#include "paddle/trainer/Trainer.h" #include "LayerGradUtil.h" #include "paddle/testing/TestUtil.h" diff --git a/paddle/gserver/tests/test_ConvTrans.cpp b/paddle/gserver/tests/test_ConvTrans.cpp index 6035a866b4..5f2f966547 100644 --- a/paddle/gserver/tests/test_ConvTrans.cpp +++ b/paddle/gserver/tests/test_ConvTrans.cpp @@ -18,7 +18,6 @@ limitations under the License. */ #include "ModelConfig.pb.h" #include "paddle/gserver/layers/DataLayer.h" #include "paddle/math/MathUtils.h" -#include "paddle/trainer/Trainer.h" #include "paddle/utils/GlobalConstants.h" #include "LayerGradUtil.h" diff --git a/paddle/gserver/tests/test_ConvUnify.cpp b/paddle/gserver/tests/test_ConvUnify.cpp index ffcc47e2a8..8634355b52 100644 --- a/paddle/gserver/tests/test_ConvUnify.cpp +++ b/paddle/gserver/tests/test_ConvUnify.cpp @@ -18,7 +18,6 @@ limitations under the License. */ #include "ModelConfig.pb.h" #include "paddle/gserver/layers/DataLayer.h" #include "paddle/math/MathUtils.h" -#include "paddle/trainer/Trainer.h" #include "paddle/utils/GlobalConstants.h" #include "LayerGradUtil.h" diff --git a/paddle/gserver/tests/test_CrossEntropyOverBeamGrad.cpp b/paddle/gserver/tests/test_CrossEntropyOverBeamGrad.cpp index c922237d33..477638426f 100644 --- a/paddle/gserver/tests/test_CrossEntropyOverBeamGrad.cpp +++ b/paddle/gserver/tests/test_CrossEntropyOverBeamGrad.cpp @@ -18,7 +18,6 @@ limitations under the License. */ #include #include "ModelConfig.pb.h" #include "paddle/gserver/layers/DataLayer.h" -#include "paddle/trainer/Trainer.h" #include "LayerGradUtil.h" #include "paddle/testing/TestUtil.h" diff --git a/paddle/gserver/tests/test_KmaxSeqScore.cpp b/paddle/gserver/tests/test_KmaxSeqScore.cpp index 6386259882..ffe5cfb8db 100644 --- a/paddle/gserver/tests/test_KmaxSeqScore.cpp +++ b/paddle/gserver/tests/test_KmaxSeqScore.cpp @@ -18,7 +18,6 @@ limitations under the License. */ #include #include "ModelConfig.pb.h" #include "paddle/gserver/layers/DataLayer.h" -#include "paddle/trainer/Trainer.h" #include "paddle/utils/GlobalConstants.h" #include "LayerGradUtil.h" diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index 542db5ee5b..f63c93c943 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -21,7 +21,6 @@ limitations under the License. */ #include "ModelConfig.pb.h" #include "paddle/gserver/layers/DataLayer.h" #include "paddle/math/MathUtils.h" -#include "paddle/trainer/Trainer.h" #include "LayerGradUtil.h" #include "paddle/testing/TestUtil.h" diff --git a/paddle/gserver/tests/test_SelectiveFCLayer.cpp b/paddle/gserver/tests/test_SelectiveFCLayer.cpp index 4c87fe1bba..d164e382c4 100644 --- a/paddle/gserver/tests/test_SelectiveFCLayer.cpp +++ b/paddle/gserver/tests/test_SelectiveFCLayer.cpp @@ -24,7 +24,6 @@ limitations under the License. */ #include "paddle/gserver/layers/Layer.h" #include "paddle/gserver/layers/SelectiveFullyConnectedLayer.h" #include "paddle/math/CpuSparseMatrix.h" -#include "paddle/trainer/Trainer.h" using namespace paddle; // NOLINT using namespace std; // NOLINT diff --git a/paddle/gserver/tests/test_SeqSliceLayerGrad.cpp b/paddle/gserver/tests/test_SeqSliceLayerGrad.cpp index 3366002ca1..3dbffc5634 100644 --- a/paddle/gserver/tests/test_SeqSliceLayerGrad.cpp +++ b/paddle/gserver/tests/test_SeqSliceLayerGrad.cpp @@ -15,7 +15,6 @@ limitations under the License. */ #include #include "ModelConfig.pb.h" #include "paddle/gserver/layers/DataLayer.h" -#include "paddle/trainer/Trainer.h" #include "LayerGradUtil.h" #include "paddle/testing/TestUtil.h" diff --git a/paddle/math/tests/test_GpuProfiler.cpp b/paddle/math/tests/test_GpuProfiler.cpp index 9402bd3ec4..d9f146f0d1 100644 --- a/paddle/math/tests/test_GpuProfiler.cpp +++ b/paddle/math/tests/test_GpuProfiler.cpp @@ -162,4 +162,4 @@ int main(int argc, char** argv) { return RUN_ALL_TESTS(); } -#endif /* PADDLE_ONLY_CPU */ +#endif diff --git a/paddle/memory/detail/buddy_allocator.cc b/paddle/memory/detail/buddy_allocator.cc index fdc5ed19dc..e212f7737a 100644 --- a/paddle/memory/detail/buddy_allocator.cc +++ b/paddle/memory/detail/buddy_allocator.cc @@ -182,7 +182,7 @@ BuddyAllocator::PoolSet::iterator BuddyAllocator::RefillPool() { max_chunk_size_ = platform::GpuMaxChunkSize(); } } -#endif // PADDLE_ONLY_CPU +#endif // Allocate a new maximum sized block size_t index = 0; diff --git a/paddle/memory/detail/system_allocator.cc b/paddle/memory/detail/system_allocator.cc index 6c9a46dd09..33166d9ce2 100644 --- a/paddle/memory/detail/system_allocator.cc +++ b/paddle/memory/detail/system_allocator.cc @@ -134,7 +134,7 @@ void GPUAllocator::Free(void* p, size_t size, size_t index) { bool GPUAllocator::UseGpu() const { return true; } -#endif // PADDLE_ONLY_CPU +#endif } // namespace detail } // namespace memory diff --git a/paddle/memory/detail/system_allocator.h b/paddle/memory/detail/system_allocator.h index ee9b012f91..552cab4f96 100644 --- a/paddle/memory/detail/system_allocator.h +++ b/paddle/memory/detail/system_allocator.h @@ -51,7 +51,7 @@ class GPUAllocator : public SystemAllocator { size_t gpu_alloc_size_ = 0; size_t fallback_alloc_size_ = 0; }; -#endif // PADDLE_ONLY_CPU +#endif } // namespace detail } // namespace memory diff --git a/paddle/memory/detail/system_allocator_test.cc b/paddle/memory/detail/system_allocator_test.cc index cd563844e7..6a8558937b 100644 --- a/paddle/memory/detail/system_allocator_test.cc +++ b/paddle/memory/detail/system_allocator_test.cc @@ -62,4 +62,4 @@ TEST(GPUAllocator, Alloc) { TestAllocator(a, 2048); TestAllocator(a, 0); } -#endif // PADDLE_ONLY_CPU +#endif diff --git a/paddle/memory/memcpy.cc b/paddle/memory/memcpy.cc index 790420a8ab..1df88a6da9 100644 --- a/paddle/memory/memcpy.cc +++ b/paddle/memory/memcpy.cc @@ -89,7 +89,7 @@ void Copy(platform::GPUPlace dst_place, platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToDevice); } -#endif // PADDLE_ONLY_CPU +#endif } // namespace memory } // namespace paddle diff --git a/paddle/memory/memcpy.h b/paddle/memory/memcpy.h index 0bccee58c3..9b36182c2b 100644 --- a/paddle/memory/memcpy.h +++ b/paddle/memory/memcpy.h @@ -53,7 +53,7 @@ template void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num, cudaStream_t stream); -#endif // PADDLE_ONLY_CPU +#endif } // namespace memory } // namespace paddle diff --git a/paddle/memory/memory.cc b/paddle/memory/memory.cc index 30ce8a82e1..5087c02385 100644 --- a/paddle/memory/memory.cc +++ b/paddle/memory/memory.cc @@ -111,7 +111,7 @@ size_t Used(platform::GPUPlace place) { return GetGPUBuddyAllocator(place.device)->Used(); } -#endif // PADDLE_ONLY_CPU +#endif } // namespace memory } // namespace paddle diff --git a/paddle/memory/memory_test.cc b/paddle/memory/memory_test.cc index 0d402038a0..2444931e26 100644 --- a/paddle/memory/memory_test.cc +++ b/paddle/memory/memory_test.cc @@ -135,4 +135,4 @@ TEST(BuddyAllocator, GPUMultAlloc) { } } -#endif // PADDLE_ONLY_CPU +#endif diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 0fa1fca2bc..7dae8fe2f9 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -55,12 +55,20 @@ function(op_library TARGET) set(pybind_flag 1) endif() + # pool_op contains several operators if ("${TARGET}" STREQUAL "pool_op") set(pybind_flag 1) # It's enough to just adding one operator to pybind file(APPEND ${pybind_file} "USE_OP(pool2d);\n") endif() + # pool_with_index_op contains several operators + if ("${TARGET}" STREQUAL "pool_with_index_op") + set(pybind_flag 1) + # It's enough to just adding one operator to pybind + file(APPEND ${pybind_file} "USE_OP(max_pool2d_with_index);\n") + endif() + # activation_op contains several operators if ("${TARGET}" STREQUAL "activation_op") set(pybind_flag 1) @@ -125,3 +133,4 @@ cc_test(gather_test SRCS gather_test.cc DEPS tensor) cc_test(net_op_test SRCS net_op_test.cc DEPS net_op) cc_test(scatter_test SRCS scatter_test.cc DEPS tensor) cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor paddle_memory) +cc_test(dynamic_recurrent_op_test SRCS dynamic_recurrent_op_test.cc DEPS dynamic_recurrent_op recurrent_op tensor_array) diff --git a/paddle/operators/accuracy_op.cc b/paddle/operators/accuracy_op.cc index 82010bfb53..c5fb113e0f 100644 --- a/paddle/operators/accuracy_op.cc +++ b/paddle/operators/accuracy_op.cc @@ -22,7 +22,7 @@ class AccuracyOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Inference"), "Input(Inference) of AccuracyOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Label"), diff --git a/paddle/operators/activation_op.cc b/paddle/operators/activation_op.cc index 66e9d2c401..ced14a8923 100644 --- a/paddle/operators/activation_op.cc +++ b/paddle/operators/activation_op.cc @@ -22,7 +22,7 @@ class ActivationOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { ctx->SetOutputDim("Y", ctx->GetInputDim("X")); ctx->ShareLoD("X", /*->*/ "Y"); } @@ -33,7 +33,7 @@ class ActivationOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Y")); } }; @@ -49,6 +49,18 @@ class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker { } }; +class LogSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { + public: + LogSigmoidOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "Input of LogSigmoid operator"); + AddOutput("Y", "Output of LogSigmoid operator"); + AddComment( + "Logsigmoid activation operator, logsigmoid = log (1 / (1 + exp(-x)))"); + } +}; + class ExpOpMaker : public framework::OpProtoAndCheckerMaker { public: ExpOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) @@ -85,6 +97,23 @@ class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker { } }; +template +class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker { + public: + SoftShrinkOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "Input of Softshrink operator"); + AddOutput("Y", "Output of Softshrink operator"); + AddComment( + "Softshrink activation operator, " + "softshrink = x - lambda, if x > lambda;" + " x + lambda, if x < lambda; 0 otherwise"); + AddAttr("lambda", "non-negative offset") + .SetDefault(static_cast(0.5f)); + } +}; + class TanhOpMaker : public framework::OpProtoAndCheckerMaker { public: TanhOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) @@ -108,6 +137,24 @@ class TanhShrinkOpMaker : public framework::OpProtoAndCheckerMaker { } }; +template +class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker { + public: + HardShrinkOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "Input of HardShrink operator"); + AddOutput("Y", "Output of HardShrink operator"); + AddComment( + "HardShrink activation operator, " + "hard_shrink(x) = x if x > lambda" + "hard_shrink(x) = x if x < -lambda" + "hard_shrink(x) = 0 otherwise"); + AddAttr("threshold", "The value of threshold for HardShrink") + .SetDefault(static_cast(0.5)); + } +}; + class SqrtOpMaker : public framework::OpProtoAndCheckerMaker { public: SqrtOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) @@ -159,6 +206,17 @@ class SquareOpMaker : public framework::OpProtoAndCheckerMaker { } }; +class SoftplusOpMaker : public framework::OpProtoAndCheckerMaker { + public: + SoftplusOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "Input of Softplus operator"); + AddOutput("Y", "Output of Softplus operator"); + AddComment("Softplus activation operator, softplus(x) = log(1 + exp(x))"); + } +}; + class SoftsignOpMaker : public framework::OpProtoAndCheckerMaker { public: SoftsignOpMaker(framework::OpProto *proto, @@ -201,6 +259,40 @@ class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker { } }; +template +class ELUOpMaker : public framework::OpProtoAndCheckerMaker { + public: + ELUOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "(Tensor) The input of ELU operator, it shouldn't be empty. Input " + "is flattened and treated as a 1D array."); + AddOutput("Y", + "(Tensor) The output of ELU operator. It has the same shape as " + "the input."); + AddAttr( + "alpha", "(float, default 1.0) Alpha value in the elu formulation.") + .SetDefault(static_cast(1.)); + AddComment(R"DOC( + ELU activation operator. It applies this element-wise computation on + the input: f(x) = max(0, x) + min(0, alpha * (exp(x) - 1)). + Check .. _Link: https://arxiv.org/abs/1511.07289 for more details.)DOC"); + } +}; + +template +class Relu6OpMaker : public framework::OpProtoAndCheckerMaker { + public: + Relu6OpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "Input of Relu6 operator"); + AddOutput("Y", "Output of Relu6 operator"); + AddComment("Relu6 activation operator, relu6 = min(max(0, x), 6)"); + AddAttr("threshold", "The threshold value of Relu6") + .SetDefault(static_cast(6)); + } +}; + template class PowOpMaker : public framework::OpProtoAndCheckerMaker { public: @@ -237,6 +329,9 @@ namespace ops = paddle::operators; REGISTER_OP(sigmoid, ops::ActivationOp, ops::SigmoidOpMaker, sigmoid_grad, ops::ActivationOpGrad); +REGISTER_OP(logsigmoid, ops::ActivationOp, ops::LogSigmoidOpMaker, + logsigmoid_grad, ops::ActivationOpGrad); + REGISTER_OP(exp, ops::ActivationOp, ops::ExpOpMaker, exp_grad, ops::ActivationOpGrad); @@ -249,6 +344,9 @@ REGISTER_OP(tanh, ops::ActivationOp, ops::TanhOpMaker, tanh_grad, REGISTER_OP(tanh_shrink, ops::ActivationOp, ops::TanhShrinkOpMaker, tanh_shrink_grad, ops::ActivationOpGrad); +REGISTER_OP(softshrink, ops::ActivationOp, ops::SoftShrinkOpMaker, + softshrink_grad, ops::ActivationOpGrad); + REGISTER_OP(sqrt, ops::ActivationOp, ops::SqrtOpMaker, sqrt_grad, ops::ActivationOpGrad); @@ -264,6 +362,9 @@ REGISTER_OP(log, ops::ActivationOp, ops::LogOpMaker, log_grad, REGISTER_OP(square, ops::ActivationOp, ops::SquareOpMaker, square_grad, ops::ActivationOpGrad); +REGISTER_OP(softplus, ops::ActivationOp, ops::SoftplusOpMaker, softplus_grad, + ops::ActivationOpGrad); + REGISTER_OP(softsign, ops::ActivationOp, ops::SoftsignOpMaker, softsign_grad, ops::ActivationOpGrad); @@ -276,20 +377,27 @@ REGISTER_OP(leaky_relu, ops::ActivationOp, ops::LeakyReluOpMaker, REGISTER_OP(soft_relu, ops::ActivationOp, ops::SoftReluOpMaker, soft_relu_grad, ops::ActivationOpGrad); +REGISTER_OP(elu, ops::ActivationOp, ops::ELUOpMaker, elu_grad, + ops::ActivationOpGrad); + +REGISTER_OP(relu6, ops::ActivationOp, ops::Relu6OpMaker, relu6_grad, + ops::ActivationOpGrad); + REGISTER_OP(pow, ops::ActivationOp, ops::PowOpMaker, pow_grad, ops::ActivationOpGrad); REGISTER_OP(stanh, ops::ActivationOp, ops::STanhOpMaker, stanh_grad, ops::ActivationOpGrad); +REGISTER_OP(hard_shrink, ops::ActivationOp, ops::HardShrinkOpMaker, + hard_shrink_grad, ops::ActivationOpGrad); + #define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \ REGISTER_OP_CPU_KERNEL( \ act_type, \ - paddle::operators::ActivationKernel>); \ + ops::ActivationKernel>); \ REGISTER_OP_CPU_KERNEL(act_type##_grad, \ - paddle::operators::ActivationGradKernel< \ - paddle::platform::CPUPlace, \ - paddle::operators::grad_functor>); + ops::ActivationGradKernel>); FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CPU_KERNEL); diff --git a/paddle/operators/activation_op.cu b/paddle/operators/activation_op.cu index 93e9f1c694..7b7644519d 100644 --- a/paddle/operators/activation_op.cu +++ b/paddle/operators/activation_op.cu @@ -15,14 +15,14 @@ #define EIGEN_USE_GPU #include "paddle/operators/activation_op.h" +namespace ops = paddle::operators; + #define REGISTER_ACTIVATION_GPU_KERNEL(act_type, functor, grad_functor) \ REGISTER_OP_GPU_KERNEL( \ act_type, \ - paddle::operators::ActivationKernel>); \ + ops::ActivationKernel>); \ REGISTER_OP_GPU_KERNEL(act_type##_grad, \ - paddle::operators::ActivationGradKernel< \ - paddle::platform::GPUPlace, \ - paddle::operators::grad_functor>); + ops::ActivationGradKernel>); FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_GPU_KERNEL); diff --git a/paddle/operators/activation_op.h b/paddle/operators/activation_op.h index 2450601742..f88c9c48eb 100644 --- a/paddle/operators/activation_op.h +++ b/paddle/operators/activation_op.h @@ -95,6 +95,41 @@ struct SigmoidGradFunctor : public BaseActivationFunctor { } }; +// Originally: logsigmoid(x) = -log (1 + exp(-x)) +// For numerical stability, we can use the log-sum-exp trick: +// https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/ +// We can rewrite the above equation as: +// y = -log( exp(0) + exp(-x)) [since exp(0) = 1] +// = -log( exp(max(-x, 0) - max(-x, 0)) + exp(-x + max(-x, 0) - max(-x, 0))) +// = -log( exp(max(-x, 0)) * exp(-max(-x, 0)) - exp(max(-x, 0)) * exp(-x - +// max(-x, 0))) +// = -log( exp(max(-x, 0)) * (exp(-max(-x, 0)) + exp(-x - max(-x, 0)))) +// = -log( exp(max(-x, 0)) - log(exp(-max(-x, 0)) + exp(-x - max(-x, 0))) +// +// Hence, logsigmoid(x) = - (max(-x, 0) + log(exp(-max(-x, 0)) +// + exp(-x - max(-x, 0)))) +template +struct LogSigmoidFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Y y) const { + auto temp = (-x).cwiseMax(static_cast(0)); // temp = max(-x, 0) + y.device(d) = -temp - (((-temp).exp() + (-x - temp).exp()).log()); + } +}; + +// Originally: f' = exp(-x) / (1 + exp(-x)) +// For numerical stability: f' = exp(-x - max(-x, 0)) / (exp(-max(-x, 0)) + +// exp(-x - max(-x, 0))) +template +struct LogSigmoidGradFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Y y, dY dy, dX dx) const { + auto temp = (-x).cwiseMax(static_cast(0)); // temp = max(-x, 0) + dx.device(d) = + dy * ((-x - temp).exp() / ((-temp).exp() + (-x - temp).exp())); + } +}; + // exp(x) = e^x template struct ExpFunctor : public BaseActivationFunctor { @@ -164,6 +199,70 @@ struct TanhShrinkGradFunctor : public BaseActivationFunctor { } }; +// tanhshrink(x) = x - tanh(x) +// where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) +template +struct HardShrinkFunctor : public BaseActivationFunctor { + float threshold; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}}; + } + template + void operator()(Device d, X x, Y y) const { + auto temp1 = (x < (threshold * -1)).template cast().eval(); + auto temp2 = (x > threshold).template cast().eval(); + y.device(d) = x * (temp1 + temp2); + } +}; + +template +struct HardShrinkGradFunctor : public BaseActivationFunctor { + float threshold; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}}; + } + + template + void operator()(Device d, X x, Y y, dY dy, dX dx) const { + auto temp1 = (x < (threshold * -1)).template cast().eval(); + auto temp2 = (x > threshold).template cast().eval(); + dx.device(d) = dy * (temp1 + temp2).template cast(); + } +}; + +// softshrink(x) = x - lambda, if x > lambda; x + lambda, if x < lambda; 0 +// otherwise +template +struct SoftShrinkFunctor : public BaseActivationFunctor { + float lambda; + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"lambda", &lambda}}; + } + + template + void operator()(Device d, X x, Y y) const { + auto temp1 = (x > lambda).template cast().eval(); + auto temp2 = (x < -lambda).template cast().eval(); + y.device(d) = temp1 * (x - lambda) + temp2 * (x + lambda); + } +}; + +template +struct SoftShrinkGradFunctor : public BaseActivationFunctor { + float lambda; + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"lambda", &lambda}}; + } + template + void operator()(Device d, X x, Y y, dY dy, dX dx) const { + auto temp1 = (x > lambda).template cast().eval(); + auto temp2 = (x < -lambda).template cast().eval(); + dx.device(d) = dy * (temp1 + temp2).template cast(); + } +}; + // sqrt(x) = x^(1/2) template struct SqrtFunctor : public BaseActivationFunctor { @@ -280,6 +379,61 @@ struct BReluGradFunctor : public BaseActivationFunctor { } }; +// relu6(x) = min(max(0, x), 6) +template +struct Relu6Functor : public BaseActivationFunctor { + float threshold; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}}; + } + + template + void operator()(Device d, X x, Y y) const { + y.device(d) = x.cwiseMax(static_cast(0)).cwiseMin(threshold); + } +}; + +template +struct Relu6GradFunctor : public BaseActivationFunctor { + float threshold; + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}}; + } + template + void operator()(Device d, X x, Y y, dY dy, dX dx) const { + dx.device(d) = + dy * ((x > static_cast(0)) * (x < threshold)).template cast(); + } +}; + +// softplus(x) = log(1 + exp(x)) +// When x is a very large positive number, exp(x) may explode to inf, +// Using trick below for numerical stability +// https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/ +// Then: softplus(x) = max(x, 0) + log(exp(-max(x, 0)) + exp(x - max(x, 0))) +template +struct SoftplusFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Y y) { + auto temp = x.cwiseMax(static_cast(0)); // temp = max(x, 0) + y.device(d) = temp + (((-temp).exp() + (x - temp).exp()).log()); + } +}; + +// d(softplus(x))/dx = exp(x) / (1 + exp(x)) +// For numerical stability: +// d(softplus(x))/dx = exp(x - max(x, 0)) / (exp(-max(x, 0)) + +// exp(x - max(x, 0))) +template +struct SoftplusGradFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Y y, dY dy, dX dx) { + auto temp = x.cwiseMax(static_cast(0)); // temp = max(x, 0) + dx.device(d) = dy * ((x - temp).exp() / ((-temp).exp() + (x - temp).exp())); + } +}; + // softsign(x) = x / (1 + |x|) template struct SoftsignFunctor : public BaseActivationFunctor { @@ -354,6 +508,35 @@ struct LeakyReluGradFunctor : public BaseActivationFunctor { } }; +template +struct ELUFunctor : public BaseActivationFunctor { + float alpha; + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"alpha", &alpha}}; + } + + template + void operator()(Device d, X x, Y y) const { + y.device(d) = + x.cwiseMax(static_cast(0)) + + (alpha * (x.exp() - static_cast(1))).cwiseMin(static_cast(0)); + } +}; + +template +struct ELUGradFunctor : public BaseActivationFunctor { + float alpha; + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"alpha", &alpha}}; + } + template + void operator()(Device d, X x, Y y, dY dy, dX dx) const { + dx.device(d) = + dy * (x > static_cast(0)).template cast() + + dy * (y + alpha) * (x < static_cast(0)).template cast(); + } +}; + template struct PowFunctor : public BaseActivationFunctor { float factor; @@ -410,20 +593,26 @@ struct STanhGradFunctor : public BaseActivationFunctor { } // namespace operators } // namespace paddle -#define FOR_EACH_KERNEL_FUNCTOR(__macro) \ - __macro(sigmoid, SigmoidFunctor, SigmoidGradFunctor); \ - __macro(exp, ExpFunctor, ExpGradFunctor); \ - __macro(relu, ReluFunctor, ReluGradFunctor); \ - __macro(tanh, TanhFunctor, TanhGradFunctor); \ - __macro(sqrt, SqrtFunctor, SqrtGradFunctor); \ - __macro(abs, AbsFunctor, AbsGradFunctor); \ - __macro(reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \ - __macro(log, LogFunctor, LogGradFunctor); \ - __macro(square, SquareFunctor, SquareGradFunctor); \ - __macro(brelu, BReluFunctor, BReluGradFunctor); \ - __macro(soft_relu, SoftReluFunctor, SoftReluGradFunctor); \ - __macro(pow, PowFunctor, PowGradFunctor); \ - __macro(stanh, STanhFunctor, STanhGradFunctor); \ - __macro(softsign, SoftsignFunctor, SoftsignGradFunctor); \ - __macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor); \ - __macro(tanh_shrink, TanhShrinkFunctor, TanhShrinkGradFunctor) +#define FOR_EACH_KERNEL_FUNCTOR(__macro) \ + __macro(sigmoid, SigmoidFunctor, SigmoidGradFunctor); \ + __macro(logsigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor); \ + __macro(exp, ExpFunctor, ExpGradFunctor); \ + __macro(relu, ReluFunctor, ReluGradFunctor); \ + __macro(tanh, TanhFunctor, TanhGradFunctor); \ + __macro(softshrink, SoftShrinkFunctor, SoftShrinkGradFunctor); \ + __macro(sqrt, SqrtFunctor, SqrtGradFunctor); \ + __macro(abs, AbsFunctor, AbsGradFunctor); \ + __macro(reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \ + __macro(log, LogFunctor, LogGradFunctor); \ + __macro(square, SquareFunctor, SquareGradFunctor); \ + __macro(brelu, BReluFunctor, BReluGradFunctor); \ + __macro(soft_relu, SoftReluFunctor, SoftReluGradFunctor); \ + __macro(pow, PowFunctor, PowGradFunctor); \ + __macro(stanh, STanhFunctor, STanhGradFunctor); \ + __macro(softplus, SoftplusFunctor, SoftplusGradFunctor); \ + __macro(softsign, SoftsignFunctor, SoftsignGradFunctor); \ + __macro(relu6, Relu6Functor, Relu6GradFunctor); \ + __macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor); \ + __macro(tanh_shrink, TanhShrinkFunctor, TanhShrinkGradFunctor); \ + __macro(elu, ELUFunctor, ELUGradFunctor); \ + __macro(hard_shrink, HardShrinkFunctor, HardShrinkGradFunctor) diff --git a/paddle/operators/adadelta_op.cc b/paddle/operators/adadelta_op.cc index bd8c93b4a1..cf1bca1658 100644 --- a/paddle/operators/adadelta_op.cc +++ b/paddle/operators/adadelta_op.cc @@ -22,7 +22,7 @@ class AdadeltaOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Param"), "Input(Param) of AdadeltaOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Grad"), diff --git a/paddle/operators/adagrad_op.cc b/paddle/operators/adagrad_op.cc index ea2ff3c503..a17747efb7 100644 --- a/paddle/operators/adagrad_op.cc +++ b/paddle/operators/adagrad_op.cc @@ -22,7 +22,7 @@ class AdagradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Param"), "Input(Param) of AdagradOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Grad"), diff --git a/paddle/operators/adamax_op.cc b/paddle/operators/adamax_op.cc new file mode 100644 index 0000000000..5cf727742c --- /dev/null +++ b/paddle/operators/adamax_op.cc @@ -0,0 +1,139 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/adamax_op.h" + +namespace paddle { +namespace operators { + +class AdamaxOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Param"), + "Input(Param) of AdamaxOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Grad"), + "Input(Grad) of AdamaxOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Moment"), + "Input(Moment) of AdamaxOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("InfNorm"), + "Input(InfNorm) of AdamaxOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("LearningRate"), + "Input(LearningRate) of AdamaxOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Beta1Pow"), + "Input(Beta1Pow) of AdamaxOp should not be null."); + + PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), + "Output(ParamOut) of AdamaxOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("MomentOut"), + "Output(MomentOut) of AdamaxOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("InfNormOut"), + "Output(InfNormOut) of AdamaxOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Beta1PowOut"), + "Output(Beta1PowOut) of AdamaxOp should not be null."); + + auto lr_dims = ctx->GetInputDim("LearningRate"); + PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1, + "Learning rate should have 1 dimension"); + auto beta1_pow_dims = ctx->GetInputDim("Beta1Pow"); + PADDLE_ENFORCE_EQ(framework::product(beta1_pow_dims), 1, + "Beta1 power accumulator should have 1 dimension"); + auto param_dims = ctx->GetInputDim("Param"); + PADDLE_ENFORCE_EQ( + param_dims, ctx->GetInputDim("Grad"), + "Param and Grad input of AdamaxOp should have same dimension"); + PADDLE_ENFORCE_EQ( + param_dims, ctx->GetInputDim("Moment"), + "Param and Moment input of AdamaxOp should have same dimension"); + PADDLE_ENFORCE_EQ( + param_dims, ctx->GetInputDim("InfNorm"), + "Param and InfNorm input of AdamaxOp should have same dimension"); + + ctx->SetOutputDim("ParamOut", param_dims); + ctx->SetOutputDim("MomentOut", param_dims); + ctx->SetOutputDim("InfNormOut", param_dims); + ctx->SetOutputDim("Beta1PowOut", beta1_pow_dims); + } +}; + +class AdamaxOpMaker : public framework::OpProtoAndCheckerMaker { + public: + AdamaxOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Param", "(Tensor) Input parameter"); + AddInput("Grad", "(Tensor) Input gradient"); + AddInput("LearningRate", "(Tensor) Learning rate"); + AddInput("Moment", "(Tensor) First moment"); + AddInput("InfNorm", + "(Tensor) " + "Input exponentially weighted infinity norm"); + AddInput("Beta1Pow", "(Tensor) Input beta1 power accumulator"); + + AddOutput("ParamOut", "(Tensor) Output parameter"); + AddOutput("MomentOut", "(Tensor) Output first moment"); + AddOutput("InfNormOut", + "(Tensor) " + "Output exponentially weighted infinity norm"); + AddOutput("Beta1PowOut", "(Tensor) Output beta1 power accumulator"); + + AddAttr("beta1", + "(float, default 0.9) " + "Exponential decay rate for the " + "1st moment estimates.") + .SetDefault(0.9f); + AddAttr("beta2", + "(float, default 0.999) " + "exponential decay rate for the weighted " + "infinity norm estimates.") + .SetDefault(0.999f); + AddAttr("epsilon", + "(float, default 1.0e-8) " + "Constant for numerical stability") + .SetDefault(1.0e-8f); + AddComment(R"DOC( +Adamax Updates Operator. + +This implements the Adamax optimizer from Section 7 of the Adam +paper[1]. Adamax is a variant of the +Adam algorithm based on the infinity norm. + +Adamax updates: + +moment_out = beta1 * moment + (1 - beta1) * grad +inf_norm_out = max(beta2 * inf_norm + epsilon, abs(grad)) +beta1_pow_out = beta1_pow * beta1 +learning_rate_t = learning_rate/(1 - beta1_pow_out) +param_out = param - learning_rate_t * moment_out/inf_norm_out + +The original paper does not have an epsilon attribute. +However, it is added here for numerical stability +by preventing divide by 0. + +References: + [1] Adam: A Method for Stochastic Optimization + (https://arxiv.org/abs/1412.6980) + +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(adamax, ops::AdamaxOp, ops::AdamaxOpMaker); +REGISTER_OP_CPU_KERNEL(adamax, + ops::AdamaxOpKernel); diff --git a/paddle/operators/adamax_op.cu b/paddle/operators/adamax_op.cu new file mode 100644 index 0000000000..fee3b6fc6b --- /dev/null +++ b/paddle/operators/adamax_op.cu @@ -0,0 +1,20 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/adamax_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(adamax, + ops::AdamaxOpKernel); diff --git a/paddle/operators/adamax_op.h b/paddle/operators/adamax_op.h new file mode 100644 index 0000000000..9677b1bb78 --- /dev/null +++ b/paddle/operators/adamax_op.h @@ -0,0 +1,72 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class AdamaxOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto param_out_tensor = ctx.Output("ParamOut"); + auto moment_out_tensor = ctx.Output("MomentOut"); + auto inf_norm_out_tensor = ctx.Output("InfNormOut"); + auto beta1_pow_out_tensor = ctx.Output("Beta1PowOut"); + + param_out_tensor->mutable_data(ctx.GetPlace()); + moment_out_tensor->mutable_data(ctx.GetPlace()); + inf_norm_out_tensor->mutable_data(ctx.GetPlace()); + beta1_pow_out_tensor->mutable_data(ctx.GetPlace()); + + float beta1 = ctx.Attr("beta1"); + float beta2 = ctx.Attr("beta2"); + float epsilon = ctx.Attr("epsilon"); + + auto param = framework::EigenVector::Flatten( + *ctx.Input("Param")); + auto grad = framework::EigenVector::Flatten( + *ctx.Input("Grad")); + auto moment = framework::EigenVector::Flatten( + *ctx.Input("Moment")); + auto inf_norm = framework::EigenVector::Flatten( + *ctx.Input("InfNorm")); + auto lr = framework::EigenVector::Flatten( + *ctx.Input("LearningRate")); + auto beta1_pow = framework::EigenVector::Flatten( + *ctx.Input("Beta1Pow")); + auto param_out = framework::EigenVector::Flatten(*param_out_tensor); + auto moment_out = framework::EigenVector::Flatten(*moment_out_tensor); + auto inf_norm_out = + framework::EigenVector::Flatten(*inf_norm_out_tensor); + auto beta1_pow_out = + framework::EigenVector::Flatten(*beta1_pow_out_tensor); + auto place = ctx.GetEigenDevice(); + + moment_out.device(place) = beta1 * moment + (1 - beta1) * grad; + inf_norm_out.device(place) = + grad.abs().cwiseMax((beta2 * inf_norm) + epsilon); + beta1_pow_out.device(place) = beta1_pow * beta1; + auto lr_t = lr / (1 - beta1_pow_out); + Eigen::DSizes m_dsize(moment_out_tensor->numel()); + param_out.device(place) = + param - lr_t.broadcast(m_dsize) * (moment_out / inf_norm_out); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/clip_op.cc b/paddle/operators/clip_op.cc index b3dd060fd7..3e9b0d82ba 100644 --- a/paddle/operators/clip_op.cc +++ b/paddle/operators/clip_op.cc @@ -22,7 +22,7 @@ class ClipOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of ClipOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), @@ -61,7 +61,7 @@ class ClipOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), "Input(Out@GRAD) should not be null"); diff --git a/paddle/operators/concat_op.cc b/paddle/operators/concat_op.cc index 1ffa02c8f9..235c4449ac 100644 --- a/paddle/operators/concat_op.cc +++ b/paddle/operators/concat_op.cc @@ -24,7 +24,7 @@ class ConcatOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE_GE(ctx->Inputs("X").size(), 1UL, "Inputs(X) of ConcatOp should be empty.") PADDLE_ENFORCE(ctx->HasOutput("Out"), @@ -83,7 +83,7 @@ class ConcatOpGrad : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X")); } }; diff --git a/paddle/operators/conv2d_op.cc b/paddle/operators/conv2d_op.cc index 5cc82944bb..6325d4248f 100644 --- a/paddle/operators/conv2d_op.cc +++ b/paddle/operators/conv2d_op.cc @@ -27,7 +27,7 @@ class Conv2DOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Input"), "Input(Input) of Conv2DOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Filter"), @@ -106,7 +106,7 @@ class Conv2DOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { auto in_dims = ctx->GetInputDim("Input"); auto filter_dims = ctx->GetInputDim("Filter"); if (ctx->HasOutput(framework::GradVarName("Input"))) { diff --git a/paddle/operators/conv_shift_op.cc b/paddle/operators/conv_shift_op.cc new file mode 100644 index 0000000000..e1e321ed5f --- /dev/null +++ b/paddle/operators/conv_shift_op.cc @@ -0,0 +1,206 @@ +/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/conv_shift_op.h" +#include "paddle/framework/eigen.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; +template +using EigenMatrix = framework::EigenMatrix; + +class ConvShiftOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); + PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should be not null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should be not null."); + + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2."); + PADDLE_ENFORCE_EQ(y_dims.size(), 2, "Input(Y)'s rank should be 2."); + PADDLE_ENFORCE_EQ(x_dims[0], y_dims[0], + "The 1st dimension of Input(X) and Input(Y) should " + "be equal."); + PADDLE_ENFORCE_EQ(y_dims[1] % 2, 1, + "The 2nd dimension of Input(Y) should be odd."); + PADDLE_ENFORCE_LE(y_dims[1], x_dims[1], + "The 2nd dimension of Input(Y) should be less than or " + "equal to the 2nd dimension of Input(X)."); + ctx->SetOutputDim("Out", x_dims); + ctx->ShareLoD("X", /*->*/ "Out"); + } +}; + +class ConvShiftGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); + PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should be not null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should be not null."); + + auto x_grad_name = framework::GradVarName("X"); + if (ctx->HasOutput(x_grad_name)) { + auto x_dims = ctx->GetInputDim("X"); + ctx->SetOutputDim(x_grad_name, x_dims); + } + + auto y_grad_name = framework::GradVarName("Y"); + if (ctx->HasOutput(y_grad_name)) { + auto y_dims = ctx->GetInputDim("Y"); + ctx->SetOutputDim(y_grad_name, y_dims); + } + } +}; + +class ConvShiftOpMaker : public framework::OpProtoAndCheckerMaker { + public: + ConvShiftOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : framework::OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "(Tensor, default Tensor), a 2-D tensor with shape B x M, " + "where B is the batch size and M is the data dimension."); + AddInput("Y", + "(Tensor, default Tensor), a 2-D tensor with shape B x N, " + "where B is the batch size and N is the data dimension. N must " + "be odd."); + AddOutput("Out", + "(Tensor, default Tensor), a 2-D tensor with shape B x M, " + "i.e., the same shape as X."); + AddComment(R"DOC( +ConvShift Operator. + +A layer for circular convolution of two vectors, +as used in the Neural Turing Machine: https://arxiv.org/abs/1410.5401 + +The equation is: + + \f[ + Out[i] = \sum_{j=-(N-1)/2}^{(N-1)/2} X_{i+j} * Y_{j} + \f] + +where X's index is computed modulo M, and b's index is computed modulo N. + +Both of the input `X` and `Y` can carry LoD (Level of Details) information. +However, the output only shares the LoD information with input `X`. +)DOC"); + } +}; + +template +class ConvShiftKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + auto *X = context.Input("X"); + auto *Y = context.Input("Y"); + auto *Out = context.Output("Out"); + Out->mutable_data(context.GetPlace()); + + auto x = EigenMatrix::From(*X); + auto y = EigenMatrix::From(*Y); + auto out = EigenMatrix::From(*Out); + out.setZero(); + + size_t batch_size = X->dims()[0]; + size_t x_width = X->dims()[1]; + size_t y_width = Y->dims()[1]; + size_t y_half_width = (y_width - 1) / 2; + + for (size_t k = 0; k < batch_size; ++k) { + for (size_t i = 0; i < x_width; ++i) { + for (size_t j = 0; j < y_width; ++j) { + int index = (i + j - y_half_width + x_width) % x_width; + out(k, i) += x(k, index) * y(k, j); + } + } + } + } +}; + +template +class ConvShiftGradKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + auto *X = context.Input("X"); + auto *Y = context.Input("Y"); + auto *dOut = context.Input(framework::GradVarName("Out")); + auto *dX = context.Output(framework::GradVarName("X")); + auto *dY = context.Output(framework::GradVarName("Y")); + + auto x = EigenMatrix::From(*X); + auto y = EigenMatrix::From(*Y); + auto dout = EigenMatrix::From(*dOut); + + auto x_dims = X->dims(); + auto y_dims = Y->dims(); + size_t batch_size = x_dims[0]; + size_t x_width = x_dims[1]; + size_t y_width = y_dims[1]; + size_t y_half_width = (y_width - 1) / 2; + + // The below trades code duplication for efficiency (keeping the if + // statement outside of the loop). + if (dX) { + dX->mutable_data(context.GetPlace()); + auto dx = EigenMatrix::From(*dX); + dx.setZero(); + for (size_t k = 0; k < batch_size; ++k) { + for (size_t i = 0; i < x_width; ++i) { + for (size_t j = 0; j < y_width; ++j) { + int index = (i + j - y_half_width + x_width) % x_width; + dx(k, index) += dout(k, i) * y(k, j); + } + } + } + } + + if (dY) { + dY->mutable_data(context.GetPlace()); + auto dy = EigenMatrix::From(*dY); + dy.setZero(); + for (size_t k = 0; k < batch_size; ++k) { + for (size_t i = 0; i < x_width; ++i) { + for (size_t j = 0; j < y_width; ++j) { + int index = (i + j - y_half_width + x_width) % x_width; + dy(k, j) += x(k, index) * dout(k, i); + } + } + } + } + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(conv_shift, ops::ConvShiftOp, ops::ConvShiftOpMaker, + conv_shift_grad, ops::ConvShiftGradOp); +REGISTER_OP_CPU_KERNEL(conv_shift, + ops::ConvShiftKernel); +REGISTER_OP_CPU_KERNEL( + conv_shift_grad, + ops::ConvShiftGradKernel); diff --git a/paddle/operators/conv_shift_op.cu b/paddle/operators/conv_shift_op.cu new file mode 100644 index 0000000000..145e966fe9 --- /dev/null +++ b/paddle/operators/conv_shift_op.cu @@ -0,0 +1,194 @@ +/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/conv_shift_op.h" +#include "paddle/platform/cuda_helper.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +namespace { + +inline int div_up(int x, int y) { return (x + y - 1) / y; } + +// Some notes on the design: +// +// Each thread is responsible for computing a single output out[k, i]. +// Thread blocks are based on tiles of x with height 1 in the batch dimension. +// +// This design is based on the typical use case where the filter +// y is fairly small. For large y, it would probably be more efficient +// to also tile across y. +template +__global__ void conv_shift_forward(const T *x, const T *y, T *out, int x_width, + int y_width, int y_half_width, + int batch_size) { + extern __shared__ T mem[]; + + int tx = threadIdx.x; + int i = blockIdx.x * blockDim.x + tx; // global x index + int k = blockIdx.y; // batch index + + // Check if we are in a boundary block with fewer x's to process than + // blockDim.x. + int num_x = + (blockIdx.x == gridDim.x - 1) ? (x_width % blockDim.x) : blockDim.x; + + T *sx = mem; + T *sx_pad = &mem[num_x]; + T *sy = &mem[blockDim.x + y_width]; + + // Collaboratively load y[k, :] and length-y padding of x into shared memory. + int pad_start = blockIdx.x * blockDim.x + num_x + x_width - y_half_width; + for (int j = tx; j < y_width; j += blockDim.x) { + sy[j] = y[k * y_width + j]; + sx_pad[j] = x[k * x_width + (pad_start + j) % x_width]; + } + + // Load a cyclically shifted slice of x into shared memory. + if (tx < num_x) { + int load_i = (i - y_half_width + x_width) % x_width; + sx[tx] = x[k * x_width + load_i]; + } else { + return; + } + __syncthreads(); + + // Compute dot product of sx[tx:tx + y_width] and sy. + T sum = 0; + for (int j = 0; j < y_width; ++j) { + sum += sx[tx + j] * sy[j]; + } + + // Save to out[k, i]. + out[k * x_width + i] = sum; +} + +// Compute x gradient - initial naive implementation with atomic add. +template +__global__ void conv_shift_dx(const T *dout, const T *y, T *dx, int x_width, + int y_width, int y_half_width, int batch_size) { + int i = blockIdx.x * blockDim.x + threadIdx.x; // x index + int j = blockIdx.y; // y index + int k = blockIdx.z; // batch index + + if (i < x_width) { + int index = (i + j - y_half_width + x_width) % x_width; + atomicAdd(&dx[k * x_width + index], + dout[k * x_width + i] * y[k * y_width + j]); + } +} + +// Compute y gradient - initial naive implementation with atomic add. +template +__global__ void conv_shift_dy(const T *x, const T *dout, T *dy, int x_width, + int y_width, int y_half_width, int batch_size) { + int i = blockIdx.x * blockDim.x + threadIdx.x; // x index + int j = blockIdx.y; // y index + int k = blockIdx.z; // batch index + + if (i < x_width) { + int index = (i + j - y_half_width + x_width) % x_width; + atomicAdd(&dy[k * y_width + j], + x[k * x_width + index] * dout[k * x_width + i]); + } +} +} // namespace + +template +class ConvShiftKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + const Tensor *X = context.Input("X"); + const Tensor *Y = context.Input("Y"); + Tensor *Out = context.Output("Out"); + const T *x_data = X->data(); + const T *y_data = Y->data(); + T *out_data = Out->mutable_data(context.GetPlace()); + + int batch_size = X->dims()[0]; + int x_width = X->dims()[1]; + int y_width = Y->dims()[1]; + int y_half_width = (y_width - 1) / 2; + + const int x_per_block = 256; + int num_x_blocks = div_up(x_width, x_per_block); + int mem_per_block = (x_per_block + 2 * y_width) * sizeof(T); + + dim3 grid_dim(num_x_blocks, batch_size); + + auto stream = reinterpret_cast( + context.device_context()) + .stream(); + + conv_shift_forward<<>>( + x_data, y_data, out_data, x_width, y_width, y_half_width, batch_size); + } +}; + +template +class ConvShiftGradKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + const Tensor *X = context.Input("X"); + const Tensor *Y = context.Input("Y"); + const Tensor *dOut = context.Input(framework::GradVarName("Out")); + const T *x_data = X->data(); + const T *y_data = Y->data(); + const T *dout_data = dOut->data(); + + Tensor *dX = context.Output(framework::GradVarName("X")); + Tensor *dY = context.Output(framework::GradVarName("Y")); + + int batch_size = X->dims()[0]; + int x_width = X->dims()[1]; + int y_width = Y->dims()[1]; + int y_half_width = (y_width - 1) / 2; + + auto stream = reinterpret_cast( + context.device_context()) + .stream(); + + const int x_per_block = 256; + int num_x_blocks = div_up(x_width, x_per_block); + dim3 grid_dim(num_x_blocks, y_width, batch_size); + + if (dX) { + T *dx_data = dX->mutable_data(context.GetPlace()); + cudaMemsetAsync(dx_data, 0, dX->numel() * sizeof(T), stream); + conv_shift_dx<<>>( + dout_data, y_data, dx_data, x_width, y_width, y_half_width, + batch_size); + } + if (dY) { + T *dy_data = dY->mutable_data(context.GetPlace()); + cudaMemsetAsync(dy_data, 0, dY->numel() * sizeof(T), stream); + conv_shift_dy<<>>( + x_data, dout_data, dy_data, x_width, y_width, y_half_width, + batch_size); + } + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(conv_shift, + ops::ConvShiftKernel); +REGISTER_OP_GPU_KERNEL( + conv_shift_grad, + ops::ConvShiftGradKernel); diff --git a/paddle/operators/conv_shift_op.h b/paddle/operators/conv_shift_op.h new file mode 100644 index 0000000000..5a160b0f16 --- /dev/null +++ b/paddle/operators/conv_shift_op.h @@ -0,0 +1,33 @@ +/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class ConvShiftKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override; +}; + +template +class ConvShiftGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override; +}; +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/cos_sim_op.cc b/paddle/operators/cos_sim_op.cc index 040546f1a6..2b4c4b9c45 100644 --- a/paddle/operators/cos_sim_op.cc +++ b/paddle/operators/cos_sim_op.cc @@ -24,7 +24,7 @@ class CosSimOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { // notnull check PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of CosSimOp should not be null."); @@ -98,7 +98,7 @@ class CosSimOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { // notnull check PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) must not be null."); diff --git a/paddle/operators/crop_op.cc b/paddle/operators/crop_op.cc index 9b2305e90e..a1424993cc 100644 --- a/paddle/operators/crop_op.cc +++ b/paddle/operators/crop_op.cc @@ -25,7 +25,7 @@ class CropOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of CropOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), @@ -115,7 +115,7 @@ class CropOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), "Input(Out@GRAD) should not be null"); diff --git a/paddle/operators/cross_entropy_op.cc b/paddle/operators/cross_entropy_op.cc index 4b67887f36..708e80e96a 100644 --- a/paddle/operators/cross_entropy_op.cc +++ b/paddle/operators/cross_entropy_op.cc @@ -22,7 +22,7 @@ class CrossEntropyOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null."); PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) should be not null."); @@ -60,7 +60,7 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null."); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), diff --git a/paddle/operators/dropout_op.cc b/paddle/operators/dropout_op.cc index a669b5cf00..708ccfa0bf 100644 --- a/paddle/operators/dropout_op.cc +++ b/paddle/operators/dropout_op.cc @@ -24,7 +24,7 @@ class DropoutOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); PADDLE_ENFORCE_GE(ctx->Attrs().Get("dropout_prob"), 0); PADDLE_ENFORCE_LE(ctx->Attrs().Get("dropout_prob"), 1); @@ -70,7 +70,7 @@ class DropoutOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE_EQ(ctx->Attrs().Get("is_training"), 1, "GradOp is only callable when is_training is true"); diff --git a/paddle/operators/dynamic_recurrent_op.cc b/paddle/operators/dynamic_recurrent_op.cc new file mode 100644 index 0000000000..b919aef8fb --- /dev/null +++ b/paddle/operators/dynamic_recurrent_op.cc @@ -0,0 +1,276 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve . + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/dynamic_recurrent_op.h" + +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using framework::Scope; +using framework::TensorArray; +using framework::LoDTensor; +using framework::Variable; + +namespace detail { + +inline void CreateVariables(Scope& scope, + const std::vector& var_names) { + for (const auto& name : var_names) { + scope.NewVar(name); + } +} + +} // namespace detail + +class DynamicRecurrentOpProtoAndCheckerMaker + : public framework::OpProtoAndCheckerMaker { + public: + DynamicRecurrentOpProtoAndCheckerMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + const auto& name = DynamicRecurrentOp::kArgName; + // inputs and outputs stored in proto + AddInput(name.inlinks, + "the inputs that need to be segmented for each step.") + .AsDuplicable(); + AddInput(name.boot_memories, "variables to initialize memories.") + .AsDuplicable(); + + AddOutput(name.outlinks, "the outputs that need to concated for all steps.") + .AsDuplicable(); + AddOutput(name.step_scopes, "step scopes"); + + // Attributes stored in AttributeMap + AddAttr>(name.pre_memories, + "names of pre-memories"); + AddAttr>(name.memories, "names of memories"); + + AddComment("This is a RNN operator for varience-length sequences."); + } +}; + +void DynamicRecurrentOp::Run(const Scope& scope, + const platform::DeviceContext& dev_ctx) const { + cache_.Init(kArgName, *this, scope, &arg_); + SplitInputs(); + CreateScopes(); + WriteStepInputs(); + InitStates(); + + // call stepnet in all the time steps + for (size_t step = 0; step < cache_.num_steps; step++) { + auto& step_scope = cache_.GetScope(step); + stepnet_->Run(step_scope, dev_ctx); + } + + WriteStepOutputs(); + ConcatOutputs(); +} + +void DynamicRecurrentOp::SplitInputs() const { + // TODO(superjom) make level a config + // TODO(superjom) check all the inputs has the same LoD + int level = 0; + const auto& inlinks = cache_.inlinks; + for (const auto& item : inlinks) { + const auto& var = item.second; + const auto& tensor = var->Get(); + TensorArray& ta = step_inputs_[item.first]; + dy_seq_metas_[item.first] = + ta.Unpack(tensor, level, true /*length_descend*/); + + if (cache_.num_steps) { + PADDLE_ENFORCE_EQ(ta.size(), cache_.num_steps, + "inputs should have the same steps"); + } else { + cache_.num_steps = ta.size(); + } + } +} + +void DynamicRecurrentOp::WriteStepInputs() const { + for (const auto& item : cache_.inlinks) { + auto ta_it = step_inputs_.find(item.first); + PADDLE_ENFORCE(ta_it != step_inputs_.end(), + "step_inputs_ not compatible with memory set"); + TensorArray& ta = ta_it->second; + for (size_t step = 0; step < ta.size(); step++) { + auto tensor = ta.Read(step); + auto& step_scope = cache_.GetScope(step); + Variable* var = step_scope.FindVar(item.first); + if (var == nullptr) { + var = step_scope.NewVar(item.first); + } + var->GetMutable()->ShareDataWith(tensor); + } + } +} + +void DynamicRecurrentOp::WriteStepOutputs() const { + for (size_t step = 0; step < cache_.scopes->size(); step++) { + auto& scope = cache_.GetScope(step); + for (auto& item : step_outputs_) { + auto* var = scope.FindVar(item.first); + if (var == nullptr) { + var = scope.NewVar(item.first); + } + auto* tensor = var->GetMutable(); + item.second.WriteShared(step, *tensor); + } + } +} + +void DynamicRecurrentOp::CreateScopes() const { + PADDLE_ENFORCE_GT(cache_.num_steps, 0); + // resize scopes + size_t num_scopes_need_create = cache_.num_steps - cache_.scopes->size(); + for (size_t i = 0; i < num_scopes_need_create; i++) { + cache_.scopes->emplace_back(&cache_.scope->NewScope()); + } + + // init temporary inputs + PADDLE_ENFORCE_NOT_NULL(stepnet_, "stepnet should be set first"); + std::vector memories; + std::vector pre_memories; + std::transform(arg_.memories.begin(), arg_.memories.end(), + std::back_inserter(memories), + [](const rnn::MemoryAttr& m) { return m.var; }); + std::transform(arg_.memories.begin(), arg_.memories.end(), + std::back_inserter(pre_memories), + [](const rnn::MemoryAttr& m) { return m.pre_var; }); + + for (size_t step = 0; step < cache_.num_steps; step++) { + auto& scope = cache_.GetScope(step); + detail::CreateVariables(scope, arg_.inlinks); + detail::CreateVariables(scope, arg_.outlinks); + detail::CreateVariables(scope, memories); + detail::CreateVariables(scope, pre_memories); + } +} + +void DynamicRecurrentOp::ConcatOutputs() const { + // TODO(superjom) transform this to a config + int level = 0; + // TODO(superjom) pass in some lod + // just a placeholder + framework::LoD lod; + for (auto& item : step_outputs_) { + auto tensor = item.second.Pack(level, dy_seq_metas_[item.first], lod); + auto& output = cache_.outlinks[item.first]->Get(); + const_cast(&output)->ShareDataWith(tensor); + } +} + +void DynamicRecurrentOp::InitStates() const { + // init the first state + // TODO(superjom) parepare the scenerio that boot state not exists + for (auto memory : arg_.memories) { + auto* boot_state_var = cache_.scope->FindVar(memory.boot_var); + PADDLE_ENFORCE_NOT_NULL(boot_state_var); + auto& boot_state = boot_state_var->Get(); + const auto& dims = boot_state.dims(); + + for (size_t step = 0; step < cache_.num_steps; step++) { + auto& cur_scope = cache_.GetScope(step); + // link pre-state to boot_state + // init state and pre-state + auto* pre_state = cur_scope.FindVar(memory.pre_var); + PADDLE_ENFORCE_NOT_NULL(pre_state); + pre_state->GetMutable(); + + auto* state = cur_scope.FindVar(memory.var); + PADDLE_ENFORCE_NOT_NULL(state); + state->GetMutable()->Resize(dims); + state->GetMutable()->mutable_data( + platform::CPUPlace()); + + if (step == 0) { + auto* pre_state_tensor = pre_state->GetMutable(); + pre_state_tensor->Resize(boot_state.dims()); + pre_state_tensor->ShareDataWith(boot_state); + } else { + auto& pre_scope = cache_.GetScope(step - 1); + auto* state_pre = pre_scope.FindVar(memory.var); + PADDLE_ENFORCE_NOT_NULL(state_pre); + pre_state->GetMutable()->ShareDataWith( + *state_pre->GetMutable()); + } + } + } +} + +void DynamicRecurrentOp::ArgCache::Init( + const rnn::ArgumentName& name, const paddle::framework::OperatorBase& op, + const paddle::framework::Scope& scope, rnn::Argument* arg) { + this->scope = &scope; + InitArgument(name, op, arg); + CacheScopes(scope, *arg); + CacheInlinks(scope, arg->inlinks); + CacheOutlinks(scope, arg->outlinks); +} + +void DynamicRecurrentOp::ArgCache::InitArgument(const rnn::ArgumentName& name, + const OperatorBase& op, + rnn::Argument* arg) { + rnn::InitArgument(name, arg, op, false /*is_grad*/); +} + +void DynamicRecurrentOp::ArgCache::CacheScopes(const Scope& scope, + const rnn::Argument& arg) { + auto scopes_var = scope.FindVar(arg.step_scopes); + PADDLE_ENFORCE(scopes_var != nullptr, + "the step_scopes output argument [%s] should be created first " + "by framework.", + arg.step_scopes); + this->scopes = scopes_var->GetMutable>(); +} + +void DynamicRecurrentOp::ArgCache::CacheInlinks( + const Scope& scope, const std::vector& names) { + for (auto name : names) { + auto* var = GetVariable(scope, name); + inlinks[name] = var; + } +} + +void DynamicRecurrentOp::ArgCache::CacheOutlinks( + const Scope& scope, const std::vector& names) { + for (auto name : names) { + auto* var = GetVariable(scope, name); + outlinks[name] = var; + } +} + +Variable* DynamicRecurrentOp::ArgCache::GetVariable(const Scope& scope, + const std::string& name) { + auto* var = scope.FindVar(name); + PADDLE_ENFORCE_NOT_NULL(var, "variable [%s] not exist in scope", name); + return var; +} + +const rnn::ArgumentName DynamicRecurrentOp::kArgName{ + "step_net", "step_scopes", "inlinks", "outlinks", + "memories", "pre_memories", "boot_memories"}; + +void DynamicRecurrentGradientOp::Run( + const Scope& scope, const platform::DeviceContext& dev_ctx) const {} + +} // namespace operators +} // namespace paddle + +REGISTER_OP_WITHOUT_GRADIENT( + dynamic_recurrent, paddle::operators::DynamicRecurrentOp, + paddle::operators::DynamicRecurrentOpProtoAndCheckerMaker); diff --git a/paddle/operators/dynamic_recurrent_op.h b/paddle/operators/dynamic_recurrent_op.h new file mode 100644 index 0000000000..6a2970f27f --- /dev/null +++ b/paddle/operators/dynamic_recurrent_op.h @@ -0,0 +1,158 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#ifdef PADDLE_WITH_TESTING +#include "gtest/gtest.h" +#endif + +#include "paddle/framework/lod_tensor.h" +#include "paddle/framework/operator.h" +#include "paddle/framework/tensor_array.h" +#include "paddle/framework/variable.h" +#include "paddle/operators/rnn/recurrent_op_utils.h" + +namespace paddle { +namespace operators { + +class DynamicRecurrentOp : public framework::OperatorBase { + public: + static const rnn::ArgumentName kArgName; + using value_type = float; + + DynamicRecurrentOp(const std::string& type, + const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + DynamicRecurrentOp(const DynamicRecurrentOp& o) + : framework::OperatorBase( + static_cast(o)) { + // TODO(yuyang18): Implement copy ctor well. + PADDLE_THROW("Not implemented"); + } + + void Run(const framework::Scope& scope, + const platform::DeviceContext& dev_ctx) const override; + + /* + * Split the inputs(LoDTensors) to segments for each time step. + */ + void SplitInputs() const; + + /* + * Create step-scopes to store temporary outputs in each time steps. + */ + void CreateScopes() const; + + /* + * Link TensorArray steps to the corresponding variables located in + * step-scopes. + */ + void WriteStepInputs() const; + + /* + * Write output of each step to the corresponding TensorArray. + */ + void WriteStepOutputs() const; + + /* + * Initialize the states, each state will have a corresponding pre-state, + * which share the memory with the state in the previous time state. The + * pre-state in the first time step will be initialized with an zero tensor or + * a tensor in parent scope if is provided. + */ + void InitStates() const; + + /* + * Concatenate outputs in each time step and generate a LoDTensor. + */ + void ConcatOutputs() const; + + /* + * set a stepnet that is created according to a RecurrentOp's stepnet. + */ + void SetStepNet(std::unique_ptr net) { + PADDLE_ENFORCE_NOT_NULL(net); + stepnet_ = std::move(net); + } + const OperatorBase& GetStepNet() const { return *stepnet_; } + + protected: + struct ArgCache { + framework::Scope const* scope; + std::vector* scopes; + std::map inlinks; + std::map outlinks; + + size_t num_steps{0}; + + void Init(const rnn::ArgumentName& name, const OperatorBase& op, + const framework::Scope& scope, rnn::Argument* arg); + + framework::Scope& GetScope(size_t index) { + PADDLE_ENFORCE_LT(index, num_steps); + return *scopes->at(index); + } + + private: + void InitArgument(const rnn::ArgumentName& name, const OperatorBase& op, + rnn::Argument* arg); + void CacheScopes(const framework::Scope& scope, const rnn::Argument& arg); + void CacheInlinks(const framework::Scope& scope, + const std::vector& names); + void CacheOutlinks(const framework::Scope& scope, + const std::vector& names); + framework::Variable* GetVariable(const framework::Scope& scope, + const std::string& name); + }; + + private: + std::unique_ptr stepnet_; + mutable framework::TensorArray states_; + mutable std::map step_inputs_; + mutable std::map step_outputs_; + mutable std::map> + dy_seq_metas_; + mutable rnn::Argument arg_; + mutable ArgCache cache_; + +#ifdef PADDLE_WITH_TESTING + friend class DynamicRecurrentOpTestHelper; + FRIEND_TEST(DynamicRecurrentOpTestHelper, SplitInputs); + FRIEND_TEST(DynamicRecurrentOpTestHelper, CreateCache); + FRIEND_TEST(DynamicRecurrentOpTestHelper, CreateScopes); + FRIEND_TEST(DynamicRecurrentOpTestHelper, WriteStepInputs); + FRIEND_TEST(DynamicRecurrentOpTestHelper, WriteStepOutputs); + FRIEND_TEST(DynamicRecurrentOpTestHelper, InitStates); + FRIEND_TEST(DynamicRecurrentOpTestHelper, ConcatOutputs); +#endif +}; + +class DynamicRecurrentGradientOp : public framework::OperatorBase { + public: + DynamicRecurrentGradientOp(const std::string& type, + const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + void Run(const framework::Scope& scope, + const platform::DeviceContext& dev_ctx) const override; +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/dynamic_recurrent_op_test.cc b/paddle/operators/dynamic_recurrent_op_test.cc new file mode 100644 index 0000000000..675a7890f3 --- /dev/null +++ b/paddle/operators/dynamic_recurrent_op_test.cc @@ -0,0 +1,222 @@ +#include "paddle/operators/dynamic_recurrent_op.h" + +#include + +#include "paddle/framework/ddim.h" +#include "paddle/framework/lod_tensor.h" +#include "paddle/framework/op_desc.h" +#include "paddle/framework/op_registry.h" +#include "paddle/operators/net_op.h" + +namespace paddle { +namespace operators { + +using framework::Scope; +using framework::TensorArray; +using framework::LoDTensor; +using framework::Variable; + +class TestOp : public framework::OperatorBase { + public: + using framework::OperatorBase::OperatorBase; + DEFINE_OP_CLONE_METHOD(TestOp); + void Run(const Scope& scope, + const platform::DeviceContext& dev_ctx) const override {} +}; + +void OpDescNewVar(const std::string& param_name, + std::initializer_list arguments, + paddle::framework::OpDesc::Var* var) { + var->set_parameter(param_name); + for (auto& arg_name : arguments) { + var->add_arguments(arg_name); + } +} + +// create a LoD tensor in scope with specific dims +LoDTensor* CreateVar(Scope& scope, std::string name, framework::DDim dims, + const platform::Place& place) { + auto* var = scope.NewVar(name); + auto* tensor = var->GetMutable(); + tensor->Resize(dims); + tensor->mutable_data(place); + return tensor; +} + +class DynamicRecurrentOpTestHelper : public ::testing::Test { + protected: + const rnn::ArgumentName argname = DynamicRecurrentOp::kArgName; + + virtual void SetUp() override { + CreateGlobalVariables(); + + auto op_desc = CreateOpDesc(); + op = paddle::framework::OpRegistry::CreateOp(op_desc); + dop = dynamic_cast(op.get()); + InitCacheManually(); + InitStepNet(); + } + + framework::OpDesc CreateOpDesc() { + // create op + paddle::framework::OpDesc op_desc; + op_desc.set_type("dynamic_recurrent"); + + OpDescNewVar(argname.inlinks, {"in0"}, op_desc.add_inputs()); + OpDescNewVar(argname.boot_memories, {"boot_mem"}, op_desc.add_inputs()); + OpDescNewVar(argname.step_scopes, {"step_scopes"}, op_desc.add_outputs()); + OpDescNewVar(argname.outlinks, {"out0"}, op_desc.add_outputs()); + + // set pre-memories + auto pre_memories = op_desc.mutable_attrs()->Add(); + pre_memories->set_name(argname.pre_memories); + pre_memories->set_type(paddle::framework::AttrType::STRINGS); + auto pre_memories_item = pre_memories->add_strings(); + *pre_memories_item = "mem@pre"; + + // set memories + auto memories = op_desc.mutable_attrs()->Add(); + memories->set_name(argname.memories); + memories->set_type(paddle::framework::AttrType::STRINGS); + auto memories_item = memories->add_strings(); + *memories_item = "mem"; + return op_desc; + } + + void CreateGlobalVariables() { + platform::CPUPlace place; + scope.NewVar("step_scopes"); + CreateVar(scope, "boot_mem", framework::make_ddim({10, 20}), place); + // auto* out0 = + CreateVar(scope, "out0", framework::make_ddim({10, 20}), place); + auto* in0 = CreateVar(scope, "in0", framework::make_ddim({10, 8}), place); + // 10 instanes with 4 sentences, length is 4, 3, 2, 1 respectively. + framework::LoD in0_lod(1); + for (int x : std::vector{0, 4, 7, 9, 10}) { + in0_lod[0].push_back(x); + } + in0->set_lod(in0_lod); + in0->Resize(framework::make_ddim({10, 8})); + // set the content, each sentence content is seqid.batchid + // the seqid starts from 0 + int start = 0; + for (size_t seqid = 0; seqid < in0_lod.size() - 1; seqid++) { + for (size_t batchid = 0; + batchid < in0_lod[0][seqid + 1] - in0_lod[0][seqid]; batchid++) { + float v = seqid + batchid * 0.1; + + for (size_t dim = 0; dim < 8; dim++) { + in0->data()[start * 8 + dim] = v; + } + start++; + } + } + } + + void InitCacheManually() { + dop->cache_.Init(DynamicRecurrentOp::kArgName, *dop, scope, &dop->arg_); + } + + void InitStepNet() { + std::unique_ptr stepnet{new NetOp}; + dynamic_cast(stepnet.get()) + ->AppendOp(std::unique_ptr(new TestOp( + "test", {{"inlinks", {"in0"}}, {"boot_memories", {"boot_mem"}}}, + {{"outlinks", {"out0"}}, {"step_scopes", {"step_scopes"}}}, {}))); + dop->SetStepNet(std::move(stepnet)); + } + + protected: + DynamicRecurrentOp* dop; + std::unique_ptr op; + paddle::platform::CPUDeviceContext device_context; + paddle::framework::Scope scope; +}; + +TEST_F(DynamicRecurrentOpTestHelper, CreateCache) { + const rnn::Argument& arg = dop->arg_; + ASSERT_EQ(arg.inlinks.size(), 1UL); + ASSERT_EQ(arg.outlinks.size(), 1UL); +} + +TEST_F(DynamicRecurrentOpTestHelper, SplitInputs) { + dop->SplitInputs(); + auto& in0_ta = dop->step_inputs_["in0"]; + ASSERT_EQ(in0_ta.size(), 4UL); + + const auto& batch0 = in0_ta.Read(0); + const auto& batch1 = in0_ta.Read(1); + const auto& batch2 = in0_ta.Read(2); + const auto& batch3 = in0_ta.Read(3); + EXPECT_EQ(batch0.dims()[0], 4); + EXPECT_EQ(batch1.dims()[0], 3); + EXPECT_EQ(batch2.dims()[0], 2); + EXPECT_EQ(batch3.dims()[0], 1); +} + +TEST_F(DynamicRecurrentOpTestHelper, CreateScopes) { + dop->SplitInputs(); + dop->CreateScopes(); + ASSERT_EQ(dop->cache_.num_steps, 4UL); + ASSERT_EQ(dop->cache_.scopes->size(), 4UL); +} + +TEST_F(DynamicRecurrentOpTestHelper, WriteStepInputs) { + dop->SplitInputs(); + dop->CreateScopes(); + dop->WriteStepInputs(); + + for (size_t step = 0; step < dop->cache_.num_steps; step++) { + auto& scope = dop->cache_.GetScope(step); + for (auto name : std::vector({"in0"})) { + ASSERT_TRUE(scope.FindVar(name) != nullptr); + } + } +} + +TEST_F(DynamicRecurrentOpTestHelper, WriteStepOutputs) { + dop->SplitInputs(); + dop->CreateScopes(); + dop->WriteStepInputs(); + dop->WriteStepOutputs(); + + for (size_t step = 0; step < dop->cache_.num_steps; step++) { + auto& scope = dop->cache_.GetScope(step); + for (auto name : std::vector({"out0"})) { + ASSERT_TRUE(scope.FindVar(name)); + } + } +} + +TEST_F(DynamicRecurrentOpTestHelper, ConcatOutputs) { + // Let's leave this test to python unittest. +} + +TEST_F(DynamicRecurrentOpTestHelper, InitStates) { + dop->SplitInputs(); + dop->CreateScopes(); + dop->WriteStepInputs(); + dop->WriteStepOutputs(); + dop->InitStates(); + + for (size_t step = 0; step < dop->cache_.num_steps; step++) { + auto& scope = dop->cache_.GetScope(step); + auto state = scope.FindVar("mem"); + ASSERT_TRUE(state != nullptr); + + auto* pre_state = scope.FindVar("mem@pre"); + ASSERT_TRUE(pre_state != nullptr); + + auto* boot_state = scope.FindVar("boot_mem"); + ASSERT_TRUE(boot_state != nullptr); + + if (step == 0) { + // check pre_state is a reference of boot_state + ASSERT_EQ(boot_state->Get().data(), + pre_state->Get().data()); + } + } +} + +} // operators +} // namespace paddle diff --git a/paddle/operators/elementwise_op.h b/paddle/operators/elementwise_op.h index 3082f37422..66f1910a47 100644 --- a/paddle/operators/elementwise_op.h +++ b/paddle/operators/elementwise_op.h @@ -25,7 +25,7 @@ class ElementwiseOp : public framework::OperatorWithKernel { protected: using Tensor = framework::Tensor; - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of elementwise op should not be null"); PADDLE_ENFORCE(ctx->HasInput("Y"), @@ -106,7 +106,7 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { using Tensor = framework::Tensor; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null"); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), diff --git a/paddle/operators/feed_op.cc b/paddle/operators/feed_op.cc new file mode 100644 index 0000000000..fa325bb282 --- /dev/null +++ b/paddle/operators/feed_op.cc @@ -0,0 +1,59 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/feed_op.h" + +namespace paddle { +namespace operators { + +class FeedOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output should be not null."); + auto& shape = ctx->Attrs().Get>("dims"); + std::vector shape_int64(shape.size(), 0); + std::transform(shape.begin(), shape.end(), shape_int64.begin(), + [](int a) { return static_cast(a); }); + ctx->SetOutputDim("Out", framework::make_ddim(shape_int64)); + // TODO(qijun): need to handle LodTensor later + } + + framework::DataType IndicateDataType( + const framework::ExecutionContext& ctx) const override { + return static_cast(Attr("dataType")); + } +}; + +class FeedOpMaker : public framework::OpProtoAndCheckerMaker { + public: + FeedOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddAttr("dataType", "output data type") + .SetDefault(framework::DataType::FP32); + AddAttr("col", "The col in global feed variable").SetDefault(0); + AddAttr>("dims", "The dimension of feed tensor."); + AddOutput("Out", "The output of feed op."); + AddComment(R"DOC(Feed data from global feed variable)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(feed, ops::FeedOp, ops::FeedOpMaker); +REGISTER_OP_CPU_KERNEL(feed, ops::FeedKernel); diff --git a/paddle/operators/feed_op.cu b/paddle/operators/feed_op.cu new file mode 100644 index 0000000000..7b6a2ac91e --- /dev/null +++ b/paddle/operators/feed_op.cu @@ -0,0 +1,18 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/feed_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(feed, ops::FeedKernel); diff --git a/paddle/operators/feed_op.h b/paddle/operators/feed_op.h new file mode 100644 index 0000000000..9d8158299f --- /dev/null +++ b/paddle/operators/feed_op.h @@ -0,0 +1,42 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class FeedKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + framework::Tensor* out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); + framework::Variable* g_feed_variable = + framework::GetGlobalScope().FindVar("feed_value"); + const auto& tensors = + g_feed_variable->Get>(); + int col = ctx.template Attr("col"); + PADDLE_ENFORCE_GT(tensors.size(), static_cast(col)); + // TODO(qijun): + // check tensors[col].dims() with attribute, + // except the first dimenson. + out->CopyFrom(tensors[col], ctx.GetPlace()); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/fetch_op.cc b/paddle/operators/fetch_op.cc new file mode 100644 index 0000000000..90737c8c55 --- /dev/null +++ b/paddle/operators/fetch_op.cc @@ -0,0 +1,52 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/fetch_op.h" + +namespace paddle { +namespace operators { + +class FetchOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Input"), "Input should be not null."); + } + + framework::DataType IndicateDataType( + const framework::ExecutionContext& ctx) const override { + return static_cast(Attr("dataType")); + } +}; + +class FetchOpMaker : public framework::OpProtoAndCheckerMaker { + public: + FetchOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddAttr("dataType", "output data type") + .SetDefault(framework::DataType::FP32); + AddAttr("col", "The col in global fetch variable").SetDefault(0); + AddInput("Input", "The output of fetch op."); + AddComment(R"DOC(Fetch data to global fetch variable)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(fetch, ops::FetchOp, ops::FetchOpMaker); +REGISTER_OP_CPU_KERNEL(fetch, ops::FetchKernel); diff --git a/paddle/operators/fetch_op.cu b/paddle/operators/fetch_op.cu new file mode 100644 index 0000000000..ca39d24c79 --- /dev/null +++ b/paddle/operators/fetch_op.cu @@ -0,0 +1,18 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/fetch_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(fetch, ops::FetchKernel); diff --git a/paddle/operators/fetch_op.h b/paddle/operators/fetch_op.h new file mode 100644 index 0000000000..eb9c3a7b59 --- /dev/null +++ b/paddle/operators/fetch_op.h @@ -0,0 +1,44 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class FetchKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const framework::Tensor* input = ctx.Input("Input"); + framework::Variable* g_fetch_variable = + framework::GetGlobalScope().FindVar("fetch_value"); + auto* tensors = + g_fetch_variable->GetMutable>(); + int col = ctx.template Attr("col"); + if (tensors->size() < static_cast(col + 1)) { + tensors->resize(col + 1); + } + PADDLE_ENFORCE_GT(tensors->size(), static_cast(col)); + (*tensors)[col].Resize(input->dims()); + (*tensors)[col].mutable_data(platform::CPUPlace()); + (*tensors)[col].CopyFrom(*input, platform::CPUPlace()); + // TODO(qijun): need to handle LodTensor later + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/fill_constant_op.cc b/paddle/operators/fill_constant_op.cc new file mode 100644 index 0000000000..65d03d5fa4 --- /dev/null +++ b/paddle/operators/fill_constant_op.cc @@ -0,0 +1,68 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/fill_constant_op.h" + +namespace paddle { +namespace operators { + +class FillConstantOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of FillConstantOp should not be null."); + auto &shape = ctx->Attrs().Get>("shape"); + std::vector shape_int64(shape.size(), 0); + std::transform(shape.begin(), shape.end(), shape_int64.begin(), + [](int a) { return static_cast(a); }); + auto dims = framework::make_ddim(shape_int64); + ctx->SetOutputDim("Out", dims); + } + + framework::DataType IndicateDataType( + const framework::ExecutionContext &ctx) const override { + return static_cast(ctx.Attr("dataType")); + } +}; + +class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker { + public: + FillConstantOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : framework::OpProtoAndCheckerMaker(proto, op_checker) { + AddAttr("dataType", + "(int, default 5 (FP32)) " + "Output data type") + .SetDefault(framework::DataType::FP32); + AddAttr>("shape", "(vector) The shape of the output"); + AddAttr("value", "(float, default 0) The value to be filled") + .SetDefault(0.0f); + AddOutput("Out", + "(Tensor) Tensor of specified shape will be filled " + "with the specified value"); + AddComment(R"DOC(Fill up a variable with specified constant value.)DOC"); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(fill_constant, ops::FillConstantOp, + ops::FillConstantOpMaker); +REGISTER_OP_CPU_KERNEL( + fill_constant, + ops::FillConstantOpKernel); diff --git a/paddle/operators/fill_constant_op.cu b/paddle/operators/fill_constant_op.cu new file mode 100644 index 0000000000..eef8fcbd7f --- /dev/null +++ b/paddle/operators/fill_constant_op.cu @@ -0,0 +1,22 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/framework/op_registry.h" +#include "paddle/operators/fill_constant_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL( + fill_constant, + ops::FillConstantOpKernel); diff --git a/paddle/operators/fill_constant_op.h b/paddle/operators/fill_constant_op.h new file mode 100644 index 0000000000..53b8b548ec --- /dev/null +++ b/paddle/operators/fill_constant_op.h @@ -0,0 +1,37 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class FillConstantOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); + auto value = ctx.Attr("value"); + + auto out_eigen = framework::EigenVector::Flatten(*out); + auto place = ctx.GetEigenDevice(); + out_eigen.device(place) = out_eigen.constant(static_cast(value)); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/fill_zeros_like_op.cc b/paddle/operators/fill_zeros_like_op.cc index e164de6584..4c70b9a36b 100644 --- a/paddle/operators/fill_zeros_like_op.cc +++ b/paddle/operators/fill_zeros_like_op.cc @@ -22,7 +22,7 @@ class FillZerosLikeOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of FillZerosLikeOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Y"), diff --git a/paddle/operators/gather_op.cc b/paddle/operators/gather_op.cc index fe305337cb..fb99c6c016 100644 --- a/paddle/operators/gather_op.cc +++ b/paddle/operators/gather_op.cc @@ -23,7 +23,7 @@ class GatherOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of GatherOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Index"), @@ -51,7 +51,7 @@ class GatherGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); } diff --git a/paddle/operators/gaussian_random_op.cc b/paddle/operators/gaussian_random_op.cc index 5cd2c7d2c0..ca7fb38505 100644 --- a/paddle/operators/gaussian_random_op.cc +++ b/paddle/operators/gaussian_random_op.cc @@ -43,7 +43,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of GaussianRandomOp should not be null."); auto dims = ctx->Attrs().Get>("dims"); diff --git a/paddle/operators/interp_op.cc b/paddle/operators/interp_op.cc new file mode 100644 index 0000000000..d02b01c3f3 --- /dev/null +++ b/paddle/operators/interp_op.cc @@ -0,0 +1,113 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/op_registry.h" +#include "paddle/operators/net_op.h" + +namespace paddle { +namespace operators { + +class InterpOp : public NetOp { + public: + InterpOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : NetOp(type, inputs, outputs, attrs) { + PADDLE_ENFORCE_NE(Input("X"), framework::kEmptyVarName, + "Input(X) of InterpOp should not be null."); + PADDLE_ENFORCE_NE(Input("Y"), framework::kEmptyVarName, + "Input(Y) of InterpOp should not be null."); + PADDLE_ENFORCE_NE(Input("W"), framework::kEmptyVarName, + "Input(W) of InterpOp should not be null."); + PADDLE_ENFORCE_NE(Output("SubOut"), framework::kEmptyVarName, + "Output(SubOut) of InterpOp should not be null."); + PADDLE_ENFORCE_NE(Output("MulOut"), framework::kEmptyVarName, + "Output(MulOut) of InterpOp should not be null."); + PADDLE_ENFORCE_NE(Output("Out"), framework::kEmptyVarName, + "Output(Out) of InterpOp should not be null."); + + // SubOut = X - Y + auto x = Input("X"); + auto y = Input("Y"); + auto sub_out = Output("SubOut"); + AppendOp(framework::OpRegistry::CreateOp( + "elementwise_sub", {{"X", {x}}, {"Y", {y}}}, {{"Out", {sub_out}}}, {})); + + // MulOut = SubOut * W = (X - Y) * W + auto w = Input("W"); + auto mul_out = Output("MulOut"); + AppendOp(framework::OpRegistry::CreateOp( + "elementwise_mul", {{"X", {sub_out}}, {"Y", {w}}}, {{"Out", {mul_out}}}, + {{"axis", 0}})); + + // Out = MulOut + Y = (X - Y) * W + Y = X * W + Y * (1 - W) + AppendOp(framework::OpRegistry::CreateOp("elementwise_add", + {{"X", {mul_out}}, {"Y", {y}}}, + {{"Out", {Output("Out")}}}, {})); + + CompleteAddOp(false); + } +}; + +class InterpOpMaker : public framework::OpProtoAndCheckerMaker { + public: + InterpOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "(Tensor), 2-D Matrix of shape [batch_size, data_dim]" + "containing data samples, the first input of interp_op"); + AddInput("Y", + "(Tensor), 2-D Matrix of shape `[batch_size, data_dim]`" + "containing data samples, the second input of interp_op"); + AddInput("W", + "(Tensor), 1-D Vector of shape [batch_size]," + "the interpolated values in the half-open interval [0.0, 1.0)"); + AddOutput("SubOut", + "(Tensor), the intermediate subtraction outputs, saving X - Y.") + .AsIntermediate(); + AddOutput("MulOut", + "(Tensor), the intermediate multiplication outputs," + "saving the elementwise multiplication of (X - Y) and W.") + .AsIntermediate(); + AddOutput("Out", + "(Tensor), the output of interp_op, same shape with X," + "returns the first-dimensional piecewise linear interpolant " + "between X and Y"); + AddComment(R"DOC( + Linear Interpolation with two inputs, used in NEURAL TURING MACHINE. + + Equation: + Out.row[i] = X.row[i] * W[i] + Y.row[i] * (1 - W[i]) + = (X.row[i] - Y.row[i]) * W[i] + Y.row[i] + + Example: + X = [[1,2],[3,4]], + Y = [[2,1],[4,3]], + W = [0.3, 0.4] + + Then, Out = [[1.7,1.3],[3.6,3.4]] + + where 1.7 = 1*0.3+2*(1-0.3), + 1.3 = 2*0.3+1*(1-0.3), + 3.6 = 3*0.4+4*(1-0.4), + 3.4 = 4*0.4+3*(1-0.4) +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(interp, ops::InterpOp, ops::InterpOpMaker); diff --git a/paddle/operators/lookup_table_op.cc b/paddle/operators/lookup_table_op.cc index 929008fbcb..3f8d4ab857 100644 --- a/paddle/operators/lookup_table_op.cc +++ b/paddle/operators/lookup_table_op.cc @@ -22,7 +22,7 @@ class LookupTableOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) of LookupTableOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Ids"), @@ -70,7 +70,7 @@ class LookupTableOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { auto table_dims = ctx->GetInputDim("W"); ctx->SetOutputDim(framework::GradVarName("W"), table_dims); } diff --git a/paddle/operators/lstm_unit_op.cc b/paddle/operators/lstm_unit_op.cc index dad56731de..13a45ec246 100644 --- a/paddle/operators/lstm_unit_op.cc +++ b/paddle/operators/lstm_unit_op.cc @@ -22,7 +22,7 @@ class LstmUnitOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of LSTM should not be null."); PADDLE_ENFORCE(ctx->HasInput("C_prev"), "Input(C_prev) of LSTM should not be null."); @@ -77,7 +77,7 @@ class LstmUnitGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("C")), "Input(C@GRAD) should not be null"); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("H")), diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index a0ceb029e3..2fd559e90a 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -3,11 +3,14 @@ if(WITH_GPU) nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) nv_library(softmax SRCS softmax.cc softmax.cu DEPS operator) nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator) + nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context) else() cc_library(math_function SRCS math_function.cc im2col.cc pooling.cc DEPS cblas device_context operator) cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) cc_library(softmax SRCS softmax.cc DEPS operator) cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator) + cc_library(vol2col SRCS vol2col.cc DEPS device_context) endif() cc_test(im2col_test SRCS im2col_test.cc DEPS math_function tensor) +cc_test(vol2col_test SRCS vol2col_test.cc DEPS vol2col tensor) diff --git a/paddle/operators/math/pooling.cc b/paddle/operators/math/pooling.cc index 3b706529d8..50cfb88bb5 100644 --- a/paddle/operators/math/pooling.cc +++ b/paddle/operators/math/pooling.cc @@ -18,6 +18,11 @@ namespace paddle { namespace operators { namespace math { +/* + * All tensors are in NCHW format. + * Ksize, strides, paddings are two elements. These two elements represent + * height and width, respectively. + */ template class Pool2dFunctor { public: @@ -73,6 +78,11 @@ class Pool2dFunctor { } }; +/* +* All tensors are in NCHW format. +* Ksize, strides, paddings are two elements. These two elements represent height +* and width, respectively. +*/ template class Pool2dGradFunctor { public: @@ -135,6 +145,11 @@ class Pool2dGradFunctor { } }; +/* + * All tensors are in NCHW format. + * Ksize, strides, paddings are two elements. These two elements represent + * height and width, respectively. + */ template class MaxPool2dGradFunctor { public: @@ -197,7 +212,7 @@ class MaxPool2dGradFunctor { }; template class MaxPool2dGradFunctor; -// template class MaxPool2dGradFunctor; +template class MaxPool2dGradFunctor; template class Pool2dFunctor, float>; @@ -216,6 +231,11 @@ template class Pool2dGradFunctor< template class Pool2dGradFunctor< platform::CPUPlace, paddle::operators::math::AvgPoolGrad, double>; +/* + * All tensors are in NCDHW format. + * Ksize, strides, paddings are three elements. These three elements represent + * depth, height and width, respectively. + */ template class Pool3dFunctor { public: @@ -286,6 +306,11 @@ class Pool3dFunctor { } }; +/* + * All tensors are in NCDHW format. + * Ksize, strides, paddings are three elements. These three elements represent + * depth, height and width, respectively. + */ template class Pool3dGradFunctor { public: @@ -364,6 +389,11 @@ class Pool3dGradFunctor { } }; +/* + * All tensors are in NCDHW format. + * Ksize, strides, paddings are three elements. These three elements represent + * depth, height and width, respectively. + */ template class MaxPool3dGradFunctor { public: @@ -440,7 +470,7 @@ class MaxPool3dGradFunctor { }; template class MaxPool3dGradFunctor; -// template class MaxPool3dGradFunctor; +template class MaxPool3dGradFunctor; template class Pool3dFunctor, float>; @@ -458,6 +488,253 @@ template class Pool3dGradFunctor< platform::CPUPlace, paddle::operators::math::MaxPoolGrad, double>; template class Pool3dGradFunctor< platform::CPUPlace, paddle::operators::math::AvgPoolGrad, double>; + +/* + * All tensors are in NCHW format. + * Ksize, strides, paddings are two elements. These two elements represent + * height and width, respectively. + */ +template +class MaxPool2dWithIndexFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& output, + framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings) { + const int batch_size = input.dims()[0]; + const int input_height = input.dims()[2]; + const int input_width = input.dims()[3]; + const int output_channels = output.dims()[1]; + const int output_height = output.dims()[2]; + const int output_width = output.dims()[3]; + const int ksize_height = ksize[0]; + const int ksize_width = ksize[1]; + const int stride_height = strides[0]; + const int stride_width = strides[1]; + const int padding_height = paddings[0]; + const int padding_width = paddings[1]; + const int input_stride = input_height * input_width; + const int output_stride = output_height * output_width; + + const T* input_data = input.data(); + T* output_data = output.mutable_data(context.GetPlace()); + T* mask_data = mask.mutable_data(context.GetPlace()); + + for (int i = 0; i < batch_size; i++) { + for (int c = 0; c < output_channels; ++c) { + for (int ph = 0; ph < output_height; ++ph) { + int hstart = ph * stride_height - padding_height; + int hend = std::min(hstart + ksize_height, input_height); + hstart = std::max(hstart, 0); + for (int pw = 0; pw < output_width; ++pw) { + int wstart = pw * stride_width - padding_width; + int wend = std::min(wstart + ksize_width, input_width); + wstart = std::max(wstart, 0); + + T ele = static_cast(-FLT_MAX); + int index = -1; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + if (ele < input_data[h * input_width + w]) { + ele = input_data[h * input_width + w]; + index = h * input_width + w; + } + } + } + output_data[ph * output_width + pw] = ele; + mask_data[ph * output_width + pw] = index; + } + } + // offset + input_data += input_stride; + output_data += output_stride; + mask_data += output_stride; + } + } + } +}; + +/* + * All tensors are in NCHW format. + * Ksize, strides, paddings are two elements. These two elements represent + * height and width, respectively. + */ +template +class MaxPool2dWithIndexGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + framework::Tensor& input_grad, + const framework::Tensor& output_grad, + const framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings) { + const int batch_size = input_grad.dims()[0]; + const int input_height = input_grad.dims()[2]; + const int input_width = input_grad.dims()[3]; + const int output_channels = output_grad.dims()[1]; + const int output_height = output_grad.dims()[2]; + const int output_width = output_grad.dims()[3]; + const int input_stride = input_height * input_width; + const int output_stride = output_height * output_width; + + const T* mask_data = mask.data(); + const T* output_grad_data = output_grad.data(); + T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + + for (int n = 0; n < batch_size; ++n) { + for (int c = 0; c < output_channels; ++c) { + for (int ph = 0; ph < output_height; ++ph) { + for (int pw = 0; pw < output_width; ++pw) { + const int output_idx = ph * output_width + pw; + const int input_idx = static_cast(mask_data[output_idx]); + input_grad_data[input_idx] += output_grad_data[output_idx]; + } + } + // offset + input_grad_data += input_stride; + output_grad_data += output_stride; + mask_data += output_stride; + } + } + } +}; + +template class MaxPool2dWithIndexFunctor; +template class MaxPool2dWithIndexGradFunctor; +template class MaxPool2dWithIndexFunctor; +template class MaxPool2dWithIndexGradFunctor; + +/* + * All tensors are in NCDHW format. + * Ksize, strides, paddings are three elements. These three elements represent + * depth, height and width, respectively. + */ +template +class MaxPool3dWithIndexFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& output, + framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings) { + const int batch_size = input.dims()[0]; + const int input_depth = input.dims()[2]; + const int input_height = input.dims()[3]; + const int input_width = input.dims()[4]; + const int output_channels = output.dims()[1]; + const int output_depth = output.dims()[2]; + const int output_height = output.dims()[3]; + const int output_width = output.dims()[4]; + const int ksize_depth = ksize[0]; + const int ksize_height = ksize[1]; + const int ksize_width = ksize[2]; + const int stride_depth = strides[0]; + const int stride_height = strides[1]; + const int stride_width = strides[2]; + const int padding_depth = paddings[0]; + const int padding_height = paddings[1]; + const int padding_width = paddings[2]; + const int input_stride = input_depth * input_height * input_width; + const int output_stride = output_depth * output_height * output_width; + + const T* input_data = input.data(); + T* output_data = output.mutable_data(context.GetPlace()); + T* mask_data = mask.mutable_data(context.GetPlace()); + + for (int i = 0; i < batch_size; i++) { + for (int c = 0; c < output_channels; ++c) { + for (int pd = 0; pd < output_depth; ++pd) { + int dstart = pd * stride_depth - padding_depth; + int dend = std::min(dstart + ksize_depth, input_depth); + dstart = std::max(dstart, 0); + for (int ph = 0; ph < output_height; ++ph) { + int hstart = ph * stride_height - padding_height; + int hend = std::min(hstart + ksize_height, input_height); + hstart = std::max(hstart, 0); + for (int pw = 0; pw < output_width; ++pw) { + int wstart = pw * stride_width - padding_width; + int wend = std::min(wstart + ksize_width, input_width); + wstart = std::max(wstart, 0); + + int output_idx = (pd * output_height + ph) * output_width + pw; + T ele = static_cast(-FLT_MAX); + int index = -1; + for (int d = dstart; d < dend; ++d) { + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + int input_idx = (d * input_height + h) * input_width + w; + if (ele < input_data[input_idx]) { + index = input_idx; + ele = input_data[input_idx]; + } + } + } + } + output_data[output_idx] = ele; + mask_data[output_idx] = index; + } + } + } + // offset + input_data += input_stride; + output_data += output_stride; + mask_data += output_stride; + } + } + } +}; + +/* + * All tensors are in NCDHW format. + * Ksize, strides, paddings are three elements. These three elements represent + * depth, height and width, respectively. + */ +template +class MaxPool3dWithIndexGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + framework::Tensor& input_grad, + const framework::Tensor& output_grad, + const framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings) { + const int batch_size = input_grad.dims()[0]; + const int input_depth = input_grad.dims()[2]; + const int input_height = input_grad.dims()[3]; + const int input_width = input_grad.dims()[4]; + const int output_channels = output_grad.dims()[1]; + const int output_depth = output_grad.dims()[2]; + const int output_height = output_grad.dims()[3]; + const int output_width = output_grad.dims()[4]; + const int input_stride = input_depth * input_height * input_width; + const int output_stride = output_depth * output_height * output_width; + + const T* mask_data = mask.data(); + const T* output_grad_data = output_grad.data(); + T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + + for (int n = 0; n < batch_size; ++n) { + for (int c = 0; c < output_channels; ++c) { + for (int pd = 0; pd < output_depth; ++pd) { + for (int ph = 0; ph < output_height; ++ph) { + for (int pw = 0; pw < output_width; ++pw) { + const int output_idx = + (pd * output_height + ph) * output_width + pw; + const int input_idx = static_cast(mask_data[output_idx]); + input_grad_data[input_idx] += output_grad_data[output_idx]; + } + } + } + // offset + input_grad_data += input_stride; + output_grad_data += output_stride; + mask_data += output_stride; + } + } + } +}; + +template class MaxPool3dWithIndexFunctor; +template class MaxPool3dWithIndexGradFunctor; +template class MaxPool3dWithIndexFunctor; +template class MaxPool3dWithIndexGradFunctor; } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/pooling.cu b/paddle/operators/math/pooling.cu index 8aeccd1f8e..736327f4b7 100644 --- a/paddle/operators/math/pooling.cu +++ b/paddle/operators/math/pooling.cu @@ -144,11 +144,16 @@ __global__ void KernelMaxPool2DGrad( if (maxIndex != -1) { // atomic add - atomicAdd(input_grad + maxIndex, output_grad[index]); + platform::CudaAtomicAdd(input_grad + maxIndex, output_grad[index]); } } } +/* + * All tensors are in NCHW format. + * Ksize, strides, paddings are two elements. These two elements represent + * height and width, respectively. + */ template class Pool2dFunctor { public: @@ -190,6 +195,11 @@ class Pool2dFunctor { } }; +/* + * All tensors are in NCHW format. + * Ksize, strides, paddings are two elements. These two elements represent + * height and width, respectively. + */ template class Pool2dGradFunctor { public: @@ -234,6 +244,11 @@ class Pool2dGradFunctor { } }; +/* + * All tensors are in NCHW format. + * Ksize, strides, paddings are two elements. These two elements represent + * height and width, respectively. + */ template class MaxPool2dGradFunctor { public: @@ -278,9 +293,7 @@ class MaxPool2dGradFunctor { }; template class MaxPool2dGradFunctor; -// template class MaxPool2dGradFunctor; // The -// 64-bit floating-point version of atomicAdd() is only supported by devices of -// compute capability 6.x and higher. +template class MaxPool2dGradFunctor; template class Pool2dFunctor, float>; @@ -453,11 +466,16 @@ __global__ void KernelMaxPool3DGrad( } if (maxIdx != -1) { // atomic add - atomicAdd(input_grad + maxIdx, output_grad[index]); + platform::CudaAtomicAdd(input_grad + maxIdx, output_grad[index]); } } } +/* + * All tensors are in NCDHW format. + * Ksize, strides, paddings are three elements. These three elements represent + * depth, height and width, respectively. + */ template class Pool3dFunctor { public: @@ -506,6 +524,11 @@ class Pool3dFunctor { } }; +/* + * All tensors are in NCDHW format. + * Ksize, strides, paddings are three elements. These three elements represent + * depth, height and width, respectively. + */ template class Pool3dGradFunctor { public: @@ -558,6 +581,11 @@ class Pool3dGradFunctor { } }; +/* + * All tensors are in NCDHW format. + * Ksize, strides, paddings are three elements. These three elements represent + * depth, height and width, respectively. + */ template class MaxPool3dGradFunctor { public: @@ -609,9 +637,7 @@ class MaxPool3dGradFunctor { }; template class MaxPool3dGradFunctor; -// template class MaxPool3dGradFunctor; // The -// 64-bit floating-point version of atomicAdd() is only supported by devices of -// compute capability 6.x and higher. +template class MaxPool3dGradFunctor; template class Pool3dFunctor, float>; @@ -630,6 +656,404 @@ template class Pool3dGradFunctor< template class Pool3dGradFunctor< platform::GPUPlace, paddle::operators::math::AvgPoolGrad, double>; +template +__global__ void KernelMaxPool2dWithIdx( + const int nthreads, const T* input_data, T* output_data, T* mask_data, + const int channels, const int input_height, const int input_width, + const int output_height, const int output_width, const int ksize_height, + const int ksize_width, const int stride_height, const int stride_width, + const int padding_height, const int padding_width) { + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; + index += blockDim.x * gridDim.x) { + int pw = index % output_width; + int ph = (index / output_width) % output_height; + int c = (index / output_width / output_height) % channels; + int batch_idx = index / output_width / output_height / channels; + + int hstart = ph * stride_height - padding_height; + int hend = min(hstart + ksize_height, input_height); + hstart = max(hstart, 0); + + int wstart = pw * stride_width - padding_width; + int wend = min(wstart + ksize_width, input_width); + wstart = max(wstart, 0); + + input_data += (batch_idx * channels + c) * input_height * input_width; + T ele = -FLT_MAX; + int max_index = -1; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + int input_index = h * input_width + w; + if (ele < input_data[input_index]) { + max_index = input_index; + ele = input_data[input_index]; + } + } + } + output_data[index] = ele; + mask_data[index] = max_index; + } +} + +template +__global__ void KernelMaxPool2DWithIdxGrad( + const int nthreads, T* input_grad, const T* output_grad, const T* mask_data, + const int channels, const int input_height, const int input_width, + const int output_height, const int output_width, const int ksize_height, + const int ksize_width, const int stride_height, const int stride_width, + const int padding_height, const int padding_width) { + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; + index += blockDim.x * gridDim.x) { + int w_offset = index % input_width; + int h_offset = (index / input_width) % input_height; + int c_offset = (index / input_width / input_height) % channels; + int batch_idx = index / input_width / input_height / channels; + + int ph_start = + (h_offset + padding_height < ksize_height) + ? 0 + : (h_offset + padding_height - ksize_height) / stride_height + 1; + int pw_start = + (w_offset + padding_width < ksize_width) + ? 0 + : (w_offset + padding_width - ksize_width) / stride_width + 1; + int ph_end = + min((h_offset + padding_height) / stride_height + 1, output_height); + int pw_end = + min((w_offset + padding_width) / stride_width + 1, output_width); + + T gradient = 0; + int input_current_featuremap_idx = h_offset * input_width + w_offset; + int output_idx = + (batch_idx * channels + c_offset) * output_height * output_width; + + mask_data += output_idx; + output_grad += output_idx; + for (int ph = ph_start; ph < ph_end; ++ph) { + for (int pw = pw_start; pw < pw_end; ++pw) { + if (mask_data[ph * output_width + pw] == input_current_featuremap_idx) + gradient += output_grad[ph * output_width + pw]; + } + } + input_grad[index] = gradient; + } +} + +/* + * All tensors are in NCHW format. + * Ksize, strides, paddings are two elements. These two elements represent + * height and width, respectively. + */ +template +class MaxPool2dWithIndexFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& output, + framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings) { + const int batch_size = input.dims()[0]; + const int input_channels = input.dims()[1]; + const int input_height = input.dims()[2]; + const int input_width = input.dims()[3]; + const int output_channels = output.dims()[1]; + const int output_height = output.dims()[2]; + const int output_width = output.dims()[3]; + const int ksize_height = ksize[0]; + const int ksize_width = ksize[1]; + const int stride_height = strides[0]; + const int stride_width = strides[1]; + const int padding_height = paddings[0]; + const int padding_width = paddings[1]; + + const T* input_data = input.data(); + T* output_data = output.mutable_data(context.GetPlace()); + T* mask_data = mask.mutable_data(context.GetPlace()); + + int nthreads = batch_size * output_channels * output_height * output_width; + int blocks = (nthreads + 1024 - 1) / 1024; + dim3 threads(1024, 1); + dim3 grid(blocks, 1); + + KernelMaxPool2dWithIdx< + T><<(context) + .stream()>>>(nthreads, input_data, output_data, mask_data, + input_channels, input_height, input_width, + output_height, output_width, ksize_height, + ksize_width, stride_height, stride_width, + padding_height, padding_width); + } +}; + +/* + * All tensors are in NCHW format. + * Ksize, strides, paddings are two elements. These two elements represent + * height and width, respectively. + */ +template +class MaxPool2dWithIndexGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + framework::Tensor& input_grad, + const framework::Tensor& output_grad, + const framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings) { + const int batch_size = input_grad.dims()[0]; + const int input_channels = input_grad.dims()[1]; + const int input_height = input_grad.dims()[2]; + const int input_width = input_grad.dims()[3]; + const int output_height = output_grad.dims()[2]; + const int output_width = output_grad.dims()[3]; + const int ksize_height = ksize[0]; + const int ksize_width = ksize[1]; + const int stride_height = strides[0]; + const int stride_width = strides[1]; + const int padding_height = paddings[0]; + const int padding_width = paddings[1]; + + const T* mask_data = mask.data(); + const T* output_grad_data = output_grad.data(); + T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + + int nthreads = batch_size * input_channels * input_height * input_width; + int blocks = (nthreads + 1024 - 1) / 1024; + dim3 threads(1024, 1); + dim3 grid(blocks, 1); + + KernelMaxPool2DWithIdxGrad< + T><<(context) + .stream()>>>(nthreads, input_grad_data, output_grad_data, + mask_data, input_channels, input_height, + input_width, output_height, output_width, + ksize_height, ksize_width, stride_height, + stride_width, padding_height, padding_width); + } +}; + +template class MaxPool2dWithIndexFunctor; +template class MaxPool2dWithIndexGradFunctor; +template class MaxPool2dWithIndexFunctor; +template class MaxPool2dWithIndexGradFunctor; + +template +__global__ void KernelMaxPool3DWithIdx( + const int nthreads, const T* input_data, T* output_data, T* mask_data, + const int channels, const int input_depth, const int input_height, + const int input_width, const int output_depth, const int output_height, + const int output_width, const int ksize_depth, const int ksize_height, + const int ksize_width, const int stride_depth, const int stride_height, + const int stride_width, const int padding_depth, const int padding_height, + const int padding_width) { + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; + index += blockDim.x * gridDim.x) { + int pw = index % output_width; + int ph = (index / output_width) % output_height; + int pd = (index / output_width / output_height) % output_depth; + int c = (index / output_width / output_height / output_depth) % channels; + int batch_idx = + index / output_width / output_height / output_depth / channels; + + int dstart = pd * stride_depth - padding_depth; + int hstart = ph * stride_height - padding_height; + int wstart = pw * stride_width - padding_width; + int dend = min(dstart + ksize_depth, input_depth); + int hend = min(hstart + ksize_height, input_height); + int wend = min(wstart + ksize_width, input_width); + dstart = max(dstart, 0); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + + T ele = -FLT_MAX; + int max_index = -1; + input_data += + (batch_idx * channels + c) * input_depth * input_height * input_width; + + for (int d = dstart; d < dend; ++d) { + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + if (ele < input_data[(d * input_height + h) * input_width + w]) { + max_index = (d * input_height + h) * input_width + w; + ele = input_data[max_index]; + } + } + } + } + output_data[index] = ele; + mask_data[index] = max_index; + } +} + +template +__global__ void KernelMaxPool3DWithIdxGrad( + const int nthreads, T* input_grad, const T* output_grad, const T* mask, + const int channels, const int input_depth, const int input_height, + const int input_width, const int output_depth, const int output_height, + const int output_width, const int ksize_depth, const int ksize_height, + const int ksize_width, const int stride_depth, const int stride_height, + const int stride_width, const int padding_depth, const int padding_height, + const int padding_width) { + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; + index += blockDim.x * gridDim.x) { + int w_offset = index % input_width; + int h_offset = (index / input_width) % input_height; + int d_offset = (index / input_width / input_height) % input_depth; + int c_offset = + (index / input_width / input_height / input_depth) % channels; + int batch_idx = index / input_width / input_height / input_depth / channels; + + int pd_start = + (d_offset + padding_depth < ksize_depth) + ? 0 + : (d_offset + padding_depth - ksize_depth) / stride_depth + 1; + int ph_start = + (h_offset + padding_height < ksize_height) + ? 0 + : (h_offset + padding_height - ksize_height) / stride_height + 1; + int pw_start = + (w_offset + padding_width < ksize_width) + ? 0 + : (w_offset + padding_width - ksize_width) / stride_width + 1; + int pd_end = + min((d_offset + padding_depth) / stride_depth + 1, output_depth); + int ph_end = + min((h_offset + padding_height) / stride_height + 1, output_height); + int pw_end = + min((w_offset + padding_width) / stride_width + 1, output_width); + + T gradient = 0; + int input_current_feature_map_idx = + (d_offset * input_height + h_offset) * input_width + w_offset; + int output_idx = (batch_idx * channels + c_offset) * output_depth * + output_height * output_width; + mask += output_idx; + output_grad += output_idx; + + for (int pd = pd_start; pd < pd_end; ++pd) { + for (int ph = ph_start; ph < ph_end; ++ph) { + for (int pw = pw_start; pw < pw_end; ++pw) { + if (mask[(pd * output_height + ph) * output_width + pw] == + input_current_feature_map_idx) + gradient += + output_grad[(pd * output_height + ph) * output_width + pw]; + } + } + } + input_grad[index] = gradient; + } +} + +/* + * All tensors are in NCDHW format. + * Ksize, strides, paddings are three elements. These three elements represent + * depth, height and width, respectively. + */ +template +class MaxPool3dWithIndexFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& output, + framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings) { + const int batch_size = input.dims()[0]; + const int input_channels = input.dims()[1]; + const int input_depth = input.dims()[2]; + const int input_height = input.dims()[3]; + const int input_width = input.dims()[4]; + const int output_channels = output.dims()[1]; + const int output_depth = output.dims()[2]; + const int output_height = output.dims()[3]; + const int output_width = output.dims()[4]; + const int ksize_depth = ksize[0]; + const int ksize_height = ksize[1]; + const int ksize_width = ksize[2]; + const int stride_depth = strides[0]; + const int stride_height = strides[1]; + const int stride_width = strides[2]; + const int padding_depth = paddings[0]; + const int padding_height = paddings[1]; + const int padding_width = paddings[2]; + + const T* input_data = input.data(); + T* output_data = output.mutable_data(context.GetPlace()); + T* mask_data = mask.mutable_data(context.GetPlace()); + + int nthreads = batch_size * output_channels * output_depth * output_height * + output_width; + int blocks = (nthreads + 1024 - 1) / 1024; + dim3 threads(1024, 1); + dim3 grid(blocks, 1); + + KernelMaxPool3DWithIdx< + T><<(context) + .stream()>>>( + nthreads, input_data, output_data, mask_data, input_channels, + input_depth, input_height, input_width, output_depth, output_height, + output_width, ksize_depth, ksize_height, ksize_width, stride_depth, + stride_height, stride_width, padding_depth, padding_height, + padding_width); + } +}; + +/* + * All tensors are in NCDHW format. + * Ksize, strides, paddings are three elements. These three elements represent + * depth, height and width, respectively. + */ +template +class MaxPool3dWithIndexGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + framework::Tensor& input_grad, + const framework::Tensor& output_grad, + const framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings) { + const int batch_size = input_grad.dims()[0]; + const int input_channels = input_grad.dims()[1]; + const int input_depth = input_grad.dims()[2]; + const int input_height = input_grad.dims()[3]; + const int input_width = input_grad.dims()[4]; + const int output_depth = output_grad.dims()[2]; + const int output_height = output_grad.dims()[3]; + const int output_width = output_grad.dims()[4]; + const int ksize_depth = ksize[0]; + const int ksize_height = ksize[1]; + const int ksize_width = ksize[2]; + const int stride_depth = strides[0]; + const int stride_height = strides[1]; + const int stride_width = strides[2]; + const int padding_depth = paddings[0]; + const int padding_height = paddings[1]; + const int padding_width = paddings[2]; + + const T* output_grad_data = output_grad.data(); + const T* mask_data = mask.data(); + T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + + int nthreads = + batch_size * input_channels * input_depth * input_height * input_width; + int blocks = (nthreads + 1024 - 1) / 1024; + dim3 threads(1024, 1); + dim3 grid(blocks, 1); + + KernelMaxPool3DWithIdxGrad< + T><<(context) + .stream()>>>( + nthreads, input_grad_data, output_grad_data, mask_data, input_channels, + input_depth, input_height, input_width, output_depth, output_height, + output_width, ksize_depth, ksize_height, ksize_width, stride_depth, + stride_height, stride_width, padding_depth, padding_height, + padding_width); + } +}; + +template class MaxPool3dWithIndexFunctor; +template class MaxPool3dWithIndexGradFunctor; +template class MaxPool3dWithIndexFunctor; +template class MaxPool3dWithIndexGradFunctor; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/pooling.h b/paddle/operators/math/pooling.h index d214c68923..c50c57b5c5 100644 --- a/paddle/operators/math/pooling.h +++ b/paddle/operators/math/pooling.h @@ -21,15 +21,27 @@ limitations under the License. */ namespace paddle { namespace operators { namespace math { -////////////////////// -#define FLT_MAX __FLT_MAX__ // +#define FLT_MAX \ + __FLT_MAX__ // It might need to be placed in another file, but I'm still + // wondering where to put it. + +/* + * \brief Extracting simple operations from pooling. + * Both MaxPool and AvgPool need "initial", "compute" and "finalize" + * operation. + * MaxPool initializes temp variable to the negative maximum to find the + * maximum value in the pooling field. + * AvgPool initializes temp variable to the zero to accumulate all values + * in pool pooling, and finally takes the average. + * MaxPoolGrad and AvgPoolGrad are gradient operations respectively. + */ template class MaxPool { public: DEVICE inline T initial() { return static_cast(-FLT_MAX); } DEVICE inline void compute(T& y, const T& x) { y = y > x ? y : x; } - DEVICE inline void finalize(T& y, const T& poo_size) {} + DEVICE inline void finalize(T& y, const T& pool_field) {} }; template @@ -37,8 +49,9 @@ class AvgPool { public: DEVICE inline T initial() { return static_cast(0); } DEVICE inline void compute(T& y, const T& x) { y += x; } - DEVICE inline void finalize(T& y, const T& poo_size) { y /= poo_size; } + DEVICE inline void finalize(T& y, const T& pool_field) { y /= pool_field; } }; + template class MaxPoolGrad { public: @@ -57,6 +70,20 @@ class AvgPoolGrad { } }; +/* + * \brief Getting pooling results, and calculating gradient. + * + * In pool2d, all tensors are in NCHW format. Where N is batch size, C is the + * number of channels, H and W is the height and width of feature. + * In pool3d, all tensors are in NCDHW format. Where N is batch size, C is the + * number of channels, D, H and W is the depth, height and width of feature. + * + * In max pooling, it is possible that the pooling region has multiple maximum + * elements. In this case, we should compute the gradient of the first maximum + * element. + * This is different from average pooling. So we rewrite the max_pool_grad: + * MaxPool2dGradFunctor, MaxPool3dGradFunctor. + */ template class Pool2dFunctor { public: @@ -117,6 +144,51 @@ class MaxPool3dGradFunctor { std::vector& strides, std::vector& paddings); }; +/* + * \brief Getting max pooling results and corresponding max index, and + * calculating gradient. + * In up-sampling-pooling, it is necessary to know max element index. + * In pool2d, all tensors are in NCHW format. In pool3d, all tensors are in + * NCDHW format. + */ +template +class MaxPool2dWithIndexFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& output, + framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings); +}; + +template +class MaxPool2dWithIndexGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + framework::Tensor& input_grad, + const framework::Tensor& output_grad, + const framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings); +}; + +template +class MaxPool3dWithIndexFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& output, + framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings); +}; + +template +class MaxPool3dWithIndexGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + framework::Tensor& input_grad, + const framework::Tensor& output_grad, + const framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings); +}; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/vol2col.cc b/paddle/operators/math/vol2col.cc new file mode 100644 index 0000000000..e9718a0473 --- /dev/null +++ b/paddle/operators/math/vol2col.cc @@ -0,0 +1,155 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/vol2col.h" + +namespace paddle { +namespace operators { +namespace math { + +/* + * vol = [input_channels, input_depth, input_height, input_width] + * col = + * [input_channels, filter_depth, filter_height, filter_width, + * output_depth, output_height, output_width] + */ +template +class Vol2ColFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& vol, framework::Tensor& col, + int stride_depth, int stride_height, int stride_width, + int padding_depth, int padding_height, + int padding_width) const { + PADDLE_ENFORCE(vol.dims().size() == 4); + PADDLE_ENFORCE(col.dims().size() == 7); + + int input_channels = vol.dims()[0]; + int input_depth = vol.dims()[1]; + int input_height = vol.dims()[2]; + int input_width = vol.dims()[3]; + int filter_depth = col.dims()[1]; + int filter_height = col.dims()[2]; + int filter_width = col.dims()[3]; + int output_depth = col.dims()[4]; + int output_height = col.dims()[5]; + int output_width = col.dims()[6]; + int channels_col = + input_channels * filter_depth * filter_height * filter_width; + + const T* vol_data = vol.data(); + T* col_data = col.data(); + + for (int c = 0; c < channels_col; ++c) { + int w_offset = c % filter_width; + int h_offset = (c / filter_width) % filter_height; + int d_offset = (c / filter_width / filter_height) % filter_depth; + int c_in = c / filter_width / filter_height / filter_depth; + for (int d = 0; d < output_depth; ++d) { + int d_pad = d * stride_depth - padding_depth + d_offset; + for (int h = 0; h < output_height; ++h) { + int h_pad = h * stride_height - padding_height + h_offset; + for (int w = 0; w < output_width; ++w) { + int w_pad = w * stride_width - padding_width + w_offset; + + int col_idx = + ((c * output_depth + d) * output_height + h) * output_width + w; + if (h_pad < 0 || h_pad >= input_height || w_pad < 0 || + w_pad >= input_width || d_pad < 0 || d_pad >= input_depth) { + col_data[col_idx] = static_cast(0); + } else { + int vol_idx = + ((c_in * input_depth + d_pad) * input_height + h_pad) * + input_width + + w_pad; + col_data[col_idx] = vol_data[vol_idx]; + } + } + } + } + } + } +}; + +/* + * vol = [input_channels,input_depth, input_height, input_width] + * col = + * [input_channels, filter_depth, filter_height, filter_width, + * output_depth, output_height, output_width] + */ +template +class Col2VolFunctor { + public: + void operator()(const platform::DeviceContext& context, + framework::Tensor& vol, const framework::Tensor& col, + int stride_depth, int stride_height, int stride_width, + int padding_depth, int padding_height, + int padding_width) const { + PADDLE_ENFORCE(vol.dims().size() == 4); + PADDLE_ENFORCE(col.dims().size() == 7); + + int input_channels = vol.dims()[0]; + int input_depth = vol.dims()[1]; + int input_height = vol.dims()[2]; + int input_width = vol.dims()[3]; + int filter_depth = col.dims()[1]; + int filter_height = col.dims()[2]; + int filter_width = col.dims()[3]; + int output_depth = col.dims()[4]; + int output_height = col.dims()[5]; + int output_width = col.dims()[6]; + int channels_col = + input_channels * filter_depth * filter_height * filter_width; + + T* vol_data = vol.data(); + const T* col_data = col.data(); + + for (int c = 0; c < channels_col; ++c) { + int w_offset = c % filter_width; + int h_offset = (c / filter_width) % filter_height; + int d_offset = (c / filter_width / filter_height) % filter_depth; + int cIm = c / filter_width / filter_height / filter_depth; + for (int d = 0; d < output_depth; ++d) { + int d_pad = d * stride_depth - padding_depth + d_offset; + for (int h = 0; h < output_height; ++h) { + int h_pad = h * stride_height - padding_height + h_offset; + for (int w = 0; w < output_width; ++w) { + int w_pad = w * stride_width - padding_width + w_offset; + + if (h_pad >= 0 && h_pad < input_height && w_pad >= 0 && + w_pad < input_width && d_pad >= 0 && d_pad < input_depth) { + int vol_idx = + ((cIm * input_depth + d_pad) * input_height + h_pad) * + input_width + + w_pad; + int col_idx = + ((c * output_depth + d) * output_height + h) * output_width + + w; + vol_data[vol_idx] += col_data[col_idx]; + } + } + } + } + } + } +}; + +template class Vol2ColFunctor; +template class Vol2ColFunctor; +template class Col2VolFunctor; +template class Col2VolFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/vol2col.cu b/paddle/operators/math/vol2col.cu new file mode 100644 index 0000000000..27b11fb237 --- /dev/null +++ b/paddle/operators/math/vol2col.cu @@ -0,0 +1,204 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/vol2col.h" +#include "paddle/platform/cuda_helper.h" + +namespace paddle { +namespace operators { +namespace math { + +template +__global__ void vol2col(int num_kernels, const T* data_vol, int depth, + int height, int width, int filter_depth, + int filter_height, int filter_width, int stride_depth, + int stride_height, int stride_width, int padding_depth, + int padding_height, int padding_width, int output_detph, + int output_height, int output_width, T* data_col) { + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels; + index += blockDim.x * gridDim.x) { + int w_out = index % output_width; + int h_out = (index / output_width) % output_height; + int d_out = (index / output_width / output_height) % output_detph; + int channel_in = index / output_width / output_height / output_detph; + int channel_out = channel_in * filter_depth * filter_height * filter_width; + int w_in = w_out * stride_width - padding_width; + int h_in = h_out * stride_height - padding_height; + int d_in = d_out * stride_depth - padding_depth; + + data_col += ((channel_out * output_detph + d_out) * output_height + h_out) * + output_width + + w_out; + data_vol += ((channel_in * depth + d_in) * height + h_in) * width + w_in; + for (int k = 0; k < filter_depth; ++k) { + for (int i = 0; i < filter_height; ++i) { + for (int j = 0; j < filter_width; ++j) { + int d = d_in + k; + int h = h_in + i; + int w = w_in + j; + *data_col = (d >= 0 && d < depth && h >= 0 && h < height && w >= 0 && + w < width) + ? data_vol[(k * height + i) * width + j] + : 0; + data_col += output_detph * output_height * output_width; + } + } + } + } +} + +/* + * im = [input_channels,intpu_depth, input_height, input_width] + * col = + * [input_channels, filter_depth, filter_height, filter_width, + * output_depth, output_height, output_width] + */ +template +class Vol2ColFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& vol, framework::Tensor& col, + int stride_depth, int stride_height, int stride_width, + int padding_depth, int padding_height, + int padding_width) const { + PADDLE_ENFORCE(vol.dims().size() == 4); + PADDLE_ENFORCE(col.dims().size() == 7); + + int input_channels = vol.dims()[0]; + int input_depth = vol.dims()[1]; + int input_height = vol.dims()[2]; + int input_width = vol.dims()[3]; + int filter_depth = col.dims()[1]; + int filter_height = col.dims()[2]; + int filter_width = col.dims()[3]; + int output_depth = col.dims()[4]; + int output_height = col.dims()[5]; + int output_width = col.dims()[6]; + + int num_outputs = + input_channels * output_depth * output_height * output_width; + + const int threads = 1024; + const int blocks = (num_outputs + 1024 - 1) / 1024; + vol2col<<(context) + .stream()>>>( + num_outputs, vol.data(), input_depth, input_height, input_width, + filter_depth, filter_height, filter_width, stride_depth, stride_height, + stride_width, padding_depth, padding_height, padding_width, + output_depth, output_height, output_width, col.data()); + } +}; + +template +__global__ void col2vol(int num_kernels, const T* data_col, int depth, + int height, int width, int filter_depth, + int filter_height, int filter_width, int stride_depth, + int stride_height, int stride_width, int padding_depth, + int padding_height, int padding_width, int output_detph, + int output_height, int output_width, T* data_vol) { + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels; + index += blockDim.x * gridDim.x) { + T src_val = 0; + int w = index % width + padding_width; + int h = (index / width) % height + padding_height; + int d = (index / width / height) % depth + padding_depth; + int c = index / width / height / depth; + // compute the start and end of the output + int w_col_start = + (w < filter_width) ? 0 : (w - filter_width) / stride_width + 1; + int w_col_end = min(w / stride_width + 1, output_width); + int h_col_start = + (h < filter_height) ? 0 : (h - filter_height) / stride_height + 1; + int h_col_end = min(h / stride_height + 1, output_height); + int d_col_start = + (d < filter_depth) ? 0 : (d - filter_depth) / stride_depth + 1; + int d_col_end = min(d / stride_depth + 1, output_detph); + + int offset = (c * filter_depth * filter_height * filter_width + + d * filter_width * filter_height + h * filter_width + w) * + output_detph * output_height * output_width; + + int coeff_d_col = + (1 - stride_depth * filter_width * filter_height * output_detph) * + output_height * output_width; + int coeff_h_col = + (1 - stride_height * filter_width * output_detph * output_height) * + output_width; + int coeff_w_col = + (1 - stride_width * output_detph * output_height * output_width); + + for (int d_col = d_col_start; d_col < d_col_end; ++d_col) { + for (int h_col = h_col_start; h_col < h_col_end; ++h_col) { + for (int w_col = w_col_start; w_col < w_col_end; ++w_col) { + src_val += data_col[offset + d_col * coeff_d_col + + h_col * coeff_h_col + w_col * coeff_w_col]; + } + } + } + data_vol[index] = src_val; + } +} + +/* + * im = [input_channels, input_depth, input_height, input_width] + * col = + * [input_channels, filter_depth, filter_height, filter_width, + * output_depth, output_height, output_width] + */ +template +class Col2VolFunctor { + public: + void operator()(const platform::DeviceContext& context, + framework::Tensor& vol, const framework::Tensor& col, + int stride_depth, int stride_height, int stride_width, + int padding_depth, int padding_height, + int padding_width) const { + PADDLE_ENFORCE(vol.dims().size() == 4); + PADDLE_ENFORCE(col.dims().size() == 7); + + int input_channels = vol.dims()[0]; + int input_depth = vol.dims()[1]; + int input_height = vol.dims()[2]; + int input_width = vol.dims()[3]; + int filter_depth = col.dims()[1]; + int filter_height = col.dims()[2]; + int filter_width = col.dims()[3]; + int output_depth = col.dims()[4]; + int output_height = col.dims()[5]; + int output_width = col.dims()[6]; + + int num_kernels = input_channels * input_depth * input_height * input_width; + + const int threads = 1024; + const int blocks = (num_kernels + 1024 - 1) / 1024; + + col2vol<<(context) + .stream()>>>( + num_kernels, col.data(), input_depth, input_height, input_width, + filter_depth, filter_height, filter_width, stride_depth, stride_height, + stride_width, padding_depth, padding_height, padding_width, + output_depth, output_height, output_width, vol.data()); + } +}; + +template class Vol2ColFunctor; +template class Vol2ColFunctor; +template class Col2VolFunctor; +template class Col2VolFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/vol2col.h b/paddle/operators/math/vol2col.h new file mode 100644 index 0000000000..f022365a16 --- /dev/null +++ b/paddle/operators/math/vol2col.h @@ -0,0 +1,78 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/tensor.h" +#include "paddle/platform/device_context.h" + +namespace paddle { +namespace operators { +namespace math { +/* + * \brief Converts the feature data of four dimensions(CDHW) into a colData of + * seven dimensions in the Vol2ColFunctor calculation, + * And in the Col2VolFunctor calculation, it is reversed. + * + * \param volData Vol data. + * \param volShape The shape of volData, + * [input_channels, input_depth, input_height, input_width]. + * \param colData Column data. + * \param colShape The shape of colData. + * + * The shape of colData is: + * [input_channels, filter_depth, filter_height, filter_width, output_depth, + * output_height, output_width] + * So, it is easy to reshape into a convolution matrix for convolution + * calculation based on matrix multiplication. + * The shape of convolution matrix is [height, width], where the height is equal + * input_channels * filter_depth * filter_height * filter_width, and the width + * is equal output_depth * output_height * output_width. + * + * Reshape: + * shape of colData shape of convolution matrix + * [input_channels, + * filter_depth, + * filter_height, + * filter_width, ======> [height, width] + * output_depth, + * output_height, + * output_width] + * + * \note The caller needs to ensure that volShape.inputChannels is equal to + * colShape.inputChannels. + */ +template +class Vol2ColFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& vol, framework::Tensor& col, + int stride_depth, int stride_height, int stride_width, + int padding_depth, int padding_height, + int padding_width) const; +}; + +template +class Col2VolFunctor { + public: + void operator()(const platform::DeviceContext& context, + framework::Tensor& vol, const framework::Tensor& col, + int stride_depth, int stride_height, int stride_width, + int padding_depth, int padding_height, + int padding_width) const; +}; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/vol2col_test.cc b/paddle/operators/math/vol2col_test.cc new file mode 100644 index 0000000000..81225e9a98 --- /dev/null +++ b/paddle/operators/math/vol2col_test.cc @@ -0,0 +1,135 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/vol2col.h" +#include +#include + +template +void testVol2col() { + paddle::framework::Tensor input; + paddle::framework::Tensor input_tmp; + paddle::framework::Tensor output; + paddle::framework::Tensor output_tmp; + + auto* place = new Place(); + paddle::platform::DeviceContext* context; + if (paddle::platform::is_cpu_place(*place)) { + context = + new paddle::platform::CPUDeviceContext(paddle::platform::CPUPlace()); + } else { +#ifdef PADDLE_WITH_CUDA + context = + new paddle::platform::CUDADeviceContext(paddle::platform::GPUPlace()); +#else + PADDLE_THROW("no GPU support"); +#endif // PADDLE_WITH_CUDA + } + + /** + * input = [[0, 1, 2, + * 3, 4, 5] + * [6, 7, 8, + * 9, 10, 11]] + * + * output = [0, 1 + * 1, 2 + * 3, 4 + * 4, 5 + * 6, 7 + * 7, 8 + * 9, 10 + * 10, 11] + * + * col2vol = [[0, 2, 2, + * 3, 8, 5] + * [6, 14, 8, + * 9, 20, 11]] + * + */ + int input_depth = 2; + int input_height = 2; + int input_width = 3; + int filter_size = 2; + int stride = 1; + int padding = 0; + int output_depth = (input_depth - filter_size + 2 * padding) / stride + 1; + int output_height = (input_height - filter_size + 2 * padding) / stride + 1; + int output_width = (input_width - filter_size + 2 * padding) / stride + 1; + + // Vol2Col test + float* input_ptr = + input_tmp.mutable_data({1, input_depth, input_height, input_width}, + paddle::platform::CPUPlace()); + float arr[12] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; + memcpy(input_ptr, arr, 12 * sizeof(float)); + + if (paddle::platform::is_cpu_place(*place)) { + input = input_tmp; + } else { + input.CopyFrom(input_tmp, *place); + } + output.mutable_data({1, filter_size, filter_size, filter_size, + output_depth, output_height, output_width}, + *place); + + paddle::operators::math::Vol2ColFunctor vol2col; + vol2col(*context, input, output, stride, stride, stride, padding, padding, + padding); + + float vol_2_col[] = {0, 1, 1, 2, 3, 4, 4, 5, 6, 7, 7, 8, 9, 10, 10, 11}; + float* out_cfo_ptr; + if (paddle::platform::is_cpu_place(*place)) { + out_cfo_ptr = output.data(); + } else { + output_tmp.CopyFrom(output, paddle::platform::CPUPlace()); + out_cfo_ptr = output_tmp.data(); + } + + for (int i = 0; i < 16; ++i) { + EXPECT_EQ(out_cfo_ptr[i], vol_2_col[i]); + } + + // Col2Vol test + float col_2_vol[] = {0, 2, 2, 3, 8, 5, 6, 14, 8, 9, 20, 11}; + memset(input_ptr, 0, 12 * sizeof(float)); + if (paddle::platform::is_cpu_place(*place)) { + input = input_tmp; + } else { + input.CopyFrom(input_tmp, *place); + } + + paddle::operators::math::Col2VolFunctor col2vol; + col2vol(*context, input, output, stride, stride, stride, padding, padding, + padding); + + float* in_ptr; + if (paddle::platform::is_cpu_place(*place)) { + in_ptr = input.data(); + } else { + input_tmp.CopyFrom(input, paddle::platform::CPUPlace()); + in_ptr = input_tmp.data(); + } + + for (int i = 0; i < 12; ++i) { + EXPECT_EQ(in_ptr[i], col_2_vol[i]); + } +} + +TEST(math, vol2col) { + testVol2col(); +#ifdef PADDLE_WITH_CUDA + testVol2col(); +#endif // PADDLE_WITH_CUDA +} diff --git a/paddle/operators/mean_op.cc b/paddle/operators/mean_op.cc index 2332c9546b..441543049f 100644 --- a/paddle/operators/mean_op.cc +++ b/paddle/operators/mean_op.cc @@ -22,7 +22,7 @@ class MeanOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of MeanOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), @@ -47,7 +47,7 @@ class MeanGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); } }; diff --git a/paddle/operators/minus_op.cc b/paddle/operators/minus_op.cc index 7057dcbd6e..d7fd2f901b 100644 --- a/paddle/operators/minus_op.cc +++ b/paddle/operators/minus_op.cc @@ -26,7 +26,7 @@ class MinusOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of MinusOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Y"), diff --git a/paddle/operators/modified_huber_loss_op.cc b/paddle/operators/modified_huber_loss_op.cc index 84212a2b3b..6522327fdc 100644 --- a/paddle/operators/modified_huber_loss_op.cc +++ b/paddle/operators/modified_huber_loss_op.cc @@ -22,7 +22,7 @@ class ModifiedHuberLossOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "X must be initialized."); PADDLE_ENFORCE(ctx->HasInput("Y"), "Y must be initialized."); @@ -74,7 +74,7 @@ class ModifiedHuberLossGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "X must be initialized."); PADDLE_ENFORCE(ctx->HasInput("Y"), "Y must be initialized."); PADDLE_ENFORCE(ctx->HasInput("IntermediateVal"), diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index 3c8fe04d2e..ec0683d887 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -24,7 +24,7 @@ class MulOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of MulOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) of MulOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), @@ -97,7 +97,7 @@ class MulOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null"); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), diff --git a/paddle/operators/multiplex_op.cc b/paddle/operators/multiplex_op.cc index a069127a19..a86685b6dd 100644 --- a/paddle/operators/multiplex_op.cc +++ b/paddle/operators/multiplex_op.cc @@ -24,7 +24,7 @@ class MultiplexOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Ids"), "Input(Ids) shouldn't be null."); PADDLE_ENFORCE(!ctx->Inputs("X").empty(), "MultiInput(X) shouldn't be empty."); @@ -90,7 +90,7 @@ class MultiplexGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(!ctx->Inputs("X").empty(), "Input(X) should not be null."); PADDLE_ENFORCE(!ctx->Outputs(framework::GradVarName("X")).empty(), "Output(X@Grad) should not be null."); diff --git a/paddle/operators/net_op.h b/paddle/operators/net_op.h index 2388b094d2..ebeb262d96 100644 --- a/paddle/operators/net_op.h +++ b/paddle/operators/net_op.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include "paddle/framework/framework.pb.h" #include "paddle/framework/op_registry.h" diff --git a/paddle/operators/pad_op.cc b/paddle/operators/pad_op.cc index 15aa05f266..2f26ada85e 100644 --- a/paddle/operators/pad_op.cc +++ b/paddle/operators/pad_op.cc @@ -24,7 +24,7 @@ class PadOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of PadOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of PadOp should not be null."); @@ -98,7 +98,7 @@ class PadOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), "Input(Out@GRAD) should not be null"); diff --git a/paddle/operators/pool_op.cc b/paddle/operators/pool_op.cc index c29f51f056..ba3b5ed207 100644 --- a/paddle/operators/pool_op.cc +++ b/paddle/operators/pool_op.cc @@ -27,7 +27,7 @@ class PoolOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "X(Input) of Pooling should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), @@ -74,7 +74,7 @@ class PoolOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "X(Input) of Pooling should not be null."); PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), diff --git a/paddle/operators/pool_with_index_op.cc b/paddle/operators/pool_with_index_op.cc new file mode 100644 index 0000000000..7b6afcfd1f --- /dev/null +++ b/paddle/operators/pool_with_index_op.cc @@ -0,0 +1,228 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/pool_with_index_op.h" + +namespace paddle { +namespace operators { + +inline int OutputSizeMaxPool(int input_size, int filter_size, int padding, + int stride) { + int output_size = (input_size - filter_size + 2 * padding) / stride + 1; + return output_size; +} + +class MaxPoolWithIndexOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "X(Input) of Pooling should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Out(Output) of Pooling should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Mask"), + "Mask(Output) of Pooling should not be null."); + + auto in_x_dims = ctx->GetInputDim("X"); + + std::vector ksize = ctx->Attrs().Get>("ksize"); + std::vector strides = ctx->Attrs().Get>("strides"); + std::vector paddings = ctx->Attrs().Get>("paddings"); + + PADDLE_ENFORCE(in_x_dims.size() == 4 || in_x_dims.size() == 5, + "Pooling intput should be 4-D or 5-D"); + + if (ctx->Attrs().Get("globalPooling")) { + ksize.resize(static_cast(in_x_dims.size()) - 2); + for (size_t i = 0; i < ksize.size(); ++i) + ksize[i] = static_cast(in_x_dims[i + 2]); + } + + PADDLE_ENFORCE(in_x_dims.size() - ksize.size() == 2U, + "Intput size and pooling size should be consistent."); + PADDLE_ENFORCE_EQ(ksize.size(), strides.size(), + "Strides size and pooling size should be the same."); + PADDLE_ENFORCE_EQ(ksize.size(), paddings.size(), + "Paddings size and pooling size should be the same."); + + std::vector output_shape({in_x_dims[0], in_x_dims[1]}); + for (size_t i = 0; i < ksize.size(); ++i) { + output_shape.push_back(OutputSizeMaxPool(in_x_dims[i + 2], ksize[i], + paddings[i], strides[i])); + } + ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); + ctx->SetOutputDim("Mask", framework::make_ddim(output_shape)); + } +}; + +class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), + "Input(X@GRAD) should not be null."); + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + } +}; + +class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker { + public: + MaxPool2dWithIndexOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "X", + "The input tensor of pooling operator. " + "The format of input tensor is NCHW. Where N is batch size, C is the " + "number of channels, H and W is the height and width of image."); + AddOutput("Out", + "The output tensor of pooling operator." + "The format of output tensor is also NCHW." + "Where N is batch size, C is " + "the number of channels, H and W is the height and " + "width of image."); + AddOutput("Mask", + "The Mask tensor of pooling operator." + "The format of output tensor is also NCHW." + "Where N is batch size, C is the number of channels, H and W " + "is the height and width of image." + "The value in it is the index in current feature map"); + + AddAttr>( + "ksize", + "The pooling size(height, width) of pooling operator." + "If globalPooling = true, ksize is ignored and need not be " + "specified."); // TODO(Chengduo): Add checker. (Currently, + // TypedAttrChecker don't support vector type.) + AddAttr( + "globalPooling", + "Whether to use the globalPooling." + "Bool constant equal to false or true." + "Default false." + "If globalPooling = true, ksize is ignored and need not be specified.") + .SetDefault(false); + AddAttr>("strides", + "Strides(height, width) of pooling operator." + "Default {1,1}.") + .SetDefault({1, 1}); // TODO(Chengduo): Add checker. (Currently, + // TypedAttrChecker don't support vector type.) + AddAttr>("paddings", + "Paddings(height, width) of pooling operator." + "Default {0,0}.") + .SetDefault({0, 0}); // TODO(Chengduo): Add checker. (Currently, + // TypedAttrChecker don't support vector type.) + + AddComment(R"DOC( +The maxPooling2d with index operation calculates the output and the mask +based on the input and ksize, strides, paddings parameters. Input(X) and +output(Out, Mask) are in NCHW format. Where N is batch size, C is the +number of channels, H and W is the height and width of feature. +Parameters(ksize, strides, paddings) are two elements. +These two elements represent height and width, respectively. +)DOC"); + } +}; + +class MaxPool3dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker { + public: + MaxPool3dWithIndexOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "X", + "The input tensor of pooling operator. " + "The format of input tensor is NCDHW. Where N is batch size, C is " + "the number of channels, D, H and W is the depth, height and width of " + "image."); + AddOutput("Out", + "The output tensor of pooling operator." + "The format of output tensor is also NCDHW." + "Where N is batch size, C is " + "the number of channels, D, H and W is the depth, height and " + "width of image."); + AddOutput("Mask", + "The Mask tensor of pooling operator." + "The format of output tensor is also NCDHW." + "Where N is batch size, C is the number of channels, D, H and W " + "is the depth, height and width of image." + "The value in it is the index in current feature map"); + + AddAttr>( + "ksize", + "The pooling size(depth, height, width) of pooling operator." + "If globalPooling = true, ksize is ignored and need not be " + "specified."); // TODO(Chengduo): Add checker. (Currently, + // TypedAttrChecker don't support vector type.) + AddAttr( + "globalPooling", + "Whether to use the globalPooling." + "Bool constant equal to false or true." + "Default false." + "If globalPooling = true, ksize is ignored and need not be specified.") + .SetDefault(false); + AddAttr>( + "strides", + "Strides(depth, height, width) of pooling operator." + "Default {1,1,1}.") + .SetDefault({1, 1, 1}); // TODO(Chengduo): Add checker. (Currently, + // TypedAttrChecker don't support vector type.) + AddAttr>( + "paddings", + "Paddings(depth, height, width) of pooling operator." + "Default {0,0,0}.") + .SetDefault({0, 0, 0}); // TODO(Chengduo): Add checker. (Currently, + // TypedAttrChecker don't support vector type.) + + AddComment(R"DOC( +The maxpooling3d with index operation calculates the output and the mask +based on the input and ksize, strides, paddings parameters. +Input(X) and output(Out, Mask) are in NCDHW format. Where N is batch +size, C is the number of channels, D, H and W is the depth, height and +width of feature. Parameters(ksize, strides, paddings) are three elements. +These three elements represent depth, height and width, respectively. +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP(max_pool2d_with_index, ops::MaxPoolWithIndexOp, + ops::MaxPool2dWithIndexOpMaker, max_pool2d_with_index_grad, + ops::MaxPoolWithIndexOpGrad); + +REGISTER_OP_CPU_KERNEL( + max_pool2d_with_index, + ops::MaxPoolWithIndexKernel); +REGISTER_OP_CPU_KERNEL( + max_pool2d_with_index_grad, + ops::MaxPoolWithIndexGradKernel) + +REGISTER_OP(max_pool3d_with_index, ops::MaxPoolWithIndexOp, + ops::MaxPool3dWithIndexOpMaker, max_pool3d_with_index_grad, + ops::MaxPoolWithIndexOpGrad); + +REGISTER_OP_CPU_KERNEL( + max_pool3d_with_index, + ops::MaxPoolWithIndexKernel); +REGISTER_OP_CPU_KERNEL( + max_pool3d_with_index_grad, + ops::MaxPoolWithIndexGradKernel) diff --git a/paddle/operators/pool_with_index_op.cu b/paddle/operators/pool_with_index_op.cu new file mode 100644 index 0000000000..287657d4b1 --- /dev/null +++ b/paddle/operators/pool_with_index_op.cu @@ -0,0 +1,31 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/pool_with_index_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_GPU_KERNEL( + max_pool2d_with_index, + ops::MaxPoolWithIndexKernel); +REGISTER_OP_GPU_KERNEL( + max_pool2d_with_index_grad, + ops::MaxPoolWithIndexGradKernel) + +REGISTER_OP_GPU_KERNEL( + max_pool3d_with_index, + ops::MaxPoolWithIndexKernel); +REGISTER_OP_GPU_KERNEL( + max_pool3d_with_index_grad, + ops::MaxPoolWithIndexGradKernel) diff --git a/paddle/operators/pool_with_index_op.h b/paddle/operators/pool_with_index_op.h new file mode 100644 index 0000000000..01b961ca82 --- /dev/null +++ b/paddle/operators/pool_with_index_op.h @@ -0,0 +1,103 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" +#include "paddle/operators/math/pooling.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class MaxPoolWithIndexKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* in_x = context.Input("X"); + Tensor* out = context.Output("Out"); + Tensor* mask = context.Output("Mask"); + + std::vector ksize = context.Attr>("ksize"); + std::vector strides = context.Attr>("strides"); + std::vector paddings = context.Attr>("paddings"); + if (context.Attr("globalPooling")) { + for (size_t i = 0; i < ksize.size(); ++i) { + ksize[i] = static_cast(in_x->dims()[i + 2]); + } + } + + switch (ksize.size()) { + case 2: { + paddle::operators::math::MaxPool2dWithIndexFunctor + pool2d_forward; + pool2d_forward(context.device_context(), *in_x, *out, *mask, ksize, + strides, paddings); + } break; + case 3: { + paddle::operators::math::MaxPool3dWithIndexFunctor + pool3d_forward; + pool3d_forward(context.device_context(), *in_x, *out, *mask, ksize, + strides, paddings); + } break; + } + } +}; + +template +class MaxPoolWithIndexGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* mask = context.Input("Mask"); + const Tensor* out_grad = + context.Input(framework::GradVarName("Out")); + Tensor* in_x_grad = context.Output(framework::GradVarName("X")); + + std::vector ksize = context.Attr>("ksize"); + std::vector strides = context.Attr>("strides"); + std::vector paddings = context.Attr>("paddings"); + if (context.Attr("globalPooling")) { + for (size_t i = 0; i < ksize.size(); ++i) { + ksize[i] = static_cast(in_x_grad->dims()[i + 2]); + } + } + + if (in_x_grad) { + in_x_grad->mutable_data(context.GetPlace()); + auto temp = framework::EigenVector::Flatten(*in_x_grad); + temp.device(context.GetEigenDevice()) = + temp.constant(static_cast(0)); + + switch (ksize.size()) { + case 2: { + paddle::operators::math::MaxPool2dWithIndexGradFunctor + pool2d_backward; + pool2d_backward(context.device_context(), *in_x_grad, *out_grad, + *mask, ksize, strides, paddings); + } break; + case 3: { + paddle::operators::math::MaxPool3dWithIndexGradFunctor + pool3d_backward; + pool3d_backward(context.device_context(), *in_x_grad, *out_grad, + *mask, ksize, strides, paddings); + } break; + } + } + } +}; +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/prelu_op.cc b/paddle/operators/prelu_op.cc index 1692464f28..166fe26824 100644 --- a/paddle/operators/prelu_op.cc +++ b/paddle/operators/prelu_op.cc @@ -26,7 +26,7 @@ class PReluOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); PADDLE_ENFORCE(ctx->HasInput("Alpha"), "Input(Alpha) should not be null"); PADDLE_ENFORCE(product(ctx->GetInputDim("Alpha")) == 1, @@ -63,7 +63,7 @@ class PReluGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), "Input(Out@GRAD) should not be null"); diff --git a/paddle/operators/rank_loss_op.cc b/paddle/operators/rank_loss_op.cc index 1ba22006f2..e0abbc4db1 100644 --- a/paddle/operators/rank_loss_op.cc +++ b/paddle/operators/rank_loss_op.cc @@ -25,7 +25,7 @@ class RankLossOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { // input check PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) shouldn't be null"); PADDLE_ENFORCE(ctx->HasInput("Left"), "Input(Left) shouldn't be null"); @@ -90,7 +90,7 @@ class RankLossGradOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) shouldn't be null."); PADDLE_ENFORCE(ctx->HasInput("Left"), "Input(Left) shouldn't be null."); PADDLE_ENFORCE(ctx->HasInput("Right"), "Input(Right) shouldn't be null."); diff --git a/paddle/operators/reduce_op.cc b/paddle/operators/reduce_op.cc index 3ef443d1c7..005f88b57c 100644 --- a/paddle/operators/reduce_op.cc +++ b/paddle/operators/reduce_op.cc @@ -24,7 +24,7 @@ class ReduceOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of ReduceOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), @@ -58,7 +58,7 @@ class ReduceGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null."); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), "Input(Out@GRAD) should not be null."); @@ -168,36 +168,22 @@ namespace ops = paddle::operators; REGISTER_OP(reduce_sum, ops::ReduceOp, ops::ReduceSumOpMaker, reduce_sum_grad, ops::ReduceGradOp); -REGISTER_OP_CPU_KERNEL( - reduce_sum, - ops::ReduceKernel); -REGISTER_OP_CPU_KERNEL(reduce_sum_grad, - ops::ReduceGradKernel); REGISTER_OP(reduce_mean, ops::ReduceOp, ops::ReduceMeanOpMaker, reduce_mean_grad, ops::ReduceGradOp); -REGISTER_OP_CPU_KERNEL( - reduce_mean, - ops::ReduceKernel); -REGISTER_OP_CPU_KERNEL(reduce_mean_grad, - ops::ReduceGradKernel); REGISTER_OP(reduce_max, ops::ReduceOp, ops::ReduceMaxOpMaker, reduce_max_grad, ops::ReduceGradOp); -REGISTER_OP_CPU_KERNEL( - reduce_max, - ops::ReduceKernel); -REGISTER_OP_CPU_KERNEL(reduce_max_grad, - ops::ReduceGradKernel); - -REGISTER_OP(reduce_min, ops::ReduceOp, ops::ReduceMaxOpMaker, reduce_min_grad, + +REGISTER_OP(reduce_min, ops::ReduceOp, ops::ReduceMinOpMaker, reduce_min_grad, ops::ReduceGradOp); -REGISTER_OP_CPU_KERNEL( - reduce_min, - ops::ReduceKernel); -REGISTER_OP_CPU_KERNEL(reduce_min_grad, - ops::ReduceGradKernel); + +#define REGISTER_REDUCE_CPU_KERNEL(reduce_type, functor, grad_functor) \ + REGISTER_OP_CPU_KERNEL( \ + reduce_type, \ + ops::ReduceKernel); \ + REGISTER_OP_CPU_KERNEL(reduce_type##_grad, \ + ops::ReduceGradKernel); + +FOR_EACH_KERNEL_FUNCTOR(REGISTER_REDUCE_CPU_KERNEL); diff --git a/paddle/operators/reduce_op.cu b/paddle/operators/reduce_op.cu index 595127b858..d306e1a240 100644 --- a/paddle/operators/reduce_op.cu +++ b/paddle/operators/reduce_op.cu @@ -17,30 +17,12 @@ namespace ops = paddle::operators; -REGISTER_OP_GPU_KERNEL( - reduce_sum, - ops::ReduceKernel); -REGISTER_OP_GPU_KERNEL(reduce_sum_grad, - ops::ReduceGradKernel); - -REGISTER_OP_GPU_KERNEL( - reduce_mean, - ops::ReduceKernel); -REGISTER_OP_GPU_KERNEL(reduce_mean_grad, - ops::ReduceGradKernel); - -REGISTER_OP_GPU_KERNEL( - reduce_max, - ops::ReduceKernel); -REGISTER_OP_GPU_KERNEL(reduce_max_grad, - ops::ReduceGradKernel); - -REGISTER_OP_GPU_KERNEL( - reduce_min, - ops::ReduceKernel); -REGISTER_OP_GPU_KERNEL(reduce_min_grad, - ops::ReduceGradKernel); +#define REGISTER_REDUCE_GPU_KERNEL(reduce_type, functor, grad_functor) \ + REGISTER_OP_GPU_KERNEL( \ + reduce_type, \ + ops::ReduceKernel); \ + REGISTER_OP_GPU_KERNEL(reduce_type##_grad, \ + ops::ReduceGradKernel); + +FOR_EACH_KERNEL_FUNCTOR(REGISTER_REDUCE_GPU_KERNEL); diff --git a/paddle/operators/reduce_op.h b/paddle/operators/reduce_op.h index ba3f3db81d..45043c440b 100644 --- a/paddle/operators/reduce_op.h +++ b/paddle/operators/reduce_op.h @@ -198,3 +198,9 @@ class ReduceGradKernel : public framework::OpKernel { } // namespace operators } // namespace paddle + +#define FOR_EACH_KERNEL_FUNCTOR(__macro) \ + __macro(reduce_sum, SumFunctor, SumGradFunctor); \ + __macro(reduce_mean, MeanFunctor, MeanGradFunctor); \ + __macro(reduce_max, MaxFunctor, MaxOrMinGradFunctor); \ + __macro(reduce_min, MinFunctor, MaxOrMinGradFunctor); diff --git a/paddle/operators/reshape_op.cc b/paddle/operators/reshape_op.cc index a3c3fa2716..3cd54930a0 100644 --- a/paddle/operators/reshape_op.cc +++ b/paddle/operators/reshape_op.cc @@ -26,7 +26,7 @@ class ReshapeOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { // input check PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of ReshapeOp should not be null."); @@ -94,7 +94,7 @@ class ReshapeGradOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) shouldn't be null."); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), "Input(Out@GRAD) shouldn't be null."); diff --git a/paddle/operators/rmsprop_op.cc b/paddle/operators/rmsprop_op.cc index 8f61c7fdda..ada6f2bc3c 100644 --- a/paddle/operators/rmsprop_op.cc +++ b/paddle/operators/rmsprop_op.cc @@ -22,7 +22,7 @@ class RmspropOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Param"), "Input(Param) of RmspropOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("MeanSquare"), diff --git a/paddle/operators/scale_op.cc b/paddle/operators/scale_op.cc index e225aecc27..ac297da6b7 100644 --- a/paddle/operators/scale_op.cc +++ b/paddle/operators/scale_op.cc @@ -26,7 +26,7 @@ class ScaleOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of ScaleOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), diff --git a/paddle/operators/scatter_op.cc b/paddle/operators/scatter_op.cc index d15ba15153..fbea01a8db 100644 --- a/paddle/operators/scatter_op.cc +++ b/paddle/operators/scatter_op.cc @@ -23,7 +23,7 @@ class ScatterOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Ref"), "Input(Ref) of ScatterOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Index"), @@ -60,7 +60,7 @@ class ScatterGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { ctx->SetOutputDim(framework::GradVarName("Updates"), ctx->GetInputDim("Updates")); ctx->SetOutputDim(framework::GradVarName("Ref"), ctx->GetInputDim("Ref")); diff --git a/paddle/operators/sequence_pool_op.cc b/paddle/operators/sequence_pool_op.cc index bc4af2f704..06c00d31ea 100644 --- a/paddle/operators/sequence_pool_op.cc +++ b/paddle/operators/sequence_pool_op.cc @@ -22,7 +22,7 @@ class SequencePoolOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of SequencePoolOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), @@ -74,7 +74,7 @@ class SequencePoolGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), "Gradient of Out should not be null."); PADDLE_ENFORCE(ctx->HasInput("X"), "The input X should not be null."); diff --git a/paddle/operators/sequence_softmax_op.cc b/paddle/operators/sequence_softmax_op.cc index 621779ab61..ea217ba459 100644 --- a/paddle/operators/sequence_softmax_op.cc +++ b/paddle/operators/sequence_softmax_op.cc @@ -22,7 +22,7 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of SequenceSoftmaxOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), @@ -67,7 +67,7 @@ class SequenceSoftmaxGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Out"), "Input(Out) of SequenceSoftmaxGradOp should not be null."); PADDLE_ENFORCE( diff --git a/paddle/operators/sgd_op.cc b/paddle/operators/sgd_op.cc index 31d491f130..2a6a162a02 100644 --- a/paddle/operators/sgd_op.cc +++ b/paddle/operators/sgd_op.cc @@ -22,7 +22,7 @@ class SGDOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Param"), "Input(Param) of SGDOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Grad"), diff --git a/paddle/operators/sigmoid_cross_entropy_with_logits_op.cc b/paddle/operators/sigmoid_cross_entropy_with_logits_op.cc index ede458e011..b6653e1cc7 100644 --- a/paddle/operators/sigmoid_cross_entropy_with_logits_op.cc +++ b/paddle/operators/sigmoid_cross_entropy_with_logits_op.cc @@ -24,7 +24,7 @@ class SigmoidCrossEntropyWithLogitsOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); PADDLE_ENFORCE(ctx->HasInput("Labels"), "Input(Labels) should be not null."); @@ -53,7 +53,7 @@ class SigmoidCrossEntropyWithLogitsGradOp using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); PADDLE_ENFORCE(ctx->HasInput("Labels"), "Input(Labels) should be not null."); diff --git a/paddle/operators/smooth_l1_loss_op.cc b/paddle/operators/smooth_l1_loss_op.cc index 2d197e3b1b..91391dc945 100644 --- a/paddle/operators/smooth_l1_loss_op.cc +++ b/paddle/operators/smooth_l1_loss_op.cc @@ -22,7 +22,7 @@ class SmoothL1LossOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "X must be initialized."); PADDLE_ENFORCE(ctx->HasInput("Y"), "Y must be initialized."); @@ -94,7 +94,7 @@ class SmoothL1LossGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { auto in_dims = ctx->GetInputDim("X"); auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); diff --git a/paddle/operators/softmax_op.cc b/paddle/operators/softmax_op.cc index e353afee3e..4c131ed44d 100644 --- a/paddle/operators/softmax_op.cc +++ b/paddle/operators/softmax_op.cc @@ -22,7 +22,7 @@ class SoftmaxOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of SoftmaxOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Y"), @@ -69,7 +69,7 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should be not null."); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), "Input(Y@GRAD) should be not null."); diff --git a/paddle/operators/softmax_with_cross_entropy_op.cc b/paddle/operators/softmax_with_cross_entropy_op.cc index 42c1ba6fdf..5431a1657c 100644 --- a/paddle/operators/softmax_with_cross_entropy_op.cc +++ b/paddle/operators/softmax_with_cross_entropy_op.cc @@ -83,7 +83,7 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Logits"), "Input(Logits) should be not null."); PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null."); @@ -128,7 +128,7 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Loss")), "Input(Loss@Grad) should not be null."); PADDLE_ENFORCE(ctx->HasInput("Softmax"), diff --git a/paddle/operators/split_op.cc b/paddle/operators/split_op.cc index 5f4b5539af..d5dd4df2a2 100644 --- a/paddle/operators/split_op.cc +++ b/paddle/operators/split_op.cc @@ -24,7 +24,7 @@ class SplitOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of SplitOp should not be null."); PADDLE_ENFORCE_GE(ctx->Outputs("Out").size(), 1UL, diff --git a/paddle/operators/squared_l2_distance_op.cc b/paddle/operators/squared_l2_distance_op.cc index 5a0cb59600..cce4e527c3 100644 --- a/paddle/operators/squared_l2_distance_op.cc +++ b/paddle/operators/squared_l2_distance_op.cc @@ -22,7 +22,7 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of SquaredL2DistanceOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Y"), @@ -86,7 +86,7 @@ class SquaredL2DistanceGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), "Gradient of Out should not be null"); auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); diff --git a/paddle/operators/sum_op.cc b/paddle/operators/sum_op.cc index c701ee8dde..ffb0cb9211 100644 --- a/paddle/operators/sum_op.cc +++ b/paddle/operators/sum_op.cc @@ -22,7 +22,7 @@ class SumOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInputs("X"), "Inputs(X) should not be null"); auto x_dims = ctx->GetInputsDim("X"); PADDLE_ENFORCE(ctx->HasOutput("Out"), diff --git a/paddle/operators/top_k_op.cc b/paddle/operators/top_k_op.cc index 5f22bf1df8..c954819912 100644 --- a/paddle/operators/top_k_op.cc +++ b/paddle/operators/top_k_op.cc @@ -22,7 +22,7 @@ class TopkOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of TopkOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), diff --git a/paddle/operators/transpose_op.cc b/paddle/operators/transpose_op.cc index 0672f9342d..1101bbe3ef 100644 --- a/paddle/operators/transpose_op.cc +++ b/paddle/operators/transpose_op.cc @@ -24,7 +24,7 @@ class TransposeOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null"); auto x_dims = ctx->GetInputDim("X"); @@ -93,7 +93,7 @@ class TransposeOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), "Input(Out@GRAD) should not be null"); diff --git a/paddle/operators/uniform_random_op.cc b/paddle/operators/uniform_random_op.cc index 97b1d0bed4..e330877fc4 100644 --- a/paddle/operators/uniform_random_op.cc +++ b/paddle/operators/uniform_random_op.cc @@ -47,7 +47,7 @@ class UniformRandomOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of UniformRandomOp should not be null."); diff --git a/paddle/platform/device_context.cc b/paddle/platform/device_context.cc index a9b6b79903..36450e9268 100644 --- a/paddle/platform/device_context.cc +++ b/paddle/platform/device_context.cc @@ -136,7 +136,7 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; } cudaStream_t CUDADeviceContext::stream() const { return stream_; } -#endif // PADDLE_ONLY_CPU +#endif } // namespace platform } // namespace paddle diff --git a/paddle/platform/enforce.h b/paddle/platform/enforce.h index 15d8446cd8..cd906c3fa9 100644 --- a/paddle/platform/enforce.h +++ b/paddle/platform/enforce.h @@ -41,7 +41,7 @@ limitations under the License. */ #include #include -#endif // PADDLE_ONLY_CPU +#endif namespace paddle { namespace platform { diff --git a/paddle/platform/gpu_info.cc b/paddle/platform/gpu_info.cc index 70ad611d5d..0cab5ffc56 100644 --- a/paddle/platform/gpu_info.cc +++ b/paddle/platform/gpu_info.cc @@ -43,6 +43,8 @@ int GetCurrentDeviceId() { } void SetDeviceId(int id) { + // TODO(qijun): find a better way to cache the cuda device count + PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count"); PADDLE_ENFORCE(cudaSetDevice(id), "cudaSetDevice failed in paddle::platform::SetDeviceId"); } diff --git a/paddle/platform/gpu_info.h b/paddle/platform/gpu_info.h index fb33db07bd..37665b97d7 100644 --- a/paddle/platform/gpu_info.h +++ b/paddle/platform/gpu_info.h @@ -63,4 +63,4 @@ void GpuMemcpyPeer(void *dst, int dst_device, const void *src, int src_device, } // namespace platform } // namespace paddle -#endif // PADDLE_ONLY_CPU +#endif diff --git a/paddle/pybind/protobuf.cc b/paddle/pybind/protobuf.cc index 218821b35b..116c99bd2c 100644 --- a/paddle/pybind/protobuf.cc +++ b/paddle/pybind/protobuf.cc @@ -117,7 +117,6 @@ void BindProgramDesc(py::module &m) { .def("append_block", &ProgramDescBind::AppendBlock, py::return_value_policy::reference) .def("block", &ProgramDescBind::Block, py::return_value_policy::reference) - .def("__str__", &ProgramDescBind::DebugString) .def("num_blocks", &ProgramDescBind::Size); } @@ -167,7 +166,9 @@ void BindVarDsec(py::module &m) { .def("set_shape", &VarDescBind::SetShape) .def("set_data_type", &VarDescBind::SetDataType) .def("shape", &VarDescBind::Shape, py::return_value_policy::reference) - .def("data_type", &VarDescBind::GetDataType); + .def("data_type", &VarDescBind::GetDataType) + .def("lod_level", &VarDescBind::GetLodLevel) + .def("set_lod_level", &VarDescBind::SetLoDLevel); } void BindOpDesc(py::module &m) { @@ -191,15 +192,14 @@ void BindOpDesc(py::module &m) { .def("output", &OpDescBind::Output) .def("output_names", &OpDescBind::OutputNames) .def("set_output", &OpDescBind::SetOutput) - .def("__str__", &OpDescBind::DebugString) - .def("__repr__", &OpDescBind::DebugString) .def("has_attr", &OpDescBind::HasAttr) .def("attr_type", &OpDescBind::GetAttrType) .def("attr_names", &OpDescBind::AttrNames) .def("set_attr", &OpDescBind::SetAttr) .def("attr", &OpDescBind::GetAttr) .def("set_block_attr", &OpDescBind::SetBlockAttr) - .def("get_block_attr", &OpDescBind::GetBlockAttr); + .def("get_block_attr", &OpDescBind::GetBlockAttr) + .def("infer_shape", &OpDescBind::InferShape); } } // namespace pybind diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 356c4986e2..0f6e3101e2 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -231,21 +231,6 @@ All parameter, weight, gradient are variables in Paddle. desc.InitializationErrorString()); return OpRegistry::CreateOp(desc); }) - .def_static("infer_shape", - [](OpDescBind &op_desc, BlockDescBind &block) { - auto op = OpRegistry::CreateOp(*op_desc.Proto()); - auto *op_with_kernel = - dynamic_cast(op.get()); - if (op_with_kernel != nullptr) { - auto ctx = CompileTimeInferShapeContext(op_desc, block); - op_with_kernel->InferShape(&ctx); - } else { - PADDLE_THROW( - "OP(%s) is not type of OperatorWithKernel, " - "should not call this function", - op_desc.Type()); - } - }) .def("backward", [](const OperatorBase &forwardOp, const std::unordered_set &no_grad_vars) { diff --git a/python/paddle/v2/framework/graph.py b/python/paddle/v2/framework/graph.py new file mode 100644 index 0000000000..0f0a2847e5 --- /dev/null +++ b/python/paddle/v2/framework/graph.py @@ -0,0 +1,240 @@ +import paddle.v2.framework.core as core +import collections +import numpy as np +import copy + +__all__ = ['Block', 'Variable', 'Program', 'Operator'] + + +class Variable(object): + def __init__(self, + block, + name=None, + shape=None, + dtype=None, + lod_level=None, + **kwargs): + self.block = block + + if name is None: + name = Variable._unique_var_name_() + try: + self.desc = self.block.desc.var(name) + is_new_var = False + except core.EnforceNotMet: + self.desc = self.block.desc.new_var(name) + is_new_var = True + + if shape is not None: + if is_new_var: + self.desc.set_shape(shape) + else: + old_shape = self.shape + shape = tuple(shape) + if shape != old_shape: + raise ValueError( + "Variable {0} has been created before. the previous " + "shape is {1}; the new shape is {2}. They are not " + "matched.".format(self.name, old_shape, shape)) + if dtype is not None: + if not isinstance(dtype, core.DataType): + dtype = Variable._convert_np_dtype_to_dtype_(dtype) + if is_new_var: + self.desc.set_data_type(dtype) + else: + old_dtype = self.data_type() + if dtype != old_shape: + raise ValueError("Variable {0} has been created before. " + "The previous data type is {1}; the new " + "data type is {2}. They are not " + "matched.".format(self.name, old_dtype, + dtype)) + + if lod_level is not None: + if is_new_var: + self.desc.set_lod_level(lod_level) + else: + if lod_level != self.lod_level: + raise ValueError("Variable {0} has been created before. " + "The previous lod_level is {1}; the new " + "lod_level is {2}. They are not " + "matched".format(self.name, self.lod_level, + lod_level)) + self.block.vars[name] = self + self.op = None + + @property + def name(self): + return self.desc.name() + + @property + def shape(self): + # convert to tuple, make it as same as numpy API. + return tuple(self.desc.shape()) + + @property + def data_type(self): + return self.desc.data_type() + + @property + def lod_level(self): + return self.desc.lod_level() + + @staticmethod + def _unique_var_name_(): + uid = core.unique_integer() # unique during whole process. + return "_generated_var_%d" % uid + + @staticmethod + def _convert_np_dtype_to_dtype_(np_dtype): + dtype = np.dtype(np_dtype) + if dtype == np.float32: + return core.DataType.FP32 + elif dtype == np.float64: + return core.DataType.FP64 + elif dtype == np.float16: + return core.DataType.FP16 + elif dtype == np.int32: + return core.DataType.INT32 + elif dtype == np.int16: + return core.DataType.INT16 + elif dtype == np.int64: + return core.DataType.INT64 + elif dtype == np.bool: + return core.DataType.BOOL + else: + raise ValueError("Not supported numpy dtype " + str(dtype)) + + +class Operator(object): + def __init__(self, + block, + desc, + type=None, + inputs=None, + outputs=None, + attrs=None): + self.block = block + self.desc = desc + if type is not None: + # TODO. + pass + if inputs is not None: + # TODO + pass + if outputs is not None: + # TODO + pass + if attrs is not None: + # TODO + pass + + # TODO: Getters + + +class Block(object): + def __init__(self, program, idx): + self.desc = program.desc.block(idx) + self.vars = dict() # var_name --> var + self.ops = collections.deque() # operator list + self.program = program + + @property + def parent_idx(self): + return self.desc.parent + + @property + def idx(self): + return self.desc.id + + def create_var(self, *args, **kwargs): + return Variable(self, *args, **kwargs) + + def create_parameter(self, *args, **kwargs): + global_block = self.program.global_block() + return Parameter(global_block, *args, **kwargs) + + def append_op(self, *args, **kwargs): + op_desc = self.desc.append_op() + op = Operator(self, op_desc, *args, **kwargs) + self.ops.append(op) + return op + + def prepend_op(self, *args, **kwargs): + op_desc = self.desc.prepend_op() + op = Operator(self, op_desc, *args, **kwargs) + self.ops.appendleft(op) + return op + + +class Program(object): + @classmethod + def instance(cls): + # From https://stackoverflow.com/questions/8212053 + # Making Program as a Singleton class. + if not hasattr(cls, '_instance'): + cls._instance = cls() + return cls._instance + + def __init__(self): + assert not hasattr(self.__class__, + '_instance'), 'Do not call constructor directly!' + self.desc = core.ProgramDesc.instance() + self.blocks = [Block(self, 0)] + self.current_block_idx = 0 + + def global_block(self): + return self.blocks[0] + + def current_block(self): + return self.blocks[self.current_block_idx] + + def create_block(self): + new_block_idx = len(self.blocks) + self.desc.append_block(self.current_block().desc) + self.current_block_idx = new_block_idx + self.blocks.append(Block(self, self.current_block_idx)) + return self.current_block() + + def rollback(self): + self.current_block_idx = self.current_block().parent_idx + + +class Parameter(Variable): + def __init__(self, block, shape, dtype, **kwargs): + if shape is None or dtype is None: + raise ValueError("Parameter must set shape and dtype") + if len(shape) == 0: + raise ValueError("Parameter shape cannot be empty") + + for each in shape: + if each < 0: + raise ValueError("Parameter shape should not be related with " + "batch-size") + + Variable.__init__(self, block, shape=shape, dtype=dtype, **kwargs) + self.trainable = kwargs.get('trainable', True) + self.init_attr = kwargs.get('initialize_attr', { + 'type': 'uniform_random', + 'min': -1.0, + 'max': 1.0 + }) + + self.optimize_attr = kwargs.get('optimize_attr', {'learning_rate': 1.0}) + self._append_initialize_ops_() + + def _append_initialize_ops_(self): + attr = copy.deepcopy(self.init_attr) + op_type = attr.pop('type', None) + block = self.block + assert isinstance(block, Block) + shape = self.shape + attr['dims'] = shape + attr['data_type'] = int(self.data_type) + op = block.prepend_op( + type=op_type, inputs=None, outputs={'Out': [self]}, attrs=attr) + self.op = op + + +# program is a global instance. +g_program = Program.instance() diff --git a/python/paddle/v2/framework/tests/test_activation_op.py b/python/paddle/v2/framework/tests/test_activation_op.py index 701e1a1aee..a28c4431e1 100644 --- a/python/paddle/v2/framework/tests/test_activation_op.py +++ b/python/paddle/v2/framework/tests/test_activation_op.py @@ -33,6 +33,21 @@ class TestSigmoid(OpTest): self.check_grad(['X'], 'Y', max_relative_error=0.008) +class TestLogSigmoid(OpTest): + def setUp(self): + self.op_type = "logsigmoid" + self.inputs = { + 'X': np.random.uniform(-1, 1, [11, 17]).astype("float32") + } + self.outputs = {'Y': np.log(1 / (1 + np.exp(-self.inputs['X'])))} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Y', max_relative_error=0.008) + + class TestTanh(OpTest): def setUp(self): self.op_type = "tanh" @@ -63,6 +78,46 @@ class TestTanhShrink(OpTest): self.check_grad(['X'], 'Y', max_relative_error=0.008) +class TestHardShrink(OpTest): + def setUp(self): + self.op_type = "hard_shrink" + x = np.random.uniform(-1, 1, [4, 4]).astype("float32") + threshold = 0.5 + + self.inputs = {'X': x} + self.attrs = {'lambda': threshold} + + t = np.copy(x) + t[(t >= -threshold) & (t <= threshold)] = 0 + self.outputs = {'Y': t} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Y', max_relative_error=0.005) + + +class TestSoftShrink(OpTest): + def setUp(self): + self.op_type = "softshrink" + lambda_val = 0.1 + self.attrs = {'lambda': lambda_val} + self.inputs = { + 'X': np.random.uniform(0.25, 10, [4, 4]).astype("float32") + } + y = np.copy(self.inputs['X']) + y = (y < -lambda_val) * (y + lambda_val) + (y > lambda_val) * ( + y - lambda_val) + self.outputs = {'Y': y} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Y', max_relative_error=0.007) + + class TestSqrt(OpTest): def setUp(self): self.op_type = "sqrt" @@ -137,21 +192,26 @@ class TestBRelu(OpTest): self.check_grad(['X'], 'Y', max_relative_error=0.02) -class TestLeakyRelu(OpTest): +class TestRelu6(OpTest): def setUp(self): - self.op_type = "leaky_relu" - alpha = 0.02 - self.attrs = {'alpha': alpha} - self.inputs = {'X': np.random.uniform(-3, 3, [4, 4]).astype("float32")} + self.op_type = "relu6" + x = np.random.uniform(-1, 1, [4, 10]).astype("float32") + threshold = 6.0 + # The same with TestAbs + x[np.abs(x) < 0.005] = 0.02 + x[np.abs(x - threshold) < 0.005] = threshold + 0.02 + + self.inputs = {'X': x} + self.attrs = {'threshold': threshold} self.outputs = { - 'Y': np.maximum(self.inputs['X'], alpha * self.inputs['X']) + 'Y': np.minimum(np.maximum(self.inputs['X'], 0), threshold) } def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(['X'], 'Y', max_relative_error=0.007) + self.check_grad(['X'], 'Y', max_relative_error=0.02) class TestSoftRelu(OpTest): @@ -176,6 +236,26 @@ class TestSoftRelu(OpTest): self.check_grad(['X'], 'Y', max_relative_error=0.02) +class TestELU(OpTest): + def setUp(self): + self.op_type = "elu" + x = np.random.uniform(-3, 3, [4, 4]).astype("float32") + alpha = 1. + # Note: unlike other Relu extensions, point 0 on standard ELU function (i.e. alpha = 1) + # is differentiable, so we can skip modifications like x[np.abs(x) < 0.005] = 0.02 here + self.inputs = {'X': x} + self.attrs = {'alpha': alpha} + self.outputs = { + 'Y': np.maximum(0, x) + np.minimum(0, alpha * (np.exp(x) - 1)) + } + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Y', max_relative_error=0.02) + + class TestReciprocal(OpTest): def setUp(self): self.op_type = "reciprocal" @@ -251,6 +331,21 @@ class TestSTanh(OpTest): self.check_grad(['X'], 'Y', max_relative_error=0.007) +class TestSoftplus(OpTest): + def setUp(self): + self.op_type = "softplus" + self.inputs = { + 'X': np.random.uniform(-1, 1, [11, 17]).astype("float32") + } + self.outputs = {'Y': np.log(1 + np.exp(self.inputs['X']))} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Y', max_relative_error=0.007) + + class TestSoftsign(OpTest): def setUp(self): self.op_type = "softsign" diff --git a/python/paddle/v2/framework/tests/test_adamax_op.py b/python/paddle/v2/framework/tests/test_adamax_op.py new file mode 100644 index 0000000000..af81075d6a --- /dev/null +++ b/python/paddle/v2/framework/tests/test_adamax_op.py @@ -0,0 +1,178 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestAdamaxOp1(OpTest): + def setUp(self): + '''Test Adamax Operator with supplied attributes + ''' + self.op_type = "adamax" + param = np.random.uniform(-1, 1, (102, 105)).astype("float32") + grad = np.random.uniform(-1, 1, (102, 105)).astype("float32") + moment = np.random.uniform(-1, 1, (102, 105)).astype("float32") + # The infinity norm is positive + inf_norm = np.random.random((102, 105)).astype("float32") + + learning_rate = 0.002 + beta1 = 0.78 + beta2 = 0.899 + epsilon = 1e-5 + beta1_pow = beta1**10 + + self.inputs = { + 'Param': param, + 'Grad': grad, + 'Moment': moment, + 'InfNorm': inf_norm, + 'LearningRate': np.array([learning_rate]).astype("float32"), + 'Beta1Pow': np.array([beta1_pow]).astype("float32") + } + + self.attrs = {'beta1': beta1, 'beta2': beta2, 'epsilon': epsilon} + + param_out, moment_out, inf_norm_out, beta1_pow_out = adamax_step( + self.inputs, self.attrs) + + self.outputs = { + 'ParamOut': param_out, + 'MomentOut': moment_out, + 'InfNormOut': inf_norm_out, + 'Beta1PowOut': beta1_pow_out + } + + def test_check_output(self): + self.check_output() + + +class TestAdamaxOp2(OpTest): + '''Test Adamax Operator with default attributes + ''' + + def setUp(self): + self.op_type = "adamax" + param = np.random.uniform(-1, 1, (102, 105)).astype("float32") + grad = np.random.uniform(-1, 1, (102, 105)).astype("float32") + moment = np.random.uniform(-1, 1, (102, 105)).astype("float32") + # The infinity norm is positive + inf_norm = np.random.random((102, 105)).astype("float32") + + learning_rate = 0.002 + beta1 = 0.9 + beta2 = 0.999 + epsilon = 1e-8 + beta1_pow = beta1**8 + + self.inputs = { + 'Param': param, + 'Grad': grad, + 'Moment': moment, + 'InfNorm': inf_norm, + 'LearningRate': np.array([learning_rate]).astype("float32"), + 'Beta1Pow': np.array([beta1_pow]).astype("float32") + } + + attrs = {'beta1': beta1, 'beta2': beta2, 'epsilon': epsilon} + param_out, moment_out, inf_norm_out, beta1_pow_out = adamax_step( + self.inputs, attrs) + + self.outputs = { + 'ParamOut': param_out, + 'MomentOut': moment_out, + 'InfNormOut': inf_norm_out, + 'Beta1PowOut': beta1_pow_out + } + + def test_check_output(self): + self.check_output() + + +class TestAdamaxOpMultipleSteps(OpTest): + def setUp(self): + '''Test Adamax Operator with supplied attributes + ''' + self.op_type = "adamax" + self.num_steps = 10 + + param = np.random.uniform(-1, 1, (102, 105)).astype("float32") + grad = np.random.uniform(-1, 1, (102, 105)).astype("float32") + moment = np.random.uniform(-1, 1, (102, 105)).astype("float32") + # The infinity norm is positive + inf_norm = np.random.random((102, 105)).astype("float32") + + learning_rate = 0.002 + beta1 = 0.8 + beta2 = 0.99 + epsilon = 1e-5 + beta1_pow = 1 + + self.inputs = { + 'Param': param, + 'Grad': grad, + 'Moment': moment, + 'InfNorm': inf_norm, + 'LearningRate': np.array([learning_rate]).astype("float32"), + 'Beta1Pow': np.array([beta1_pow]).astype("float32") + } + + self.attrs = {'beta1': beta1, 'beta2': beta2, 'epsilon': epsilon} + + param_out, moment_out, inf_norm_out, beta1_pow_out = adamax_step( + self.inputs, self.attrs) + + def test_check_output(self): + for _ in range(self.num_steps): + param_out, moment_out, inf_norm_out, beta1_pow_out = adamax_step( + self.inputs, self.attrs) + + self.outputs = { + 'ParamOut': param_out, + 'MomentOut': moment_out, + 'InfNormOut': inf_norm_out, + 'Beta1PowOut': beta1_pow_out + } + + # Verify output for this step + self.check_output() + + # Output of this step becomes input for next step + self.inputs['Param'] = param_out + self.inputs['Moment'] = moment_out + self.inputs['InfNorm'] = inf_norm_out + self.inputs['Beta1Pow'] = beta1_pow_out + + # Randomize gradient for next step + self.inputs['Grad'] = np.random.uniform( + -1, 1, (102, 105)).astype("float32") + + +def adamax_step(inputs, attributes): + ''' + Simulate one step of the adamax optimizer + :param inputs: dict of inputs + :param attributes: dict of attributes + :return tuple: tuple of output param, moment, inf_norm and + beta1 power accumulator + ''' + param = inputs['Param'] + grad = inputs['Grad'] + moment = inputs['Moment'] + inf_norm = inputs['InfNorm'] + lr = inputs['LearningRate'] + beta1_pow = inputs['Beta1Pow'] + + beta1 = attributes['beta1'] + beta2 = attributes['beta2'] + epsilon = attributes['epsilon'] + + moment_out = beta1 * moment + (1 - beta1) * grad + inf_norm_out = np.maximum(beta2 * inf_norm + epsilon, np.abs(grad)) + beta1_pow_out = beta1_pow * beta1 + lr_t = (lr / (1 - beta1_pow_out)) + param_out = param - lr_t * np.divide(moment_out, inf_norm_out) + + return param_out, moment_out, inf_norm_out, beta1_pow_out + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_conv_shift_op.py b/python/paddle/v2/framework/tests/test_conv_shift_op.py new file mode 100644 index 0000000000..b9ab21a06a --- /dev/null +++ b/python/paddle/v2/framework/tests/test_conv_shift_op.py @@ -0,0 +1,47 @@ +import unittest +import numpy as np +from op_test import OpTest + + +def conv_shift_forward(x, y): + out = np.zeros_like(x) + M = x.shape[1] + N = y.shape[1] + y_half_width = (N - 1) / 2 + for i in xrange(M): + for j in xrange(N): + out[:, i] += x[:, (i + j + M - y_half_width) % M] * y[:, j] + return out + + +class TestConvShiftOp(OpTest): + def setUp(self): + self.op_type = "conv_shift" + + batch_size = 4 + x_dim = 17 + y_dim = 3 # must be odd and <= x_dim + x = np.random.random((batch_size, x_dim)).astype("float32") + y = np.random.random((batch_size, y_dim)).astype("float32") + self.inputs = {'X': x, 'Y': y} + + out = conv_shift_forward(x, y) + self.outputs = {'Out': out} + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.05) + + def test_check_grad_ignore_x(self): + self.check_grad( + ['Y'], 'Out', max_relative_error=0.05, no_grad_set=set("X")) + + def test_check_grad_ignore_y(self): + self.check_grad( + ['X'], 'Out', max_relative_error=0.05, no_grad_set=set('Y')) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_fill_constant_op.py b/python/paddle/v2/framework/tests/test_fill_constant_op.py new file mode 100644 index 0000000000..dff7b615aa --- /dev/null +++ b/python/paddle/v2/framework/tests/test_fill_constant_op.py @@ -0,0 +1,35 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestFillConstantOp1(OpTest): + def setUp(self): + '''Test fill_constant op with specified value + ''' + self.op_type = "fill_constant" + + self.inputs = {} + self.attrs = {'shape': [123, 92], 'value': 3.8} + self.outputs = {'Out': np.full((123, 92), 3.8)} + + def test_check_output(self): + self.check_output() + + +class TestFillConstantOp2(OpTest): + def setUp(self): + '''Test fill_constant op with default value + ''' + self.op_type = "fill_constant" + + self.inputs = {} + self.attrs = {'shape': [123, 92]} + self.outputs = {'Out': np.full((123, 92), 0.0)} + + def test_check_output(self): + self.check_output() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_infer_shape.py b/python/paddle/v2/framework/tests/test_infer_shape.py index b38ec9c037..99562890fd 100644 --- a/python/paddle/v2/framework/tests/test_infer_shape.py +++ b/python/paddle/v2/framework/tests/test_infer_shape.py @@ -1,6 +1,6 @@ import unittest + import paddle.v2.framework.core as core -from paddle.v2.framework.op import Operator class TestInferShape(unittest.TestCase): @@ -26,7 +26,7 @@ class TestInferShape(unittest.TestCase): sum_op_desc.set_input("X", ["x1", "x2"]) sum_op_desc.set_output("Out", ["out"]) - core.Operator.infer_shape(sum_op_desc, block) + sum_op_desc.infer_shape(block) self.assertEqual(out.shape(), shape) def test_mul_op(self): @@ -55,7 +55,7 @@ class TestInferShape(unittest.TestCase): mul_op_desc.set_attr("x_num_col_dims", 1) mul_op_desc.set_attr("y_num_col_dims", 1) - core.Operator.infer_shape(mul_op_desc, block) + mul_op_desc.infer_shape(block) self.assertEqual(out.shape(), [x_shape[0], y_shape[1]]) diff --git a/python/paddle/v2/framework/tests/test_interp_op.py b/python/paddle/v2/framework/tests/test_interp_op.py new file mode 100644 index 0000000000..066569b96c --- /dev/null +++ b/python/paddle/v2/framework/tests/test_interp_op.py @@ -0,0 +1,28 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestInterpOp(OpTest): + def setUp(self): + self.op_type = "interp" + x = np.random.random((2, 3)).astype("float32") + y = np.random.random((2, 3)).astype("float32") + w = np.random.random(2).astype("float32") + + sub_out = x - y + mul_out = sub_out * w.reshape(2, 1) + out = mul_out + y + + self.inputs = {'X': x, 'Y': y, 'W': w} + self.outputs = {'Out': out, 'SubOut': sub_out, 'MulOut': mul_out} + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X', 'Y'], 'Out') + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_parameter.py b/python/paddle/v2/framework/tests/test_parameter.py new file mode 100644 index 0000000000..3b5d38f257 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_parameter.py @@ -0,0 +1,27 @@ +import unittest +from paddle.v2.framework.graph import g_program +import paddle.v2.framework.core as core + + +class TestParameter(unittest.TestCase): + def test_param(self): + b = g_program.create_block() + param = b.create_parameter( + name='fc.w', + shape=[784, 100], + dtype='float32', + initialize_attr={ + 'type': 'uniform_random', + 'seed': 13, + 'min': -5.0, + 'max': 5.0 + }) + self.assertIsNotNone(param) + self.assertEqual('fc.w', param.name) + self.assertEqual((784, 100), param.shape) + self.assertEqual(core.DataType.FP32, param.data_type) + self.assertEqual(0, param.block.idx) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_pool_max_op.py b/python/paddle/v2/framework/tests/test_pool_max_op.py new file mode 100644 index 0000000000..f0f8aa6089 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_pool_max_op.py @@ -0,0 +1,212 @@ +import unittest +import numpy as np +from op_test import OpTest + + +def max_pool3D_forward_naive(x, + ksize, + strides, + paddings=[0, 0, 0], + global_pool=0): + + N, C, D, H, W = x.shape + if global_pool == 1: + ksize = [D, H, W] + D_out = (D - ksize[0] + 2 * paddings[0]) / strides[0] + 1 + H_out = (H - ksize[1] + 2 * paddings[1]) / strides[1] + 1 + W_out = (W - ksize[2] + 2 * paddings[2]) / strides[2] + 1 + out = np.zeros((N, C, D_out, H_out, W_out)) + mask = np.zeros((N, C, D_out, H_out, W_out)) + for k in xrange(D_out): + d_start = np.max((k * strides[0] - paddings[0], 0)) + d_end = np.min((k * strides[0] + ksize[0] - paddings[0], D)) + for i in xrange(H_out): + h_start = np.max((i * strides[0] - paddings[0], 0)) + h_end = np.min((i * strides[0] + ksize[0] - paddings[0], H)) + for j in xrange(W_out): + w_start = np.max((j * strides[1] - paddings[1], 0)) + w_end = np.min((j * strides[1] + ksize[1] - paddings[1], W)) + x_masked = x[:, :, d_start:d_end, h_start:h_end, w_start:w_end] + + out[:, :, k, i, j] = np.max(x_masked, axis=(2, 3, 4)) + + for n in xrange(N): + for c in xrange(C): + arr = x_masked[n, c, :, :, :] + index = np.where(arr == np.max(arr)) + sub_deep = index[0][0] + sub_row = index[1][0] + sub_col = index[2][0] + index = ((d_start + sub_deep) * H + + (h_start + sub_row)) * W + w_start + sub_col + mask[n, c, k, i, j] = index + + return out, mask + + +def max_pool2D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0): + + N, C, H, W = x.shape + if global_pool == 1: + ksize = [H, W] + H_out = (H - ksize[0] + 2 * paddings[0]) / strides[0] + 1 + W_out = (W - ksize[1] + 2 * paddings[1]) / strides[1] + 1 + out = np.zeros((N, C, H_out, W_out)) + mask = np.zeros((N, C, H_out, W_out)) + for i in xrange(H_out): + for j in xrange(W_out): + r_start = np.max((i * strides[0] - paddings[0], 0)) + r_end = np.min((i * strides[0] + ksize[0] - paddings[0], H)) + c_start = np.max((j * strides[1] - paddings[1], 0)) + c_end = np.min((j * strides[1] + ksize[1] - paddings[1], W)) + x_masked = x[:, :, r_start:r_end, c_start:c_end] + + out[:, :, i, j] = np.max(x_masked, axis=(2, 3)) + + for n in xrange(N): + for c in xrange(C): + arr = x_masked[n, c, :, :] + index = np.where(arr == np.max(arr)) + sub_row = index[0][0] + sub_col = index[1][0] + index = (r_start + sub_row) * W + c_start + sub_col + mask[n, c, i, j] = index + + return out, mask + + +class TestMaxPoolWithIndex_Op(OpTest): + def setUp(self): + self.initTestCase() + input = np.random.random(self.shape).astype("float32") + output, mask = self.pool_forward_naive(input, self.ksize, self.strides, + self.paddings, self.global_pool) + + self.attrs = { + 'strides': self.strides, + 'paddings': self.paddings, + 'ksize': self.ksize, + 'globalPooling': self.global_pool, + } + + self.inputs = {'X': input} + self.outputs = {'Out': output, "Mask": mask} + + def test_check_output(self): + self.check_output() + + # def test_check_grad(self): + # self.check_grad(set(['X']), ['Out'], max_relative_error=0.07) + + def initTestCase(self): + self.global_pool = True + self.index = "max_pool3d_with_index" + self.op_type = "%s" % self.index + self.pool_forward_naive = max_pool3D_forward_naive + self.shape = [2, 3, 5, 5, 5] + self.ksize = [3, 3, 3] + self.strides = [1, 1, 1] + self.paddings = [1, 1, 1] + + +class TestCase1(TestMaxPoolWithIndex_Op): + def initTestCase(self): + self.global_pool = True + self.op_type = "max_pool3d_with_index" + self.pool_forward_naive = max_pool3D_forward_naive + self.shape = [2, 3, 5, 5, 5] + self.ksize = [3, 3, 3] + self.strides = [1, 1, 1] + self.paddings = [1, 1, 1] + + +class TestCase2(TestMaxPoolWithIndex_Op): + def initTestCase(self): + self.global_pool = False + self.op_type = "max_pool3d_with_index" + self.pool_forward_naive = max_pool3D_forward_naive + self.shape = [2, 3, 7, 7, 7] + self.ksize = [3, 3, 3] + self.strides = [1, 1, 1] + self.paddings = [1, 1, 1] + + +class TestCase3(TestMaxPoolWithIndex_Op): + def initTestCase(self): + self.global_pool = False + self.op_type = "max_pool3d_with_index" + self.pool_forward_naive = max_pool3D_forward_naive + self.shape = [2, 3, 7, 7, 7] + self.ksize = [3, 3, 3] + self.strides = [2, 2, 2] + self.paddings = [0, 0, 0] + + +class TestCase4(TestMaxPoolWithIndex_Op): + def initTestCase(self): + self.global_pool = True + self.op_type = "max_pool3d_with_index" + self.pool_forward_naive = max_pool3D_forward_naive + self.shape = [2, 3, 5, 5, 5] + self.ksize = [3, 3, 3] + self.strides = [1, 1, 1] + self.paddings = [1, 1, 1] + + +class TestCase5(TestMaxPoolWithIndex_Op): + def initTestCase(self): + self.global_pool = True + self.op_type = "max_pool3d_with_index" + self.pool_forward_naive = max_pool3D_forward_naive + self.shape = [2, 3, 5, 5, 5] + self.ksize = [3, 3, 3] + self.strides = [2, 2, 2] + self.paddings = [0, 0, 0] + + +class TestCase6(TestMaxPoolWithIndex_Op): + def initTestCase(self): + self.global_pool = False + self.op_type = "max_pool2d_with_index" + self.pool_forward_naive = max_pool2D_forward_naive + self.shape = [2, 3, 7, 7] + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [1, 1] + + +class TestCase7(TestMaxPoolWithIndex_Op): + def initTestCase(self): + self.global_pool = False + self.op_type = "max_pool2d_with_index" + self.pool_forward_naive = max_pool2D_forward_naive + self.shape = [2, 3, 7, 7] + self.ksize = [3, 3] + self.strides = [2, 2] + self.paddings = [0, 0] + + +class TestCase8(TestMaxPoolWithIndex_Op): + def initTestCase(self): + self.global_pool = True + self.op_type = "max_pool2d_with_index" + self.pool_forward_naive = max_pool2D_forward_naive + self.shape = [2, 3, 5, 5] + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [1, 1] + + +class TestCase9(TestMaxPoolWithIndex_Op): + def initTestCase(self): + self.global_pool = True + self.op_type = "max_pool2d_with_index" + self.pool_forward_naive = max_pool2D_forward_naive + self.shape = [2, 3, 5, 5] + self.ksize = [3, 3] + self.strides = [2, 2] + self.paddings = [0, 0] + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_program.py b/python/paddle/v2/framework/tests/test_program.py new file mode 100644 index 0000000000..b82d1760d6 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_program.py @@ -0,0 +1,36 @@ +import unittest +from paddle.v2.framework.graph import g_program + + +class TestProgram(unittest.TestCase): + def test_program(self): + b = g_program.current_block() + self.assertEqual(-1, b.parent_idx) + self.assertEqual(0, b.idx) + + b = g_program.create_block() + self.assertEqual(1, b.idx) + self.assertEqual(0, b.parent_idx) + + b = g_program.create_block() + self.assertEqual(2, b.idx) + self.assertEqual(1, b.parent_idx) + + g_program.rollback() + + b = g_program.current_block() + self.assertEqual(1, b.idx) + self.assertEqual(0, b.parent_idx) + + b = g_program.create_block() + self.assertEqual(3, b.idx) + self.assertEqual(1, b.parent_idx) + + g_program.rollback() + b = g_program.current_block() + self.assertEqual(1, b.idx) + self.assertEqual(0, b.parent_idx) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_variable.py b/python/paddle/v2/framework/tests/test_variable.py new file mode 100644 index 0000000000..8ea1083ff6 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_variable.py @@ -0,0 +1,40 @@ +import unittest +from paddle.v2.framework.graph import Variable, g_program +import paddle.v2.framework.core as core +import numpy as np + + +class TestVariable(unittest.TestCase): + def test_np_dtype_convert(self): + DT = core.DataType + convert = Variable._convert_np_dtype_to_dtype_ + self.assertEqual(DT.FP32, convert(np.float32)) + self.assertEqual(DT.FP16, convert("float16")) + self.assertEqual(DT.FP64, convert("float64")) + self.assertEqual(DT.INT32, convert("int32")) + self.assertEqual(DT.INT16, convert("int16")) + self.assertEqual(DT.INT64, convert("int64")) + self.assertEqual(DT.BOOL, convert("bool")) + self.assertRaises(ValueError, lambda: convert("int8")) + + def test_var(self): + b = g_program.current_block() + w = b.create_var( + dtype="float64", shape=[784, 100], lod_level=0, name="fc.w") + self.assertEqual(core.DataType.FP64, w.data_type) + self.assertEqual((784, 100), w.shape) + self.assertEqual("fc.w", w.name) + self.assertEqual(0, w.lod_level) + + w = b.create_var(name='fc.w') + self.assertEqual(core.DataType.FP64, w.data_type) + self.assertEqual((784, 100), w.shape) + self.assertEqual("fc.w", w.name) + self.assertEqual(0, w.lod_level) + + self.assertRaises(ValueError, + lambda: b.create_var(name="fc.w", shape=(24, 100))) + + +if __name__ == '__main__': + unittest.main()