|
|
|
@ -28,9 +28,9 @@ constexpr int kShapeInputNum = 1;
|
|
|
|
|
constexpr int kShapeOutputNum = 1;
|
|
|
|
|
} // namespace
|
|
|
|
|
#ifdef PRIMITIVE_WRITEABLE
|
|
|
|
|
float ConstantOfShape::GetValue() const { return this->primitive_->value.AsConstantOfShape()->value; }
|
|
|
|
|
std::vector<float> ConstantOfShape::GetValue() const { return this->primitive_->value.AsConstantOfShape()->value; }
|
|
|
|
|
|
|
|
|
|
void ConstantOfShape::SetValue(float value) { this->primitive_->value.AsConstantOfShape()->value = value; }
|
|
|
|
|
int ConstantOfShape::GetDataType() const { return this->primitive_->value.AsConstantOfShape()->dataType; }
|
|
|
|
|
|
|
|
|
|
#else
|
|
|
|
|
int ConstantOfShape::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
|
|
|
@ -41,12 +41,22 @@ int ConstantOfShape::UnPackToFlatBuilder(const schema::Primitive *primitive, fla
|
|
|
|
|
MS_LOG(ERROR) << "value_as_ConstantOfShape return nullptr";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
auto val_offset = schema::CreateConstantOfShape(*fbb, attr->value());
|
|
|
|
|
std::vector<float> value;
|
|
|
|
|
if (attr->value() != nullptr) {
|
|
|
|
|
for (int i = 0; i < static_cast<int>(attr->value()->size()); i++) {
|
|
|
|
|
value.push_back(attr->value()->data()[i]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
auto val_offset = schema::CreateConstantOfShapeDirect(*fbb, attr->dataType(), &value);
|
|
|
|
|
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_ConstantOfShape, val_offset.o);
|
|
|
|
|
fbb->Finish(prim_offset);
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
float ConstantOfShape::GetValue() const { return this->primitive_->value_as_ConstantOfShape()->value(); }
|
|
|
|
|
std::vector<float> ConstantOfShape::GetValue() const {
|
|
|
|
|
auto fb_vector = this->primitive_->value_as_ConstantOfShape()->value();
|
|
|
|
|
return std::vector<float>(fb_vector->begin(), fb_vector->end());
|
|
|
|
|
}
|
|
|
|
|
int ConstantOfShape::GetDataType() const { return this->primitive_->value_as_ConstantOfShape()->dataType(); }
|
|
|
|
|
|
|
|
|
|
PrimitiveC *ConstantOfShapeCreator(const schema::Primitive *primitive) {
|
|
|
|
|
return PrimitiveC::NewPrimitiveC<ConstantOfShape>(primitive);
|
|
|
|
@ -70,7 +80,7 @@ int ConstantOfShape::InferShape(std::vector<Tensor *> inputs_, std::vector<Tenso
|
|
|
|
|
}
|
|
|
|
|
auto in_tensor = inputs_.front();
|
|
|
|
|
auto out_tensor = outputs_.front();
|
|
|
|
|
out_tensor->set_data_type(kNumberTypeFloat32);
|
|
|
|
|
out_tensor->set_data_type(static_cast<TypeId>(GetDataType()));
|
|
|
|
|
out_tensor->SetFormat(in_tensor->GetFormat());
|
|
|
|
|
if (!GetInferFlag()) {
|
|
|
|
|
return RET_OK;
|
|
|
|
|