diff --git a/model_zoo/official/cv/psenet/src/ETSNET/pse/Makefile b/model_zoo/official/cv/psenet/src/ETSNET/pse/Makefile index 1c3aa09213..e3b5c8898e 100644 --- a/model_zoo/official/cv/psenet/src/ETSNET/pse/Makefile +++ b/model_zoo/official/cv/psenet/src/ETSNET/pse/Makefile @@ -13,15 +13,20 @@ # limitations under the License. # ============================================================================ - -CXXFLAGS = -I include -std=c++11 -O3 +mindspore_home = ${MINDSPORE_HOME} +CXXFLAGS = -I include -I ${mindspore_home}/mindspore/official/cv/psenet -std=c++11 -O3 CXX_SOURCES = adaptor.cpp -OPENCV = `pkg-config --cflags --libs opencv` +opencv_home = ${OPENCV_HOME} +OPENCV = -I$(opencv_home)/include -L$(opencv_home)/lib64 -lopencv_superres -lopencv_ml -lopencv_objdetect \ +-lopencv_highgui -lopencv_dnn -lopencv_stitching -lopencv_videostab -lopencv_calib3d -lopencv_videoio \ +-lopencv_features2d -lopencv_photo -lopencv_flann -lopencv_shape -lopencv_video -lopencv_imgcodecs -lopencv_imgproc \ +-lopencv_core +PYBIND11 = `python -m pybind11 --includes` LIB_SO = adaptor.so $(LIB_SO): $(CXX_SOURCES) $(DEPS) - $(CXX) -o $@ $(CXXFLAGS) $(LDFLAGS) $(CXX_SOURCES) --shared -fPIC $(OPENCV) + $(CXX) -o $@ $(CXXFLAGS) $(LDFLAGS) $(CXX_SOURCES) --shared -fPIC $(OPENCV) $(PYBIND11) clean: rm -rf $(LIB_SO) diff --git a/model_zoo/official/cv/psenet/src/ETSNET/pse/adaptor.cpp b/model_zoo/official/cv/psenet/src/ETSNET/pse/adaptor.cpp index 95125278ae..a7092b4c5d 100644 --- a/model_zoo/official/cv/psenet/src/ETSNET/pse/adaptor.cpp +++ b/model_zoo/official/cv/psenet/src/ETSNET/pse/adaptor.cpp @@ -20,7 +20,6 @@ #include #include #include -#include #include #include #include @@ -28,12 +27,14 @@ #include using std::vector; -using cv::vector; +using std::queue; +using cv::Mat; +using cv::Point; namespace py = pybind11; namespace pse_adaptor { - void get_kernals(const int *data, vector data_shape, const vector &kernals) { + void get_kernals(const int *data, vector data_shape, vector *kernals) { for (int i = 0; i < data_shape[0]; ++i) { Mat kernal = Mat::zeros(data_shape[1], data_shape[2], CV_8UC1); for (int x = 0; x < kernal.rows; ++x) { @@ -41,15 +42,14 @@ namespace pse_adaptor { kernal.at(x, y) = data[i * data_shape[1] * data_shape[2] + x * data_shape[2] + y]; } } - kernals.emplace_back(kernal); + kernals->emplace_back(kernal); } } - void growing_text_line(const vector, const vector> &text_line, float min_area) { + void growing_text_line(const vector &kernals, vector> *text_line, float min_area) { Mat label_mat; int label_num = connectedComponents(kernals[kernals.size() - 1], label_mat, 4); - vector area(label_num + 1, 0) - memset_s(area, 0, sizeof(area)); + vector area(label_num + 1, 0); for (int x = 0; x < label_mat.rows; ++x) { for (int y = 0; y < label_mat.cols; ++y) { int label = label_mat.at(x, y); @@ -69,7 +69,7 @@ namespace pse_adaptor { queue.push(point); row[y] = label; } - text_line.emplace_back(row); + text_line->emplace_back(row); } int dx[] = {-1, 1, 0, 0}; @@ -81,20 +81,20 @@ namespace pse_adaptor { queue.pop(); int x = point.x; int y = point.y; - int label = text_line[x][y]; + int label = text_line->at(x)[y]; bool is_edge = true; for (int d = 0; d < 4; ++d) { int tmp_x = x + dx[d]; int tmp_y = y + dy[d]; - if (tmp_x < 0 || tmp_x >= (static_cast)text_line.size()) continue; - if (tmp_y < 0 || tmp_y >= (static_cast)text_line[1].size()) continue; + if (tmp_x < 0 || tmp_x >= static_cast(text_line->size())) continue; + if (tmp_y < 0 || tmp_y >= static_cast(text_line->at(1).size())) continue; if (kernals[kernal_id].at(tmp_x, tmp_y) == 0) continue; - if (text_line[tmp_x][tmp_y] > 0) continue; + if (text_line->at(tmp_x)[tmp_y] > 0) continue; Point point_tmp(tmp_x, tmp_y); queue.push(point_tmp); - text_line[tmp_x][tmp_y] = label; + text_line->at(tmp_x)[tmp_y] = label; is_edge = false; } @@ -110,9 +110,9 @@ namespace pse_adaptor { auto buf = quad_n9.request(); auto data = static_cast(buf.ptr); vector kernals; - get_kernals(data, buf.shape, kernals); + get_kernals(data, buf.shape, &kernals); vector> text_line; - growing_text_line(kernals, text_line, min_area); + growing_text_line(kernals, &text_line, min_area); return text_line; } diff --git a/model_zoo/official/cv/psenet/train.py b/model_zoo/official/cv/psenet/train.py index f9f4cb5285..cb5b29ea5b 100644 --- a/model_zoo/official/cv/psenet/train.py +++ b/model_zoo/official/cv/psenet/train.py @@ -50,7 +50,7 @@ def train(): rank_id = get_rank() # dataset/network/criterion/optim - ds = train_dataset_creator(args.device_id, args.device_num) + ds = train_dataset_creator(rank_id, args.device_num) step_size = ds.get_dataset_size() print('Create dataset done!')