Merge remote-tracking branch 'upstream/develop' into factorization_machine_layer

release/0.11.0
wangmeng28 7 years ago
commit 13ec6f99fe

@ -5,6 +5,7 @@ height = 224
width = 224
num_class = 1000
batch_size = get_config_arg('batch_size', int, 128)
use_gpu = get_config_arg('use_gpu', bool, True)
args = {'height': height, 'width': width, 'color': True, 'num_class': num_class}
define_py_data_sources2(
@ -16,6 +17,8 @@ settings(
learning_method=MomentumOptimizer(0.9),
regularization=L2Regularization(0.0005 * batch_size))
conv_projection = conv_projection if use_gpu else img_conv_layer
def inception2(name, input, channels, \
filter1,
filter3R, filter3,
@ -138,7 +141,7 @@ def inception(name, input, channels, \
cat = concat_layer(
name=name,
input=[cov1, cov3, cov5, covprj],
bias_attr=True,
bias_attr=True if use_gpu else False,
act=ReluActivation())
return cat

@ -40,6 +40,7 @@ fi
for use_mkldnn in True False; do
for batchsize in 64 128 256; do
train vgg 19 $batchsize $use_mkldnn
train resnet 50 $batchsize $use_mkldnn
train resnet 50 $batchsize $use_mkldnn
train googlenet v1 $batchsize $use_mkldnn
done
done

@ -212,6 +212,37 @@ Error __must_check backward(Argument& act) {
}
END_DEFINE_ACTIVATION(sequence_softmax)
/*
* @brief SoftSign Activation.
* \f[
* f(z) = \frac{z}{1 + |z|}
* \f]
*/
BEGIN_DEFINE_ACTIVATION(softsign)
private:
MatrixPtr denominator_;
Error __must_check forward(Argument& act) {
size_t height = act.value->getHeight();
size_t width = act.value->getWidth();
Matrix::resizeOrCreate(
denominator_, height, width, false, useGpu(act.deviceId));
denominator_->assign(*act.value);
denominator_->abs2();
denominator_->add(1.);
act.value->dotDiv(*act.value, *denominator_);
return Error();
}
Error __must_check backward(Argument& act) {
denominator_->square2();
denominator_->scalarDiv(*denominator_, 1.);
act.grad->dotMul(*act.grad, *denominator_);
return Error();
}
END_DEFINE_ACTIVATION(softsign)
/**
* @brief Relu Activation.
* forward. y = max(0, z)

@ -40,7 +40,8 @@ REGISTER_OP(conv_cudnn, ops::ConvOp, ops::CudnnConvOpMaker, conv_cudnn_grad,
ops::ConvOpGrad);
REGISTER_OP_CPU_KERNEL(conv_cudnn,
ops::GemmConvKernel<paddle::platform::CPUPlace, float>);
ops::GemmConvKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(
conv_cudnn_grad,
ops::GemmConvGradKernel<paddle::platform::CPUPlace, float>);
conv_cudnn_grad, ops::GemmConvGradKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvGradKernel<paddle::platform::CPUPlace, double>);

@ -259,6 +259,8 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> {
} // namespace operators
} // namespace paddle
REGISTER_OP_GPU_KERNEL(conv_cudnn, paddle::operators::CudnnConvOpKernel<float>);
REGISTER_OP_GPU_KERNEL(conv_cudnn, paddle::operators::CudnnConvOpKernel<float>,
paddle::operators::CudnnConvOpKernel<double>);
REGISTER_OP_GPU_KERNEL(conv_cudnn_grad,
paddle::operators::CudnnConvGradOpKernel<float>);
paddle::operators::CudnnConvGradOpKernel<float>,
paddle::operators::CudnnConvGradOpKernel<double>);

@ -61,10 +61,12 @@ REGISTER_OP(conv2d_transpose_cudnn, ops::ConvTransposeOp,
REGISTER_OP_CPU_KERNEL(
conv2d_transpose_cudnn,
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>);
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(
conv2d_transpose_cudnn_grad,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>);
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP(conv3d_transpose_cudnn, ops::ConvTransposeOp,
ops::CudnnConv3DTransposeOpMaker, conv3d_transpose_cudnn_grad,
@ -72,7 +74,9 @@ REGISTER_OP(conv3d_transpose_cudnn, ops::ConvTransposeOp,
REGISTER_OP_CPU_KERNEL(
conv3d_transpose_cudnn,
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>);
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(
conv3d_transpose_cudnn_grad,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>);
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, double>);

@ -235,11 +235,15 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(conv2d_transpose_cudnn,
ops::CudnnConvTransposeOpKernel<float>);
ops::CudnnConvTransposeOpKernel<float>,
ops::CudnnConvTransposeOpKernel<double>);
REGISTER_OP_GPU_KERNEL(conv2d_transpose_cudnn_grad,
ops::CudnnConvTransposeGradOpKernel<float>);
ops::CudnnConvTransposeGradOpKernel<float>,
ops::CudnnConvTransposeGradOpKernel<double>);
REGISTER_OP_GPU_KERNEL(conv3d_transpose_cudnn,
ops::CudnnConvTransposeOpKernel<float>);
ops::CudnnConvTransposeOpKernel<float>,
ops::CudnnConvTransposeOpKernel<double>);
REGISTER_OP_GPU_KERNEL(conv3d_transpose_cudnn_grad,
ops::CudnnConvTransposeGradOpKernel<float>);
ops::CudnnConvTransposeGradOpKernel<float>,
ops::CudnnConvTransposeGradOpKernel<double>);

@ -498,8 +498,8 @@ template class Pool3dGradFunctor<
* Ksize, strides, paddings are two elements. These two elements represent
* height and width, respectively.
*/
template <typename T>
class MaxPool2dWithIndexFunctor<platform::CPUPlace, T> {
template <typename T1, typename T2>
class MaxPool2dWithIndexFunctor<platform::CPUPlace, T1, T2> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, std::vector<int>& ksize,
@ -520,9 +520,9 @@ class MaxPool2dWithIndexFunctor<platform::CPUPlace, T> {
const int input_stride = input_height * input_width;
const int output_stride = output_height * output_width;
const T* input_data = input.data<T>();
T* output_data = output->mutable_data<T>(context.GetPlace());
T* mask_data = mask->mutable_data<T>(context.GetPlace());
const T1* input_data = input.data<T1>();
T1* output_data = output->mutable_data<T1>(context.GetPlace());
T2* mask_data = mask->mutable_data<T2>(context.GetPlace());
for (int i = 0; i < batch_size; i++) {
for (int c = 0; c < output_channels; ++c) {
@ -535,7 +535,7 @@ class MaxPool2dWithIndexFunctor<platform::CPUPlace, T> {
int wend = std::min(wstart + ksize_width, input_width);
wstart = std::max(wstart, 0);
T ele = static_cast<T>(-FLT_MAX);
T1 ele = static_cast<T1>(-FLT_MAX);
int index = -1;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
@ -563,8 +563,8 @@ class MaxPool2dWithIndexFunctor<platform::CPUPlace, T> {
* Ksize, strides, paddings are two elements. These two elements represent
* height and width, respectively.
*/
template <typename T>
class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, T> {
template <typename T1, typename T2>
class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, T1, T2> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& output_grad,
@ -580,9 +580,9 @@ class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, T> {
const int input_stride = input_height * input_width;
const int output_stride = output_height * output_width;
const T* mask_data = mask.data<T>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
const T2* mask_data = mask.data<T2>();
const T1* output_grad_data = output_grad.data<T1>();
T1* input_grad_data = input_grad->mutable_data<T1>(context.GetPlace());
for (int n = 0; n < batch_size; ++n) {
for (int c = 0; c < output_channels; ++c) {
@ -602,18 +602,18 @@ class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, T> {
}
};
template class MaxPool2dWithIndexFunctor<platform::CPUPlace, float>;
template class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, float>;
template class MaxPool2dWithIndexFunctor<platform::CPUPlace, double>;
template class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, double>;
template class MaxPool2dWithIndexFunctor<platform::CPUPlace, float, int>;
template class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, float, int>;
template class MaxPool2dWithIndexFunctor<platform::CPUPlace, double, int>;
template class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, double, int>;
/*
* All tensors are in NCDHW format.
* Ksize, strides, paddings are three elements. These three elements represent
* depth, height and width, respectively.
*/
template <typename T>
class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> {
template <typename T1, typename T2>
class MaxPool3dWithIndexFunctor<platform::CPUPlace, T1, T2> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, std::vector<int>& ksize,
@ -639,9 +639,9 @@ class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> {
const int input_stride = input_depth * input_height * input_width;
const int output_stride = output_depth * output_height * output_width;
const T* input_data = input.data<T>();
T* output_data = output->mutable_data<T>(context.GetPlace());
T* mask_data = mask->mutable_data<T>(context.GetPlace());
const T1* input_data = input.data<T1>();
T1* output_data = output->mutable_data<T1>(context.GetPlace());
T2* mask_data = mask->mutable_data<T2>(context.GetPlace());
for (int i = 0; i < batch_size; i++) {
for (int c = 0; c < output_channels; ++c) {
@ -659,7 +659,7 @@ class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> {
wstart = std::max(wstart, 0);
int output_idx = (pd * output_height + ph) * output_width + pw;
T ele = static_cast<T>(-FLT_MAX);
T1 ele = static_cast<T1>(-FLT_MAX);
int index = -1;
for (int d = dstart; d < dend; ++d) {
for (int h = hstart; h < hend; ++h) {
@ -691,8 +691,8 @@ class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> {
* Ksize, strides, paddings are three elements. These three elements represent
* depth, height and width, respectively.
*/
template <typename T>
class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, T> {
template <typename T1, typename T2>
class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, T1, T2> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& output_grad,
@ -710,9 +710,9 @@ class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, T> {
const int input_stride = input_depth * input_height * input_width;
const int output_stride = output_depth * output_height * output_width;
const T* mask_data = mask.data<T>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
const T2* mask_data = mask.data<T2>();
const T1* output_grad_data = output_grad.data<T1>();
T1* input_grad_data = input_grad->mutable_data<T1>(context.GetPlace());
for (int n = 0; n < batch_size; ++n) {
for (int c = 0; c < output_channels; ++c) {
@ -735,10 +735,10 @@ class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, T> {
}
};
template class MaxPool3dWithIndexFunctor<platform::CPUPlace, float>;
template class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, float>;
template class MaxPool3dWithIndexFunctor<platform::CPUPlace, double>;
template class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, double>;
template class MaxPool3dWithIndexFunctor<platform::CPUPlace, float, int>;
template class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, float, int>;
template class MaxPool3dWithIndexFunctor<platform::CPUPlace, double, int>;
template class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, double, int>;
} // namespace math
} // namespace operators
} // namespace paddle

File diff suppressed because it is too large Load Diff

@ -153,7 +153,7 @@ class MaxPool3dGradFunctor {
* In pool2d, all tensors are in NCHW format. In pool3d, all tensors are in
* NCDHW format.
*/
template <typename Place, typename T>
template <typename Place, typename T1, typename T2>
class MaxPool2dWithIndexFunctor {
public:
void operator()(const platform::DeviceContext& context,
@ -162,7 +162,7 @@ class MaxPool2dWithIndexFunctor {
framework::Tensor* output, framework::Tensor* mask);
};
template <typename Place, typename T>
template <typename Place, typename T1, typename T2>
class MaxPool2dWithIndexGradFunctor {
public:
void operator()(const platform::DeviceContext& context,
@ -172,7 +172,7 @@ class MaxPool2dWithIndexGradFunctor {
framework::Tensor* input_grad);
};
template <typename Place, typename T>
template <typename Place, typename T1, typename T2>
class MaxPool3dWithIndexFunctor {
public:
void operator()(const platform::DeviceContext& context,
@ -181,7 +181,7 @@ class MaxPool3dWithIndexFunctor {
framework::Tensor* output, framework::Tensor* mask);
};
template <typename Place, typename T>
template <typename Place, typename T1, typename T2>
class MaxPool3dWithIndexGradFunctor {
public:
void operator()(const platform::DeviceContext& context,

@ -20,6 +20,18 @@ REGISTER_OP(pool2d_cudnn, ops::PoolOp, ops::Pool2dOpMaker, pool2d_cudnn_grad,
ops::PoolOpGrad);
REGISTER_OP_CPU_KERNEL(pool2d_cudnn,
ops::PoolKernel<paddle::platform::CPUPlace, float>);
ops::PoolKernel<paddle::platform::CPUPlace, float>,
ops::PoolKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(pool2d_cudnn_grad,
ops::PoolGradKernel<paddle::platform::CPUPlace, float>)
ops::PoolGradKernel<paddle::platform::CPUPlace, float>,
ops::PoolGradKernel<paddle::platform::CPUPlace, double>)
REGISTER_OP(pool3d_cudnn, ops::PoolOp, ops::Pool3dOpMaker, pool3d_cudnn_grad,
ops::PoolOpGrad);
REGISTER_OP_CPU_KERNEL(pool3d_cudnn,
ops::PoolKernel<paddle::platform::CPUPlace, float>,
ops::PoolKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(pool3d_cudnn_grad,
ops::PoolGradKernel<paddle::platform::CPUPlace, float>,
ops::PoolGradKernel<paddle::platform::CPUPlace, double>)

@ -52,7 +52,13 @@ class PoolCudnnOpKernel : public framework::OpKernel<T> {
ScopedTensorDescriptor input_desc;
ScopedTensorDescriptor output_desc;
ScopedPoolingDescriptor pool_desc;
DataLayout layout = DataLayout::kNCHW;
DataLayout layout;
if (strides.size() == 2U) {
layout = DataLayout::kNCHW;
} else {
layout = DataLayout::kNCDHW;
}
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
layout, framework::vectorize2int(input->dims()));
@ -112,7 +118,13 @@ class PoolCudnnGradOpKernel : public framework::OpKernel<T> {
ScopedTensorDescriptor input_desc;
ScopedTensorDescriptor output_desc;
ScopedPoolingDescriptor pool_desc;
DataLayout layout = DataLayout::kNCHW;
DataLayout layout;
if (strides.size() == 2U) {
layout = DataLayout::kNCHW;
} else {
layout = DataLayout::kNCDHW;
}
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
layout, framework::vectorize2int(input->dims()));
@ -150,5 +162,12 @@ class PoolCudnnGradOpKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(pool2d_cudnn, ops::PoolCudnnOpKernel<float>);
REGISTER_OP_GPU_KERNEL(pool2d_cudnn_grad, ops::PoolCudnnGradOpKernel<float>);
REGISTER_OP_GPU_KERNEL(pool2d_cudnn, ops::PoolCudnnOpKernel<float>,
ops::PoolCudnnOpKernel<double>);
REGISTER_OP_GPU_KERNEL(pool2d_cudnn_grad, ops::PoolCudnnGradOpKernel<float>,
ops::PoolCudnnGradOpKernel<double>);
REGISTER_OP_GPU_KERNEL(pool3d_cudnn, ops::PoolCudnnOpKernel<float>,
ops::PoolCudnnOpKernel<double>);
REGISTER_OP_GPU_KERNEL(pool3d_cudnn_grad, ops::PoolCudnnGradOpKernel<float>,
ops::PoolCudnnGradOpKernel<double>);

@ -217,14 +217,18 @@ REGISTER_OP(pool2d, ops::PoolOp, ops::Pool2dOpMaker, pool2d_grad,
ops::PoolOpGrad);
REGISTER_OP_CPU_KERNEL(pool2d,
ops::PoolKernel<paddle::platform::CPUPlace, float>);
ops::PoolKernel<paddle::platform::CPUPlace, float>,
ops::PoolKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(pool2d_grad,
ops::PoolGradKernel<paddle::platform::CPUPlace, float>)
ops::PoolGradKernel<paddle::platform::CPUPlace, float>,
ops::PoolGradKernel<paddle::platform::CPUPlace, double>)
REGISTER_OP(pool3d, ops::PoolOp, ops::Pool3dOpMaker, pool3d_grad,
ops::PoolOpGrad);
REGISTER_OP_CPU_KERNEL(pool3d,
ops::PoolKernel<paddle::platform::CPUPlace, float>);
ops::PoolKernel<paddle::platform::CPUPlace, float>,
ops::PoolKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(pool3d_grad,
ops::PoolGradKernel<paddle::platform::CPUPlace, float>);
ops::PoolGradKernel<paddle::platform::CPUPlace, float>,
ops::PoolGradKernel<paddle::platform::CPUPlace, double>);

@ -17,11 +17,15 @@ limitations under the License. */
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(pool2d,
ops::PoolKernel<paddle::platform::GPUPlace, float>);
ops::PoolKernel<paddle::platform::GPUPlace, float>,
ops::PoolKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL(pool2d_grad,
ops::PoolGradKernel<paddle::platform::GPUPlace, float>);
ops::PoolGradKernel<paddle::platform::GPUPlace, float>,
ops::PoolGradKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL(pool3d,
ops::PoolKernel<paddle::platform::GPUPlace, float>);
ops::PoolKernel<paddle::platform::GPUPlace, float>,
ops::PoolKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL(pool3d_grad,
ops::PoolGradKernel<paddle::platform::GPUPlace, float>);
ops::PoolGradKernel<paddle::platform::GPUPlace, float>,
ops::PoolGradKernel<paddle::platform::GPUPlace, double>);

@ -29,11 +29,11 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"X(Input) of Pooling should not be null.");
"Input(X) of Pooling should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Out(Output) of Pooling should not be null.");
"Output(Out) of Pooling should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Mask"),
"Mask(Output) of Pooling should not be null.");
"Output(Mask) of Pooling should not be null.");
auto in_x_dims = ctx->GetInputDim("X");
@ -67,6 +67,14 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
ctx->SetOutputDim("Mask", framework::make_ddim(output_shape));
}
protected:
framework::OpKernelType GetKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
ctx.device_context());
}
};
class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel {
@ -80,6 +88,14 @@ class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel {
"Input(X@GRAD) should not be null.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
protected:
framework::OpKernelType GetKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
ctx.device_context());
}
};
class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker {
@ -116,7 +132,7 @@ class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker {
// TypedAttrChecker don't support vector type.)
AddAttr<bool>(
"global_pooling",
"(bool, default false) Whether to use the global pooling. "
"(bool, default:false) Whether to use the global pooling. "
"If global_pooling = true, ksize and paddings will be ignored.")
.SetDefault(false);
AddAttr<std::vector<int>>("strides",
@ -126,7 +142,7 @@ class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker {
// TypedAttrChecker don't support vector type.)
AddAttr<std::vector<int>>(
"paddings",
"(vector<int>, defalut {0, 0}), paddings(height, width) of pooling "
"(vector<int>, defalut:{0, 0}), paddings(height, width) of pooling "
"operator. "
"If global_pooling = true, paddings and will be ignored.")
.SetDefault({0, 0}); // TODO(Chengduo): Add checker. (Currently,
@ -250,10 +266,12 @@ REGISTER_OP(max_pool2d_with_index, ops::MaxPoolWithIndexOp,
REGISTER_OP_CPU_KERNEL(
max_pool2d_with_index,
ops::MaxPoolWithIndexKernel<paddle::platform::CPUPlace, float>);
ops::MaxPoolWithIndexKernel<paddle::platform::CPUPlace, float, int>,
ops::MaxPoolWithIndexKernel<paddle::platform::CPUPlace, double, int>);
REGISTER_OP_CPU_KERNEL(
max_pool2d_with_index_grad,
ops::MaxPoolWithIndexGradKernel<paddle::platform::CPUPlace, float>)
ops::MaxPoolWithIndexGradKernel<paddle::platform::CPUPlace, float, int>,
ops::MaxPoolWithIndexGradKernel<paddle::platform::CPUPlace, double, int>)
REGISTER_OP(max_pool3d_with_index, ops::MaxPoolWithIndexOp,
ops::MaxPool3dWithIndexOpMaker, max_pool3d_with_index_grad,
@ -261,7 +279,9 @@ REGISTER_OP(max_pool3d_with_index, ops::MaxPoolWithIndexOp,
REGISTER_OP_CPU_KERNEL(
max_pool3d_with_index,
ops::MaxPoolWithIndexKernel<paddle::platform::CPUPlace, float>);
ops::MaxPoolWithIndexKernel<paddle::platform::CPUPlace, float, int>,
ops::MaxPoolWithIndexKernel<paddle::platform::CPUPlace, double, int>);
REGISTER_OP_CPU_KERNEL(
max_pool3d_with_index_grad,
ops::MaxPoolWithIndexGradKernel<paddle::platform::CPUPlace, float>)
ops::MaxPoolWithIndexGradKernel<paddle::platform::CPUPlace, float, int>,
ops::MaxPoolWithIndexGradKernel<paddle::platform::CPUPlace, double, int>)

@ -18,14 +18,18 @@ namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
max_pool2d_with_index,
ops::MaxPoolWithIndexKernel<paddle::platform::GPUPlace, float>);
ops::MaxPoolWithIndexKernel<paddle::platform::GPUPlace, float, int>,
ops::MaxPoolWithIndexKernel<paddle::platform::GPUPlace, double, int>);
REGISTER_OP_GPU_KERNEL(
max_pool2d_with_index_grad,
ops::MaxPoolWithIndexGradKernel<paddle::platform::GPUPlace, float>)
ops::MaxPoolWithIndexGradKernel<paddle::platform::GPUPlace, float, int>,
ops::MaxPoolWithIndexGradKernel<paddle::platform::GPUPlace, double, int>)
REGISTER_OP_GPU_KERNEL(
max_pool3d_with_index,
ops::MaxPoolWithIndexKernel<paddle::platform::GPUPlace, float>);
ops::MaxPoolWithIndexKernel<paddle::platform::GPUPlace, float, int>,
ops::MaxPoolWithIndexKernel<paddle::platform::GPUPlace, double, int>);
REGISTER_OP_GPU_KERNEL(
max_pool3d_with_index_grad,
ops::MaxPoolWithIndexGradKernel<paddle::platform::GPUPlace, float>)
ops::MaxPoolWithIndexGradKernel<paddle::platform::GPUPlace, float, int>,
ops::MaxPoolWithIndexGradKernel<paddle::platform::GPUPlace, double, int>)

@ -24,8 +24,8 @@ namespace operators {
using Tensor = framework::Tensor;
template <typename Place, typename T>
class MaxPoolWithIndexKernel : public framework::OpKernel<T> {
template <typename Place, typename T1, typename T2>
class MaxPoolWithIndexKernel : public framework::OpKernel<T1> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* in_x = context.Input<Tensor>("X");
@ -44,13 +44,13 @@ class MaxPoolWithIndexKernel : public framework::OpKernel<T> {
switch (ksize.size()) {
case 2: {
paddle::operators::math::MaxPool2dWithIndexFunctor<Place, T>
paddle::operators::math::MaxPool2dWithIndexFunctor<Place, T1, T2>
pool2d_forward;
pool2d_forward(context.device_context(), *in_x, ksize, strides,
paddings, out, mask);
} break;
case 3: {
paddle::operators::math::MaxPool3dWithIndexFunctor<Place, T>
paddle::operators::math::MaxPool3dWithIndexFunctor<Place, T1, T2>
pool3d_forward;
pool3d_forward(context.device_context(), *in_x, ksize, strides,
paddings, out, mask);
@ -60,8 +60,8 @@ class MaxPoolWithIndexKernel : public framework::OpKernel<T> {
}
};
template <typename Place, typename T>
class MaxPoolWithIndexGradKernel : public framework::OpKernel<T> {
template <typename Place, typename T1, typename T2>
class MaxPoolWithIndexGradKernel : public framework::OpKernel<T1> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* mask = context.Input<Tensor>("Mask");
@ -80,19 +80,19 @@ class MaxPoolWithIndexGradKernel : public framework::OpKernel<T> {
}
if (in_x_grad) {
in_x_grad->mutable_data<T>(context.GetPlace());
in_x_grad->mutable_data<T1>(context.GetPlace());
auto& device_ctx = context.device_context();
math::set_constant(device_ctx, in_x_grad, 0);
switch (ksize.size()) {
case 2: {
paddle::operators::math::MaxPool2dWithIndexGradFunctor<Place, T>
paddle::operators::math::MaxPool2dWithIndexGradFunctor<Place, T1, T2>
pool2d_backward;
pool2d_backward(device_ctx, *out_grad, *mask, ksize, strides,
paddings, in_x_grad);
} break;
case 3: {
paddle::operators::math::MaxPool3dWithIndexGradFunctor<Place, T>
paddle::operators::math::MaxPool3dWithIndexGradFunctor<Place, T1, T2>
pool3d_backward;
pool3d_backward(device_ctx, *out_grad, *mask, ksize, strides,
paddings, in_x_grad);

@ -0,0 +1,132 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/sequence_slice_op.h"
namespace paddle {
namespace operators {
class SequenceSliceOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SequenceSliceOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Offset"),
"Input(Offset) of SequenceSliceOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Length"),
"Input(Length) of SequenceSliceOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SequenceSliceOp should not be null.");
auto input_dims = ctx->GetInputDim("X");
auto offset_dim = ctx->GetInputDim("Offset");
auto length_dim = ctx->GetInputDim("Length");
PADDLE_ENFORCE_EQ(
offset_dim.size(), 2UL,
"Only support one level sequence now, The rank of offset must be 2.");
PADDLE_ENFORCE_EQ(
length_dim.size(), 2UL,
"Only support one level sequence now, The rank of Length must be 2.");
// Initialize the output's dims to maximum,
// and re-set to real dims by the value of Offset and Length at kernel
ctx->SetOutputDim("Out", input_dims);
}
protected:
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context());
}
};
class SequenceSliceGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"The gradient of Out should not be null.");
PADDLE_ENFORCE(ctx->HasOutputs(framework::GradVarName("X")),
"The gradient of X should not be null.");
ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X"));
}
protected:
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context());
}
};
class SequenceSliceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SequenceSliceOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(LoDTensor), "
"the input of SequenceSliceOp.");
AddInput("Offset",
"(Tensor), "
"a vector<int> to describe the offset of every input sequence for "
"sub sequence item.");
AddInput("Length",
"(Tensor), "
"a vector<int> to describe the length of every input sequence for "
"sub sequence item.");
AddOutput("Out",
"(LoDTensor), the output of SequenceSliceOp.");
AddComment(R"DOC(
Sequence slice operator
The operator crops a subsequence from given sequence with given start offset and subsequence length.
It only supports sequence (LoD Tensor with level number is 1).
- Case:
X = [[a1, a2;
b1, b2;
c1, c2]
[d1, d2;
e1, e2]]
LoD(X) = {{0, 3, 5}}; Dims(X) = (5, 2)
Offset = [[0], [1]]; Length = [[2], [1]]
Out = [[a1, a2;
b1, b2]
[e1, e2]]
LoD(Out) = {{0, 2, 3}}; Dims(Out) = (3, 2)
NOTE: The first dimension size of input, the size of offset and Length, should be equal. The offset start from 0.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(sequence_slice, ops::SequenceSliceOp, ops::SequenceSliceOpMaker,
sequence_slice_grad, ops::SequenceSliceGradOp);
REGISTER_OP_CPU_KERNEL(
sequence_slice,
ops::SequenceSliceOpKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
sequence_slice_grad,
ops::SequenceSliceGradOpKernel<paddle::platform::CPUPlace, float>);

@ -0,0 +1,23 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/sequence_slice_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
sequence_slice,
ops::SequenceSliceOpKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
sequence_slice_grad,
ops::SequenceSliceGradOpKernel<paddle::platform::GPUPlace, float>);

@ -0,0 +1,173 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/strided_memcpy.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using LoD = framework::LoD;
template <typename T>
inline LoD SequenceSliceLoD(const T& in, const int64_t* offset_data,
const int64_t* length_data) {
auto out_lod = in.lod();
size_t lod_offset = 0;
auto n = in.lod()[0].size() - 1;
out_lod[0][0] = 0;
for (size_t i = 0; i < n; ++i) {
lod_offset += length_data[i];
out_lod[0][i+1] = lod_offset;
}
return out_lod;
}
template <typename Place, typename T>
class SequenceSliceOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<LoDTensor>("X");
auto* offset = ctx.Input<Tensor>("Offset");
auto* length = ctx.Input<Tensor>("Length");
auto* out = ctx.Output<LoDTensor>("Out");
auto lod = in->lod();
auto n = lod[0].size() - 1;
PADDLE_ENFORCE_EQ(lod.size(), 1UL,
"Only support one level sequence now.");
PADDLE_ENFORCE_EQ(
n, static_cast<size_t>(length->dims()[0]),
"The size of input-sequence and length-array should be the same")
PADDLE_ENFORCE_EQ(
n, static_cast<size_t>(offset->dims()[0]),
"The size of input-sequence and offset-array should be the same")
const int64_t* offset_data = offset->data<int64_t>();
const int64_t* length_data = length->data<int64_t>();
framework::Tensor offset_cpu;
framework::Tensor length_cpu;
if (platform::is_gpu_place(ctx.GetPlace())) {
offset_cpu.mutable_data<T>(offset->dims(), platform::CPUPlace());
offset_cpu.CopyFrom(*offset, platform::CPUPlace(), ctx.device_context());
offset_data = offset_cpu.data<int64_t>();
length_cpu.mutable_data<T>(length->dims(), platform::CPUPlace());
length_cpu.CopyFrom(*length, platform::CPUPlace(), ctx.device_context());
length_data = length_cpu.data<int64_t>();
}
for (size_t i = 0; i < n; ++i) {
PADDLE_ENFORCE_LT(0, offset_data[i],
"The offset[%d] must greater than zero.", i)
PADDLE_ENFORCE_LT(0, length_data[i],
"The length[%d] must greater than zero.", i)
PADDLE_ENFORCE_LT(
lod[0][i] + offset_data[i] + length_data[i],
lod[0][i + 1],
"The target tensor's length overflow.")
}
out->mutable_data<T>(ctx.GetPlace());
auto out_lod = SequenceSliceLoD(*in, offset_data, length_data);
auto out_dims = in->dims();
out_dims[0] = out_lod[0][out_lod[0].size() - 1];
out->Resize(out_dims);
out->set_lod(out_lod);
auto in_stride = framework::stride(in->dims());
auto out_stride = framework::stride(out->dims());
size_t out_offset = 0;
for (size_t i = 0; i < n; ++i) {
Tensor in_t =
in->Slice(static_cast<int>(lod[0][i] + offset_data[i]),
static_cast<int>(lod[0][i] + offset_data[i] +
length_data[i]));
StridedMemcpy<T>(ctx.device_context(), in_t.data<T>(),
in_stride, in_t.dims(), out_stride,
out->data<T>() + out_offset);
out_offset += length_data[i] * in_stride[0];
}
}
};
template <typename Place, typename T>
class SequenceSliceGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<LoDTensor>("X");
auto* offset = ctx.Input<Tensor>("Offset");
auto* length = ctx.Input<Tensor>("Length");
auto* out_grad =
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto* x_grad =
ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
const int64_t* offset_data = offset->data<int64_t>();
const int64_t* length_data = length->data<int64_t>();
framework::Tensor offset_cpu;
framework::Tensor length_cpu;
if (platform::is_gpu_place(ctx.GetPlace())) {
offset_cpu.mutable_data<T>(offset->dims(), platform::CPUPlace());
offset_cpu.CopyFrom(*offset, platform::CPUPlace(), ctx.device_context());
offset_data = offset_cpu.data<int64_t>();
length_cpu.mutable_data<T>(length->dims(), platform::CPUPlace());
length_cpu.CopyFrom(*length, platform::CPUPlace(), ctx.device_context());
length_data = length_cpu.data<int64_t>();
}
auto lod = in->lod();
auto out_lod = out_grad->lod();
if (x_grad) {
x_grad->mutable_data<T>(ctx.GetPlace());
x_grad->set_lod(in->lod());
math::SetConstant<Place, T> set_zero;
set_zero(ctx.device_context(), x_grad, static_cast<T>(0));
auto out_grad_stride = framework::stride(out_grad->dims());
for (size_t i = 0; i < out_lod[0].size() - 1; ++i) {
Tensor out_grad_t =
out_grad->Slice(static_cast<int>(out_lod[0][i]),
static_cast<int>(out_lod[0][i + 1]));
auto out_grad_stride = framework::stride(out_grad_t.dims());
auto x_grad_stride = framework::stride(x_grad->dims());
Tensor x_grad_t = x_grad->Slice(
static_cast<int>(lod[0][i] + offset_data[i]),
static_cast<int>(lod[0][i] + offset_data[i] + length_data[i]));
StridedMemcpy<T>(ctx.device_context(), out_grad_t.data<T>(),
out_grad_stride, out_grad_t.dims(), x_grad_stride,
x_grad_t.data<T>());
}
}
}
};
} // namespace operators
} // namespace paddle

@ -224,13 +224,15 @@ class ScopedConvolutionDescriptor {
PADDLE_ENFORCE_EQ(pads.size(), strides.size());
PADDLE_ENFORCE_EQ(pads.size(), dilations.size());
#if CUDNN_VERSION < 6000
#if !CUDNN_VERSION_MIN(6, 0, 0)
// cudnn v5 does not support dilation conv, the argument is called upscale
// instead of dilations and it is must be one.
for (size_t i = 0; i < dilations.size(); ++i) {
PADDLE_ENFORCE_EQ(
dilations[i], 1,
"Dilations conv is not supported in this cuDNN version");
"Dilations conv is not supported in this cuDNN version(%d.%d.%d).",
CUDNN_VERSION / 1000, CUDNN_VERSION % 1000 / 100,
CUDNN_VERSION % 100);
}
#endif

@ -38,6 +38,26 @@ TEST(CudnnHelper, ScopedTensorDescriptor) {
EXPECT_EQ(strides[2], 6);
EXPECT_EQ(strides[1], 36);
EXPECT_EQ(strides[0], 144);
// test tensor5d: ScopedTensorDescriptor
ScopedTensorDescriptor tensor5d_desc;
std::vector<int> shape_5d = {2, 4, 6, 6, 6};
auto desc_5d = tensor5d_desc.descriptor<float>(DataLayout::kNCDHW, shape_5d);
std::vector<int> dims_5d(5);
std::vector<int> strides_5d(5);
paddle::platform::dynload::cudnnGetTensorNdDescriptor(
desc_5d, 5, &type, &nd, dims_5d.data(), strides_5d.data());
EXPECT_EQ(nd, 5);
for (size_t i = 0; i < dims_5d.size(); ++i) {
EXPECT_EQ(dims_5d[i], shape_5d[i]);
}
EXPECT_EQ(strides_5d[4], 1);
EXPECT_EQ(strides_5d[3], 6);
EXPECT_EQ(strides_5d[2], 36);
EXPECT_EQ(strides_5d[1], 216);
EXPECT_EQ(strides_5d[0], 864);
}
TEST(CudnnHelper, ScopedFilterDescriptor) {
@ -60,6 +80,20 @@ TEST(CudnnHelper, ScopedFilterDescriptor) {
for (size_t i = 0; i < shape.size(); ++i) {
EXPECT_EQ(kernel[i], shape[i]);
}
ScopedFilterDescriptor filter_desc_4d;
std::vector<int> shape_4d = {2, 3, 3, 3};
auto desc_4d = filter_desc.descriptor<float>(DataLayout::kNCDHW, shape_4d);
std::vector<int> kernel_4d(4);
paddle::platform::dynload::cudnnGetFilterNdDescriptor(
desc_4d, 4, &type, &format, &nd, kernel_4d.data());
EXPECT_EQ(GetCudnnTensorFormat(DataLayout::kNCHW), format);
EXPECT_EQ(nd, 4);
for (size_t i = 0; i < shape_4d.size(); ++i) {
EXPECT_EQ(kernel_4d[i], shape_4d[i]);
}
}
TEST(CudnnHelper, ScopedConvolutionDescriptor) {

@ -17,7 +17,8 @@ __all__ = [
"IdentityActivation", "LinearActivation", 'SequenceSoftmaxActivation',
'ExpActivation', "ReluActivation", "BReluActivation", "SoftReluActivation",
"STanhActivation", "AbsActivation", "SquareActivation", "BaseActivation",
"LogActivation", "SqrtActivation", "ReciprocalActivation"
"LogActivation", "SqrtActivation", "ReciprocalActivation",
"SoftSignActivation"
]
@ -243,8 +244,20 @@ class ReciprocalActivation(BaseActivation):
Reciprocal Activation.
.. math::
f(z) = 1/z
f(z)=\\frac{1}{z}
"""
def __init__(self):
BaseActivation.__init__(self, 'reciprocal', False)
class SoftSignActivation(BaseActivation):
"""
SoftSign Activation.
.. math::
f(z)=\\frac{z}{1 + |z|}
"""
def __init__(self):
BaseActivation.__init__(self, 'softsign', False)

@ -661,7 +661,7 @@ def conv2d(input,
if groups is None:
num_filter_channels = num_channels
else:
if num_channels % groups is not 0:
if num_channels % groups != 0:
raise ValueError("num_channels must be divisible by groups.")
num_filter_channels = num_channels / groups

@ -4,6 +4,7 @@ import paddle.v2.fluid.core as core
import paddle.v2.fluid.framework as framework
import paddle.v2.fluid.layers as layers
import paddle.v2.fluid.nets as nets
import paddle.v2.fluid.evaluator as evaluator
from paddle.v2.fluid.executor import Executor
from paddle.v2.fluid.initializer import XavierInitializer
from paddle.v2.fluid.optimizer import AdamOptimizer
@ -103,12 +104,13 @@ net = vgg16_bn_drop(images)
predict = layers.fc(input=net, size=classdim, act='softmax')
cost = layers.cross_entropy(input=predict, label=label)
avg_cost = layers.mean(x=cost)
accuracy = layers.accuracy(input=predict, label=label)
# optimizer = SGDOptimizer(learning_rate=0.001)
optimizer = AdamOptimizer(learning_rate=0.001)
opts = optimizer.minimize(avg_cost)
accuracy, acc_out = evaluator.accuracy(input=predict, label=label)
BATCH_SIZE = 128
PASS_NUM = 1
@ -124,6 +126,7 @@ exe.run(framework.default_startup_program())
for pass_id in range(PASS_NUM):
batch_id = 0
accuracy.reset(exe)
for data in train_reader():
img_data = np.array(map(lambda x: x[0].reshape(data_shape),
data)).astype("float32")
@ -141,12 +144,14 @@ for pass_id in range(PASS_NUM):
outs = exe.run(framework.default_main_program(),
feed={"pixel": tensor_img,
"label": tensor_y},
fetch_list=[avg_cost, accuracy])
fetch_list=[avg_cost, acc_out])
loss = np.array(outs[0])
acc = np.array(outs[1])
pass_acc = accuracy.eval(exe)
print("pass_id:" + str(pass_id) + " batch_id:" + str(batch_id) +
" loss:" + str(loss) + " acc:" + str(acc))
" loss:" + str(loss) + " acc:" + str(acc) + " pass_acc:" + str(
pass_acc))
batch_id = batch_id + 1
if batch_id > 1:

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save