diff --git a/mindspore/lite/src/ops/cast.cc b/mindspore/lite/src/ops/cast.cc index 48c1eab193..6e23da4732 100644 --- a/mindspore/lite/src/ops/cast.cc +++ b/mindspore/lite/src/ops/cast.cc @@ -25,6 +25,39 @@ int Cast::GetDstT() const { return this->primitive_->value.AsCast()->dstT; } void Cast::SetSrcT(int src_t) { this->primitive_->value.AsCast()->srcT = src_t; } void Cast::SetDstT(int dst_t) { this->primitive_->value.AsCast()->dstT = dst_t; } +int Cast::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_Cast; + } + if (this->primitive_->value.type != schema::PrimitiveType_Cast) { + MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + auto attr = new (std::nothrow) schema::CastT(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new primitiveT value failed"; + return RET_ERROR; + } + auto srcAnf = reinterpret_cast(prim.GetAttr("SrcT").get()); + auto dstAnf = reinterpret_cast(prim.GetAttr("DstT").get()); + attr->srcT = srcAnf->number_type(); + attr->dstT = dstAnf->number_type(); + this->primitive_->value.value = attr; + if (this->primitive_->value.value == nullptr) { + MS_LOG(ERROR) << "primitive value is nullptr"; + return RET_ERROR; + } + } + + return RET_OK; +} + #else int Cast::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { MS_ASSERT(nullptr != primitive); diff --git a/mindspore/lite/src/ops/cast.h b/mindspore/lite/src/ops/cast.h index 55dcf7663a..7e375c0cf6 100644 --- a/mindspore/lite/src/ops/cast.h +++ b/mindspore/lite/src/ops/cast.h @@ -33,6 +33,7 @@ class Cast : public PrimitiveC { explicit Cast(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} void SetSrcT(int src_t); void SetDstT(int dst_t); + int UnPackAttr(const Primitive &prim, const std::vector &inputs); #else Cast() = default; diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index cb1871b1a5..0cd86145d1 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -399,8 +399,8 @@ std::shared_ptr PrimitiveC::Create(const Primitive &prim, const std: return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "StridedSlice") { return NewPrimitiveC(prim, inputs, quantType); - } else if (op_type == "AvgPool") { - return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "Cast") { + return NewPrimitiveC(prim, inputs, quantType); #ifdef SUPPORT_TRAIN