|
|
|
@ -25,26 +25,18 @@ using mindspore::kernel::KERNEL_ARCH::kCPU;
|
|
|
|
|
using mindspore::lite::KernelRegistrar;
|
|
|
|
|
using mindspore::lite::RET_ERROR;
|
|
|
|
|
using mindspore::lite::RET_OK;
|
|
|
|
|
using mindspore::lite::RET_PARAM_INVALID;
|
|
|
|
|
using mindspore::schema::PrimitiveType_QuantDTypeCast;
|
|
|
|
|
|
|
|
|
|
namespace mindspore::kernel {
|
|
|
|
|
namespace {
|
|
|
|
|
constexpr int kQuantDTypeCastInputNum = 1;
|
|
|
|
|
constexpr int kQuantDTypeCastOutputNum = 1;
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
int QuantDTypeCastCPUKernel::Init() {
|
|
|
|
|
if (context_->infer_shape_interrupt_ && !context_->running_) {
|
|
|
|
|
SetNeedReInit();
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
if (inputs_.size() != 1) {
|
|
|
|
|
MS_LOG(ERROR) << "inputs number should be 1, but " << inputs_.size() << " is given.";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
return RET_PARAM_INVALID;
|
|
|
|
|
}
|
|
|
|
|
if (outputs_.size() != 1) {
|
|
|
|
|
MS_LOG(ERROR) << "outputs number should be 1, but " << inputs_.size() << " is given.";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
return RET_PARAM_INVALID;
|
|
|
|
|
}
|
|
|
|
|
auto in_tensor = inputs_.front();
|
|
|
|
|
auto out_tensor = outputs_.front();
|
|
|
|
@ -63,18 +55,23 @@ int QuantDTypeCastCPUKernel::Init() {
|
|
|
|
|
inverse_ = true;
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(ERROR) << "param data type not supported:" << " src: " << param->srcT << " dst: " << param->dstT;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
return RET_PARAM_INVALID;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!InferShapeDone()) {
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
return ReSize();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int QuantDTypeCastCPUKernel::ReSize() {
|
|
|
|
|
auto in_tensor = inputs_.front();
|
|
|
|
|
num_unit_ = static_cast<int>(in_tensor->DataSize());
|
|
|
|
|
thread_n_num_ = MSMIN(thread_num_, num_unit_);
|
|
|
|
|
thread_n_stride_ = UP_DIV(num_unit_, thread_n_num_);
|
|
|
|
|
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int QuantDTypeCastCPUKernel::ReSize() { return RET_OK; }
|
|
|
|
|
|
|
|
|
|
int QuantDTypeCastCPUKernel::QuantDTypeCast(int task_id) {
|
|
|
|
|
int num_unit_thread = MSMIN(thread_n_stride_, num_unit_ - task_id * thread_n_stride_);
|
|
|
|
|
if (num_unit_thread <= 0) {
|
|
|
|
@ -108,6 +105,11 @@ int QuantDTypeCastRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int QuantDTypeCastCPUKernel::Run() {
|
|
|
|
|
auto prepare_ret = Prepare();
|
|
|
|
|
if (prepare_ret != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
|
|
|
|
|
return prepare_ret;
|
|
|
|
|
}
|
|
|
|
|
if (inverse_) {
|
|
|
|
|
int8_ptr_ = reinterpret_cast<int8_t *>(inputs_[0]->Data());
|
|
|
|
|
float32_ptr_ = reinterpret_cast<float *>(outputs_[0]->Data());
|
|
|
|
|