Fix: ApplyAdagrad CodeDEX.

pull/11272/head
yang_chun 4 years ago
parent 823ba8d71f
commit 8ba9987460

@ -84,6 +84,7 @@ void ApplyAdagradCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs)
if (batch_size == 0) {
MS_LOG(EXCEPTION) << "Error occur in launch kernel";
return;
}
while (start < length) {
size_t end = (start + batch_size) > length ? length : (start + batch_size);
@ -98,7 +99,8 @@ void ApplyAdagradCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs)
}
template <typename T>
void ApplyAdagradCPUKernel::LaunchApplyAdagrad(T var, T accum, T lr, T gradient, size_t start, size_t end) {
void ApplyAdagradCPUKernel::LaunchApplyAdagrad(T const var, T const accum, const T lr, const T gradient, size_t start,
size_t end) {
// DataType can only be float32 or float16, so eps will not be zero.
using DataType = typename std::iterator_traits<T>::value_type;
const DataType one = DataType(1);

@ -38,7 +38,7 @@ class ApplyAdagradCPUKernel : public CPUKernel {
template <typename T>
void LaunchKernel(const std::vector<AddressPtr> &inputs);
template <typename T>
void LaunchApplyAdagrad(T var, T accum, T lr, T gradient, size_t start, size_t end);
void LaunchApplyAdagrad(T const var, T const accum, const T lr, const T gradient, size_t start, size_t end);
bool update_slots_{true};
TypeId dtype_{kTypeUnknown};
};

Loading…
Cancel
Save