Fix onehot input 3 convert

pull/9009/head
zhaozhenlong 4 years ago
parent 63bb481ddb
commit 01267949b1

@ -82,7 +82,8 @@ Registry OneHotRegistry(schema::PrimitiveType_OneHot, OneHotCreator);
namespace {
constexpr size_t kOneHotInputNum = 4;
}
constexpr size_t kOneHotInputNumOpt = 3;
} // namespace
int OneHot::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs) {
if (this->primitive_ == nullptr) {
return RET_NULL_PTR;
@ -90,8 +91,10 @@ int OneHot::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outpu
int axis = GetAxis();
// indices, depth, on_value, off_value
if (inputs.size() != kOneHotInputNum) {
MS_LOG(ERROR) << "OneHot got inputs num " << inputs.size() << ", should be " << kOneHotInputNum;
// indices, depth, on_off_value(contain 2 values);
if (inputs.size() != kOneHotInputNum && inputs.size() != kOneHotInputNumOpt) {
MS_LOG(ERROR) << "OneHot got inputs num " << inputs.size() << ", should be " << kOneHotInputNum << " or "
<< kOneHotInputNumOpt;
return RET_ERROR;
}
auto depth_tensor = inputs.at(1);

@ -43,7 +43,7 @@ int SpaceToBatchCPUKernel::ReSize() {
MS_ASSERT(input_tensor);
auto output_tensor = out_tensors_.at(0);
MS_ASSERT(output_tensor);
MS_ASSERT(param);
MS_ASSERT(param_);
for (size_t i = 0; i < DIMENSION_4D; i++) {
param_->input_shape_[i] = input_tensor->shape().at(i);
param_->output_shape_[i] = output_tensor->shape().at(i);

@ -34,15 +34,18 @@ int SqueezeCPUKernel::ReSize() { return RET_OK; }
int SqueezeCPUKernel::Run() {
mindspore::lite::STATUS ret = RET_ERROR;
size_t data_size = in_tensors_.front()->Size();
MS_ASSERT(input_ptr);
MS_ASSERT(output_ptr);
if (in_tensors_.front()->data_type() == kNumberTypeInt32) {
auto input_ptr = reinterpret_cast<int32_t *>(in_tensors_.front()->MutableData());
auto output_ptr = reinterpret_cast<int32_t *>(out_tensors_.front()->MutableData());
MS_ASSERT(input_ptr);
MS_ASSERT(output_ptr);
ret = DoSqueezeInt32(input_ptr, output_ptr, data_size);
} else {
auto input_ptr = reinterpret_cast<float *>(in_tensors_.front()->MutableData());
auto output_ptr = reinterpret_cast<float *>(out_tensors_.front()->MutableData());
MS_ASSERT(input_ptr);
MS_ASSERT(output_ptr);
ret = DoSqueeze(input_ptr, output_ptr, data_size);
}

@ -61,7 +61,7 @@ int SqueezeInt8CPUKernel::Init() {
return RET_ERROR;
}
auto in_quant_args = in_tensors_.front()->quant_params();
MS_ASSERT(quant_args.size() > 0);
MS_ASSERT(in_quant_args.size() > 0);
quant_squeeze_param_->in_quant_args_->scale_ = in_quant_args.front().scale;
quant_squeeze_param_->in_quant_args_->zp_ = in_quant_args.front().zeroPoint;

Loading…
Cancel
Save