!4502 ops add ReSize()

Merge pull request !4502 from zhaozhenlong/lite/issue/ops_need_call_resize
pull/4502/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 03ea6bfc11

@ -29,12 +29,10 @@ using mindspore::schema::PrimitiveType_ExpandDims;
namespace mindspore::kernel {
int ExpandDimsCPUKernel::Init() {
if (context_->infer_shape_interrupt_ && !context_->running_) {
set_need_reinit();
if (!InferShapeDone()) {
return RET_OK;
}
int ret = ReSize();
return ret;
return ReSize();
}
int ExpandDimsCPUKernel::ReSize() {

@ -35,18 +35,19 @@ constexpr int kOutputNum = 1;
} // namespace
int FillCPUKernel::Init() {
if (context_->infer_shape_interrupt_ && !context_->running_) {
set_need_reinit();
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}
int FillCPUKernel::ReSize() {
data_size_ = out_tensors_.front()->ElementsNum();
thread_sz_count_ = MSMIN(thread_count_, data_size_);
thread_sz_stride_ = UP_DIV(data_size_, thread_sz_count_);
return RET_OK;
}
int FillCPUKernel::ReSize() { return RET_OK; }
int FillCPUKernel::DoFill(int task_id) {
int size = MSMIN(thread_sz_stride_, data_size_ - task_id * thread_sz_stride_);
if (size <= 0) {

@ -32,7 +32,10 @@ namespace mindspore::kernel {
int GatherCPUKernel::Init() {
axis_ = (reinterpret_cast<GatherParameter *>(op_parameter_))->axis_;
batchDims_ = (reinterpret_cast<GatherParameter *>(op_parameter_))->batchDims_;
return RET_OK;
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}
int GatherCPUKernel::ReSize() { return RET_OK; }

@ -38,10 +38,17 @@ GatherNdCPUKernel::~GatherNdCPUKernel() {
}
int GatherNdCPUKernel::Init() {
if (context_->infer_shape_interrupt_ && !context_->running_) {
set_need_reinit();
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}
int GatherNdCPUKernel::ReSize() {
if (in_offset_ != nullptr) {
free(in_offset_);
in_offset_ = nullptr;
}
auto indices_tensor = in_tensors_.at(1);
auto indices_shape = indices_tensor->shape();
int indices_rank = indices_shape.size();
@ -59,16 +66,9 @@ int GatherNdCPUKernel::Init() {
thread_sz_count_ = MSMIN(thread_count_, count_);
thread_sz_stride_ = UP_DIV(count_, thread_sz_count_);
int ret = ReSize();
return ret;
}
int GatherNdCPUKernel::ReSize() {
auto in_shape = in_tensors_.front()->shape();
int in_rank = in_shape.size();
auto indices_tensor = in_tensors_.at(1);
auto indices_shape = indices_tensor->shape();
int indices_rank = indices_shape.size();
int idx_lastshape = indices_shape[indices_rank - 1];
auto indices_ptr = reinterpret_cast<int *>(indices_tensor->Data());
area_ = 1;

@ -35,40 +35,49 @@ constexpr size_t kOutputNum = 1;
} // namespace
int OneHotCPUKernel::Init() {
if (context_->infer_shape_interrupt_ && !context_->running_) {
set_need_reinit();
return RET_OK;
}
// indices depth on_value off_value
if (in_tensors_.size() != kInputNum || out_tensors_.size() != kOutputNum) {
MS_LOG(ERROR) << "OneHot input size should be " << kInputNum << ", got " << in_tensors_.size()
<< ", output size should be" << kOutputNum << ", got " << out_tensors_.size();
return RET_ERROR;
}
if (context_ == nullptr) {
MS_LOG(ERROR) << "OneHot context nullptr";
return RET_NULL_PTR;
}
thread_num_ = context_->thread_num_;
auto param = reinterpret_cast<OneHotParameter *>(op_parameter_);
if (param == nullptr) {
MS_LOG(ERROR) << "OneHot op_parameter_ nullptr";
return RET_NULL_PTR;
}
axis_ = param->axis_;
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}
int OneHotCPUKernel::ReSize() {
auto indices = in_tensors_.at(0);
if (indices == nullptr) {
MS_LOG(ERROR) << "OneHot inputs[0] indices nullptr";
return RET_NULL_PTR;
}
auto indices_shape = indices->shape();
const int indices_rank = static_cast<int>(indices_shape.size());
if (axis_ < 0) {
axis_ += indices_rank + 1;
}
outer_size_ = 1;
for (size_t i = 0; i < static_cast<size_t>(axis_); i++) {
outer_size_ *= indices_shape[i];
}
inner_size_ = indices->ElementsNum() / outer_size_;
if (context_ == nullptr) {
MS_LOG(ERROR) << "OneHot context nullptr";
return RET_NULL_PTR;
}
thread_num_ = context_->thread_num_;
const int indices_rank = static_cast<int>(in_tensors_.at(0)->shape().size());
if (axis_ < 0) {
axis_ += indices_rank + 1;
}
return RET_OK;
}

@ -26,12 +26,12 @@ class OneHotCPUKernel : public LiteKernel {
OneHotCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
const lite::Primitive *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive), context_(ctx) {}
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {}
~OneHotCPUKernel() override = default;
int Init() override;
int ReSize() override { return 0; };
int ReSize() override;
int Run() override;
int OneHotImpl(int task_id);
@ -39,7 +39,6 @@ class OneHotCPUKernel : public LiteKernel {
int GetParams();
private:
const lite::Context *context_;
int thread_num_;
int axis_;
int outer_size_;

@ -36,16 +36,19 @@ constexpr int kOutputNum = 1;
} // namespace
int PadCPUKernel::Init() {
if (context_->infer_shape_interrupt_ && !context_->running_) {
set_need_reinit();
return RET_OK;
}
if (in_tensors_.size() != kInputNum || out_tensors_.size() != kOutputNum) {
MS_LOG(ERROR) << "Pad input size should be " << kInputNum << ", got " << in_tensors_.size()
<< ", output size should be" << kOutputNum << ", got " << out_tensors_.size();
return RET_ERROR;
}
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}
int PadCPUKernel::ReSize() {
auto input = in_tensors_.at(0);
auto output = out_tensors_.at(0);
if (input == nullptr || output == nullptr) {

@ -35,7 +35,7 @@ class PadCPUKernel : public LiteKernel {
~PadCPUKernel() {}
int Init() override;
int ReSize() override { return 0; };
int ReSize() override;
int Run() override;
int RunImpl(int task_id);

@ -44,10 +44,7 @@ int ReduceCPUKernel::Init() {
if (ret != RET_OK) {
return ret;
}
ret = MallocTmpBuffer();
if (ret != RET_OK) {
return ret;
}
switch (mode_) {
case static_cast<int>(ReduceMode_ReduceSum): {
reducer_ = ReduceSum;
@ -77,12 +74,15 @@ int ReduceCPUKernel::Init() {
MS_LOG(ERROR) << "Reduce unsupported reduce mode: " << mode_;
return RET_ERROR;
}
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}
int ReduceCPUKernel::ReSize() { return MallocTmpBuffer(); }
int ReduceCPUKernel::CallReduceUnit(int task_id) {
auto ret = reducer_(outer_size_, inner_size_, axis_size_, src_data_, tmp_shape_.data(), dst_data_, task_id,
context_->thread_num_);
@ -149,6 +149,14 @@ int ReduceCPUKernel::Run() {
}
int ReduceCPUKernel::MallocTmpBuffer() {
for (auto buffer : data_buffers_) {
if (buffer != nullptr) {
free(buffer);
buffer = nullptr;
}
}
data_buffers_.clear();
auto input_shape = in_tensors_.at(0)->shape();
for (auto i = 0; i < num_axes_ - 1; i++) {
int axis = axes_[i];

@ -48,15 +48,15 @@ class ReduceCPUKernel : public ReduceBaseCPUKernel {
}
int Init() override;
int ReSize() override { return 0; };
int ReSize() override;
int Run() override;
int CallReduceUnit(int task_id);
private:
Reducer reducer_;
Reducer reducer_ = nullptr;
std::vector<float *> data_buffers_;
const float *src_data_;
float *dst_data_;
const float *src_data_ = nullptr;
float *dst_data_ = nullptr;
private:
int MallocTmpBuffer();

@ -38,6 +38,10 @@ int ReverseCPUKernel::Stride(int index) {
}
int ReverseCPUKernel::ReSize() {
data_size_ = in_tensors_.at(0)->ElementsNum();
thread_sz_count_ = MSMIN(thread_count_, data_size_);
thread_sz_stride_ = UP_DIV(data_size_, thread_sz_count_);
auto *param = reinterpret_cast<ReverseParameter *>(op_parameter_);
auto input_shape = in_tensors_[0]->shape();
if (param->num_axis_ > input_shape.size()) {
@ -89,13 +93,9 @@ int ReverseCPUKernel::ReSize() {
}
int ReverseCPUKernel::Init() {
if (context_->infer_shape_interrupt_ && !context_->running_) {
set_need_reinit();
if (!InferShapeDone()) {
return RET_OK;
}
data_size_ = in_tensors_.at(0)->ElementsNum();
thread_sz_count_ = MSMIN(thread_count_, data_size_);
thread_sz_stride_ = UP_DIV(data_size_, thread_sz_count_);
int ret = ReSize();
return ret;
}

Loading…
Cancel
Save