From cb06337f9e6503a14d8802ee2355e1aa64db9960 Mon Sep 17 00:00:00 2001 From: Yang Yang Date: Fri, 16 Feb 2018 01:07:36 +0000 Subject: [PATCH 1/5] change outputsize func name --- paddle/fluid/operators/conv_op.cc | 5 +++-- paddle/fluid/operators/conv_op.h | 4 ++-- paddle/fluid/operators/im2sequence_op.cc | 8 ++++---- paddle/fluid/operators/im2sequence_op.h | 20 ++++++++++---------- paddle/fluid/operators/pool_op.cc | 4 ++-- paddle/fluid/operators/pool_with_index_op.cc | 4 ++-- paddle/fluid/operators/unpool_op.cc | 6 +++--- 7 files changed, 26 insertions(+), 25 deletions(-) diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index 6b378ec1bc..2ecece7073 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -60,8 +60,9 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { "Due to the settings of paddings, filter_dims and " "dilations, the output size is less than 0, please check " "again."); - output_shape.push_back(OutputSize(in_dims[i + 2], filter_dims[i + 2], - dilations[i], paddings[i], strides[i])); + output_shape.push_back(ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], + dilations[i], paddings[i], + strides[i])); } ctx->SetOutputDim("Output", framework::make_ddim(output_shape)); ctx->ShareLoD("Input", "Output"); diff --git a/paddle/fluid/operators/conv_op.h b/paddle/fluid/operators/conv_op.h index ecbe3d505a..c93c2e73f7 100644 --- a/paddle/fluid/operators/conv_op.h +++ b/paddle/fluid/operators/conv_op.h @@ -28,8 +28,8 @@ using Tensor = framework::Tensor; // Base convolution operator definations for other conv // like operators to reuse the implementation. -inline int OutputSize(int input_size, int filter_size, int dilation, - int padding, int stride) { +inline int ConvOutputSize(int input_size, int filter_size, int dilation, + int padding, int stride) { const int dkernel = dilation * (filter_size - 1) + 1; const int output_size = (input_size + 2 * padding - dkernel) / stride + 1; return output_size; diff --git a/paddle/fluid/operators/im2sequence_op.cc b/paddle/fluid/operators/im2sequence_op.cc index 5bc28e0a52..048391549d 100644 --- a/paddle/fluid/operators/im2sequence_op.cc +++ b/paddle/fluid/operators/im2sequence_op.cc @@ -41,10 +41,10 @@ class Im2SequenceOp : public framework::OperatorWithKernel { int img_height = in_dim[2]; int img_width = in_dim[3]; - int output_height = OutputSize(img_height, kernels[0], paddings[0], - paddings[2], strides[0]); - int output_width = - OutputSize(img_width, kernels[1], paddings[1], paddings[3], strides[1]); + int output_height = Im2SeqOutputSize(img_height, kernels[0], paddings[0], + paddings[2], strides[0]); + int output_width = Im2SeqOutputSize(img_width, kernels[1], paddings[1], + paddings[3], strides[1]); ctx->SetOutputDim("Out", {batch_size * output_height * output_width, img_channels * kernels[0] * kernels[1]}); diff --git a/paddle/fluid/operators/im2sequence_op.h b/paddle/fluid/operators/im2sequence_op.h index 4193819b78..a6a83fefbc 100644 --- a/paddle/fluid/operators/im2sequence_op.h +++ b/paddle/fluid/operators/im2sequence_op.h @@ -26,8 +26,8 @@ namespace operators { using Tensor = framework::Tensor; using LoDTensor = framework::LoDTensor; -inline int OutputSize(int input_size, int filter_size, int padding_0, - int padding_1, int stride) { +inline int Im2SeqOutputSize(int input_size, int filter_size, int padding_0, + int padding_1, int stride) { const int output_size = (input_size + padding_0 + padding_1 - filter_size) / stride + 1; return output_size; @@ -53,10 +53,10 @@ class Im2SequenceKernel : public framework::OpKernel { auto kernels = ctx.Attr>("kernels"); auto strides = ctx.Attr>("strides"); auto paddings = ctx.Attr>("paddings"); - int output_height = OutputSize(img_height, kernels[0], paddings[0], - paddings[2], strides[0]); - int output_width = - OutputSize(img_width, kernels[1], paddings[1], paddings[3], strides[1]); + int output_height = Im2SeqOutputSize(img_height, kernels[0], paddings[0], + paddings[2], strides[0]); + int output_width = Im2SeqOutputSize(img_width, kernels[1], paddings[1], + paddings[3], strides[1]); const std::vector dilations({1, 1}); @@ -109,10 +109,10 @@ class Im2SequenceGradKernel : public framework::OpKernel { auto kernels = ctx.Attr>("kernels"); auto strides = ctx.Attr>("strides"); auto paddings = ctx.Attr>("paddings"); - int output_height = OutputSize(img_height, kernels[0], paddings[0], - paddings[2], strides[0]); - int output_width = - OutputSize(img_width, kernels[1], paddings[1], paddings[3], strides[1]); + int output_height = Im2SeqOutputSize(img_height, kernels[0], paddings[0], + paddings[2], strides[0]); + int output_width = Im2SeqOutputSize(img_width, kernels[1], paddings[1], + paddings[3], strides[1]); const std::vector dilations({1, 1}); diff --git a/paddle/fluid/operators/pool_op.cc b/paddle/fluid/operators/pool_op.cc index a80b23b8ed..c7729ad132 100644 --- a/paddle/fluid/operators/pool_op.cc +++ b/paddle/fluid/operators/pool_op.cc @@ -17,7 +17,7 @@ limitations under the License. */ namespace paddle { namespace operators { -int OutputSizePool(int input_size, int filter_size, int padding, int stride) { +int PoolOutputSize(int input_size, int filter_size, int padding, int stride) { int output_size = (input_size - filter_size + 2 * padding) / stride + 1; return output_size; } @@ -55,7 +55,7 @@ void PoolOp::InferShape(framework::InferShapeContext *ctx) const { std::vector output_shape({in_x_dims[0], in_x_dims[1]}); for (size_t i = 0; i < ksize.size(); ++i) { output_shape.push_back( - OutputSizePool(in_x_dims[i + 2], ksize[i], paddings[i], strides[i])); + PoolOutputSize(in_x_dims[i + 2], ksize[i], paddings[i], strides[i])); } ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); ctx->ShareLoD("X", "Out"); diff --git a/paddle/fluid/operators/pool_with_index_op.cc b/paddle/fluid/operators/pool_with_index_op.cc index 3a59365d17..4df0a14577 100644 --- a/paddle/fluid/operators/pool_with_index_op.cc +++ b/paddle/fluid/operators/pool_with_index_op.cc @@ -17,7 +17,7 @@ limitations under the License. */ namespace paddle { namespace operators { -inline int OutputSizeMaxPool(int input_size, int filter_size, int padding, +inline int MaxPoolOutputSize(int input_size, int filter_size, int padding, int stride) { int output_size = (input_size - filter_size + 2 * padding) / stride + 1; return output_size; @@ -61,7 +61,7 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel { std::vector output_shape({in_x_dims[0], in_x_dims[1]}); for (size_t i = 0; i < ksize.size(); ++i) { - output_shape.push_back(OutputSizeMaxPool(in_x_dims[i + 2], ksize[i], + output_shape.push_back(MaxPoolOutputSize(in_x_dims[i + 2], ksize[i], paddings[i], strides[i])); } ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); diff --git a/paddle/fluid/operators/unpool_op.cc b/paddle/fluid/operators/unpool_op.cc index d3bd7fda09..0ca7ea00fa 100644 --- a/paddle/fluid/operators/unpool_op.cc +++ b/paddle/fluid/operators/unpool_op.cc @@ -64,7 +64,7 @@ Paper: http://www.matthewzeiler.com/wp-content/uploads/2017/07/iccv2011.pdf } }; -int OutputSize(int input_size, int ksize, int padding, int stride) { +int UnpoolOutputSize(int input_size, int ksize, int padding, int stride) { int output_size = (input_size - 1) * stride - 2 * padding + ksize; return output_size; } @@ -101,8 +101,8 @@ class UnpoolOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(in_x_dims, in_y_dims); std::vector output_shape({in_x_dims[0], in_x_dims[1]}); for (size_t i = 0; i < ksize.size(); ++i) { - output_shape.push_back( - OutputSize(in_x_dims[i + 2], ksize[i], paddings[i], strides[i])); + output_shape.push_back(UnpoolOutputSize(in_x_dims[i + 2], ksize[i], + paddings[i], strides[i])); } ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); } From 74404fadcd5256d321f5440fec9fac44a3c8fc3e Mon Sep 17 00:00:00 2001 From: Abhinav Arora Date: Thu, 15 Feb 2018 17:08:57 -0800 Subject: [PATCH 2/5] Python implementation for a proposed Go Op. (#8434) * Adding Python boilerplate code for Go op * Add very basic test case * Adding the python logic for go routine * Fix syntax * Changing test to notest * Rename Routine to Go * Combining GoGuard and Go in one class * Modify test * Adding fluid close channel * Fixing __init__.py for calling fluid.go() * Adding stubs for channel methods and updating test case * Removing import * * Adding imports from concurrency --- python/paddle/v2/fluid/__init__.py | 4 +- python/paddle/v2/fluid/concurrency.py | 86 +++++++++++++++++++ .../v2/fluid/tests/notest_concurrency.py | 38 ++++++++ 3 files changed, 127 insertions(+), 1 deletion(-) create mode 100644 python/paddle/v2/fluid/concurrency.py create mode 100644 python/paddle/v2/fluid/tests/notest_concurrency.py diff --git a/python/paddle/v2/fluid/__init__.py b/python/paddle/v2/fluid/__init__.py index 9f710c4a4a..361fb3f5ad 100644 --- a/python/paddle/v2/fluid/__init__.py +++ b/python/paddle/v2/fluid/__init__.py @@ -34,13 +34,15 @@ from data_feeder import DataFeeder from core import LoDTensor, CPUPlace, CUDAPlace from distribute_transpiler import DistributeTranspiler from distribute_transpiler_simple import SimpleDistributeTranspiler +from concurrency import (Go, make_channel, channel_send, channel_recv, + channel_close) import clip from memory_optimization_transpiler import memory_optimize import profiler Tensor = LoDTensor -__all__ = framework.__all__ + executor.__all__ + [ +__all__ = framework.__all__ + executor.__all__ + concurrency.__all__ + [ 'io', 'initializer', 'layers', diff --git a/python/paddle/v2/fluid/concurrency.py b/python/paddle/v2/fluid/concurrency.py new file mode 100644 index 0000000000..5f868b6e86 --- /dev/null +++ b/python/paddle/v2/fluid/concurrency.py @@ -0,0 +1,86 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TODO: Variables: make_channel +# TODO: Operators: send, close_channel, recv, go, select +from layers.control_flow import BlockGuard +from layer_helper import LayerHelper + +__all__ = [ + 'Go', + 'make_channel', + 'channel_send', + 'channel_recv', + 'channel_close', +] + + +class Go(BlockGuard): + def __init__(self, name=None): + self.helper = LayerHelper("go", name=name) + super(Go, self).__init__(self.helper.main_program) + + def __enter__(self): + super(Go, self).__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is not None: + return False + self.construct_go_op() + return super(Go, self).__exit__(exc_type, exc_val, exc_tb) + + def construct_go_op(self): + main_program = self.helper.main_program + go_block = main_program.current_block() + parent_block = main_program.block(main_program.current_block() + .parent_idx) + + x_name_list = set() + out_vars = set() + for op in go_block.ops: + # Iterate over all operators, get all the inputs + # and add as input to the Go operator. + for iname in op.input_names: + for in_var_name in op.input(iname): + x_name_list.add(in_var_name) + + # Iterate over all operators , get all the outputs + # add to the output list of Go operator only if + # they exist in the parent block. + for oname in op.output_names: + for out_var_name in op.output(oname): + if out_var_name in parent_block.vars: + out_vars.add(parent_block.var(out_var_name)) + + parent_block.append_op( + type='go', + inputs={'X': [parent_block.var(x_name) for x_name in x_name_list]}, + outputs={'Out': out_vars}, + attrs={'sub_block': go_block}) + + +def make_channel(dtype, size=0): + return True + + +def channel_send(channel, value): + return True + + +def channel_recv(channel): + return True + + +def channel_close(channel): + return True diff --git a/python/paddle/v2/fluid/tests/notest_concurrency.py b/python/paddle/v2/fluid/tests/notest_concurrency.py new file mode 100644 index 0000000000..9d87ed9c07 --- /dev/null +++ b/python/paddle/v2/fluid/tests/notest_concurrency.py @@ -0,0 +1,38 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle.v2.fluid as fluid +import paddle.v2.fluid.core as core +from paddle.v2.fluid.executor import Executor + + +class TestRoutineOp(unittest.TestCase): + def test_simple_routine(self): + ch = fluid.make_channel(dtype=bool) + with fluid.Go(): + fluid.channel_send(ch, True) + + result = fluid.channel_recv(ch) + fluid.channel_close(ch) + + cpu = core.CPUPlace() + exe = Executor(cpu) + + outs = exe.run(fetch_list=[result]) + self.assertEqual(outs[0], True) + + +if __name__ == '__main__': + unittest.main() From 74e0eb7267b4eae7016a44fb1fbc62bf2e952bde Mon Sep 17 00:00:00 2001 From: kexinzhao Date: Thu, 15 Feb 2018 17:10:29 -0800 Subject: [PATCH 3/5] make float16 a pod type (#8456) --- paddle/fluid/framework/tensor_impl.h | 5 +++- paddle/fluid/platform/float16.h | 43 ++++++++++++++++++++++----- paddle/fluid/platform/float16_test.cc | 32 ++++++++++++++++---- paddle/fluid/platform/float16_test.cu | 33 ++++++++++++++++++++ 4 files changed, 99 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/framework/tensor_impl.h b/paddle/fluid/framework/tensor_impl.h index 59e6269ea0..638bd0db9d 100644 --- a/paddle/fluid/framework/tensor_impl.h +++ b/paddle/fluid/framework/tensor_impl.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/float16.h" namespace paddle { namespace framework { @@ -52,7 +53,9 @@ struct SizeOfTypeFunctor { }; static inline size_t SizeOfType(std::type_index type) { - SizeOfTypeFunctor functor; + SizeOfTypeFunctor + functor; size_t size = functor(type); PADDLE_ENFORCE(size != 0UL, "Cannot get size of type %s", type.name()); return size; diff --git a/paddle/fluid/platform/float16.h b/paddle/fluid/platform/float16.h index c36bfad4bc..cf6a4b09db 100644 --- a/paddle/fluid/platform/float16.h +++ b/paddle/fluid/platform/float16.h @@ -62,6 +62,7 @@ limitations under the License. */ #define PADDLE_ALIGN(x) __attribute__((aligned(x))) namespace paddle { +namespace platform { // Use PADDLE_ALIGNED(2) to ensure that each float16 will be allocated // and aligned at least on a 2-byte boundary, which leads to efficient @@ -71,11 +72,21 @@ struct PADDLE_ALIGN(2) float16 { public: uint16_t x; - // Constructors - HOSTDEVICE inline float16() : x(0) {} + // The following defaulted special class member functions + // are added to make float16 pass the std::is_trivial test + HOSTDEVICE inline float16() = default; - HOSTDEVICE inline float16(const float16& h) : x(h.x) {} + HOSTDEVICE inline float16(const float16&) = default; + HOSTDEVICE inline float16& operator=(const float16&) = default; + + HOSTDEVICE inline float16(float16&&) = default; + + HOSTDEVICE inline float16& operator=(float16&&) = default; + + HOSTDEVICE inline ~float16() = default; + +// Constructors #ifdef PADDLE_CUDA_FP16 HOSTDEVICE inline explicit float16(const half& h) { #if CUDA_VERSION >= 9000 @@ -136,11 +147,6 @@ struct PADDLE_ALIGN(2) float16 { HOSTDEVICE inline explicit float16(const T& val) : x(float16(static_cast(val)).x) {} - HOSTDEVICE inline float16& operator=(const float16& rhs) { - x = rhs.x; - return *this; - } - // Assignment operators #ifdef PADDLE_CUDA_FP16 HOSTDEVICE inline float16& operator=(const half& rhs) { @@ -727,4 +733,25 @@ HOSTDEVICE inline bool operator>=(const float16& a, const float16& b) { return float(a) >= float(b); } #endif + +} // namespace platform } // namespace paddle + +namespace std { + +// Override the std::is_pod::value for float16 +// The reason is that different compilers implemented std::is_pod based on +// different C++ standards. float16 class is a plain old data in C++11 given +// that it is both trivial and standard_layout. +// However, std::is_pod in nvcc 8.0 host c++ compiler follows C++0x and is +// more restricted in that you cannot provide any customized +// constructor in float16. Hence, we override is_pod here following C++11 +// so that .cu files can be successfully compiled by nvcc. +template <> +struct is_pod { + static const bool value = + is_trivial::value && + is_standard_layout::value; +}; + +} // namespace std diff --git a/paddle/fluid/platform/float16_test.cc b/paddle/fluid/platform/float16_test.cc index bed29dbfa7..b716ad9df4 100644 --- a/paddle/fluid/platform/float16_test.cc +++ b/paddle/fluid/platform/float16_test.cc @@ -10,10 +10,13 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/platform/float16.h" +#include "paddle/fluid/framework/init.h" +#include "paddle/fluid/framework/lod_tensor.h" #include namespace paddle { +namespace platform { TEST(float16, conversion_cpu) { // Explicit conversion from Eigen::half @@ -54,13 +57,9 @@ TEST(float16, conversion_cpu) { EXPECT_EQ(float16(true).x, 0x3c00); EXPECT_EQ(float16(false).x, 0x0000); - // Default constructor - float16 v_def; - EXPECT_EQ(v_def.x, 0x0000); - // Assignment operator float16 v_assign; - v_assign = v_def; + v_assign = float16(0); EXPECT_EQ(v_assign.x, 0x0000); v_assign = Eigen::half(1.0f); EXPECT_EQ(v_assign.x, 0x3c00); @@ -116,4 +115,27 @@ TEST(float16, comparison_cpu) { EXPECT_FALSE(float16(-0.0f) > float16(0.0f)); } +TEST(float16, lod_tensor_cpu) { + framework::LoDTensor lod_tensor; + + std::vector input_data = {float16(1.0f), float16(0.5f), + float16(0.33333f), float16(0.0f)}; + EXPECT_EQ(input_data[0].x, 0x3c00); + EXPECT_EQ(input_data[1].x, 0x3800); + EXPECT_EQ(input_data[2].x, 0x3555); + EXPECT_EQ(input_data[3].x, 0x0000); + + lod_tensor.Resize({4, 1}); + lod_tensor.set_lod(framework::LoD({{0, 2, 4}})); + float16* data_ptr = lod_tensor.mutable_data(CPUPlace()); + + EXPECT_NE(data_ptr, nullptr); + EXPECT_EQ(input_data.size(), static_cast(lod_tensor.numel())); + for (size_t i = 0; i < input_data.size(); ++i) { + data_ptr[i] = input_data[i]; + EXPECT_EQ(data_ptr[i].x, input_data[i].x); + } +} + +} // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/float16_test.cu b/paddle/fluid/platform/float16_test.cu index 7e6c9f58ac..567209df4e 100644 --- a/paddle/fluid/platform/float16_test.cu +++ b/paddle/fluid/platform/float16_test.cu @@ -13,6 +13,8 @@ limitations under the License. */ #include +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/tensor_util.h" #include "paddle/utils/Logging.h" #define ARITHMETIC_KERNEL(op_type, sign) \ @@ -108,6 +110,7 @@ limitations under the License. */ #ifdef PADDLE_CUDA_FP16 namespace paddle { +namespace platform { #if CUDA_VERSION < 9000 ARITHMETIC_KERNEL(Add, +) @@ -209,5 +212,35 @@ TEST(float16, conversion_on_gpu) { EXPECT_EQ(v_assign.x, 0x3c00); } +TEST(float16, lod_tensor_on_gpu) { + framework::LoDTensor src_tensor; + framework::LoDTensor gpu_tensor; + framework::LoDTensor dst_tensor; + + float16* src_ptr = src_tensor.mutable_data( + framework::make_ddim({2, 2}), CPUPlace()); + + float16 arr[4] = {float16(1.0f), float16(0.5f), float16(0.33333f), + float16(0.0f)}; + memcpy(src_ptr, arr, 4 * sizeof(float16)); + + // CPU LoDTensor to GPU LoDTensor + CUDAPlace gpu_place(0); + CUDADeviceContext gpu_ctx(gpu_place); + framework::TensorCopy(src_tensor, gpu_place, gpu_ctx, &gpu_tensor); + + // GPU LoDTensor to CPU LoDTensor + framework::TensorCopy(gpu_tensor, CPUPlace(), gpu_ctx, &dst_tensor); + + // Sync before comparing LoDTensors + gpu_ctx.Wait(); + const float16* dst_ptr = dst_tensor.data(); + ASSERT_NE(src_ptr, dst_ptr); + for (size_t i = 0; i < 4; ++i) { + EXPECT_EQ(src_ptr[i].x, dst_ptr[i].x); + } +} + +} // namespace platform } // namespace paddle #endif // PADDLE_CUDA_FP16 From c7ad26d6a4a37c6a2f0a59408e93fda23315cf94 Mon Sep 17 00:00:00 2001 From: Abhinav Arora Date: Thu, 15 Feb 2018 21:32:07 -0800 Subject: [PATCH 4/5] [WIP] Move DataType enum inside VarType (#8447) * Move Pod Types from DataType enum to Type enum * Fixed data_type.h * Fix type in TensorDesc * Add comment to framework.proto * Fixed type in data_type.h * Updated format of type in data_type.h * Fix var_desc.h * Fix op_kernel_type.h * Fixed data_type_transform_test.cc * Fix operator.h * Fixed data_type_transform.cc * Fixed op_kernel_type_test.cc * Fix operator.cc * Fixed data_layout_transform_test.cc * Fix var_desc.cc * Fixed assign_value_op.cc * Fixed assign_value_op.h * fixed protobuf.cc * Fix data_layout_transform_test.cc and op_kernel_type_test.cc * Fixed rnn_memory_helper_op.cc * Fix progrma_desc_test.cc * Fixed fill_constant_batch_size_like_op.cc * Fix operator_test.cc * Fixed fill_constant_op.cc * Fixed gaussian_random_op.cc * Fixed uniform_random_op.cc * Fixed edit_distance_op.cc * Fixed fill_constant_batch_size_like_op.cc * Fixed rnn_memory_helper_op.cc * Fixed chunk_eval_op.cc * Fixed assign_value_op.cc * Fixed assign_value_op.h * Fixed cast_op.h * Fixed cast_op.h * Fix fill constant op * Fixed clang for assign_value_op.cc * Fix one_hot_op.h * Fix one_hot_op.cc * Fix fill_op.cc * Fixed sum_op.cc * Fixed sum_op clang * Fix uniform_random_op.cc * Fix gaussian_random_op.cc * Fix backward.cc * Fix protobuf.cc * Fixed prune_test.cc * Fixed op_registry_test.cc * Fix data_device_transform_test.cu * Fix travis error * Fixed one_hot_op.cu * Fixed op_registry_test.cc * Fixed nccl_op.cc * Fixing python tests * Revert "Fixing python tests" This reverts commit fccaa4c5818ed9f379ea1ce4315066cc78076c64. * Fixing Pybind to remove data type * Fixing tensor.py * Updated the new files: * Resolve error in merge conflict of fill_constant_batch_size_like_op.cc --- paddle/fluid/framework/backward.cc | 2 +- .../framework/data_device_transform_test.cu | 4 +- .../framework/data_layout_transform_test.cc | 4 +- paddle/fluid/framework/data_type.h | 54 +++++++++---------- paddle/fluid/framework/data_type_transform.cc | 10 ++-- .../framework/data_type_transform_test.cc | 6 +-- paddle/fluid/framework/framework.proto | 41 +++++++------- paddle/fluid/framework/op_kernel_type.h | 6 +-- paddle/fluid/framework/op_kernel_type_test.cc | 4 +- paddle/fluid/framework/op_registry_test.cc | 8 +-- paddle/fluid/framework/operator.cc | 4 +- paddle/fluid/framework/operator.h | 4 +- paddle/fluid/framework/operator_test.cc | 2 +- paddle/fluid/framework/program_desc_test.cc | 8 +-- paddle/fluid/framework/prune_test.cc | 2 +- paddle/fluid/framework/var_desc.cc | 10 ++-- paddle/fluid/framework/var_desc.h | 9 ++-- paddle/fluid/operators/assign_value_op.cc | 7 +-- paddle/fluid/operators/assign_value_op.h | 4 +- paddle/fluid/operators/cast_op.h | 3 +- paddle/fluid/operators/chunk_eval_op.cc | 2 +- paddle/fluid/operators/edit_distance_op.cc | 2 +- .../fill_constant_batch_size_like_op.cc | 4 +- paddle/fluid/operators/fill_constant_op.cc | 4 +- paddle/fluid/operators/fill_op.cc | 5 +- .../gaussian_random_batch_size_like_op.cc | 4 +- paddle/fluid/operators/gaussian_random_op.cc | 4 +- paddle/fluid/operators/nccl_op.cc | 2 +- paddle/fluid/operators/one_hot_op.cc | 2 +- paddle/fluid/operators/one_hot_op.cu | 3 +- paddle/fluid/operators/one_hot_op.h | 3 +- .../fluid/operators/rnn_memory_helper_op.cc | 4 +- paddle/fluid/operators/sum_op.cc | 3 +- .../uniform_random_batch_size_like_op.cc | 4 +- paddle/fluid/operators/uniform_random_op.cc | 4 +- paddle/fluid/pybind/protobuf.cc | 16 +++--- python/paddle/v2/fluid/backward.py | 2 +- python/paddle/v2/fluid/data_feeder.py | 8 +-- python/paddle/v2/fluid/evaluator.py | 2 +- python/paddle/v2/fluid/framework.py | 29 +++++----- python/paddle/v2/fluid/layers/control_flow.py | 2 +- python/paddle/v2/fluid/layers/nn.py | 2 +- python/paddle/v2/fluid/layers/tensor.py | 14 ++--- .../fluid/memory_optimization_transpiler.py | 14 ++--- .../paddle/v2/fluid/tests/test_cpp_reader.py | 2 +- .../v2/fluid/tests/unittests/op_test.py | 4 +- .../tests/unittests/test_batch_norm_op.py | 4 +- .../v2/fluid/tests/unittests/test_cast_op.py | 4 +- .../v2/fluid/tests/unittests/test_fill_op.py | 2 +- .../tests/unittests/test_layer_norm_op.py | 4 +- .../fluid/tests/unittests/test_one_hot_op.py | 2 +- .../fluid/tests/unittests/test_parameter.py | 2 +- .../tests/unittests/test_protobuf_descs.py | 7 +-- .../v2/fluid/tests/unittests/test_variable.py | 6 +-- 54 files changed, 189 insertions(+), 179 deletions(-) diff --git a/paddle/fluid/framework/backward.cc b/paddle/fluid/framework/backward.cc index 68f4fd4424..1314af2b3d 100644 --- a/paddle/fluid/framework/backward.cc +++ b/paddle/fluid/framework/backward.cc @@ -341,7 +341,7 @@ static void CreateGradVarInBlock( auto* param = block_desc->FindVarRecursive(pname); auto* grad = block_desc->FindVar(arg); if (param == nullptr) { - grad->SetDataType(proto::DataType::FP32); + grad->SetDataType(proto::VarType::FP32); } else { grad->SetDataType(param->GetDataType()); } diff --git a/paddle/fluid/framework/data_device_transform_test.cu b/paddle/fluid/framework/data_device_transform_test.cu index db6687985d..e896a06162 100644 --- a/paddle/fluid/framework/data_device_transform_test.cu +++ b/paddle/fluid/framework/data_device_transform_test.cu @@ -51,10 +51,10 @@ class TestOpWithKernel : public OperatorWithKernel { const ExecutionContext& ctx) const override { if (Attr("use_gpu")) { VLOG(3) << "force use gpu kernel"; - return OpKernelType(proto::DataType::FP32, platform::CUDAPlace(0)); + return OpKernelType(proto::VarType::FP32, platform::CUDAPlace(0)); } else { VLOG(3) << "use default kernel"; - return OpKernelType(proto::DataType::FP32, + return OpKernelType(proto::VarType::FP32, ctx.Input("input")->place()); } } diff --git a/paddle/fluid/framework/data_layout_transform_test.cc b/paddle/fluid/framework/data_layout_transform_test.cc index 73689cc9bc..dd17cac0e1 100644 --- a/paddle/fluid/framework/data_layout_transform_test.cc +++ b/paddle/fluid/framework/data_layout_transform_test.cc @@ -27,9 +27,9 @@ TEST(DataTransform, DataLayoutFunction) { in.mutable_data(make_ddim({2, 3, 1, 2}), place); in.set_layout(DataLayout::kNHWC); - auto kernel_nhwc = OpKernelType(proto::DataType::FP32, place, + auto kernel_nhwc = OpKernelType(proto::VarType::FP32, place, DataLayout::kNHWC, LibraryType::kPlain); - auto kernel_ncwh = OpKernelType(proto::DataType::FP32, place, + auto kernel_ncwh = OpKernelType(proto::VarType::FP32, place, DataLayout::kNCHW, LibraryType::kPlain); TransDataLayout(kernel_nhwc, kernel_ncwh, in, &out); diff --git a/paddle/fluid/framework/data_type.h b/paddle/fluid/framework/data_type.h index 127bbcf5d0..1dec766a34 100644 --- a/paddle/fluid/framework/data_type.h +++ b/paddle/fluid/framework/data_type.h @@ -20,35 +20,35 @@ limitations under the License. */ namespace paddle { namespace framework { -inline proto::DataType ToDataType(std::type_index type) { +inline proto::VarType::Type ToDataType(std::type_index type) { using namespace paddle::framework::proto; if (typeid(float).hash_code() == type.hash_code()) { - return DataType::FP32; + return proto::VarType::FP32; } else if (typeid(double).hash_code() == type.hash_code()) { - return DataType::FP64; + return proto::VarType::FP64; } else if (typeid(int).hash_code() == type.hash_code()) { - return DataType::INT32; + return proto::VarType::INT32; } else if (typeid(int64_t).hash_code() == type.hash_code()) { - return DataType::INT64; + return proto::VarType::INT64; } else if (typeid(bool).hash_code() == type.hash_code()) { - return DataType::BOOL; + return proto::VarType::BOOL; } else { PADDLE_THROW("Not supported"); } } -inline std::type_index ToTypeIndex(proto::DataType type) { +inline std::type_index ToTypeIndex(proto::VarType::Type type) { using namespace paddle::framework::proto; switch (type) { - case DataType::FP32: + case proto::VarType::FP32: return typeid(float); - case DataType::FP64: + case proto::VarType::FP64: return typeid(double); - case DataType::INT32: + case proto::VarType::INT32: return typeid(int); - case DataType::INT64: + case proto::VarType::INT64: return typeid(int64_t); - case DataType::BOOL: + case proto::VarType::BOOL: return typeid(bool); default: PADDLE_THROW("Not support type %d", type); @@ -56,22 +56,22 @@ inline std::type_index ToTypeIndex(proto::DataType type) { } template -inline void VisitDataType(proto::DataType type, Visitor visitor) { +inline void VisitDataType(proto::VarType::Type type, Visitor visitor) { using namespace paddle::framework::proto; switch (type) { - case DataType::FP32: + case proto::VarType::FP32: visitor.template operator()(); break; - case DataType::FP64: + case proto::VarType::FP64: visitor.template operator()(); break; - case DataType::INT32: + case proto::VarType::INT32: visitor.template operator()(); break; - case DataType::INT64: + case proto::VarType::INT64: visitor.template operator()(); break; - case DataType::BOOL: + case proto::VarType::BOOL: visitor.template operator()(); break; default: @@ -79,22 +79,22 @@ inline void VisitDataType(proto::DataType type, Visitor visitor) { } } -inline std::string DataTypeToString(const proto::DataType type) { +inline std::string DataTypeToString(const proto::VarType::Type type) { using namespace paddle::framework::proto; switch (type) { - case DataType::FP16: + case proto::VarType::FP16: return "float16"; - case DataType::FP32: + case proto::VarType::FP32: return "float32"; - case DataType::FP64: + case proto::VarType::FP64: return "float64"; - case DataType::INT16: + case proto::VarType::INT16: return "int16"; - case DataType::INT32: + case proto::VarType::INT32: return "int32"; - case DataType::INT64: + case proto::VarType::INT64: return "int64"; - case DataType::BOOL: + case proto::VarType::BOOL: return "bool"; default: PADDLE_THROW("Not support type %d", type); @@ -102,7 +102,7 @@ inline std::string DataTypeToString(const proto::DataType type) { } inline std::ostream& operator<<(std::ostream& out, - const proto::DataType& type) { + const proto::VarType::Type& type) { out << DataTypeToString(type); return out; } diff --git a/paddle/fluid/framework/data_type_transform.cc b/paddle/fluid/framework/data_type_transform.cc index e5836998e2..54cc1575d8 100644 --- a/paddle/fluid/framework/data_type_transform.cc +++ b/paddle/fluid/framework/data_type_transform.cc @@ -65,19 +65,19 @@ void TransDataType(const OpKernelType& kernel_type_for_var, auto ctx = pool.Get(in.place()); switch (src_type) { - case proto::DataType::FP32: + case proto::VarType::FP32: framework::VisitDataType(dst_type, CastDataType(in, out, ctx)); break; - case proto::DataType::FP64: + case proto::VarType::FP64: framework::VisitDataType(dst_type, CastDataType(in, out, ctx)); break; - case proto::DataType::INT32: + case proto::VarType::INT32: framework::VisitDataType(dst_type, CastDataType(in, out, ctx)); break; - case proto::DataType::INT64: + case proto::VarType::INT64: framework::VisitDataType(dst_type, CastDataType(in, out, ctx)); break; - case proto::DataType::BOOL: + case proto::VarType::BOOL: framework::VisitDataType(dst_type, CastDataType(in, out, ctx)); break; default: diff --git a/paddle/fluid/framework/data_type_transform_test.cc b/paddle/fluid/framework/data_type_transform_test.cc index 444d3b823c..724c8c301f 100644 --- a/paddle/fluid/framework/data_type_transform_test.cc +++ b/paddle/fluid/framework/data_type_transform_test.cc @@ -32,11 +32,11 @@ TEST(DataTypeTransform, CPUTransform) { ptr[i] = i / 3; } - auto kernel_fp32 = OpKernelType(proto::DataType::FP32, place, + auto kernel_fp32 = OpKernelType(proto::VarType::FP32, place, DataLayout::kAnyLayout, LibraryType::kPlain); - auto kernel_fp64 = OpKernelType(proto::DataType::FP64, place, + auto kernel_fp64 = OpKernelType(proto::VarType::FP64, place, DataLayout::kAnyLayout, LibraryType::kPlain); - auto kernel_int32 = OpKernelType(proto::DataType::INT32, place, + auto kernel_int32 = OpKernelType(proto::VarType::INT32, place, DataLayout::kAnyLayout, LibraryType::kPlain); TransDataType(kernel_fp32, kernel_fp64, in, &out); diff --git a/paddle/fluid/framework/framework.proto b/paddle/fluid/framework/framework.proto index fa7f437851..22d0692394 100644 --- a/paddle/fluid/framework/framework.proto +++ b/paddle/fluid/framework/framework.proto @@ -91,33 +91,34 @@ message OpProto { required string comment = 5; } -enum DataType { - BOOL = 0; - INT16 = 1; - INT32 = 2; - INT64 = 3; - FP16 = 4; - FP32 = 5; - FP64 = 6; -} - message VarType { enum Type { - LOD_TENSOR = 1; - SELECTED_ROWS = 2; - FEED_MINIBATCH = 3; - FETCH_LIST = 4; - STEP_SCOPES = 5; - LOD_RANK_TABLE = 6; - LOD_TENSOR_ARRAY = 7; - PLACE_LIST = 8; - READER = 9; + // Pod Types + BOOL = 0; + INT16 = 1; + INT32 = 2; + INT64 = 3; + FP16 = 4; + FP32 = 5; + FP64 = 6; + + // Other types that may need additional descriptions + LOD_TENSOR = 7; + SELECTED_ROWS = 8; + FEED_MINIBATCH = 9; + FETCH_LIST = 10; + STEP_SCOPES = 11; + LOD_RANK_TABLE = 12; + LOD_TENSOR_ARRAY = 13; + PLACE_LIST = 14; + READER = 15; } required Type type = 1; message TensorDesc { - required DataType data_type = 1; + // Should only be PODType. Is enforced in C++ + required Type data_type = 1; repeated int64 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480] } optional TensorDesc selected_rows = 2; diff --git a/paddle/fluid/framework/op_kernel_type.h b/paddle/fluid/framework/op_kernel_type.h index 980e4eafaa..3a1036742c 100644 --- a/paddle/fluid/framework/op_kernel_type.h +++ b/paddle/fluid/framework/op_kernel_type.h @@ -40,12 +40,12 @@ struct OpKernelType { // place, data_type, library_type kinds less than 2^8 constexpr static int LEFT_SHIFT = 8; - proto::DataType data_type_; + proto::VarType::Type data_type_; DataLayout data_layout_; platform::Place place_; LibraryType library_type_; - OpKernelType(proto::DataType data_type, platform::Place place, + OpKernelType(proto::VarType::Type data_type, platform::Place place, DataLayout data_layout = DataLayout::kAnyLayout, LibraryType library_type = LibraryType::kPlain) : data_type_(data_type), @@ -53,7 +53,7 @@ struct OpKernelType { place_(place), library_type_(library_type) {} - OpKernelType(proto::DataType data_type, + OpKernelType(proto::VarType::Type data_type, const platform::DeviceContext& dev_ctx, DataLayout data_layout = DataLayout::kAnyLayout, LibraryType library_type = LibraryType::kPlain) diff --git a/paddle/fluid/framework/op_kernel_type_test.cc b/paddle/fluid/framework/op_kernel_type_test.cc index e56fe35c01..d37ce149ce 100644 --- a/paddle/fluid/framework/op_kernel_type_test.cc +++ b/paddle/fluid/framework/op_kernel_type_test.cc @@ -18,7 +18,7 @@ limitations under the License. */ TEST(OpKernelType, ToString) { using OpKernelType = paddle::framework::OpKernelType; - using DataType = paddle::framework::proto::DataType; + using DataType = paddle::framework::proto::VarType; using CPUPlace = paddle::platform::CPUPlace; using DataLayout = paddle::framework::DataLayout; using LibraryType = paddle::framework::LibraryType; @@ -33,7 +33,7 @@ TEST(OpKernelType, ToString) { TEST(OpKernelType, Hash) { using OpKernelType = paddle::framework::OpKernelType; - using DataType = paddle::framework::proto::DataType; + using DataType = paddle::framework::proto::VarType; using CPUPlace = paddle::platform::CPUPlace; using CUDAPlace = paddle::platform::CUDAPlace; using DataLayout = paddle::framework::DataLayout; diff --git a/paddle/fluid/framework/op_registry_test.cc b/paddle/fluid/framework/op_registry_test.cc index b92647e892..0d791c8583 100644 --- a/paddle/fluid/framework/op_registry_test.cc +++ b/paddle/fluid/framework/op_registry_test.cc @@ -226,7 +226,7 @@ class OpWithKernelTest : public OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(proto::DataType::FP32, ctx.device_context()); + return framework::OpKernelType(proto::VarType::FP32, ctx.device_context()); } }; @@ -290,9 +290,9 @@ class OpWithMultiKernelTest : public OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - proto::DataType::FP32, platform::CUDAPlace(0), DataLayout::kAnyLayout, - framework::LibraryType::kCUDNN); + return framework::OpKernelType(proto::VarType::FP32, platform::CUDAPlace(0), + DataLayout::kAnyLayout, + framework::LibraryType::kCUDNN); } }; diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index ff90aba10b..7debdd8525 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -569,7 +569,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, } } -proto::DataType OperatorWithKernel::IndicateDataType( +proto::VarType::Type OperatorWithKernel::IndicateDataType( const ExecutionContext& ctx) const { auto& scope = ctx.scope(); int data_type = -1; @@ -595,7 +595,7 @@ proto::DataType OperatorWithKernel::IndicateDataType( } } PADDLE_ENFORCE(data_type != -1, "DataType should be indicated by input"); - return static_cast(data_type); + return static_cast(data_type); } OpKernelType OperatorWithKernel::GetExpectedKernelType( diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index c2782066ce..41214b41cb 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -394,9 +394,9 @@ class OperatorWithKernel : public OperatorBase { const OpKernelType& expected_kernel_type) const; private: - // indicate kernel DataType by input data. Defaultly all input data must be + // indicate kernel DataType by input data. By default all input data must be // same. - proto::DataType IndicateDataType(const ExecutionContext& ctx) const; + proto::VarType::Type IndicateDataType(const ExecutionContext& ctx) const; void RunImpl(const Scope& scope, const platform::Place& place) const final; }; diff --git a/paddle/fluid/framework/operator_test.cc b/paddle/fluid/framework/operator_test.cc index 08a471e0a1..44ca4d7ca5 100644 --- a/paddle/fluid/framework/operator_test.cc +++ b/paddle/fluid/framework/operator_test.cc @@ -119,7 +119,7 @@ class OpWithKernelTest : public OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override {} OpKernelType GetExpectedKernelType( const ExecutionContext& ctx) const override { - return OpKernelType(proto::DataType::FP32, ctx.GetPlace()); + return OpKernelType(proto::VarType::FP32, ctx.GetPlace()); } }; diff --git a/paddle/fluid/framework/program_desc_test.cc b/paddle/fluid/framework/program_desc_test.cc index d9c4331da1..66618a291b 100644 --- a/paddle/fluid/framework/program_desc_test.cc +++ b/paddle/fluid/framework/program_desc_test.cc @@ -24,13 +24,13 @@ TEST(ProgramDesc, copy_ctor) { auto* x = global_block->Var("X"); x->SetType(proto::VarType::LOD_TENSOR); x->SetLoDLevel(0); - x->SetDataType(proto::FP32); + x->SetDataType(proto::VarType::FP32); x->SetShape({1000, 784}); auto* y = global_block->Var("Y"); y->SetType(proto::VarType::LOD_TENSOR); y->SetLoDLevel(0); - y->SetDataType(proto::FP32); + y->SetDataType(proto::VarType::FP32); y->SetShape({784, 100}); auto* op = global_block->AppendOp(); @@ -86,13 +86,13 @@ TEST(ProgramDescBind, serialize_and_deserialize) { auto* x = global_block->Var("X"); x->SetType(proto::VarType::LOD_TENSOR); x->SetLoDLevel(0); - x->SetDataType(proto::FP32); + x->SetDataType(proto::VarType::FP32); x->SetShape({1000, 784}); auto* y = global_block->Var("Y"); y->SetType(proto::VarType::LOD_TENSOR); y->SetLoDLevel(0); - y->SetDataType(proto::FP32); + y->SetDataType(proto::VarType::FP32); y->SetShape({784, 100}); auto* op = global_block->AppendOp(); diff --git a/paddle/fluid/framework/prune_test.cc b/paddle/fluid/framework/prune_test.cc index b612fe8ad5..0e44b34383 100644 --- a/paddle/fluid/framework/prune_test.cc +++ b/paddle/fluid/framework/prune_test.cc @@ -34,7 +34,7 @@ void AddOp(const std::string &type, const f::VariableNameMap &inputs, for (auto kv : outputs) { for (auto v : kv.second) { auto var = block->Var(v); - var->SetDataType(paddle::framework::proto::DataType::FP32); + var->SetDataType(paddle::framework::proto::VarType::FP32); } } diff --git a/paddle/fluid/framework/var_desc.cc b/paddle/fluid/framework/var_desc.cc index bb2be1ab50..7e3f002b53 100644 --- a/paddle/fluid/framework/var_desc.cc +++ b/paddle/fluid/framework/var_desc.cc @@ -87,12 +87,12 @@ std::vector> VarDesc::GetShapes() const { return res; } -void VarDesc::SetDataType(proto::DataType data_type) { +void VarDesc::SetDataType(proto::VarType::Type data_type) { mutable_tensor_desc()->set_data_type(data_type); } void VarDesc::SetDataTypes( - const std::vector &multiple_data_type) { + const std::vector &multiple_data_type) { if (multiple_data_type.size() != GetTensorDescNum()) { VLOG(3) << "WARNING: The number of given data types(" << multiple_data_type.size() @@ -108,13 +108,13 @@ void VarDesc::SetDataTypes( } } -proto::DataType VarDesc::GetDataType() const { +proto::VarType::Type VarDesc::GetDataType() const { return tensor_desc().data_type(); } -std::vector VarDesc::GetDataTypes() const { +std::vector VarDesc::GetDataTypes() const { std::vector descs = tensor_descs(); - std::vector res; + std::vector res; res.reserve(descs.size()); for (const auto &tensor_desc : descs) { res.push_back(tensor_desc.data_type()); diff --git a/paddle/fluid/framework/var_desc.h b/paddle/fluid/framework/var_desc.h index 013ba446b9..19b8d890c1 100644 --- a/paddle/fluid/framework/var_desc.h +++ b/paddle/fluid/framework/var_desc.h @@ -80,13 +80,14 @@ class VarDesc { std::vector> GetShapes() const; - void SetDataType(proto::DataType data_type); + void SetDataType(proto::VarType::Type data_type); - void SetDataTypes(const std::vector &multiple_data_type); + void SetDataTypes( + const std::vector &multiple_data_type); - proto::DataType GetDataType() const; + proto::VarType::Type GetDataType() const; - std::vector GetDataTypes() const; + std::vector GetDataTypes() const; void SetLoDLevel(int32_t lod_level); diff --git a/paddle/fluid/operators/assign_value_op.cc b/paddle/fluid/operators/assign_value_op.cc index 2985fc28a0..e8123cb1a4 100644 --- a/paddle/fluid/operators/assign_value_op.cc +++ b/paddle/fluid/operators/assign_value_op.cc @@ -36,7 +36,8 @@ class AssignValueOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( - framework::proto::DataType(ctx.Attr("dtype")), ctx.GetPlace()); + framework::proto::VarType::Type(ctx.Attr("dtype")), + ctx.GetPlace()); } }; @@ -49,8 +50,8 @@ class AssignValueOpMaker : public framework::OpProtoAndCheckerMaker { "(vector) " "Shape of values."); AddAttr("dtype", "data type of values") - .InEnum({framework::proto::DataType::INT32, - framework::proto::DataType::FP32}); + .InEnum({framework::proto::VarType::INT32, + framework::proto::VarType::FP32}); AddAttr>("fp32_values", "store the float values") .SetDefault({}); AddAttr>("int32_values", "store the int values") diff --git a/paddle/fluid/operators/assign_value_op.h b/paddle/fluid/operators/assign_value_op.h index d51b215a08..c7b1a55a5c 100644 --- a/paddle/fluid/operators/assign_value_op.h +++ b/paddle/fluid/operators/assign_value_op.h @@ -30,10 +30,10 @@ class AssignValueKernel : public framework::OpKernel { int dtype = ctx.Attr("dtype"); const char* value_name = nullptr; switch (dtype) { - case framework::proto::DataType::INT32: + case framework::proto::VarType::INT32: value_name = "int32_values"; break; - case framework::proto::DataType::FP32: + case framework::proto::VarType::FP32: value_name = "fp32_values"; break; default: diff --git a/paddle/fluid/operators/cast_op.h b/paddle/fluid/operators/cast_op.h index ccfbd09a6b..6220e57f59 100644 --- a/paddle/fluid/operators/cast_op.h +++ b/paddle/fluid/operators/cast_op.h @@ -55,7 +55,8 @@ class CastOpKernel : public framework::OpKernel { auto* in = context.Input("X"); auto* out = context.Output("Out"); framework::VisitDataType( - static_cast(context.Attr("out_dtype")), + static_cast( + context.Attr("out_dtype")), CastOpFunctor( in, out, context.template device_context())); } diff --git a/paddle/fluid/operators/chunk_eval_op.cc b/paddle/fluid/operators/chunk_eval_op.cc index 09d090e187..77d3cffe7c 100644 --- a/paddle/fluid/operators/chunk_eval_op.cc +++ b/paddle/fluid/operators/chunk_eval_op.cc @@ -57,7 +57,7 @@ class ChunkEvalOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(framework::proto::DataType::FP32, + return framework::OpKernelType(framework::proto::VarType::FP32, platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/edit_distance_op.cc b/paddle/fluid/operators/edit_distance_op.cc index dbcbfec971..c7f037d2df 100644 --- a/paddle/fluid/operators/edit_distance_op.cc +++ b/paddle/fluid/operators/edit_distance_op.cc @@ -42,7 +42,7 @@ class EditDistanceOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(framework::proto::DataType::FP32, + return framework::OpKernelType(framework::proto::VarType::FP32, ctx.device_context()); } }; diff --git a/paddle/fluid/operators/fill_constant_batch_size_like_op.cc b/paddle/fluid/operators/fill_constant_batch_size_like_op.cc index 55eca71c8b..72da80baaf 100644 --- a/paddle/fluid/operators/fill_constant_batch_size_like_op.cc +++ b/paddle/fluid/operators/fill_constant_batch_size_like_op.cc @@ -24,7 +24,7 @@ class FillConstantBatchSizeLikeOp : public BatchSizeLikeOp { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( - static_cast(ctx.Attr("dtype")), + static_cast(ctx.Attr("dtype")), ctx.device_context()); } }; @@ -36,7 +36,7 @@ class FillConstantBatchSizeLikeOpMaker : public BatchSizeLikeOpMaker { AddAttr("dtype", "(int, default 5 (FP32)) " "Output data type") - .SetDefault(framework::proto::DataType::FP32); + .SetDefault(framework::proto::VarType::FP32); AddAttr("value", "(float, default 0) The value to be filled") .SetDefault(0.0f); AddComment(R"DOC( diff --git a/paddle/fluid/operators/fill_constant_op.cc b/paddle/fluid/operators/fill_constant_op.cc index 0b65c83d3a..07e0a80f8d 100644 --- a/paddle/fluid/operators/fill_constant_op.cc +++ b/paddle/fluid/operators/fill_constant_op.cc @@ -38,7 +38,7 @@ class FillConstantOp : public framework::OperatorBase { void RunImpl(const framework::Scope &scope, const platform::Place &dev_place) const override { auto data_type = - static_cast(Attr("dtype")); + static_cast(Attr("dtype")); auto value = Attr("value"); auto force_cpu = Attr("force_cpu"); auto &out = @@ -64,7 +64,7 @@ class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("dtype", "(int, default 5 (FP32)) " "Output data type") - .SetDefault(framework::proto::DataType::FP32); + .SetDefault(framework::proto::VarType::FP32); AddAttr>("shape", "(vector) The shape of the output"); AddAttr("value", "(float, default 0) The value to be filled") .SetDefault(0.0f); diff --git a/paddle/fluid/operators/fill_op.cc b/paddle/fluid/operators/fill_op.cc index 3b4b409231..ee8a2fc353 100644 --- a/paddle/fluid/operators/fill_op.cc +++ b/paddle/fluid/operators/fill_op.cc @@ -51,7 +51,8 @@ class FillOp : public framework::OperatorBase { "Cannot find variable %s", Output("Out")) .GetMutable()); out.Resize(framework::make_ddim(Attr>("shape"))); - auto dtype = static_cast(Attr("dtype")); + auto dtype = + static_cast(Attr("dtype")); platform::CPUPlace cpu; auto force_cpu = Attr("force_cpu"); out.mutable_data(force_cpu ? cpu : place, framework::ToTypeIndex(dtype)); @@ -93,7 +94,7 @@ Fill an tensor with `value` and `shape`. The type of the tensor is specify by "value", "The float values of tensor, which are flatten in row major"); AddAttr>("shape", "The shape of output tensor"); AddAttr("dtype", "The data type of output tensor, Default is float") - .SetDefault(framework::proto::DataType::FP32); + .SetDefault(framework::proto::VarType::FP32); AddAttr("force_cpu", "Whether the output tensor must be at CPU memory or not. " "Default is false.") diff --git a/paddle/fluid/operators/gaussian_random_batch_size_like_op.cc b/paddle/fluid/operators/gaussian_random_batch_size_like_op.cc index ac516986ad..53c706a83e 100644 --- a/paddle/fluid/operators/gaussian_random_batch_size_like_op.cc +++ b/paddle/fluid/operators/gaussian_random_batch_size_like_op.cc @@ -26,7 +26,7 @@ class GaussianRandomBatchSizeLikeOp : public BatchSizeLikeOp { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( - static_cast(ctx.Attr("dtype")), + static_cast(ctx.Attr("dtype")), ctx.GetPlace()); } }; @@ -53,7 +53,7 @@ class GaussianRandomBatchSizeLikeOpMaker : public BatchSizeLikeOpMaker { AddAttr("dtype", "(int, default 5(FP32)) " "Output data type.") - .SetDefault(framework::proto::DataType::FP32); + .SetDefault(framework::proto::VarType::FP32); AddComment(R"DOC( GaussianRandom Operator. diff --git a/paddle/fluid/operators/gaussian_random_op.cc b/paddle/fluid/operators/gaussian_random_op.cc index 7fb2b2c230..4d197637b3 100644 --- a/paddle/fluid/operators/gaussian_random_op.cc +++ b/paddle/fluid/operators/gaussian_random_op.cc @@ -63,7 +63,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - static_cast(ctx.Attr("dtype")), + static_cast(ctx.Attr("dtype")), ctx.device_context()); } }; @@ -95,7 +95,7 @@ class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("dtype", "(int, default 5(FP32)) " "Output data type.") - .SetDefault(framework::proto::DataType::FP32); + .SetDefault(framework::proto::VarType::FP32); AddComment(R"DOC( GaussianRandom Operator. diff --git a/paddle/fluid/operators/nccl_op.cc b/paddle/fluid/operators/nccl_op.cc index 7f1278f3a5..5ae50590dd 100644 --- a/paddle/fluid/operators/nccl_op.cc +++ b/paddle/fluid/operators/nccl_op.cc @@ -55,7 +55,7 @@ class NCCLInitOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("dtype", "(int, default 5 (FP32)) " "Output data type") - .SetDefault(framework::proto::DataType::FP32); + .SetDefault(framework::proto::VarType::FP32); AddComment(R"DOC( NCCLInit Operator. diff --git a/paddle/fluid/operators/one_hot_op.cc b/paddle/fluid/operators/one_hot_op.cc index 21d3405b70..1d42dfdd76 100644 --- a/paddle/fluid/operators/one_hot_op.cc +++ b/paddle/fluid/operators/one_hot_op.cc @@ -60,7 +60,7 @@ class OneHotOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("dtype", "An integer to specify the data type of one-hot " "vector. The default value is FP32.") - .SetDefault(paddle::framework::proto::DataType::FP32); + .SetDefault(paddle::framework::proto::VarType::FP32); AddComment(R"DOC( One Hot Operator. This operator creates the one-hot representations for input index values. The following example will help to explain the function of this diff --git a/paddle/fluid/operators/one_hot_op.cu b/paddle/fluid/operators/one_hot_op.cu index 87c285df4e..240ac895e2 100644 --- a/paddle/fluid/operators/one_hot_op.cu +++ b/paddle/fluid/operators/one_hot_op.cu @@ -65,7 +65,8 @@ class OneHotCUDAKernel : public framework::OpKernel { int depth = context.Attr("depth"); framework::VisitDataType( - static_cast(context.Attr("dtype")), + static_cast( + context.Attr("dtype")), OneHotOpCUDAFunctor( in, out, depth, context.template device_context())); } diff --git a/paddle/fluid/operators/one_hot_op.h b/paddle/fluid/operators/one_hot_op.h index 1409f8af62..7e77f25089 100644 --- a/paddle/fluid/operators/one_hot_op.h +++ b/paddle/fluid/operators/one_hot_op.h @@ -58,7 +58,8 @@ class OneHotKernel : public framework::OpKernel { int depth = context.Attr("depth"); framework::VisitDataType( - static_cast(context.Attr("dtype")), + static_cast( + context.Attr("dtype")), OneHotOpFunctor( in, out, depth, context.template device_context())); } diff --git a/paddle/fluid/operators/rnn_memory_helper_op.cc b/paddle/fluid/operators/rnn_memory_helper_op.cc index 8ab9f010a2..70f205d887 100644 --- a/paddle/fluid/operators/rnn_memory_helper_op.cc +++ b/paddle/fluid/operators/rnn_memory_helper_op.cc @@ -66,7 +66,7 @@ class RNNMemoryHelperOpInfoMaker : public framework::OpProtoAndCheckerMaker { AddAttr("dtype", "(int, default 5 (FP32)) " "Output data type") - .SetDefault(framework::proto::DataType::FP32); + .SetDefault(framework::proto::VarType::FP32); AddComment(""); } }; @@ -126,7 +126,7 @@ class RNNMemoryHelperGradOpInfoMaker AddAttr("dtype", "(int, default 5 (FP32)) " "Output data type") - .SetDefault(framework::proto::DataType::FP32); + .SetDefault(framework::proto::VarType::FP32); AddComment(""); } }; diff --git a/paddle/fluid/operators/sum_op.cc b/paddle/fluid/operators/sum_op.cc index 7b88387c33..c3abb3ea4a 100644 --- a/paddle/fluid/operators/sum_op.cc +++ b/paddle/fluid/operators/sum_op.cc @@ -73,7 +73,8 @@ class SumOp : public framework::OperatorWithKernel { "Sum operator should have at least one tensor"); return framework::OpKernelType( - static_cast(dtype), ctx.device_context()); + static_cast(dtype), + ctx.device_context()); } else if (x_vars[0]->IsType()) { return framework::OpKernelType( framework::ToDataType( diff --git a/paddle/fluid/operators/uniform_random_batch_size_like_op.cc b/paddle/fluid/operators/uniform_random_batch_size_like_op.cc index fa31dad513..00f00bb403 100644 --- a/paddle/fluid/operators/uniform_random_batch_size_like_op.cc +++ b/paddle/fluid/operators/uniform_random_batch_size_like_op.cc @@ -26,7 +26,7 @@ class UniformRandomBatchSizeLikeOp : public BatchSizeLikeOp { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( - static_cast(ctx.Attr("dtype")), + static_cast(ctx.Attr("dtype")), ctx.GetPlace()); } }; @@ -58,7 +58,7 @@ This operator initializes a tensor with the same batch_size as the Input tensor "generate the same random numbers every time.") .SetDefault(0); AddAttr("dtype", "(int, default 5(FP32)) Output tensor data type") - .SetDefault(framework::proto::DataType::FP32); + .SetDefault(framework::proto::VarType::FP32); } }; diff --git a/paddle/fluid/operators/uniform_random_op.cc b/paddle/fluid/operators/uniform_random_op.cc index 3a0a0d6fca..87699362b2 100644 --- a/paddle/fluid/operators/uniform_random_op.cc +++ b/paddle/fluid/operators/uniform_random_op.cc @@ -66,7 +66,7 @@ class UniformRandomOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - static_cast(ctx.Attr("dtype")), + static_cast(ctx.Attr("dtype")), ctx.GetPlace()); } }; @@ -101,7 +101,7 @@ uniform distribution. "generate the same random numbers every time.") .SetDefault(0); AddAttr("dtype", "(int, default 5(FP32)) Output tensor data type") - .SetDefault(framework::proto::DataType::FP32); + .SetDefault(framework::proto::VarType::FP32); } }; } // namespace operators diff --git a/paddle/fluid/pybind/protobuf.cc b/paddle/fluid/pybind/protobuf.cc index 9f97cc5007..99716ccb24 100644 --- a/paddle/fluid/pybind/protobuf.cc +++ b/paddle/fluid/pybind/protobuf.cc @@ -195,15 +195,6 @@ void BindBlockDesc(py::module &m) { } void BindVarDsec(py::module &m) { - py::enum_(m, "DataType", "") - .value("BOOL", proto::DataType::BOOL) - .value("INT16", proto::DataType::INT16) - .value("INT32", proto::DataType::INT32) - .value("INT64", proto::DataType::INT64) - .value("FP16", proto::DataType::FP16) - .value("FP32", proto::DataType::FP32) - .value("FP64", proto::DataType::FP64); - py::class_ var_desc(m, "VarDesc", ""); var_desc .def("name", @@ -233,6 +224,13 @@ void BindVarDsec(py::module &m) { .def("set_persistable", &VarDesc::SetPersistable); py::enum_(var_desc, "VarType", "") + .value("BOOL", proto::VarType::BOOL) + .value("INT16", proto::VarType::INT16) + .value("INT32", proto::VarType::INT32) + .value("INT64", proto::VarType::INT64) + .value("FP16", proto::VarType::FP16) + .value("FP32", proto::VarType::FP32) + .value("FP64", proto::VarType::FP64) .value("LOD_TENSOR", proto::VarType::LOD_TENSOR) .value("SELECTED_ROWS", proto::VarType::SELECTED_ROWS) .value("FEED_MINIBATCH", proto::VarType::FEED_MINIBATCH) diff --git a/python/paddle/v2/fluid/backward.py b/python/paddle/v2/fluid/backward.py index a690c14300..26b35cfc19 100644 --- a/python/paddle/v2/fluid/backward.py +++ b/python/paddle/v2/fluid/backward.py @@ -68,7 +68,7 @@ def _infer_var_data_type_(grad_var_name, block): fwd_var = block.desc.find_var_recursive(fwd_name.encode("ascii")) grad_var.set_dtype(fwd_var.dtype()) else: - grad_var.set_dtype(core.DataType.FP32) + grad_var.set_dtype(core.VarDesc.VarType.FP32) def _all_in_set_(cands, s): diff --git a/python/paddle/v2/fluid/data_feeder.py b/python/paddle/v2/fluid/data_feeder.py index 070bcadd71..ac02401c79 100644 --- a/python/paddle/v2/fluid/data_feeder.py +++ b/python/paddle/v2/fluid/data_feeder.py @@ -27,13 +27,13 @@ class DataToLoDTensorConverter(object): self.place = place self.lod_level = lod_level self.shape = shape - if dtype == core.DataType.FP32: + if dtype == core.VarDesc.VarType.FP32: self.dtype = 'float32' - elif dtype == core.DataType.INT64: + elif dtype == core.VarDesc.VarType.INT64: self.dtype = 'int64' - elif dtype == core.DataType.FP64: + elif dtype == core.VarDesc.VarType.FP64: self.dtype = 'float64' - elif dtype == core.DataType.INT32: + elif dtype == core.VarDesc.VarType.INT32: self.dtype = 'int32' else: raise ValueError("dtype must be any of [int32, float32, int64, " diff --git a/python/paddle/v2/fluid/evaluator.py b/python/paddle/v2/fluid/evaluator.py index 30d87c76c2..1f4618310c 100644 --- a/python/paddle/v2/fluid/evaluator.py +++ b/python/paddle/v2/fluid/evaluator.py @@ -89,7 +89,7 @@ class Evaluator(object): Args: suffix(str): the state suffix. - dtype(str|core.DataType): the state data type + dtype(str|core.VarDesc.VarType): the state data type shape(tuple|list): the shape of state Returns: State variable diff --git a/python/paddle/v2/fluid/framework.py b/python/paddle/v2/fluid/framework.py index dfd7e8047c..fb4cd5b75a 100644 --- a/python/paddle/v2/fluid/framework.py +++ b/python/paddle/v2/fluid/framework.py @@ -67,24 +67,24 @@ def convert_np_dtype_to_dtype_(np_dtype): Args: np_dtype(np.dtype): the data type in numpy - Returns(core.DataType): the data type in Paddle + Returns(core.VarDesc.VarType): the data type in Paddle """ dtype = np.dtype(np_dtype) if dtype == np.float32: - return core.DataType.FP32 + return core.VarDesc.VarType.FP32 elif dtype == np.float64: - return core.DataType.FP64 + return core.VarDesc.VarType.FP64 elif dtype == np.float16: - return core.DataType.FP16 + return core.VarDesc.VarType.FP16 elif dtype == np.int32: - return core.DataType.INT32 + return core.VarDesc.VarType.INT32 elif dtype == np.int16: - return core.DataType.INT16 + return core.VarDesc.VarType.INT16 elif dtype == np.int64: - return core.DataType.INT64 + return core.VarDesc.VarType.INT64 elif dtype == np.bool: - return core.DataType.BOOL + return core.VarDesc.VarType.BOOL else: raise ValueError("Not supported numpy dtype " + str(dtype)) @@ -93,16 +93,19 @@ def dtype_is_floating(dtype): """ Check the data type is floating or not. Args: - dtype(np.dtype|core.DataType): data type. + dtype(np.dtype|core.VarDesc.VarType): data type. Could be numpy format or Paddle format Returns(bool): True if data type is a float value """ - if not isinstance(dtype, core.DataType): + if not isinstance(dtype, core.VarDesc.VarType): dtype = convert_np_dtype_to_dtype_(dtype) - return dtype in [core.DataType.FP16, core.DataType.FP32, core.DataType.FP64] + return dtype in [ + core.VarDesc.VarType.FP16, core.VarDesc.VarType.FP32, + core.VarDesc.VarType.FP64 + ] def _debug_string_(proto, throw_on_error=True): @@ -148,7 +151,7 @@ class Variable(object): framework.proto for details. shape(tuple|list|None): The shape of variable. -1 means the batch size. Some kinds of variable do not contain shape, just set it to None. - dtype(np.dtype|core.DataType|str): The data type of variable. + dtype(np.dtype|core.VarDesc.VarType|str): The data type of variable. lod_level(int): The level of lod tensor. 0 means there is not a time series data. persistable(bool): True if the variable should be saved as check point. @@ -200,7 +203,7 @@ class Variable(object): "shape is {1}; the new shape is {2}. They are not " "matched.".format(self.name, old_shape, shape)) if dtype is not None: - if not isinstance(dtype, core.DataType): + if not isinstance(dtype, core.VarDesc.VarType): dtype = convert_np_dtype_to_dtype_(dtype) if is_new_var: self.desc.set_dtype(dtype) diff --git a/python/paddle/v2/fluid/layers/control_flow.py b/python/paddle/v2/fluid/layers/control_flow.py index 1ca11bb35b..b56a391618 100644 --- a/python/paddle/v2/fluid/layers/control_flow.py +++ b/python/paddle/v2/fluid/layers/control_flow.py @@ -612,7 +612,7 @@ class While(object): if not isinstance(cond, Variable): raise TypeError("condition should be a variable") assert isinstance(cond, Variable) - if cond.dtype != core.DataType.BOOL: + if cond.dtype != core.VarDesc.VarType.BOOL: raise TypeError("condition should be a bool variable") if reduce(lambda a, b: a * b, cond.shape, 1) != 1: raise TypeError("condition should be a bool scalar") diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py index d1ac6583dd..c4baa62ccd 100644 --- a/python/paddle/v2/fluid/layers/nn.py +++ b/python/paddle/v2/fluid/layers/nn.py @@ -221,7 +221,7 @@ def embedding(input, :math:`padding_idx < 0`, the padding_idx to use in lookup is :math:`size[0] + dim`. param_attr(ParamAttr): Parameters for this layer - dtype(np.dtype|core.DataType|str): The type of data : float32, float_16, int etc + dtype(np.dtype|core.VarDesc.VarType|str): The type of data : float32, float_16, int etc Returns: Variable: The tensor variable storing the embeddings of the \ diff --git a/python/paddle/v2/fluid/layers/tensor.py b/python/paddle/v2/fluid/layers/tensor.py index db400aad37..97e8f082cf 100644 --- a/python/paddle/v2/fluid/layers/tensor.py +++ b/python/paddle/v2/fluid/layers/tensor.py @@ -17,7 +17,7 @@ from ..param_attr import ParamAttr from ..framework import convert_np_dtype_to_dtype_ from ..framework import Variable from ..initializer import Constant, force_init_on_cpu -from ..core import DataType +from ..core import VarDesc import numpy __all__ = [ @@ -199,10 +199,10 @@ def assign(input, output): attrs={'scale': 1.0}) elif isinstance(input, numpy.ndarray): dtype = convert_np_dtype_to_dtype_(input.dtype) - if dtype == DataType.FP32: + if dtype == VarDesc.VarType.FP32: value_name = "fp32_values" values = [float(v) for v in input.flat] - elif dtype == DataType.INT32: + elif dtype == VarDesc.VarType.INT32: value_name = "int32_values" values = [int(v) for v in input.flat] else: @@ -236,7 +236,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None): Args: shape(tuple|list|None): Shape of the output tensor. - dtype(np.dtype|core.DataType|str): Data type of the output tensor. + dtype(np.dtype|core.VarDesc.VarType|str): Data type of the output tensor. value(float): The constant value used to initialize the output tensor. out(Variable): The output tensor. force_cpu(True|False): data should be on CPU if set true. @@ -285,7 +285,7 @@ def fill_constant_batch_size_like(input, Args: input(Variable): Tensor whose dimensions will be used to get batch size shape(tuple|list|None): Shape of output tensor - dtype(np.dtype|core.DataType|str): Data type of output tensor + dtype(np.dtype|core.VarDesc.VarType|str): Data type of output tensor value(float): Constant value to initialize the output tensor input_dim_idx(int): Index of input's batch size dimension output_dim_idx(int): Index of output's batch size dimension @@ -327,7 +327,7 @@ def ones(shape, dtype, force_cpu=False): Args: shape(tuple|list|None): Shape of output tensor - dtype(np.dtype|core.DataType|str): Data type of output tensor + dtype(np.dtype|core.VarDesc.VarType|str): Data type of output tensor Returns: Variable: The tensor variable storing the output @@ -351,7 +351,7 @@ def zeros(shape, dtype, force_cpu=False): Args: shape(tuple|list|None): Shape of output tensor - dtype(np.dtype|core.DataType|str): Data type of output tensor + dtype(np.dtype|core.VarDesc.VarType|str): Data type of output tensor Returns: Variable: The tensor variable storing the output diff --git a/python/paddle/v2/fluid/memory_optimization_transpiler.py b/python/paddle/v2/fluid/memory_optimization_transpiler.py index 78dc56f849..ee56ccdcf1 100644 --- a/python/paddle/v2/fluid/memory_optimization_transpiler.py +++ b/python/paddle/v2/fluid/memory_optimization_transpiler.py @@ -20,13 +20,13 @@ from backward import _rename_arg_ from . import core dtype_to_size = { - core.DataType.FP16: 2, - core.DataType.FP32: 4, - core.DataType.FP64: 8, - core.DataType.INT16: 2, - core.DataType.INT32: 4, - core.DataType.INT64: 8, - core.DataType.BOOL: 1 + core.VarDesc.VarType.FP16: 2, + core.VarDesc.VarType.FP32: 4, + core.VarDesc.VarType.FP64: 8, + core.VarDesc.VarType.INT16: 2, + core.VarDesc.VarType.INT32: 4, + core.VarDesc.VarType.INT64: 8, + core.VarDesc.VarType.BOOL: 1 } diff --git a/python/paddle/v2/fluid/tests/test_cpp_reader.py b/python/paddle/v2/fluid/tests/test_cpp_reader.py index 8d4f454611..6d2312dbcb 100644 --- a/python/paddle/v2/fluid/tests/test_cpp_reader.py +++ b/python/paddle/v2/fluid/tests/test_cpp_reader.py @@ -22,7 +22,7 @@ block = prog.current_block() random_reader = block.create_var( type=fluid.core.VarDesc.VarType.READER, name="RandomDataGenerator") random_reader.desc.set_dtypes( - [fluid.core.DataType.FP32, fluid.core.DataType.FP32]) + [fluid.core.VarDesc.VarType.FP32, fluid.core.VarDesc.VarType.FP32]) create_random_data_generator_op = block.append_op( type="create_random_data_generator", diff --git a/python/paddle/v2/fluid/tests/unittests/op_test.py b/python/paddle/v2/fluid/tests/unittests/op_test.py index 4761811f0a..d8867550ca 100644 --- a/python/paddle/v2/fluid/tests/unittests/op_test.py +++ b/python/paddle/v2/fluid/tests/unittests/op_test.py @@ -119,9 +119,9 @@ def get_numeric_gradient(place, tensor_to_check = scope.find_var(input_to_check).get_tensor() tensor_size = product(tensor_to_check.get_dims()) tensor_to_check_dtype = tensor_to_check.dtype() - if tensor_to_check_dtype == core.DataType.FP32: + if tensor_to_check_dtype == core.VarDesc.VarType.FP32: tensor_to_check_dtype = np.float32 - elif tensor_to_check_dtype == core.DataType.FP64: + elif tensor_to_check_dtype == core.VarDesc.VarType.FP64: tensor_to_check_dtype = np.float64 else: raise ValueError("Not supported data type " + str( diff --git a/python/paddle/v2/fluid/tests/unittests/test_batch_norm_op.py b/python/paddle/v2/fluid/tests/unittests/test_batch_norm_op.py index 778c7044ce..b7c0cb521a 100644 --- a/python/paddle/v2/fluid/tests/unittests/test_batch_norm_op.py +++ b/python/paddle/v2/fluid/tests/unittests/test_batch_norm_op.py @@ -140,9 +140,9 @@ def set_output_grad(scope, outputs, place, feed_dict=None): grad_tensor = scope.var(grad_var_name(name)).get_tensor() out_dtype = out_tensor.dtype() if data is None: - if out_dtype == core.DataType.FP64: + if out_dtype == core.VarDesc.VarType.FP64: data = np.ones(out_tensor.shape(), dtype=np.float64) - elif out_dtype == core.DataType.FP32: + elif out_dtype == core.VarDesc.VarType.FP32: data = np.ones(out_tensor.shape(), dtype=np.float32) else: raise ValueError("Not supported data type " + str(out_dtype)) diff --git a/python/paddle/v2/fluid/tests/unittests/test_cast_op.py b/python/paddle/v2/fluid/tests/unittests/test_cast_op.py index 44859e2155..3d05a319cd 100644 --- a/python/paddle/v2/fluid/tests/unittests/test_cast_op.py +++ b/python/paddle/v2/fluid/tests/unittests/test_cast_op.py @@ -24,8 +24,8 @@ class TestCastOp(op_test.OpTest): self.inputs = {'X': ipt.astype('float32')} self.outputs = {'Out': ipt.astype('float64')} self.attrs = { - 'in_dtype': int(core.DataType.FP32), - 'out_dtype': int(core.DataType.FP64) + 'in_dtype': int(core.VarDesc.VarType.FP32), + 'out_dtype': int(core.VarDesc.VarType.FP64) } self.op_type = 'cast' diff --git a/python/paddle/v2/fluid/tests/unittests/test_fill_op.py b/python/paddle/v2/fluid/tests/unittests/test_fill_op.py index 34c6401377..c2e3cfe6f3 100644 --- a/python/paddle/v2/fluid/tests/unittests/test_fill_op.py +++ b/python/paddle/v2/fluid/tests/unittests/test_fill_op.py @@ -26,7 +26,7 @@ class TestFillOp(OpTest): self.attrs = { 'value': val.flatten().tolist(), 'shape': [100, 200], - 'dtype': int(core.DataType.FP64) + 'dtype': int(core.VarDesc.VarType.FP64) } self.outputs = {'Out': val.astype('float64')} diff --git a/python/paddle/v2/fluid/tests/unittests/test_layer_norm_op.py b/python/paddle/v2/fluid/tests/unittests/test_layer_norm_op.py index b723b471bc..a1206b3b85 100644 --- a/python/paddle/v2/fluid/tests/unittests/test_layer_norm_op.py +++ b/python/paddle/v2/fluid/tests/unittests/test_layer_norm_op.py @@ -97,9 +97,9 @@ def set_output_grad(scope, outputs, place, feed_dict=None): grad_tensor = scope.var(grad_var_name(name)).get_tensor() out_dtype = out_tensor.dtype() if data is None: - if out_dtype == core.DataType.FP64: + if out_dtype == core.VarDesc.VarType.FP64: data = np.ones(out_tensor.shape(), dtype=np.float64) - elif out_dtype == core.DataType.FP32: + elif out_dtype == core.VarDesc.VarType.FP32: data = np.ones(out_tensor.shape(), dtype=np.float32) else: raise ValueError("Not supported data type " + str(out_dtype)) diff --git a/python/paddle/v2/fluid/tests/unittests/test_one_hot_op.py b/python/paddle/v2/fluid/tests/unittests/test_one_hot_op.py index c93be0efda..b7db30104a 100644 --- a/python/paddle/v2/fluid/tests/unittests/test_one_hot_op.py +++ b/python/paddle/v2/fluid/tests/unittests/test_one_hot_op.py @@ -38,7 +38,7 @@ class TestOneHotOp(OpTest): out[i, x[i]] = 1.0 self.inputs = {'X': (x, x_lod)} - self.attrs = {'depth': depth, 'dtype': int(core.DataType.FP32)} + self.attrs = {'depth': depth, 'dtype': int(core.VarDesc.VarType.FP32)} self.outputs = {'Out': (out, x_lod)} def test_check_output(self): diff --git a/python/paddle/v2/fluid/tests/unittests/test_parameter.py b/python/paddle/v2/fluid/tests/unittests/test_parameter.py index 0ba9235fdb..88356a7ea1 100644 --- a/python/paddle/v2/fluid/tests/unittests/test_parameter.py +++ b/python/paddle/v2/fluid/tests/unittests/test_parameter.py @@ -36,7 +36,7 @@ class TestParameter(unittest.TestCase): self.assertIsNotNone(param) self.assertEqual('fc.w', param.name) self.assertEqual((784, 100), param.shape) - self.assertEqual(core.DataType.FP32, param.dtype) + self.assertEqual(core.VarDesc.VarType.FP32, param.dtype) self.assertEqual(0, param.block.idx) exe = Executor(core.CPUPlace()) p = exe.run(main_program, fetch_list=[param])[0] diff --git a/python/paddle/v2/fluid/tests/unittests/test_protobuf_descs.py b/python/paddle/v2/fluid/tests/unittests/test_protobuf_descs.py index 55d18d2729..c3bef95874 100644 --- a/python/paddle/v2/fluid/tests/unittests/test_protobuf_descs.py +++ b/python/paddle/v2/fluid/tests/unittests/test_protobuf_descs.py @@ -131,8 +131,8 @@ class TestVarDesc(unittest.TestCase): block = program_desc.block(0) var = block.var('my_var') var.set_type(core.VarDesc.VarType.LOD_TENSOR) - var.set_dtype(core.DataType.INT32) - self.assertEqual(core.DataType.INT32, var.dtype()) + var.set_dtype(core.VarDesc.VarType.INT32) + self.assertEqual(core.VarDesc.VarType.INT32, var.dtype()) self.assertEqual(core.VarDesc.VarType.LOD_TENSOR, var.type()) def test_multiple_dtype(self): @@ -141,7 +141,8 @@ class TestVarDesc(unittest.TestCase): var = block.var('my_reader') var.set_type(core.VarDesc.VarType.READER) src_types = [ - core.DataType.INT32, core.DataType.FP64, core.DataType.FP32 + core.VarDesc.VarType.INT32, core.VarDesc.VarType.FP64, + core.VarDesc.VarType.FP32 ] var.set_dtypes(src_types) self.assertEqual(src_types, var.dtypes()) diff --git a/python/paddle/v2/fluid/tests/unittests/test_variable.py b/python/paddle/v2/fluid/tests/unittests/test_variable.py index b06bcfb075..4ae3909d27 100644 --- a/python/paddle/v2/fluid/tests/unittests/test_variable.py +++ b/python/paddle/v2/fluid/tests/unittests/test_variable.py @@ -20,7 +20,7 @@ import numpy as np class TestVariable(unittest.TestCase): def test_np_dtype_convert(self): - DT = core.DataType + DT = core.VarDesc.VarType convert = convert_np_dtype_to_dtype_ self.assertEqual(DT.FP32, convert(np.float32)) self.assertEqual(DT.FP16, convert("float16")) @@ -36,13 +36,13 @@ class TestVariable(unittest.TestCase): w = b.create_var( dtype="float64", shape=[784, 100], lod_level=0, name="fc.w") self.assertNotEqual(str(w), "") - self.assertEqual(core.DataType.FP64, w.dtype) + self.assertEqual(core.VarDesc.VarType.FP64, w.dtype) self.assertEqual((784, 100), w.shape) self.assertEqual("fc.w", w.name) self.assertEqual(0, w.lod_level) w = b.create_var(name='fc.w') - self.assertEqual(core.DataType.FP64, w.dtype) + self.assertEqual(core.VarDesc.VarType.FP64, w.dtype) self.assertEqual((784, 100), w.shape) self.assertEqual("fc.w", w.name) self.assertEqual(0, w.lod_level) From 56d5319261f94a7c1b135ffe904b415cdfe8f4e8 Mon Sep 17 00:00:00 2001 From: emailweixu Date: Fri, 16 Feb 2018 15:05:33 -0800 Subject: [PATCH 5/5] Fix typo Paddle/tools/manylinux1/README.md (#8463) --- tools/manylinux1/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/manylinux1/README.md b/tools/manylinux1/README.md index cb0a9ac22c..898e00bd37 100644 --- a/tools/manylinux1/README.md +++ b/tools/manylinux1/README.md @@ -12,7 +12,7 @@ with newer version compilers cannot work with those with older versions. The suggested building environment is as old as CentOS 5. However, PaddlePaddle relies on CUDA, and the earlies version of [CentOS works with CUDA is 6](https://hub.docker.com/r/nvidia/cuda/). -So, here we provide a Docker image basing on CentOS 6 and CUDA for +So, here we provide a Docker image based on CentOS 6 and CUDA for building PaddlePaddle and making the release supports "as-manylinux as possible." or "sufficiently many Linux" according to [this discussion](https://mail.python.org/pipermail/wheel-builders/2016-July/000175.html).