You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/paddle/optimizer/optimizer.cc

76 lines
2.3 KiB

#include "optimizer.h"
#include <string>
#include "parameter_optimizer.h"
template <class T>
struct EnumToType {};
template <class T>
struct TypeToEnum {};
#define MATCH_ENUM_TYPE(TYPE, ENUM) \
template <> \
struct TypeToEnum<ENUM> { \
static paddle_element_type v() { return ENUM; }; \
static constexpr TYPE value = ENUM;
}
;
template <>
struct EnumToType<ENUM> {
typedef TYPE Type;
}
MATCH_ENUM_TYPE(int32_t, PADDLE_ELEMENT_TYPE_INT32);
MATCH_ENUM_TYPE(uint32_t, PADDLE_ELEMENT_TYPE_UINT32);
MATCH_ENUM_TYPE(int64_t, PADDLE_ELEMENT_TYPE_INT64);
MATCH_ENUM_TYPE(uint64_t, PADDLE_ELEMENT_TYPE_UINT64);
MATCH_ENUM_TYPE(float, PADDLE_ELEMENT_TYPE_FLOAT32);
MATCH_ENUM_TYPE(double, PADDLE_ELEMENT_TYPE_FLOAT64);
struct paddle_optimizer {
/*! \brief optmizer in C++ side */
paddle::optimizer::ParameterOptimzier* impl;
};
paddle_optimizer* paddle_create_optimizer(const unsigned char* config_proto,
int config_proto_len) {
paddle_optimizer* optimizer;
std::string config(config_proto, config_proto + config_proto_len);
optimizer->impl->create(config_proto);
return optimizer;
}
int paddle_release_optimizer(paddle_optimizer* o) {
if (o != nullptr) delete o->impl;
return PADDLE_SUCCESS;
}
int paddle_update_parameter(paddle_optimizer* o,
paddle_element_type data_type,
const void* grad_buffer,
int num_bytes) {
auto type = EnumToType<data_type>::Type;
paddle::Tensor<type> gradient(reinterpret_cast<type*>(grad_buffer),
num_bytes);
o->impl->update(gradient);
return PADDLE_SUCCESS;
}
int paddle_optimizer_set_weights(paddle_optimizer* o,
paddle_element_type data_type,
void* param_buffer,
int num_bytes) {
auto type = EnumToType<data_type>::Type;
paddle::Tensor<type>* param = new paddle::Tensor<type>(
reinterpret_cast<type*>(param_buffer), num_bytes);
o->impl->set_weight(param);
return PADDLE_SUCCESS;
}
void* paddle_optimizer_get_weights(paddle_optimizer* o) {
void* buffer = (void*)o->impl->get_weight();
return buffer;
}