|
|
|
@ -84,6 +84,7 @@ void ApplyAdagradCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs)
|
|
|
|
|
|
|
|
|
|
if (batch_size == 0) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Error occur in launch kernel";
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
while (start < length) {
|
|
|
|
|
size_t end = (start + batch_size) > length ? length : (start + batch_size);
|
|
|
|
@ -98,7 +99,8 @@ void ApplyAdagradCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void ApplyAdagradCPUKernel::LaunchApplyAdagrad(T var, T accum, T lr, T gradient, size_t start, size_t end) {
|
|
|
|
|
void ApplyAdagradCPUKernel::LaunchApplyAdagrad(T const var, T const accum, const T lr, const T gradient, size_t start,
|
|
|
|
|
size_t end) {
|
|
|
|
|
// DataType can only be float32 or float16, so eps will not be zero.
|
|
|
|
|
using DataType = typename std::iterator_traits<T>::value_type;
|
|
|
|
|
const DataType one = DataType(1);
|
|
|
|
|