|
|
|
@ -26,46 +26,46 @@
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace ops {
|
|
|
|
|
void Lrn::set_depth_radius(const int64_t depth_radius) {
|
|
|
|
|
void LRN::set_depth_radius(const int64_t depth_radius) {
|
|
|
|
|
CheckAndConvertUtils::CheckInteger(kDepthRadius, depth_radius, kGreaterEqual, 0, this->name());
|
|
|
|
|
this->AddAttr(kDepthRadius, MakeValue(depth_radius));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int64_t Lrn::get_depth_radius() const {
|
|
|
|
|
int64_t LRN::get_depth_radius() const {
|
|
|
|
|
auto value_ptr = GetAttr(kDepthRadius);
|
|
|
|
|
return GetValue<int64_t>(value_ptr);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Lrn::set_bias(const float bias) { this->AddAttr(kBias, MakeValue(bias)); }
|
|
|
|
|
void LRN::set_bias(const float bias) { this->AddAttr(kBias, MakeValue(bias)); }
|
|
|
|
|
|
|
|
|
|
float Lrn::get_bias() const {
|
|
|
|
|
float LRN::get_bias() const {
|
|
|
|
|
auto value_ptr = GetAttr(kBias);
|
|
|
|
|
return GetValue<float>(value_ptr);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Lrn::set_alpha(const float alpha) { this->AddAttr(kAlpha, MakeValue(alpha)); }
|
|
|
|
|
void LRN::set_alpha(const float alpha) { this->AddAttr(kAlpha, MakeValue(alpha)); }
|
|
|
|
|
|
|
|
|
|
float Lrn::get_alpha() const {
|
|
|
|
|
float LRN::get_alpha() const {
|
|
|
|
|
auto value_ptr = GetAttr(kAlpha);
|
|
|
|
|
return GetValue<float>(value_ptr);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Lrn::set_beta(const float beta) { this->AddAttr(kBeta, MakeValue(beta)); }
|
|
|
|
|
void LRN::set_beta(const float beta) { this->AddAttr(kBeta, MakeValue(beta)); }
|
|
|
|
|
|
|
|
|
|
float Lrn::get_beta() const {
|
|
|
|
|
float LRN::get_beta() const {
|
|
|
|
|
auto value_ptr = GetAttr(kBeta);
|
|
|
|
|
return GetValue<float>(value_ptr);
|
|
|
|
|
}
|
|
|
|
|
void Lrn::set_norm_region(const std::string &norm_region) {
|
|
|
|
|
void LRN::set_norm_region(const std::string &norm_region) {
|
|
|
|
|
CheckAndConvertUtils::CheckString(kNormRegion, norm_region, {"ACROSS_CHANNELS"}, this->name());
|
|
|
|
|
this->AddAttr(kNormRegion, MakeValue(norm_region));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string Lrn::get_norm_region() const {
|
|
|
|
|
std::string LRN::get_norm_region() const {
|
|
|
|
|
auto value_ptr = GetAttr(kNormRegion);
|
|
|
|
|
return GetValue<std::string>(value_ptr);
|
|
|
|
|
}
|
|
|
|
|
void Lrn::Init(const int64_t depth_radius, const float bias, const float alpha, const float beta,
|
|
|
|
|
void LRN::Init(const int64_t depth_radius, const float bias, const float alpha, const float beta,
|
|
|
|
|
const std::string &norm_region) {
|
|
|
|
|
this->set_depth_radius(depth_radius);
|
|
|
|
|
this->set_bias(bias);
|
|
|
|
@ -102,6 +102,7 @@ AbstractBasePtr LrnInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr
|
|
|
|
|
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
|
|
|
|
InferShape(primitive, input_args)->shape());
|
|
|
|
|
}
|
|
|
|
|
REGISTER_PRIMITIVE_C(kNameLrn, Lrn);
|
|
|
|
|
REGISTER_PRIMITIVE_EVAL_IMPL(LRN, prim::kPrimLrn, LrnInfer);
|
|
|
|
|
REGISTER_PRIMITIVE_C(kNameLRN, LRN);
|
|
|
|
|
} // namespace ops
|
|
|
|
|
} // namespace mindspore
|
|
|
|
|