You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
53 lines
1.5 KiB
53 lines
1.5 KiB
8 years ago
|
#include <stdlib.h>
|
||
|
|
||
|
#include "optimizer.h"
|
||
|
|
||
8 years ago
|
typedef int (*update_func)(void*, void*, paddle_element_type, const void*, int);
|
||
8 years ago
|
typedef void (*release_func)(void*);
|
||
8 years ago
|
|
||
8 years ago
|
typedef struct paddle_optimizer {
|
||
8 years ago
|
update_func update;
|
||
|
release_func release;
|
||
8 years ago
|
void* optimizer;
|
||
|
} paddle_optimizer;
|
||
8 years ago
|
|
||
|
void paddle_release_optimizer(paddle_optimizer* o) {
|
||
8 years ago
|
o->release(o->optimizer);
|
||
8 years ago
|
free(o);
|
||
|
}
|
||
|
|
||
8 years ago
|
int paddle_update_parameter(paddle_optimizer* o,
|
||
|
void* buffer,
|
||
|
paddle_element_type element_type,
|
||
|
const void* gradient,
|
||
|
int num_bytes) {
|
||
8 years ago
|
return o->update(o->optimizer, buffer, element_type, gradient, num_bytes);
|
||
8 years ago
|
}
|
||
|
|
||
8 years ago
|
typedef struct { double learning_rate; } SGD_optimizer;
|
||
8 years ago
|
|
||
8 years ago
|
int update_SGD(void* optimizer,
|
||
|
void* buffer,
|
||
|
paddle_element_type element_type,
|
||
|
const void* gradient,
|
||
|
int num_bytes) {
|
||
8 years ago
|
SGD_optimizer* o = (SGD_optimizer*)optimizer;
|
||
8 years ago
|
// TODO
|
||
|
return 0;
|
||
|
}
|
||
8 years ago
|
|
||
8 years ago
|
void release_SGD(void* optimizer) {
|
||
|
SGD_optimizer* o = (SGD_optimizer*)optimizer;
|
||
|
// nothing allocated on heap
|
||
8 years ago
|
}
|
||
|
|
||
8 years ago
|
paddle_optimizer* paddle_create_SGD_optimizer(double learning_rate) {
|
||
8 years ago
|
SGD_optimizer* impl = (SGD_optimizer*)malloc(sizeof(SGD_optimizer));
|
||
|
impl->learning_rate = learning_rate;
|
||
|
paddle_optimizer* opt = (paddle_optimizer*)malloc(sizeof(paddle_optimizer));
|
||
|
opt->update = update_SGD;
|
||
|
opt->release = release_SGD;
|
||
|
opt->optimizer = impl;
|
||
|
return opt;
|
||
8 years ago
|
}
|