From 67cbb3e3b6bc5a00b66b3fb1c2de4991ad2e4a21 Mon Sep 17 00:00:00 2001 From: wanghaox Date: Wed, 13 Dec 2017 18:50:03 +0800 Subject: [PATCH 01/43] detection map evaluator for SSD --- paddle/operators/detection_map_op.cc | 77 +++++ paddle/operators/detection_map_op.cu | 20 ++ paddle/operators/detection_map_op.h | 316 ++++++++++++++++++ paddle/operators/math/detection_util.cc | 22 ++ paddle/operators/math/detection_util.cu | 23 ++ paddle/operators/math/detection_util.h | 128 +++++++ .../v2/fluid/tests/test_detection_map_op.py | 155 +++++++++ 7 files changed, 741 insertions(+) create mode 100644 paddle/operators/detection_map_op.cc create mode 100644 paddle/operators/detection_map_op.cu create mode 100644 paddle/operators/detection_map_op.h create mode 100644 paddle/operators/math/detection_util.cc create mode 100644 paddle/operators/math/detection_util.cu create mode 100644 paddle/operators/math/detection_util.h create mode 100644 python/paddle/v2/fluid/tests/test_detection_map_op.py diff --git a/paddle/operators/detection_map_op.cc b/paddle/operators/detection_map_op.cc new file mode 100644 index 0000000000..b59d3bfad9 --- /dev/null +++ b/paddle/operators/detection_map_op.cc @@ -0,0 +1,77 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/detection_map_op.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +class DetectionMAPOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + auto map_dim = framework::make_ddim({1}); + ctx->SetOutputDim("MAP", map_dim); + } + + protected: + framework::OpKernelType GetKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("Label")->type()), + ctx.device_context()); + } +}; + +class DetectionMAPOpMaker : public framework::OpProtoAndCheckerMaker { + public: + DetectionMAPOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Detect", "The detection output."); + AddInput("Label", "The label data."); + AddOutput("MAP", "The MAP evaluate result of the detection."); + + AddAttr("overlap_threshold", "The overlap threshold.") + .SetDefault(.3f); + AddAttr("evaluate_difficult", + "Switch to control whether the difficult data is evaluated.") + .SetDefault(true); + AddAttr("ap_type", + "The AP algorithm type, 'Integral' or '11point'.") + .SetDefault("Integral"); + + AddComment(R"DOC( +Detection MAP Operator. + +Detection MAP evaluator for SSD(Single Shot MultiBox Detector) algorithm. +Please get more information from the following papers: +https://arxiv.org/abs/1512.02325. + +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(detection_map, ops::DetectionMAPOp, + ops::DetectionMAPOpMaker); +REGISTER_OP_CPU_KERNEL( + detection_map, ops::DetectionMAPOpKernel, + ops::DetectionMAPOpKernel); diff --git a/paddle/operators/detection_map_op.cu b/paddle/operators/detection_map_op.cu new file mode 100644 index 0000000000..ab9a992c36 --- /dev/null +++ b/paddle/operators/detection_map_op.cu @@ -0,0 +1,20 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/detection_map_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + detection_map, ops::DetectionMAPOpKernel, + ops::DetectionMAPOpKernel); diff --git a/paddle/operators/detection_map_op.h b/paddle/operators/detection_map_op.h new file mode 100644 index 0000000000..3e862abda6 --- /dev/null +++ b/paddle/operators/detection_map_op.h @@ -0,0 +1,316 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/detection_util.h" +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +template +inline void GetAccumulation(std::vector> in_pairs, + std::vector* accu_vec) { + std::stable_sort(in_pairs.begin(), in_pairs.end(), + math::SortScorePairDescend); + accu_vec->clear(); + size_t sum = 0; + for (size_t i = 0; i < in_pairs.size(); ++i) { + // auto score = in_pairs[i].first; + auto count = in_pairs[i].second; + sum += count; + accu_vec->push_back(sum); + } +} + +template +class DetectionMAPOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input_label = ctx.Input("Label"); + auto* input_detect = ctx.Input("Detect"); + auto* map_out = ctx.Output("MAP"); + + float overlap_threshold = ctx.Attr("overlap_threshold"); + float evaluate_difficult = ctx.Attr("evaluate_difficult"); + std::string ap_type = ctx.Attr("ap_type"); + + auto label_lod = input_label->lod(); + PADDLE_ENFORCE_EQ(label_lod.size(), 1UL, + "Only support one level sequence now."); + auto batch_size = label_lod[0].size() - 1; + + std::vector>>> gt_bboxes; + + std::vector< + std::map>>>> + detect_bboxes; + + if (platform::is_gpu_place(ctx.GetPlace())) { + framework::LoDTensor input_label_cpu; + framework::Tensor input_detect_cpu; + input_label_cpu.set_lod(input_label->lod()); + input_label_cpu.Resize(input_label->dims()); + input_detect_cpu.Resize(input_detect->dims()); + input_label_cpu.mutable_data(platform::CPUPlace()); + input_detect_cpu.mutable_data(platform::CPUPlace()); + framework::CopyFrom(*input_label, platform::CPUPlace(), + ctx.device_context(), &input_label_cpu); + framework::CopyFrom(*input_detect, platform::CPUPlace(), + ctx.device_context(), &input_detect_cpu); + GetBBoxes(input_label_cpu, input_detect_cpu, gt_bboxes, detect_bboxes); + } else { + GetBBoxes(*input_label, *input_detect, gt_bboxes, detect_bboxes); + } + + std::map label_pos_count; + std::map>> true_pos; + std::map>> false_pos; + + CalcTrueAndFalsePositive(batch_size, evaluate_difficult, overlap_threshold, + gt_bboxes, detect_bboxes, label_pos_count, + true_pos, false_pos); + + T map = CalcMAP(ap_type, label_pos_count, true_pos, false_pos); + + T* map_data = nullptr; + framework::Tensor map_cpu; + map_out->mutable_data(ctx.GetPlace()); + if (platform::is_gpu_place(ctx.GetPlace())) { + map_data = map_cpu.mutable_data(map_out->dims(), platform::CPUPlace()); + map_data[0] = map; + framework::CopyFrom(map_cpu, platform::CPUPlace(), ctx.device_context(), + map_out); + } else { + map_data = map_out->mutable_data(ctx.GetPlace()); + map_data[0] = map; + } + } + + protected: + void GetBBoxes( + const framework::LoDTensor& input_label, + const framework::Tensor& input_detect, + std::vector>>>& + gt_bboxes, + std::vector< + std::map>>>>& + detect_bboxes) const { + const T* label_data = input_label.data(); + const T* detect_data = input_detect.data(); + + auto label_lod = input_label.lod(); + auto batch_size = label_lod[0].size() - 1; + auto label_index = label_lod[0]; + + for (size_t n = 0; n < batch_size; ++n) { + std::map>> bboxes; + for (int i = label_index[n]; i < label_index[n + 1]; ++i) { + std::vector> bbox; + math::GetBBoxFromLabelData(label_data + i * 6, 1, bbox); + int label = static_cast(label_data[i * 6]); + bboxes[label].push_back(bbox[0]); + } + gt_bboxes.push_back(bboxes); + } + + size_t n = 0; + size_t detect_box_count = input_detect.dims()[0]; + for (size_t img_id = 0; img_id < batch_size; ++img_id) { + std::map>>> bboxes; + size_t cur_img_id = static_cast((detect_data + n * 7)[0]); + while (cur_img_id == img_id && n < detect_box_count) { + std::vector label; + std::vector score; + std::vector> bbox; + math::GetBBoxFromDetectData(detect_data + n * 7, 1, label, score, + bbox); + bboxes[label[0]].push_back(std::make_pair(score[0], bbox[0])); + ++n; + cur_img_id = static_cast((detect_data + n * 7)[0]); + } + detect_bboxes.push_back(bboxes); + } + } + + void CalcTrueAndFalsePositive( + size_t batch_size, bool evaluate_difficult, float overlap_threshold, + const std::vector>>>& + gt_bboxes, + const std::vector< + std::map>>>>& + detect_bboxes, + std::map& label_pos_count, + std::map>>& true_pos, + std::map>>& false_pos) const { + for (size_t n = 0; n < batch_size; ++n) { + auto image_gt_bboxes = gt_bboxes[n]; + for (auto it = image_gt_bboxes.begin(); it != image_gt_bboxes.end(); + ++it) { + size_t count = 0; + auto labeled_bboxes = it->second; + if (evaluate_difficult) { + count = labeled_bboxes.size(); + } else { + for (size_t i = 0; i < labeled_bboxes.size(); ++i) + if (!(labeled_bboxes[i].is_difficult)) ++count; + } + if (count == 0) { + continue; + } + int label = it->first; + if (label_pos_count.find(label) == label_pos_count.end()) { + label_pos_count[label] = count; + } else { + label_pos_count[label] += count; + } + } + } + + for (size_t n = 0; n < detect_bboxes.size(); ++n) { + auto image_gt_bboxes = gt_bboxes[n]; + auto detections = detect_bboxes[n]; + + if (image_gt_bboxes.size() == 0) { + for (auto it = detections.begin(); it != detections.end(); ++it) { + auto pred_bboxes = it->second; + int label = it->first; + for (size_t i = 0; i < pred_bboxes.size(); ++i) { + auto score = pred_bboxes[i].first; + true_pos[label].push_back(std::make_pair(score, 0)); + false_pos[label].push_back(std::make_pair(score, 1)); + } + } + continue; + } + + for (auto it = detections.begin(); it != detections.end(); ++it) { + int label = it->first; + auto pred_bboxes = it->second; + if (image_gt_bboxes.find(label) == image_gt_bboxes.end()) { + for (size_t i = 0; i < pred_bboxes.size(); ++i) { + auto score = pred_bboxes[i].first; + true_pos[label].push_back(std::make_pair(score, 0)); + false_pos[label].push_back(std::make_pair(score, 1)); + } + continue; + } + + auto matched_bboxes = image_gt_bboxes.find(label)->second; + std::vector visited(matched_bboxes.size(), false); + // Sort detections in descend order based on scores + std::sort(pred_bboxes.begin(), pred_bboxes.end(), + math::SortScorePairDescend>); + for (size_t i = 0; i < pred_bboxes.size(); ++i) { + float max_overlap = -1.0; + size_t max_idx = 0; + auto score = pred_bboxes[i].first; + for (size_t j = 0; j < matched_bboxes.size(); ++j) { + float overlap = + JaccardOverlap(pred_bboxes[i].second, matched_bboxes[j]); + if (overlap > max_overlap) { + max_overlap = overlap; + max_idx = j; + } + } + if (max_overlap > overlap_threshold) { + bool match_evaluate_difficult = + evaluate_difficult || + (!evaluate_difficult && !matched_bboxes[max_idx].is_difficult); + if (match_evaluate_difficult) { + if (!visited[max_idx]) { + true_pos[label].push_back(std::make_pair(score, 1)); + false_pos[label].push_back(std::make_pair(score, 0)); + visited[max_idx] = true; + } else { + true_pos[label].push_back(std::make_pair(score, 0)); + false_pos[label].push_back(std::make_pair(score, 1)); + } + } + } else { + true_pos[label].push_back(std::make_pair(score, 0)); + false_pos[label].push_back(std::make_pair(score, 1)); + } + } + } + } + } + + T CalcMAP( + std::string ap_type, const std::map& label_pos_count, + const std::map>>& true_pos, + const std::map>>& false_pos) const { + T mAP = 0.0; + int count = 0; + for (auto it = label_pos_count.begin(); it != label_pos_count.end(); ++it) { + int label = it->first; + int label_num_pos = it->second; + if (label_num_pos == 0 || true_pos.find(label) == true_pos.end()) + continue; + auto label_true_pos = true_pos.find(label)->second; + auto label_false_pos = false_pos.find(label)->second; + // Compute average precision. + std::vector tp_sum; + GetAccumulation(label_true_pos, &tp_sum); + std::vector fp_sum; + GetAccumulation(label_false_pos, &fp_sum); + std::vector precision, recall; + size_t num = tp_sum.size(); + // Compute Precision. + for (size_t i = 0; i < num; ++i) { + // CHECK_LE(tpCumSum[i], labelNumPos); + precision.push_back(static_cast(tp_sum[i]) / + static_cast(tp_sum[i] + fp_sum[i])); + recall.push_back(static_cast(tp_sum[i]) / label_num_pos); + } + // VOC2007 style + if (ap_type == "11point") { + std::vector max_precisions(11, 0.0); + int start_idx = num - 1; + for (int j = 10; j >= 0; --j) + for (int i = start_idx; i >= 0; --i) { + if (recall[i] < j / 10.) { + start_idx = i; + if (j > 0) max_precisions[j - 1] = max_precisions[j]; + break; + } else { + if (max_precisions[j] < precision[i]) + max_precisions[j] = precision[i]; + } + } + for (int j = 10; j >= 0; --j) mAP += max_precisions[j] / 11; + ++count; + } else if (ap_type == "Integral") { + // Nature integral + float average_precisions = 0.; + float prev_recall = 0.; + for (size_t i = 0; i < num; ++i) { + if (fabs(recall[i] - prev_recall) > 1e-6) + average_precisions += precision[i] * fabs(recall[i] - prev_recall); + prev_recall = recall[i]; + } + mAP += average_precisions; + ++count; + } else { + LOG(FATAL) << "Unkown ap version: " << ap_type; + } + } + if (count != 0) mAP /= count; + return mAP * 100; + } +}; // namespace operators + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/detection_util.cc b/paddle/operators/math/detection_util.cc new file mode 100644 index 0000000000..4131a0cb0e --- /dev/null +++ b/paddle/operators/math/detection_util.cc @@ -0,0 +1,22 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/detection_util.h" +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace operators { +namespace math {} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/detection_util.cu b/paddle/operators/math/detection_util.cu new file mode 100644 index 0000000000..d2bb992396 --- /dev/null +++ b/paddle/operators/math/detection_util.cu @@ -0,0 +1,23 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/detection_util.h" +#include "paddle/operators/math/math_function.h" +#include "paddle/platform/cuda_helper.h" + +namespace paddle { +namespace operators { +namespace math {} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/detection_util.h b/paddle/operators/math/detection_util.h new file mode 100644 index 0000000000..2a4dadc545 --- /dev/null +++ b/paddle/operators/math/detection_util.h @@ -0,0 +1,128 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#pragma once +#include "paddle/framework/selected_rows.h" +#include "paddle/platform/device_context.h" + +namespace paddle { +namespace operators { +namespace math { + +template +struct BBox { + BBox(T x_min, T y_min, T x_max, T y_max) + : x_min(x_min), + y_min(y_min), + x_max(x_max), + y_max(y_max), + is_difficult(false) {} + + BBox() {} + + T get_width() const { return x_max - x_min; } + + T get_height() const { return y_max - y_min; } + + T get_center_x() const { return (x_min + x_max) / 2; } + + T get_center_y() const { return (y_min + y_max) / 2; } + + T get_area() const { return get_width() * get_height(); } + + // coordinate of bounding box + T x_min; + T y_min; + T x_max; + T y_max; + // whether difficult object (e.g. object with heavy occlusion is difficult) + bool is_difficult; +}; + +template +void GetBBoxFromDetectData(const T* detect_data, const size_t num_bboxes, + std::vector& labels, std::vector& scores, + std::vector>& bboxes) { + size_t out_offset = bboxes.size(); + labels.resize(out_offset + num_bboxes); + scores.resize(out_offset + num_bboxes); + bboxes.resize(out_offset + num_bboxes); + for (size_t i = 0; i < num_bboxes; ++i) { + labels[out_offset + i] = *(detect_data + i * 7 + 1); + scores[out_offset + i] = *(detect_data + i * 7 + 2); + BBox bbox; + bbox.x_min = *(detect_data + i * 7 + 3); + bbox.y_min = *(detect_data + i * 7 + 4); + bbox.x_max = *(detect_data + i * 7 + 5); + bbox.y_max = *(detect_data + i * 7 + 6); + bboxes[out_offset + i] = bbox; + }; +} + +template +void GetBBoxFromLabelData(const T* label_data, const size_t num_bboxes, + std::vector>& bboxes) { + size_t out_offset = bboxes.size(); + bboxes.resize(bboxes.size() + num_bboxes); + for (size_t i = 0; i < num_bboxes; ++i) { + BBox bbox; + bbox.x_min = *(label_data + i * 6 + 1); + bbox.y_min = *(label_data + i * 6 + 2); + bbox.x_max = *(label_data + i * 6 + 3); + bbox.y_max = *(label_data + i * 6 + 4); + T is_difficult = *(label_data + i * 6 + 5); + if (std::abs(is_difficult - 0.0) < 1e-6) + bbox.is_difficult = false; + else + bbox.is_difficult = true; + bboxes[out_offset + i] = bbox; + } +} + +template +inline float JaccardOverlap(const BBox& bbox1, const BBox& bbox2) { + if (bbox2.x_min > bbox1.x_max || bbox2.x_max < bbox1.x_min || + bbox2.y_min > bbox1.y_max || bbox2.y_max < bbox1.y_min) { + return 0.0; + } else { + float inter_x_min = std::max(bbox1.x_min, bbox2.x_min); + float inter_y_min = std::max(bbox1.y_min, bbox2.y_min); + float inter_x_max = std::min(bbox1.x_max, bbox2.x_max); + float inter_y_max = std::min(bbox1.y_max, bbox2.y_max); + + float inter_width = inter_x_max - inter_x_min; + float inter_height = inter_y_max - inter_y_min; + float inter_area = inter_width * inter_height; + + float bbox_area1 = bbox1.get_area(); + float bbox_area2 = bbox2.get_area(); + + return inter_area / (bbox_area1 + bbox_area2 - inter_area); + } +} + +template +bool SortScorePairDescend(const std::pair& pair1, + const std::pair& pair2) { + return pair1.first > pair2.first; +} + +// template <> +// bool SortScorePairDescend(const std::pair& pair1, +// const std::pair& pair2) { +// return pair1.first > pair2.first; +// } + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/fluid/tests/test_detection_map_op.py b/python/paddle/v2/fluid/tests/test_detection_map_op.py new file mode 100644 index 0000000000..50ce3afbb9 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_detection_map_op.py @@ -0,0 +1,155 @@ +import unittest +import numpy as np +import sys +import collections +import math +from op_test import OpTest + + +class TestDetectionMAPOp(OpTest): + def set_data(self): + self.init_test_case() + + self.mAP = [self.calc_map(self.tf_pos)] + self.label = np.array(self.label).astype('float32') + self.detect = np.array(self.detect).astype('float32') + self.mAP = np.array(self.mAP).astype('float32') + + self.inputs = { + 'Label': (self.label, self.label_lod), + 'Detect': self.detect + } + + self.attrs = { + 'overlap_threshold': self.overlap_threshold, + 'evaluate_difficult': self.evaluate_difficult, + 'ap_type': self.ap_type + } + + self.outputs = {'MAP': self.mAP} + + def init_test_case(self): + self.overlap_threshold = 0.3 + self.evaluate_difficult = True + self.ap_type = "Integral" + + self.label_lod = [[0, 2, 4]] + # label xmin ymin xmax ymax difficult + self.label = [[1, 0.1, 0.1, 0.3, 0.3, 0], [1, 0.6, 0.6, 0.8, 0.8, 1], + [2, 0.3, 0.3, 0.6, 0.5, 0], [1, 0.7, 0.1, 0.9, 0.3, 0]] + + # image_id label score xmin ymin xmax ymax difficult + self.detect = [ + [0, 1, 0.3, 0.1, 0.0, 0.4, 0.3], [0, 1, 0.7, 0.0, 0.1, 0.2, 0.3], + [0, 1, 0.9, 0.7, 0.6, 0.8, 0.8], [1, 2, 0.8, 0.2, 0.1, 0.4, 0.4], + [1, 2, 0.1, 0.4, 0.3, 0.7, 0.5], [1, 1, 0.2, 0.8, 0.1, 1.0, 0.3], + [1, 3, 0.2, 0.8, 0.1, 1.0, 0.3] + ] + + # image_id label score false_pos false_pos + # [-1, 1, 3, -1, -1], + # [-1, 2, 1, -1, -1] + self.tf_pos = [[0, 1, 0.9, 1, 0], [0, 1, 0.7, 1, 0], [0, 1, 0.3, 0, 1], + [1, 1, 0.2, 1, 0], [1, 2, 0.8, 0, 1], [1, 2, 0.1, 1, 0], + [1, 3, 0.2, 0, 1]] + + def calc_map(self, tf_pos): + mAP = 0.0 + count = 0 + + class_pos_count = {} + true_pos = {} + false_pos = {} + + def get_accumulation(pos_list): + sorted_list = sorted(pos_list, key=lambda pos: pos[0], reverse=True) + sum = 0 + accu_list = [] + for (score, count) in sorted_list: + sum += count + accu_list.append(sum) + return accu_list + + label_count = collections.Counter() + for (label, xmin, ymin, xmax, ymax, difficult) in self.label: + if self.evaluate_difficult: + label_count[label] += 1 + elif not difficult: + label_count[label] += 1 + + true_pos = collections.defaultdict(list) + false_pos = collections.defaultdict(list) + for (image_id, label, score, tp, fp) in tf_pos: + true_pos[label].append([score, tp]) + false_pos[label].append([score, fp]) + + for (label, label_pos_num) in label_count.items(): + if label_pos_num == 0 or label not in true_pos: + continue + + label_true_pos = true_pos[label] + label_false_pos = false_pos[label] + + accu_tp_sum = get_accumulation(label_true_pos) + accu_fp_sum = get_accumulation(label_false_pos) + + precision = [] + recall = [] + + for i in range(len(accu_tp_sum)): + precision.append( + float(accu_tp_sum[i]) / + float(accu_tp_sum[i] + accu_fp_sum[i])) + recall.append(float(accu_tp_sum[i]) / label_pos_num) + + if self.ap_type == "11point": + max_precisions = [11.0, 0.0] + start_idx = len(accu_tp_sum) - 1 + for j in range(10, 0, -1): + for i in range(start_idx, 0, -1): + if recall[i] < j / 10.0: + start_idx = i + if j > 0: + max_precisions[j - 1] = max_precisions[j] + break + else: + if max_precisions[j] < accu_precision[i]: + max_precisions[j] = accu_precision[i] + for j in range(10, 0, -1): + mAP += max_precisions[j] / 11 + count += 1 + elif self.ap_type == "Integral": + average_precisions = 0.0 + prev_recall = 0.0 + for i in range(len(accu_tp_sum)): + if math.fabs(recall[i] - prev_recall) > 1e-6: + average_precisions += precision[i] * \ + math.fabs(recall[i] - prev_recall) + prev_recall = recall[i] + + mAP += average_precisions + count += 1 + + if count != 0: mAP /= count + return mAP * 100.0 + + def setUp(self): + self.op_type = "detection_map" + self.set_data() + + def test_check_output(self): + self.check_output() + + +class TestDetectionMAPOpSkipDiff(TestDetectionMAPOp): + def init_test_case(self): + super(TestDetectionMAPOpSkipDiff, self).init_test_case() + + self.evaluate_difficult = False + + self.tf_pos = [[0, 1, 0.7, 1, 0], [0, 1, 0.3, 0, 1], [1, 1, 0.2, 1, 0], + [1, 2, 0.8, 0, 1], [1, 2, 0.1, 1, 0], [1, 3, 0.2, 0, 1]] + + +if __name__ == '__main__': + unittest.main() From 26f03ea13d14a28c199185aa1fd5feda84d4eb6e Mon Sep 17 00:00:00 2001 From: wanghaox Date: Tue, 30 Jan 2018 19:49:50 +0800 Subject: [PATCH 02/43] update detection_map operator --- paddle/operators/detection_map_op.cc | 73 ++++- paddle/operators/detection_map_op.cu | 20 -- paddle/operators/detection_map_op.h | 249 +++++++++--------- paddle/operators/math/detection_util.cc | 22 -- paddle/operators/math/detection_util.cu | 23 -- paddle/operators/math/detection_util.h | 128 --------- .../v2/fluid/tests/test_detection_map_op.py | 71 ++--- 7 files changed, 231 insertions(+), 355 deletions(-) delete mode 100644 paddle/operators/detection_map_op.cu delete mode 100644 paddle/operators/math/detection_util.cc delete mode 100644 paddle/operators/math/detection_util.cu delete mode 100644 paddle/operators/math/detection_util.h diff --git a/paddle/operators/detection_map_op.cc b/paddle/operators/detection_map_op.cc index b59d3bfad9..aa47cb3c80 100644 --- a/paddle/operators/detection_map_op.cc +++ b/paddle/operators/detection_map_op.cc @@ -24,6 +24,29 @@ class DetectionMAPOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Detection"), + "Input(Detection) of DetectionMAPOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Label"), + "Input(Label) of DetectionMAPOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("MAP"), + "Output(MAP) of DetectionMAPOp should not be null."); + + auto det_dims = ctx->GetInputDim("Detection"); + PADDLE_ENFORCE_EQ(det_dims.size(), 2UL, + "The rank of Input(Detection) must be 2, " + "the shape is [N, 6]."); + PADDLE_ENFORCE_EQ(det_dims[1], 6UL, + "The shape is of Input(Detection) [N, 6]."); + auto label_dims = ctx->GetInputDim("Label"); + PADDLE_ENFORCE_EQ(label_dims.size(), 2UL, + "The rank of Input(Label) must be 2, " + "the shape is [N, 6]."); + PADDLE_ENFORCE_EQ(label_dims[1], 6UL, + "The shape is of Input(Label) [N, 6]."); + + auto ap_type = GetAPType(ctx->Attrs().Get("ap_type")); + PADDLE_ENFORCE_NE(ap_type, APType::kNone, + "The ap_type should be 'integral' or '11point."); auto map_dim = framework::make_ddim({1}); ctx->SetOutputDim("MAP", map_dim); } @@ -42,25 +65,49 @@ class DetectionMAPOpMaker : public framework::OpProtoAndCheckerMaker { DetectionMAPOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("Detect", "The detection output."); - AddInput("Label", "The label data."); - AddOutput("MAP", "The MAP evaluate result of the detection."); - - AddAttr("overlap_threshold", "The overlap threshold.") + AddInput("Label", + "(LoDTensor) A 2-D LoDTensor with shape[N, 6] represents the" + "Labeled ground-truth data. Each row has 6 values: " + "[label, is_difficult, xmin, ymin, xmax, ymax], N is the total " + "number of ground-truth data in this mini-batch. For each " + "instance, the offsets in first dimension are called LoD, " + "the number of offset is N + 1, if LoD[i + 1] - LoD[i] == 0, " + "means there is no ground-truth data."); + AddInput("Detection", + "(LoDTensor) A 2-D LoDTensor with shape [M, 6] represents the " + "detections. Each row has 6 values: " + "[label, confidence, xmin, ymin, xmax, ymax], M is the total " + "number of detections in this mini-batch. For each instance, " + "the offsets in first dimension are called LoD, the number of " + "offset is N + 1, if LoD[i + 1] - LoD[i] == 0, means there is " + "no detected data."); + AddOutput("MAP", + "(Tensor) A tensor with shape [1], store the mAP evaluate " + "result of the detection."); + + AddAttr("overlap_threshold", + "(float) " + "The jaccard overlap threshold of detection output and " + "ground-truth data.") .SetDefault(.3f); AddAttr("evaluate_difficult", + "(bool, default true) " "Switch to control whether the difficult data is evaluated.") .SetDefault(true); AddAttr("ap_type", - "The AP algorithm type, 'Integral' or '11point'.") - .SetDefault("Integral"); - + "(string, default 'integral') " + "The AP algorithm type, 'integral' or '11point'.") + .SetDefault("integral") + .InEnum({"integral", "11point"}); AddComment(R"DOC( -Detection MAP Operator. - -Detection MAP evaluator for SSD(Single Shot MultiBox Detector) algorithm. -Please get more information from the following papers: -https://arxiv.org/abs/1512.02325. +Detection mAP evaluate operator. +The general steps are as follows. First, calculate the true positive and + false positive according to the input of detection and labels, then + calculate the mAP evaluate value. + Supporting '11 point' and 'integral' mAP algorithm. Please get more information + from the following articles: + https://sanchom.wordpress.com/tag/average-precision/ + https://arxiv.org/abs/1512.02325 )DOC"); } diff --git a/paddle/operators/detection_map_op.cu b/paddle/operators/detection_map_op.cu deleted file mode 100644 index ab9a992c36..0000000000 --- a/paddle/operators/detection_map_op.cu +++ /dev/null @@ -1,20 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/operators/detection_map_op.h" - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - detection_map, ops::DetectionMAPOpKernel, - ops::DetectionMAPOpKernel); diff --git a/paddle/operators/detection_map_op.h b/paddle/operators/detection_map_op.h index 3e862abda6..d29a6968e4 100644 --- a/paddle/operators/detection_map_op.h +++ b/paddle/operators/detection_map_op.h @@ -13,22 +13,37 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" -#include "paddle/operators/math/detection_util.h" -#include "paddle/operators/math/math_function.h" namespace paddle { namespace operators { +enum APType { kNone = 0, kIntegral, k11point }; + +APType GetAPType(std::string str) { + if (str == "integral") { + return APType::kIntegral; + } else if (str == "11point") { + return APType::k11point; + } else { + return APType::kNone; + } +} + +template +inline bool SortScorePairDescend(const std::pair& pair1, + const std::pair& pair2) { + return pair1.first > pair2.first; +} + template inline void GetAccumulation(std::vector> in_pairs, std::vector* accu_vec) { - std::stable_sort(in_pairs.begin(), in_pairs.end(), - math::SortScorePairDescend); + std::stable_sort(in_pairs.begin(), in_pairs.end(), SortScorePairDescend); accu_vec->clear(); size_t sum = 0; for (size_t i = 0; i < in_pairs.size(); ++i) { - // auto score = in_pairs[i].first; auto count = in_pairs[i].second; sum += count; accu_vec->push_back(sum); @@ -39,126 +54,125 @@ template class DetectionMAPOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* input_label = ctx.Input("Label"); - auto* input_detect = ctx.Input("Detect"); - auto* map_out = ctx.Output("MAP"); + auto* in_detect = ctx.Input("Detection"); + auto* in_label = ctx.Input("Label"); + auto* out_map = ctx.Output("MAP"); float overlap_threshold = ctx.Attr("overlap_threshold"); float evaluate_difficult = ctx.Attr("evaluate_difficult"); - std::string ap_type = ctx.Attr("ap_type"); + auto ap_type = GetAPType(ctx.Attr("ap_type")); - auto label_lod = input_label->lod(); + auto label_lod = in_label->lod(); + auto detect_lod = in_detect->lod(); PADDLE_ENFORCE_EQ(label_lod.size(), 1UL, "Only support one level sequence now."); - auto batch_size = label_lod[0].size() - 1; - - std::vector>>> gt_bboxes; - - std::vector< - std::map>>>> - detect_bboxes; - - if (platform::is_gpu_place(ctx.GetPlace())) { - framework::LoDTensor input_label_cpu; - framework::Tensor input_detect_cpu; - input_label_cpu.set_lod(input_label->lod()); - input_label_cpu.Resize(input_label->dims()); - input_detect_cpu.Resize(input_detect->dims()); - input_label_cpu.mutable_data(platform::CPUPlace()); - input_detect_cpu.mutable_data(platform::CPUPlace()); - framework::CopyFrom(*input_label, platform::CPUPlace(), - ctx.device_context(), &input_label_cpu); - framework::CopyFrom(*input_detect, platform::CPUPlace(), - ctx.device_context(), &input_detect_cpu); - GetBBoxes(input_label_cpu, input_detect_cpu, gt_bboxes, detect_bboxes); - } else { - GetBBoxes(*input_label, *input_detect, gt_bboxes, detect_bboxes); - } + PADDLE_ENFORCE_EQ(label_lod[0].size(), detect_lod[0].size(), + "The batch_size of input(Label) and input(Detection) " + "must be the same."); + + std::vector>> gt_boxes; + std::vector>>> detect_boxes; + + GetBoxes(*in_label, *in_detect, gt_boxes, detect_boxes); std::map label_pos_count; std::map>> true_pos; std::map>> false_pos; - CalcTrueAndFalsePositive(batch_size, evaluate_difficult, overlap_threshold, - gt_bboxes, detect_bboxes, label_pos_count, - true_pos, false_pos); + CalcTrueAndFalsePositive(gt_boxes, detect_boxes, evaluate_difficult, + overlap_threshold, label_pos_count, true_pos, + false_pos); T map = CalcMAP(ap_type, label_pos_count, true_pos, false_pos); - T* map_data = nullptr; - framework::Tensor map_cpu; - map_out->mutable_data(ctx.GetPlace()); - if (platform::is_gpu_place(ctx.GetPlace())) { - map_data = map_cpu.mutable_data(map_out->dims(), platform::CPUPlace()); - map_data[0] = map; - framework::CopyFrom(map_cpu, platform::CPUPlace(), ctx.device_context(), - map_out); + T* map_data = out_map->mutable_data(ctx.GetPlace()); + map_data[0] = map; + } + + protected: + struct Box { + Box(T xmin, T ymin, T xmax, T ymax) + : xmin(xmin), ymin(ymin), xmax(xmax), ymax(ymax), is_difficult(false) {} + + T xmin, ymin, xmax, ymax; + bool is_difficult; + }; + + inline T JaccardOverlap(const Box& box1, const Box& box2) const { + if (box2.xmin > box1.xmax || box2.xmax < box1.xmin || + box2.ymin > box1.ymax || box2.ymax < box1.ymin) { + return 0.0; } else { - map_data = map_out->mutable_data(ctx.GetPlace()); - map_data[0] = map; + T inter_xmin = std::max(box1.xmin, box2.xmin); + T inter_ymin = std::max(box1.ymin, box2.ymin); + T inter_xmax = std::min(box1.xmax, box2.xmax); + T inter_ymax = std::min(box1.ymax, box2.ymax); + + T inter_width = inter_xmax - inter_xmin; + T inter_height = inter_ymax - inter_ymin; + T inter_area = inter_width * inter_height; + + T bbox_area1 = (box1.xmax - box1.xmin) * (box1.ymax - box1.ymin); + T bbox_area2 = (box2.xmax - box2.xmin) * (box2.ymax - box2.ymin); + + return inter_area / (bbox_area1 + bbox_area2 - inter_area); } } - protected: - void GetBBoxes( - const framework::LoDTensor& input_label, - const framework::Tensor& input_detect, - std::vector>>>& - gt_bboxes, - std::vector< - std::map>>>>& - detect_bboxes) const { - const T* label_data = input_label.data(); - const T* detect_data = input_detect.data(); + void GetBoxes(const framework::LoDTensor& input_label, + const framework::LoDTensor& input_detect, + std::vector>>& gt_boxes, + std::vector>>>& + detect_boxes) const { + auto labels = framework::EigenTensor::From(input_label); + auto detect = framework::EigenTensor::From(input_detect); auto label_lod = input_label.lod(); - auto batch_size = label_lod[0].size() - 1; + auto detect_lod = input_detect.lod(); + + int batch_size = label_lod[0].size() - 1; auto label_index = label_lod[0]; - for (size_t n = 0; n < batch_size; ++n) { - std::map>> bboxes; + for (int n = 0; n < batch_size; ++n) { + std::map> boxes; for (int i = label_index[n]; i < label_index[n + 1]; ++i) { - std::vector> bbox; - math::GetBBoxFromLabelData(label_data + i * 6, 1, bbox); - int label = static_cast(label_data[i * 6]); - bboxes[label].push_back(bbox[0]); + Box box(labels(i, 2), labels(i, 3), labels(i, 4), labels(i, 5)); + int label = labels(i, 0); + auto is_difficult = labels(i, 1); + if (std::abs(is_difficult - 0.0) < 1e-6) + box.is_difficult = false; + else + box.is_difficult = true; + boxes[label].push_back(box); } - gt_bboxes.push_back(bboxes); + gt_boxes.push_back(boxes); } - size_t n = 0; - size_t detect_box_count = input_detect.dims()[0]; - for (size_t img_id = 0; img_id < batch_size; ++img_id) { - std::map>>> bboxes; - size_t cur_img_id = static_cast((detect_data + n * 7)[0]); - while (cur_img_id == img_id && n < detect_box_count) { - std::vector label; - std::vector score; - std::vector> bbox; - math::GetBBoxFromDetectData(detect_data + n * 7, 1, label, score, - bbox); - bboxes[label[0]].push_back(std::make_pair(score[0], bbox[0])); - ++n; - cur_img_id = static_cast((detect_data + n * 7)[0]); + auto detect_index = detect_lod[0]; + for (int n = 0; n < batch_size; ++n) { + std::map>> boxes; + for (int i = detect_index[n]; i < detect_index[n + 1]; ++i) { + Box box(detect(i, 2), detect(i, 3), detect(i, 4), detect(i, 5)); + int label = detect(i, 0); + auto score = detect(i, 1); + boxes[label].push_back(std::make_pair(score, box)); } - detect_bboxes.push_back(bboxes); + detect_boxes.push_back(boxes); } } void CalcTrueAndFalsePositive( - size_t batch_size, bool evaluate_difficult, float overlap_threshold, - const std::vector>>>& - gt_bboxes, - const std::vector< - std::map>>>>& - detect_bboxes, + const std::vector>>& gt_boxes, + const std::vector>>>& + detect_boxes, + bool evaluate_difficult, float overlap_threshold, std::map& label_pos_count, std::map>>& true_pos, std::map>>& false_pos) const { - for (size_t n = 0; n < batch_size; ++n) { - auto image_gt_bboxes = gt_bboxes[n]; - for (auto it = image_gt_bboxes.begin(); it != image_gt_bboxes.end(); - ++it) { + int batch_size = gt_boxes.size(); + for (int n = 0; n < batch_size; ++n) { + auto image_gt_boxes = gt_boxes[n]; + for (auto it = image_gt_boxes.begin(); it != image_gt_boxes.end(); ++it) { size_t count = 0; auto labeled_bboxes = it->second; if (evaluate_difficult) { @@ -179,16 +193,16 @@ class DetectionMAPOpKernel : public framework::OpKernel { } } - for (size_t n = 0; n < detect_bboxes.size(); ++n) { - auto image_gt_bboxes = gt_bboxes[n]; - auto detections = detect_bboxes[n]; + for (size_t n = 0; n < detect_boxes.size(); ++n) { + auto image_gt_boxes = gt_boxes[n]; + auto detections = detect_boxes[n]; - if (image_gt_bboxes.size() == 0) { + if (image_gt_boxes.size() == 0) { for (auto it = detections.begin(); it != detections.end(); ++it) { - auto pred_bboxes = it->second; + auto pred_boxes = it->second; int label = it->first; - for (size_t i = 0; i < pred_bboxes.size(); ++i) { - auto score = pred_bboxes[i].first; + for (size_t i = 0; i < pred_boxes.size(); ++i) { + auto score = pred_boxes[i].first; true_pos[label].push_back(std::make_pair(score, 0)); false_pos[label].push_back(std::make_pair(score, 1)); } @@ -198,28 +212,27 @@ class DetectionMAPOpKernel : public framework::OpKernel { for (auto it = detections.begin(); it != detections.end(); ++it) { int label = it->first; - auto pred_bboxes = it->second; - if (image_gt_bboxes.find(label) == image_gt_bboxes.end()) { - for (size_t i = 0; i < pred_bboxes.size(); ++i) { - auto score = pred_bboxes[i].first; + auto pred_boxes = it->second; + if (image_gt_boxes.find(label) == image_gt_boxes.end()) { + for (size_t i = 0; i < pred_boxes.size(); ++i) { + auto score = pred_boxes[i].first; true_pos[label].push_back(std::make_pair(score, 0)); false_pos[label].push_back(std::make_pair(score, 1)); } continue; } - auto matched_bboxes = image_gt_bboxes.find(label)->second; + auto matched_bboxes = image_gt_boxes.find(label)->second; std::vector visited(matched_bboxes.size(), false); // Sort detections in descend order based on scores - std::sort(pred_bboxes.begin(), pred_bboxes.end(), - math::SortScorePairDescend>); - for (size_t i = 0; i < pred_bboxes.size(); ++i) { - float max_overlap = -1.0; + std::sort(pred_boxes.begin(), pred_boxes.end(), + SortScorePairDescend); + for (size_t i = 0; i < pred_boxes.size(); ++i) { + T max_overlap = -1.0; size_t max_idx = 0; - auto score = pred_bboxes[i].first; + auto score = pred_boxes[i].first; for (size_t j = 0; j < matched_bboxes.size(); ++j) { - float overlap = - JaccardOverlap(pred_bboxes[i].second, matched_bboxes[j]); + T overlap = JaccardOverlap(pred_boxes[i].second, matched_bboxes[j]); if (overlap > max_overlap) { max_overlap = overlap; max_idx = j; @@ -249,7 +262,7 @@ class DetectionMAPOpKernel : public framework::OpKernel { } T CalcMAP( - std::string ap_type, const std::map& label_pos_count, + APType ap_type, const std::map& label_pos_count, const std::map>>& true_pos, const std::map>>& false_pos) const { T mAP = 0.0; @@ -266,18 +279,18 @@ class DetectionMAPOpKernel : public framework::OpKernel { GetAccumulation(label_true_pos, &tp_sum); std::vector fp_sum; GetAccumulation(label_false_pos, &fp_sum); - std::vector precision, recall; + std::vector precision, recall; size_t num = tp_sum.size(); // Compute Precision. for (size_t i = 0; i < num; ++i) { // CHECK_LE(tpCumSum[i], labelNumPos); - precision.push_back(static_cast(tp_sum[i]) / - static_cast(tp_sum[i] + fp_sum[i])); - recall.push_back(static_cast(tp_sum[i]) / label_num_pos); + precision.push_back(static_cast(tp_sum[i]) / + static_cast(tp_sum[i] + fp_sum[i])); + recall.push_back(static_cast(tp_sum[i]) / label_num_pos); } // VOC2007 style - if (ap_type == "11point") { - std::vector max_precisions(11, 0.0); + if (ap_type == APType::k11point) { + std::vector max_precisions(11, 0.0); int start_idx = num - 1; for (int j = 10; j >= 0; --j) for (int i = start_idx; i >= 0; --i) { @@ -292,7 +305,7 @@ class DetectionMAPOpKernel : public framework::OpKernel { } for (int j = 10; j >= 0; --j) mAP += max_precisions[j] / 11; ++count; - } else if (ap_type == "Integral") { + } else if (ap_type == APType::kIntegral) { // Nature integral float average_precisions = 0.; float prev_recall = 0.; diff --git a/paddle/operators/math/detection_util.cc b/paddle/operators/math/detection_util.cc deleted file mode 100644 index 4131a0cb0e..0000000000 --- a/paddle/operators/math/detection_util.cc +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/operators/math/detection_util.h" -#include "paddle/operators/math/math_function.h" - -namespace paddle { -namespace operators { -namespace math {} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/operators/math/detection_util.cu b/paddle/operators/math/detection_util.cu deleted file mode 100644 index d2bb992396..0000000000 --- a/paddle/operators/math/detection_util.cu +++ /dev/null @@ -1,23 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/operators/math/detection_util.h" -#include "paddle/operators/math/math_function.h" -#include "paddle/platform/cuda_helper.h" - -namespace paddle { -namespace operators { -namespace math {} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/operators/math/detection_util.h b/paddle/operators/math/detection_util.h deleted file mode 100644 index 2a4dadc545..0000000000 --- a/paddle/operators/math/detection_util.h +++ /dev/null @@ -1,128 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -#pragma once -#include "paddle/framework/selected_rows.h" -#include "paddle/platform/device_context.h" - -namespace paddle { -namespace operators { -namespace math { - -template -struct BBox { - BBox(T x_min, T y_min, T x_max, T y_max) - : x_min(x_min), - y_min(y_min), - x_max(x_max), - y_max(y_max), - is_difficult(false) {} - - BBox() {} - - T get_width() const { return x_max - x_min; } - - T get_height() const { return y_max - y_min; } - - T get_center_x() const { return (x_min + x_max) / 2; } - - T get_center_y() const { return (y_min + y_max) / 2; } - - T get_area() const { return get_width() * get_height(); } - - // coordinate of bounding box - T x_min; - T y_min; - T x_max; - T y_max; - // whether difficult object (e.g. object with heavy occlusion is difficult) - bool is_difficult; -}; - -template -void GetBBoxFromDetectData(const T* detect_data, const size_t num_bboxes, - std::vector& labels, std::vector& scores, - std::vector>& bboxes) { - size_t out_offset = bboxes.size(); - labels.resize(out_offset + num_bboxes); - scores.resize(out_offset + num_bboxes); - bboxes.resize(out_offset + num_bboxes); - for (size_t i = 0; i < num_bboxes; ++i) { - labels[out_offset + i] = *(detect_data + i * 7 + 1); - scores[out_offset + i] = *(detect_data + i * 7 + 2); - BBox bbox; - bbox.x_min = *(detect_data + i * 7 + 3); - bbox.y_min = *(detect_data + i * 7 + 4); - bbox.x_max = *(detect_data + i * 7 + 5); - bbox.y_max = *(detect_data + i * 7 + 6); - bboxes[out_offset + i] = bbox; - }; -} - -template -void GetBBoxFromLabelData(const T* label_data, const size_t num_bboxes, - std::vector>& bboxes) { - size_t out_offset = bboxes.size(); - bboxes.resize(bboxes.size() + num_bboxes); - for (size_t i = 0; i < num_bboxes; ++i) { - BBox bbox; - bbox.x_min = *(label_data + i * 6 + 1); - bbox.y_min = *(label_data + i * 6 + 2); - bbox.x_max = *(label_data + i * 6 + 3); - bbox.y_max = *(label_data + i * 6 + 4); - T is_difficult = *(label_data + i * 6 + 5); - if (std::abs(is_difficult - 0.0) < 1e-6) - bbox.is_difficult = false; - else - bbox.is_difficult = true; - bboxes[out_offset + i] = bbox; - } -} - -template -inline float JaccardOverlap(const BBox& bbox1, const BBox& bbox2) { - if (bbox2.x_min > bbox1.x_max || bbox2.x_max < bbox1.x_min || - bbox2.y_min > bbox1.y_max || bbox2.y_max < bbox1.y_min) { - return 0.0; - } else { - float inter_x_min = std::max(bbox1.x_min, bbox2.x_min); - float inter_y_min = std::max(bbox1.y_min, bbox2.y_min); - float inter_x_max = std::min(bbox1.x_max, bbox2.x_max); - float inter_y_max = std::min(bbox1.y_max, bbox2.y_max); - - float inter_width = inter_x_max - inter_x_min; - float inter_height = inter_y_max - inter_y_min; - float inter_area = inter_width * inter_height; - - float bbox_area1 = bbox1.get_area(); - float bbox_area2 = bbox2.get_area(); - - return inter_area / (bbox_area1 + bbox_area2 - inter_area); - } -} - -template -bool SortScorePairDescend(const std::pair& pair1, - const std::pair& pair2) { - return pair1.first > pair2.first; -} - -// template <> -// bool SortScorePairDescend(const std::pair& pair1, -// const std::pair& pair2) { -// return pair1.first > pair2.first; -// } - -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/python/paddle/v2/fluid/tests/test_detection_map_op.py b/python/paddle/v2/fluid/tests/test_detection_map_op.py index 50ce3afbb9..bb545031ae 100644 --- a/python/paddle/v2/fluid/tests/test_detection_map_op.py +++ b/python/paddle/v2/fluid/tests/test_detection_map_op.py @@ -10,14 +10,14 @@ class TestDetectionMAPOp(OpTest): def set_data(self): self.init_test_case() - self.mAP = [self.calc_map(self.tf_pos)] + self.mAP = [self.calc_map(self.tf_pos, self.tf_pos_lod)] self.label = np.array(self.label).astype('float32') self.detect = np.array(self.detect).astype('float32') self.mAP = np.array(self.mAP).astype('float32') self.inputs = { 'Label': (self.label, self.label_lod), - 'Detect': self.detect + 'Detection': (self.detect, self.detect_lod) } self.attrs = { @@ -31,29 +31,29 @@ class TestDetectionMAPOp(OpTest): def init_test_case(self): self.overlap_threshold = 0.3 self.evaluate_difficult = True - self.ap_type = "Integral" + self.ap_type = "integral" self.label_lod = [[0, 2, 4]] - # label xmin ymin xmax ymax difficult - self.label = [[1, 0.1, 0.1, 0.3, 0.3, 0], [1, 0.6, 0.6, 0.8, 0.8, 1], - [2, 0.3, 0.3, 0.6, 0.5, 0], [1, 0.7, 0.1, 0.9, 0.3, 0]] + # label difficult xmin ymin xmax ymax + self.label = [[1, 0, 0.1, 0.1, 0.3, 0.3], [1, 1, 0.6, 0.6, 0.8, 0.8], + [2, 0, 0.3, 0.3, 0.6, 0.5], [1, 0, 0.7, 0.1, 0.9, 0.3]] - # image_id label score xmin ymin xmax ymax difficult + # label score xmin ymin xmax ymax difficult + self.detect_lod = [[0, 3, 7]] self.detect = [ - [0, 1, 0.3, 0.1, 0.0, 0.4, 0.3], [0, 1, 0.7, 0.0, 0.1, 0.2, 0.3], - [0, 1, 0.9, 0.7, 0.6, 0.8, 0.8], [1, 2, 0.8, 0.2, 0.1, 0.4, 0.4], - [1, 2, 0.1, 0.4, 0.3, 0.7, 0.5], [1, 1, 0.2, 0.8, 0.1, 1.0, 0.3], - [1, 3, 0.2, 0.8, 0.1, 1.0, 0.3] + [1, 0.3, 0.1, 0.0, 0.4, 0.3], [1, 0.7, 0.0, 0.1, 0.2, 0.3], + [1, 0.9, 0.7, 0.6, 0.8, 0.8], [2, 0.8, 0.2, 0.1, 0.4, 0.4], + [2, 0.1, 0.4, 0.3, 0.7, 0.5], [1, 0.2, 0.8, 0.1, 1.0, 0.3], + [3, 0.2, 0.8, 0.1, 1.0, 0.3] ] - # image_id label score false_pos false_pos - # [-1, 1, 3, -1, -1], - # [-1, 2, 1, -1, -1] - self.tf_pos = [[0, 1, 0.9, 1, 0], [0, 1, 0.7, 1, 0], [0, 1, 0.3, 0, 1], - [1, 1, 0.2, 1, 0], [1, 2, 0.8, 0, 1], [1, 2, 0.1, 1, 0], - [1, 3, 0.2, 0, 1]] + # label score true_pos false_pos + self.tf_pos_lod = [[0, 3, 7]] + self.tf_pos = [[1, 0.9, 1, 0], [1, 0.7, 1, 0], [1, 0.3, 0, 1], + [1, 0.2, 1, 0], [2, 0.8, 0, 1], [2, 0.1, 1, 0], + [3, 0.2, 0, 1]] - def calc_map(self, tf_pos): + def calc_map(self, tf_pos, tf_pos_lod): mAP = 0.0 count = 0 @@ -71,7 +71,7 @@ class TestDetectionMAPOp(OpTest): return accu_list label_count = collections.Counter() - for (label, xmin, ymin, xmax, ymax, difficult) in self.label: + for (label, difficult, xmin, ymin, xmax, ymax) in self.label: if self.evaluate_difficult: label_count[label] += 1 elif not difficult: @@ -79,7 +79,7 @@ class TestDetectionMAPOp(OpTest): true_pos = collections.defaultdict(list) false_pos = collections.defaultdict(list) - for (image_id, label, score, tp, fp) in tf_pos: + for (label, score, tp, fp) in tf_pos: true_pos[label].append([score, tp]) false_pos[label].append([score, fp]) @@ -103,22 +103,22 @@ class TestDetectionMAPOp(OpTest): recall.append(float(accu_tp_sum[i]) / label_pos_num) if self.ap_type == "11point": - max_precisions = [11.0, 0.0] + max_precisions = [0.0] * 11 start_idx = len(accu_tp_sum) - 1 - for j in range(10, 0, -1): - for i in range(start_idx, 0, -1): - if recall[i] < j / 10.0: + for j in range(10, -1, -1): + for i in range(start_idx, -1, -1): + if recall[i] < float(j) / 10.0: start_idx = i if j > 0: max_precisions[j - 1] = max_precisions[j] break - else: - if max_precisions[j] < accu_precision[i]: - max_precisions[j] = accu_precision[i] - for j in range(10, 0, -1): + else: + if max_precisions[j] < precision[i]: + max_precisions[j] = precision[i] + for j in range(10, -1, -1): mAP += max_precisions[j] / 11 count += 1 - elif self.ap_type == "Integral": + elif self.ap_type == "integral": average_precisions = 0.0 prev_recall = 0.0 for i in range(len(accu_tp_sum)): @@ -147,8 +147,17 @@ class TestDetectionMAPOpSkipDiff(TestDetectionMAPOp): self.evaluate_difficult = False - self.tf_pos = [[0, 1, 0.7, 1, 0], [0, 1, 0.3, 0, 1], [1, 1, 0.2, 1, 0], - [1, 2, 0.8, 0, 1], [1, 2, 0.1, 1, 0], [1, 3, 0.2, 0, 1]] + self.tf_pos_lod = [[0, 2, 6]] + # label score true_pos false_pos + self.tf_pos = [[1, 0.7, 1, 0], [1, 0.3, 0, 1], [1, 0.2, 1, 0], + [2, 0.8, 0, 1], [2, 0.1, 1, 0], [3, 0.2, 0, 1]] + + +class TestDetectionMAPOp11Point(TestDetectionMAPOp): + def init_test_case(self): + super(TestDetectionMAPOp11Point, self).init_test_case() + + self.ap_type = "11point" if __name__ == '__main__': From dd6b59da6beca7ee66ad86ae7899a63a8cf57a6e Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Thu, 8 Feb 2018 13:56:41 +0800 Subject: [PATCH 03/43] add Python interface of prior_boxes --- python/paddle/v2/fluid/layers/nn.py | 152 +++++++++++++++++++++++++++- 1 file changed, 151 insertions(+), 1 deletion(-) diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py index a79479f469..891d89a24b 100644 --- a/python/paddle/v2/fluid/layers/nn.py +++ b/python/paddle/v2/fluid/layers/nn.py @@ -14,7 +14,7 @@ """ All layers just related to the neural network. """ - +import math from ..layer_helper import LayerHelper from ..initializer import Normal, Constant from ..framework import Variable @@ -65,6 +65,7 @@ __all__ = [ 'beam_search', 'row_conv', 'multiplex', + 'prior_boxes', ] @@ -2993,3 +2994,152 @@ def multiplex(inputs, index): 'Ids': index}, outputs={'Out': [out]}) return out + + +def prior_box(input, + image, + min_sizes, + max_sizes, + aspect_ratios, + variance, + flip, + clip, + step_w, + step_h, + offset, + name=None): + """ + **Prior_box** + + """ + helper = LayerHelper("prior_box", **locals()) + dtype = helper.input_dtype() + + box = helper.create_tmp_variable(dtype) + var = helper.create_tmp_variable(dtype) + helper.append_op( + type="prior_box", + inputs={"Input": input, + "Image": image}, + outputs={"Boxes": box, + "Variances": var}, + attrs={ + 'min_sizes': min_sizes, + 'max_sizes': max_sizes, + 'aspect_ratios': aspect_ratios, + 'variances': variance, + 'flip': flip, + 'clip': clip, + 'step_w': step_w, + 'step_h': step_h, + 'offset': offset + }) + return box, var + + +def prior_boxes(input_layers, + image, + min_ratio, + max_ratio, + steps, + aspect_ratios, + min_dim, + step_w=None, + step_h=None, + offset=0.5, + variance=[0.1], + flip=True, + clip=True, + name=None): + """ + **Prior_boxes** + e.g. + prior_boxes( + input_layers = [conv1, conv2, conv3, conv4, conv5, conv6], + image = data, + min_ratio = 0.2, + max_ratio = 0.9, + steps = [8, 16, 32, 64, 100, 300], + aspect_ratios = [[2], [2, 3], [2, 3], [2, 3], [2], [2]], + min_dim = 300, + offset = 0.5, + variance = [0.1], + flip=True, + clip=True) + """ + assert isinstance(input_layers, list), 'input_layer should be a list.' + assert not step_h and not steps, '' + assert not step_w and not steps, '' + + num_layer = len(input_layers) + assert num_layer > 2 # TODO(zcd): currently, num_layer must be bigger than two. + + min_sizes = [] + max_sizes = [] + if num_layer > 2: + step = int(math.floor((max_ratio - min_ratio) / (num_layer - 2))) + for ratio in xrange(min_ratio, max_ratio + 1, step): + min_sizes.append(min_dim * ratio) + max_sizes.append(min_dim * (ratio + step)) + min_sizes = [min_dim * .10] + min_sizes + max_sizes = [min_dim * .20] + max_sizes + + if step_h: + assert isinstance(step_h,list) and len(step_h) == num_layer, \ + 'step_h should be list and input_layers and step_h should have same length' + if step_w: + assert isinstance(step_w,list) and len(step_w) == num_layer, \ + 'step_w should be list and input_layers and step_w should have same length' + if steps: + assert isinstance(steps,list) and len(step_w) == num_layer, \ + 'steps should be list and input_layers and step_w should have same length' + step_w = steps + step_h = steps + if aspect_ratios: + assert isinstance(aspect_ratios, list) and len(aspect_ratios) == num_layer, \ + 'aspect_ratios should be list and input_layers and aspect_ratios should ' \ + 'have same length' + + helper = LayerHelper("prior_box", **locals()) + dtype = helper.input_dtype() + + box_results = [] + var_results = [] + for i, input in enumerate(input_layers): + min_size = min_sizes[i] + max_size = max_sizes[i] + if isinstance(min_size, list): + min_size = [min_size] + if isinstance(max_size, list): + max_size = [max_size] + if aspect_ratios: + aspect_ratio = aspect_ratios[i] + if isinstance(aspect_ratio, list): + aspect_ratio = [aspect_ratio] + + box, var = prior_box(input, image, min_size, max_size, aspect_ratios, + variance, flip, clip, step_w[i], step_h[i], offset) + + box_results.append(box) + var_results.append(var) + + if len(box_results) == 1: + box = box_results[0] + var = var_results[0] + else: + axis = 1 + box = helper.create_tmp_variable(dtype) + helper.append_op( + type="concat", + inputs={"X": box_results}, + outputs={"Out": box}, + attrs={'axis': axis}) + + var = helper.create_tmp_variable(dtype) + helper.append_op( + type="concat", + inputs={"X": var_results}, + outputs={"Out": var}, + attrs={'axis': axis}) + + return box, var From 5ca0b7628d90098298604ecf4f62d4845db99b7d Mon Sep 17 00:00:00 2001 From: wanghaox Date: Thu, 8 Feb 2018 17:43:32 +0800 Subject: [PATCH 04/43] add OutPosCount for detection_map op --- paddle/operators/detection_map_op.cc | 47 ++++++- paddle/operators/detection_map_op.h | 132 +++++++++++++++++- .../v2/fluid/tests/test_detection_map_op.py | 111 +++++++++++++-- 3 files changed, 271 insertions(+), 19 deletions(-) diff --git a/paddle/operators/detection_map_op.cc b/paddle/operators/detection_map_op.cc index 553adb215d..1ab691eb4f 100644 --- a/paddle/operators/detection_map_op.cc +++ b/paddle/operators/detection_map_op.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -28,6 +28,12 @@ class DetectionMAPOp : public framework::OperatorWithKernel { "Input(Detection) of DetectionMAPOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) of DetectionMAPOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("OutPosCount"), + "Output(OutPosCount) of DetectionMAPOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("OutTruePos"), + "Output(OutTruePos) of DetectionMAPOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("OutFalsePos"), + "Output(OutFalsePos) of DetectionMAPOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("MAP"), "Output(MAP) of DetectionMAPOp should not be null."); @@ -44,9 +50,6 @@ class DetectionMAPOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(label_dims[1], 6UL, "The shape is of Input(Label) [N, 6]."); - auto ap_type = GetAPType(ctx->Attrs().Get("ap_type")); - PADDLE_ENFORCE_NE(ap_type, APType::kNone, - "The ap_type should be 'integral' or '11point."); auto map_dim = framework::make_ddim({1}); ctx->SetOutputDim("MAP", map_dim); } @@ -55,7 +58,8 @@ class DetectionMAPOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - framework::ToDataType(ctx.Input("Label")->type()), + framework::ToDataType( + ctx.Input("Detection")->type()), ctx.device_context()); } }; @@ -80,6 +84,33 @@ class DetectionMAPOpMaker : public framework::OpProtoAndCheckerMaker { "the offsets in first dimension are called LoD, the number of " "offset is N + 1, if LoD[i + 1] - LoD[i] == 0, means there is " "no detected data."); + AddInput("PosCount", + "(Tensor) A tensor with shape [Ncls, 1], store the " + "input positive example count of each class.") + .AsDispensable(); + AddInput("TruePos", + "(LodTensor) A 2-D LodTensor with shape [Ntp, 2], store the " + "input true positive example of each class.") + .AsDispensable(); + AddInput("FalsePos", + "(LodTensor) A 2-D LodTensor with shape [Nfp, 2], store the " + "input false positive example of each class.") + .AsDispensable(); + AddOutput("OutPosCount", + "(Tensor) A tensor with shape [Ncls, 1], store the " + "positive example count of each class. It combines the input " + "input(PosCount) and the positive example count computed from " + "input(Detection) and input(Label)."); + AddOutput("OutTruePos", + "(LodTensor) A LodTensor with shape [Ntp', 2], store the " + "true positive example of each class. It combines the " + "input(TruePos) and the true positive examples computed from " + "input(Detection) and input(Label)."); + AddOutput("OutFalsePos", + "(LodTensor) A LodTensor with shape [Nfp', 2], store the " + "false positive example of each class. It combines the " + "input(FalsePos) and the false positive examples computed from " + "input(Detection) and input(Label)."); AddOutput("MAP", "(Tensor) A tensor with shape [1], store the mAP evaluate " "result of the detection."); @@ -97,7 +128,11 @@ class DetectionMAPOpMaker : public framework::OpProtoAndCheckerMaker { "(string, default 'integral') " "The AP algorithm type, 'integral' or '11point'.") .SetDefault("integral") - .InEnum({"integral", "11point"}); + .InEnum({"integral", "11point"}) + .AddCustomChecker([](const std::string& ap_type) { + PADDLE_ENFORCE_NE(GetAPType(ap_type), APType::kNone, + "The ap_type should be 'integral' or '11point."); + }); AddComment(R"DOC( Detection mAP evaluate operator. The general steps are as follows. First, calculate the true positive and diff --git a/paddle/operators/detection_map_op.h b/paddle/operators/detection_map_op.h index d29a6968e4..fd0ddd10aa 100644 --- a/paddle/operators/detection_map_op.h +++ b/paddle/operators/detection_map_op.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -58,6 +58,14 @@ class DetectionMAPOpKernel : public framework::OpKernel { auto* in_label = ctx.Input("Label"); auto* out_map = ctx.Output("MAP"); + auto* in_pos_count = ctx.Input("PosCount"); + auto* in_true_pos = ctx.Input("TruePos"); + auto* in_false_pos = ctx.Input("FalsePos"); + + auto* out_pos_count = ctx.Output("OutPosCount"); + auto* out_true_pos = ctx.Output("OutTruePos"); + auto* out_false_pos = ctx.Output("OutFalsePos"); + float overlap_threshold = ctx.Attr("overlap_threshold"); float evaluate_difficult = ctx.Attr("evaluate_difficult"); auto ap_type = GetAPType(ctx.Attr("ap_type")); @@ -79,12 +87,20 @@ class DetectionMAPOpKernel : public framework::OpKernel { std::map>> true_pos; std::map>> false_pos; + if (in_pos_count != nullptr) { + GetInputPos(*in_pos_count, *in_true_pos, *in_false_pos, label_pos_count, + true_pos, false_pos); + } + CalcTrueAndFalsePositive(gt_boxes, detect_boxes, evaluate_difficult, overlap_threshold, label_pos_count, true_pos, false_pos); T map = CalcMAP(ap_type, label_pos_count, true_pos, false_pos); + GetOutputPos(ctx, label_pos_count, true_pos, false_pos, *out_pos_count, + *out_true_pos, *out_false_pos); + T* map_data = out_map->mutable_data(ctx.GetPlace()); map_data[0] = map; } @@ -161,6 +177,119 @@ class DetectionMAPOpKernel : public framework::OpKernel { } } + void GetOutputPos( + const framework::ExecutionContext& ctx, + const std::map& label_pos_count, + const std::map>>& true_pos, + const std::map>>& false_pos, + framework::Tensor& output_pos_count, + framework::LoDTensor& output_true_pos, + framework::LoDTensor& output_false_pos) const { + int max_class_id = 0; + int true_pos_count = 0; + int false_pos_count = 0; + for (auto it = label_pos_count.begin(); it != label_pos_count.end(); ++it) { + int label = it->first; + if (label > max_class_id) max_class_id = label; + int label_num_pos = it->second; + if (label_num_pos == 0 || true_pos.find(label) == true_pos.end()) + continue; + auto label_true_pos = true_pos.find(label)->second; + auto label_false_pos = false_pos.find(label)->second; + true_pos_count += label_true_pos.size(); + false_pos_count += label_false_pos.size(); + } + + int* pos_count_data = output_pos_count.mutable_data( + framework::make_ddim({max_class_id + 1, 1}), ctx.GetPlace()); + T* true_pos_data = output_true_pos.mutable_data( + framework::make_ddim({true_pos_count, 2}), ctx.GetPlace()); + T* false_pos_data = output_false_pos.mutable_data( + framework::make_ddim({false_pos_count, 2}), ctx.GetPlace()); + true_pos_count = 0; + false_pos_count = 0; + std::vector true_pos_starts = {0}; + std::vector false_pos_starts = {0}; + for (int i = 0; i <= max_class_id; ++i) { + auto it_count = label_pos_count.find(i); + pos_count_data[i] = 0; + if (it_count != label_pos_count.end()) { + pos_count_data[i] = it_count->second; + } + auto it_true_pos = true_pos.find(i); + if (it_true_pos != true_pos.end()) { + const std::vector>& true_pos_vec = + it_true_pos->second; + for (const std::pair& tp : true_pos_vec) { + true_pos_data[true_pos_count * 2] = tp.first; + true_pos_data[true_pos_count * 2 + 1] = static_cast(tp.second); + true_pos_count++; + } + } + true_pos_starts.push_back(true_pos_count); + + auto it_false_pos = false_pos.find(i); + if (it_false_pos != false_pos.end()) { + const std::vector>& false_pos_vec = + it_false_pos->second; + for (const std::pair& fp : false_pos_vec) { + false_pos_data[false_pos_count * 2] = fp.first; + false_pos_data[false_pos_count * 2 + 1] = static_cast(fp.second); + false_pos_count++; + } + } + false_pos_starts.push_back(false_pos_count); + } + + framework::LoD true_pos_lod; + true_pos_lod.emplace_back(true_pos_starts); + framework::LoD false_pos_lod; + false_pos_lod.emplace_back(false_pos_starts); + + output_true_pos.set_lod(true_pos_lod); + output_false_pos.set_lod(false_pos_lod); + return; + } + + void GetInputPos( + const framework::Tensor& input_pos_count, + const framework::LoDTensor& input_true_pos, + const framework::LoDTensor& input_false_pos, + std::map& label_pos_count, + std::map>>& true_pos, + std::map>>& false_pos) const { + constexpr T kEPS = static_cast(1e-6); + int class_number = input_pos_count.dims()[0]; + const int* pos_count_data = input_pos_count.data(); + for (int i = 0; i < class_number; ++i) { + label_pos_count[i] = pos_count_data[i]; + } + + const T* true_pos_data = input_true_pos.data(); + auto true_pos_data_lod = input_true_pos.lod(); + for (int i = 0; i < true_pos_data_lod.size(); ++i) { + for (int j = true_pos_data_lod[0][i]; j < true_pos_data_lod[0][i + 1]; + ++j) { + T score = true_pos_data[j * 2]; + int flag = 1; + if (true_pos_data[j * 2 + 1] < kEPS) flag = 0; + true_pos[i].push_back(std::make_pair(score, flag)); + } + } + const T* false_pos_data = input_false_pos.data(); + auto false_pos_data_lod = input_false_pos.lod(); + for (int i = 0; i < false_pos_data_lod.size(); ++i) { + for (int j = false_pos_data_lod[0][i]; j < false_pos_data_lod[0][i + 1]; + ++j) { + T score = false_pos_data[j * 2]; + int flag = 1; + if (false_pos_data[j * 2 + 1] < kEPS) flag = 0; + false_pos[i].push_back(std::make_pair(score, flag)); + } + } + return; + } + void CalcTrueAndFalsePositive( const std::vector>>& gt_boxes, const std::vector>>>& @@ -283,7 +412,6 @@ class DetectionMAPOpKernel : public framework::OpKernel { size_t num = tp_sum.size(); // Compute Precision. for (size_t i = 0; i < num; ++i) { - // CHECK_LE(tpCumSum[i], labelNumPos); precision.push_back(static_cast(tp_sum[i]) / static_cast(tp_sum[i] + fp_sum[i])); recall.push_back(static_cast(tp_sum[i]) / label_num_pos); diff --git a/python/paddle/v2/fluid/tests/test_detection_map_op.py b/python/paddle/v2/fluid/tests/test_detection_map_op.py index db8012334a..ec57ca4ad5 100644 --- a/python/paddle/v2/fluid/tests/test_detection_map_op.py +++ b/python/paddle/v2/fluid/tests/test_detection_map_op.py @@ -29,10 +29,24 @@ class TestDetectionMAPOp(OpTest): self.detect = np.array(self.detect).astype('float32') self.mAP = np.array(self.mAP).astype('float32') - self.inputs = { - 'Label': (self.label, self.label_lod), - 'Detection': (self.detect, self.detect_lod) - } + if (len(self.class_pos_count) > 0): + self.class_pos_count = np.array(self.class_pos_count).astype( + 'int32') + self.true_pos = np.array(self.true_pos).astype('float32') + self.false_pos = np.array(self.false_pos).astype('float32') + + self.inputs = { + 'Label': (self.label, self.label_lod), + 'Detection': (self.detect, self.detect_lod), + 'PosCount': self.class_pos_count, + 'TruePos': (self.true_pos, self.true_pos_lod), + 'FalsePos': (self.false_pos, self.false_pos_lod) + } + else: + self.inputs = { + 'Label': (self.label, self.label_lod), + 'Detection': (self.detect, self.detect_lod), + } self.attrs = { 'overlap_threshold': self.overlap_threshold, @@ -40,7 +54,17 @@ class TestDetectionMAPOp(OpTest): 'ap_type': self.ap_type } - self.outputs = {'MAP': self.mAP} + self.out_class_pos_count = np.array(self.out_class_pos_count).astype( + 'int') + self.out_true_pos = np.array(self.out_true_pos).astype('float32') + self.out_false_pos = np.array(self.out_false_pos).astype('float32') + + self.outputs = { + 'MAP': self.mAP, + 'OutPosCount': self.out_class_pos_count, + 'OutTruePos': (self.out_true_pos, self.out_true_pos_lod), + 'OutFalsePos': (self.out_false_pos, self.out_false_pos_lod) + } def init_test_case(self): self.overlap_threshold = 0.3 @@ -67,13 +91,64 @@ class TestDetectionMAPOp(OpTest): [1, 0.2, 1, 0], [2, 0.8, 0, 1], [2, 0.1, 1, 0], [3, 0.2, 0, 1]] + self.class_pos_count = [] + self.true_pos_lod = [[]] + self.true_pos = [[]] + self.false_pos_lod = [[]] + self.false_pos = [[]] + def calc_map(self, tf_pos, tf_pos_lod): mAP = 0.0 count = 0 - class_pos_count = {} - true_pos = {} - false_pos = {} + def get_input_pos(class_pos_count, true_pos, true_pos_lod, false_pos, + false_pos_lod): + class_pos_count_dict = collections.Counter() + true_pos_dict = collections.defaultdict(list) + false_pos_dict = collections.defaultdict(list) + for i, count in enumerate(class_pos_count): + class_pos_count_dict[i] = count + + for i in range(len(true_pos_lod[0]) - 1): + start = true_pos_lod[0][i] + end = true_pos_lod[0][i + 1] + for j in range(start, end): + true_pos_dict[i].append(true_pos[j]) + + for i in range(len(false_pos_lod[0]) - 1): + start = false_pos_lod[0][i] + end = false_pos_lod[0][i + 1] + for j in range(start, end): + false_pos_dict[i].append(false_pos[j]) + + return class_pos_count_dict, true_pos_dict, false_pos_dict + + def get_output_pos(label_count, true_pos, false_pos): + max_label = 0 + for (label, label_pos_num) in label_count.items(): + if max_label < label: + max_label = label + + label_number = max_label + 1 + + out_class_pos_count = [] + out_true_pos_lod = [0] + out_true_pos = [] + out_false_pos_lod = [0] + out_false_pos = [] + + for i in range(label_number): + out_class_pos_count.append([label_count[i]]) + true_pos_list = true_pos[i] + out_true_pos += true_pos_list + out_true_pos_lod.append(len(out_true_pos)) + false_pos_list = false_pos[i] + out_false_pos += false_pos_list + out_false_pos_lod.append(len(out_false_pos)) + + return out_class_pos_count, out_true_pos, [ + out_true_pos_lod + ], out_false_pos, [out_false_pos_lod] def get_accumulation(pos_list): sorted_list = sorted(pos_list, key=lambda pos: pos[0], reverse=True) @@ -84,7 +159,9 @@ class TestDetectionMAPOp(OpTest): accu_list.append(sum) return accu_list - label_count = collections.Counter() + label_count, true_pos, false_pos = get_input_pos( + self.class_pos_count, self.true_pos, self.true_pos_lod, + self.false_pos, self.false_pos_lod) for (label, difficult, xmin, ymin, xmax, ymax) in self.label: if self.evaluate_difficult: label_count[label] += 1 @@ -143,8 +220,10 @@ class TestDetectionMAPOp(OpTest): mAP += average_precisions count += 1 - - if count != 0: mAP /= count + self.out_class_pos_count, self.out_true_pos, self.out_true_pos_lod, self.out_false_pos, self.out_false_pos_lod = get_output_pos( + label_count, true_pos, false_pos) + if count != 0: + mAP /= count return mAP * 100.0 def setUp(self): @@ -174,5 +253,15 @@ class TestDetectionMAPOp11Point(TestDetectionMAPOp): self.ap_type = "11point" +class TestDetectionMAPOpMultiBatch(TestDetectionMAPOp): + def init_test_case(self): + super(TestDetectionMAPOpMultiBatch, self).init_test_case() + self.class_pos_count = [0, 2, 1] + self.true_pos_lod = [[0, 0, 3, 5]] + self.true_pos = [[0.7, 1.], [0.3, 0.], [0.2, 1.], [0.8, 0.], [0.1, 1.]] + self.false_pos_lod = [[0, 0, 3, 5]] + self.false_pos = [[0.7, 0.], [0.3, 1.], [0.2, 0.], [0.8, 1.], [0.1, 0.]] + + if __name__ == '__main__': unittest.main() From 19749d52348669cbf2cd000a67b2ffe790384e8c Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Thu, 8 Feb 2018 16:01:41 +0800 Subject: [PATCH 05/43] refine prior_box --- paddle/operators/prior_box_op.cc | 20 ++-- paddle/operators/prior_box_op.h | 8 +- python/paddle/v2/fluid/layers/nn.py | 94 ++++++++++++++----- .../v2/fluid/tests/test_prior_box_op.py | 4 +- 4 files changed, 87 insertions(+), 39 deletions(-) diff --git a/paddle/operators/prior_box_op.cc b/paddle/operators/prior_box_op.cc index 1dc4b28855..b7f38b3cb6 100644 --- a/paddle/operators/prior_box_op.cc +++ b/paddle/operators/prior_box_op.cc @@ -38,8 +38,8 @@ class PriorBoxOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_LT(input_dims[3], image_dims[3], "The width of input must smaller than image."); - auto min_sizes = ctx->Attrs().Get>("min_sizes"); - auto max_sizes = ctx->Attrs().Get>("max_sizes"); + auto min_sizes = ctx->Attrs().Get>("min_sizes"); + auto max_sizes = ctx->Attrs().Get>("max_sizes"); auto variances = ctx->Attrs().Get>("variances"); auto aspect_ratios = ctx->Attrs().Get>("aspect_ratios"); bool flip = ctx->Attrs().Get("flip"); @@ -47,7 +47,7 @@ class PriorBoxOp : public framework::OperatorWithKernel { std::vector aspect_ratios_vec; ExpandAspectRatios(aspect_ratios, flip, aspect_ratios_vec); - int num_priors = aspect_ratios_vec.size() * min_sizes.size(); + size_t num_priors = aspect_ratios_vec.size() * min_sizes.size(); if (max_sizes.size() > 0) { PADDLE_ENFORCE_EQ(max_sizes.size(), min_sizes.size(), "The number of min_size and max_size must be equal."); @@ -90,20 +90,20 @@ class PriorBoxOpMaker : public framework::OpProtoAndCheckerMaker { "H is the height of input, W is the width of input, num_priors " "is the box count of each position."); - AddAttr>("min_sizes", - "(vector) List of min sizes " - "of generated prior boxes.") - .AddCustomChecker([](const std::vector& min_sizes) { + AddAttr>("min_sizes", + "(vector) List of min sizes " + "of generated prior boxes.") + .AddCustomChecker([](const std::vector& min_sizes) { PADDLE_ENFORCE_GT(min_sizes.size(), 0, "Size of min_sizes must be at least 1."); for (size_t i = 0; i < min_sizes.size(); ++i) { - PADDLE_ENFORCE_GT(min_sizes[i], 0, + PADDLE_ENFORCE_GT(min_sizes[i], 0.0, "min_sizes[%d] must be positive.", i); } }); - AddAttr>( + AddAttr>( "max_sizes", - "(vector) List of max sizes of generated prior boxes."); + "(vector) List of max sizes of generated prior boxes."); AddAttr>( "aspect_ratios", "(vector) List of aspect ratios of generated prior boxes."); diff --git a/paddle/operators/prior_box_op.h b/paddle/operators/prior_box_op.h index 6b221cb74e..d8ff5d19eb 100644 --- a/paddle/operators/prior_box_op.h +++ b/paddle/operators/prior_box_op.h @@ -60,8 +60,8 @@ class PriorBoxOpKernel : public framework::OpKernel { auto* boxes = ctx.Output("Boxes"); auto* vars = ctx.Output("Variances"); - auto min_sizes = ctx.Attr>("min_sizes"); - auto max_sizes = ctx.Attr>("max_sizes"); + auto min_sizes = ctx.Attr>("min_sizes"); + auto max_sizes = ctx.Attr>("max_sizes"); auto input_aspect_ratio = ctx.Attr>("aspect_ratios"); auto variances = ctx.Attr>("variances"); auto flip = ctx.Attr("flip"); @@ -108,7 +108,7 @@ class PriorBoxOpKernel : public framework::OpKernel { T box_width, box_height; int idx = 0; for (size_t s = 0; s < min_sizes.size(); ++s) { - int min_size = min_sizes[s]; + auto min_size = min_sizes[s]; // first prior: aspect_ratio = 1, size = min_size box_width = box_height = min_size; // xmin @@ -124,7 +124,7 @@ class PriorBoxOpKernel : public framework::OpKernel { idx++; if (max_sizes.size() > 0) { - int max_size = max_sizes[s]; + auto max_size = max_sizes[s]; // second prior: aspect_ratio = 1, // size = sqrt(min_size * max_size) box_width = box_height = sqrt(min_size * max_size); diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py index 891d89a24b..dc1839fd82 100644 --- a/python/paddle/v2/fluid/layers/nn.py +++ b/python/paddle/v2/fluid/layers/nn.py @@ -14,13 +14,16 @@ """ All layers just related to the neural network. """ -import math + from ..layer_helper import LayerHelper from ..initializer import Normal, Constant from ..framework import Variable from ..param_attr import ParamAttr from layer_function_generator import autodoc from tensor import concat +import math +import numpy as np +from operator import mul __all__ = [ 'fc', @@ -64,7 +67,10 @@ __all__ = [ 'nce', 'beam_search', 'row_conv', + 'reshape', + 'reshape_with_axis', 'multiplex', + 'prior_box' 'prior_boxes', ] @@ -2996,6 +3002,40 @@ def multiplex(inputs, index): return out +def reshape_with_axis(input, axis): + """ + **ReshapeWithAxis Layer** + + """ + assert len(input.shape) > axis and axis >= 0, ' ' + input_shape = input.shape + new_dim = [-1, reduce(mul, input_shape[axis:len(input_shape)], 1)] + + helper = LayerHelper('reshape', **locals()) + out = helper.create_tmp_variable(helper.input_dtype()) + helper.append_op( + type='reshape', + inputs={'X': [input]}, + outputs={'Out': [out]}, + attrs={'shape': new_dim}) + return out + + +def reshape(input, new_dim): + """ + **Reshape Layer** + + """ + helper = LayerHelper('reshape', **locals()) + out = helper.create_tmp_variable(helper.input_dtype()) + helper.append_op( + type='reshape', + inputs={'X': [input]}, + outputs={'Out': [out]}, + attrs={'shape': new_dim}) + return out + + def prior_box(input, image, min_sizes, @@ -3041,13 +3081,13 @@ def prior_boxes(input_layers, image, min_ratio, max_ratio, - steps, aspect_ratios, min_dim, + steps=None, step_w=None, step_h=None, offset=0.5, - variance=[0.1], + variance=[0.1, 0.1, 0.1, 0.1], flip=True, clip=True, name=None): @@ -3059,8 +3099,8 @@ def prior_boxes(input_layers, image = data, min_ratio = 0.2, max_ratio = 0.9, - steps = [8, 16, 32, 64, 100, 300], - aspect_ratios = [[2], [2, 3], [2, 3], [2, 3], [2], [2]], + steps = [8., 16., 32., 64., 100., 300.], + aspect_ratios = [[2.], [2., 3.], [2., 3.], [2., 3.], [2.], [2.]], min_dim = 300, offset = 0.5, variance = [0.1], @@ -3068,19 +3108,16 @@ def prior_boxes(input_layers, clip=True) """ assert isinstance(input_layers, list), 'input_layer should be a list.' - assert not step_h and not steps, '' - assert not step_w and not steps, '' - num_layer = len(input_layers) assert num_layer > 2 # TODO(zcd): currently, num_layer must be bigger than two. min_sizes = [] max_sizes = [] if num_layer > 2: - step = int(math.floor((max_ratio - min_ratio) / (num_layer - 2))) + step = int(math.floor(((max_ratio - min_ratio)) / (num_layer - 2))) for ratio in xrange(min_ratio, max_ratio + 1, step): - min_sizes.append(min_dim * ratio) - max_sizes.append(min_dim * (ratio + step)) + min_sizes.append(min_dim * ratio / 100.) + max_sizes.append(min_dim * (ratio + step) / 100.) min_sizes = [min_dim * .10] + min_sizes max_sizes = [min_dim * .20] + max_sizes @@ -3091,7 +3128,7 @@ def prior_boxes(input_layers, assert isinstance(step_w,list) and len(step_w) == num_layer, \ 'step_w should be list and input_layers and step_w should have same length' if steps: - assert isinstance(steps,list) and len(step_w) == num_layer, \ + assert isinstance(steps,list) and len(steps) == num_layer, \ 'steps should be list and input_layers and step_w should have same length' step_w = steps step_h = steps @@ -3100,25 +3137,25 @@ def prior_boxes(input_layers, 'aspect_ratios should be list and input_layers and aspect_ratios should ' \ 'have same length' - helper = LayerHelper("prior_box", **locals()) - dtype = helper.input_dtype() - box_results = [] var_results = [] for i, input in enumerate(input_layers): min_size = min_sizes[i] max_size = max_sizes[i] - if isinstance(min_size, list): + aspect_ratio = [] + if not isinstance(min_size, list): min_size = [min_size] - if isinstance(max_size, list): + if not isinstance(max_size, list): max_size = [max_size] if aspect_ratios: aspect_ratio = aspect_ratios[i] - if isinstance(aspect_ratio, list): + if not isinstance(aspect_ratio, list): aspect_ratio = [aspect_ratio] - box, var = prior_box(input, image, min_size, max_size, aspect_ratios, - variance, flip, clip, step_w[i], step_h[i], offset) + box, var = prior_box(input, image, min_size, max_size, aspect_ratio, + variance, flip, clip, step_w[i] + if step_w else [], step_h[i] + if step_w else [], offset) box_results.append(box) var_results.append(var) @@ -3127,18 +3164,29 @@ def prior_boxes(input_layers, box = box_results[0] var = var_results[0] else: - axis = 1 + axis = 3 + reshaped_boxes = [] + reshaped_vars = [] + for i in range(len(box_results)): + reshaped_boxes += [reshape_with_axis(box_results[i], axis=axis)] + reshaped_vars += [reshape_with_axis(var_results[i], axis=axis)] + + helper = LayerHelper("concat", **locals()) + dtype = helper.input_dtype() box = helper.create_tmp_variable(dtype) + var = helper.create_tmp_variable(dtype) + + axis = 0 helper.append_op( type="concat", - inputs={"X": box_results}, + inputs={"X": reshaped_boxes}, outputs={"Out": box}, attrs={'axis': axis}) var = helper.create_tmp_variable(dtype) helper.append_op( type="concat", - inputs={"X": var_results}, + inputs={"X": reshaped_vars}, outputs={"Out": var}, attrs={'axis': axis}) diff --git a/python/paddle/v2/fluid/tests/test_prior_box_op.py b/python/paddle/v2/fluid/tests/test_prior_box_op.py index ca8d2bca74..25dfc4307c 100644 --- a/python/paddle/v2/fluid/tests/test_prior_box_op.py +++ b/python/paddle/v2/fluid/tests/test_prior_box_op.py @@ -65,9 +65,9 @@ class TestPriorBoxOp(OpTest): self.batch_size = 10 self.min_sizes = [2, 4] - self.min_sizes = np.array(self.min_sizes).astype('int64') + self.min_sizes = np.array(self.min_sizes).astype('float32') self.max_sizes = [5, 10] - self.max_sizes = np.array(self.max_sizes).astype('int64') + self.max_sizes = np.array(self.max_sizes).astype('float32') self.aspect_ratios = [2.0, 3.0] self.flip = True self.real_aspect_ratios = [1, 2.0, 1.0 / 2.0, 3.0, 1.0 / 3.0] From 98c943730e886ffaf3b6feb59b64d977158f995e Mon Sep 17 00:00:00 2001 From: Yang Yang Date: Fri, 9 Feb 2018 00:12:54 +0000 Subject: [PATCH 06/43] use op run as wrapper of run_impl; make run_impl as private virtual function --- paddle/framework/op_registry_test.cc | 10 +++++-- paddle/framework/operator.cc | 16 +++++++++-- paddle/framework/operator.h | 15 ++++++---- paddle/framework/operator_test.cc | 11 ++++++-- paddle/operators/array_to_lod_tensor_op.cc | 6 ++-- paddle/operators/assign_op.cc | 6 ++-- paddle/operators/beam_search_decode_op.cc | 6 ++-- paddle/operators/beam_search_op.h | 5 ++-- paddle/operators/cond_op.cc | 2 +- paddle/operators/cond_op.h | 5 ++-- paddle/operators/conditional_block_op.cc | 12 +++++--- paddle/operators/create_reader_op.cc | 18 ++++++++---- paddle/operators/feed_op.cc | 6 ++-- paddle/operators/fetch_op.cc | 5 ++-- paddle/operators/fill_constant_op.cc | 6 ++-- paddle/operators/fill_op.cc | 6 ++-- paddle/operators/get_places_op.cc | 6 ++-- paddle/operators/increment_op.cc | 5 ++-- paddle/operators/is_empty_op.cc | 5 ++-- paddle/operators/load_combine_op.cc | 6 ++-- paddle/operators/load_op.cc | 6 ++-- paddle/operators/lod_array_length_op.cc | 6 ++-- paddle/operators/lod_rank_table_op.cc | 6 ++-- paddle/operators/lod_tensor_to_array_op.cc | 6 ++-- paddle/operators/max_sequence_len_op.cc | 5 ++-- paddle/operators/merge_lod_tensor_op.cc | 6 ++-- paddle/operators/nccl_op.cc | 5 ++-- paddle/operators/net_op.h | 28 +++++++++---------- paddle/operators/net_op_test.cc | 5 +++- paddle/operators/parallel_do_op.cc | 10 ++++--- paddle/operators/print_op.cc | 5 ++-- paddle/operators/read_op.cc | 6 ++-- paddle/operators/recurrent_op.cc | 10 ++++--- .../reorder_lod_tensor_by_rank_op.cc | 6 ++-- paddle/operators/rnn_memory_helper_op.cc | 12 +++++--- paddle/operators/save_combine_op.cc | 6 ++-- paddle/operators/save_op.cc | 6 ++-- paddle/operators/shrink_rnn_memory_op.cc | 10 ++++--- paddle/operators/split_lod_tensor_op.cc | 6 ++-- .../operators/tensor_array_read_write_op.cc | 11 +++++--- paddle/operators/while_op.cc | 10 ++++--- 41 files changed, 214 insertions(+), 114 deletions(-) diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc index 341da8befd..b22e06cc79 100644 --- a/paddle/framework/op_registry_test.cc +++ b/paddle/framework/op_registry_test.cc @@ -25,7 +25,10 @@ namespace framework { class CosineOp : public OperatorBase { public: using OperatorBase::OperatorBase; - void Run(const Scope& scope, const platform::Place& place) const override {} + + private: + void RunImpl(const Scope& scope, + const platform::Place& place) const override {} }; class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { @@ -44,7 +47,10 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { class MyTestOp : public OperatorBase { public: using OperatorBase::OperatorBase; - void Run(const Scope& scope, const platform::Place& place) const override {} + + private: + void RunImpl(const Scope& scope, + const platform::Place& place) const override {} }; class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 52387aabd9..240a0602c9 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -64,6 +64,18 @@ static LoD GetLoD(const Scope& scope, const std::string& name) { } } +void OperatorBase::Run(const Scope& scope, const platform::Place& place) { + if (platform::is_gpu_place(place)) { +#ifndef PADDLE_WITH_CUDA + PADDLE_THROW("Cannot run operator on place %s", place); +#else + auto dev_id = boost::get(place).device; + platform::SetDeviceId(dev_id); +#endif + } + RunImpl(scope, place); +} + std::string OperatorBase::Input(const std::string& name) const { auto& ins = Inputs(name); PADDLE_ENFORCE_LE(ins.size(), 1UL, @@ -475,8 +487,8 @@ class RuntimeInferShapeContext : public InferShapeContext { const Scope& scope_; }; -void OperatorWithKernel::Run(const Scope& scope, - const platform::Place& place) const { +void OperatorWithKernel::RunImpl(const Scope& scope, + const platform::Place& place) const { RuntimeInferShapeContext infer_shape_ctx(*this, scope); this->InferShape(&infer_shape_ctx); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index c9140f304c..886e373348 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -89,8 +89,9 @@ class OperatorBase { std::string DebugString() const { return DebugStringEx(nullptr); } - /// Net will call this function to Run an op. - virtual void Run(const Scope& scope, const platform::Place& place) const = 0; + /// Net will call this interface function to Run an op. + // The implementation should be written at RunImpl + void Run(const Scope& scope, const platform::Place& place); // FIXME(typhoonzero): this is only used for recv_op to stop event_loop. virtual void Stop() {} @@ -144,6 +145,8 @@ class OperatorBase { private: void GenerateTemporaryNames(); void CheckAllInputOutputSet() const; + virtual void RunImpl(const Scope& scope, + const platform::Place& place) const = 0; }; // Macro for define a clone method. @@ -168,10 +171,13 @@ class OperatorBase { class NOP : public OperatorBase { public: using OperatorBase::OperatorBase; - void Run(const Scope& scope, const platform::Place& place) const override {} std::unique_ptr Clone() const override { return std::unique_ptr(new NOP(*this)); } + + private: + void RunImpl(const Scope& scope, + const platform::Place& place) const override {} }; class ExecutionContext { @@ -363,8 +369,6 @@ class OperatorWithKernel : public OperatorBase { const VariableNameMap& outputs, const AttributeMap& attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const Scope& scope, const platform::Place& place) const final; - static std::unordered_map& AllOpKernels() { static std::unordered_map g_all_op_kernels; @@ -393,6 +397,7 @@ class OperatorWithKernel : public OperatorBase { // indicate kernel DataType by input data. Defaultly all input data must be // same. proto::DataType IndicateDataType(const ExecutionContext& ctx) const; + void RunImpl(const Scope& scope, const platform::Place& place) const final; }; extern bool OpSupportGPU(const std::string& op_type); diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index b69d7c7a74..7100e64732 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -28,7 +28,10 @@ class OpWithoutKernelTest : public OperatorBase { OpWithoutKernelTest(const std::string& type, const VariableNameMap& inputs, const VariableNameMap& outputs, const AttributeMap& attrs) : OperatorBase(type, inputs, outputs, attrs), x(1) {} - void Run(const Scope& scope, const platform::Place& place) const override { + + private: + void RunImpl(const Scope& scope, + const platform::Place& place) const override { ++op_run_num; ASSERT_EQ(static_cast(inputs_.size()), 1); ASSERT_EQ(static_cast(outputs_.size()), 1); @@ -259,8 +262,10 @@ class OperatorClone : public paddle::framework::OperatorBase { const paddle::framework::VariableNameMap& outputs, const paddle::framework::AttributeMap& attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const paddle::framework::Scope& scope, - const paddle::platform::Place& place) const override {} + + private: + void RunImpl(const paddle::framework::Scope& scope, + const paddle::platform::Place& place) const override {} }; TEST(Operator, Clone) { diff --git a/paddle/operators/array_to_lod_tensor_op.cc b/paddle/operators/array_to_lod_tensor_op.cc index ba5c6bd3c6..3b9ebae153 100644 --- a/paddle/operators/array_to_lod_tensor_op.cc +++ b/paddle/operators/array_to_lod_tensor_op.cc @@ -31,8 +31,10 @@ class ArrayToLoDTensorOp : public framework::OperatorBase { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &dev_place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { auto &x = scope.FindVar(Input("X"))->Get(); auto &rank_table = scope.FindVar(Input("RankTable"))->Get(); diff --git a/paddle/operators/assign_op.cc b/paddle/operators/assign_op.cc index e04aa2d28c..0d1ce62bd6 100644 --- a/paddle/operators/assign_op.cc +++ b/paddle/operators/assign_op.cc @@ -71,8 +71,10 @@ class AssignOp : public framework::OperatorBase { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto *x = scope.FindVar(Input("X")); if (x == nullptr) { return; diff --git a/paddle/operators/beam_search_decode_op.cc b/paddle/operators/beam_search_decode_op.cc index 72e05607b0..a1b4430425 100644 --- a/paddle/operators/beam_search_decode_op.cc +++ b/paddle/operators/beam_search_decode_op.cc @@ -55,8 +55,10 @@ class BeamSearchDecodeOp : public framework::OperatorBase { const framework::VariableNameMap& outputs, const framework::AttributeMap& attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope& scope, - const platform::Place& dev_place) const override { + + private: + void RunImpl(const framework::Scope& scope, + const platform::Place& dev_place) const override { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto& dev_ctx = *pool.Get(dev_place); diff --git a/paddle/operators/beam_search_op.h b/paddle/operators/beam_search_op.h index 7ad85874fc..8d62e71565 100644 --- a/paddle/operators/beam_search_op.h +++ b/paddle/operators/beam_search_op.h @@ -204,8 +204,9 @@ class BeamSearchOp : public framework::OperatorBase { PADDLE_THROW("Not Implemented"); } - void Run(const framework::Scope& scope, - const platform::Place& dev_place) const override { + private: + void RunImpl(const framework::Scope& scope, + const platform::Place& dev_place) const override { auto ids_var = scope.FindVar(Input("ids")); auto scores_var = scope.FindVar(Input("scores")); auto pre_ids_var = scope.FindVar(Input("pre_ids")); diff --git a/paddle/operators/cond_op.cc b/paddle/operators/cond_op.cc index e333002bfd..28bac0b7be 100644 --- a/paddle/operators/cond_op.cc +++ b/paddle/operators/cond_op.cc @@ -193,7 +193,7 @@ void CondOp::MergeDataFromSubnet(const framework::Scope& scope, } } -void CondOp::Run(const Scope& scope, const platform::Place& place) const { +void CondOp::RunImpl(const Scope& scope, const platform::Place& place) const { // get device context from pool platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto& dev_ctx = *pool.Get(place); diff --git a/paddle/operators/cond_op.h b/paddle/operators/cond_op.h index 7dcdc47e0b..2dc0e23301 100644 --- a/paddle/operators/cond_op.h +++ b/paddle/operators/cond_op.h @@ -77,8 +77,9 @@ class CondOp : public framework::OperatorBase { sub_net_op_[FALSE_BRANCH] = std::move(net); } - void Run(const framework::Scope& scope, - const platform::Place& place) const override; + private: + void RunImpl(const framework::Scope& scope, + const platform::Place& place) const override; private: const int TRUE_BRANCH = 0; diff --git a/paddle/operators/conditional_block_op.cc b/paddle/operators/conditional_block_op.cc index bdcdb85be7..f7572ccfaf 100644 --- a/paddle/operators/conditional_block_op.cc +++ b/paddle/operators/conditional_block_op.cc @@ -65,8 +65,10 @@ class ConditionalBlockOp : public ConditionalOp { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : ConditionalOp(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &dev_place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { auto xs = InputTensors(scope); bool need_run; @@ -128,8 +130,10 @@ class ConditionalBlockGradOp : public ConditionalOp { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : ConditionalOp(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &dev_place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { auto xs = this->InputTensors(scope); bool need_run; diff --git a/paddle/operators/create_reader_op.cc b/paddle/operators/create_reader_op.cc index 5ba2a25ab4..66fd132b3a 100644 --- a/paddle/operators/create_reader_op.cc +++ b/paddle/operators/create_reader_op.cc @@ -72,8 +72,10 @@ template class CreateRandomDataGeneratorOp : public framework::OperatorBase { public: using framework::OperatorBase::OperatorBase; - void Run(const framework::Scope& scope, - const platform::Place& dev_place) const override { + + private: + void RunImpl(const framework::Scope& scope, + const platform::Place& dev_place) const override { const auto& shape_concat = Attr>("shape_concat"); const auto& ranks = Attr>("ranks"); PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty()); @@ -120,8 +122,10 @@ class CreateRandomDataGeneratorOpMaker class CreateShuffleReaderOp : public framework::OperatorBase { public: using framework::OperatorBase::OperatorBase; - void Run(const framework::Scope& scope, - const platform::Place& dev_place) const override { + + private: + void RunImpl(const framework::Scope& scope, + const platform::Place& dev_place) const override { const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) ->Get(); auto* out = scope.FindVar(Output("Out")) @@ -152,8 +156,10 @@ class CreateShuffleReaderOpMaker : public framework::OpProtoAndCheckerMaker { class CreateBatchReaderOp : public framework::OperatorBase { public: using framework::OperatorBase::OperatorBase; - void Run(const framework::Scope& scope, - const platform::Place& dev_place) const override { + + private: + void RunImpl(const framework::Scope& scope, + const platform::Place& dev_place) const override { const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) ->Get(); auto* out = scope.FindVar(Output("Out")) diff --git a/paddle/operators/feed_op.cc b/paddle/operators/feed_op.cc index 789d01e002..3f6f8a589d 100644 --- a/paddle/operators/feed_op.cc +++ b/paddle/operators/feed_op.cc @@ -24,8 +24,10 @@ class FeedOp : public framework::OperatorBase { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto feed_var_name = Input("X"); auto *feed_var = scope.FindVar(feed_var_name); diff --git a/paddle/operators/fetch_op.cc b/paddle/operators/fetch_op.cc index 7205ee2a87..bb4b7356e7 100644 --- a/paddle/operators/fetch_op.cc +++ b/paddle/operators/fetch_op.cc @@ -26,8 +26,9 @@ class FetchOp : public framework::OperatorBase { const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto fetch_var_name = Input("X"); auto *fetch_var = scope.FindVar(fetch_var_name); PADDLE_ENFORCE(fetch_var != nullptr, diff --git a/paddle/operators/fill_constant_op.cc b/paddle/operators/fill_constant_op.cc index dcd43a30c8..ce4e7bf7f2 100644 --- a/paddle/operators/fill_constant_op.cc +++ b/paddle/operators/fill_constant_op.cc @@ -33,8 +33,10 @@ class FillConstantInferShape : public framework::InferShapeBase { class FillConstantOp : public framework::OperatorBase { public: using framework::OperatorBase::OperatorBase; - void Run(const framework::Scope &scope, - const platform::Place &dev_place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { auto data_type = static_cast(Attr("dtype")); auto value = Attr("value"); diff --git a/paddle/operators/fill_op.cc b/paddle/operators/fill_op.cc index 4f5a2ed169..bc72a18902 100644 --- a/paddle/operators/fill_op.cc +++ b/paddle/operators/fill_op.cc @@ -42,8 +42,10 @@ class FillOp : public framework::OperatorBase { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto &out = detail::Ref(detail::Ref(scope.FindVar(Output("Out")), "Cannot find variable %s", Output("Out")) diff --git a/paddle/operators/get_places_op.cc b/paddle/operators/get_places_op.cc index 24fafb2307..a7168a1079 100644 --- a/paddle/operators/get_places_op.cc +++ b/paddle/operators/get_places_op.cc @@ -37,8 +37,10 @@ class GetPlacesOp : public framework::OperatorBase { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { bool is_gpu; if (Attr("device_type") == "AUTO") { is_gpu = platform::is_gpu_place(place); diff --git a/paddle/operators/increment_op.cc b/paddle/operators/increment_op.cc index e0b80cc4e7..adc7e8f1a4 100644 --- a/paddle/operators/increment_op.cc +++ b/paddle/operators/increment_op.cc @@ -51,8 +51,9 @@ class IncrementOp : public framework::OperatorBase { const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto &x = scope.FindVar(Input("X"))->Get(); auto &out = *scope.FindVar(Output("Out"))->GetMutable(); diff --git a/paddle/operators/is_empty_op.cc b/paddle/operators/is_empty_op.cc index 492ae48845..1de3437b0c 100644 --- a/paddle/operators/is_empty_op.cc +++ b/paddle/operators/is_empty_op.cc @@ -28,8 +28,9 @@ class IsEmptyOp : public framework::OperatorBase { const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { // get input auto *var = scope.FindVar(Input(kInput)); PADDLE_ENFORCE_NOT_NULL(var); diff --git a/paddle/operators/load_combine_op.cc b/paddle/operators/load_combine_op.cc index f4be793d7b..13b1c5da90 100644 --- a/paddle/operators/load_combine_op.cc +++ b/paddle/operators/load_combine_op.cc @@ -26,8 +26,10 @@ class LoadCombineOp : public framework::OperatorBase { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto filename = Attr("file_path"); std::ifstream fin(filename); diff --git a/paddle/operators/load_op.cc b/paddle/operators/load_op.cc index f886b423ac..88d0cc725d 100644 --- a/paddle/operators/load_op.cc +++ b/paddle/operators/load_op.cc @@ -25,8 +25,10 @@ class LoadOp : public framework::OperatorBase { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto filename = Attr("file_path"); std::ifstream fin(filename); PADDLE_ENFORCE(static_cast(fin), "Cannot open file %s for load op", diff --git a/paddle/operators/lod_array_length_op.cc b/paddle/operators/lod_array_length_op.cc index d2c52745cf..aa18aa2646 100644 --- a/paddle/operators/lod_array_length_op.cc +++ b/paddle/operators/lod_array_length_op.cc @@ -25,8 +25,10 @@ class LoDArrayLengthOp : public framework::OperatorBase { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto &x = scope.FindVar(Input("X"))->Get(); auto &out = *scope.FindVar(Output("Out"))->GetMutable(); diff --git a/paddle/operators/lod_rank_table_op.cc b/paddle/operators/lod_rank_table_op.cc index 692b9bf371..8e05ee63a0 100644 --- a/paddle/operators/lod_rank_table_op.cc +++ b/paddle/operators/lod_rank_table_op.cc @@ -23,8 +23,10 @@ class LoDRankTableOp : public framework::OperatorBase { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &dev_place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { auto x = scope.FindVar(Input("X"))->Get(); auto *out = scope.FindVar(Output("Out"))->GetMutable(); diff --git a/paddle/operators/lod_tensor_to_array_op.cc b/paddle/operators/lod_tensor_to_array_op.cc index 685a807a8a..0b1d2ffc8f 100644 --- a/paddle/operators/lod_tensor_to_array_op.cc +++ b/paddle/operators/lod_tensor_to_array_op.cc @@ -32,8 +32,10 @@ class LoDTensorToArrayOp : public framework::OperatorBase { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto &x = detail::Ref(scope.FindVar(Input("X")), "Cannot find input %s", Input("X")) .Get(); diff --git a/paddle/operators/max_sequence_len_op.cc b/paddle/operators/max_sequence_len_op.cc index 019150e491..794a1e56d3 100644 --- a/paddle/operators/max_sequence_len_op.cc +++ b/paddle/operators/max_sequence_len_op.cc @@ -27,8 +27,9 @@ class MaxSeqenceLenOp : public framework::OperatorBase { const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &dev_place) const override { + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { auto &rank_table = scope.FindVar(Input("RankTable"))->Get(); auto *out = diff --git a/paddle/operators/merge_lod_tensor_op.cc b/paddle/operators/merge_lod_tensor_op.cc index 87644d316d..53ee7d63f3 100644 --- a/paddle/operators/merge_lod_tensor_op.cc +++ b/paddle/operators/merge_lod_tensor_op.cc @@ -27,8 +27,10 @@ class MergeLoDTensorOp : public framework::OperatorBase { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &dev_place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { // get device context from pool platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(dev_place); diff --git a/paddle/operators/nccl_op.cc b/paddle/operators/nccl_op.cc index 9d51153b06..974ae9d963 100644 --- a/paddle/operators/nccl_op.cc +++ b/paddle/operators/nccl_op.cc @@ -26,8 +26,9 @@ class NCCLInitOp : public framework::OperatorBase { const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { const auto &name = Output("Communicator"); PADDLE_ENFORCE_NOT_NULL(scope.FindVar(name), "Can not find variable '%s' in the scope.", name); diff --git a/paddle/operators/net_op.h b/paddle/operators/net_op.h index b24042f5ef..9ac8f34347 100644 --- a/paddle/operators/net_op.h +++ b/paddle/operators/net_op.h @@ -57,20 +57,6 @@ class NetOp : public framework::OperatorBase { this->CompleteAddOp(); } - /** - * @brief Run the network. - * - * Run all the operators with the `scope`, if no scope is provided, default - * scope will be used instead. If no OpContext is provicded, default context - * will be used. - */ - void Run(const framework::Scope& scope, - const platform::Place& place) const override { - for (auto& op : ops_) { - op->Run(scope, place); - } - } - bool SupportGPU() const override { for (auto& op : ops_) { if (!op->SupportGPU()) { @@ -117,6 +103,20 @@ class NetOp : public framework::OperatorBase { std::vector> ops_; private: + /** + * @brief Run the network. + * + * Run all the operators with the `scope`, if no scope is provided, default + * scope will be used instead. If no OpContext is provicded, default context + * will be used. + */ + void RunImpl(const framework::Scope& scope, + const platform::Place& place) const override { + for (auto& op : ops_) { + op->Run(scope, place); + } + } + bool add_op_done_{false}; std::set intermediate_outputs_; diff --git a/paddle/operators/net_op_test.cc b/paddle/operators/net_op_test.cc index 9358f29f62..95d21f1516 100644 --- a/paddle/operators/net_op_test.cc +++ b/paddle/operators/net_op_test.cc @@ -26,7 +26,10 @@ class TestOp : public framework::OperatorBase { public: using framework::OperatorBase::OperatorBase; DEFINE_OP_CLONE_METHOD(TestOp); - void Run(const Scope& scope, const platform::Place& place) const override { + + private: + void RunImpl(const Scope& scope, + const platform::Place& place) const override { ++run_cnt; } }; diff --git a/paddle/operators/parallel_do_op.cc b/paddle/operators/parallel_do_op.cc index 89045923f9..b1233c93f8 100644 --- a/paddle/operators/parallel_do_op.cc +++ b/paddle/operators/parallel_do_op.cc @@ -124,8 +124,9 @@ class ParallelDoOp : public framework::OperatorBase { const framework::AttributeMap &attrs) : framework::OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { // get device context from pool platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(place); @@ -216,8 +217,9 @@ class ParallelDoGradOp : public framework::OperatorBase { const framework::AttributeMap &attrs) : framework::OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto *block = Attr(kParallelBlock); auto *program = block->Program(); diff --git a/paddle/operators/print_op.cc b/paddle/operators/print_op.cc index 8b233d64c9..e869e4d620 100644 --- a/paddle/operators/print_op.cc +++ b/paddle/operators/print_op.cc @@ -130,8 +130,9 @@ class TensorPrintOp : public framework::OperatorBase { PADDLE_THROW("Not implemented."); } - void Run(const framework::Scope& scope, - const platform::Place& place) const override { + private: + void RunImpl(const framework::Scope& scope, + const platform::Place& place) const override { const framework::Variable* in_var_ptr = nullptr; std::string phase = kForward; std::string printed_var_name = ""; diff --git a/paddle/operators/read_op.cc b/paddle/operators/read_op.cc index 3ae454101f..924b787faa 100644 --- a/paddle/operators/read_op.cc +++ b/paddle/operators/read_op.cc @@ -54,8 +54,10 @@ class ReadInferVarType : public framework::VarTypeInference { class ReadOp : public framework::OperatorBase { public: using framework::OperatorBase::OperatorBase; - void Run(const framework::Scope& scope, - const platform::Place& dev_place) const override { + + private: + void RunImpl(const framework::Scope& scope, + const platform::Place& dev_place) const override { framework::ReaderHolder* reader = scope.FindVar(Input("Reader"))->GetMutable(); if (!reader->HasNext()) { diff --git a/paddle/operators/recurrent_op.cc b/paddle/operators/recurrent_op.cc index a136c5b447..19ad7fbb70 100644 --- a/paddle/operators/recurrent_op.cc +++ b/paddle/operators/recurrent_op.cc @@ -226,8 +226,9 @@ class RecurrentOp : public RecurrentBase { const framework::AttributeMap &attrs) : RecurrentBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto seq_len = static_cast(this->GetSequenceLength(scope)); VLOG(3) << "Static RNN input sequence length = " << seq_len; StepScopes scopes = CreateStepScopes(scope, seq_len); @@ -315,8 +316,9 @@ class RecurrentGradOp : public RecurrentBase { const framework::AttributeMap &attrs) : RecurrentBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto seq_len = static_cast(GetSequenceLength(scope)); StepScopes scopes = CreateStepScopes(scope, seq_len); auto reverse = Attr(kReverse); diff --git a/paddle/operators/reorder_lod_tensor_by_rank_op.cc b/paddle/operators/reorder_lod_tensor_by_rank_op.cc index 3c30447949..f5c16870b5 100644 --- a/paddle/operators/reorder_lod_tensor_by_rank_op.cc +++ b/paddle/operators/reorder_lod_tensor_by_rank_op.cc @@ -75,8 +75,10 @@ class ReorderLoDTensorByRankTableBase : public framework::OperatorBase { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto &x = detail::Ref(scope.FindVar(Input("X")), "Cannot find input lod tensor variable %s", Input("X")) diff --git a/paddle/operators/rnn_memory_helper_op.cc b/paddle/operators/rnn_memory_helper_op.cc index eb55ed6a05..fe88aa1fb5 100644 --- a/paddle/operators/rnn_memory_helper_op.cc +++ b/paddle/operators/rnn_memory_helper_op.cc @@ -24,8 +24,10 @@ class RNNMemoryHelperOp : public framework::OperatorBase { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &dev_place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { auto mem_var_name = Input("X"); auto *mem_var = scope.FindVar(mem_var_name); PADDLE_ENFORCE(mem_var != nullptr, @@ -76,8 +78,10 @@ class RNNMemoryHelperGradOp : public framework::OperatorBase { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &dev_place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { auto out_grad_var_name = Input(framework::GradVarName("Out")); auto *out_grad_var = scope.FindVar(out_grad_var_name); diff --git a/paddle/operators/save_combine_op.cc b/paddle/operators/save_combine_op.cc index bffa2908bc..5ce0bfb914 100644 --- a/paddle/operators/save_combine_op.cc +++ b/paddle/operators/save_combine_op.cc @@ -63,8 +63,10 @@ class SaveCombineOp : public framework::OperatorBase { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto filename = Attr("file_path"); auto overwrite = Attr("overwrite"); diff --git a/paddle/operators/save_op.cc b/paddle/operators/save_op.cc index 4b1cbe8883..c8250d0c3d 100644 --- a/paddle/operators/save_op.cc +++ b/paddle/operators/save_op.cc @@ -62,8 +62,10 @@ class SaveOp : public framework::OperatorBase { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto filename = Attr("file_path"); auto overwrite = Attr("overwrite"); diff --git a/paddle/operators/shrink_rnn_memory_op.cc b/paddle/operators/shrink_rnn_memory_op.cc index bf870115a4..cd96ec5133 100644 --- a/paddle/operators/shrink_rnn_memory_op.cc +++ b/paddle/operators/shrink_rnn_memory_op.cc @@ -27,8 +27,9 @@ class ShrinkRNNMemoryOp : public ArrayOp { const framework::AttributeMap &attrs) : ArrayOp(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto *x_var = scope.FindVar(Input("X")); PADDLE_ENFORCE(x_var != nullptr, "Input X must be set"); auto &x_tensor = x_var->Get(); @@ -108,8 +109,9 @@ class ShrinkRNNMemoryGradOp : public ArrayOp { const framework::AttributeMap &attrs) : ArrayOp(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto *dout_var = scope.FindVar(Input(framework::GradVarName("Out"))); auto *dx_var = scope.FindVar(Output(framework::GradVarName("X"))); PADDLE_ENFORCE(dx_var != nullptr, "Input Gradient should not be nullptr"); diff --git a/paddle/operators/split_lod_tensor_op.cc b/paddle/operators/split_lod_tensor_op.cc index bd93c49201..cd833889ed 100644 --- a/paddle/operators/split_lod_tensor_op.cc +++ b/paddle/operators/split_lod_tensor_op.cc @@ -33,8 +33,10 @@ class SplitLoDTensorOp : public framework::OperatorBase { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &dev_place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { auto &x = scope.FindVar(Input("X"))->Get(); auto &mask = scope.FindVar(Input("Mask"))->Get(); auto *out_true = diff --git a/paddle/operators/tensor_array_read_write_op.cc b/paddle/operators/tensor_array_read_write_op.cc index a70be8b875..af3d9b7cc3 100644 --- a/paddle/operators/tensor_array_read_write_op.cc +++ b/paddle/operators/tensor_array_read_write_op.cc @@ -24,8 +24,9 @@ class WriteToArrayOp : public ArrayOp { const framework::AttributeMap &attrs) : ArrayOp(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto *x = scope.FindVar(Input("X")); if (x == nullptr) return; auto &x_tensor = x->Get(); @@ -122,8 +123,10 @@ class ReadFromArrayOp : public ArrayOp { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : ArrayOp(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &place) const override { + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { auto *x = scope.FindVar(Input("X")); PADDLE_ENFORCE(x != nullptr, "X must be set"); auto &x_array = x->Get(); diff --git a/paddle/operators/while_op.cc b/paddle/operators/while_op.cc index a744ebd615..06b0c77485 100644 --- a/paddle/operators/while_op.cc +++ b/paddle/operators/while_op.cc @@ -39,8 +39,9 @@ class WhileOp : public framework::OperatorBase { const framework::AttributeMap &attrs) : framework::OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &dev_place) const override { + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { PADDLE_ENFORCE_NOT_NULL(scope.FindVar(Input(kCondition))); auto &cond = scope.FindVar(Input(kCondition))->Get(); PADDLE_ENFORCE_EQ(cond.dims(), paddle::framework::make_ddim({1})); @@ -99,8 +100,9 @@ class WhileGradOp : public framework::OperatorBase { const framework::AttributeMap &attrs) : framework::OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &dev_place) const override { + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { // get device context from pool platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(dev_place); From 3cdb419b15b13cdf29803aef9e5b4fd28cca930e Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Fri, 9 Feb 2018 16:44:06 +0800 Subject: [PATCH 07/43] add doc for prior box --- python/paddle/v2/fluid/layers/nn.py | 158 ++++++++++++++++++++++++---- 1 file changed, 137 insertions(+), 21 deletions(-) diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py index dc1839fd82..0d944c332b 100644 --- a/python/paddle/v2/fluid/layers/nn.py +++ b/python/paddle/v2/fluid/layers/nn.py @@ -22,7 +22,6 @@ from ..param_attr import ParamAttr from layer_function_generator import autodoc from tensor import concat import math -import numpy as np from operator import mul __all__ = [ @@ -3006,10 +3005,43 @@ def reshape_with_axis(input, axis): """ **ReshapeWithAxis Layer** - """ - assert len(input.shape) > axis and axis >= 0, ' ' + According to the axis to merge the adjacent dim of input. Currently, the axis of + reshape_with_axis must be a scalar. + + Args: + input(variable): The input tensor. + axis(list): According to the axis to merge the adjacent dim. + + Returns: + Variable: A tensor variable. + + Examples: + .. code-block:: python + + x = fluid.layers.data(name="data", shape=[3, 32, 32], dtype="float32") + reshaped = fluid.layers.reshape_with_axis(input=x, axis=2) + reshaped.shape + >> [-1, 1024] + reshaped = fluid.layers.reshape_with_axis(input=x, axis=[1,3]) + reshaped.shape + >> [-1, 96, 32] + """ + assert isinstance(axis, list), "axis should be list." + assert len(input.shape) > len( + axis), "the length of axis should be litter than input.shape's." input_shape = input.shape - new_dim = [-1, reduce(mul, input_shape[axis:len(input_shape)], 1)] + temp = 0 + for ax in axis: + assert ax < len(input.shape) and ax > 0, \ + 'The data of Axis should be between 1 and len(input.shape)' + assert ax > temp, 'Axis should be incremented sequence' + temp = ax + axis += [len(input.shape)] + + new_shape = [] + for i in range(len(axis) - 1): + new_shape += [reduce(mul, input_shape[axis[i]:axis[i + 1]], 1)] + new_shape = [-1] + new_shape helper = LayerHelper('reshape', **locals()) out = helper.create_tmp_variable(helper.input_dtype()) @@ -3017,14 +3049,28 @@ def reshape_with_axis(input, axis): type='reshape', inputs={'X': [input]}, outputs={'Out': [out]}, - attrs={'shape': new_dim}) + attrs={'shape': new_shape}) return out -def reshape(input, new_dim): +def reshape(input, new_shape): """ **Reshape Layer** + Reshape the shape of input according to new_dim. + + Args: + input(variable): The input tensor. + new_shape(list): The new shape of input. + + Returns: + Variable: A tensor variable. + + Examples: + .. code-block:: python + + x = fluid.layers.data(name="data", shape=[3, 32, 32], dtype="float32") + reshaped = fluid.layers.reshape(input=x, new_shape=[-1, 1024]) """ helper = LayerHelper('reshape', **locals()) out = helper.create_tmp_variable(helper.input_dtype()) @@ -3051,6 +3097,44 @@ def prior_box(input, """ **Prior_box** + Generate prior boxes for SSD(Single Shot MultiBox Detector) algorithm. + Each position of the input produce N prior boxes, N is determined by + the count of min_sizes, max_sizes and aspect_ratios, The size of the + box is in range(min_size, max_size) interval, which is generated in + sequence according to the aspect_ratios. + + Args: + input(variable): The input feature data of PriorBox, the layout is NCHW. + image(variable): The input image data of PriorBoxOp, the layout is NCHW. + min_sizes(list): the min sizes of generated prior boxes. + max_sizes(list): the max sizes of generated prior boxes. + aspect_ratios(list): the aspect ratios of generated prior boxes. + variance(list): the variances to be encoded in prior boxes. + flip(bool): Whether to flip aspect ratios. + clip(bool): Whether to clip out-of-boundary boxes. + step_w(list): Prior boxes step across width, 0 for auto calculation. + step_h(list): Prior boxes step across height, 0 for auto calculation. + offset(float): Prior boxes center offset. + name(str): Name of the prior box layer. + + Returns: + boxes(variable): the output prior boxes of PriorBoxOp. The layout is + [H, W, num_priors, 4]. H is the height of input, W is the width + of input, num_priors is the box count of each position. + Variances(variable): the expanded variances of PriorBoxOp. The layout + is [H, W, num_priors, 4]. H is the height of input, W is the width + of input, num_priors is the box count of each position. + + Examples: + .. code-block:: python + + data = fluid.layers.data(name="data", shape=[3, 32, 32], dtype="float32") + conv2d = fluid.layers.conv2d( + input=data, num_filters=2, filter_size=3) + box, var = fluid.layers.prior_box(conv2d, data, + min_size, max_size, aspect_ratio, + variance, flip, clip, + step_w, step_h, offset) """ helper = LayerHelper("prior_box", **locals()) dtype = helper.input_dtype() @@ -3093,19 +3177,51 @@ def prior_boxes(input_layers, name=None): """ **Prior_boxes** - e.g. - prior_boxes( - input_layers = [conv1, conv2, conv3, conv4, conv5, conv6], - image = data, - min_ratio = 0.2, - max_ratio = 0.9, - steps = [8., 16., 32., 64., 100., 300.], - aspect_ratios = [[2.], [2., 3.], [2., 3.], [2., 3.], [2.], [2.]], - min_dim = 300, - offset = 0.5, - variance = [0.1], - flip=True, - clip=True) + + Generate prior boxes for SSD(Single Shot MultiBox Detector) algorithm. + Each position of the input produce N prior boxes, N is determined by + the count of min_sizes, max_sizes and aspect_ratios, The size of the + box is in range(min_size, max_size) interval, which is generated in + sequence according to the aspect_ratios. + + Args: + input(list): The list of input variables, the format of all variables is NCHW. + image(variable): The input image data of PriorBoxOp, the layout is NCHW. + min_ratio(list): the min sizes of generated prior boxes. + max_ratio(list): the max sizes of generated prior boxes. + aspect_ratios(list): the aspect ratios of generated prior boxes. + min_dim(int): + step_w(list): Prior boxes step across width, 0 for auto calculation. + step_h(list): Prior boxes step across height, 0 for auto calculation. + offset(float): Prior boxes center offset. + variance(list): the variances to be encoded in prior boxes. + flip(bool): Whether to flip aspect ratios. + clip(bool): Whether to clip out-of-boundary boxes. + name(str): Name of the prior box layer. + + Returns: + boxes(variable): the output prior boxes of PriorBoxOp. The layout is + [num_priors, 4]. num_priors is the total box count of each + position of input_layers. + Variances(variable): the expanded variances of PriorBoxOp. The layout + is [num_priors, 4]. num_priors is the total box count of each + position of input_layers + + Examples: + .. code-block:: python + + prior_boxes( + input_layers = [conv1, conv2, conv3, conv4, conv5, conv6], + image = data, + min_ratio = 0.2, + max_ratio = 0.9, + steps = [8., 16., 32., 64., 100., 300.], + aspect_ratios = [[2.], [2., 3.], [2., 3.], [2., 3.], [2.], [2.]], + min_dim = 300, + offset = 0.5, + variance = [0.1,0.1,0.1,0.1], + flip=True, + clip=True) """ assert isinstance(input_layers, list), 'input_layer should be a list.' num_layer = len(input_layers) @@ -3168,8 +3284,8 @@ def prior_boxes(input_layers, reshaped_boxes = [] reshaped_vars = [] for i in range(len(box_results)): - reshaped_boxes += [reshape_with_axis(box_results[i], axis=axis)] - reshaped_vars += [reshape_with_axis(var_results[i], axis=axis)] + reshaped_boxes += [reshape_with_axis(box_results[i], axis=[axis])] + reshaped_vars += [reshape_with_axis(var_results[i], axis=[axis])] helper = LayerHelper("concat", **locals()) dtype = helper.input_dtype() From 4a8559c0ccf237ec09c6434accc6e2b76a5e4d06 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Fri, 9 Feb 2018 20:26:25 +0800 Subject: [PATCH 08/43] follow comments and code refine --- paddle/operators/prior_box_op.cc | 8 +- python/paddle/v2/fluid/layers/nn.py | 153 +++++++++++++--------------- 2 files changed, 72 insertions(+), 89 deletions(-) diff --git a/paddle/operators/prior_box_op.cc b/paddle/operators/prior_box_op.cc index 82b4eb1528..064543c2b4 100644 --- a/paddle/operators/prior_box_op.cc +++ b/paddle/operators/prior_box_op.cc @@ -51,11 +51,11 @@ class PriorBoxOp : public framework::OperatorWithKernel { if (max_sizes.size() > 0) { PADDLE_ENFORCE_EQ(max_sizes.size(), min_sizes.size(), "The number of min_size and max_size must be equal."); - for (size_t i = 0; i < min_sizes.size(); ++i) { + num_priors += max_sizes.size(); + for (size_t i = 0; i < max_sizes.size(); ++i) { PADDLE_ENFORCE_GT(max_sizes[i], min_sizes[i], "max_size[%d] must be greater than min_size[%d].", i, i); - num_priors += 1; } } @@ -125,13 +125,13 @@ class PriorBoxOpMaker : public framework::OpProtoAndCheckerMaker { .SetDefault(true); AddAttr("step_w", - "Prior boxes step across width, 0 for auto calculation.") + "Prior boxes step across width, 0.0 for auto calculation.") .SetDefault(0.0) .AddCustomChecker([](const float& step_w) { PADDLE_ENFORCE_GE(step_w, 0.0, "step_w should be larger than 0."); }); AddAttr("step_h", - "Prior boxes step across height, 0 for auto calculation.") + "Prior boxes step across height, 0.0 for auto calculation.") .SetDefault(0.0) .AddCustomChecker([](const float& step_h) { PADDLE_ENFORCE_GE(step_h, 0.0, "step_h should be larger than 0."); diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py index b1b3da46b9..f0bcddaf9a 100644 --- a/python/paddle/v2/fluid/layers/nn.py +++ b/python/paddle/v2/fluid/layers/nn.py @@ -66,7 +66,6 @@ __all__ = [ 'nce', 'beam_search', 'row_conv', - 'reshape', 'reshape_with_axis', 'multiplex', 'prior_box', @@ -3103,12 +3102,11 @@ def reshape_with_axis(input, axis): """ **ReshapeWithAxis Layer** - According to the axis to merge the adjacent dim of input. Currently, the axis of - reshape_with_axis must be a scalar. + ReshapeWithAxis is used to merge adjacent dimensions according to axis. Args: input(variable): The input tensor. - axis(list): According to the axis to merge the adjacent dim. + axis(list): The axis which is used to merge the adjacent dimensions. Returns: Variable: A tensor variable. @@ -3117,7 +3115,7 @@ def reshape_with_axis(input, axis): .. code-block:: python x = fluid.layers.data(name="data", shape=[3, 32, 32], dtype="float32") - reshaped = fluid.layers.reshape_with_axis(input=x, axis=2) + reshaped = fluid.layers.reshape_with_axis(input=x, axis=[2]) reshaped.shape >> [-1, 1024] reshaped = fluid.layers.reshape_with_axis(input=x, axis=[1,3]) @@ -3151,46 +3149,17 @@ def reshape_with_axis(input, axis): return out -def reshape(input, new_shape): - """ - **Reshape Layer** - - Reshape the shape of input according to new_dim. - - Args: - input(variable): The input tensor. - new_shape(list): The new shape of input. - - Returns: - Variable: A tensor variable. - - Examples: - .. code-block:: python - - x = fluid.layers.data(name="data", shape=[3, 32, 32], dtype="float32") - reshaped = fluid.layers.reshape(input=x, new_shape=[-1, 1024]) - """ - helper = LayerHelper('reshape', **locals()) - out = helper.create_tmp_variable(helper.input_dtype()) - helper.append_op( - type='reshape', - inputs={'X': [input]}, - outputs={'Out': [out]}, - attrs={'shape': new_dim}) - return out - - def prior_box(input, image, min_sizes, max_sizes, aspect_ratios, variance, - flip, - clip, - step_w, - step_h, - offset, + flip=False, + clip=False, + step_w=0.0, + step_h=0.0, + offset=0.5, name=None): """ **Prior_box** @@ -3202,27 +3171,33 @@ def prior_box(input, sequence according to the aspect_ratios. Args: - input(variable): The input feature data of PriorBox, the layout is NCHW. - image(variable): The input image data of PriorBoxOp, the layout is NCHW. + input(variable): The input feature data of PriorBox, + the layout is NCHW. + image(variable): The input image data of PriorBox, the + layout is NCHW. min_sizes(list): the min sizes of generated prior boxes. max_sizes(list): the max sizes of generated prior boxes. aspect_ratios(list): the aspect ratios of generated prior boxes. variance(list): the variances to be encoded in prior boxes. - flip(bool): Whether to flip aspect ratios. - clip(bool): Whether to clip out-of-boundary boxes. - step_w(list): Prior boxes step across width, 0 for auto calculation. - step_h(list): Prior boxes step across height, 0 for auto calculation. - offset(float): Prior boxes center offset. - name(str): Name of the prior box layer. + flip(bool, optional, default=False): Whether to flip aspect ratios. + clip(bool, optional, default=False)): Whether to clip + out-of-boundary boxes. + step_w(int, optional, default=0.0): Prior boxes step across + width, 0.0 for auto calculation. + step_h(int, optional, default=0.0): Prior boxes step across + height, 0.0 for auto calculation. + offset(float, optional, default=0.5): Prior boxes center offset. + name(str, optional, default=None): Name of the prior box layer. Returns: boxes(variable): the output prior boxes of PriorBoxOp. The layout is [H, W, num_priors, 4]. H is the height of input, W is the width - of input, num_priors is the box count of each position. + of input, num_priors is the box count of each position. Where num_priors = + len(aspect_ratios) * len(min_sizes) + len(max_sizes) Variances(variable): the expanded variances of PriorBoxOp. The layout is [H, W, num_priors, 4]. H is the height of input, W is the width - of input, num_priors is the box count of each position. - + of input, num_priors is the box count of each position. Where num_priors = + len(aspect_ratios) * len(min_sizes) + len(max_sizes) Examples: .. code-block:: python @@ -3259,70 +3234,78 @@ def prior_box(input, return box, var -def prior_boxes(input_layers, +def prior_boxes(inputs, image, min_ratio, max_ratio, aspect_ratios, - min_dim, + base_size, steps=None, step_w=None, step_h=None, offset=0.5, variance=[0.1, 0.1, 0.1, 0.1], - flip=True, - clip=True, + flip=False, + clip=False, name=None): """ **Prior_boxes** Generate prior boxes for SSD(Single Shot MultiBox Detector) algorithm. - Each position of the input produce N prior boxes, N is determined by - the count of min_sizes, max_sizes and aspect_ratios, The size of the - box is in range(min_size, max_size) interval, which is generated in + Each position of the inputs produces many prior boxes respectly, the number + of prior boxes which is produced by inputs respectly is determined by + the count of min_ratio, max_ratio and aspect_ratios, The size of the + box is in range(min_ratio, max_ratio) interval, which is generated in sequence according to the aspect_ratios. Args: - input(list): The list of input variables, the format of all variables is NCHW. + inputs(list): The list of input variables, the format of all variables is NCHW. image(variable): The input image data of PriorBoxOp, the layout is NCHW. - min_ratio(list): the min sizes of generated prior boxes. - max_ratio(list): the max sizes of generated prior boxes. + min_ratio(int): the min ratio of generated prior boxes. + max_ratio(int): the max ratio of generated prior boxes. aspect_ratios(list): the aspect ratios of generated prior boxes. - min_dim(int): - step_w(list): Prior boxes step across width, 0 for auto calculation. - step_h(list): Prior boxes step across height, 0 for auto calculation. - offset(float): Prior boxes center offset. - variance(list): the variances to be encoded in prior boxes. - flip(bool): Whether to flip aspect ratios. - clip(bool): Whether to clip out-of-boundary boxes. - name(str): Name of the prior box layer. + The length of input and aspect_ratios must be equal. + base_size(int): the base_size is used to get min_size and max_size + according to min_ratio and max_ratio. + step_w(list, optional, default=None): Prior boxes step across width. + If step_w[i] == 0.0, the prior boxes step across width of the inputs[i] + will be automatically calculated. + step_h(list, optional, default=None): Prior boxes step across height, + If step_h[i] == 0.0, the prior boxes step across height of the inputs[i] + will be automatically calculated. + offset(float, optional, default=0.5): Prior boxes center offset. + variance(list, optional, default=[0.1, 0.1, 0.1, 0.1]): the variances + to be encoded in prior boxes. + flip(bool, optional, default=False): Whether to flip aspect ratios. + clip(bool, optional, default=False): Whether to clip out-of-boundary boxes. + name(str, optional, None): Name of the prior box layer. Returns: boxes(variable): the output prior boxes of PriorBoxOp. The layout is [num_priors, 4]. num_priors is the total box count of each - position of input_layers. + position of inputs. Variances(variable): the expanded variances of PriorBoxOp. The layout is [num_priors, 4]. num_priors is the total box count of each - position of input_layers + position of inputs Examples: .. code-block:: python prior_boxes( - input_layers = [conv1, conv2, conv3, conv4, conv5, conv6], + inputs = [conv1, conv2, conv3, conv4, conv5, conv6], image = data, - min_ratio = 0.2, - max_ratio = 0.9, + min_ratio = 20, # 0.20 + max_ratio = 90, # 0.90 steps = [8., 16., 32., 64., 100., 300.], aspect_ratios = [[2.], [2., 3.], [2., 3.], [2., 3.], [2.], [2.]], - min_dim = 300, + base_size = 300, offset = 0.5, variance = [0.1,0.1,0.1,0.1], flip=True, clip=True) """ - assert isinstance(input_layers, list), 'input_layer should be a list.' - num_layer = len(input_layers) + assert isinstance(inputs, list), 'inputs should be a list.' + num_layer = len(inputs) assert num_layer > 2 # TODO(zcd): currently, num_layer must be bigger than two. min_sizes = [] @@ -3330,30 +3313,30 @@ def prior_boxes(input_layers, if num_layer > 2: step = int(math.floor(((max_ratio - min_ratio)) / (num_layer - 2))) for ratio in xrange(min_ratio, max_ratio + 1, step): - min_sizes.append(min_dim * ratio / 100.) - max_sizes.append(min_dim * (ratio + step) / 100.) - min_sizes = [min_dim * .10] + min_sizes - max_sizes = [min_dim * .20] + max_sizes + min_sizes.append(base_size * ratio / 100.) + max_sizes.append(base_size * (ratio + step) / 100.) + min_sizes = [base_size * .10] + min_sizes + max_sizes = [base_size * .20] + max_sizes if step_h: assert isinstance(step_h,list) and len(step_h) == num_layer, \ - 'step_h should be list and input_layers and step_h should have same length' + 'step_h should be list and inputs and step_h should have same length' if step_w: assert isinstance(step_w,list) and len(step_w) == num_layer, \ - 'step_w should be list and input_layers and step_w should have same length' + 'step_w should be list and inputs and step_w should have same length' if steps: assert isinstance(steps,list) and len(steps) == num_layer, \ - 'steps should be list and input_layers and step_w should have same length' + 'steps should be list and inputs and step_w should have same length' step_w = steps step_h = steps if aspect_ratios: assert isinstance(aspect_ratios, list) and len(aspect_ratios) == num_layer, \ - 'aspect_ratios should be list and input_layers and aspect_ratios should ' \ + 'aspect_ratios should be list and inputs and aspect_ratios should ' \ 'have same length' box_results = [] var_results = [] - for i, input in enumerate(input_layers): + for i, input in enumerate(inputs): min_size = min_sizes[i] max_size = max_sizes[i] aspect_ratio = [] From df7c29e5165847f9958903f8c492f566b3df63fc Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Sat, 10 Feb 2018 20:54:50 +0800 Subject: [PATCH 09/43] override comparison operators in Python for Variable --- python/paddle/v2/fluid/layers/math_op_patch.py | 6 +++++- python/paddle/v2/fluid/learning_rate_decay.py | 4 ++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/python/paddle/v2/fluid/layers/math_op_patch.py b/python/paddle/v2/fluid/layers/math_op_patch.py index 9b5f22759c..4cf995ec85 100644 --- a/python/paddle/v2/fluid/layers/math_op_patch.py +++ b/python/paddle/v2/fluid/layers/math_op_patch.py @@ -151,7 +151,11 @@ def monkey_patch_variable(): ("__div__", "elementwise_div", False), ("__rdiv__", "elementwise_div", True), ("__pow__", "elementwise_pow", False), - ("__rpow__", "elementwise_pow", True)): + ("__rpow__", "elementwise_pow", True), + # for logical compare + ("__eq__", "equal", False), + ("__lt__", "less_then", False), + ("__le__", "less_equal", False), ): setattr(Variable, method_name, _elemwise_method_creator_(method_name, op_type, reverse)) diff --git a/python/paddle/v2/fluid/learning_rate_decay.py b/python/paddle/v2/fluid/learning_rate_decay.py index 2a2a29fd9c..0826d3da79 100644 --- a/python/paddle/v2/fluid/learning_rate_decay.py +++ b/python/paddle/v2/fluid/learning_rate_decay.py @@ -179,7 +179,7 @@ def polynomial_decay(learning_rate, shape=[1], dtype='float32', value=1.0) with layers.Switch() as switch: - with switch.case(layers.equal(x=global_step, y=zero_var)): + with switch.case(global_step == zero_var): layers.assign(input=one_var, output=div_res) decay_steps = decay_steps * div_res else: @@ -229,7 +229,7 @@ def piecewise_decay(global_step, boundaries, values): shape=[1], dtype='float32', value=float(boundaries[i])) value_var = layers.fill_constant( shape=[1], dtype='float32', value=float(values[i])) - with switch.case(layers.less_than(global_step, boundary_val)): + with switch.case(global_step < boundary_val): layers.assign(value_var, lr) last_value_var = layers.fill_constant( shape=[1], From 6ed545b0d8386c0a4ed0d978bb15cf9acd29c42f Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Sat, 10 Feb 2018 21:09:53 +0800 Subject: [PATCH 10/43] fix typo --- python/paddle/v2/fluid/layers/math_op_patch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/v2/fluid/layers/math_op_patch.py b/python/paddle/v2/fluid/layers/math_op_patch.py index 4cf995ec85..5301c3d1de 100644 --- a/python/paddle/v2/fluid/layers/math_op_patch.py +++ b/python/paddle/v2/fluid/layers/math_op_patch.py @@ -154,7 +154,7 @@ def monkey_patch_variable(): ("__rpow__", "elementwise_pow", True), # for logical compare ("__eq__", "equal", False), - ("__lt__", "less_then", False), + ("__lt__", "less_than", False), ("__le__", "less_equal", False), ): setattr(Variable, method_name, _elemwise_method_creator_(method_name, op_type, reverse)) From 4b3fadc1cdcd3d2cbc4c7cf63d4a96d7db45fce6 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Sat, 10 Feb 2018 22:16:27 +0800 Subject: [PATCH 11/43] init test_python_operator_overriding.py --- .../tests/test_python_operator_overriding.py | 54 +++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 python/paddle/v2/fluid/tests/test_python_operator_overriding.py diff --git a/python/paddle/v2/fluid/tests/test_python_operator_overriding.py b/python/paddle/v2/fluid/tests/test_python_operator_overriding.py new file mode 100644 index 0000000000..b985ae3e29 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_python_operator_overriding.py @@ -0,0 +1,54 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy +import paddle.v2.fluid.framework as framework +import paddle.v2.fluid as fluid + + +class TestPythonOperatorOverride(unittest.TestCase): + def check_result(self, fn, place, dtype='float32'): + shape = [9, 10] + + x_data = numpy.random.random(size=shape).astype(dtype) + y_data = numpy.random.random(size=shape).astype(dtype) + python_out = fn(x_data, y_data) + + x_var = fluid.layers.data(name='x', shape=shape, dtype=dtype) + y_var = fluid.layers.data(name='y', shape=shape, dtype=dtype) + out = fn(x_var, y_var) + + exe = fluid.Executor(place) + feeder = fluid.DataFeeder(feed_list=[x_var, y_var], place=place) + + exe.run(fluid.default_startup_program()) + fluid_out = exe.run(fluid.default_main_program(), + feed=feeder.feed([x_data, y_data]), + fetch_list=[out]) + + print(python_out) + self.assertAlmostEqual(python_out, fluid_out[0]) + + def test_override(self): + main_program = framework.Program() + startup_program = framework.Program() + with framework.program_guard(main_program, startup_program): + place = fluid.CPUPlace() + self.check_result(lambda _a, _b: _a == _b, place) + + +if __name__ == '__main__': + unittest.main() From d89e1449b7701d759b6e3180f12ea430320db18d Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Sat, 10 Feb 2018 22:41:54 +0800 Subject: [PATCH 12/43] optimize test --- .../tests/test_python_operator_overriding.py | 36 ++++++++++++------- 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/python/paddle/v2/fluid/tests/test_python_operator_overriding.py b/python/paddle/v2/fluid/tests/test_python_operator_overriding.py index b985ae3e29..94f3fc958e 100644 --- a/python/paddle/v2/fluid/tests/test_python_operator_overriding.py +++ b/python/paddle/v2/fluid/tests/test_python_operator_overriding.py @@ -14,40 +14,52 @@ import unittest -import numpy +import numpy as np + +import paddle.v2.fluid.layers as layers import paddle.v2.fluid.framework as framework import paddle.v2.fluid as fluid class TestPythonOperatorOverride(unittest.TestCase): - def check_result(self, fn, place, dtype='float32'): + def check_result(self, fn, x_val, y_val, place, dtype): shape = [9, 10] - x_data = numpy.random.random(size=shape).astype(dtype) - y_data = numpy.random.random(size=shape).astype(dtype) + x_data = np.full(shape, x_val).astype(dtype) + y_data = np.full(shape, y_val).astype(dtype) python_out = fn(x_data, y_data) - x_var = fluid.layers.data(name='x', shape=shape, dtype=dtype) - y_var = fluid.layers.data(name='y', shape=shape, dtype=dtype) + x_var = layers.create_global_var( + shape=shape, value=x_val, dtype=dtype, persistable=True) + y_var = layers.create_global_var( + shape=shape, value=y_val, dtype=dtype, persistable=True) out = fn(x_var, y_var) exe = fluid.Executor(place) - feeder = fluid.DataFeeder(feed_list=[x_var, y_var], place=place) exe.run(fluid.default_startup_program()) fluid_out = exe.run(fluid.default_main_program(), - feed=feeder.feed([x_data, y_data]), + feed=[], fetch_list=[out]) - print(python_out) - self.assertAlmostEqual(python_out, fluid_out[0]) + np.testing.assert_array_equal(python_out, fluid_out[0]) def test_override(self): + cpu_place = fluid.CPUPlace() + test_data = [(lambda _a, _b: _a == _b, 0.1, 1.1, cpu_place, 'float32'), + (lambda _a, _b: _a == _b, 1.2, 1.1, cpu_place, 'float32'), + (lambda _a, _b: _a < _b, 0.1, 1.1, cpu_place, 'float32'), + (lambda _a, _b: _a < _b, 2.1, 1.1, cpu_place, 'float32'), + (lambda _a, _b: _a <= _b, 0.1, 1.1, cpu_place, 'float32'), + (lambda _a, _b: _a <= _b, 1.1, 1.1, cpu_place, 'float32'), + (lambda _a, _b: _a >= _b, 1.1, 1.1, cpu_place, 'float32')] + main_program = framework.Program() startup_program = framework.Program() + with framework.program_guard(main_program, startup_program): - place = fluid.CPUPlace() - self.check_result(lambda _a, _b: _a == _b, place) + for fn, x_val, y_val, place, dtype in test_data: + self.check_result(fn, x_val, y_val, place, dtype) if __name__ == '__main__': From de469d58380dd4376d905165678ad05eee9e3e17 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Sun, 11 Feb 2018 10:17:23 +0800 Subject: [PATCH 13/43] optimize test --- .../tests/test_python_operator_overriding.py | 53 +++++++++++-------- 1 file changed, 32 insertions(+), 21 deletions(-) diff --git a/python/paddle/v2/fluid/tests/test_python_operator_overriding.py b/python/paddle/v2/fluid/tests/test_python_operator_overriding.py index 94f3fc958e..b9e2623bdd 100644 --- a/python/paddle/v2/fluid/tests/test_python_operator_overriding.py +++ b/python/paddle/v2/fluid/tests/test_python_operator_overriding.py @@ -22,44 +22,55 @@ import paddle.v2.fluid as fluid class TestPythonOperatorOverride(unittest.TestCase): - def check_result(self, fn, x_val, y_val, place, dtype): + def check_result(self, fn, place, dtype): shape = [9, 10] - x_data = np.full(shape, x_val).astype(dtype) - y_data = np.full(shape, y_val).astype(dtype) + x_data = np.random.random(size=shape).astype(dtype) + y_data = np.random.random(size=shape).astype(dtype) python_out = fn(x_data, y_data) x_var = layers.create_global_var( - shape=shape, value=x_val, dtype=dtype, persistable=True) + name='x', shape=shape, value=0.0, dtype=dtype, persistable=True) y_var = layers.create_global_var( - shape=shape, value=y_val, dtype=dtype, persistable=True) + name='y', shape=shape, value=0.0, dtype=dtype, persistable=True) out = fn(x_var, y_var) exe = fluid.Executor(place) exe.run(fluid.default_startup_program()) fluid_out = exe.run(fluid.default_main_program(), - feed=[], + feed={'x': x_data, + 'y': y_data}, fetch_list=[out]) np.testing.assert_array_equal(python_out, fluid_out[0]) def test_override(self): - cpu_place = fluid.CPUPlace() - test_data = [(lambda _a, _b: _a == _b, 0.1, 1.1, cpu_place, 'float32'), - (lambda _a, _b: _a == _b, 1.2, 1.1, cpu_place, 'float32'), - (lambda _a, _b: _a < _b, 0.1, 1.1, cpu_place, 'float32'), - (lambda _a, _b: _a < _b, 2.1, 1.1, cpu_place, 'float32'), - (lambda _a, _b: _a <= _b, 0.1, 1.1, cpu_place, 'float32'), - (lambda _a, _b: _a <= _b, 1.1, 1.1, cpu_place, 'float32'), - (lambda _a, _b: _a >= _b, 1.1, 1.1, cpu_place, 'float32')] - - main_program = framework.Program() - startup_program = framework.Program() - - with framework.program_guard(main_program, startup_program): - for fn, x_val, y_val, place, dtype in test_data: - self.check_result(fn, x_val, y_val, place, dtype) + # compare func to check + compare_fns = [ + lambda _a, _b: _a == _b, + lambda _a, _b: _a == _b, + lambda _a, _b: _a < _b, + lambda _a, _b: _a < _b, + lambda _a, _b: _a <= _b, + lambda _a, _b: _a <= _b, + lambda _a, _b: _a >= _b, + ] + + # places to check + places = [fluid.CPUPlace()] + if fluid.core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + + # dtypes to check + dtypes = ['int32', 'float32'] + + for place in places: + for dtype in dtypes: + for compare_fn in compare_fns: + with framework.program_guard(framework.Program(), + gframework.Program()): + self.check_result(compare_fn, place, dtype) if __name__ == '__main__': From 23ba79b16b7135503a5ec804071de5ba22f57ce2 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Sun, 11 Feb 2018 10:19:38 +0800 Subject: [PATCH 14/43] fix typo --- .../v2/fluid/tests/test_python_operator_overriding.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/paddle/v2/fluid/tests/test_python_operator_overriding.py b/python/paddle/v2/fluid/tests/test_python_operator_overriding.py index b9e2623bdd..aecae3332b 100644 --- a/python/paddle/v2/fluid/tests/test_python_operator_overriding.py +++ b/python/paddle/v2/fluid/tests/test_python_operator_overriding.py @@ -49,10 +49,8 @@ class TestPythonOperatorOverride(unittest.TestCase): # compare func to check compare_fns = [ lambda _a, _b: _a == _b, - lambda _a, _b: _a == _b, - lambda _a, _b: _a < _b, lambda _a, _b: _a < _b, - lambda _a, _b: _a <= _b, + lambda _a, _b: _a > _b, lambda _a, _b: _a <= _b, lambda _a, _b: _a >= _b, ] @@ -69,7 +67,7 @@ class TestPythonOperatorOverride(unittest.TestCase): for dtype in dtypes: for compare_fn in compare_fns: with framework.program_guard(framework.Program(), - gframework.Program()): + framework.Program()): self.check_result(compare_fn, place, dtype) From 6f78cb996912d056c7df131838d2c0a79a018e19 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Sun, 11 Feb 2018 10:34:24 +0800 Subject: [PATCH 15/43] add not_equal --- paddle/fluid/operators/compare_op.cc | 2 ++ paddle/fluid/operators/compare_op.cu | 1 + paddle/fluid/operators/compare_op.h | 8 ++++++++ python/paddle/v2/fluid/layers/math_op_patch.py | 3 ++- .../v2/fluid/tests/test_python_operator_overriding.py | 1 + 5 files changed, 14 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/compare_op.cc b/paddle/fluid/operators/compare_op.cc index f3414c33b5..b1f09fb002 100644 --- a/paddle/fluid/operators/compare_op.cc +++ b/paddle/fluid/operators/compare_op.cc @@ -102,3 +102,5 @@ REGISTER_LOGICAL_OP(less_equal, "Out = X <= Y"); REGISTER_LOGICAL_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor); REGISTER_LOGICAL_OP(equal, "Out = X == Y"); REGISTER_LOGICAL_KERNEL(equal, CPU, paddle::operators::EqualFunctor); +REGISTER_LOGICAL_OP(not_equal, "Out = X != Y"); +REGISTER_LOGICAL_KERNEL(not_equal, CPU, paddle::operators::NotEqualFunctor); diff --git a/paddle/fluid/operators/compare_op.cu b/paddle/fluid/operators/compare_op.cu index 3507af2ae3..00263a2ade 100644 --- a/paddle/fluid/operators/compare_op.cu +++ b/paddle/fluid/operators/compare_op.cu @@ -17,3 +17,4 @@ limitations under the License. */ REGISTER_LOGICAL_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor); REGISTER_LOGICAL_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor); REGISTER_LOGICAL_KERNEL(equal, CUDA, paddle::operators::EqualFunctor); +REGISTER_LOGICAL_KERNEL(not_equal, CUDA, paddle::operators::NotEqualFunctor); diff --git a/paddle/fluid/operators/compare_op.h b/paddle/fluid/operators/compare_op.h index 4b2ee5a9d6..c651335268 100644 --- a/paddle/fluid/operators/compare_op.h +++ b/paddle/fluid/operators/compare_op.h @@ -48,6 +48,14 @@ struct EqualFunctor { } }; +template +struct NotEqualFunctor { + using ELEM_TYPE = T; + HOSTDEVICE bool operator()(const T& a, const T& b) const { + return !EqualFunctor()(a, b); + } +}; + template class CompareOpKernel : public framework::OpKernel { diff --git a/python/paddle/v2/fluid/layers/math_op_patch.py b/python/paddle/v2/fluid/layers/math_op_patch.py index 5301c3d1de..8208629af7 100644 --- a/python/paddle/v2/fluid/layers/math_op_patch.py +++ b/python/paddle/v2/fluid/layers/math_op_patch.py @@ -154,8 +154,9 @@ def monkey_patch_variable(): ("__rpow__", "elementwise_pow", True), # for logical compare ("__eq__", "equal", False), + ("__ne__", "not_equal", False), ("__lt__", "less_than", False), - ("__le__", "less_equal", False), ): + ("__le__", "less_equal", False)): setattr(Variable, method_name, _elemwise_method_creator_(method_name, op_type, reverse)) diff --git a/python/paddle/v2/fluid/tests/test_python_operator_overriding.py b/python/paddle/v2/fluid/tests/test_python_operator_overriding.py index aecae3332b..5ef0097388 100644 --- a/python/paddle/v2/fluid/tests/test_python_operator_overriding.py +++ b/python/paddle/v2/fluid/tests/test_python_operator_overriding.py @@ -53,6 +53,7 @@ class TestPythonOperatorOverride(unittest.TestCase): lambda _a, _b: _a > _b, lambda _a, _b: _a <= _b, lambda _a, _b: _a >= _b, + lambda _a, _b: _a != _b, ] # places to check From b19ef3f05e81a9564d1b26dde474f44a6f1bc7be Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Sun, 11 Feb 2018 10:38:02 +0800 Subject: [PATCH 16/43] optimize code --- .../paddle/v2/fluid/tests/test_python_operator_overriding.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/v2/fluid/tests/test_python_operator_overriding.py b/python/paddle/v2/fluid/tests/test_python_operator_overriding.py index 5ef0097388..e5198ec17d 100644 --- a/python/paddle/v2/fluid/tests/test_python_operator_overriding.py +++ b/python/paddle/v2/fluid/tests/test_python_operator_overriding.py @@ -49,11 +49,11 @@ class TestPythonOperatorOverride(unittest.TestCase): # compare func to check compare_fns = [ lambda _a, _b: _a == _b, + lambda _a, _b: _a != _b, lambda _a, _b: _a < _b, - lambda _a, _b: _a > _b, lambda _a, _b: _a <= _b, + lambda _a, _b: _a > _b, lambda _a, _b: _a >= _b, - lambda _a, _b: _a != _b, ] # places to check From 006ef1fd7a551e109b8ac294cdb9cc012d2a5161 Mon Sep 17 00:00:00 2001 From: wanghaox Date: Sun, 11 Feb 2018 11:17:48 +0800 Subject: [PATCH 17/43] migrate detection_map code directory --- paddle/{ => fluid}/operators/detection_map_op.cc | 2 +- paddle/{ => fluid}/operators/detection_map_op.h | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) rename paddle/{ => fluid}/operators/detection_map_op.cc (99%) rename paddle/{ => fluid}/operators/detection_map_op.h (99%) diff --git a/paddle/operators/detection_map_op.cc b/paddle/fluid/operators/detection_map_op.cc similarity index 99% rename from paddle/operators/detection_map_op.cc rename to paddle/fluid/operators/detection_map_op.cc index 1ab691eb4f..cc4b6202c0 100644 --- a/paddle/operators/detection_map_op.cc +++ b/paddle/fluid/operators/detection_map_op.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/detection_map_op.h" +#include "paddle/fluid/operators/detection_map_op.h" namespace paddle { namespace operators { diff --git a/paddle/operators/detection_map_op.h b/paddle/fluid/operators/detection_map_op.h similarity index 99% rename from paddle/operators/detection_map_op.h rename to paddle/fluid/operators/detection_map_op.h index fd0ddd10aa..0379a3328a 100644 --- a/paddle/operators/detection_map_op.h +++ b/paddle/fluid/operators/detection_map_op.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once -#include "paddle/framework/eigen.h" -#include "paddle/framework/op_registry.h" +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" namespace paddle { namespace operators { From cf2ed179940e3de30bafb9e9e89587424c58e1b2 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Sun, 11 Feb 2018 11:18:32 +0800 Subject: [PATCH 18/43] fix prior_op unit test --- python/paddle/v2/fluid/tests/test_prior_box_op.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/v2/fluid/tests/test_prior_box_op.py b/python/paddle/v2/fluid/tests/test_prior_box_op.py index 25dfc4307c..a6c21af49f 100644 --- a/python/paddle/v2/fluid/tests/test_prior_box_op.py +++ b/python/paddle/v2/fluid/tests/test_prior_box_op.py @@ -65,9 +65,9 @@ class TestPriorBoxOp(OpTest): self.batch_size = 10 self.min_sizes = [2, 4] - self.min_sizes = np.array(self.min_sizes).astype('float32') + self.min_sizes = np.array(self.min_sizes).astype('float32').tolist() self.max_sizes = [5, 10] - self.max_sizes = np.array(self.max_sizes).astype('float32') + self.max_sizes = np.array(self.max_sizes).astype('float32').tolist() self.aspect_ratios = [2.0, 3.0] self.flip = True self.real_aspect_ratios = [1, 2.0, 1.0 / 2.0, 3.0, 1.0 / 3.0] From 593bec2c225fb8ababb2c50b116adc689d635575 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Sun, 11 Feb 2018 12:22:16 +0800 Subject: [PATCH 19/43] update test_layers --- python/paddle/v2/fluid/tests/test_layers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/v2/fluid/tests/test_layers.py b/python/paddle/v2/fluid/tests/test_layers.py index aea43c2517..fa46f86973 100644 --- a/python/paddle/v2/fluid/tests/test_layers.py +++ b/python/paddle/v2/fluid/tests/test_layers.py @@ -161,8 +161,8 @@ class TestBook(unittest.TestCase): label=label, chunk_scheme="IOB", num_chunk_types=(label_dict_len - 1) / 2) - self.assertNotEqual(crf, None) - self.assertNotEqual(crf_decode, None) + self.assertFalse(crf is None) + self.assertFalse(crf_decode is None) print(str(program)) From a43fac35676ba391da1aabaadd3edb19fab4e087 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Sun, 11 Feb 2018 14:38:14 +0800 Subject: [PATCH 20/43] Fix empty Vector foreach Fix #8368 --- paddle/fluid/framework/mixed_vector.h | 33 +++++++++++++-------- paddle/fluid/framework/mixed_vector_test.cu | 6 ++++ 2 files changed, 26 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/framework/mixed_vector.h b/paddle/fluid/framework/mixed_vector.h index 4dc3de54de..a35ec5d1de 100644 --- a/paddle/fluid/framework/mixed_vector.h +++ b/paddle/fluid/framework/mixed_vector.h @@ -106,9 +106,9 @@ class Vector { // std::vector iterator methods. Based on CPU data access method size_t size() const { return size_; } - T* begin() { return &this->operator[](0); } + T* begin() { return size() == 0 ? &EmptyDummy() : &this->operator[](0); } - T* end() { return &this->operator[](size()); } + T* end() { return size() == 0 ? &EmptyDummy() : &this->operator[](size()); } T& front() { return *begin(); } @@ -118,12 +118,12 @@ class Vector { return *it; } - const T* begin() const { return &this->operator[](0); } - const T* end() const { return &this->operator[](size()); } - - const T* cbegin() const { return begin(); } - - const T* cend() const { return end(); } + const T* begin() const { + return size() == 0 ? &EmptyDummy() : &this->operator[](0); + } + const T* end() const { + return size() == 0 ? &EmptyDummy() : &this->operator[](size()); + } const T& back() const { auto it = end(); @@ -240,16 +240,18 @@ class Vector { // implicit cast operator. Vector can be cast to std::vector implicitly. operator std::vector() const { std::vector result; - result.resize(size()); - std::copy(begin(), end(), result.begin()); + if (size() == 0) { + result.resize(size()); + std::copy(begin(), end(), result.begin()); + } return result; } bool operator==(const Vector& other) const { if (size() != other.size()) return false; - auto it1 = cbegin(); - auto it2 = other.cbegin(); - for (; it1 < cend(); ++it1, ++it2) { + auto it1 = begin(); + auto it2 = other.begin(); + for (; it1 < end(); ++it1, ++it2) { if (*it1 != *it2) { return false; } @@ -358,6 +360,11 @@ class Vector { } } + static T& EmptyDummy() { + static T dummy = T(); + return dummy; + } + mutable int flag_; mutable Tensor cpu_vec_; mutable Tensor cuda_vec_; diff --git a/paddle/fluid/framework/mixed_vector_test.cu b/paddle/fluid/framework/mixed_vector_test.cu index 0d5a914eac..8ea574b31c 100644 --- a/paddle/fluid/framework/mixed_vector_test.cu +++ b/paddle/fluid/framework/mixed_vector_test.cu @@ -98,3 +98,9 @@ TEST(mixed_vector, InitWithCount) { ASSERT_EQ(vec[i], 10); } } + +TEST(mixed_vector, ForEach) { + vec tmp; + for (auto& v : tmp) { + } +} From caf9a09d7bee946969999130477fb5de2983007b Mon Sep 17 00:00:00 2001 From: Yancey Date: Sun, 11 Feb 2018 15:44:27 +0800 Subject: [PATCH 21/43] Merge selected rows with dynamic variable count (#8023) * dynamic send/recv selected rows * update by comment * fix by comment --- paddle/fluid/operators/listen_and_serv_op.cc | 16 +++++++++++++ paddle/fluid/operators/send_op.cc | 24 +++++++++++++++++-- .../fluid/operators/split_selected_rows_op.cc | 23 +----------------- .../fluid/operators/split_selected_rows_op.h | 1 + paddle/fluid/operators/sum_op.h | 4 +++- .../paddle/v2/fluid/distribute_transpiler.py | 4 ++++ 6 files changed, 47 insertions(+), 25 deletions(-) diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index 3730ae161f..426dd0dc0e 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -101,6 +101,9 @@ class ListenAndServOp : public framework::OperatorBase { // TODO(typhoonzero): change this to a while_op for every cluster-batch. bool exit_flag = false; + // Record received sparse variables, so that + // we could reset those after execute optimize program + std::vector sparse_vars; while (!exit_flag) { // Get from multiple trainers, we don't care about the order in which // the gradients arrives, just add suffix 0~n and merge the gradient. @@ -143,6 +146,9 @@ class ListenAndServOp : public framework::OperatorBase { PADDLE_THROW("Can not find server side var"); } detail::DeserializeFromMessage(v.second, dev_ctx, var); + if (var->IsType()) { + sparse_vars.push_back(var); + } } } VLOG(3) << "recv " << recv_var_cnt << " parmeters for one barrier."; @@ -156,9 +162,19 @@ class ListenAndServOp : public framework::OperatorBase { } catch (std::exception &e) { LOG(ERROR) << "run sub program error " << e.what(); } + + // Reset the received sparse variables, the sum operator would not + // sum the input sparse variables which rows is empty at the next + // mini-batch. + // TOOD(Yancey1989): move the reset action into an operator, we couldn't + // have any hide logic in the operator. + for (auto &var : sparse_vars) { + var->GetMutable()->mutable_rows()->clear(); + } rpc_service_->SetCond(1); rpc_service_->WaitClientGet(update_param_cnt); grads_counter_.clear(); + sparse_vars.clear(); } // while(true) } diff --git a/paddle/fluid/operators/send_op.cc b/paddle/fluid/operators/send_op.cc index a8390aa659..b241f738cb 100644 --- a/paddle/fluid/operators/send_op.cc +++ b/paddle/fluid/operators/send_op.cc @@ -24,6 +24,22 @@ limitations under the License. */ namespace paddle { namespace operators { +static bool IsVariableInitialized(const framework::Scope& scope, + const std::string& varname) { + auto* var = scope.FindVar(varname); + PADDLE_ENFORCE_NOT_NULL(var, "Can not find variable '%s' in the send side.", + varname); + if (var->IsType()) { + return var->Get().IsInitialized(); + } else if (var->IsType()) { + return var->Get().value().IsInitialized(); + } else { + PADDLE_THROW( + "Variable type in send side should be in " + "[LodTensor, SelectedRows]"); + } + return false; +} class SendOp : public framework::OperatorBase { public: @@ -51,8 +67,12 @@ class SendOp : public framework::OperatorBase { detail::RPCClient* rpc_client = client_var->GetMutable(); for (size_t i = 0; i < ins.size(); i++) { - VLOG(3) << "sending " << ins[i] << " to " << epmap[i]; - rpc_client->AsyncSendVariable(epmap[i], ctx, scope, ins[i]); + if (IsVariableInitialized(scope, ins[i])) { + VLOG(3) << "sending " << ins[i] << " to " << epmap[i]; + rpc_client->AsyncSendVariable(epmap[i], ctx, scope, ins[i]); + } else { + VLOG(3) << "don't send no-initialied variable: " << ins[i]; + } } PADDLE_ENFORCE(rpc_client->Wait()); diff --git a/paddle/fluid/operators/split_selected_rows_op.cc b/paddle/fluid/operators/split_selected_rows_op.cc index 113ce2ce10..c30280f654 100644 --- a/paddle/fluid/operators/split_selected_rows_op.cc +++ b/paddle/fluid/operators/split_selected_rows_op.cc @@ -22,7 +22,7 @@ class SplitSelectedRowsOpMaker : public framework::OpProtoAndCheckerMaker { SplitSelectedRowsOpMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "The input SelectedRows."); - AddOutput("Out", "The outputs of input SelectedRows.").AsDuplicable(); + AddOutput("Out", "The outputs of the input SelectedRows.").AsDuplicable(); AddAttr>("height_sections", "Height for each output SelectedRows.") .SetDefault(std::vector({})); @@ -56,27 +56,6 @@ class SplitSelectedRowsOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasInput("X"), "SplitSelectedRowsOp must has input X."); PADDLE_ENFORCE(ctx->HasOutputs("Out"), "SplitSelectedRowsOp must has output Out."); - - std::vector height_sections = - ctx->Attrs().Get>("height_sections"); - int64_t n = ctx->Outputs("Out").size(); - - std::vector outs_dims; - outs_dims.reserve(n); - - // make output dims - for (int64_t i = 0; i < n; ++i) { - auto dims = ctx->GetInputDim("X"); - if (height_sections.size()) { - PADDLE_ENFORCE_EQ( - height_sections.size(), static_cast(n), - "The size of height section should be the same with height" - " section size."); - dims[0] = height_sections[i]; - } - outs_dims.push_back(dims); - } - ctx->SetOutputsDim("Out", outs_dims); } }; diff --git a/paddle/fluid/operators/split_selected_rows_op.h b/paddle/fluid/operators/split_selected_rows_op.h index 527264bd67..af44b09b70 100644 --- a/paddle/fluid/operators/split_selected_rows_op.h +++ b/paddle/fluid/operators/split_selected_rows_op.h @@ -55,6 +55,7 @@ class SplitSelectedRowsOpKernel : public framework::OpKernel { for (size_t i = 0; i < outs_rows_idx.size(); ++i) { auto rows_idx = outs_rows_idx[i]; + outs[i]->set_height(height_sections[i]); if (rows_idx.size() > 0) { auto dims = x->GetCompleteDims(); dims[0] = rows_idx.size(); diff --git a/paddle/fluid/operators/sum_op.h b/paddle/fluid/operators/sum_op.h index 5e1222c6ef..08218b6836 100644 --- a/paddle/fluid/operators/sum_op.h +++ b/paddle/fluid/operators/sum_op.h @@ -116,7 +116,9 @@ class SumKernel : public framework::OpKernel { int64_t offset = 0; for (int i = 0; i < N; i++) { auto &sel_row = get_selected_row(i); - + if (!sel_row.value().IsInitialized() || sel_row.rows().size() == 0) { + continue; + } PADDLE_ENFORCE_EQ(out->height(), sel_row.height()); functor(context.template device_context(), sel_row, offset, out); diff --git a/python/paddle/v2/fluid/distribute_transpiler.py b/python/paddle/v2/fluid/distribute_transpiler.py index edef2b1b17..e4675e24b1 100644 --- a/python/paddle/v2/fluid/distribute_transpiler.py +++ b/python/paddle/v2/fluid/distribute_transpiler.py @@ -191,6 +191,7 @@ class DistributeTranspiler: for b in param_blocks: varname, block_id, _ = b.split(":") send_outputs.append(param_var_mapping[varname][int(block_id)]) + # let send_op know which endpoint to send which var to, eplist has the same # order as send_inputs. eplist = split_method(send_inputs, pserver_endpoints) @@ -274,6 +275,7 @@ class DistributeTranspiler: name="%s.block%d" % (varname, i), psersistable=False, dtype=orig_var.dtype, + type=orig_var.type, shape=splited_shape) # flattend splited var var_mapping[varname].append(var) return var_mapping @@ -335,6 +337,7 @@ class DistributeTranspiler: name="%s.trainer_%d" % (var.name, i), psersistable=var.persistable, dtype=var.dtype, + type=var.type, shape=var.shape) var_list.append(var_each) return var_list @@ -561,6 +564,7 @@ class DistributeTranspiler: persistable=True, dtype=v.dtype, shape=v.shape) + # step6 optimize_block = pserver_program.create_block(0) # step 6.1 From 18efe5aa1d8a6395dea68cfaa299fe636a22509e Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Sun, 11 Feb 2018 15:48:00 +0800 Subject: [PATCH 22/43] Fix CI --- paddle/fluid/framework/mixed_vector.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/framework/mixed_vector.h b/paddle/fluid/framework/mixed_vector.h index a35ec5d1de..b834d4633b 100644 --- a/paddle/fluid/framework/mixed_vector.h +++ b/paddle/fluid/framework/mixed_vector.h @@ -240,7 +240,7 @@ class Vector { // implicit cast operator. Vector can be cast to std::vector implicitly. operator std::vector() const { std::vector result; - if (size() == 0) { + if (size() != 0) { result.resize(size()); std::copy(begin(), end(), result.begin()); } From 2cfb2928dbe1b3c6848e9c4a8d187c3e1e4245ca Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Sun, 11 Feb 2018 16:44:52 +0800 Subject: [PATCH 23/43] Fix develop dist transpiler bug --- .../paddle/v2/fluid/distribute_transpiler.py | 78 ++++++++----------- 1 file changed, 34 insertions(+), 44 deletions(-) diff --git a/python/paddle/v2/fluid/distribute_transpiler.py b/python/paddle/v2/fluid/distribute_transpiler.py index e4675e24b1..62d1f3434c 100644 --- a/python/paddle/v2/fluid/distribute_transpiler.py +++ b/python/paddle/v2/fluid/distribute_transpiler.py @@ -191,7 +191,6 @@ class DistributeTranspiler: for b in param_blocks: varname, block_id, _ = b.split(":") send_outputs.append(param_var_mapping[varname][int(block_id)]) - # let send_op know which endpoint to send which var to, eplist has the same # order as send_inputs. eplist = split_method(send_inputs, pserver_endpoints) @@ -230,21 +229,6 @@ class DistributeTranspiler: outputs={"Out": [orig_param]}, attrs={"axis": 0}) - self.lr_param_mapping = self._create_lr_param_mapping() - - def _create_lr_param_mapping(self): - lr_mapping = dict() - for _, opt_op in enumerate(self.optimize_ops): - if not opt_op.inputs or not opt_op.inputs.has_key("LearningRate") \ - or not opt_op.inputs.has_key("Param"): - continue - lr = opt_op.inputs["LearningRate"].name - param = opt_op.inputs["Param"].name - if not lr_mapping.has_key(lr): - lr_mapping.update({lr: list()}) - lr_mapping[lr].append(param) - return lr_mapping - def _create_vars_from_blocklist(self, program, block_list): # Create respective variables using the block_list block_map = dict() @@ -369,18 +353,19 @@ class DistributeTranspiler: pass return orig_shape - def _fetch_var_names(self, param_dict): - res = [] - if not param_dict: - return res - for _, values in param_dict.iteritems(): - if not isinstance(values, list): - values = [values] - res += [v.name for v in values] - return res + # def _fetch_var_names(self, param_dict): + # res = [] + # if not param_dict: + # return res + # for _, values in param_dict.iteritems(): + # if not isinstance(values, list): + # values = [values] + # res += [v.name for v in values] + # return res def _append_pserver_ops(self, optimize_block, opt_op, endpoint): program = optimize_block.program + pserver_block = program.global_block() new_inputs = dict() # update param/grad shape first, then other inputs like # moment can use the updated shape @@ -395,11 +380,11 @@ class DistributeTranspiler: # do not append this op if current endpoint # is not dealing with this grad block return - merged_var = program.global_block().vars[grad_block.name] + merged_var = pserver_block.vars[grad_block.name] # append merging ops if trainers > 1 if self.trainers > 1: vars2merge = self._create_var_for_trainers( - program.global_block(), grad_block, self.trainers) + pserver_block, grad_block, self.trainers) optimize_block.append_op( type="sum", inputs={"X": vars2merge}, @@ -419,29 +404,27 @@ class DistributeTranspiler: break if not param_block: return - tmpvar = program.global_block().create_var( + tmpvar = pserver_block.create_var( name=param_block.name, persistable=True, dtype=param_block.dtype, shape=param_block.shape) - new_inputs[key] = tmpvar elif key == "LearningRate": # leraning rate variable has already be created by non-optimize op, # don't create it once again. - new_inputs[key] = program.global_block().vars[opt_op.input(key)[ - 0]] + new_inputs[key] = pserver_block.vars[opt_op.input(key)[0]] for key in opt_op.input_names: new_shape = None if key in ["Param", "Grad", "LearningRate"]: continue - var = program.global_block().vars[opt_op.input(key)[0]] + var = self.program.global_block().vars[opt_op.input(key)[0]] # update accumulator variable shape param_shape = new_inputs["Param"].shape new_shape = self._get_optimizer_input_shape(opt_op.type, key, var.shape, param_shape) - tmpvar = program.global_block().create_var( + tmpvar = pserver_block.create_var( name=var.name, persistable=var.persistable, dtype=var.dtype, @@ -449,11 +432,14 @@ class DistributeTranspiler: new_inputs[key] = tmpvar # change output's ParamOut variable + outputs = self._get_output_map_from_op(self.program.global_block().vars, + opt_op) opt_op.outputs["ParamOut"] = new_inputs["Param"] + optimize_block.append_op( type=opt_op.type, inputs=new_inputs, - outputs=opt_op.outputs, + outputs=outputs, attrs=opt_op.attrs) def _append_pserver_non_opt_ops(self, optimize_block, opt_op): @@ -497,11 +483,16 @@ class DistributeTranspiler: # If one op's input is another op's output or # one op's output is another op's input, we say # the two operator is connected. - op1_input_names = self._fetch_var_names(op1.inputs) - op1_output_names = self._fetch_var_names(op1.outputs) + # op1_input_names = self._fetch_var_names(op1.inputs) + # op1_output_names = self._fetch_var_names(op1.outputs) + op1_input_names = op1.desc.input_arg_names() + op1_output_names = op1.desc.output_arg_names() + + # op2_input_names = self._fetch_var_names(op2.inputs) + # op2_output_names = self._fetch_var_names(op2.outputs) + op2_input_names = op2.desc.input_arg_names() + op2_output_names = op2.desc.output_arg_names() - op2_input_names = self._fetch_var_names(op2.inputs) - op2_output_names = self._fetch_var_names(op2.outputs) if set(op1_output_names) & set(op2_input_names) or \ set(op1_input_names) & set(op2_output_names): return True @@ -521,8 +512,8 @@ class DistributeTranspiler: def _is_opt_op(self, op): # NOTE: It's a HACK implement. # optimize op: SGDOptimize, MomentumOptimizer, AdamOptimizer and etc... - if op.inputs and op.inputs.has_key("Param") \ - and op.inputs.has_key("LearningRate"): + if "Param" in op.input_names and \ + "LearningRate" in op.input_names: return True return False @@ -530,12 +521,12 @@ class DistributeTranspiler: param_names = [ p.name for p in self.param_grad_ep_mapping[endpoint]["params"] ] - if op.inputs["Param"].name in param_names: + if op.input("Param") in param_names: return True else: for n in param_names: - param = op.inputs["Param"].name - if same_or_split_var(n, param) and n != op.inputs["Param"].name: + param = op.input("Param")[0] + if same_or_split_var(n, param) and n != param: return True return False return False @@ -564,7 +555,6 @@ class DistributeTranspiler: persistable=True, dtype=v.dtype, shape=v.shape) - # step6 optimize_block = pserver_program.create_block(0) # step 6.1 From 92ac30efd9bab1e7bcf9c0d98e3b44dd4edbc5a3 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Sun, 11 Feb 2018 16:47:10 +0800 Subject: [PATCH 24/43] remove comments --- python/paddle/v2/fluid/distribute_transpiler.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/python/paddle/v2/fluid/distribute_transpiler.py b/python/paddle/v2/fluid/distribute_transpiler.py index 62d1f3434c..ff84e609e2 100644 --- a/python/paddle/v2/fluid/distribute_transpiler.py +++ b/python/paddle/v2/fluid/distribute_transpiler.py @@ -353,16 +353,6 @@ class DistributeTranspiler: pass return orig_shape - # def _fetch_var_names(self, param_dict): - # res = [] - # if not param_dict: - # return res - # for _, values in param_dict.iteritems(): - # if not isinstance(values, list): - # values = [values] - # res += [v.name for v in values] - # return res - def _append_pserver_ops(self, optimize_block, opt_op, endpoint): program = optimize_block.program pserver_block = program.global_block() @@ -483,13 +473,9 @@ class DistributeTranspiler: # If one op's input is another op's output or # one op's output is another op's input, we say # the two operator is connected. - # op1_input_names = self._fetch_var_names(op1.inputs) - # op1_output_names = self._fetch_var_names(op1.outputs) op1_input_names = op1.desc.input_arg_names() op1_output_names = op1.desc.output_arg_names() - # op2_input_names = self._fetch_var_names(op2.inputs) - # op2_output_names = self._fetch_var_names(op2.outputs) op2_input_names = op2.desc.input_arg_names() op2_output_names = op2.desc.output_arg_names() From 628bb27a5144a3765884c6c13fc1dd1655c80a93 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Sun, 11 Feb 2018 15:29:52 +0800 Subject: [PATCH 25/43] refine prior_boxes --- python/paddle/v2/fluid/layers/__init__.py | 5 +- python/paddle/v2/fluid/layers/detection.py | 260 +++++++++++++++++++++ python/paddle/v2/fluid/layers/nn.py | 256 ++------------------ 3 files changed, 287 insertions(+), 234 deletions(-) create mode 100644 python/paddle/v2/fluid/layers/detection.py diff --git a/python/paddle/v2/fluid/layers/__init__.py b/python/paddle/v2/fluid/layers/__init__.py index a83dd3db74..f4fb2ca279 100644 --- a/python/paddle/v2/fluid/layers/__init__.py +++ b/python/paddle/v2/fluid/layers/__init__.py @@ -26,12 +26,15 @@ import device from device import * import math_op_patch from math_op_patch import * +import detection +from detection import * __all__ = [] +__all__ += math_op_patch.__all__ __all__ += nn.__all__ __all__ += io.__all__ __all__ += tensor.__all__ __all__ += control_flow.__all__ __all__ += ops.__all__ __all__ += device.__all__ -__all__ += math_op_patch.__all__ +__all__ += detection.__all__ diff --git a/python/paddle/v2/fluid/layers/detection.py b/python/paddle/v2/fluid/layers/detection.py new file mode 100644 index 0000000000..b0c25c11de --- /dev/null +++ b/python/paddle/v2/fluid/layers/detection.py @@ -0,0 +1,260 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +All layers just related to the detection neural network. +""" + +from ..layer_helper import LayerHelper +from ..framework import Variable +from ..param_attr import ParamAttr +from ..framework import Variable +from layer_function_generator import autodoc +from tensor import concat +from nn import flatten +import math + +__all__ = [ + 'prior_box', + 'prior_boxes', +] + + +def prior_box(input, + image, + min_sizes, + max_sizes, + aspect_ratios, + variance, + flip=False, + clip=False, + step_w=0.0, + step_h=0.0, + offset=0.5, + name=None): + """ + **Prior_box** + + Generate prior boxes for SSD(Single Shot MultiBox Detector) algorithm. + Each position of the input produce N prior boxes, N is determined by + the count of min_sizes, max_sizes and aspect_ratios, The size of the + box is in range(min_size, max_size) interval, which is generated in + sequence according to the aspect_ratios. + + Args: + input(variable): The input feature data of PriorBox, + the layout is NCHW. + image(variable): The input image data of PriorBox, the + layout is NCHW. + min_sizes(list): the min sizes of generated prior boxes. + max_sizes(list): the max sizes of generated prior boxes. + aspect_ratios(list): the aspect ratios of generated prior boxes. + variance(list): the variances to be encoded in prior boxes. + flip(bool, optional, default=False): Whether to flip aspect ratios. + clip(bool, optional, default=False)): Whether to clip + out-of-boundary boxes. + step_w(int, optional, default=0.0): Prior boxes step across + width, 0.0 for auto calculation. + step_h(int, optional, default=0.0): Prior boxes step across + height, 0.0 for auto calculation. + offset(float, optional, default=0.5): Prior boxes center offset. + name(str, optional, default=None): Name of the prior box layer. + + Returns: + boxes(variable): the output prior boxes of PriorBoxOp. The layout is + [H, W, num_priors, 4]. H is the height of input, W is the width + of input, num_priors is the box count of each position. Where num_priors = + len(aspect_ratios) * len(min_sizes) + len(max_sizes) + Variances(variable): the expanded variances of PriorBoxOp. The layout + is [H, W, num_priors, 4]. H is the height of input, W is the width + of input, num_priors is the box count of each position. Where num_priors = + len(aspect_ratios) * len(min_sizes) + len(max_sizes) + Examples: + .. code-block:: python + + data = fluid.layers.data(name="data", shape=[3, 32, 32], dtype="float32") + conv2d = fluid.layers.conv2d( + input=data, num_filters=2, filter_size=3) + box, var = fluid.layers.prior_box(conv2d, data, + min_size, max_size, aspect_ratio, + variance, flip, clip, + step_w, step_h, offset) + """ + helper = LayerHelper("prior_box", **locals()) + dtype = helper.input_dtype() + + box = helper.create_tmp_variable(dtype) + var = helper.create_tmp_variable(dtype) + helper.append_op( + type="prior_box", + inputs={"Input": input, + "Image": image}, + outputs={"Boxes": box, + "Variances": var}, + attrs={ + 'min_sizes': min_sizes, + 'max_sizes': max_sizes, + 'aspect_ratios': aspect_ratios, + 'variances': variance, + 'flip': flip, + 'clip': clip, + 'step_w': step_w, + 'step_h': step_h, + 'offset': offset + }) + return box, var + + +def prior_boxes(inputs, + image, + min_ratio, + max_ratio, + aspect_ratios, + base_size, + steps=None, + step_w=None, + step_h=None, + offset=0.5, + variance=[0.1, 0.1, 0.1, 0.1], + flip=False, + clip=False, + name=None): + """ + **Prior_boxes** + + Generate prior boxes for SSD(Single Shot MultiBox Detector) algorithm. + Each position of the inputs produces many prior boxes respectly, the number + of prior boxes which is produced by inputs respectly is determined by + the count of min_ratio, max_ratio and aspect_ratios, The size of the + box is in range(min_ratio, max_ratio) interval, which is generated in + sequence according to the aspect_ratios. + + Args: + inputs(list): The list of input variables, the format of all variables is NCHW. + image(variable): The input image data of PriorBoxOp, the layout is NCHW. + min_ratio(int): the min ratio of generated prior boxes. + max_ratio(int): the max ratio of generated prior boxes. + aspect_ratios(list): the aspect ratios of generated prior boxes. + The length of input and aspect_ratios must be equal. + base_size(int): the base_size is used to get min_size and max_size + according to min_ratio and max_ratio. + step_w(list, optional, default=None): Prior boxes step across width. + If step_w[i] == 0.0, the prior boxes step across width of the inputs[i] + will be automatically calculated. + step_h(list, optional, default=None): Prior boxes step across height, + If step_h[i] == 0.0, the prior boxes step across height of the inputs[i] + will be automatically calculated. + offset(float, optional, default=0.5): Prior boxes center offset. + variance(list, optional, default=[0.1, 0.1, 0.1, 0.1]): the variances + to be encoded in prior boxes. + flip(bool, optional, default=False): Whether to flip aspect ratios. + clip(bool, optional, default=False): Whether to clip out-of-boundary boxes. + name(str, optional, None): Name of the prior box layer. + + Returns: + boxes(variable): the output prior boxes of PriorBoxOp. The layout is + [num_priors, 4]. num_priors is the total box count of each + position of inputs. + Variances(variable): the expanded variances of PriorBoxOp. The layout + is [num_priors, 4]. num_priors is the total box count of each + position of inputs + + Examples: + .. code-block:: python + + prior_boxes( + inputs = [conv1, conv2, conv3, conv4, conv5, conv6], + image = data, + min_ratio = 20, # 0.20 + max_ratio = 90, # 0.90 + steps = [8., 16., 32., 64., 100., 300.], + aspect_ratios = [[2.], [2., 3.], [2., 3.], [2., 3.], [2.], [2.]], + base_size = 300, + offset = 0.5, + variance = [0.1,0.1,0.1,0.1], + flip=True, + clip=True) + """ + assert isinstance(inputs, list), 'inputs should be a list.' + num_layer = len(inputs) + assert num_layer > 2 # TODO(zcd): currently, num_layer must be bigger than two. + + min_sizes = [] + max_sizes = [] + if num_layer > 2: + step = int(math.floor(((max_ratio - min_ratio)) / (num_layer - 2))) + for ratio in xrange(min_ratio, max_ratio + 1, step): + min_sizes.append(base_size * ratio / 100.) + max_sizes.append(base_size * (ratio + step) / 100.) + min_sizes = [base_size * .10] + min_sizes + max_sizes = [base_size * .20] + max_sizes + + if step_h: + assert isinstance(step_h,list) and len(step_h) == num_layer, \ + 'step_h should be list and inputs and step_h should have same length' + if step_w: + assert isinstance(step_w,list) and len(step_w) == num_layer, \ + 'step_w should be list and inputs and step_w should have same length' + if steps: + assert isinstance(steps,list) and len(steps) == num_layer, \ + 'steps should be list and inputs and step_w should have same length' + step_w = steps + step_h = steps + if aspect_ratios: + assert isinstance(aspect_ratios, list) and len(aspect_ratios) == num_layer, \ + 'aspect_ratios should be list and inputs and aspect_ratios should ' \ + 'have same length' + + box_results = [] + var_results = [] + for i, input in enumerate(inputs): + min_size = min_sizes[i] + max_size = max_sizes[i] + aspect_ratio = [] + if not isinstance(min_size, list): + min_size = [min_size] + if not isinstance(max_size, list): + max_size = [max_size] + if aspect_ratios: + aspect_ratio = aspect_ratios[i] + if not isinstance(aspect_ratio, list): + aspect_ratio = [aspect_ratio] + + box, var = prior_box(input, image, min_size, max_size, aspect_ratio, + variance, flip, clip, step_w[i] + if step_w else 0.0, step_h[i] + if step_w else 0.0, offset) + + box_results.append(box) + var_results.append(var) + + if len(box_results) == 1: + box = box_results[0] + var = var_results[0] + else: + axis = 3 + reshaped_boxes = [] + reshaped_vars = [] + for i in range(len(box_results)): + reshaped_boxes += [flatten(box_results[i], axis=3)] + reshaped_vars += [flatten(var_results[i], axis=3)] + + helper = LayerHelper("concat", **locals()) + dtype = helper.input_dtype() + box = helper.create_tmp_variable(dtype) + var = helper.create_tmp_variable(dtype) + + box = concat(reshaped_boxes) + var = concat(reshaped_vars) + + return box, var diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py index f0bcddaf9a..4d2de38c35 100644 --- a/python/paddle/v2/fluid/layers/nn.py +++ b/python/paddle/v2/fluid/layers/nn.py @@ -67,9 +67,8 @@ __all__ = [ 'beam_search', 'row_conv', 'reshape_with_axis', + 'flatten', 'multiplex', - 'prior_box', - 'prior_boxes', 'layer_norm', ] @@ -3149,242 +3148,33 @@ def reshape_with_axis(input, axis): return out -def prior_box(input, - image, - min_sizes, - max_sizes, - aspect_ratios, - variance, - flip=False, - clip=False, - step_w=0.0, - step_h=0.0, - offset=0.5, - name=None): +def flatten(input, axis=1): """ - **Prior_box** - - Generate prior boxes for SSD(Single Shot MultiBox Detector) algorithm. - Each position of the input produce N prior boxes, N is determined by - the count of min_sizes, max_sizes and aspect_ratios, The size of the - box is in range(min_size, max_size) interval, which is generated in - sequence according to the aspect_ratios. - + **Flatten Layer** + ReshapeWithAxis is used to merge adjacent dimensions according to axis. Args: - input(variable): The input feature data of PriorBox, - the layout is NCHW. - image(variable): The input image data of PriorBox, the - layout is NCHW. - min_sizes(list): the min sizes of generated prior boxes. - max_sizes(list): the max sizes of generated prior boxes. - aspect_ratios(list): the aspect ratios of generated prior boxes. - variance(list): the variances to be encoded in prior boxes. - flip(bool, optional, default=False): Whether to flip aspect ratios. - clip(bool, optional, default=False)): Whether to clip - out-of-boundary boxes. - step_w(int, optional, default=0.0): Prior boxes step across - width, 0.0 for auto calculation. - step_h(int, optional, default=0.0): Prior boxes step across - height, 0.0 for auto calculation. - offset(float, optional, default=0.5): Prior boxes center offset. - name(str, optional, default=None): Name of the prior box layer. - + input(variable): The input tensor. + axis(int): Returns: - boxes(variable): the output prior boxes of PriorBoxOp. The layout is - [H, W, num_priors, 4]. H is the height of input, W is the width - of input, num_priors is the box count of each position. Where num_priors = - len(aspect_ratios) * len(min_sizes) + len(max_sizes) - Variances(variable): the expanded variances of PriorBoxOp. The layout - is [H, W, num_priors, 4]. H is the height of input, W is the width - of input, num_priors is the box count of each position. Where num_priors = - len(aspect_ratios) * len(min_sizes) + len(max_sizes) + Variable: A tensor variable. Examples: .. code-block:: python - - data = fluid.layers.data(name="data", shape=[3, 32, 32], dtype="float32") - conv2d = fluid.layers.conv2d( - input=data, num_filters=2, filter_size=3) - box, var = fluid.layers.prior_box(conv2d, data, - min_size, max_size, aspect_ratio, - variance, flip, clip, - step_w, step_h, offset) + x = fluid.layers.data(name="data", shape=[3, 32, 32], dtype="float32") + reshaped = fluid.layers.reshape_with_axis(input=x, axis=2) + reshaped.shape + >> [-1, 1024] """ - helper = LayerHelper("prior_box", **locals()) - dtype = helper.input_dtype() - - box = helper.create_tmp_variable(dtype) - var = helper.create_tmp_variable(dtype) - helper.append_op( - type="prior_box", - inputs={"Input": input, - "Image": image}, - outputs={"Boxes": box, - "Variances": var}, - attrs={ - 'min_sizes': min_sizes, - 'max_sizes': max_sizes, - 'aspect_ratios': aspect_ratios, - 'variances': variance, - 'flip': flip, - 'clip': clip, - 'step_w': step_w, - 'step_h': step_h, - 'offset': offset - }) - return box, var - - -def prior_boxes(inputs, - image, - min_ratio, - max_ratio, - aspect_ratios, - base_size, - steps=None, - step_w=None, - step_h=None, - offset=0.5, - variance=[0.1, 0.1, 0.1, 0.1], - flip=False, - clip=False, - name=None): - """ - **Prior_boxes** - - Generate prior boxes for SSD(Single Shot MultiBox Detector) algorithm. - Each position of the inputs produces many prior boxes respectly, the number - of prior boxes which is produced by inputs respectly is determined by - the count of min_ratio, max_ratio and aspect_ratios, The size of the - box is in range(min_ratio, max_ratio) interval, which is generated in - sequence according to the aspect_ratios. - - Args: - inputs(list): The list of input variables, the format of all variables is NCHW. - image(variable): The input image data of PriorBoxOp, the layout is NCHW. - min_ratio(int): the min ratio of generated prior boxes. - max_ratio(int): the max ratio of generated prior boxes. - aspect_ratios(list): the aspect ratios of generated prior boxes. - The length of input and aspect_ratios must be equal. - base_size(int): the base_size is used to get min_size and max_size - according to min_ratio and max_ratio. - step_w(list, optional, default=None): Prior boxes step across width. - If step_w[i] == 0.0, the prior boxes step across width of the inputs[i] - will be automatically calculated. - step_h(list, optional, default=None): Prior boxes step across height, - If step_h[i] == 0.0, the prior boxes step across height of the inputs[i] - will be automatically calculated. - offset(float, optional, default=0.5): Prior boxes center offset. - variance(list, optional, default=[0.1, 0.1, 0.1, 0.1]): the variances - to be encoded in prior boxes. - flip(bool, optional, default=False): Whether to flip aspect ratios. - clip(bool, optional, default=False): Whether to clip out-of-boundary boxes. - name(str, optional, None): Name of the prior box layer. - - Returns: - boxes(variable): the output prior boxes of PriorBoxOp. The layout is - [num_priors, 4]. num_priors is the total box count of each - position of inputs. - Variances(variable): the expanded variances of PriorBoxOp. The layout - is [num_priors, 4]. num_priors is the total box count of each - position of inputs - - Examples: - .. code-block:: python + assert len(input.shape) > axis and axis > 0, \ + "the axis should be litter than input.shape's." + input_shape = input.shape - prior_boxes( - inputs = [conv1, conv2, conv3, conv4, conv5, conv6], - image = data, - min_ratio = 20, # 0.20 - max_ratio = 90, # 0.90 - steps = [8., 16., 32., 64., 100., 300.], - aspect_ratios = [[2.], [2., 3.], [2., 3.], [2., 3.], [2.], [2.]], - base_size = 300, - offset = 0.5, - variance = [0.1,0.1,0.1,0.1], - flip=True, - clip=True) - """ - assert isinstance(inputs, list), 'inputs should be a list.' - num_layer = len(inputs) - assert num_layer > 2 # TODO(zcd): currently, num_layer must be bigger than two. - - min_sizes = [] - max_sizes = [] - if num_layer > 2: - step = int(math.floor(((max_ratio - min_ratio)) / (num_layer - 2))) - for ratio in xrange(min_ratio, max_ratio + 1, step): - min_sizes.append(base_size * ratio / 100.) - max_sizes.append(base_size * (ratio + step) / 100.) - min_sizes = [base_size * .10] + min_sizes - max_sizes = [base_size * .20] + max_sizes - - if step_h: - assert isinstance(step_h,list) and len(step_h) == num_layer, \ - 'step_h should be list and inputs and step_h should have same length' - if step_w: - assert isinstance(step_w,list) and len(step_w) == num_layer, \ - 'step_w should be list and inputs and step_w should have same length' - if steps: - assert isinstance(steps,list) and len(steps) == num_layer, \ - 'steps should be list and inputs and step_w should have same length' - step_w = steps - step_h = steps - if aspect_ratios: - assert isinstance(aspect_ratios, list) and len(aspect_ratios) == num_layer, \ - 'aspect_ratios should be list and inputs and aspect_ratios should ' \ - 'have same length' - - box_results = [] - var_results = [] - for i, input in enumerate(inputs): - min_size = min_sizes[i] - max_size = max_sizes[i] - aspect_ratio = [] - if not isinstance(min_size, list): - min_size = [min_size] - if not isinstance(max_size, list): - max_size = [max_size] - if aspect_ratios: - aspect_ratio = aspect_ratios[i] - if not isinstance(aspect_ratio, list): - aspect_ratio = [aspect_ratio] - - box, var = prior_box(input, image, min_size, max_size, aspect_ratio, - variance, flip, clip, step_w[i] - if step_w else 0.0, step_h[i] - if step_w else 0.0, offset) - - box_results.append(box) - var_results.append(var) - - if len(box_results) == 1: - box = box_results[0] - var = var_results[0] - else: - axis = 3 - reshaped_boxes = [] - reshaped_vars = [] - for i in range(len(box_results)): - reshaped_boxes += [reshape_with_axis(box_results[i], axis=[axis])] - reshaped_vars += [reshape_with_axis(var_results[i], axis=[axis])] - - helper = LayerHelper("concat", **locals()) - dtype = helper.input_dtype() - box = helper.create_tmp_variable(dtype) - var = helper.create_tmp_variable(dtype) - - axis = 0 - helper.append_op( - type="concat", - inputs={"X": reshaped_boxes}, - outputs={"Out": box}, - attrs={'axis': axis}) + new_shape = [-1, reduce(mul, input_shape[axis:len(input_shape)], 1)] - var = helper.create_tmp_variable(dtype) - helper.append_op( - type="concat", - inputs={"X": reshaped_vars}, - outputs={"Out": var}, - attrs={'axis': axis}) - - return box, var + helper = LayerHelper('reshape', **locals()) + out = helper.create_tmp_variable(helper.input_dtype()) + helper.append_op( + type='reshape', + inputs={'X': [input]}, + outputs={'Out': [out]}, + attrs={'shape': new_shape}) + return out From 74f7aff397871b1f658e5f0d5195beb94794551f Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Sun, 11 Feb 2018 16:46:28 +0800 Subject: [PATCH 26/43] add unit test --- .../object_detection/test_prior_boxes.py | 87 +++++++++++++++++++ 1 file changed, 87 insertions(+) create mode 100644 python/paddle/v2/fluid/tests/object_detection/test_prior_boxes.py diff --git a/python/paddle/v2/fluid/tests/object_detection/test_prior_boxes.py b/python/paddle/v2/fluid/tests/object_detection/test_prior_boxes.py new file mode 100644 index 0000000000..50b5249d98 --- /dev/null +++ b/python/paddle/v2/fluid/tests/object_detection/test_prior_boxes.py @@ -0,0 +1,87 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import numpy as np +import paddle.v2.fluid as fluid +import paddle.v2.fluid.layers.detection as detection +import paddle.v2.fluid.core as core +import unittest + + +def prior_box_output(data_shape): + images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32') + conv1 = fluid.layers.conv2d( + input=images, num_filters=3, filter_size=3, stride=2, use_cudnn=False) + conv2 = fluid.layers.conv2d( + input=conv1, num_filters=3, filter_size=3, stride=2, use_cudnn=False) + conv3 = fluid.layers.conv2d( + input=conv2, num_filters=3, filter_size=3, stride=2, use_cudnn=False) + conv4 = fluid.layers.conv2d( + input=conv3, num_filters=3, filter_size=3, stride=2, use_cudnn=False) + conv5 = fluid.layers.conv2d( + input=conv4, num_filters=3, filter_size=3, stride=2, use_cudnn=False) + + box, var = detection.prior_boxes( + inputs=[conv1, conv2, conv3, conv4, conv5, conv5], + image=images, + min_ratio=20, + max_ratio=90, + # steps=[8, 16, 32, 64, 100, 300], + aspect_ratios=[[2.], [2., 3.], [2., 3.], [2., 3.], [2.], [2.]], + base_size=300, + offset=0.5, + flip=True, + clip=True) + return box, var + + +def main(use_cuda): + if use_cuda: # prior_box only support CPU. + return + + box, var = prior_box_output(data_shape=[3, 224, 224]) + + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + batch = [128] + + for i in range(1): + # print("iteration : %d" % i) + x = np.random.random(batch + data_shape).astype("float32") + tensor_x = core.LoDTensor() + tensor_x.set(x, place) + box, var = exe.run(fluid.default_main_program(), + feed={'pixel': tensor_x}, + fetch_list=[box, var]) + box_arr = np.array(box) + var_arr = np.array(var) + assert box_arr.shape[1] == 4 + assert var_arr.shape[1] == 4 + assert box_arr.shape[0] == var_arr.shape[0] + + +class TestFitALine(unittest.TestCase): + def test_cpu(self): + with self.program_scope_guard(): + main(use_cuda=False) + + def test_cuda(self): + with self.program_scope_guard(): + main(use_cuda=True) + + +if __name__ == '__main__': + unittest.main() From 01f4bcb57ee99a9d5c2d52cbc7cbce4c4a0454c8 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Sun, 11 Feb 2018 16:54:01 +0800 Subject: [PATCH 27/43] remove inputs/outputs from Operator --- python/paddle/v2/fluid/distribute_transpiler.py | 2 +- python/paddle/v2/fluid/framework.py | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/python/paddle/v2/fluid/distribute_transpiler.py b/python/paddle/v2/fluid/distribute_transpiler.py index ff84e609e2..f84481adf7 100644 --- a/python/paddle/v2/fluid/distribute_transpiler.py +++ b/python/paddle/v2/fluid/distribute_transpiler.py @@ -424,7 +424,7 @@ class DistributeTranspiler: # change output's ParamOut variable outputs = self._get_output_map_from_op(self.program.global_block().vars, opt_op) - opt_op.outputs["ParamOut"] = new_inputs["Param"] + outputs["ParamOut"] = new_inputs["Param"] optimize_block.append_op( type=opt_op.type, diff --git a/python/paddle/v2/fluid/framework.py b/python/paddle/v2/fluid/framework.py index a517db68c5..35d3df785b 100644 --- a/python/paddle/v2/fluid/framework.py +++ b/python/paddle/v2/fluid/framework.py @@ -400,9 +400,6 @@ class Operator(object): """ self.block = block self.desc = desc - # for clone a new operator - self.inputs = inputs - self.outputs = outputs self.attrs = attrs if len(self.desc.type()) != 0: return From 07bb4139776cafecbdcc11d663e38e22a2163a96 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Sun, 11 Feb 2018 16:59:57 +0800 Subject: [PATCH 28/43] Revert changes --- paddle/fluid/framework/mixed_vector.h | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/framework/mixed_vector.h b/paddle/fluid/framework/mixed_vector.h index 837b5fa7f6..a06e34d551 100644 --- a/paddle/fluid/framework/mixed_vector.h +++ b/paddle/fluid/framework/mixed_vector.h @@ -110,10 +110,6 @@ class Vector { T* end() { return size() == 0 ? &EmptyDummy() : &this->operator[](size()); } - const T* cbegin() const { return begin(); } - - const T* cend() const { return end(); } - T& front() { return *begin(); } T& back() { @@ -125,10 +121,15 @@ class Vector { const T* begin() const { return size() == 0 ? &EmptyDummy() : &this->operator[](0); } + const T* end() const { return size() == 0 ? &EmptyDummy() : &this->operator[](size()); } + const T* cbegin() const { return begin(); } + + const T* cend() const { return end(); } + const T& back() const { auto it = end(); --it; @@ -244,10 +245,8 @@ class Vector { // implicit cast operator. Vector can be cast to std::vector implicitly. operator std::vector() const { std::vector result; - if (size() != 0) { - result.resize(size()); - std::copy(begin(), end(), result.begin()); - } + result.resize(size()); + std::copy(begin(), end(), result.begin()); return result; } From 42912e48cb0ce5c62c450e42373de09b04c30513 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Sun, 11 Feb 2018 17:28:01 +0800 Subject: [PATCH 29/43] rename switch_kernel.md to kernel_selection.md --- doc/design/{switch_kernel.md => kernel_selection.md} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename doc/design/{switch_kernel.md => kernel_selection.md} (100%) diff --git a/doc/design/switch_kernel.md b/doc/design/kernel_selection.md similarity index 100% rename from doc/design/switch_kernel.md rename to doc/design/kernel_selection.md From 5d5dcedc841452e4135275087ffc0ad03ebf47f4 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Sun, 11 Feb 2018 18:45:25 +0800 Subject: [PATCH 30/43] merge build docs with build using docker --- doc/build_and_install/build_cn.md | 124 ------------------ doc/build_and_install/build_en.md | 124 ------------------ .../build_from_source_cn.rst | 111 +++++++++++++--- .../build_from_source_en.rst | 120 +++++++++++++---- 4 files changed, 182 insertions(+), 297 deletions(-) delete mode 100644 doc/build_and_install/build_cn.md delete mode 100644 doc/build_and_install/build_en.md diff --git a/doc/build_and_install/build_cn.md b/doc/build_and_install/build_cn.md deleted file mode 100644 index 4a80a52451..0000000000 --- a/doc/build_and_install/build_cn.md +++ /dev/null @@ -1,124 +0,0 @@ -# 用Docker编译和测试PaddlePaddle - -## 需要的软硬件 - -为了开发PaddlePaddle,我们需要 - -1. 一台电脑,可以装的是 Linux, BSD, Windows 或者 MacOS 操作系统,以及 -1. Docker。 - -不需要依赖其他任何软件了。即便是 Python 和 GCC 都不需要,因为我们会把所有编译工具都安装进一个 Docker image 里。 - -## 总体流程 - -1. 获取源码 - - ```bash - git clone https://github.com/paddlepaddle/paddle - ``` - -2. 安装开发工具到 Docker image 里 - - ```bash - cd paddle; docker build -t paddle:dev . - ``` - - 请注意这个命令结尾处的 `.`;它表示 `docker build` 应该读取当前目录下的 [`Dockerfile`文件](https://github.com/PaddlePaddle/Paddle/blob/develop/Dockerfile),按照其内容创建一个名为 `paddle:dev` 的 Docker image,并且把各种开发工具安装进去。 - -3. 编译 - - 以下命令启动一个 Docker container 来执行 `paddle:dev` 这个 Docker image,同时把当前目录(源码树根目录)映射为 container 里的 `/paddle` 目录,并且运行 `Dockerfile` 描述的默认入口程序 [`build.sh`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/scripts/docker/build.sh)。这个脚本调用 `cmake` 和 `make` 来编译 `/paddle` 里的源码,结果输出到 `/paddle/build`,也就是本地的源码树根目录里的 `build` 子目录。 - - ```bash - docker run --rm -v $PWD:/paddle paddle:dev - ``` - - 上述命令编译出一个 CUDA-enabled 版本。如果我们只需要编译一个只支持 CPU 的版本,可以用 - - ```bash - docker run --rm -e WITH_GPU=OFF -v $PWD:/paddle paddle:dev - ``` - -4. 运行单元测试 - - 用本机的第一个 GPU 来运行包括 GPU 单元测试在内的所有单元测试: - - ```bash - NV_GPU=0 nvidia-docker run --rm -v $PWD:/paddle paddle:dev bash -c "cd /paddle/build; ctest" - ``` - - 如果编译的时候我们用了 `WITH_GPU=OFF` 选项,那么编译过程只会产生 CPU-based 单元测试,那么我们也就不需要 nvidia-docker 来运行单元测试了。我们只需要: - - ```bash - docker run --rm -v $PWD:/paddle paddle:dev bash -c "cd /paddle/build; ctest" - ``` - - 有时候我们只想运行一个特定的单元测试,比如 `memory_test`,我们可以 - - ```bash - nvidia-docker run --rm -v $PWD:/paddle paddle:dev bash -c "cd /paddle/build; ctest -V -R memory_test" - ``` - -5. 清理 - - 有时候我们会希望清理掉已经下载的第三方依赖以及已经编译的二进制文件。此时只需要: - - ```bash - rm -rf build - ``` - -## 为什么要 Docker 呀? - -- 什么是 Docker? - - 如果您没有听说 Docker,可以把它想象为一个类似 virtualenv 的系统,但是虚拟的不仅仅是 Python 的运行环境。 - -- Docker 还是虚拟机? - - 有人用虚拟机来类比 Docker。需要强调的是:Docker 不会虚拟任何硬件,Docker container 里运行的编译工具实际上都是在本机的 CPU 和操作系统上直接运行的,性能和把编译工具安装在本机运行一样。 - -- 为什么用 Docker? - - 把工具和配置都安装在一个 Docker image 里可以标准化编译环境。这样如果遇到问题,其他人可以复现问题以便帮助。 - - 另外,对于习惯使用Windows和MacOS的开发者来说,使用Docker就不用配置交叉编译环境了。 - -- 我可以选择不用Docker吗? - - 当然可以。大家可以用把开发工具安装进入 Docker image 一样的方式,把这些工具安装到本机。这篇文档介绍基于 Docker 的开发流程,是因为这个流程比其他方法都更简便。 - -- 学习 Docker 有多难? - - 理解 Docker 并不难,大概花十分钟看一下[这篇文章](https://zhuanlan.zhihu.com/p/19902938)。这可以帮您省掉花一小时安装和配置各种开发工具,以及切换机器时需要新安装的辛苦。别忘了 PaddlePaddle 更新可能导致需要新的开发工具。更别提简化问题复现带来的好处了。 - -- 我可以用 IDE 吗? - - 当然可以,因为源码就在本机上。IDE 默认调用 make 之类的程序来编译源码,我们只需要配置 IDE 来调用 Docker 命令编译源码即可。 - - 很多 PaddlePaddle 开发者使用 Emacs。他们在自己的 `~/.emacs` 配置文件里加两行 - - ```emacs - (global-set-key "\C-cc" 'compile) - (setq compile-command - "docker run --rm -it -v $(git rev-parse --show-toplevel):/paddle paddle:dev") - ``` - - 就可以按 `Ctrl-C` 和 `c` 键来启动编译了。 - -- 可以并行编译吗? - - 是的。我们的 Docker image 运行一个 [Bash 脚本](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/scripts/docker/build.sh)。这个脚本调用 `make -j$(nproc)` 来启动和 CPU 核一样多的进程来并行编译。 - -## 可能碰到的问题 - -- Docker 需要 sudo - - 如果用自己的电脑开发,自然也就有管理员权限(sudo)了。如果用公用的电脑开发,需要请管理员安装和配置好 Docker。此外,PaddlePaddle 项目在努力开始支持其他不需要 sudo 的集装箱技术,比如 rkt。 - -- 在 Windows/MacOS 上编译很慢 - - Docker 在 Windows 和 MacOS 都可以运行。不过实际上是运行在一个 Linux 虚拟机上。可能需要注意给这个虚拟机多分配一些 CPU 和内存,以保证编译高效。具体做法请参考[这个issue](https://github.com/PaddlePaddle/Paddle/issues/627)。 - -- 磁盘不够 - - 本文中的例子里,`docker run` 命令里都用了 `--rm` 参数,这样保证运行结束之后的 containers 不会保留在磁盘上。可以用 `docker ps -a` 命令看到停止后但是没有删除的 containers。`docker build` 命令有时候会产生一些中间结果,是没有名字的 images,也会占用磁盘。可以参考[这篇文章](https://zaiste.net/posts/removing_docker_containers/)来清理这些内容。 diff --git a/doc/build_and_install/build_en.md b/doc/build_and_install/build_en.md deleted file mode 100644 index 91c41ef8ce..0000000000 --- a/doc/build_and_install/build_en.md +++ /dev/null @@ -1,124 +0,0 @@ -# Build using Docker - -## What Developers Need - -To contribute to PaddlePaddle, you need - -1. A computer -- Linux, BSD, Windows, MacOS, and -1. Docker. - -Nothing else. Not even Python and GCC, because you can install all build tools into a Docker image. We run all the tools by running this image. - -## General Process - -1. Retrieve source code. - - ```bash - git clone https://github.com/paddlepaddle/paddle - ``` - -2. Install build tools into a Docker image. - - ```bash - cd paddle; docker build -t paddle:dev . - ``` - - Please be aware of the `.` at the end of the command, which refers to the [`./Dockerfile` file](https://github.com/PaddlePaddle/Paddle/blob/develop/Dockerfile). `docker build` follows instructions in this file to create a Docker image named `paddle:dev`, and installs building tools into it. - -3. Build from source. - - This following command starts a Docker container that executes the Docker image `paddle:dev`, mapping the current directory to `/paddle/` in the container, and runs the default entry-point [`build.sh`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/scripts/docker/build.sh) as specified in the Dockefile. `build.sh` invokes `cmake` and `make` to build PaddlePaddle source code, which had been mapped to `/paddle`, and writes outputs to `/paddle/build`, which maps to `build` in the current source directory on the computer. - - ```bash - docker run -v $PWD:/paddle paddle:dev - ``` - - Above command builds a CUDA-enabled version. If we want to build a CPU-only version, we can type - - ```bash - docker run -e WITH_GPU=OFF -v $PWD:/paddle paddle:dev - ``` - -4. Run unit tests. - - To run all unit tests using the first GPU of a node: - - ```bash - NV_GPU=0 nvidia-docker run -v $PWD:/paddle paddle:dev bash -c "cd /paddle/build; ctest" - ``` - - If we used `WITH_GPU=OFF` at build time, it generates only CPU-based unit tests, and we don't need nvidia-docker to run them. We can just run - - ```bash - docker run -v $PWD:/paddle paddle:dev bash -c "cd /paddle/build; ctest" - ``` - - Sometimes we want to run a specific unit test, say `memory_test`, we can run - - ```bash - nvidia-docker run -v $PWD:/paddle paddle:dev bash -c "cd /paddle/build; ctest -V -R memory_test" - ``` - -5. Clean Build. - - Sometimes, we might want to clean all thirt-party dependents and built binaries. To do so, just - - ```bash - rm -rf build - ``` - -## Docker, Or Not? - -- What is Docker? - - If you haven't heard of it, consider it something like Python's virtualenv. - -- Docker or virtual machine? - - Some people compare Docker with VMs, but Docker doesn't virtualize any hardware nor running a guest OS, which means there is no compromise on the performance. - -- Why Docker? - - Using a Docker image of build tools standardizes the building environment, which makes it easier for others to reproduce your problems and to help. - - Also, some build tools don't run on Windows or Mac or BSD, but Docker runs almost everywhere, so developers can use whatever computer they want. - -- Can I choose not to use Docker? - - Sure, you don't have to install build tools into a Docker image; instead, you can install them in your local computer. This document exists because Docker would make the development way easier. - -- How difficult is it to learn Docker? - - It takes you ten minutes to read [an introductory article](https://docs.docker.com/get-started) and saves you more than one hour to install all required build tools, configure them, especially when new versions of PaddlePaddle require some new tools. Not even to mention the time saved when other people trying to reproduce the issue you have. - -- Can I use my favorite IDE? - - Yes, of course. The source code resides on your local computer, and you can edit it using whatever editor you like. - - Many PaddlePaddle developers are using Emacs. They add the following few lines into their `~/.emacs` configure file: - - ```emacs - (global-set-key "\C-cc" 'compile) - (setq compile-command - "docker run --rm -it -v $(git rev-parse --show-toplevel):/paddle paddle:dev") - ``` - - so they could type `Ctrl-C` and `c` to build PaddlePaddle from source. - -- Does Docker do parallel building? - - Our building Docker image runs a [Bash script](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/scripts/docker/build.sh), which calls `make -j$(nproc)` to starts as many processes as the number of your CPU cores. - -## Some Gotchas - -- Docker requires sudo - - An owner of a computer has the administrative privilege, a.k.a., sudo, and Docker requires this privilege to work properly. If you use a shared computer for development, please ask the administrator to install and configure Docker. We will do our best to support rkt, another container technology that doesn't require sudo. - -- Docker on Windows/MacOS builds slowly - - On Windows and MacOS, Docker containers run in a Linux VM. You might want to give this VM some more memory and CPUs so to make the building efficient. Please refer to [this issue](https://github.com/PaddlePaddle/Paddle/issues/627) for details. - -- Not enough disk space - - Examples in this article uses option `--rm` with the `docker run` command. This option ensures that stopped containers do not exist on hard disks. We can use `docker ps -a` to list all containers, including stopped. Sometimes `docker build` generates some intermediate dangling images, which also take disk space. To clean them, please refer to [this article](https://zaiste.net/posts/removing_docker_containers/). diff --git a/doc/build_and_install/build_from_source_cn.rst b/doc/build_and_install/build_from_source_cn.rst index ff904b1022..fec2d412f0 100644 --- a/doc/build_and_install/build_from_source_cn.rst +++ b/doc/build_and_install/build_from_source_cn.rst @@ -1,14 +1,26 @@ 从源码编译 ====================== +.. _requirements: + +需要的软硬件 +---------------- + +为了编译PaddlePaddle,我们需要 + +1. 一台电脑,可以装的是 Linux, Windows 或者 MacOS 操作系统 +1. Docker + +不需要依赖其他任何软件了。即便是 Python 和 GCC 都不需要,因为我们会把所有编译工具都安装进一个 Docker 镜像里。 + .. _build_step: 编译方法 ---------------- -PaddlePaddle主要使用 `CMake `_ 以及GCC, G++作为编译工具。 -我们推荐您使用PaddlePaddle Docker编译环境镜像完成编译,这样可以免去单独安装编译依赖的步骤,可选的不同编译环境Docker镜像 -可以在 `这里 `_ 找到。 +PaddlePaddle需要使用Docker环境完成编译,这样可以免去单独安装编译依赖的步骤,可选的不同编译环境Docker镜像 +可以在 `这里 `_ 找到。或者 +参考下述可选步骤,从源码中构建用于编译PaddlePaddle的Docker镜像。 如果您选择不使用Docker镜像,则需要在本机安装下面章节列出的 `编译依赖`_ 之后才能开始编译的步骤。 @@ -16,15 +28,19 @@ PaddlePaddle主要使用 `CMake `_ 以及GCC, G++作为编译 .. code-block:: bash + # 1. 获取源码 git clone https://github.com/PaddlePaddle/Paddle.git cd Paddle - # 如果使用Docker编译环境,执行下面的命令编译CPU-Only的二进制 + # 2. 可选步骤:源码中构建用于编译PaddlePaddle的Docker镜像 + docker build -t paddle:dev . + # 3. 执行下面的命令编译CPU-Only的二进制 docker run -it -v $PWD:/paddle -e "WITH_GPU=OFF" -e "WITH_TESTING=OFF" paddlepaddle/paddle_manylinux_devel:cuda8.0_cudnn5 bash -x /paddle/paddle/scripts/docker/build.sh - # 如果不使用Docker编译环境,执行下面的命令 - mkdir build - cd build - cmake -DWITH_GPU=OFF -DWITH_TESTING=OFF .. - make + # 4. 或者也可以使用为上述可选步骤构建的镜像(必须先执行第2步) + docker run -it -v $PWD:/paddle -e "WITH_GPU=OFF" -e "WITH_TESTING=OFF" paddle:dev + +注:上述命令把当前目录(源码树根目录)映射为 container 里的 :code:`/paddle` 目录。如果使用自行 +构建的镜像(上述第4步)会执行 :code:`Dockerfile` 描述的默认入口程序 :code:`build.sh` 可以省略步骤3中 +最后的执行脚本的命令。 编译完成后会在build/python/dist目录下生成输出的whl包,可以选在在当前机器安装也可以拷贝到目标机器安装: @@ -50,28 +66,83 @@ PaddlePaddle主要使用 `CMake `_ 以及GCC, G++作为编译 如果您期望在编译完成后立即执行所有的单元测试,可以按照下面的方法: -使用Docker的情况下,设置 :code:`RUN_TEST=ON` 和 :code:`WITH_TESTING=ON` 就会在完成编译之后,立即执行单元测试。 +设置 :code:`RUN_TEST=ON` 和 :code:`WITH_TESTING=ON` 就会在完成编译之后,立即执行单元测试。 开启 :code:`WITH_GPU=ON` 可以指定同时执行GPU上的单元测试。 .. code-block:: bash docker run -it -v $PWD:/paddle -e "WITH_GPU=OFF" -e "WITH_TESTING=ON" -e "RUN_TEST=ON" paddlepaddle/paddle_manylinux_devel:cuda8.0_cudnn5 bash -x /paddle/paddle/scripts/docker/build.sh -如果不使用Docker,可以执行ctest命令即可: +如果期望执行其中一个单元测试,(比如 :code:`test_sum_op` ): .. code-block:: bash - mkdir build - cd build - cmake -DWITH_GPU=OFF -DWITH_TESTING=OFF .. - make - ctest - # 指定执行其中一个单元测试 test_mul_op - ctest -R test_mul_op + docker run -it -v $PWD:/paddle -e "WITH_GPU=OFF" -e "WITH_TESTING=ON" -e "RUN_TEST=OFF" paddlepaddle/paddle_manylinux_devel:cuda8.0_cudnn5 /bin/bash + bash /paddle/paddle/scripts/docker/build.sh + cd /paddle/build + ctest -R test_sum_op -V + +.. _faq_docker: + +常见问题 +---------------- + +- 什么是 Docker? + + 如果您没有听说 Docker,可以把它想象为一个类似 virtualenv 的系统,但是虚拟的不仅仅是 Python 的运行环境。 + +- Docker 还是虚拟机? + + 有人用虚拟机来类比 Docker。需要强调的是:Docker 不会虚拟任何硬件,Docker container 里运行的编译工具实际上都是在本机的 CPU 和操作系统上直接运行的,性能和把编译工具安装在本机运行一样。 + +- 为什么用 Docker? + + 把工具和配置都安装在一个 Docker image 里可以标准化编译环境。这样如果遇到问题,其他人可以复现问题以便帮助。 + + 另外,对于习惯使用Windows和MacOS的开发者来说,使用Docker就不用配置交叉编译环境了。 + +- 我可以选择不用Docker吗? + + 当然可以。大家可以用把开发工具安装进入 Docker image 一样的方式,把这些工具安装到本机。这篇文档介绍基于 Docker 的开发流程,是因为这个流程比其他方法都更简便。 + +- 学习 Docker 有多难? + + 理解 Docker 并不难,大概花十分钟看一下[这篇文章](https://zhuanlan.zhihu.com/p/19902938)。这可以帮您省掉花一小时安装和配置各种开发工具,以及切换机器时需要新安装的辛苦。别忘了 PaddlePaddle 更新可能导致需要新的开发工具。更别提简化问题复现带来的好处了。 + +- 我可以用 IDE 吗? + + 当然可以,因为源码就在本机上。IDE 默认调用 make 之类的程序来编译源码,我们只需要配置 IDE 来调用 Docker 命令编译源码即可。 + + 很多 PaddlePaddle 开发者使用 Emacs。他们在自己的 `~/.emacs` 配置文件里加两行 + + ```emacs + (global-set-key "\C-cc" 'compile) + (setq compile-command + "docker run --rm -it -v $(git rev-parse --show-toplevel):/paddle paddle:dev") + ``` + + 就可以按 `Ctrl-C` 和 `c` 键来启动编译了。 + +- 可以并行编译吗? + + 是的。我们的 Docker image 运行一个 [Bash 脚本](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/scripts/docker/build.sh)。这个脚本调用 `make -j$(nproc)` 来启动和 CPU 核一样多的进程来并行编译。 + +- Docker 需要 sudo + + 如果用自己的电脑开发,自然也就有管理员权限(sudo)了。如果用公用的电脑开发,需要请管理员安装和配置好 Docker。此外,PaddlePaddle 项目在努力开始支持其他不需要 sudo 的集装箱技术,比如 rkt。 + +- 在 Windows/MacOS 上编译很慢 + + Docker 在 Windows 和 MacOS 都可以运行。不过实际上是运行在一个 Linux 虚拟机上。可能需要注意给这个虚拟机多分配一些 CPU 和内存,以保证编译高效。具体做法请参考[这个issue](https://github.com/PaddlePaddle/Paddle/issues/627)。 + +- 磁盘不够 + + 本文中的例子里,`docker run` 命令里都用了 `--rm` 参数,这样保证运行结束之后的 containers 不会保留在磁盘上。可以用 `docker ps -a` 命令看到停止后但是没有删除的 containers。`docker build` 命令有时候会产生一些中间结果,是没有名字的 images,也会占用磁盘。可以参考[这篇文章](https://zaiste.net/posts/removing_docker_containers/)来清理这些内容。 + .. _compile_deps: -编译依赖 +附录:编译依赖 ---------------- PaddlePaddle编译需要使用到下面的依赖(包含但不限于),其他的依赖软件,会自动在编译时下载。 @@ -91,7 +162,7 @@ PaddlePaddle编译需要使用到下面的依赖(包含但不限于),其 .. _build_options: -编译选项 +附录:编译选项 ---------------- PaddlePaddle的编译选项,包括生成CPU/GPU二进制文件、链接何种BLAS库等。 diff --git a/doc/build_and_install/build_from_source_en.rst b/doc/build_and_install/build_from_source_en.rst index 718fb869c2..29a1439e4c 100644 --- a/doc/build_and_install/build_from_source_en.rst +++ b/doc/build_and_install/build_from_source_en.rst @@ -1,32 +1,45 @@ Build from Sources ========================== -.. _build_step: +.. _requirements: -How To Build +Requirements ---------------- -PaddlePaddle mainly uses `CMake `_ and GCC, G++ as compile -tools. We recommend you to use our pre-built Docker image to run the build -to avoid installing dependencies by yourself. We have several build environment -Docker images `here `_ . +To build PaddlePaddle, you need + +1. A computer -- Linux, Windows, MacOS. +1. Docker. + +Nothing else. Not even Python and GCC, because you can install all build tools into a Docker image. +We run all the tools by running this image. + +.. _build_step: -If you choose not to use Docker image for your build, you need to install the -below `Compile Dependencies`_ before run the build. +How To Build +---------------- -Then run: +You need to use Docker to build PaddlePaddle +to avoid installing dependencies by yourself. We have several pre-built +Docker images `here `_ , +Or you can build your own image from source as the optional step below: .. code-block:: bash + # 1. clone the source code git clone https://github.com/PaddlePaddle/Paddle.git cd Paddle - # run the following command to build a CPU-Only binaries if you are using docker + # 2. Optional: build development docker image from source + docker build -t paddle:dev . + # 3. Run the following command to build a CPU-Only binaries docker run -it -v $PWD:/paddle -e "WITH_GPU=OFF" -e "WITH_TESTING=OFF" paddlepaddle/paddle_manylinux_devel:cuda8.0_cudnn5 bash -x /paddle/paddle/scripts/docker/build.sh - # else run these commands - mkdir build - cd build - cmake -DWITH_GPU=OFF -DWITH_TESTING=OFF .. - make + # 4. Or, use your built Docker image to build PaddlePaddle (must run step 2) + docker run -it -v $PWD:/paddle -e "WITH_GPU=OFF" -e "WITH_TESTING=OFF" paddle:dev + +NOTE: The above command try to mount the current working directory (root directory of source code) +into :code:`/paddle` directory inside docker container. If you are using your own image +(Step 4) it will run default entry-point :code:`build.sh` , so you could omit the last +command in step 3. When the compile finishes, you can get the output whl package under build/python/dist, then you can choose to install the whl on local @@ -61,22 +74,75 @@ Set :code:`WITH_GPU=ON` Can also run tests on GPU. docker run -it -v $PWD:/paddle -e "WITH_GPU=OFF" -e "WITH_TESTING=ON" -e "RUN_TEST=ON" paddlepaddle/paddle_manylinux_devel:cuda8.0_cudnn5 bash -x paddle/paddle/scripts/docker/build.sh -If you don't use Docker, just run ctest will start the tests: +If you wish to run only one unit test, like :code:`test_sum_op`: .. code-block:: bash - mkdir build - cd build - cmake -DWITH_GPU=OFF -DWITH_TESTING=ON .. - make - ctest - # run a single test like test_mul_op - ctest -R test_mul_op + docker run -it -v $PWD:/paddle -e "WITH_GPU=OFF" -e "WITH_TESTING=ON" -e "RUN_TEST=OFF" paddlepaddle/paddle_manylinux_devel:cuda8.0_cudnn5 /bin/bash + bash /paddle/paddle/scripts/docker/build.sh + cd /paddle/build + ctest -R test_sum_op -V + +.. _faq_docker: + +Frequently Asked Questions +---------------- + +- What is Docker? + + If you haven't heard of it, consider it something like Python's virtualenv. + +- Docker or virtual machine? + + Some people compare Docker with VMs, but Docker doesn't virtualize any hardware nor running a guest OS, which means there is no compromise on the performance. + +- Why Docker? + + Using a Docker image of build tools standardizes the building environment, which makes it easier for others to reproduce your problems and to help. + + Also, some build tools don't run on Windows or Mac or BSD, but Docker runs almost everywhere, so developers can use whatever computer they want. +- Can I choose not to use Docker? + + Sure, you don't have to install build tools into a Docker image; instead, you can install them on your local computer. This document exists because Docker would make the development way easier. + +- How difficult is it to learn Docker? + + It takes you ten minutes to read [an introductory article](https://docs.docker.com/get-started) and saves you more than one hour to install all required build tools, configure them, especially when new versions of PaddlePaddle require some new tools. Not even to mention the time saved when other people trying to reproduce the issue you have. + +- Can I use my favorite IDE? + + Yes, of course. The source code resides on your local computer, and you can edit it using whatever editor you like. + + Many PaddlePaddle developers are using Emacs. They add the following few lines into their `~/.emacs` configure file: + + ```emacs + (global-set-key "\C-cc" 'compile) + (setq compile-command + "docker run --rm -it -v $(git rev-parse --show-toplevel):/paddle paddle:dev") + ``` + + so they could type `Ctrl-C` and `c` to build PaddlePaddle from source. + +- Does Docker do parallel building? + + Our building Docker image runs a [Bash script](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/scripts/docker/build.sh), which calls `make -j$(nproc)` to starts as many processes as the number of your CPU cores. + +- Docker requires sudo + + An owner of a computer has the administrative privilege, a.k.a., sudo, and Docker requires this privilege to work properly. If you use a shared computer for development, please ask the administrator to install and configure Docker. We will do our best to support rkt, another container technology that doesn't require sudo. + +- Docker on Windows/MacOS builds slowly + + On Windows and MacOS, Docker containers run in a Linux VM. You might want to give this VM some more memory and CPUs so to make the building efficient. Please refer to [this issue](https://github.com/PaddlePaddle/Paddle/issues/627) for details. + +- Not enough disk space + + Examples in this article use option `--rm` with the `docker run` command. This option ensures that stopped containers do not exist on hard disks. We can use `docker ps -a` to list all containers, including stopped. Sometimes `docker build` generates some intermediate dangling images, which also take disk space. To clean them, please refer to [this article](https://zaiste.net/posts/removing_docker_containers/). .. _compile_deps: -Compile Dependencies +Appendix: Compile Dependencies ---------------- PaddlePaddle need the following dependencies when compiling, other dependencies @@ -97,17 +163,13 @@ will be downloaded automatically. .. _build_options: -Build Options +Appendix: Build Options ---------------- Build options include whether build binaries for CPU or GPU, which BLAS library to use etc. You may pass these settings when running cmake. For detailed cmake tutorial please refer to `here `_ 。 -.. _build_options_bool: - -Bool Type Options ----------------- You can add :code:`-D` argument to pass such options, like: From d641d5ac336d29471fe5206d45f717ef9cc62f4e Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Sun, 11 Feb 2018 17:09:53 +0800 Subject: [PATCH 31/43] follow comments --- python/paddle/v2/fluid/layers/detection.py | 212 ++++++++---------- python/paddle/v2/fluid/layers/nn.py | 87 ------- .../object_detection/test_prior_boxes.py | 12 +- 3 files changed, 93 insertions(+), 218 deletions(-) diff --git a/python/paddle/v2/fluid/layers/detection.py b/python/paddle/v2/fluid/layers/detection.py index b0c25c11de..cc38796042 100644 --- a/python/paddle/v2/fluid/layers/detection.py +++ b/python/paddle/v2/fluid/layers/detection.py @@ -17,11 +17,8 @@ All layers just related to the detection neural network. from ..layer_helper import LayerHelper from ..framework import Variable -from ..param_attr import ParamAttr -from ..framework import Variable -from layer_function_generator import autodoc from tensor import concat -from nn import flatten +from ops import reshape import math __all__ = [ @@ -30,91 +27,6 @@ __all__ = [ ] -def prior_box(input, - image, - min_sizes, - max_sizes, - aspect_ratios, - variance, - flip=False, - clip=False, - step_w=0.0, - step_h=0.0, - offset=0.5, - name=None): - """ - **Prior_box** - - Generate prior boxes for SSD(Single Shot MultiBox Detector) algorithm. - Each position of the input produce N prior boxes, N is determined by - the count of min_sizes, max_sizes and aspect_ratios, The size of the - box is in range(min_size, max_size) interval, which is generated in - sequence according to the aspect_ratios. - - Args: - input(variable): The input feature data of PriorBox, - the layout is NCHW. - image(variable): The input image data of PriorBox, the - layout is NCHW. - min_sizes(list): the min sizes of generated prior boxes. - max_sizes(list): the max sizes of generated prior boxes. - aspect_ratios(list): the aspect ratios of generated prior boxes. - variance(list): the variances to be encoded in prior boxes. - flip(bool, optional, default=False): Whether to flip aspect ratios. - clip(bool, optional, default=False)): Whether to clip - out-of-boundary boxes. - step_w(int, optional, default=0.0): Prior boxes step across - width, 0.0 for auto calculation. - step_h(int, optional, default=0.0): Prior boxes step across - height, 0.0 for auto calculation. - offset(float, optional, default=0.5): Prior boxes center offset. - name(str, optional, default=None): Name of the prior box layer. - - Returns: - boxes(variable): the output prior boxes of PriorBoxOp. The layout is - [H, W, num_priors, 4]. H is the height of input, W is the width - of input, num_priors is the box count of each position. Where num_priors = - len(aspect_ratios) * len(min_sizes) + len(max_sizes) - Variances(variable): the expanded variances of PriorBoxOp. The layout - is [H, W, num_priors, 4]. H is the height of input, W is the width - of input, num_priors is the box count of each position. Where num_priors = - len(aspect_ratios) * len(min_sizes) + len(max_sizes) - Examples: - .. code-block:: python - - data = fluid.layers.data(name="data", shape=[3, 32, 32], dtype="float32") - conv2d = fluid.layers.conv2d( - input=data, num_filters=2, filter_size=3) - box, var = fluid.layers.prior_box(conv2d, data, - min_size, max_size, aspect_ratio, - variance, flip, clip, - step_w, step_h, offset) - """ - helper = LayerHelper("prior_box", **locals()) - dtype = helper.input_dtype() - - box = helper.create_tmp_variable(dtype) - var = helper.create_tmp_variable(dtype) - helper.append_op( - type="prior_box", - inputs={"Input": input, - "Image": image}, - outputs={"Boxes": box, - "Variances": var}, - attrs={ - 'min_sizes': min_sizes, - 'max_sizes': max_sizes, - 'aspect_ratios': aspect_ratios, - 'variances': variance, - 'flip': flip, - 'clip': clip, - 'step_w': step_w, - 'step_h': step_h, - 'offset': offset - }) - return box, var - - def prior_boxes(inputs, image, min_ratio, @@ -128,20 +40,19 @@ def prior_boxes(inputs, variance=[0.1, 0.1, 0.1, 0.1], flip=False, clip=False, + min_sizes=None, + max_sizes=None, name=None): """ **Prior_boxes** Generate prior boxes for SSD(Single Shot MultiBox Detector) algorithm. - Each position of the inputs produces many prior boxes respectly, the number - of prior boxes which is produced by inputs respectly is determined by - the count of min_ratio, max_ratio and aspect_ratios, The size of the - box is in range(min_ratio, max_ratio) interval, which is generated in - sequence according to the aspect_ratios. + The details of this algorithm, please refer the section 2.2 of SSD paper + (SSD: Single Shot MultiBox Detector)`_ . Args: - inputs(list): The list of input variables, the format of all variables is NCHW. - image(variable): The input image data of PriorBoxOp, the layout is NCHW. + inputs(list): The list of input Variables, the format of all Variables is NCHW. + image(Variable): The input image data of PriorBoxOp, the layout is NCHW. min_ratio(int): the min ratio of generated prior boxes. max_ratio(int): the max ratio of generated prior boxes. aspect_ratios(list): the aspect ratios of generated prior boxes. @@ -159,13 +70,17 @@ def prior_boxes(inputs, to be encoded in prior boxes. flip(bool, optional, default=False): Whether to flip aspect ratios. clip(bool, optional, default=False): Whether to clip out-of-boundary boxes. + min_sizes(list, optional, default=None): If `len(inputs) <=2`, min_sizes must + be set up, and the length of min_sizes should equal to the length of inputs. + max_sizes(list, optional, default=None): If `len(inputs) <=2`, max_sizes must + be set up, and the length of min_sizes should equal to the length of inputs. name(str, optional, None): Name of the prior box layer. Returns: - boxes(variable): the output prior boxes of PriorBoxOp. The layout is + boxes(Variable): the output prior boxes of PriorBoxOp. The layout is [num_priors, 4]. num_priors is the total box count of each position of inputs. - Variances(variable): the expanded variances of PriorBoxOp. The layout + Variances(Variable): the expanded variances of PriorBoxOp. The layout is [num_priors, 4]. num_priors is the total box count of each position of inputs @@ -185,13 +100,60 @@ def prior_boxes(inputs, flip=True, clip=True) """ + + def _prior_box_(input, + image, + min_sizes, + max_sizes, + aspect_ratios, + variance, + flip=False, + clip=False, + step_w=0.0, + step_h=0.0, + offset=0.5, + name=None): + helper = LayerHelper("prior_box", **locals()) + dtype = helper.input_dtype() + + box = helper.create_tmp_variable(dtype) + var = helper.create_tmp_variable(dtype) + helper.append_op( + type="prior_box", + inputs={"Input": input, + "Image": image}, + outputs={"Boxes": box, + "Variances": var}, + attrs={ + 'min_sizes': min_sizes, + 'max_sizes': max_sizes, + 'aspect_ratios': aspect_ratios, + 'variances': variance, + 'flip': flip, + 'clip': clip, + 'step_w': step_w, + 'step_h': step_h, + 'offset': offset + }) + return box, var + + def _reshape_with_axis_(input, axis=1): + if not (axis > 0 and axis < len(input.shape)): + raise ValueError( + "The axis should be smaller than the arity of input's shape.") + new_shape = [-1, reduce(mul, input.shape[axis:len(input.shape)], 1)] + out = reshape([input], shape=new_shape) + return out + assert isinstance(inputs, list), 'inputs should be a list.' num_layer = len(inputs) - assert num_layer > 2 # TODO(zcd): currently, num_layer must be bigger than two. - min_sizes = [] - max_sizes = [] - if num_layer > 2: + if num_layer <= 2: + assert min_sizes is not None and max_sizes is not None + assert len(min_sizes) == num_layer and len(max_sizes) == num_layer + else: + min_sizes = [] + max_sizes = [] step = int(math.floor(((max_ratio - min_ratio)) / (num_layer - 2))) for ratio in xrange(min_ratio, max_ratio + 1, step): min_sizes.append(base_size * ratio / 100.) @@ -199,21 +161,29 @@ def prior_boxes(inputs, min_sizes = [base_size * .10] + min_sizes max_sizes = [base_size * .20] + max_sizes + if aspect_ratios: + if not (isinstance(aspect_ratios, list) and + len(aspect_ratios) == num_layer): + raise ValueError( + 'aspect_ratios should be list and the length of inputs ' + 'and aspect_ratios should be the same.') if step_h: - assert isinstance(step_h,list) and len(step_h) == num_layer, \ - 'step_h should be list and inputs and step_h should have same length' + if not (isinstance(step_h, list) and len(step_h) == num_layer): + raise ValueError( + 'step_h should be list and the length of inputs and ' + 'step_h should be the same.') if step_w: - assert isinstance(step_w,list) and len(step_w) == num_layer, \ - 'step_w should be list and inputs and step_w should have same length' + if not (isinstance(step_w, list) and len(step_w) == num_layer): + raise ValueError( + 'step_w should be list and the length of inputs and ' + 'step_w should be the same.') if steps: - assert isinstance(steps,list) and len(steps) == num_layer, \ - 'steps should be list and inputs and step_w should have same length' + if not (isinstance(steps, list) and len(steps) == num_layer): + raise ValueError( + 'steps should be list and the length of inputs and ' + 'step_w should be the same.') step_w = steps step_h = steps - if aspect_ratios: - assert isinstance(aspect_ratios, list) and len(aspect_ratios) == num_layer, \ - 'aspect_ratios should be list and inputs and aspect_ratios should ' \ - 'have same length' box_results = [] var_results = [] @@ -230,10 +200,10 @@ def prior_boxes(inputs, if not isinstance(aspect_ratio, list): aspect_ratio = [aspect_ratio] - box, var = prior_box(input, image, min_size, max_size, aspect_ratio, - variance, flip, clip, step_w[i] - if step_w else 0.0, step_h[i] - if step_w else 0.0, offset) + box, var = _prior_box_(input, image, min_size, max_size, aspect_ratio, + variance, flip, clip, step_w[i] + if step_w else 0.0, step_h[i] + if step_w else 0.0, offset) box_results.append(box) var_results.append(var) @@ -242,17 +212,11 @@ def prior_boxes(inputs, box = box_results[0] var = var_results[0] else: - axis = 3 reshaped_boxes = [] reshaped_vars = [] for i in range(len(box_results)): - reshaped_boxes += [flatten(box_results[i], axis=3)] - reshaped_vars += [flatten(var_results[i], axis=3)] - - helper = LayerHelper("concat", **locals()) - dtype = helper.input_dtype() - box = helper.create_tmp_variable(dtype) - var = helper.create_tmp_variable(dtype) + reshaped_boxes.append(_reshape_with_axis_(box_results[i], axis=3)) + reshaped_vars.append(_reshape_with_axis_(var_results[i], axis=3)) box = concat(reshaped_boxes) var = concat(reshaped_vars) diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py index 4d2de38c35..5ebd329fc0 100644 --- a/python/paddle/v2/fluid/layers/nn.py +++ b/python/paddle/v2/fluid/layers/nn.py @@ -21,8 +21,6 @@ from ..framework import Variable from ..param_attr import ParamAttr from layer_function_generator import autodoc from tensor import concat -import math -from operator import mul __all__ = [ 'fc', @@ -66,8 +64,6 @@ __all__ = [ 'nce', 'beam_search', 'row_conv', - 'reshape_with_axis', - 'flatten', 'multiplex', 'layer_norm', ] @@ -3095,86 +3091,3 @@ def multiplex(inputs, index): 'Ids': index}, outputs={'Out': [out]}) return out - - -def reshape_with_axis(input, axis): - """ - **ReshapeWithAxis Layer** - - ReshapeWithAxis is used to merge adjacent dimensions according to axis. - - Args: - input(variable): The input tensor. - axis(list): The axis which is used to merge the adjacent dimensions. - - Returns: - Variable: A tensor variable. - - Examples: - .. code-block:: python - - x = fluid.layers.data(name="data", shape=[3, 32, 32], dtype="float32") - reshaped = fluid.layers.reshape_with_axis(input=x, axis=[2]) - reshaped.shape - >> [-1, 1024] - reshaped = fluid.layers.reshape_with_axis(input=x, axis=[1,3]) - reshaped.shape - >> [-1, 96, 32] - """ - assert isinstance(axis, list), "axis should be list." - assert len(input.shape) > len( - axis), "the length of axis should be litter than input.shape's." - input_shape = input.shape - temp = 0 - for ax in axis: - assert ax < len(input.shape) and ax > 0, \ - 'The data of Axis should be between 1 and len(input.shape)' - assert ax > temp, 'Axis should be incremented sequence' - temp = ax - axis += [len(input.shape)] - - new_shape = [] - for i in range(len(axis) - 1): - new_shape += [reduce(mul, input_shape[axis[i]:axis[i + 1]], 1)] - new_shape = [-1] + new_shape - - helper = LayerHelper('reshape', **locals()) - out = helper.create_tmp_variable(helper.input_dtype()) - helper.append_op( - type='reshape', - inputs={'X': [input]}, - outputs={'Out': [out]}, - attrs={'shape': new_shape}) - return out - - -def flatten(input, axis=1): - """ - **Flatten Layer** - ReshapeWithAxis is used to merge adjacent dimensions according to axis. - Args: - input(variable): The input tensor. - axis(int): - Returns: - Variable: A tensor variable. - Examples: - .. code-block:: python - x = fluid.layers.data(name="data", shape=[3, 32, 32], dtype="float32") - reshaped = fluid.layers.reshape_with_axis(input=x, axis=2) - reshaped.shape - >> [-1, 1024] - """ - assert len(input.shape) > axis and axis > 0, \ - "the axis should be litter than input.shape's." - input_shape = input.shape - - new_shape = [-1, reduce(mul, input_shape[axis:len(input_shape)], 1)] - - helper = LayerHelper('reshape', **locals()) - out = helper.create_tmp_variable(helper.input_dtype()) - helper.append_op( - type='reshape', - inputs={'X': [input]}, - outputs={'Out': [out]}, - attrs={'shape': new_shape}) - return out diff --git a/python/paddle/v2/fluid/tests/object_detection/test_prior_boxes.py b/python/paddle/v2/fluid/tests/object_detection/test_prior_boxes.py index 50b5249d98..1b093c6463 100644 --- a/python/paddle/v2/fluid/tests/object_detection/test_prior_boxes.py +++ b/python/paddle/v2/fluid/tests/object_detection/test_prior_boxes.py @@ -51,15 +51,15 @@ def main(use_cuda): if use_cuda: # prior_box only support CPU. return - box, var = prior_box_output(data_shape=[3, 224, 224]) + data_shape = [3, 224, 224] + box, var = prior_box_output(data_shape) place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() exe = fluid.Executor(place) exe.run(fluid.default_startup_program()) batch = [128] - for i in range(1): - # print("iteration : %d" % i) + for _ in range(1): x = np.random.random(batch + data_shape).astype("float32") tensor_x = core.LoDTensor() tensor_x.set(x, place) @@ -75,12 +75,10 @@ def main(use_cuda): class TestFitALine(unittest.TestCase): def test_cpu(self): - with self.program_scope_guard(): - main(use_cuda=False) + main(use_cuda=False) def test_cuda(self): - with self.program_scope_guard(): - main(use_cuda=True) + main(use_cuda=True) if __name__ == '__main__': From 892cc28c7b7c77ead20f17d5644a4e9482906404 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Sun, 11 Feb 2018 20:03:07 +0800 Subject: [PATCH 32/43] Fix bug --- paddle/fluid/framework/mixed_vector.h | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/framework/mixed_vector.h b/paddle/fluid/framework/mixed_vector.h index a06e34d551..6e5ceefadd 100644 --- a/paddle/fluid/framework/mixed_vector.h +++ b/paddle/fluid/framework/mixed_vector.h @@ -106,9 +106,11 @@ class Vector { // std::vector iterator methods. Based on CPU data access method size_t size() const { return size_; } - T* begin() { return size() == 0 ? &EmptyDummy() : &this->operator[](0); } + T* begin() { return capacity() == 0 ? &EmptyDummy() : &this->operator[](0); } - T* end() { return size() == 0 ? &EmptyDummy() : &this->operator[](size()); } + T* end() { + return capacity() == 0 ? &EmptyDummy() : &this->operator[](size()); + } T& front() { return *begin(); } @@ -119,11 +121,11 @@ class Vector { } const T* begin() const { - return size() == 0 ? &EmptyDummy() : &this->operator[](0); + return capacity() == 0 ? &EmptyDummy() : &this->operator[](0); } const T* end() const { - return size() == 0 ? &EmptyDummy() : &this->operator[](size()); + return capacity() == 0 ? &EmptyDummy() : &this->operator[](size()); } const T* cbegin() const { return begin(); } From 1dceb99e86052355fea7c83d9f636ea681aa8d18 Mon Sep 17 00:00:00 2001 From: Yuan Gao Date: Sun, 11 Feb 2018 20:17:26 +0800 Subject: [PATCH 33/43] add detection output python api (#8389) --- python/paddle/v2/fluid/layers/__init__.py | 3 + python/paddle/v2/fluid/layers/detection.py | 116 ++++++++++++++++++ .../paddle/v2/fluid/tests/test_detection.py | 53 ++++++++ 3 files changed, 172 insertions(+) create mode 100644 python/paddle/v2/fluid/layers/detection.py create mode 100644 python/paddle/v2/fluid/tests/test_detection.py diff --git a/python/paddle/v2/fluid/layers/__init__.py b/python/paddle/v2/fluid/layers/__init__.py index a83dd3db74..89b9f30668 100644 --- a/python/paddle/v2/fluid/layers/__init__.py +++ b/python/paddle/v2/fluid/layers/__init__.py @@ -16,6 +16,8 @@ import ops from ops import * import nn from nn import * +import detection +from detection import * import io from io import * import tensor @@ -28,6 +30,7 @@ import math_op_patch from math_op_patch import * __all__ = [] +__all__ += detection.__all__ __all__ += nn.__all__ __all__ += io.__all__ __all__ += tensor.__all__ diff --git a/python/paddle/v2/fluid/layers/detection.py b/python/paddle/v2/fluid/layers/detection.py new file mode 100644 index 0000000000..054443cb43 --- /dev/null +++ b/python/paddle/v2/fluid/layers/detection.py @@ -0,0 +1,116 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +All layers just related to the detection neural network. +""" + +from ..layer_helper import LayerHelper + +__all__ = ['detection_output', ] + + +def detection_output(scores, + loc, + prior_box, + prior_box_var, + background_label=0, + nms_threshold=0.3, + nms_top_k=400, + keep_top_k=200, + score_threshold=0.01, + nms_eta=1.0): + """ + **Detection Output Layer** + + This layer applies the NMS to the output of network and computes the + predict bounding box location. The output's shape of this layer could + be zero if there is no valid bounding box. + + Args: + scores(Variable): A 3-D Tensor with shape [N, C, M] represents the + predicted confidence predictions. N is the batch size, C is the + class number, M is number of bounding boxes. For each category + there are total M scores which corresponding M bounding boxes. + loc(Variable): A 3-D Tensor with shape [N, M, 4] represents the + predicted locations of M bounding bboxes. N is the batch size, + and each bounding box has four coordinate values and the layout + is [xmin, ymin, xmax, ymax]. + prior_box(Variable): A 2-D Tensor with shape [M, 4] holds M boxes, + each box is represented as [xmin, ymin, xmax, ymax], + [xmin, ymin] is the left top coordinate of the anchor box, + if the input is image feature map, they are close to the origin + of the coordinate system. [xmax, ymax] is the right bottom + coordinate of the anchor box. + prior_box_var(Variable): A 2-D Tensor with shape [M, 4] holds M group + of variance. + background_label(float): The index of background label, + the background label will be ignored. If set to -1, then all + categories will be considered. + nms_threshold(float): The threshold to be used in NMS. + nms_top_k(int): Maximum number of detections to be kept according + to the confidences aftern the filtering detections based on + score_threshold. + keep_top_k(int): Number of total bboxes to be kept per image after + NMS step. -1 means keeping all bboxes after NMS step. + score_threshold(float): Threshold to filter out bounding boxes with + low confidence score. If not provided, consider all boxes. + nms_eta(float): The parameter for adaptive NMS. + + Returns: + The detected bounding boxes which are a Tensor. + + Examples: + .. code-block:: python + + pb = layers.data(name='prior_box', shape=[10, 4], + append_batch_size=False, dtype='float32') + pbv = layers.data(name='prior_box_var', shape=[10, 4], + append_batch_size=False, dtype='float32') + loc = layers.data(name='target_box', shape=[21, 4], + append_batch_size=False, dtype='float32') + scores = layers.data(name='scores', shape=[2, 21, 10], + append_batch_size=False, dtype='float32') + nmsed_outs = fluid.layers.detection_output(scores=scores, + loc=loc, + prior_box=pb, + prior_box_var=pbv) + """ + + helper = LayerHelper("detection_output", **locals()) + decoded_box = helper.create_tmp_variable(dtype=loc.dtype) + helper.append_op( + type="box_coder", + inputs={ + 'PriorBox': prior_box, + 'PriorBoxVar': prior_box_var, + 'TargetBox': loc + }, + outputs={'OutputBox': decoded_box}, + attrs={'code_type': 'decode_center_size'}) + nmsed_outs = helper.create_tmp_variable(dtype=decoded_box.dtype) + + helper.append_op( + type="multiclass_nms", + inputs={'Scores': scores, + 'BBoxes': decoded_box}, + outputs={'Out': nmsed_outs}, + attrs={ + 'background_label': 0, + 'nms_threshold': nms_threshold, + 'nms_top_k': nms_top_k, + 'keep_top_k': keep_top_k, + 'score_threshold': score_threshold, + 'nms_eta': 1.0 + }) + return nmsed_outs diff --git a/python/paddle/v2/fluid/tests/test_detection.py b/python/paddle/v2/fluid/tests/test_detection.py new file mode 100644 index 0000000000..75498ad770 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_detection.py @@ -0,0 +1,53 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest + +import paddle.v2.fluid.layers as layers +from paddle.v2.fluid.framework import Program, program_guard + + +class TestBook(unittest.TestCase): + def test_detection_output(self): + program = Program() + with program_guard(program): + pb = layers.data( + name='prior_box', + shape=[10, 4], + append_batch_size=False, + dtype='float32') + pbv = layers.data( + name='prior_box_var', + shape=[10, 4], + append_batch_size=False, + dtype='float32') + loc = layers.data( + name='target_box', + shape=[20, 4], + append_batch_size=False, + dtype='float32') + scores = layers.data( + name='scores', + shape=[2, 20, 10], + append_batch_size=False, + dtype='float32') + out = layers.detection_output( + scores=scores, loc=loc, prior_box=pb, prior_box_var=pbv) + self.assertIsNotNone(out) + print(str(program)) + + +if __name__ == '__main__': + unittest.main() From 77a6e1c670aece73d208222e483a76fbfe361cd6 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Sun, 11 Feb 2018 21:28:14 +0800 Subject: [PATCH 34/43] Disable unstable tests --- .../{test_rnn_encoder_decoder.py => notest_rnn_encoder_decoer.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename python/paddle/v2/fluid/tests/book/{test_rnn_encoder_decoder.py => notest_rnn_encoder_decoer.py} (100%) diff --git a/python/paddle/v2/fluid/tests/book/test_rnn_encoder_decoder.py b/python/paddle/v2/fluid/tests/book/notest_rnn_encoder_decoer.py similarity index 100% rename from python/paddle/v2/fluid/tests/book/test_rnn_encoder_decoder.py rename to python/paddle/v2/fluid/tests/book/notest_rnn_encoder_decoer.py From 6f625f9c2f8861c1ec4c345e6abc33b3936cc080 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Sun, 11 Feb 2018 22:35:33 +0800 Subject: [PATCH 35/43] Disable unstable unittest --- paddle/fluid/inference/tests/book/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/inference/tests/book/CMakeLists.txt b/paddle/fluid/inference/tests/book/CMakeLists.txt index 9fe76afb58..cddd5a786c 100644 --- a/paddle/fluid/inference/tests/book/CMakeLists.txt +++ b/paddle/fluid/inference/tests/book/CMakeLists.txt @@ -29,6 +29,6 @@ inference_test(image_classification ARGS vgg resnet) inference_test(label_semantic_roles) inference_test(recognize_digits ARGS mlp) inference_test(recommender_system) -inference_test(rnn_encoder_decoder) +#inference_test(rnn_encoder_decoder) inference_test(understand_sentiment) inference_test(word2vec) From bbff442eee03df799edc74bc354ff16ad77684ca Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Sun, 11 Feb 2018 22:19:14 +0800 Subject: [PATCH 36/43] follow comments of qingqing --- python/paddle/v2/fluid/layers/detection.py | 47 +++++++++---------- .../test_prior_boxes.py | 22 +++++---- 2 files changed, 35 insertions(+), 34 deletions(-) rename python/paddle/v2/fluid/tests/{object_detection => }/test_prior_boxes.py (84%) diff --git a/python/paddle/v2/fluid/layers/detection.py b/python/paddle/v2/fluid/layers/detection.py index cc38796042..657f3e22fb 100644 --- a/python/paddle/v2/fluid/layers/detection.py +++ b/python/paddle/v2/fluid/layers/detection.py @@ -19,30 +19,28 @@ from ..layer_helper import LayerHelper from ..framework import Variable from tensor import concat from ops import reshape +from operator import mul import math -__all__ = [ - 'prior_box', - 'prior_boxes', -] - - -def prior_boxes(inputs, - image, - min_ratio, - max_ratio, - aspect_ratios, - base_size, - steps=None, - step_w=None, - step_h=None, - offset=0.5, - variance=[0.1, 0.1, 0.1, 0.1], - flip=False, - clip=False, - min_sizes=None, - max_sizes=None, - name=None): +__all__ = ['prior_box', ] + + +def prior_box(inputs, + image, + min_ratio, + max_ratio, + aspect_ratios, + base_size, + steps=None, + step_w=None, + step_h=None, + offset=0.5, + variance=[0.1, 0.1, 0.1, 0.1], + flip=False, + clip=False, + min_sizes=None, + max_sizes=None, + name=None): """ **Prior_boxes** @@ -140,9 +138,10 @@ def prior_boxes(inputs, def _reshape_with_axis_(input, axis=1): if not (axis > 0 and axis < len(input.shape)): raise ValueError( - "The axis should be smaller than the arity of input's shape.") + "The axis should be smaller than the arity of input and bigger than 0." + ) new_shape = [-1, reduce(mul, input.shape[axis:len(input.shape)], 1)] - out = reshape([input], shape=new_shape) + out = reshape(x=input, shape=new_shape) return out assert isinstance(inputs, list), 'inputs should be a list.' diff --git a/python/paddle/v2/fluid/tests/object_detection/test_prior_boxes.py b/python/paddle/v2/fluid/tests/test_prior_boxes.py similarity index 84% rename from python/paddle/v2/fluid/tests/object_detection/test_prior_boxes.py rename to python/paddle/v2/fluid/tests/test_prior_boxes.py index 1b093c6463..74d292020c 100644 --- a/python/paddle/v2/fluid/tests/object_detection/test_prior_boxes.py +++ b/python/paddle/v2/fluid/tests/test_prior_boxes.py @@ -33,7 +33,7 @@ def prior_box_output(data_shape): conv5 = fluid.layers.conv2d( input=conv4, num_filters=3, filter_size=3, stride=2, use_cudnn=False) - box, var = detection.prior_boxes( + box, var = detection.prior_box( inputs=[conv1, conv2, conv3, conv4, conv5, conv5], image=images, min_ratio=20, @@ -57,20 +57,22 @@ def main(use_cuda): place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() exe = fluid.Executor(place) exe.run(fluid.default_startup_program()) - batch = [128] + batch = [4] # batch is not used in the prior_box. + + assert box.shape[1] == 4 + assert var.shape[1] == 4 + assert box.shape == var.shape + assert len(box.shape) == 2 for _ in range(1): x = np.random.random(batch + data_shape).astype("float32") tensor_x = core.LoDTensor() tensor_x.set(x, place) - box, var = exe.run(fluid.default_main_program(), - feed={'pixel': tensor_x}, - fetch_list=[box, var]) - box_arr = np.array(box) - var_arr = np.array(var) - assert box_arr.shape[1] == 4 - assert var_arr.shape[1] == 4 - assert box_arr.shape[0] == var_arr.shape[0] + boxes, vars = exe.run(fluid.default_main_program(), + feed={'pixel': tensor_x}, + fetch_list=[box, var]) + assert vars.shape == var.shape + assert boxes.shape == box.shape class TestFitALine(unittest.TestCase): From afe63228682aa43518f6df5d183a62fa79fbcce7 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Mon, 12 Feb 2018 09:50:13 +0800 Subject: [PATCH 37/43] follow comment --- doc/build_and_install/index_cn.rst | 1 - doc/build_and_install/index_en.rst | 2 -- 2 files changed, 3 deletions(-) diff --git a/doc/build_and_install/index_cn.rst b/doc/build_and_install/index_cn.rst index 4220ff2279..c0b60f5589 100644 --- a/doc/build_and_install/index_cn.rst +++ b/doc/build_and_install/index_cn.rst @@ -13,7 +13,6 @@ PaddlePaddle提供pip和Docker的安装方式: pip_install_cn.rst docker_install_cn.rst - build_cn.md 编译流程 ++++++++ diff --git a/doc/build_and_install/index_en.rst b/doc/build_and_install/index_en.rst index db6b5be742..7e0ca5bcbd 100644 --- a/doc/build_and_install/index_en.rst +++ b/doc/build_and_install/index_en.rst @@ -13,8 +13,6 @@ You can choose either pip or Docker to complete your install: pip_install_en.rst docker_install_en.rst - build_en.md - Build from Source ----------------- From 9a05c9075043345e34b4461ded2ce92ba6501ae4 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Mon, 12 Feb 2018 10:38:31 +0800 Subject: [PATCH 38/43] fix StridedNumelCopyWithAxis --- paddle/fluid/operators/strided_memcpy.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/strided_memcpy.h b/paddle/fluid/operators/strided_memcpy.h index 385124305e..4036d1091d 100644 --- a/paddle/fluid/operators/strided_memcpy.h +++ b/paddle/fluid/operators/strided_memcpy.h @@ -58,6 +58,7 @@ inline void StridedNumelCopyWithAxis(const platform::DeviceContext& ctx, int64_t before = dst_stride_numel[0] / dst_stride_numel[axis]; int64_t src_after = src_stride_numel[axis]; int64_t dst_after = dst_stride_numel[axis]; + int64_t copy_size = std::min(src_after, dst_after); auto place = ctx.GetPlace(); PADDLE_ENFORCE_EQ(src_stride_numel.size(), dst_stride_numel.size(), @@ -82,14 +83,14 @@ inline void StridedNumelCopyWithAxis(const platform::DeviceContext& ctx, if (platform::is_cpu_place(place)) { auto& cpu_place = boost::get(place); memory::Copy(cpu_place, dst + i * dst_after, cpu_place, - src + i * src_after, sizeof(T) * src_after); + src + i * src_after, sizeof(T) * copy_size); } else { #ifdef PADDLE_WITH_CUDA auto& gpu_place = boost::get(place); auto& cuda_ctx = reinterpret_cast(ctx); memory::Copy(gpu_place, dst + i * dst_after, gpu_place, - src + i * src_after, sizeof(T) * src_after, + src + i * src_after, sizeof(T) * copy_size, cuda_ctx.stream()); #else PADDLE_THROW("Paddle is not compiled with GPU"); From 91a2188301b82151560c59501cca45785d34cfcb Mon Sep 17 00:00:00 2001 From: wanghaox Date: Mon, 12 Feb 2018 10:39:59 +0800 Subject: [PATCH 39/43] update detection_map --- paddle/fluid/operators/detection_map_op.cc | 98 ++++++++++++------- paddle/fluid/operators/detection_map_op.h | 44 ++++----- .../v2/fluid/tests/test_detection_map_op.py | 14 ++- 3 files changed, 87 insertions(+), 69 deletions(-) diff --git a/paddle/fluid/operators/detection_map_op.cc b/paddle/fluid/operators/detection_map_op.cc index cc4b6202c0..48308a11b4 100644 --- a/paddle/fluid/operators/detection_map_op.cc +++ b/paddle/fluid/operators/detection_map_op.cc @@ -24,25 +24,28 @@ class DetectionMAPOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("Detection"), - "Input(Detection) of DetectionMAPOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("DetectRes"), + "Input(DetectRes) of DetectionMAPOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) of DetectionMAPOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("OutPosCount"), - "Output(OutPosCount) of DetectionMAPOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("OutTruePos"), - "Output(OutTruePos) of DetectionMAPOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("OutFalsePos"), - "Output(OutFalsePos) of DetectionMAPOp should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("AccumPosCount"), + "Output(AccumPosCount) of DetectionMAPOp should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("AccumTruePos"), + "Output(AccumTruePos) of DetectionMAPOp should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("AccumFalsePos"), + "Output(AccumFalsePos) of DetectionMAPOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("MAP"), "Output(MAP) of DetectionMAPOp should not be null."); - auto det_dims = ctx->GetInputDim("Detection"); + auto det_dims = ctx->GetInputDim("DetectRes"); PADDLE_ENFORCE_EQ(det_dims.size(), 2UL, - "The rank of Input(Detection) must be 2, " + "The rank of Input(DetectRes) must be 2, " "the shape is [N, 6]."); PADDLE_ENFORCE_EQ(det_dims[1], 6UL, - "The shape is of Input(Detection) [N, 6]."); + "The shape is of Input(DetectRes) [N, 6]."); auto label_dims = ctx->GetInputDim("Label"); PADDLE_ENFORCE_EQ(label_dims.size(), 2UL, "The rank of Input(Label) must be 2, " @@ -50,8 +53,17 @@ class DetectionMAPOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(label_dims[1], 6UL, "The shape is of Input(Label) [N, 6]."); - auto map_dim = framework::make_ddim({1}); - ctx->SetOutputDim("MAP", map_dim); + if (ctx->HasInput("PosCount")) { + PADDLE_ENFORCE(ctx->HasInput("TruePos"), + "Input(TruePos) of DetectionMAPOp should not be null when " + "Input(TruePos) is not null."); + PADDLE_ENFORCE( + ctx->HasInput("FalsePos"), + "Input(FalsePos) of DetectionMAPOp should not be null when " + "Input(FalsePos) is not null."); + } + + ctx->SetOutputDim("MAP", framework::make_ddim({1})); } protected: @@ -59,7 +71,7 @@ class DetectionMAPOp : public framework::OperatorWithKernel { const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( framework::ToDataType( - ctx.Input("Detection")->type()), + ctx.Input("DetectRes")->type()), ctx.device_context()); } }; @@ -68,6 +80,14 @@ class DetectionMAPOpMaker : public framework::OpProtoAndCheckerMaker { public: DetectionMAPOpMaker(OpProto* proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("DetectRes", + "(LoDTensor) A 2-D LoDTensor with shape [M, 6] represents the " + "detections. Each row has 6 values: " + "[label, confidence, xmin, ymin, xmax, ymax], M is the total " + "number of detect results in this mini-batch. For each instance, " + "the offsets in first dimension are called LoD, the number of " + "offset is N + 1, if LoD[i + 1] - LoD[i] == 0, means there is " + "no detected data."); AddInput("Label", "(LoDTensor) A 2-D LoDTensor with shape[N, 6] represents the" "Labeled ground-truth data. Each row has 6 values: " @@ -76,38 +96,43 @@ class DetectionMAPOpMaker : public framework::OpProtoAndCheckerMaker { "instance, the offsets in first dimension are called LoD, " "the number of offset is N + 1, if LoD[i + 1] - LoD[i] == 0, " "means there is no ground-truth data."); - AddInput("Detection", - "(LoDTensor) A 2-D LoDTensor with shape [M, 6] represents the " - "detections. Each row has 6 values: " - "[label, confidence, xmin, ymin, xmax, ymax], M is the total " - "number of detections in this mini-batch. For each instance, " - "the offsets in first dimension are called LoD, the number of " - "offset is N + 1, if LoD[i + 1] - LoD[i] == 0, means there is " - "no detected data."); AddInput("PosCount", "(Tensor) A tensor with shape [Ncls, 1], store the " - "input positive example count of each class.") + "input positive example count of each class, Ncls is the count of " + "input classification. " + "This input is used to pass the AccumPosCount generated by the " + "previous mini-batch when the multi mini-batches cumulative " + "calculation carried out. " + "When the input(PosCount) is empty, the cumulative " + "calculation is not carried out, and only the results of the " + "current mini-batch are calculated.") .AsDispensable(); AddInput("TruePos", - "(LodTensor) A 2-D LodTensor with shape [Ntp, 2], store the " - "input true positive example of each class.") + "(LoDTensor) A 2-D LoDTensor with shape [Ntp, 2], store the " + "input true positive example of each class." + "This input is used to pass the AccumTruePos generated by the " + "previous mini-batch when the multi mini-batches cumulative " + "calculation carried out. ") .AsDispensable(); AddInput("FalsePos", - "(LodTensor) A 2-D LodTensor with shape [Nfp, 2], store the " - "input false positive example of each class.") + "(LoDTensor) A 2-D LoDTensor with shape [Nfp, 2], store the " + "input false positive example of each class." + "This input is used to pass the AccumFalsePos generated by the " + "previous mini-batch when the multi mini-batches cumulative " + "calculation carried out. ") .AsDispensable(); - AddOutput("OutPosCount", + AddOutput("AccumPosCount", "(Tensor) A tensor with shape [Ncls, 1], store the " "positive example count of each class. It combines the input " "input(PosCount) and the positive example count computed from " "input(Detection) and input(Label)."); - AddOutput("OutTruePos", - "(LodTensor) A LodTensor with shape [Ntp', 2], store the " + AddOutput("AccumTruePos", + "(LoDTensor) A LoDTensor with shape [Ntp', 2], store the " "true positive example of each class. It combines the " "input(TruePos) and the true positive examples computed from " "input(Detection) and input(Label)."); - AddOutput("OutFalsePos", - "(LodTensor) A LodTensor with shape [Nfp', 2], store the " + AddOutput("AccumFalsePos", + "(LoDTensor) A LoDTensor with shape [Nfp', 2], store the " "false positive example of each class. It combines the " "input(FalsePos) and the false positive examples computed from " "input(Detection) and input(Label)."); @@ -115,10 +140,11 @@ class DetectionMAPOpMaker : public framework::OpProtoAndCheckerMaker { "(Tensor) A tensor with shape [1], store the mAP evaluate " "result of the detection."); - AddAttr("overlap_threshold", - "(float) " - "The jaccard overlap threshold of detection output and " - "ground-truth data.") + AddAttr( + "overlap_threshold", + "(float) " + "The lower bound jaccard overlap threshold of detection output and " + "ground-truth data.") .SetDefault(.3f); AddAttr("evaluate_difficult", "(bool, default true) " diff --git a/paddle/fluid/operators/detection_map_op.h b/paddle/fluid/operators/detection_map_op.h index 0379a3328a..0f5f588e9c 100644 --- a/paddle/fluid/operators/detection_map_op.h +++ b/paddle/fluid/operators/detection_map_op.h @@ -54,7 +54,7 @@ template class DetectionMAPOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* in_detect = ctx.Input("Detection"); + auto* in_detect = ctx.Input("DetectRes"); auto* in_label = ctx.Input("Label"); auto* out_map = ctx.Output("MAP"); @@ -62,9 +62,9 @@ class DetectionMAPOpKernel : public framework::OpKernel { auto* in_true_pos = ctx.Input("TruePos"); auto* in_false_pos = ctx.Input("FalsePos"); - auto* out_pos_count = ctx.Output("OutPosCount"); - auto* out_true_pos = ctx.Output("OutTruePos"); - auto* out_false_pos = ctx.Output("OutFalsePos"); + auto* out_pos_count = ctx.Output("AccumPosCount"); + auto* out_true_pos = ctx.Output("AccumTruePos"); + auto* out_false_pos = ctx.Output("AccumFalsePos"); float overlap_threshold = ctx.Attr("overlap_threshold"); float evaluate_difficult = ctx.Attr("evaluate_difficult"); @@ -265,28 +265,22 @@ class DetectionMAPOpKernel : public framework::OpKernel { label_pos_count[i] = pos_count_data[i]; } - const T* true_pos_data = input_true_pos.data(); - auto true_pos_data_lod = input_true_pos.lod(); - for (int i = 0; i < true_pos_data_lod.size(); ++i) { - for (int j = true_pos_data_lod[0][i]; j < true_pos_data_lod[0][i + 1]; - ++j) { - T score = true_pos_data[j * 2]; - int flag = 1; - if (true_pos_data[j * 2 + 1] < kEPS) flag = 0; - true_pos[i].push_back(std::make_pair(score, flag)); - } - } - const T* false_pos_data = input_false_pos.data(); - auto false_pos_data_lod = input_false_pos.lod(); - for (int i = 0; i < false_pos_data_lod.size(); ++i) { - for (int j = false_pos_data_lod[0][i]; j < false_pos_data_lod[0][i + 1]; - ++j) { - T score = false_pos_data[j * 2]; - int flag = 1; - if (false_pos_data[j * 2 + 1] < kEPS) flag = 0; - false_pos[i].push_back(std::make_pair(score, flag)); + auto SetData = [](const framework::LoDTensor& pos_tensor, + std::map>>& pos) { + const T* pos_data = pos_tensor.data(); + auto pos_data_lod = pos_tensor.lod(); + for (int i = 0; i < pos_data_lod.size(); ++i) { + for (int j = pos_data_lod[0][i]; j < pos_data_lod[0][i + 1]; ++j) { + T score = pos_data[j * 2]; + int flag = 1; + if (pos_data[j * 2 + 1] < kEPS) flag = 0; + pos[i].push_back(std::make_pair(score, flag)); + } } - } + }; + + SetData(input_true_pos, true_pos); + SetData(input_false_pos, false_pos); return; } diff --git a/python/paddle/v2/fluid/tests/test_detection_map_op.py b/python/paddle/v2/fluid/tests/test_detection_map_op.py index ec57ca4ad5..70ccd885d8 100644 --- a/python/paddle/v2/fluid/tests/test_detection_map_op.py +++ b/python/paddle/v2/fluid/tests/test_detection_map_op.py @@ -37,7 +37,7 @@ class TestDetectionMAPOp(OpTest): self.inputs = { 'Label': (self.label, self.label_lod), - 'Detection': (self.detect, self.detect_lod), + 'DetectRes': (self.detect, self.detect_lod), 'PosCount': self.class_pos_count, 'TruePos': (self.true_pos, self.true_pos_lod), 'FalsePos': (self.false_pos, self.false_pos_lod) @@ -45,7 +45,7 @@ class TestDetectionMAPOp(OpTest): else: self.inputs = { 'Label': (self.label, self.label_lod), - 'Detection': (self.detect, self.detect_lod), + 'DetectRes': (self.detect, self.detect_lod), } self.attrs = { @@ -61,9 +61,9 @@ class TestDetectionMAPOp(OpTest): self.outputs = { 'MAP': self.mAP, - 'OutPosCount': self.out_class_pos_count, - 'OutTruePos': (self.out_true_pos, self.out_true_pos_lod), - 'OutFalsePos': (self.out_false_pos, self.out_false_pos_lod) + 'AccumPosCount': self.out_class_pos_count, + 'AccumTruePos': (self.out_true_pos, self.out_true_pos_lod), + 'AccumFalsePos': (self.out_false_pos, self.out_false_pos_lod) } def init_test_case(self): @@ -175,9 +175,7 @@ class TestDetectionMAPOp(OpTest): false_pos[label].append([score, fp]) for (label, label_pos_num) in label_count.items(): - if label_pos_num == 0 or label not in true_pos: - continue - + if label_pos_num == 0 or label not in true_pos: continue label_true_pos = true_pos[label] label_false_pos = false_pos[label] From 51912a7a77ace39de00e7b8198c37d9c85491614 Mon Sep 17 00:00:00 2001 From: Yancey Date: Mon, 12 Feb 2018 12:48:59 +0800 Subject: [PATCH 40/43] Append cmd in manylinux dockerfile.x86 (#8397) * append cmd in manylinux dockerfile.x86 * add new line --- tools/manylinux1/Dockerfile.x64 | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tools/manylinux1/Dockerfile.x64 b/tools/manylinux1/Dockerfile.x64 index 0f1b833130..93cab692e3 100644 --- a/tools/manylinux1/Dockerfile.x64 +++ b/tools/manylinux1/Dockerfile.x64 @@ -52,3 +52,5 @@ RUN wget -O /opt/swig-2.0.12.tar.gz https://sourceforge.net/projects/swig/files/ RUN mkdir -p /src && cd /src && git clone https://github.com/NVIDIA/nccl.git nccl && cd nccl &&\ make -j `nproc` install && cd .. && rm -rf nccl + +CMD ["bash", "/paddle/paddle/scripts/docker/build.sh"] From 8a0dd2409e7cebe146b8a93103c0de71577bc533 Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Mon, 12 Feb 2018 12:54:24 +0800 Subject: [PATCH 41/43] Expose softmax_with_cross_entropy and smooth_l1 into Python API. (#8375) * Add softmax_with_cross_entropy and smooth_l1 in Python API. * Fix doc format. --- paddle/fluid/operators/smooth_l1_loss_op.cc | 14 +-- python/paddle/v2/fluid/layers/nn.py | 121 ++++++++++++++++++++ python/paddle/v2/fluid/tests/test_layers.py | 18 +++ 3 files changed, 145 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/operators/smooth_l1_loss_op.cc b/paddle/fluid/operators/smooth_l1_loss_op.cc index be4c7a56a8..e6eede23ee 100644 --- a/paddle/fluid/operators/smooth_l1_loss_op.cc +++ b/paddle/fluid/operators/smooth_l1_loss_op.cc @@ -44,7 +44,6 @@ class SmoothL1LossOp : public framework::OperatorWithKernel { } }; -template class SmoothL1LossOpMaker : public framework::OpProtoAndCheckerMaker { public: SmoothL1LossOpMaker(OpProto* proto, OpAttrChecker* op_checker) @@ -73,10 +72,10 @@ class SmoothL1LossOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("Out", "(Tensor, default Tensor) A tensor with rank be 2. " "The output smooth l1 loss with shape [batch_size, 1]."); - AddAttr("sigma", - "Hyper parameter of smooth l1 loss op." - "A float scalar with default value 3.0.") - .SetDefault(3.0); + AddAttr("sigma", + "Hyper parameter of smooth l1 loss op." + "A float scalar with default value 3.0.") + .SetDefault(1.0); AddComment(R"DOC( Smooth L1 Loss Operator. @@ -133,9 +132,8 @@ class SmoothL1LossGradOp : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(smooth_l1_loss, ops::SmoothL1LossOp, - ops::SmoothL1LossOpMaker, smooth_l1_loss_grad, - ops::SmoothL1LossGradOp); +REGISTER_OP(smooth_l1_loss, ops::SmoothL1LossOp, ops::SmoothL1LossOpMaker, + smooth_l1_loss_grad, ops::SmoothL1LossGradOp); REGISTER_OP_CPU_KERNEL( smooth_l1_loss, ops::SmoothL1LossKernel); diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py index 5ebd329fc0..051b536818 100644 --- a/python/paddle/v2/fluid/layers/nn.py +++ b/python/paddle/v2/fluid/layers/nn.py @@ -66,6 +66,8 @@ __all__ = [ 'row_conv', 'multiplex', 'layer_norm', + 'softmax_with_cross_entropy', + 'smooth_l1', ] @@ -3091,3 +3093,122 @@ def multiplex(inputs, index): 'Ids': index}, outputs={'Out': [out]}) return out + + +def softmax_with_cross_entropy(logits, label, soft_label=False): + """ + **Softmax With Cross Entropy Operator.** + + Cross entropy loss with softmax is used as the output layer extensively. This + operator computes the softmax normalized values for each row of the input + tensor, after which cross-entropy loss is computed. This provides a more + numerically stable gradient. + + Because this operator performs a softmax on logits internally, it expects + unscaled logits. This operator should not be used with the output of + softmax operator since that would produce incorrect results. + + When the attribute soft_label is set false, this operators expects mutually + exclusive hard labels, each sample in a batch is in exactly one class with a + probability of 1.0. Each sample in the batch will have a single label. + + The equation is as follows: + + 1) Hard label (one-hot label, so every sample has exactly one class) + + .. math:: + + loss_j = -\\text{logit}_{label_j} + + \\log\\left(\\sum_{i=0}^{K}\\exp(\\text{logit}_i)\\right), j = 1,..., K + + 2) Soft label (each sample can have a distribution over all classes) + + .. math:: + + loss_j = -\\sum_{i=0}^{K}\\text{label}_i + \\left(\\text{logit}_i - \\log\\left(\\sum_{i=0}^{K} + \\exp(\\text{logit}_i)\\right)\\right), j = 1,...,K + + Args: + logits (Variable): The unscaled log probabilities, which is a 2-D tensor + with shape [N x K]. N is the batch_size, and K is the class number. + label (Variable): The ground truth which is a 2-D tensor. If soft_label + is set to false, Label is a Tensor with shape [N x 1]. If + soft_label is set to true, Label is a Tensor with + soft_label (bool): A flag to indicate whether to interpretate the given + labels as soft labels. By default, `soft_label` is set to False. + Returns: + Variable: The cross entropy loss is a 2-D tensor with shape [N x 1]. + + Examples: + .. code-block:: python + + data = fluid.layers.data(name='data', shape=[128], dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + fc = fluid.layers.fc(input=data, size=100) + out = fluid.layers.softmax_with_cross_entropy(logits=fc, label=label) + """ + helper = LayerHelper('softmax_with_cross_entropy', **locals()) + softmax = helper.create_tmp_variable(dtype=logits.dtype) + loss = helper.create_tmp_variable(dtype=logits.dtype) + helper.append_op( + type='softmax_with_cross_entropy', + inputs={'Logits': logits, + 'Label': label}, + outputs={'Softmax': softmax, + 'Loss': loss}, + attrs={'soft_label': soft_label}) + return loss + + +def smooth_l1(x, y, inside_weight=None, outside_weight=None, sigma=None): + """ + **Smooth L1 Loss Operator. ** + + This operator computes the smooth l1 loss for X and Y. + The operator takes the first dimension of X and Y as batch size. + For each instance, it computes the smooth l1 loss element by element first + and then sums all the losses. So the shape of Out is [batch_size, 1]. + + Args: + x (Variable): A tensor with rank at least 2. The input value of smooth + l1 loss op with shape [batch_size, dim1, ..., dimN]. + y (Variable): A tensor with rank at least 2. The target value of smooth + l1 loss op with same shape as x. + inside_weight (Variable|None): A tensor with rank at least 2. This + input is optional and should have same shape with x. If provided, + the result of (x - y) will be multiplied by this tensor element by + element. + outside_weight (Variable|None): A tensor with rank at least 2. This + input is optional and should have same shape with x. If provided, + the out smooth l1 loss will be multiplied by this tensor element + by element. + sigma (float|None): Hyper parameter of smooth l1 loss op. A float scalar + with default value 1.0. + Returns: + Variable: A tensor with rank be 2. The output smooth l1 loss with + shape [batch_size, 1]. + + Examples: + .. code-block:: python + + data = fluid.layers.data(name='data', shape=[128], dtype='float32') + label = fluid.layers.data(name='label', shape=[100], dtype='int64') + fc = fluid.layers.fc(input=data, size=100) + out = fluid.layers.smooth_l1(logits=fc, label=label) + """ + helper = LayerHelper('smooth_l1_loss', **locals()) + diff = helper.create_tmp_variable(dtype=x.dtype) + loss = helper.create_tmp_variable(dtype=x.dtype) + helper.append_op( + type='smooth_l1_loss', + inputs={ + 'X': x, + 'Y': y, + 'InsideWeight': inside_weight, + 'OutsideWeight': outside_weight + }, + outputs={'Diff': diff, + 'Out': loss}, + attrs={'sigma': sigma}) + return loss diff --git a/python/paddle/v2/fluid/tests/test_layers.py b/python/paddle/v2/fluid/tests/test_layers.py index fa46f86973..50ef820424 100644 --- a/python/paddle/v2/fluid/tests/test_layers.py +++ b/python/paddle/v2/fluid/tests/test_layers.py @@ -309,6 +309,24 @@ class TestBook(unittest.TestCase): self.assertIsNotNone(out) print(str(program)) + def test_softmax_with_cross_entropy(self): + program = Program() + with program_guard(program): + x = layers.data(name='x', shape=[16], dtype='float32') + y = layers.data(name='label', shape=[1], dtype='int64') + loss = layers.softmax_with_cross_entropy(x, y) + self.assertIsNotNone(loss) + print(str(program)) + + def test_smooth_l1(self): + program = Program() + with program_guard(program): + x = layers.data(name='x', shape=[4], dtype='float32') + y = layers.data(name='label', shape=[4], dtype='float32') + loss = layers.smooth_l1(x, y) + self.assertIsNotNone(loss) + print(str(program)) + if __name__ == '__main__': unittest.main() From da02a5812c4a8947a9e20d8d590e67165d7703c5 Mon Sep 17 00:00:00 2001 From: Tao Luo Date: Mon, 12 Feb 2018 13:10:45 +0800 Subject: [PATCH 42/43] refine inference_lib_dist after code move, and add it to docker/build.sh (#8379) * refine inference_lib_dist after code move, and add it to docker/build.sh * remove is_directory in inference_lib.cmake --- cmake/inference_lib.cmake | 18 ++++++++---------- paddle/scripts/docker/build.sh | 12 ++++++++++++ 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/cmake/inference_lib.cmake b/cmake/inference_lib.cmake index 7d53554358..df18663772 100644 --- a/cmake/inference_lib.cmake +++ b/cmake/inference_lib.cmake @@ -16,12 +16,10 @@ function(copy TARGET) foreach(index RANGE ${len}) list(GET copy_lib_SRCS ${index} src) list(GET copy_lib_DSTS ${index} dst) - add_custom_command(TARGET ${TARGET} PRE_BUILD COMMAND mkdir -p "${dst}") - if(IS_DIRECTORY ${src}) - add_custom_command(TARGET ${TARGET} PRE_BUILD COMMAND cp -r "${src}" "${dst}") - else() - add_custom_command(TARGET ${TARGET} PRE_BUILD COMMAND cp "${src}" "${dst}") - endif() + add_custom_command(TARGET ${TARGET} PRE_BUILD + COMMAND mkdir -p "${dst}" + COMMAND cp -r "${src}" "${dst}" + COMMENT "copying ${src} -> ${dst}") endforeach() endfunction() @@ -53,11 +51,11 @@ IF(NOT PROTOBUF_FOUND) ENDIF(NOT PROTOBUF_FOUND) # paddle fluid module -set(src_dir "${PADDLE_SOURCE_DIR}/paddle") -set(dst_dir "${CMAKE_INSTALL_PREFIX}/paddle") +set(src_dir "${PADDLE_SOURCE_DIR}/paddle/fluid") +set(dst_dir "${CMAKE_INSTALL_PREFIX}/paddle/fluid") set(module "framework") copy(framework_lib DEPS framework_py_proto - SRCS ${src_dir}/${module}/*.h ${src_dir}/${module}/details/*.h ${PADDLE_BINARY_DIR}/paddle/framework/framework.pb.h + SRCS ${src_dir}/${module}/*.h ${src_dir}/${module}/details/*.h ${PADDLE_BINARY_DIR}/paddle/fluid/framework/framework.pb.h DSTS ${dst_dir}/${module} ${dst_dir}/${module}/details ${dst_dir}/${module} ) @@ -69,7 +67,7 @@ copy(memory_lib set(module "inference") copy(inference_lib DEPENDS paddle_fluid_shared - SRCS ${src_dir}/${module}/*.h ${PADDLE_BINARY_DIR}/paddle/inference/libpaddle_fluid.so + SRCS ${src_dir}/${module}/*.h ${PADDLE_BINARY_DIR}/paddle/fluid/inference/libpaddle_fluid.so DSTS ${dst_dir}/${module} ${dst_dir}/${module} ) diff --git a/paddle/scripts/docker/build.sh b/paddle/scripts/docker/build.sh index 1486d5ed25..442a7ea883 100644 --- a/paddle/scripts/docker/build.sh +++ b/paddle/scripts/docker/build.sh @@ -204,6 +204,17 @@ function gen_capi_package() { fi } +function gen_fluid_inference_lib() { + if [ ${WITH_C_API:-OFF} == "OFF" ] ; then + cat < Date: Mon, 12 Feb 2018 13:58:41 +0800 Subject: [PATCH 43/43] pass size when copy --- paddle/fluid/operators/concat_op.h | 4 ++-- paddle/fluid/operators/split_op.h | 2 +- paddle/fluid/operators/strided_memcpy.h | 9 ++++----- python/paddle/v2/fluid/distribute_transpiler.py | 5 +++++ 4 files changed, 12 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/operators/concat_op.h b/paddle/fluid/operators/concat_op.h index 878e530585..c8a4292932 100644 --- a/paddle/fluid/operators/concat_op.h +++ b/paddle/fluid/operators/concat_op.h @@ -38,7 +38,7 @@ class ConcatKernel : public framework::OpKernel { auto in_stride = framework::stride_numel(in->dims()); StridedNumelCopyWithAxis(ctx.device_context(), axis, out->data() + output_offset, out_stride, - in->data(), in_stride); + in->data(), in_stride, in_stride[axis]); output_offset += in_stride[axis]; } } @@ -59,7 +59,7 @@ class ConcatGradKernel : public framework::OpKernel { auto out_stride = framework::stride_numel(out->dims()); StridedNumelCopyWithAxis(ctx.device_context(), axis, out->data(), out_stride, in->data() + input_offset, - in_stride); + in_stride, out_stride[axis]); input_offset += out_stride[axis]; } } diff --git a/paddle/fluid/operators/split_op.h b/paddle/fluid/operators/split_op.h index 06bcf82620..54420e1bf6 100644 --- a/paddle/fluid/operators/split_op.h +++ b/paddle/fluid/operators/split_op.h @@ -38,7 +38,7 @@ class SplitOpKernel : public framework::OpKernel { auto out_stride = framework::stride_numel(out->dims()); StridedNumelCopyWithAxis(ctx.device_context(), axis, out->data(), out_stride, in->data() + input_offset, - in_stride); + in_stride, out_stride[axis]); input_offset += out_stride[axis]; } } diff --git a/paddle/fluid/operators/strided_memcpy.h b/paddle/fluid/operators/strided_memcpy.h index 4036d1091d..4c7b90693a 100644 --- a/paddle/fluid/operators/strided_memcpy.h +++ b/paddle/fluid/operators/strided_memcpy.h @@ -54,11 +54,11 @@ inline void StridedNumelCopyWithAxis(const platform::DeviceContext& ctx, int64_t axis, T* dst, const framework::DDim& dst_stride_numel, const T* src, - const framework::DDim& src_stride_numel) { + const framework::DDim& src_stride_numel, + int64_t size) { int64_t before = dst_stride_numel[0] / dst_stride_numel[axis]; int64_t src_after = src_stride_numel[axis]; int64_t dst_after = dst_stride_numel[axis]; - int64_t copy_size = std::min(src_after, dst_after); auto place = ctx.GetPlace(); PADDLE_ENFORCE_EQ(src_stride_numel.size(), dst_stride_numel.size(), @@ -83,15 +83,14 @@ inline void StridedNumelCopyWithAxis(const platform::DeviceContext& ctx, if (platform::is_cpu_place(place)) { auto& cpu_place = boost::get(place); memory::Copy(cpu_place, dst + i * dst_after, cpu_place, - src + i * src_after, sizeof(T) * copy_size); + src + i * src_after, sizeof(T) * size); } else { #ifdef PADDLE_WITH_CUDA auto& gpu_place = boost::get(place); auto& cuda_ctx = reinterpret_cast(ctx); memory::Copy(gpu_place, dst + i * dst_after, gpu_place, - src + i * src_after, sizeof(T) * copy_size, - cuda_ctx.stream()); + src + i * src_after, sizeof(T) * size, cuda_ctx.stream()); #else PADDLE_THROW("Paddle is not compiled with GPU"); #endif diff --git a/python/paddle/v2/fluid/distribute_transpiler.py b/python/paddle/v2/fluid/distribute_transpiler.py index f84481adf7..689920af0c 100644 --- a/python/paddle/v2/fluid/distribute_transpiler.py +++ b/python/paddle/v2/fluid/distribute_transpiler.py @@ -121,6 +121,7 @@ def split_dense_variable(var_list, block_size += dim1 - remains # update split_count after aligning split_count = int(math.ceil(var_numel / float(block_size))) + print("###split var ", var.name, var.shape, block_size, split_count) for block_id in xrange(split_count): curr_block_size = min(block_size, var_numel - ( (block_id) * block_size)) @@ -255,6 +256,7 @@ class DistributeTranspiler: splited_shape = [rows] if len(orig_shape) >= 2: splited_shape.extend(orig_shape[1:]) + print("###splited: ", size, rows, splited_shape) var = program.global_block().create_var( name="%s.block%d" % (varname, i), psersistable=False, @@ -262,6 +264,7 @@ class DistributeTranspiler: type=orig_var.type, shape=splited_shape) # flattend splited var var_mapping[varname].append(var) + print("###created split var ", var) return var_mapping def _clone_var(self, block, var): @@ -528,6 +531,8 @@ class DistributeTranspiler: """ # step5 pserver_program = Program() + print("param mapping on pserver: #### ", + self.param_grad_ep_mapping[endpoint]["params"]) for v in self.param_grad_ep_mapping[endpoint]["params"]: self._clone_var(pserver_program.global_block(), v) for v in self.param_grad_ep_mapping[endpoint]["grads"]: