& no_grad_vars);
-```
-
-The implementation behind it can be divided into two parts, **Backward Operator Creating** and **Backward Operator Building**.
-
-### Backward Operator Registry
-
-A backward network is built up with several backward operators. Backward operators take forward operators' inputs, outputs, and output gradients and then calculate its input gradients.
-
-| | forward operator | backward operator
-| ---------------------- | ---------------- |------------------------- |
-| **Operator::inputs_** | Inputs | Inputs, Outputs, OutputGradients |
-| **Operator::outputs_** | Outputs | InputGradients |
-
- In most cases, there is a one-to-one relation between the forward and backward operators. These relations are recorded by a global hash map(`OpInfoMap`). To follow the philosophy of minimum core and to make operators pluggable, the registry mechanism is introduced.
-
-For example, we have `mul_op`, and we can register its information and corresponding backward operator by the following macro:
-
-```cpp
-REGISTER_OP(mul, MulOp, MulOpMaker, mul_grad, MulOpGrad);
-```
-
-`mul` is the operator's type. `MulOp` and `MulOpMaker` are the operator class and the operator maker class respectively.
-
-`mul_grad` is the type of backward operator, and `MulOpGrad` is its class name.
-
-### Backward Opeartor Creating
-
-Given a certain forward operator, we can get its corresponding backward operator by calling:
-
-```cpp
-OperatorBase* bwd_op = BuildGradOp(const OperatorBase* fwd_op);
-```
-
-The function `BuildGradOp` will sequentially execute following processes:
-
-1. Get the `type_` of given forward operator, and then get the corresponding backward operator's type by looking up the `OpInfoMap`.
-
-2. Build two maps named `inputs` and `outputs` to temporarily store backward operator's inputs and outputs. Copy forward operator's `inputs_` and `outputs_` to map `inputs`, except these, are not necessary for gradient computing.
-
-3. Add forward inputs' gradient variables into map `output`, adding forward outputs' gradient variables into map `input`.
-
-4. Building backward operator with `inputs`, `outputs` and forward operator's attributes.
-
-### Backward Network Building
-
-A backward network is a series of backward operators. The main idea of building a backward network is creating backward operators in the inverted sequence and appending them together one by one. There are some corner cases that need special processing.
-
-1. Op
-
- When the input forward network is an Op, return its gradient Operator immediately. If all of its outputs are in no gradient set, then return a special `NOP`.
-
-2. NetOp
-
- In our design, the network itself is also a kind of operator(**NetOp**). So the operators contained by a big network may be some small network. When the input forward network is a NetOp, it needs to call the sub NetOp/Operators backward function recursively. During the process, we need to collect the `OutputGradients` name according to the forward NetOp.
-
-3. RnnOp
-
- RnnOp is a nested stepnet operator. Backward module needs to recusively call `Backward` for every stepnet.
-
-4. Sharing Variables
-
- As illustrated in the figure 1 and figure 2, two operators share the same variable name **W@GRAD**, which will overwrite their shared input variable.
-
-
-
-
- Figure 1. Sharing variables in operators.
-
-
-
- Sharing variable between operators or same input variable used in multiple operators can lead to duplicate gradient variables. As illustrated in figure 2, we need to rename the gradient names recursively and add a generic add operator to prevent overwriting.
-
-
-
-
- Figure 2. Replace sharing variable's gradient with `Add` operator.
-
-
-
- Because the framework finds variables according to their names, we need to rename the output links. We add an integer suffix to represent its position in the clockwise direction.
-
-5. Part of the Gradient is Zero.
-
- In the whole graph, there is some case of that one operator's gradient is not needed, but its input's gradient is a dependency link of other operator, we need to fill a same shape gradient matrix in the position. In our implementation, we insert a special `fillZeroLike` operator.
-
-
-Follow these rules above, then collect the sub graph `OutputGradients`/`InputGradients` as the NetOp's and return it.
diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc
index 0957646b56..692406b1c3 100644
--- a/paddle/framework/backward_test.cc
+++ b/paddle/framework/backward_test.cc
@@ -1,16 +1,16 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#include "paddle/framework/backward.h"
diff --git a/paddle/framework/data_layout.h b/paddle/framework/data_layout.h
index 7d7a444cf0..4a8669c3a4 100644
--- a/paddle/framework/data_layout.h
+++ b/paddle/framework/data_layout.h
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
+#include "paddle/platform/enforce.h"
#include
#include "paddle/platform/enforce.h"
@@ -20,7 +21,7 @@ limitations under the License. */
namespace paddle {
namespace framework {
-enum DataLayout {
+enum class DataLayout {
kNHWC = 0,
kNCHW = 1,
kAnyLayout = 2,
@@ -38,11 +39,11 @@ inline DataLayout StringToDataLayout(const std::string& str) {
inline std::string DataLayoutToString(const DataLayout& data_layout) {
switch (data_layout) {
- case kNHWC:
+ case DataLayout::kNHWC:
return "NHWC";
- case kNCHW:
+ case DataLayout::kNCHW:
return "NCHW";
- case kAnyLayout:
+ case DataLayout::kAnyLayout:
return "ANY_LAYOUT";
default:
PADDLE_THROW("unknown DataLayou %d", data_layout);
diff --git a/paddle/framework/data_transform.cc b/paddle/framework/data_transform.cc
new file mode 100644
index 0000000000..376268888e
--- /dev/null
+++ b/paddle/framework/data_transform.cc
@@ -0,0 +1,27 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/framework/data_transform.h"
+#include "paddle/framework/lod_tensor.h"
+
+namespace paddle {
+namespace framework {
+
+DataTransformFnMap& DataTransformFnMap::Instance() {
+ static DataTransformFnMap data_transform_map;
+ return data_transform_map;
+}
+
+} // namespace framework
+} // namespace paddle
diff --git a/paddle/framework/data_transform.h b/paddle/framework/data_transform.h
new file mode 100644
index 0000000000..bd6d301c12
--- /dev/null
+++ b/paddle/framework/data_transform.h
@@ -0,0 +1,108 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+
+#include
+#include
+#include
+
+#include "paddle/framework/op_kernel_type.h"
+#include "paddle/framework/tensor.h"
+#include "paddle/framework/variable.h"
+#include "paddle/platform/device_context.h"
+#include "paddle/platform/macros.h"
+
+namespace paddle {
+namespace framework {
+
+using DataTransformFn = std::function;
+using KernelTypePair = std::pair;
+
+struct KernelTypePairHash {
+ static void HashCombine(const OpKernelType& t, std::size_t* seed) {
+ OpKernelType::Hash kernel_type_hasher;
+ (*seed) ^= kernel_type_hasher(t) + 0x9e3779b9 + (*seed << 6) + (*seed >> 2);
+ }
+
+ size_t operator()(const KernelTypePair& kernel_pair) const {
+ std::size_t seed = 0;
+ HashCombine(kernel_pair.first, &seed);
+ HashCombine(kernel_pair.second, &seed);
+ return seed;
+ }
+};
+
+using DataTransformMap =
+ std::unordered_map;
+
+class DataTransformFnMap {
+ public:
+ static DataTransformFnMap& Instance();
+
+ bool Has(const KernelTypePair& key_pair) const {
+ return map_.find(key_pair) != map_.end();
+ }
+
+ void Insert(const OpKernelType& left, const OpKernelType& right,
+ const DataTransformFn& data_tranform_fn) {
+ Insert(std::make_pair(left, right), data_tranform_fn);
+ }
+
+ void Insert(const KernelTypePair& kernel_type_pair,
+ const DataTransformFn& data_tranform_fn) {
+ PADDLE_ENFORCE(!Has(kernel_type_pair),
+ "KernelTypePair %s has been registered", "");
+ map_.insert({kernel_type_pair, data_tranform_fn});
+ }
+
+ const DataTransformFn& Get(const KernelTypePair& key_pair) const {
+ auto data_transformer = GetNullable(key_pair);
+ PADDLE_ENFORCE_NOT_NULL(data_transformer,
+ "DataTransformFn should not be NULL");
+ return *data_transformer;
+ }
+
+ const DataTransformFn* GetNullable(const KernelTypePair& key_pair) const {
+ auto it = map_.find(key_pair);
+ if (it == map_.end()) {
+ return nullptr;
+ } else {
+ return &(it->second);
+ }
+ }
+
+ const DataTransformMap& Map() const { return map_; }
+
+ private:
+ DataTransformFnMap() = default;
+ DataTransformMap map_;
+ DISABLE_COPY_AND_ASSIGN(DataTransformFnMap);
+};
+
+// generate unique name with __LINE__
+// refs https://stackoverflow.com/questions/1597007
+#define TOKENPASTE(x, y) x##y
+#define TOKENPASTE2(x, y) TOKENPASTE(x, y)
+#define REGISTER_DATA_TRANSFORM_FN(from, to, fn) \
+ static int TOKENPASTE2(fn_, __LINE__)() { \
+ ::paddle::framework::DataTransformFnMap::Instance().Insert(from, to, fn); \
+ return 0; \
+ } \
+ static int TOKENPASTE2(var_, __LINE__) __attribute__((unused)) = \
+ TOKENPASTE2(fn_, __LINE__)()
+
+} // namespace framework
+} // namespace paddle
diff --git a/paddle/framework/data_transform_test.cc b/paddle/framework/data_transform_test.cc
new file mode 100644
index 0000000000..5f05e881fa
--- /dev/null
+++ b/paddle/framework/data_transform_test.cc
@@ -0,0 +1,99 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+#include
+#include
+
+#include
+
+#include "paddle/framework/data_transform.h"
+
+namespace paddle {
+namespace framework {
+using namespace platform;
+
+/**
+ * @brief cross validation of different kernel type transform
+ * We use four bit map represent different combination.
+ * If the field has multiple possible value, only choose two of them.
+ * For DataType, only test the FP32(float), FP64(double).
+ * e.g. 0000 -> FP32, CPUPlace, kNHWC, kPlain
+ * 1111 -> FP64, GPUPlace, kNCHW, kMKLDNN
+ */
+
+std::array kDataType = {
+ {proto::DataType::FP32, proto::DataType::FP64}};
+
+std::array kPlace = {{CPUPlace(), CUDAPlace(0)}};
+
+std::array kDataLayout = {
+ {DataLayout::kNHWC, DataLayout::kNCHW}};
+
+std::array kLibraryType = {
+ {LibraryType::kPlain, LibraryType::kMKLDNN}};
+
+OpKernelType GenFromBit(const std::vector bits) {
+ return OpKernelType(kDataType[bits[0]], kPlace[bits[1]], kDataLayout[bits[2]],
+ kLibraryType[bits[3]]);
+}
+
+int test_value = 0;
+
+auto kernel0 = GenFromBit({0, 0, 0, 0});
+auto kernel1 = GenFromBit({0, 0, 0, 1});
+auto kernel2 = GenFromBit({0, 0, 1, 0});
+auto kernel3 = GenFromBit({0, 0, 1, 1});
+
+void TransDataType_t(const platform::DeviceContext* ctx, const Variable& in,
+ Variable* out) {
+ test_value++;
+}
+
+void TransDataLayout_t(const platform::DeviceContext* ctx, const Variable& in,
+ Variable* out) {
+ test_value--;
+}
+
+void TransLibraryType_t(const platform::DeviceContext* ctx, const Variable& in,
+ Variable* out) {
+ test_value += 2;
+}
+
+} // namespace framework
+} // namespace paddle
+
+namespace frw = paddle::framework;
+
+REGISTER_DATA_TRANSFORM_FN(frw::kernel0, frw::kernel1, frw::TransDataType_t);
+REGISTER_DATA_TRANSFORM_FN(frw::kernel1, frw::kernel2, frw::TransDataLayout_t);
+REGISTER_DATA_TRANSFORM_FN(frw::kernel0, frw::kernel2, frw::TransLibraryType_t);
+
+TEST(DataTransform, Register) {
+ using namespace paddle::framework;
+ using namespace paddle::platform;
+
+ auto& instance = DataTransformFnMap::Instance();
+ ASSERT_EQ(instance.Map().size(), 3UL);
+ DeviceContext* ctx = nullptr;
+ paddle::framework::Variable in;
+ paddle::framework::Variable out;
+
+ instance.Get(std::make_pair(frw::kernel0, frw::kernel1))(ctx, in, &out);
+ ASSERT_EQ(test_value, 1);
+
+ instance.Get(std::make_pair(frw::kernel1, frw::kernel2))(ctx, in, &out);
+ ASSERT_EQ(test_value, 0);
+
+ instance.Get(std::make_pair(frw::kernel0, frw::kernel2))(ctx, in, &out);
+ ASSERT_EQ(test_value, 2);
+}
diff --git a/paddle/framework/data_type.h b/paddle/framework/data_type.h
index e94ee2ed52..6a372ac32e 100644
--- a/paddle/framework/data_type.h
+++ b/paddle/framework/data_type.h
@@ -1,16 +1,16 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#pragma once
#include
diff --git a/paddle/framework/ddim_test.cc b/paddle/framework/ddim_test.cc
index bd5ea09d7d..bc259d1f60 100644
--- a/paddle/framework/ddim_test.cc
+++ b/paddle/framework/ddim_test.cc
@@ -1,16 +1,16 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#include
#include
diff --git a/paddle/framework/details/op_registry.h b/paddle/framework/details/op_registry.h
index 7f5151c41d..6d50e820b2 100644
--- a/paddle/framework/details/op_registry.h
+++ b/paddle/framework/details/op_registry.h
@@ -1,16 +1,16 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#pragma once
diff --git a/paddle/framework/executor.cc b/paddle/framework/executor.cc
index 997773c168..bf1f0471cc 100644
--- a/paddle/framework/executor.cc
+++ b/paddle/framework/executor.cc
@@ -14,18 +14,17 @@ limitations under the License. */
#include "paddle/framework/executor.h"
-#include
-#include
-#include
#include
-#include
+#include "gflags/gflags.h"
#include "paddle/framework/feed_fetch_type.h"
#include "paddle/framework/lod_rank_table.h"
-#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/op_registry.h"
-#include "paddle/framework/scope.h"
+
+DEFINE_bool(check_nan_inf, false,
+ "Checking whether operator produce NAN/INF or not. It will be "
+ "extremely slow so please use this flag wisely.");
namespace paddle {
namespace framework {
@@ -58,6 +57,19 @@ static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) {
}
}
+static void CheckTensorNANOrInf(const std::string& name,
+ const framework::Tensor& tensor) {
+ if (tensor.memory_size() == 0) {
+ return;
+ }
+ if (tensor.type().hash_code() != typeid(float).hash_code() &&
+ tensor.type().hash_code() != typeid(double).hash_code()) {
+ return;
+ }
+ PADDLE_ENFORCE(!framework::HasInf(tensor), "Tensor %s has Inf", name);
+ PADDLE_ENFORCE(!framework::HasNAN(tensor), "Tensor %s has NAN", name);
+}
+
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
bool create_local_scope, bool create_vars) {
// TODO(tonyyang-svail):
@@ -101,8 +113,17 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
auto op = paddle::framework::OpRegistry::CreateOp(*op_desc);
VLOG(3) << op->DebugString();
op->Run(*local_scope, place_);
+ if (FLAGS_check_nan_inf) {
+ for (auto& vname : op->OutputVars(true)) {
+ auto* var = local_scope->FindVar(vname);
+ if (var == nullptr) continue;
+ if (var->IsType()) {
+ CheckTensorNANOrInf(vname, var->Get());
+ }
+ }
+ }
}
- if (create_local_scope) {
+ if (create_vars && create_local_scope) {
scope->DeleteScope(local_scope);
}
}
diff --git a/paddle/framework/feed_fetch_type.h b/paddle/framework/feed_fetch_type.h
index bc4ae440fc..9bc4a90c44 100644
--- a/paddle/framework/feed_fetch_type.h
+++ b/paddle/framework/feed_fetch_type.h
@@ -1,16 +1,16 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#pragma once
#include
diff --git a/paddle/framework/grad_op_desc_maker.h b/paddle/framework/grad_op_desc_maker.h
index cf411fa710..2de5242831 100644
--- a/paddle/framework/grad_op_desc_maker.h
+++ b/paddle/framework/grad_op_desc_maker.h
@@ -1,16 +1,16 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#pragma once
#include
diff --git a/paddle/framework/init.cc b/paddle/framework/init.cc
index 3ff2da3446..682cff168d 100644
--- a/paddle/framework/init.cc
+++ b/paddle/framework/init.cc
@@ -1,16 +1,16 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#include
#include
@@ -71,7 +71,7 @@ bool InitDevices(const std::vector &devices) {
places.emplace_back(platform::CPUPlace());
LOG(WARNING) << "Not specified CPU device, create CPU by Default.";
}
- platform::DeviceContextPool::Create(places);
+ platform::DeviceContextPool::Init(places);
return true;
}
diff --git a/paddle/framework/init.h b/paddle/framework/init.h
index 1715cd81e6..33907f9eb0 100644
--- a/paddle/framework/init.h
+++ b/paddle/framework/init.h
@@ -1,16 +1,16 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#pragma once
#include
diff --git a/paddle/framework/init_test.cc b/paddle/framework/init_test.cc
index cb1ba7ce8f..f0788051d4 100644
--- a/paddle/framework/init_test.cc
+++ b/paddle/framework/init_test.cc
@@ -1,16 +1,16 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/framework/init.h"
diff --git a/paddle/framework/library_type.h b/paddle/framework/library_type.h
index aa66cf00f3..7707799cae 100644
--- a/paddle/framework/library_type.h
+++ b/paddle/framework/library_type.h
@@ -20,18 +20,41 @@ namespace framework {
// For more details about the design of LibraryType, Please refer to
// https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/operator_kernel_type.md#library
-enum LibraryType { kPlain = 0, kMKLDNN = 1, kCUDNN = 2 };
+enum class LibraryType {
+ kPlain = 0,
+ kMKLDNN = 1,
+ kCUDNN = 2,
+};
inline std::string LibraryTypeToString(const LibraryType& library_type) {
switch (library_type) {
- case kPlain:
+ case LibraryType::kPlain:
return "PLAIN";
- case kMKLDNN:
+ case LibraryType::kMKLDNN:
return "MKLDNN";
- case kCUDNN:
+ case LibraryType::kCUDNN:
return "CUDNN";
default:
- PADDLE_THROW("unknown LibraryType %d", library_type);
+ PADDLE_THROW("unknown LibraryType %d", static_cast(library_type));
+ }
+}
+
+inline LibraryType StringToLibraryType(const char* ctype) {
+ std::string s(ctype);
+ if (s == std::string("PLAIN")) {
+ return LibraryType::kPlain;
+ } else if (s == std::string("MKLDNN")) {
+ return LibraryType::kMKLDNN;
+ } else if (s == std::string("CUDNN")) {
+ return LibraryType::kCUDNN;
+ // To be compatible with register macro.
+ // CPU, CUDA, PLAIN are same library type.
+ } else if (s == std::string("CPU")) {
+ return LibraryType::kPlain;
+ } else if (s == std::string("CUDA")) {
+ return LibraryType::kPlain;
+ } else {
+ PADDLE_THROW("Unknown LibraryType %s", s.c_str());
}
}
diff --git a/paddle/framework/lod_rank_table.cc b/paddle/framework/lod_rank_table.cc
index 17d524c092..704bce2a0e 100644
--- a/paddle/framework/lod_rank_table.cc
+++ b/paddle/framework/lod_rank_table.cc
@@ -1,16 +1,16 @@
-/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#include "paddle/framework/lod_rank_table.h"
diff --git a/paddle/framework/lod_rank_table.h b/paddle/framework/lod_rank_table.h
index d3007d3d73..df188709e9 100644
--- a/paddle/framework/lod_rank_table.h
+++ b/paddle/framework/lod_rank_table.h
@@ -1,16 +1,16 @@
-/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#pragma once
#include
diff --git a/paddle/framework/lod_tensor.cc b/paddle/framework/lod_tensor.cc
index d766d3c416..7b6dc09bdb 100644
--- a/paddle/framework/lod_tensor.cc
+++ b/paddle/framework/lod_tensor.cc
@@ -1,16 +1,16 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/data_type.h"
@@ -189,62 +189,16 @@ void AppendLoD(LoD *lod, const LoD &lod_length) {
void SerializeToStream(std::ostream &os, const LoDTensor &tensor,
const platform::DeviceContext &dev_ctx) {
- // TODO(typhoonzero): serialize to ostream
- { // the 1st field, uint32_t version
+ { // the 1st field, uint32_t version for LoDTensor
constexpr uint32_t version = 0;
os.write(reinterpret_cast(&version), sizeof(version));
}
- { // the 2nd field, tensor description
- // int32_t size
- // void* protobuf message
- proto::TensorDesc desc;
- desc.set_data_type(framework::ToDataType(tensor.type()));
- auto dims = framework::vectorize(tensor.dims());
- auto *pb_dims = desc.mutable_dims();
- pb_dims->Resize(static_cast(dims.size()), 0);
- std::copy(dims.begin(), dims.end(), pb_dims->begin());
- int32_t size = desc.ByteSize();
- os.write(reinterpret_cast(&size), sizeof(size));
- auto out = desc.SerializeAsString();
- os.write(out.data(), size);
- }
- { // the 3rd field, tensor data
- uint64_t size = tensor.memory_size();
- auto *data_ptr = tensor.data();
- PADDLE_ENFORCE(size < std::numeric_limits::max(),
- "Index overflow when writing tensor");
- if (platform::is_gpu_place(tensor.place())) {
-#ifdef PADDLE_WITH_CUDA
- constexpr size_t kBufSize = 1024 * 1024 * 64; // 64MB
- std::unique_ptr buf(new char[kBufSize]);
- auto &gpu_dev_ctx =
- static_cast(dev_ctx);
- platform::CPUPlace cpu;
- uintptr_t data = reinterpret_cast(data_ptr);
- while (size != 0) {
- size_t size_to_write = std::min(kBufSize, static_cast(size));
- memory::Copy(cpu, buf.get(),
- boost::get(tensor.place()),
- reinterpret_cast(data), size_to_write,
- gpu_dev_ctx.stream());
- gpu_dev_ctx.Wait();
- os.write(buf.get(), size_to_write);
- data += size_to_write;
- size -= size_to_write;
- }
-#else
- PADDLE_THROW("Unexpected branch");
-#endif
- } else {
- os.write(static_cast(data_ptr),
- static_cast(size));
- }
- }
- { // the 4th field, lod information
- // uint64_t lod_level
- // uint64_t lod_level_1 size in byte.
- // int* lod_level_1 data
- // ...
+ {
+ // the 2st field, LoD information
+ // uint64_t lod_level
+ // uint64_t lod_level_1 size in byte.
+ // int* lod_level_1 data
+ // ...
auto lod = tensor.lod();
uint64_t size = lod.size();
os.write(reinterpret_cast(&size), sizeof(size));
@@ -256,49 +210,19 @@ void SerializeToStream(std::ostream &os, const LoDTensor &tensor,
static_cast(size));
}
}
+ // the 3st field, Tensor
+ SerializeToStream(os, static_cast(tensor), dev_ctx);
}
void DeserializeFromStream(std::istream &is, LoDTensor *tensor) {
- uint32_t version;
- is.read(reinterpret_cast(&version), sizeof(version));
- PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported");
- proto::TensorDesc desc;
- { // int32_t size
- // proto buffer
- int32_t size;
- is.read(reinterpret_cast(&size), sizeof(size));
- std::unique_ptr buf(new char[size]);
- is.read(reinterpret_cast(buf.get()), size);
- PADDLE_ENFORCE(desc.ParseFromArray(buf.get(), size),
- "Cannot parse tensor desc");
- }
- { // read tensor
- std::vector dims;
- dims.reserve(static_cast(desc.dims().size()));
- std::copy(desc.dims().begin(), desc.dims().end(), std::back_inserter(dims));
- tensor->Resize(framework::make_ddim(dims));
-
- void *buf;
- platform::Place cpu = platform::CPUPlace();
- switch (desc.data_type()) {
- case proto::FP32:
- buf = tensor->mutable_data(cpu);
- break;
- case proto::FP64:
- buf = tensor->mutable_data(cpu);
- break;
- case proto::INT32:
- buf = tensor->mutable_data(cpu);
- break;
- case proto::INT64:
- buf = tensor->mutable_data(cpu);
- break;
- default:
- PADDLE_THROW("DataType %d not supported", desc.data_type());
- }
- is.read(static_cast(buf), tensor->memory_size());
- }
- { // read lod
+ {
+ // the 1st field, unit32_t version for SelectedRows
+ uint32_t version;
+ is.read(reinterpret_cast(&version), sizeof(version));
+ PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported");
+ }
+ {
+ // the 2st field, LoD information
uint64_t lod_level;
is.read(reinterpret_cast(&lod_level), sizeof(lod_level));
auto &lod = *tensor->mutable_lod();
@@ -312,6 +236,8 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor) {
lod[i] = tmp;
}
}
+ // the 3st filed, Tensor
+ DeserializeFromStream(is, static_cast(tensor));
}
} // namespace framework
diff --git a/paddle/framework/lod_tensor.h b/paddle/framework/lod_tensor.h
index 0923c52a0a..147db3ab08 100644
--- a/paddle/framework/lod_tensor.h
+++ b/paddle/framework/lod_tensor.h
@@ -1,16 +1,16 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#pragma once
diff --git a/paddle/framework/lod_tensor_array.h b/paddle/framework/lod_tensor_array.h
index 13f0608d24..4a8e7f4fa5 100644
--- a/paddle/framework/lod_tensor_array.h
+++ b/paddle/framework/lod_tensor_array.h
@@ -1,16 +1,16 @@
-/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#pragma once
#include
diff --git a/paddle/framework/lod_tensor_test.cc b/paddle/framework/lod_tensor_test.cc
index 02d84b6823..0747c8db53 100644
--- a/paddle/framework/lod_tensor_test.cc
+++ b/paddle/framework/lod_tensor_test.cc
@@ -126,6 +126,20 @@ TEST_F(LoDTensorTester, ShrinkInLevel) {
EXPECT_NE(t1.data(), lod_tensor_.data());
}
+TEST_F(LoDTensorTester, SerializeAndDeserialize) {
+ LoDTensor dst_tensor;
+ platform::CPUDeviceContext cpu_ctx((platform::CPUPlace()));
+ std::ostringstream oss;
+ SerializeToStream(oss, lod_tensor_, cpu_ctx);
+ std::istringstream iss(oss.str());
+ DeserializeFromStream(iss, &dst_tensor);
+ float* dst_ptr = dst_tensor.mutable_data(platform::CPUPlace());
+ for (int i = 0; i < kLodTensorSize; ++i) {
+ EXPECT_EQ(dst_ptr[i], i);
+ }
+ EXPECT_EQ(dst_tensor.lod(), lod_tensor_.lod());
+}
+
TEST(LodExpand, test) {
LoD lod{{0, 2}};
LoDTensor tensor;
diff --git a/paddle/framework/op_desc.cc b/paddle/framework/op_desc.cc
index b361e64438..781bbb4c19 100644
--- a/paddle/framework/op_desc.cc
+++ b/paddle/framework/op_desc.cc
@@ -88,6 +88,14 @@ OpDesc::OpDesc(const std::string &type, const VariableNameMap &inputs,
need_update_ = true;
}
+void OpDesc::CopyFrom(const OpDesc &op_desc) {
+ desc_.set_type(op_desc.Type());
+ inputs_ = op_desc.inputs_;
+ outputs_ = op_desc.outputs_;
+ attrs_ = op_desc.attrs_;
+ need_update_ = true;
+}
+
OpDesc::OpDesc(const proto::OpDesc &desc, ProgramDesc *prog)
: desc_(desc), need_update_(false) {
// restore inputs_
diff --git a/paddle/framework/op_desc.h b/paddle/framework/op_desc.h
index 93d4a88f3c..4cf784a0d0 100644
--- a/paddle/framework/op_desc.h
+++ b/paddle/framework/op_desc.h
@@ -35,6 +35,8 @@ class OpDesc {
OpDesc(const proto::OpDesc &desc, ProgramDesc *prog);
+ void CopyFrom(const OpDesc &op_desc);
+
proto::OpDesc *Proto();
std::string Type() const { return desc_.type(); }
diff --git a/paddle/framework/op_info.cc b/paddle/framework/op_info.cc
index 81ba29797c..b520108109 100644
--- a/paddle/framework/op_info.cc
+++ b/paddle/framework/op_info.cc
@@ -1,16 +1,16 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#include "paddle/framework/op_info.h"
diff --git a/paddle/framework/op_info.h b/paddle/framework/op_info.h
index 7772d6e745..d9b89f9cac 100644
--- a/paddle/framework/op_info.h
+++ b/paddle/framework/op_info.h
@@ -1,16 +1,16 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#pragma once
#include
diff --git a/paddle/framework/op_kernel_type.h b/paddle/framework/op_kernel_type.h
index e9c45b958c..b06002096f 100644
--- a/paddle/framework/op_kernel_type.h
+++ b/paddle/framework/op_kernel_type.h
@@ -40,6 +40,7 @@ struct OpKernelType {
// place, data_type, library_type kinds less than 2^8
constexpr static int LEFT_SHIFT = 8;
+
proto::DataType data_type_;
DataLayout data_layout_;
platform::Place place_;
@@ -67,6 +68,8 @@ struct OpKernelType {
data_type_ == o.data_type_ && data_layout_ == o.data_layout_ &&
library_type_ == o.library_type_;
}
+
+ bool operator!=(const OpKernelType& o) const { return !(*this == o); }
};
inline std::ostream& operator<<(std::ostream& os,
@@ -77,5 +80,11 @@ inline std::ostream& operator<<(std::ostream& os,
return os;
}
+inline std::string KernelTypeToString(const OpKernelType& kernel_key) {
+ std::ostringstream stream;
+ stream << kernel_key;
+ return stream.str();
+}
+
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/op_kernel_type_test.cc b/paddle/framework/op_kernel_type_test.cc
index 8753d7cc37..649afeee8a 100644
--- a/paddle/framework/op_kernel_type_test.cc
+++ b/paddle/framework/op_kernel_type_test.cc
@@ -26,10 +26,8 @@ TEST(OpKernelType, ToString) {
OpKernelType op_kernel_type(DataType::FP32, CPUPlace(), DataLayout::kNCHW,
LibraryType::kCUDNN);
- std::ostringstream stream;
- stream << op_kernel_type;
ASSERT_EQ(
- stream.str(),
+ paddle::framework::KernelTypeToString(op_kernel_type),
"data_type[5]:data_layout[NCHW]:place[CPUPlace]:library_type[CUDNN]");
}
@@ -48,4 +46,4 @@ TEST(OpKernelType, Hash) {
OpKernelType::Hash hasher;
ASSERT_NE(hasher(op_kernel_type_1), hasher(op_kernel_type_2));
-}
\ No newline at end of file
+}
diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h
index 9bb2a3b5c2..bdaa259181 100644
--- a/paddle/framework/op_registry.h
+++ b/paddle/framework/op_registry.h
@@ -79,30 +79,31 @@ struct OpKernelRegistrarFunctor {
using KERNEL_TYPE =
typename std::tuple_element>::type;
- void operator()(const char* op_type) const {
+ void operator()(const char* op_type, const char* library_type) const {
using T = typename KERNEL_TYPE::ELEMENT_TYPE;
- OpKernelType key(ToDataType(std::type_index(typeid(T))), PlaceType());
+ OpKernelType key(ToDataType(std::type_index(typeid(T))), PlaceType(),
+ DataLayout::kAnyLayout, StringToLibraryType(library_type));
OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KERNEL_TYPE);
constexpr auto size = std::tuple_size>::value;
OpKernelRegistrarFunctor
func;
- func(op_type);
+ func(op_type, library_type);
}
};
template
struct OpKernelRegistrarFunctor {
- void operator()(const char* op_type) const {}
+ void operator()(const char* op_type, const char* library_type) const {}
};
// User can register many kernel in one place. The data type could be different.
template
class OpKernelRegistrar : public Registrar {
public:
- explicit OpKernelRegistrar(const char* op_type) {
+ explicit OpKernelRegistrar(const char* op_type, const char* library_type) {
OpKernelRegistrarFunctor func;
- func(op_type);
+ func(op_type, library_type);
}
};
@@ -181,7 +182,8 @@ class OpKernelRegistrar : public Registrar {
__reg_op_kernel_##op_type##_##DEVICE_TYPE##__, \
"REGISTER_OP_KERNEL must be called in global namespace"); \
static ::paddle::framework::OpKernelRegistrar \
- __op_kernel_registrar_##op_type##_##DEVICE_TYPE##__(#op_type); \
+ __op_kernel_registrar_##op_type##_##DEVICE_TYPE##__(#op_type, \
+ #DEVICE_TYPE); \
int TouchOpKernelRegistrar_##op_type##_##DEVICE_TYPE() { \
__op_kernel_registrar_##op_type##_##DEVICE_TYPE##__.Touch(); \
return 0; \
diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc
index 4cdf6e0865..cef530c6e6 100644
--- a/paddle/framework/op_registry_test.cc
+++ b/paddle/framework/op_registry_test.cc
@@ -1,3 +1,17 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License. */
+
#include "paddle/framework/op_registry.h"
#include
@@ -182,3 +196,71 @@ TEST(OperatorRegistrar, Test) {
using namespace paddle::framework;
OperatorRegistrar reg("cos");
}
+
+namespace paddle {
+namespace framework {
+
+class OpKernelTestMaker : public OpProtoAndCheckerMaker {
+ public:
+ OpKernelTestMaker(OpProto* proto, OpAttrChecker* op_checker)
+ : OpProtoAndCheckerMaker(proto, op_checker) {
+ AddComment("NoGradOp, same input output. no Grad");
+ }
+};
+
+class OpWithKernelTest : public OperatorWithKernel {
+ public:
+ using OperatorWithKernel::OperatorWithKernel;
+
+ protected:
+ void InferShape(InferShapeContext* ctx) const override {}
+
+ framework::OpKernelType GetActualKernelType(
+ const framework::ExecutionContext& ctx) const override {
+ return framework::OpKernelType(proto::DataType::FP32, ctx.device_context());
+ }
+};
+
+template
+class OpKernelTest : public paddle::framework::OpKernel {
+ public:
+ void Compute(const paddle::framework::ExecutionContext& ctx) const {}
+};
+
+} // namespace framework
+} // namespace paddle
+
+REGISTER_OP_WITHOUT_GRADIENT(op_with_kernel,
+ paddle::framework::OpWithKernelTest,
+ paddle::framework::OpKernelTestMaker);
+REGISTER_OP_CPU_KERNEL(
+ op_with_kernel,
+ paddle::framework::OpKernelTest);
+
+REGISTER_OP_CUDA_KERNEL(op_with_kernel,
+ paddle::framework::OpKernelTest<
+ paddle::platform::CUDADeviceContext, float>);
+
+TEST(OperatorRegistrar, CPU) {
+ paddle::framework::proto::OpDesc op_desc;
+ paddle::platform::CPUPlace cpu_place;
+ paddle::framework::Scope scope;
+
+ op_desc.set_type("op_with_kernel");
+ auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
+
+ op->Run(scope, cpu_place);
+}
+
+#ifdef PADDLE_WITH_CUDA
+TEST(OperatorRegistrar, CUDA) {
+ paddle::framework::proto::OpDesc op_desc;
+ paddle::platform::CUDAPlace cuda_place(0);
+ paddle::framework::Scope scope;
+
+ op_desc.set_type("op_with_kernel");
+ auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
+
+ op->Run(scope, cuda_place);
+}
+#endif
diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc
index 66840a2e03..a3ce96c409 100644
--- a/paddle/framework/operator.cc
+++ b/paddle/framework/operator.cc
@@ -15,6 +15,7 @@ limitations under the License. */
#include
#include
+#include "paddle/framework/data_transform.h"
#include "paddle/framework/executor.h"
#include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/operator.h"
@@ -383,12 +384,30 @@ class RuntimeInferShapeContext : public InferShapeContext {
const Scope& scope_;
};
+const platform::DeviceContext* GetDeviceContext(
+ framework::KernelTypePair& kernel_pair) {
+ auto& actual_kernel_key = kernel_pair.first;
+ auto& expected_kernel_key = kernel_pair.second;
+ platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
+
+ if (platform::is_gpu_place(actual_kernel_key.place_) &&
+ platform::is_cpu_place(expected_kernel_key.place_)) {
+ return pool.Get(actual_kernel_key.place_);
+ } else if (platform::is_cpu_place(actual_kernel_key.place_) &&
+ platform::is_gpu_place(expected_kernel_key.place_)) {
+ return pool.Get(expected_kernel_key.place_);
+ } else {
+ PADDLE_THROW(
+ "Currently, model parallelism is only supported between CPU and CUDA");
+ }
+}
+
void OperatorWithKernel::Run(const Scope& scope,
const platform::Place& place) const {
RuntimeInferShapeContext infer_shape_ctx(*this, scope);
this->InferShape(&infer_shape_ctx);
- platform::DeviceContextPool& pool = platform::DeviceContextPool::Get();
- auto dev_ctx = pool.Borrow(place);
+ platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
+ auto dev_ctx = pool.Get(place);
// check if op[type] has kernel registered.
auto& all_op_kernels = AllOpKernels();
@@ -411,6 +430,47 @@ void OperatorWithKernel::Run(const Scope& scope,
expected_kernel_key);
}
+ if (actual_kernel_key == expected_kernel_key) {
+ PADDLE_ENFORCE_EQ(actual_kernel_key.place_, expected_kernel_key.place_,
+ "Currently, model parallelism is only supported between "
+ "CPU and other devices. For example, multi-GPU model "
+ "parallelism will failed.");
+ } else {
+ auto kernel_pair = std::make_pair(actual_kernel_key, expected_kernel_key);
+ const DataTransformFn* trans_fun =
+ DataTransformFnMap::Instance().GetNullable(kernel_pair);
+ if (trans_fun) {
+ auto input_vars = this->InputVars();
+ // TODO(qijun) filter the input vars that do not need to be transformed
+
+ // filter vars that has been transformed
+ std::vector need_trans;
+ for (auto var_name : input_vars) {
+ auto var_name_trans =
+ var_name + framework::KernelTypeToString(expected_kernel_key);
+ if (!scope.FindVar(var_name_trans)) {
+ const_cast(scope).Var(var_name_trans);
+ need_trans.push_back(var_name);
+ }
+ }
+
+ if (!need_trans.empty()) {
+ auto trans_dev_ctx = GetDeviceContext(kernel_pair);
+
+ // Wait for transform starting
+ dev_ctx->Wait();
+
+ for (auto var_name : need_trans) {
+ (*trans_fun)(trans_dev_ctx, *(scope.FindVar(var_name)),
+ scope.FindVar(var_name + framework::KernelTypeToString(
+ expected_kernel_key)));
+ }
+ // Wait for data transform finishing
+ trans_dev_ctx->Wait();
+ }
+ }
+ }
+
kernel_iter->second->Compute(ctx);
}
diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h
index 55eed57e66..d0a9b643d5 100644
--- a/paddle/framework/operator.h
+++ b/paddle/framework/operator.h
@@ -89,6 +89,9 @@ class OperatorBase {
/// Net will call this function to Run an op.
virtual void Run(const Scope& scope, const platform::Place& place) const = 0;
+ // FIXME(typhoonzero): this is only used for recv_op to stop event_loop.
+ virtual void Stop() {}
+
virtual bool IsNetOp() const { return false; }
virtual bool SupportGPU() const { return false; }
diff --git a/paddle/framework/program_desc_test.cc b/paddle/framework/program_desc_test.cc
index a49886f7ea..59947c9f21 100644
--- a/paddle/framework/program_desc_test.cc
+++ b/paddle/framework/program_desc_test.cc
@@ -1,16 +1,16 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#include "paddle/framework/program_desc.h"
#include "gtest/gtest.h"
diff --git a/paddle/framework/prune_test.cc b/paddle/framework/prune_test.cc
index bdd5765943..d76c5abca9 100644
--- a/paddle/framework/prune_test.cc
+++ b/paddle/framework/prune_test.cc
@@ -1,16 +1,16 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#include "paddle/framework/prune.h"
diff --git a/paddle/framework/scope.cc b/paddle/framework/scope.cc
index 656736e238..0c01d605bc 100644
--- a/paddle/framework/scope.cc
+++ b/paddle/framework/scope.cc
@@ -74,17 +74,9 @@ void Scope::DropKids() {
kids_.clear();
}
-std::vector Scope::GetAllNames(bool recursive) const {
- std::vector known_vars(vars_.size());
-
- if (recursive) {
- for (auto& kid : kids_) {
- auto kid_vars = kid->GetAllNames();
- for (auto& p : kid_vars) {
- known_vars.emplace_back(p);
- }
- }
- }
+std::vector Scope::LocalVarNames() const {
+ std::vector known_vars;
+ known_vars.reserve(this->vars_.size());
for (auto& p : vars_) {
known_vars.emplace_back(p.first);
}
diff --git a/paddle/framework/scope.h b/paddle/framework/scope.h
index 56e815db54..10143326df 100644
--- a/paddle/framework/scope.h
+++ b/paddle/framework/scope.h
@@ -66,7 +66,7 @@ class Scope {
void DropKids();
// enumerate all the variables current contains.
- std::vector GetAllNames(bool recursive = false) const;
+ std::vector LocalVarNames() const;
// Rename variable to a new name
void Rename(const std::string& origin_name,
diff --git a/paddle/framework/scope_test.cc b/paddle/framework/scope_test.cc
index f738d5ba9e..0f5b86061d 100644
--- a/paddle/framework/scope_test.cc
+++ b/paddle/framework/scope_test.cc
@@ -61,7 +61,7 @@ TEST(Scope, GetAllNames) {
Variable* v = s.Var("a");
EXPECT_EQ(&s, s.FindScope(v));
- std::vector ans = s.GetAllNames();
+ std::vector ans = s.LocalVarNames();
std::string str;
for (auto& var : ans) {
str += var;
diff --git a/paddle/framework/selected_rows.cc b/paddle/framework/selected_rows.cc
index c74459c9dd..82adfa7123 100644
--- a/paddle/framework/selected_rows.cc
+++ b/paddle/framework/selected_rows.cc
@@ -12,5 +12,58 @@ limitations under the License. */
#include "paddle/framework/selected_rows.h"
namespace paddle {
-namespace framework {} // namespace framework
+namespace framework {
+void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows,
+ const platform::DeviceContext& dev_ctx) {
+ { // the 1st field, uint32_t version
+ constexpr uint32_t version = 0;
+ os.write(reinterpret_cast(&version), sizeof(version));
+ }
+ {
+ // the 2st field, rows information
+ auto& rows = selected_rows.rows();
+ uint64_t size = rows.size();
+ os.write(reinterpret_cast(&size), sizeof(size));
+ for (uint64_t i = 0; i < size; ++i) {
+ os.write(reinterpret_cast(&rows[i]), sizeof(rows[i]));
+ }
+ }
+ {
+ // the 3st field, the height of SelectedRows
+ int64_t height = selected_rows.height();
+ os.write(reinterpret_cast(&height), sizeof(height));
+ }
+ // the 4st field, Tensor data
+ SerializeToStream(os, selected_rows.value(), dev_ctx);
+}
+
+void DeserializeFromStream(std::istream& is, SelectedRows* selected_rows) {
+ auto tensor = *selected_rows->mutable_value();
+ {
+ // the 1st field, unit32_t version for SelectedRows
+ uint32_t version;
+ is.read(reinterpret_cast(&version), sizeof(version));
+ PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported");
+ }
+ {
+ // the 2st field, rows information
+ uint64_t size;
+ is.read(reinterpret_cast(&size), sizeof(size));
+ auto& rows = *selected_rows->mutable_rows();
+ rows.resize(size);
+ for (uint64_t i = 0; i < size; ++i) {
+ is.read(reinterpret_cast(&rows[i]), sizeof(int64_t));
+ }
+ }
+ {
+ // the 3st field, the height of the SelectedRows
+ int64_t height;
+ is.read(reinterpret_cast(&height), sizeof(int64_t));
+ selected_rows->set_height(height);
+ }
+ // the 4st field, tensor which contains the data
+ DeserializeFromStream(is, &tensor);
+}
+
+} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/selected_rows.h b/paddle/framework/selected_rows.h
index 0332b91323..699e392688 100644
--- a/paddle/framework/selected_rows.h
+++ b/paddle/framework/selected_rows.h
@@ -59,5 +59,14 @@ class SelectedRows {
int64_t height_;
};
+/*
+ * Serialize/Desiralize SelectedRows to std::ostream
+ * You can pass ofstream or ostringstream to serilize to file
+ * or to a in memory string. GPU tensor will be copied to CPU.
+ */
+void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows,
+ const platform::DeviceContext& dev_ctx);
+void DeserializeFromStream(std::istream& is, SelectedRows* selected_rows);
+
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/selected_rows_test.cc b/paddle/framework/selected_rows_test.cc
index 4ee13a65d7..75487c4010 100644
--- a/paddle/framework/selected_rows_test.cc
+++ b/paddle/framework/selected_rows_test.cc
@@ -43,5 +43,19 @@ TEST_F(SelectedRowsTester, complete_dims) {
ASSERT_EQ(selected_rows_->GetCompleteDims(), make_ddim({10, 100}));
}
+TEST_F(SelectedRowsTester, SerializeAndDeseralize) {
+ SelectedRows dst_tensor;
+ platform::CPUDeviceContext cpu_ctx(place_);
+ std::ostringstream oss;
+
+ SerializeToStream(oss, *selected_rows_, cpu_ctx);
+
+ std::istringstream iss(oss.str());
+ DeserializeFromStream(iss, &dst_tensor);
+
+ ASSERT_EQ(selected_rows_->rows(), dst_tensor.rows());
+ ASSERT_EQ(selected_rows_->height(), dst_tensor.height());
+}
+
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/shape_inference.cc b/paddle/framework/shape_inference.cc
index 86dc01665b..e53cc0cdab 100644
--- a/paddle/framework/shape_inference.cc
+++ b/paddle/framework/shape_inference.cc
@@ -1,16 +1,16 @@
-/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#include "paddle/framework/shape_inference.h"
#include "grad_op_desc_maker.h"
#include "paddle/framework/operator.h"
diff --git a/paddle/framework/tensor.cc b/paddle/framework/tensor.cc
index ea7b2a1f7b..f922e60624 100644
--- a/paddle/framework/tensor.cc
+++ b/paddle/framework/tensor.cc
@@ -1,16 +1,16 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#include "paddle/framework/tensor.h"
diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h
index 6a0c5133c9..341a6949be 100644
--- a/paddle/framework/tensor.h
+++ b/paddle/framework/tensor.h
@@ -20,12 +20,12 @@ limitations under the License. */
#include
#include
+#include "paddle/framework/data_layout.h"
#include "paddle/framework/ddim.h"
#include "paddle/memory/memory.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/place.h"
-#include "unsupported/Eigen/CXX11/Tensor"
namespace paddle {
@@ -115,6 +115,10 @@ class Tensor {
inline void check_memory_size() const;
+ inline DataLayout layout() const { return layout_; }
+
+ inline void set_layout(const DataLayout layout) { layout_ = layout; }
+
private:
friend class LoDTensor;
@@ -173,6 +177,19 @@ class Tensor {
DDim dims_;
+ /**
+ * @brief the layout of memory block, default is NHWC.
+ *
+ * @note the memory allocation order, describe how weight/data is stored
+ * For example, in 4-D Tensor(rank=4), there are three commonly
+ * used layout. They are
+ * NCHW, NHWC, CHWN.
+ * N,C,H,W for respectively the batch size, the number of
+ * feature maps, the height.
+ */
+
+ DataLayout layout_ = DataLayout::kNHWC;
+
/**
* @brief A PlaceHolder may be shared by more than one tensor.
*
diff --git a/paddle/framework/tensor_impl.h b/paddle/framework/tensor_impl.h
index 3d93b7808b..6c6f298edc 100644
--- a/paddle/framework/tensor_impl.h
+++ b/paddle/framework/tensor_impl.h
@@ -165,6 +165,7 @@ inline Tensor Tensor::Slice(int begin_idx, int end_idx) const {
size_t base = numel() / dims_[0];
Tensor dst;
dst.holder_ = holder_;
+ dst.set_layout(layout_);
DDim dst_dims = dims_;
dst_dims[0] = end_idx - begin_idx;
dst.Resize(dst_dims);
diff --git a/paddle/framework/tensor_test.cc b/paddle/framework/tensor_test.cc
index f347981f2e..a1b4a03289 100644
--- a/paddle/framework/tensor_test.cc
+++ b/paddle/framework/tensor_test.cc
@@ -15,12 +15,13 @@
#include
#include
+namespace framework = paddle::framework;
+namespace platform = paddle::platform;
+
TEST(Tensor, Dims) {
- using namespace paddle::framework;
- using namespace paddle::platform;
- Tensor tt;
+ framework::Tensor tt;
tt.Resize({2, 3, 4});
- DDim dims = tt.dims();
+ framework::DDim dims = tt.dims();
ASSERT_EQ(arity(dims), 3);
for (int i = 0; i < 3; ++i) {
EXPECT_EQ(i + 2, dims[i]);
@@ -28,12 +29,12 @@ TEST(Tensor, Dims) {
}
TEST(Tensor, DataAssert) {
- paddle::framework::Tensor src_tensor;
+ framework::Tensor src_tensor;
bool caught = false;
try {
src_tensor.data();
- } catch (paddle::platform::EnforceNotMet err) {
+ } catch (platform::EnforceNotMet err) {
caught = true;
std::string msg =
"holder_ should not be null\nTensor holds no memory. Call "
@@ -50,61 +51,65 @@ TEST(Tensor, DataAssert) {
because Memory::Alloc() and Memory::Free() have not been ready.
*/
TEST(Tensor, MutableData) {
- using namespace paddle::framework;
- using namespace paddle::platform;
{
- Tensor src_tensor;
+ framework::Tensor src_tensor;
float* p1 = nullptr;
float* p2 = nullptr;
// initialization
- p1 = src_tensor.mutable_data(make_ddim({1, 2, 3}), CPUPlace());
+ p1 = src_tensor.mutable_data(framework::make_ddim({1, 2, 3}),
+ platform::CPUPlace());
EXPECT_NE(p1, nullptr);
// set src_tensor a new dim with large size
// momery is supposed to be re-allocated
- p2 = src_tensor.mutable_data(make_ddim({3, 4}), CPUPlace());
+ p2 = src_tensor.mutable_data(framework::make_ddim({3, 4}),
+ platform::CPUPlace());
EXPECT_NE(p2, nullptr);
EXPECT_NE(p1, p2);
// set src_tensor a new dim with same size
// momery block is supposed to be unchanged
- p1 = src_tensor.mutable_data(make_ddim({2, 2, 3}), CPUPlace());
+ p1 = src_tensor.mutable_data(framework::make_ddim({2, 2, 3}),
+ platform::CPUPlace());
EXPECT_EQ(p1, p2);
// set src_tensor a new dim with smaller size
// momery block is supposed to be unchanged
- p2 = src_tensor.mutable_data(make_ddim({2, 2}), CPUPlace());
+ p2 = src_tensor.mutable_data(framework::make_ddim({2, 2}),
+ platform::CPUPlace());
EXPECT_EQ(p1, p2);
}
#ifdef PADDLE_WITH_CUDA
{
- Tensor src_tensor;
+ framework::Tensor src_tensor;
float* p1 = nullptr;
float* p2 = nullptr;
// initialization
- p1 = src_tensor.mutable_data(make_ddim({1, 2, 3}), CUDAPlace());
+ p1 = src_tensor.mutable_data(framework::make_ddim({1, 2, 3}),
+ platform::CUDAPlace());
EXPECT_NE(p1, nullptr);
// set src_tensor a new dim with large size
// momery is supposed to be re-allocated
- p2 = src_tensor.mutable_data(make_ddim({3, 4}), CUDAPlace());
+ p2 = src_tensor.mutable_data(framework::make_ddim({3, 4}),
+ platform::CUDAPlace());
EXPECT_NE(p2, nullptr);
EXPECT_NE(p1, p2);
// set src_tensor a new dim with same size
// momery block is supposed to be unchanged
- p1 = src_tensor.mutable_data(make_ddim({2, 2, 3}), CUDAPlace());
+ p1 = src_tensor.mutable_data(framework::make_ddim({2, 2, 3}),
+ platform::CUDAPlace());
EXPECT_EQ(p1, p2);
// set src_tensor a new dim with smaller size
// momery block is supposed to be unchanged
- p2 = src_tensor.mutable_data(make_ddim({2, 2}), CUDAPlace());
+ p2 = src_tensor.mutable_data(framework::make_ddim({2, 2}),
+ platform::CUDAPlace());
EXPECT_EQ(p1, p2);
}
#endif
}
TEST(Tensor, ShareDataWith) {
- using namespace paddle::framework;
- using namespace paddle::platform;
{
- Tensor src_tensor;
- Tensor dst_tensor;
+ framework::Tensor src_tensor;
+ framework::Tensor dst_tensor;
// Try to share data form uninitialized tensor
bool caught = false;
try {
@@ -121,16 +126,18 @@ TEST(Tensor, ShareDataWith) {
}
ASSERT_TRUE(caught);
- src_tensor.mutable_data(make_ddim({2, 3, 4}), CPUPlace());
+ src_tensor.mutable_data(framework::make_ddim({2, 3, 4}),
+ platform::CPUPlace());
dst_tensor.ShareDataWith(src_tensor);
ASSERT_EQ(src_tensor.data(), dst_tensor.data());
}
#ifdef PADDLE_WITH_CUDA
{
- Tensor src_tensor;
- Tensor dst_tensor;
- src_tensor.mutable_data(make_ddim({2, 3, 4}), CUDAPlace());
+ framework::Tensor src_tensor;
+ framework::Tensor dst_tensor;
+ src_tensor.mutable_data(framework::make_ddim({2, 3, 4}),
+ platform::CUDAPlace());
dst_tensor.ShareDataWith(src_tensor);
ASSERT_EQ(src_tensor.data(), dst_tensor.data());
}
@@ -138,13 +145,12 @@ TEST(Tensor, ShareDataWith) {
}
TEST(Tensor, Slice) {
- using namespace paddle::framework;
- using namespace paddle::platform;
{
- Tensor src_tensor;
- src_tensor.mutable_data(make_ddim({5, 3, 4}), CPUPlace());
- Tensor slice_tensor = src_tensor.Slice(1, 3);
- DDim slice_dims = slice_tensor.dims();
+ framework::Tensor src_tensor;
+ src_tensor.mutable_data(framework::make_ddim({5, 3, 4}),
+ platform::CPUPlace());
+ framework::Tensor slice_tensor = src_tensor.Slice(1, 3);
+ framework::DDim slice_dims = slice_tensor.dims();
ASSERT_EQ(arity(slice_dims), 3);
EXPECT_EQ(slice_dims[0], 2);
EXPECT_EQ(slice_dims[1], 3);
@@ -153,11 +159,12 @@ TEST(Tensor, Slice) {
uintptr_t src_data_address =
reinterpret_cast(src_tensor.data());
uintptr_t src_mutable_data_address = reinterpret_cast(
- src_tensor.mutable_data(src_tensor.dims(), CPUPlace()));
+ src_tensor.mutable_data(src_tensor.dims(), platform::CPUPlace()));
uintptr_t slice_data_address =
reinterpret_cast(slice_tensor.data());
- uintptr_t slice_mutable_data_address = reinterpret_cast(
- slice_tensor.mutable_data(slice_tensor.dims(), CPUPlace()));
+ uintptr_t slice_mutable_data_address =
+ reinterpret_cast(slice_tensor.mutable_data(
+ slice_tensor.dims(), platform::CPUPlace()));
EXPECT_EQ(src_data_address, src_mutable_data_address);
EXPECT_EQ(slice_data_address, slice_mutable_data_address);
EXPECT_EQ(src_data_address + 3 * 4 * 1 * sizeof(int), slice_data_address);
@@ -165,22 +172,25 @@ TEST(Tensor, Slice) {
#ifdef PADDLE_WITH_CUDA
{
- Tensor src_tensor;
- src_tensor.mutable_data(make_ddim({6, 9}), CUDAPlace());
- Tensor slice_tensor = src_tensor.Slice(2, 6);
- DDim slice_dims = slice_tensor.dims();
+ framework::Tensor src_tensor;
+ src_tensor.mutable_data(framework::make_ddim({6, 9}),
+ platform::CUDAPlace());
+ framework::Tensor slice_tensor = src_tensor.Slice(2, 6);
+ framework::DDim slice_dims = slice_tensor.dims();
ASSERT_EQ(arity(slice_dims), 2);
EXPECT_EQ(slice_dims[0], 4);
EXPECT_EQ(slice_dims[1], 9);
uintptr_t src_data_address =
reinterpret_cast(src_tensor.data());
- uintptr_t src_mutable_data_address = reinterpret_cast(
- src_tensor.mutable_data(src_tensor.dims(), CUDAPlace()));
+ uintptr_t src_mutable_data_address =
+ reinterpret_cast(src_tensor.mutable_data(
+ src_tensor.dims(), platform::CUDAPlace()));
uintptr_t slice_data_address =
reinterpret_cast(slice_tensor.data());
- uintptr_t slice_mutable_data_address = reinterpret_cast(
- slice_tensor.mutable_data(slice_tensor.dims(), CUDAPlace()));
+ uintptr_t slice_mutable_data_address =
+ reinterpret_cast(slice_tensor.mutable_data(
+ slice_tensor.dims(), platform::CUDAPlace()));
EXPECT_EQ(src_data_address, src_mutable_data_address);
EXPECT_EQ(slice_data_address, slice_mutable_data_address);
EXPECT_EQ(src_data_address + 9 * 2 * sizeof(double), slice_data_address);
@@ -189,14 +199,19 @@ TEST(Tensor, Slice) {
}
TEST(Tensor, ReshapeToMatrix) {
- using namespace paddle::framework;
- using namespace paddle::platform;
- Tensor src;
- int* src_ptr = src.mutable_data({2, 3, 4, 9}, CPUPlace());
+ framework::Tensor src;
+ int* src_ptr = src.mutable_data({2, 3, 4, 9}, platform::CPUPlace());
for (int i = 0; i < 2 * 3 * 4 * 9; ++i) {
src_ptr[i] = i;
}
- Tensor res = ReshapeToMatrix(src, 2);
+ framework::Tensor res = framework::ReshapeToMatrix(src, 2);
ASSERT_EQ(res.dims()[0], 2 * 3);
ASSERT_EQ(res.dims()[1], 4 * 9);
}
+
+TEST(Tensor, Layout) {
+ framework::Tensor src;
+ ASSERT_EQ(src.layout(), framework::DataLayout::kNHWC);
+ src.set_layout(framework::DataLayout::kAnyLayout);
+ ASSERT_EQ(src.layout(), framework::DataLayout::kAnyLayout);
+}
diff --git a/paddle/framework/tensor_util.cc b/paddle/framework/tensor_util.cc
new file mode 100644
index 0000000000..7efc649d0b
--- /dev/null
+++ b/paddle/framework/tensor_util.cc
@@ -0,0 +1,119 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License. */
+
+#include "paddle/framework/tensor_util.h"
+
+namespace paddle {
+namespace framework {
+template
+struct AnyDTypeVisitor {
+ Predicate predicate_;
+ const Tensor& tensor_;
+ const DevCtx& ctx_;
+ Tensor* out_;
+
+ AnyDTypeVisitor(Predicate predicate, const Tensor& tensor, const DevCtx& ctx,
+ Tensor* out)
+ : predicate_(predicate), tensor_(tensor), ctx_(ctx), out_(out) {}
+
+ template
+ void operator()() const {
+ auto t = EigenVector::Flatten(tensor_);
+ auto o = EigenScalar::From(*out_);
+ // return any of predicate_(t) is true.
+ o.device(*ctx_.eigen_device()) = predicate_(t).any();
+ }
+};
+
+template
+inline void AnyImpl(Predicate predicate, const framework::Tensor& tensor,
+ const DevCtx& ctx, framework::Tensor* out) {
+ VisitDataType(ToDataType(tensor.type()), AnyDTypeVisitor(
+ predicate, tensor, ctx, out));
+}
+
+template
+struct AnyVisitor : public boost::static_visitor {
+ const framework::Tensor& tensor_;
+ Predicate predicate_;
+
+ AnyVisitor(const framework::Tensor& tensor, Predicate predicate)
+ : tensor_(tensor), predicate_(std::move(predicate)) {}
+
+ template
+ bool operator()(const Place& place) const {
+ framework::Tensor out;
+ out.Resize({1});
+ out.mutable_data(place);
+ auto* ctx = platform::DeviceContextPool::Instance().GetByPlace(place);
+ AnyImpl(predicate_, tensor_, *ctx, &out);
+ return this->GetResult(out, place);
+ }
+
+ bool GetResult(const framework::Tensor& out,
+ const platform::CUDAPlace& gpu) const {
+ platform::CPUPlace cpu;
+ framework::Tensor tmp;
+ tmp.Resize({1});
+ tmp.mutable_data(cpu);
+ auto gpuctx = platform::DeviceContextPool::Instance().Get(gpu);
+ gpuctx->Wait();
+ CopyFrom(out, cpu, *gpuctx, &tmp);
+ gpuctx->Wait();
+ return GetResult(tmp, cpu);
+ }
+
+ bool GetResult(const framework::Tensor& out,
+ const platform::CPUPlace& cpu) const {
+ return *out.data();
+ }
+};
+
+template
+inline bool Any(const framework::Tensor& tensor, Predicate predicate) {
+ AnyVisitor visitor(tensor, predicate);
+ auto place = tensor.place();
+ return platform::VisitPlace(place, visitor);
+}
+
+struct HasNANPredicate {
+ template
+ auto operator()(const T& eigen_vec) const
+ -> decltype(std::declval().isnan()) {
+ // Cast eigen_vector to vector of bool. true if is inf.
+ return eigen_vec.isnan();
+ }
+};
+
+bool HasNAN(const framework::Tensor& tensor) {
+ HasNANPredicate predicate;
+ return Any(tensor, predicate);
+}
+
+struct HasInfPredicate {
+ template
+ auto operator()(const T& eigen_vec) const
+ -> decltype(std::declval().isinf()) {
+ // Cast eigen_vector to vector of bool. true if is inf.
+ return eigen_vec.isinf();
+ }
+};
+
+bool HasInf(const framework::Tensor& tensor) {
+ HasInfPredicate predicate;
+ return Any(tensor, predicate);
+}
+
+} // namespace framework
+} // namespace paddle
diff --git a/paddle/framework/tensor_util.cu b/paddle/framework/tensor_util.cu
new file mode 120000
index 0000000000..b00e6e59d9
--- /dev/null
+++ b/paddle/framework/tensor_util.cu
@@ -0,0 +1 @@
+./tensor_util.cc
\ No newline at end of file
diff --git a/paddle/framework/tensor_util.h b/paddle/framework/tensor_util.h
index ebfb0e5538..6a21f8db1e 100644
--- a/paddle/framework/tensor_util.h
+++ b/paddle/framework/tensor_util.h
@@ -1,19 +1,23 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#pragma once
+#include "paddle/framework/data_type.h"
+#include "paddle/framework/eigen.h"
+#include "paddle/framework/framework.pb.h"
#include "paddle/framework/tensor.h"
+#include "paddle/platform/device_context.h"
namespace paddle {
namespace framework {
@@ -33,6 +37,7 @@ inline void CopyFrom(const Tensor& src, const platform::Place& dst_place,
src.check_memory_size();
dst->Resize(src.dims());
+ dst->set_layout(src.layout());
auto src_place = src.place();
auto src_ptr = src.data();
@@ -89,6 +94,7 @@ inline void CopyFrom(const Tensor& src, const platform::Place& dst_place,
Tensor* dst) {
src.check_memory_size();
dst->Resize(src.dims());
+ dst->set_layout(src.layout());
auto src_place = src.place();
auto src_ptr = src.data();
@@ -203,5 +209,109 @@ inline void CopyToVector(const Tensor& src, std::vector* dst) {
src_ptr, size);
}
+// Returns true if a tensor contains NAN, i.e., Not A Number.
+bool HasNAN(const framework::Tensor& tensor);
+
+// Returns true if a tensor contains Inf, i.e., Infinity.
+bool HasInf(const framework::Tensor& tensor);
+
+inline void SerializeToStream(std::ostream& os, const Tensor& tensor,
+ const platform::DeviceContext& dev_ctx) {
+ // TODO(typhoonzero): serialize to ostream
+ { // the 1st field, uint32_t version
+ constexpr uint32_t version = 0;
+ os.write(reinterpret_cast(&version), sizeof(version));
+ }
+ { // the 2nd field, tensor description
+ // int32_t size
+ // void* protobuf message
+ proto::TensorDesc desc;
+ desc.set_data_type(framework::ToDataType(tensor.type()));
+ auto dims = framework::vectorize(tensor.dims());
+ auto* pb_dims = desc.mutable_dims();
+ pb_dims->Resize(static_cast(dims.size()), 0);
+ std::copy(dims.begin(), dims.end(), pb_dims->begin());
+ int32_t size = desc.ByteSize();
+ os.write(reinterpret_cast(&size), sizeof(size));
+ auto out = desc.SerializeAsString();
+ os.write(out.data(), size);
+ }
+ { // the 3rd field, tensor data
+ uint64_t size = tensor.memory_size();
+ auto* data_ptr = tensor.data();
+ PADDLE_ENFORCE(size < std::numeric_limits::max(),
+ "Index overflow when writing tensor");
+ if (platform::is_gpu_place(tensor.place())) {
+#ifdef PADDLE_WITH_CUDA
+ constexpr size_t kBufSize = 1024 * 1024 * 64; // 64MB
+ std::unique_ptr buf(new char[kBufSize]);
+ auto& gpu_dev_ctx =
+ static_cast(dev_ctx);
+ platform::CPUPlace cpu;
+ uintptr_t data = reinterpret_cast(data_ptr);
+ while (size != 0) {
+ size_t size_to_write = std::min(kBufSize, static_cast(size));
+ memory::Copy(cpu, buf.get(),
+ boost::get(tensor.place()),
+ reinterpret_cast(data), size_to_write,
+ gpu_dev_ctx.stream());
+ gpu_dev_ctx.Wait();
+ os.write(buf.get(), size_to_write);
+ data += size_to_write;
+ size -= size_to_write;
+ }
+#else
+ PADDLE_THROW("Unexpected branch");
+#endif
+ } else {
+ os.write(static_cast(data_ptr),
+ static_cast(size));
+ }
+ }
+}
+
+inline void DeserializeFromStream(std::istream& is, Tensor* tensor) {
+ uint32_t version;
+ is.read(reinterpret_cast(&version), sizeof(version));
+ PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported");
+ proto::TensorDesc desc;
+ { // int32_t size
+ // proto buffer
+ int32_t size;
+ is.read(reinterpret_cast(&size), sizeof(size));
+ std::unique_ptr buf(new char[size]);
+ is.read(reinterpret_cast(buf.get()), size);
+ PADDLE_ENFORCE(desc.ParseFromArray(buf.get(), size),
+ "Cannot parse tensor desc");
+ }
+ { // read tensor
+ std::vector dims;
+ dims.reserve(static_cast(desc.dims().size()));
+ std::copy(desc.dims().begin(), desc.dims().end(), std::back_inserter(dims));
+ tensor->Resize(framework::make_ddim(dims));
+
+ void* buf;
+ platform::Place cpu = platform::CPUPlace();
+ // TODO(Yancey1989): use VisiterDataType instead of DataType switch
+ switch (desc.data_type()) {
+ case proto::FP32:
+ buf = tensor->mutable_data(cpu);
+ break;
+ case proto::FP64:
+ buf = tensor->mutable_data(cpu);
+ break;
+ case proto::INT32:
+ buf = tensor->mutable_data(cpu);
+ break;
+ case proto::INT64:
+ buf = tensor->mutable_data(cpu);
+ break;
+ default:
+ PADDLE_THROW("DataType %d not supported", desc.data_type());
+ }
+ is.read(static_cast(buf), tensor->memory_size());
+ }
+}
+
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/tensor_util_test.cc b/paddle/framework/tensor_util_test.cc
index 6fc243aaf6..0dc5166fca 100644
--- a/paddle/framework/tensor_util_test.cc
+++ b/paddle/framework/tensor_util_test.cc
@@ -13,6 +13,7 @@
#include "paddle/framework/tensor_util.h"
#include
+#include
#include
namespace paddle {
@@ -28,6 +29,7 @@ TEST(CopyFrom, Tensor) {
int arr[9] = {1, 2, 3, 4, 5, 6, 7, 8, 9};
memcpy(src_ptr, arr, 9 * sizeof(int));
+ src_tensor.set_layout(DataLayout::kAnyLayout);
auto cpu_place = new platform::CPUPlace();
CopyFrom(src_tensor, *cpu_place, &dst_tensor);
@@ -38,14 +40,18 @@ TEST(CopyFrom, Tensor) {
EXPECT_EQ(src_ptr[i], dst_ptr[i]);
}
+ EXPECT_TRUE(dst_tensor.layout() == src_tensor.layout());
+
Tensor slice_tensor = src_tensor.Slice(1, 2);
- CopyFrom(slice_tensor, *cpu_place, cpu_ctx, &dst_tensor);
+ CopyFrom(slice_tensor, *cpu_place, &dst_tensor);
const int* slice_ptr = slice_tensor.data();
dst_ptr = dst_tensor.data();
ASSERT_NE(dst_ptr, slice_ptr);
for (size_t i = 0; i < 3; ++i) {
EXPECT_EQ(dst_ptr[i], slice_ptr[i]);
}
+ EXPECT_TRUE(dst_tensor.layout() == src_tensor.layout());
+
#ifdef PADDLE_WITH_CUDA
{
Tensor src_tensor;
@@ -91,6 +97,8 @@ TEST(CopyFrom, Tensor) {
for (size_t i = 0; i < 3; ++i) {
EXPECT_EQ(dst_ptr[i], slice_ptr[i]);
}
+
+ EXPECT_TRUE(dst_tensor.layout() == src_tensor.layout());
}
#endif
}
@@ -223,5 +231,78 @@ TEST(CopyToVector, Tensor) {
#endif
}
+TEST(HasNAN, CPU) {
+ using namespace paddle::framework;
+ using namespace paddle::platform;
+ Tensor src;
+ float* buf = src.mutable_data({3}, CPUPlace());
+ buf[0] = 0.0;
+ buf[1] = NAN;
+ buf[2] = 0.0;
+
+ ASSERT_TRUE(HasNAN(src));
+}
+
+TEST(HasInf, CPU) {
+ using namespace paddle::framework;
+ using namespace paddle::platform;
+ Tensor src;
+ double* buf = src.mutable_data({3}, CPUPlace());
+ buf[0] = 1.0;
+ buf[1] = INFINITY;
+ buf[2] = 0.0;
+ ASSERT_TRUE(HasInf(src));
+}
+
+TEST(Tensor, SerializeAndDeserialize) {
+ framework::Tensor src_tensor;
+ int array[6] = {1, 2, 3, 4, 5, 6};
+ src_tensor.Resize({2, 3});
+ int* src_ptr = src_tensor.mutable_data(platform::CPUPlace());
+ for (int i = 0; i < 6; ++i) {
+ src_ptr[i] = array[i];
+ }
+ {
+ framework::Tensor dst_tensor;
+ auto place = new platform::CPUPlace();
+ platform::CPUDeviceContext cpu_ctx(*place);
+ std::ostringstream oss;
+ SerializeToStream(oss, src_tensor, cpu_ctx);
+
+ std::istringstream iss(oss.str());
+ DeserializeFromStream(iss, &dst_tensor);
+ int* dst_ptr = dst_tensor.mutable_data(platform::CPUPlace());
+ for (int i = 0; i < 5; ++i) {
+ ASSERT_EQ(dst_ptr[i], array[i]);
+ }
+ delete place;
+ }
+#ifdef PADDLE_WITH_CUDA
+ {
+ Tensor gpu_tensor;
+ gpu_tensor.Resize({2, 3});
+ Tensor dst_tensor;
+
+ auto gpu_place = new platform::CUDAPlace();
+ platform::CUDADeviceContext gpu_ctx(*gpu_place);
+
+ CopyFrom(src_tensor, *gpu_place, gpu_ctx, &gpu_tensor);
+
+ std::ostringstream oss;
+ SerializeToStream(oss, gpu_tensor, gpu_ctx);
+
+ std::istringstream iss(oss.str());
+ DeserializeFromStream(iss, &dst_tensor);
+
+ int* dst_ptr = dst_tensor.mutable_data(platform::CPUPlace());
+ for (int i = 0; i < 6; ++i) {
+ ASSERT_EQ(dst_ptr[i], array[i]);
+ }
+
+ delete gpu_place;
+ }
+#endif
+}
+
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/tensor_util_test.cu b/paddle/framework/tensor_util_test.cu
new file mode 100644
index 0000000000..ebd35fdf6c
--- /dev/null
+++ b/paddle/framework/tensor_util_test.cu
@@ -0,0 +1,57 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License. */
+
+#include "gtest/gtest.h"
+#include "paddle/framework/tensor_util.h"
+#include "paddle/platform/device_context.h"
+#include "paddle/platform/place.h"
+
+namespace paddle {
+namespace framework {
+
+static __global__ void FillNAN(float* buf) {
+ buf[0] = 0.0;
+ buf[1] = 0.1;
+ buf[2] = NAN;
+}
+static __global__ void FillInf(float* buf) {
+ buf[0] = 0.0;
+ buf[1] = INFINITY;
+ buf[2] = 0.5;
+}
+
+TEST(HasNAN, GPU) {
+ Tensor tensor;
+ platform::CUDAPlace gpu(0);
+ auto& pool = platform::DeviceContextPool::Instance();
+ auto* cuda_ctx = pool.GetByPlace(gpu);
+ float* buf = tensor.mutable_data({3}, gpu);
+ FillNAN<<<1, 1, 0, cuda_ctx->stream()>>>(buf);
+ cuda_ctx->Wait();
+ ASSERT_TRUE(HasNAN(tensor));
+}
+
+TEST(HasInf, GPU) {
+ Tensor tensor;
+ platform::CUDAPlace gpu(0);
+ auto& pool = platform::DeviceContextPool::Instance();
+ auto* cuda_ctx = pool.GetByPlace(gpu);
+ float* buf = tensor.mutable_data({3}, gpu);
+ FillInf<<<1, 1, 0, cuda_ctx->stream()>>>(buf);
+ cuda_ctx->Wait();
+ ASSERT_TRUE(HasInf(tensor));
+}
+
+} // namespace framework
+} // namespace paddle
diff --git a/paddle/framework/threadpool.cc b/paddle/framework/threadpool.cc
new file mode 100644
index 0000000000..109a7e7dc4
--- /dev/null
+++ b/paddle/framework/threadpool.cc
@@ -0,0 +1,24 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/framework/threadpool.h"
+
+namespace paddle {
+namespace framework {
+
+std::unique_ptr ThreadPool::threadpool(nullptr);
+std::once_flag ThreadPool::init_flag;
+
+} // namespace framework
+} // namespace paddle
diff --git a/paddle/framework/threadpool.h b/paddle/framework/threadpool.h
index 9a1ece3ae8..bcd8190755 100644
--- a/paddle/framework/threadpool.h
+++ b/paddle/framework/threadpool.h
@@ -13,24 +13,24 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
+
#include
-#include
#include
-#include
+#include
#include