psenet scripts update.

pull/7288/head
linqingke 4 years ago
parent 81d9015ddb
commit 11c4d7e8e3

@ -13,15 +13,20 @@
# limitations under the License.
# ============================================================================
CXXFLAGS = -I include -std=c++11 -O3
mindspore_home = ${MINDSPORE_HOME}
CXXFLAGS = -I include -I ${mindspore_home}/mindspore/official/cv/psenet -std=c++11 -O3
CXX_SOURCES = adaptor.cpp
OPENCV = `pkg-config --cflags --libs opencv`
opencv_home = ${OPENCV_HOME}
OPENCV = -I$(opencv_home)/include -L$(opencv_home)/lib64 -lopencv_superres -lopencv_ml -lopencv_objdetect \
-lopencv_highgui -lopencv_dnn -lopencv_stitching -lopencv_videostab -lopencv_calib3d -lopencv_videoio \
-lopencv_features2d -lopencv_photo -lopencv_flann -lopencv_shape -lopencv_video -lopencv_imgcodecs -lopencv_imgproc \
-lopencv_core
PYBIND11 = `python -m pybind11 --includes`
LIB_SO = adaptor.so
$(LIB_SO): $(CXX_SOURCES) $(DEPS)
$(CXX) -o $@ $(CXXFLAGS) $(LDFLAGS) $(CXX_SOURCES) --shared -fPIC $(OPENCV)
$(CXX) -o $@ $(CXXFLAGS) $(LDFLAGS) $(CXX_SOURCES) --shared -fPIC $(OPENCV) $(PYBIND11)
clean:
rm -rf $(LIB_SO)

@ -20,7 +20,6 @@
#include <pybind11/stl_bind.h>
#include <iostream>
#include <queue>
#include <utility>
#include <vector>
#include <opencv2/opencv.hpp>
#include <opencv2/core/core.hpp>
@ -28,12 +27,14 @@
#include <opencv2/imgproc/imgproc.hpp>
using std::vector;
using cv::vector;
using std::queue;
using cv::Mat;
using cv::Point;
namespace py = pybind11;
namespace pse_adaptor {
void get_kernals(const int *data, vector<int> data_shape, const vector<Mat> &kernals) {
void get_kernals(const int *data, vector<int64> data_shape, vector<Mat> *kernals) {
for (int i = 0; i < data_shape[0]; ++i) {
Mat kernal = Mat::zeros(data_shape[1], data_shape[2], CV_8UC1);
for (int x = 0; x < kernal.rows; ++x) {
@ -41,15 +42,14 @@ namespace pse_adaptor {
kernal.at<char>(x, y) = data[i * data_shape[1] * data_shape[2] + x * data_shape[2] + y];
}
}
kernals.emplace_back(kernal);
kernals->emplace_back(kernal);
}
}
void growing_text_line(const vector, const vector<vector<>> &text_line, float min_area) {
void growing_text_line(const vector<Mat> &kernals, vector<vector<int>> *text_line, float min_area) {
Mat label_mat;
int label_num = connectedComponents(kernals[kernals.size() - 1], label_mat, 4);
vector<int> area(label_num + 1, 0)
memset_s(area, 0, sizeof(area));
vector<int> area(label_num + 1, 0);
for (int x = 0; x < label_mat.rows; ++x) {
for (int y = 0; y < label_mat.cols; ++y) {
int label = label_mat.at<int>(x, y);
@ -69,7 +69,7 @@ namespace pse_adaptor {
queue.push(point);
row[y] = label;
}
text_line.emplace_back(row);
text_line->emplace_back(row);
}
int dx[] = {-1, 1, 0, 0};
@ -81,20 +81,20 @@ namespace pse_adaptor {
queue.pop();
int x = point.x;
int y = point.y;
int label = text_line[x][y];
int label = text_line->at(x)[y];
bool is_edge = true;
for (int d = 0; d < 4; ++d) {
int tmp_x = x + dx[d];
int tmp_y = y + dy[d];
if (tmp_x < 0 || tmp_x >= (static_cast)<int>text_line.size()) continue;
if (tmp_y < 0 || tmp_y >= (static_cast)<int>text_line[1].size()) continue;
if (tmp_x < 0 || tmp_x >= static_cast<int>(text_line->size())) continue;
if (tmp_y < 0 || tmp_y >= static_cast<int>(text_line->at(1).size())) continue;
if (kernals[kernal_id].at<char>(tmp_x, tmp_y) == 0) continue;
if (text_line[tmp_x][tmp_y] > 0) continue;
if (text_line->at(tmp_x)[tmp_y] > 0) continue;
Point point_tmp(tmp_x, tmp_y);
queue.push(point_tmp);
text_line[tmp_x][tmp_y] = label;
text_line->at(tmp_x)[tmp_y] = label;
is_edge = false;
}
@ -110,9 +110,9 @@ namespace pse_adaptor {
auto buf = quad_n9.request();
auto data = static_cast<int *>(buf.ptr);
vector<Mat> kernals;
get_kernals(data, buf.shape, kernals);
get_kernals(data, buf.shape, &kernals);
vector<vector<int>> text_line;
growing_text_line(kernals, text_line, min_area);
growing_text_line(kernals, &text_line, min_area);
return text_line;
}

@ -50,7 +50,7 @@ def train():
rank_id = get_rank()
# dataset/network/criterion/optim
ds = train_dataset_creator(args.device_id, args.device_num)
ds = train_dataset_creator(rank_id, args.device_num)
step_size = ds.get_dataset_size()
print('Create dataset done!')

Loading…
Cancel
Save