|
|
|
@ -3,7 +3,7 @@
|
|
|
|
|
|
|
|
|
|
#include "parameter_optimizer.h"
|
|
|
|
|
|
|
|
|
|
template <class T>
|
|
|
|
|
template <paddle_element_type T>
|
|
|
|
|
struct EnumToType {};
|
|
|
|
|
|
|
|
|
|
template <class T>
|
|
|
|
@ -11,15 +11,14 @@ struct TypeToEnum {};
|
|
|
|
|
|
|
|
|
|
#define MATCH_ENUM_TYPE(TYPE, ENUM) \
|
|
|
|
|
template <> \
|
|
|
|
|
struct TypeToEnum<ENUM> { \
|
|
|
|
|
struct TypeToEnum<TYPE> { \
|
|
|
|
|
static paddle_element_type v() { return ENUM; }; \
|
|
|
|
|
static constexpr TYPE value = ENUM;
|
|
|
|
|
}
|
|
|
|
|
;
|
|
|
|
|
template <>
|
|
|
|
|
struct EnumToType<ENUM> {
|
|
|
|
|
typedef TYPE Type;
|
|
|
|
|
}
|
|
|
|
|
static constexpr TYPE value = ENUM; \
|
|
|
|
|
}; \
|
|
|
|
|
template <> \
|
|
|
|
|
struct EnumToType<ENUM> { \
|
|
|
|
|
typedef TYPE Type; \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
MATCH_ENUM_TYPE(int32_t, PADDLE_ELEMENT_TYPE_INT32);
|
|
|
|
|
MATCH_ENUM_TYPE(uint32_t, PADDLE_ELEMENT_TYPE_UINT32);
|
|
|
|
@ -27,11 +26,10 @@ MATCH_ENUM_TYPE(int64_t, PADDLE_ELEMENT_TYPE_INT64);
|
|
|
|
|
MATCH_ENUM_TYPE(uint64_t, PADDLE_ELEMENT_TYPE_UINT64);
|
|
|
|
|
MATCH_ENUM_TYPE(float, PADDLE_ELEMENT_TYPE_FLOAT32);
|
|
|
|
|
MATCH_ENUM_TYPE(double, PADDLE_ELEMENT_TYPE_FLOAT64);
|
|
|
|
|
|
|
|
|
|
struct paddle_optimizer {
|
|
|
|
|
struct paddle_optimizer {
|
|
|
|
|
/*! \brief optmizer in C++ side */
|
|
|
|
|
|
|
|
|
|
paddle::optimizer::ParameterOptimzier* impl;
|
|
|
|
|
paddle::optimizer::ParameterOptimizerBase* impl;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
paddle_optimizer* paddle_create_optimizer(const unsigned char* config_proto,
|
|
|
|
@ -48,7 +46,7 @@ int paddle_release_optimizer(paddle_optimizer* o) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int paddle_update_parameter(paddle_optimizer* o,
|
|
|
|
|
paddle_element_type data_type,
|
|
|
|
|
const paddle_element_type data_type,
|
|
|
|
|
const void* grad_buffer,
|
|
|
|
|
int num_bytes) {
|
|
|
|
|
auto type = EnumToType<data_type>::Type;
|
|
|
|
@ -59,7 +57,7 @@ int paddle_update_parameter(paddle_optimizer* o,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int paddle_optimizer_set_weights(paddle_optimizer* o,
|
|
|
|
|
paddle_element_type data_type,
|
|
|
|
|
const paddle_element_type data_type,
|
|
|
|
|
void* param_buffer,
|
|
|
|
|
int num_bytes) {
|
|
|
|
|
auto type = EnumToType<data_type>::Type;
|
|
|
|
|