!10429 Optimize ExtractChannel op for lite cv

From: @jiangzhiwen8
Reviewed-by: @xulei2020,@pandoublefeng,@heleiwang,@liucunwei
Signed-off-by: @liucunwei
pull/10429/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 9f9c132440

@ -659,42 +659,34 @@ static void PadWithConstant(const LiteMat &src, LiteMat &dst, const int top, con
}
}
bool ExtractChannel(const LiteMat &src, LiteMat &dst, int col) {
template <typename T>
void ExtractChannelImpl(T *src_ptr, T *dst_ptr, int height, int width, int channel, int col) {
int total = height * width;
int i = 0;
int src_idx = col;
for (; i < total; i++, src_idx += channel) {
dst_ptr[i] = src_ptr[src_idx];
}
}
bool ExtractChannel(LiteMat &src, LiteMat &dst, int col) {
if (src.IsEmpty() || col < 0 || col > src.channel_ - 1) {
return false;
}
if (src.data_type_ == LDataType::FLOAT32) {
(void)dst.Init(src.width_, src.height_, 1, src.data_type_);
const float *src_start_p = src;
float *dst_start_p = dst;
for (int h = 0; h < src.height_; h++) {
uint32_t src_start = h * src.width_ * src.channel_ + col;
uint32_t dst_start = h * dst.width_;
for (int w = 0; w < src.width_; w++) {
uint32_t src_index = src_start + w * src.channel_;
uint32_t dst_index = dst_start + w;
dst_start_p[dst_index] = src_start_p[src_index];
}
}
return true;
} else if (src.data_type_ == LDataType::UINT8) {
(void)dst.Init(src.width_, src.height_, 1, src.data_type_);
const uint8_t *src_start_p = src;
uint8_t *dst_start_p = dst;
for (int h = 0; h < src.height_; h++) {
uint32_t src_start = h * src.width_ * src.channel_ + col;
uint32_t dst_start = h * dst.width_;
for (int w = 0; w < src.width_; w++) {
uint32_t src_index = src_start + w * src.channel_;
uint32_t dst_index = dst_start + w;
dst_start_p[dst_index] = src_start_p[src_index];
}
}
return true;
if (dst.IsEmpty() || dst.width_ != src.width_ || dst.height_ != src.height_ || dst.channel_ != 1 ||
dst.data_type_ != src.data_type_) {
dst.Init(src.width_, src.height_, 1, src.data_type_);
}
if (dst.data_type_ == LDataType::FLOAT32) {
ExtractChannelImpl<float>(src, dst, src.height_, src.width_, src.channel_, col);
} else if (dst.data_type_ == LDataType::UINT8) {
ExtractChannelImpl<uint8_t>(src, dst, src.height_, src.width_, src.channel_, col);
} else {
return false;
}
return false;
return true;
}
bool Split(const LiteMat &src, std::vector<LiteMat> &mv) {

@ -82,7 +82,7 @@ bool Pad(const LiteMat &src, LiteMat &dst, int top, int bottom, int left, int ri
uint8_t fill_b_or_gray, uint8_t fill_g, uint8_t fill_r);
/// \brief Extract image channel by index
bool ExtractChannel(const LiteMat &src, LiteMat &dst, int col);
bool ExtractChannel(LiteMat &src, LiteMat &dst, int col);
/// \brief Split image channels to single channel
bool Split(const LiteMat &src, std::vector<LiteMat> &mv);

Loading…
Cancel
Save