|
|
|
@ -4951,8 +4951,7 @@ class LRN(PrimitiveWithInfer):
|
|
|
|
|
bias (float): An offset (usually positive to avoid dividing by 0).
|
|
|
|
|
alpha (float): A scale factor, usually positive.
|
|
|
|
|
beta (float): An exponent.
|
|
|
|
|
norm_region (str): Specify normalization region. Options: "ACROSS_CHANNELS", "WITHIN_CHANNEL".
|
|
|
|
|
Default: "ACROSS_CHANNELS".
|
|
|
|
|
norm_region (str): Specify normalization region. Options: "ACROSS_CHANNELS". Default: "ACROSS_CHANNELS".
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **x** (Tensor) - A 4D Tensor with float16 or float32 data type.
|
|
|
|
@ -4974,6 +4973,7 @@ class LRN(PrimitiveWithInfer):
|
|
|
|
|
validator.check_value_type("alpha", alpha, [float], self.name)
|
|
|
|
|
validator.check_value_type("beta", beta, [float], self.name)
|
|
|
|
|
validator.check_value_type("norm_region", norm_region, [str], self.name)
|
|
|
|
|
validator.check_string('norm_region', norm_region, ['ACROSS_CHANNELS'], self.name)
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype):
|
|
|
|
|
validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32,), self.name)
|
|
|
|
|