You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/paddle/optimizer/optimizer.h

94 lines
2.6 KiB

#pragma once
#include <stdbool.h>
#include <stdint.h>
/**
* @brief optimizer library in independent with other module
* which will be used in :
* Case A, the gradient optimized locally on the trainer.
*
* Case B, the gradient optimized on the parameter server.
*/
#ifdef __cplusplus
extern "C" {
#endif
typedef enum {
PADDLE_ELEMENT_TYPE_INT32 = 0,
PADDLE_ELEMENT_TYPE_UINT32 = 1,
PADDLE_ELEMENT_TYPE_INT64 = 2,
PADDLE_ELEMENT_TYPE_UINT64 = 3,
PADDLE_ELEMENT_TYPE_FLOAT32 = 4,
PADDLE_ELEMENT_TYPE_FLOAT64 = 5,
} paddle_element_type;
/**
* @brief execution status code
*/
const int32_t PADDLE_SUCCESS = 0;
const int32_t PADDLE_ERROR = -1;
typedef struct paddle_optimizer paddle_optimizer;
/**
* this group interface called in order :
* 1. create optimizer with config
* 2. set weights
* 3. update_parameter
* 4. get_weights
* 5. release optimizer
*/
/**
* @brief create optimizer with proto_config
* @param config_proto, optimizer protobuf, see OptimizerConfig.proto in detail
* @return return optimizer instance
*/
paddle_optimizer* paddle_create_optimizer(const unsigned char* config_proto,
const int config_proto_len,
const paddle_element_type data_type,
void* param_buffer,
int num_bytes,
const char* state,
const int state_len);
/**
* @brief release optimizer
* @param optimizer
* @return return exec status
*/
int paddle_release_optimizer(paddle_optimizer* o);
/**
* @brief optimizer instance
* @param datatype of gradient and parameter
* @param gradient, calculate by optimzizer caller.
* TODO(zhihong): just pass loss to reduce communicate overhead.
* Project Adam Ms'14 paper for detail
* @param num_bytes, gradient size
* @return return exec status
*/
int paddle_update_parameter(paddle_optimizer* o,
const paddle_element_type data_type,
const void* gradient,
int num_bytes);
/**
* @brief optimizer instance
* @param param_buffer, initilized parameter buffer
* @return return content length
*/
int paddle_optimizer_get_weights(paddle_optimizer* o, void** param_buffer);
/**
* @brief optimzizer instance
* @param training state for receive SerializeState
* @return return state_buffer length
*/
int paddle_optimizer_get_state(paddle_optimizer* o, const char** state);
#ifdef __cplusplus
}
#endif