|
|
|
@ -1,3 +1,17 @@
|
|
|
|
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
|
|
|
|
|
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
you may not use this file except in compliance with the License.
|
|
|
|
|
You may obtain a copy of the License at
|
|
|
|
|
|
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
|
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "optimizer.h"
|
|
|
|
|
#include <glog/logging.h>
|
|
|
|
|
#include <cstdlib>
|
|
|
|
@ -6,8 +20,8 @@
|
|
|
|
|
|
|
|
|
|
#include "parameter_optimizer.h"
|
|
|
|
|
|
|
|
|
|
using namespace paddle;
|
|
|
|
|
using namespace paddle::optimizer;
|
|
|
|
|
using paddle::optimizer::ParameterOptimizer;
|
|
|
|
|
using paddle::optimizer::Tensor;
|
|
|
|
|
|
|
|
|
|
template <paddle_element_type VALUE>
|
|
|
|
|
struct EnumToType {};
|
|
|
|
@ -15,22 +29,21 @@ struct EnumToType {};
|
|
|
|
|
template <class T>
|
|
|
|
|
struct TypeToEnum {};
|
|
|
|
|
|
|
|
|
|
#define MATCH_ENUM_TYPE(TYPE, ENUM) \
|
|
|
|
|
template <> \
|
|
|
|
|
struct TypeToEnum<TYPE> { \
|
|
|
|
|
static paddle_element_type v() { return ENUM; }; \
|
|
|
|
|
static constexpr TYPE value = ENUM; \
|
|
|
|
|
}; \
|
|
|
|
|
template <> \
|
|
|
|
|
struct EnumToType<ENUM> { \
|
|
|
|
|
typedef TYPE Type; \
|
|
|
|
|
#define MATCH_ENUM_TYPE(TYPE, ENUM) \
|
|
|
|
|
template <> \
|
|
|
|
|
struct TypeToEnum<TYPE> { \
|
|
|
|
|
static paddle_element_type v() { return ENUM; } \
|
|
|
|
|
static constexpr TYPE value = ENUM; \
|
|
|
|
|
}; \
|
|
|
|
|
template <> \
|
|
|
|
|
struct EnumToType<ENUM> { \
|
|
|
|
|
typedef TYPE Type; \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
MATCH_ENUM_TYPE(int32_t, PADDLE_ELEMENT_TYPE_INT32);
|
|
|
|
|
MATCH_ENUM_TYPE(uint32_t, PADDLE_ELEMENT_TYPE_UINT32);
|
|
|
|
|
MATCH_ENUM_TYPE(int64_t, PADDLE_ELEMENT_TYPE_INT64);
|
|
|
|
|
MATCH_ENUM_TYPE(uint64_t, PADDLE_ELEMENT_TYPE_UINT64);
|
|
|
|
|
// TODO(zhihong): only implement below type, need to fix
|
|
|
|
|
MATCH_ENUM_TYPE(float, PADDLE_ELEMENT_TYPE_FLOAT32);
|
|
|
|
|
MATCH_ENUM_TYPE(double, PADDLE_ELEMENT_TYPE_FLOAT64);
|
|
|
|
|
|
|
|
|
|