commit
da1181bfc6
@ -0,0 +1,50 @@
|
||||
INCLUDE(ExternalProject)
|
||||
|
||||
SET(NCCL_SOURCE_DIR ${THIRD_PARTY_PATH}/nccl)
|
||||
|
||||
INCLUDE_DIRECTORIES(${NCCL_SOURCE_DIR}/src/extern_nccl/src)
|
||||
|
||||
|
||||
if(WITH_DSO)
|
||||
# If we use DSO, we do not build nccl, just download the dependencies
|
||||
set(NCCL_BUILD_COMMAND "")
|
||||
set(NCCL_INSTALL_COMMAND "")
|
||||
set(NCCL_INSTALL_DIR "")
|
||||
else()
|
||||
# otherwise, we build nccl and link it.
|
||||
set(NCCL_BUILD_COMMAND "make -j 8")
|
||||
set(NCCL_INSTALL_COMMAND "make install")
|
||||
SET(NCCL_INSTALL_DIR ${THIRD_PARTY_PATH}/install/nccl)
|
||||
endif()
|
||||
|
||||
ExternalProject_Add(
|
||||
extern_nccl
|
||||
${EXTERNAL_PROJECT_LOG_ARGS}
|
||||
GIT_REPOSITORY "https://github.com/NVIDIA/nccl.git"
|
||||
GIT_TAG "v1.3.4-1"
|
||||
PREFIX "${NCCL_SOURCE_DIR}"
|
||||
UPDATE_COMMAND ""
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND "${NCCL_BUILD_COMMAND}"
|
||||
INSTALL_COMMAND "${NCCL_INSTALL_COMMAND}"
|
||||
INSTALL_DIR "${NCCL_INSTALL_DIR}"
|
||||
TEST_COMMAND ""
|
||||
)
|
||||
|
||||
if (WITH_DSO)
|
||||
if (${CMAKE_VERSION} VERSION_LESS "3.3.0")
|
||||
set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/lib_any_dummy.c)
|
||||
file(WRITE ${dummyfile} "const char * dummy_any = \"${dummyfile}\";")
|
||||
add_library(nccl STATIC ${dummyfile})
|
||||
else()
|
||||
add_library(nccl INTERFACE)
|
||||
endif()
|
||||
else()
|
||||
ADD_LIBRARY(nccl STATIC IMPORTED GLOBAL)
|
||||
SET_PROPERTY(TARGET nccl PROPERTY IMPORTED_LOCATION
|
||||
${NCCL_INSTALL_DIR}/lib/libnccl.a)
|
||||
endif()
|
||||
|
||||
add_dependencies(nccl extern_nccl)
|
||||
|
||||
LIST(APPEND external_project_dependencies nccl)
|
@ -1,30 +0,0 @@
|
||||
if (NOT WITH_GPU)
|
||||
return ()
|
||||
endif()
|
||||
|
||||
set(NCCL_ROOT "/usr" CACHE PATH "CUDNN ROOT")
|
||||
find_path(NCCL_INCLUDE_DIR nccl.h PATHS
|
||||
${NCCL_ROOT} ${NCCL_ROOT}/include
|
||||
$ENV{NCCL_ROOT} $ENV{NCCL_ROOT}/include ${CUDA_TOOLKIT_INCLUDE}
|
||||
NO_DEFAULT_PATH)
|
||||
|
||||
get_filename_component(__libpath_hist ${CUDA_CUDART_LIBRARY} PATH)
|
||||
|
||||
set(TARGET_ARCH "x86_64")
|
||||
if(NOT ${CMAKE_SYSTEM_PROCESSOR})
|
||||
set(TARGET_ARCH ${CMAKE_SYSTEM_PROCESSOR})
|
||||
endif()
|
||||
|
||||
list(APPEND NCCL_CHECK_LIBRARY_DIRS
|
||||
${NCCL_ROOT}
|
||||
${NCCL_ROOT}/lib64
|
||||
${NCCL_ROOT}/lib
|
||||
${NCCL_ROOT}/lib/${TARGET_ARCH}-linux-gnu
|
||||
$ENV{NCCL_ROOT}
|
||||
$ENV{NCCL_ROOT}/lib64
|
||||
$ENV{NCCL_ROOT}/lib
|
||||
/usr/lib)
|
||||
find_library(NCCL_LIBRARY NAMES libnccl.so libnccl.dylib # libcudnn_static.a
|
||||
PATHS ${NCCL_CHECK_LIBRARY_DIRS} ${NCCL_INCLUDE_DIR} ${__libpath_hist}
|
||||
NO_DEFAULT_PATH
|
||||
DOC "Path to nccl library.")
|
@ -0,0 +1,232 @@
|
||||
## Survey on Graph
|
||||
|
||||
Neural network framework often provides symbolic API for users to write network topology conveniently. This doc manily focus on symbolic API in most popular neural network frameworks, and try to find out how to parse symbolic configuration to a portable file, such as protobuf or json.
|
||||
|
||||
### Mxnet
|
||||
|
||||
The core concept of symbolic API is `Symbol`. Mxnet implements `Symbol` class in C++, and export to Python using C-API. Please refer to the comments in Mxnet:
|
||||
|
||||
|
||||
`Symbol` is help class used to represent the operator node in Graph.
|
||||
`Symbol` acts as an interface for building graphs from different components like Variable, Functor and Group. `Symbol` is also exported to python front-end (while Graph is not) to enable quick test and deployment. Conceptually, symbol is the final operation of a graph and thus including all the information required (the graph) to evaluate its output value.
|
||||
|
||||
|
||||
A simple network topology wrote by Symbol is as follows:
|
||||
|
||||
```python
|
||||
def get_symbol(num_classes=10, **kwargs):
|
||||
data = mx.symbol.Variable('data')
|
||||
data = mx.symbol.Flatten(data=data)
|
||||
fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128)
|
||||
act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu")
|
||||
fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64)
|
||||
act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu")
|
||||
fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=num_classes)
|
||||
mlp = mx.symbol.SoftmaxOutput(data = fc3, name = 'softmax')
|
||||
return mlp
|
||||
```
|
||||
|
||||
|
||||
|
||||
Varible here is actually a Symbol. Every basic Symbol will correspond to one Node, and every Node has its own NodeAttr. There is a op field in NodeAttr class, when a Symbol represents Variable(often input data), the op field is null.
|
||||
|
||||
Symbol contains a data member, std::vector<NodeEntry> outputs, and NodeEntry cantains a poniter to Node. We can follow the Node pointer to get all the Graph.
|
||||
|
||||
And Symbol can be saved to a Json file.
|
||||
|
||||
Here is a detailed example:
|
||||
|
||||
```
|
||||
>>> import mxnet as mx
|
||||
>>> data = mx.symbol.Variable('data')
|
||||
>>> print data.debug_str()
|
||||
Variable:data
|
||||
|
||||
>>> data = mx.symbol.Flatten(data=data)
|
||||
>>> print data.debug_str()
|
||||
Symbol Outputs:
|
||||
output[0]=flatten0(0)
|
||||
Variable:data
|
||||
--------------------
|
||||
Op:Flatten, Name=flatten0
|
||||
Inputs:
|
||||
arg[0]=data(0) version=0
|
||||
|
||||
>>> fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128)
|
||||
>>> print fc1.debug_str()
|
||||
Symbol Outputs:
|
||||
output[0]=fc1(0)
|
||||
Variable:data
|
||||
--------------------
|
||||
Op:Flatten, Name=flatten0
|
||||
Inputs:
|
||||
arg[0]=data(0) version=0
|
||||
Variable:fc1_weight
|
||||
Variable:fc1_bias
|
||||
--------------------
|
||||
Op:FullyConnected, Name=fc1
|
||||
Inputs:
|
||||
arg[0]=flatten0(0)
|
||||
arg[1]=fc1_weight(0) version=0
|
||||
arg[2]=fc1_bias(0) version=0
|
||||
Attrs:
|
||||
num_hidden=128
|
||||
|
||||
```
|
||||
|
||||
|
||||
### TensorFlow
|
||||
|
||||
|
||||
The core concept of symbolic API is `Tensor`. Tensorflow defines `Tensor` in Python. Please refer to the comments in TensorFlow:
|
||||
|
||||
A `Tensor` is a symbolic handle to one of the outputs of an `Operation`. It does not hold the values of that operation's output, but instead provides a means of computing those values in a TensorFlow [Session](https://www.tensorflow.org/api_docs/python/tf/Session).
|
||||
|
||||
A simple example is as follows:
|
||||
|
||||
```python
|
||||
# Build a dataflow graph.
|
||||
c = tf.constant([[1.0, 2.0], [3.0, 4.0]])
|
||||
d = tf.constant([[1.0, 1.0], [0.0, 1.0]])
|
||||
e = tf.matmul(c, d)
|
||||
|
||||
# Construct a `Session` to execute the graph.
|
||||
sess = tf.Session()
|
||||
|
||||
# Execute the graph and store the value that `e` represents in `result`.
|
||||
result = sess.run(e)
|
||||
```
|
||||
|
||||
|
||||
The main method of `Tensor` is as follows:
|
||||
|
||||
|
||||
```python
|
||||
@property
|
||||
def op(self):
|
||||
"""The `Operation` that produces this tensor as an output."""
|
||||
return self._op
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
"""The `DType` of elements in this tensor."""
|
||||
return self._dtype
|
||||
|
||||
@property
|
||||
def graph(self):
|
||||
"""The `Graph` that contains this tensor."""
|
||||
return self._op.graph
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
"""The string name of this tensor."""
|
||||
if not self._op.name:
|
||||
raise ValueError("Operation was not named: %s" % self._op)
|
||||
return "%s:%d" % (self._op.name, self._value_index)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
"""The name of the device on which this tensor will be produced, or None."""
|
||||
return self._op.device
|
||||
```
|
||||
|
||||
|
||||
Tensor can be taken as target to run by session. Tensor contains all the information of Graph, and tracks data dependency.
|
||||
|
||||
|
||||
Here is a detailed example:
|
||||
|
||||
|
||||
```
|
||||
>>> import tensorflow as tf
|
||||
>>> c = tf.constant([[1.0, 2.0], [3.0, 4.0]])
|
||||
>>> print c.graph
|
||||
<tensorflow.python.framework.ops.Graph object at 0x10f256d50>
|
||||
>>> d = tf.constant([[1.0, 1.0], [0.0, 1.0]])
|
||||
>>> print d.graph
|
||||
<tensorflow.python.framework.ops.Graph object at 0x10f256d50>
|
||||
>>> e = tf.matmul(c, d)
|
||||
>>> print e.graph
|
||||
<tensorflow.python.framework.ops.Graph object at 0x10f256d50>
|
||||
```
|
||||
|
||||
### Dynet
|
||||
|
||||
|
||||
The core concept of symbolic API is `Expression`, and Dynet defines `Expression` class in C++.
|
||||
|
||||
|
||||
A simple example is as follows:
|
||||
|
||||
```cpp
|
||||
ComputationGraph cg;
|
||||
Expression W = parameter(cg, pW);
|
||||
|
||||
Expression in = input(cg, xs[i]);
|
||||
Expression label = input(cg, ys[i]);
|
||||
Expression pred = W * in;
|
||||
Expression loss = square(pred - label);
|
||||
```
|
||||
|
||||
The input data and parameter are also represented by Expression. Every basci Expression corresponds to a Node. And input data is also a Node.
|
||||
|
||||
Expression has a data member ComputationGraph, and ComputationGraph will be modified in users' configuring process. Expression can be a running target, beacuse Expression contains all dependency.
|
||||
|
||||
|
||||
Here is a detailed example:
|
||||
|
||||
write topology in C++
|
||||
|
||||
```
|
||||
ComputationGraph cg;
|
||||
Expression W = parameter(cg, pW);
|
||||
cg.print_graphviz();
|
||||
|
||||
Expression pred = W * xs[i];
|
||||
cg.print_graphviz();
|
||||
|
||||
Expression loss = square(pred - ys[i]);
|
||||
cg.print_graphviz();
|
||||
```
|
||||
|
||||
compile and print
|
||||
|
||||
```
|
||||
# first print
|
||||
digraph G {
|
||||
rankdir=LR;
|
||||
nodesep=.05;
|
||||
N0 [label="v0 = parameters({1}) @ 0x7ffe4de00110"];
|
||||
}
|
||||
# second print
|
||||
digraph G {
|
||||
rankdir=LR;
|
||||
nodesep=.05;
|
||||
N0 [label="v0 = parameters({1}) @ 0x7ffe4de00110"];
|
||||
N1 [label="v1 = v0 * -0.98"];
|
||||
N0 -> N1;
|
||||
}
|
||||
# third print
|
||||
digraph G {
|
||||
rankdir=LR;
|
||||
nodesep=.05;
|
||||
N0 [label="v0 = parameters({1}) @ 0x7ffe4de00110"];
|
||||
N1 [label="v1 = v0 * -0.98"];
|
||||
N0 -> N1;
|
||||
N2 [label="v2 = -1.88387 - v1"];
|
||||
N1 -> N2;
|
||||
N3 [label="v3 = -v2"];
|
||||
N2 -> N3;
|
||||
N4 [label="v4 = square(v3)"];
|
||||
N3 -> N4;
|
||||
}
|
||||
```
|
||||
|
||||
### Conclusion
|
||||
|
||||
|
||||
Actually, Symbol/Tensor/Expression in Mxnet/TensorFlow/Dynet are the same level concepts. We use a unified name Expression here, this level concept has following features:
|
||||
|
||||
- Users wirte topoloy with symbolic API, and all return value is Expression, including input data and parameter.
|
||||
- Expression corresponds with a global Graph, and Expression can also be composed.
|
||||
- Expression tracks all dependency and can be taken as a run target
|
@ -0,0 +1,107 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/conv2dtranspose_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
void Conv2DTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
|
||||
PADDLE_ENFORCE(ctx->HasInput("Input"),
|
||||
"Input(Input) of Conv2DTransposeOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Filter"),
|
||||
"Input(Filter) of Conv2DTransposeOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Output"),
|
||||
"Output(Output) of Conv2DTransposeOp should not be null.");
|
||||
|
||||
auto in_dims = ctx->GetInputDim("Input");
|
||||
auto filter_dims = ctx->GetInputDim("Filter");
|
||||
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
|
||||
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
|
||||
|
||||
for (size_t i = 0; i < paddings.size(); ++i) {
|
||||
PADDLE_ENFORCE_EQ(paddings[i], 0,
|
||||
"No Padding allowed in conv transpose op.");
|
||||
}
|
||||
|
||||
PADDLE_ENFORCE_EQ(in_dims.size(), 4,
|
||||
"Conv2DTransposeOp input should be 4-D tensor.");
|
||||
PADDLE_ENFORCE_EQ(filter_dims.size(), 4,
|
||||
"Conv2DTransposeOp filter should be 4-D tensor.");
|
||||
PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[0],
|
||||
"input and kernel input dimension should be equal.");
|
||||
|
||||
auto output_height = (in_dims[2] - 1) * strides[0] + filter_dims[2];
|
||||
auto output_width = (in_dims[3] - 1) * strides[1] + filter_dims[3];
|
||||
ctx->SetOutputDim("Output",
|
||||
{in_dims[0], filter_dims[1], output_height, output_width});
|
||||
}
|
||||
|
||||
Conv2DTransposeOpMaker::Conv2DTransposeOpMaker(
|
||||
framework::OpProto* proto, framework::OpAttrChecker* op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput(
|
||||
"Input",
|
||||
"(Tensor) The input tensor of convolution transpose operator. "
|
||||
"The format of input tensor is NCHW. Where N is batch size, C is the "
|
||||
"number of input channels, H and W is the height and width of image.");
|
||||
AddInput("Filter",
|
||||
"(Tensor) The filter tensor of convolution transpose operator."
|
||||
"The format of the filter tensor is CMHW, where C is the number of "
|
||||
"output image channels, M is the number of input image channels, "
|
||||
"H and W is height and width of filter. "
|
||||
"We enforce groups number == 1 and padding == 0 in "
|
||||
"convolution transpose Scenario.");
|
||||
AddOutput("Output",
|
||||
"(Tensor) The output tensor of convolution transpose operator."
|
||||
"The format of output tensor is also NCHW.");
|
||||
AddAttr<std::vector<int>>("strides",
|
||||
"strides of convolution transpose operator.")
|
||||
.SetDefault({1, 1});
|
||||
AddAttr<std::vector<int>>("paddings",
|
||||
"paddings of convolution transpose operator.")
|
||||
.SetDefault({0, 0});
|
||||
AddComment(R"DOC(
|
||||
The convolution transpose operation calculates the output based on the input, filter
|
||||
and strides, paddings, groups parameters. The size of each dimension of the
|
||||
parameters is checked in the infer-shape.
|
||||
)DOC");
|
||||
}
|
||||
|
||||
void Conv2DTransposeOpGrad::InferShape(
|
||||
framework::InferShapeContext* ctx) const {
|
||||
auto in_dims = ctx->GetInputDim("Input");
|
||||
auto filter_dims = ctx->GetInputDim("Filter");
|
||||
if (ctx->HasOutput(framework::GradVarName("Input"))) {
|
||||
ctx->SetOutputDim(framework::GradVarName("Input"), in_dims);
|
||||
}
|
||||
if (ctx->HasOutput(framework::GradVarName("Filter"))) {
|
||||
ctx->SetOutputDim(framework::GradVarName("Filter"), filter_dims);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP(conv2dtranspose, ops::Conv2DTransposeOp,
|
||||
ops::Conv2DTransposeOpMaker, conv2dtranspose_grad,
|
||||
ops::Conv2DTransposeOpGrad);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
conv2dtranspose,
|
||||
ops::GemmConv2DTransposeKernel<paddle::platform::CPUPlace, float>);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
conv2dtranspose_grad,
|
||||
ops::GemmConv2DTransposeGradKernel<paddle::platform::CPUPlace, float>);
|
@ -0,0 +1,24 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/conv2dtranspose_op.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
|
||||
REGISTER_OP_GPU_KERNEL(
|
||||
conv2dtranspose,
|
||||
ops::GemmConv2DTransposeKernel<paddle::platform::GPUPlace, float>);
|
||||
REGISTER_OP_GPU_KERNEL(
|
||||
conv2dtranspose_grad,
|
||||
ops::GemmConv2DTransposeGradKernel<paddle::platform::GPUPlace, float>);
|
File diff suppressed because it is too large
Load Diff
@ -1,2 +1,3 @@
|
||||
cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags)
|
||||
nv_library(dynload_cuda SRCS cublas.cc cudnn.cc curand.cc nccl.cc DEPS dynamic_loader)
|
||||
nv_library(dynload_cuda SRCS cublas.cc cudnn.cc curand.cc nccl.cc
|
||||
DEPS dynamic_loader nccl)
|
||||
|
@ -0,0 +1,102 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
def conv2dtranspose_forward_naive(input_, filter_, conv2dtranspose_param):
|
||||
# [2, 3, 5, 5]
|
||||
in_n, in_c, in_h, in_w = input_.shape
|
||||
# [3, 6, 3, 3]
|
||||
f_c, out_c, f_h, f_w = filter_.shape
|
||||
assert in_c == f_c
|
||||
|
||||
stride, pad = conv2dtranspose_param['stride'], conv2dtranspose_param['pad']
|
||||
out_h = (in_h - 1) * stride[0] + f_h
|
||||
out_w = (in_w - 1) * stride[1] + f_w
|
||||
|
||||
out = np.zeros((in_n, out_c, out_h, out_w))
|
||||
|
||||
for n in range(in_n):
|
||||
for i in range(in_h):
|
||||
for j in range(in_w):
|
||||
input_masked = input_[n, :, i, j] # (c)
|
||||
input_masked = np.reshape(input_masked, (in_c, 1, 1))
|
||||
input_masked = np.tile(input_masked, (1, f_h, f_w))
|
||||
|
||||
for k in range(out_c):
|
||||
tmp_out = np.sum(input_masked * filter_[:, k, :, :], axis=0)
|
||||
i1, i2 = i * stride[0], i * stride[0] + f_h
|
||||
j1, j2 = j * stride[0], j * stride[0] + f_w
|
||||
out[n, k, i1:i2, j1:j2] += tmp_out
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class TestConv2dTransposeOp(OpTest):
|
||||
def setUp(self):
|
||||
# init as conv transpose
|
||||
self.init_op_type()
|
||||
|
||||
# [2, 3, 5, 5] -> kernel [3, 6, 3, 3] -> output [2, 6, 7, 7]
|
||||
self.init_test_case()
|
||||
|
||||
conv2dtranspose_param = {'stride': self.stride, 'pad': self.pad}
|
||||
input_ = np.random.random(self.input_size).astype("float32")
|
||||
filter_ = np.random.random(self.filter_size).astype("float32")
|
||||
output = conv2dtranspose_forward_naive(input_, filter_,
|
||||
conv2dtranspose_param)
|
||||
# print 'deconv output py', output, output.shape
|
||||
|
||||
self.inputs = {'Input': input_, 'Filter': filter_}
|
||||
self.attrs = {
|
||||
'strides': self.stride,
|
||||
'paddings': self.pad,
|
||||
# 'dilations': self.dilations
|
||||
}
|
||||
self.outputs = {'Output': output}
|
||||
|
||||
def test_check_output(self):
|
||||
print 'check output here'
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad(
|
||||
set(['Input', 'Filter']), 'Output', max_relative_error=0.05)
|
||||
|
||||
def test_check_grad_no_filter(self):
|
||||
self.check_grad(
|
||||
['Input'],
|
||||
'Output',
|
||||
max_relative_error=0.05,
|
||||
no_grad_set=set(['Filter']))
|
||||
|
||||
def test_check_grad_no_input(self):
|
||||
self.check_grad(
|
||||
['Filter'],
|
||||
'Output',
|
||||
max_relative_error=0.05,
|
||||
no_grad_set=set(['Input']))
|
||||
|
||||
def init_test_case(self):
|
||||
self.pad = [0, 0]
|
||||
self.stride = [1, 1]
|
||||
self.dilations = [1, 1]
|
||||
self.input_size = [2, 3, 5, 5] # NCHW
|
||||
f_c = self.input_size[1]
|
||||
self.filter_size = [f_c, 6, 3, 3]
|
||||
|
||||
def init_op_type(self):
|
||||
self.op_type = "conv2dtranspose"
|
||||
|
||||
|
||||
"""
|
||||
class TestCudnn(TestConv2dOp):
|
||||
def init_group(self):
|
||||
self.groups = 1
|
||||
|
||||
def init_op_type(self):
|
||||
self.op_type = "conv_cudnn"
|
||||
"""
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,38 @@
|
||||
import unittest
|
||||
from paddle.v2.framework.layers import *
|
||||
from paddle.v2.framework.framework import g_program
|
||||
|
||||
|
||||
class TestRNN(unittest.TestCase):
|
||||
def test_rnn(self):
|
||||
img = data(
|
||||
shape=[
|
||||
80, # sequence length
|
||||
22, # image height
|
||||
22
|
||||
], # image width
|
||||
data_type='float32',
|
||||
name='image')
|
||||
hidden = fc(input=img, size=100, act='sigmoid', num_flatten_dims=2)
|
||||
self.assertEqual((-1, 80, 100), hidden.shape)
|
||||
hidden = fc(input=hidden, size=100, act='sigmoid', num_flatten_dims=2)
|
||||
self.assertEqual((-1, 80, 100), hidden.shape)
|
||||
|
||||
rnn = StaticRNN()
|
||||
with rnn.step():
|
||||
hidden = rnn.step_input(hidden)
|
||||
self.assertEqual((-1, 100), hidden.shape)
|
||||
memory = rnn.memory(shape=(-1, 32), dtype='float32', init_value=0.0)
|
||||
|
||||
rnn_out = fc(input=[hidden, memory], size=32, act='sigmoid')
|
||||
self.assertEqual((-1, 32), rnn_out.shape)
|
||||
rnn.update_memory(memory, rnn_out)
|
||||
rnn.output(rnn_out)
|
||||
|
||||
out = rnn()
|
||||
self.assertEqual((-1, 80, 32), out.shape)
|
||||
print g_program
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue