|
|
|
@ -44,8 +44,14 @@ int Cast::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
|
|
|
|
|
MS_LOG(ERROR) << "tensor number is error.";
|
|
|
|
|
return RET_INPUT_TENSOR_ERROR;
|
|
|
|
|
}
|
|
|
|
|
output->SetFormat(input->GetFormat());
|
|
|
|
|
auto cast_prim = this->primitive->value_as_Cast();
|
|
|
|
|
MS_ASSERT(cast_prim != nullptr);
|
|
|
|
|
output->set_data_type(static_cast<TypeId>(cast_prim->dstT()));
|
|
|
|
|
if (!GetInferFlag()) {
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (input->data_type() != cast_prim->srcT()) {
|
|
|
|
|
MS_LOG(ERROR) << "input dataType is error";
|
|
|
|
|
return RET_INPUT_TENSOR_ERROR;
|
|
|
|
@ -54,13 +60,8 @@ int Cast::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
|
|
|
|
|
MS_LOG(ERROR) << "Unsupported input data type " << input->data_type();
|
|
|
|
|
return RET_INPUT_TENSOR_ERROR;
|
|
|
|
|
}
|
|
|
|
|
if (cast_prim->dstT() != kNumberTypeFloat && cast_prim->dstT() != kNumberTypeFloat32) {
|
|
|
|
|
MS_LOG(ERROR) << "Invalid output datatype " << cast_prim->dstT();
|
|
|
|
|
return RET_INPUT_TENSOR_ERROR;
|
|
|
|
|
}
|
|
|
|
|
output->SetFormat(input->GetFormat());
|
|
|
|
|
|
|
|
|
|
output->set_shape(input->shape());
|
|
|
|
|
output->set_data_type(TypeId::kNumberTypeFloat32);
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
} // namespace lite
|
|
|
|
|