Merge branch 'develop' into support_scalar_kernels

gangliao-patch-1
Liu Yiqun 8 years ago
commit c3705602de

@ -70,7 +70,7 @@ before looking into the
We provide [English](http://www.paddlepaddle.org/develop/doc/) and We provide [English](http://www.paddlepaddle.org/develop/doc/) and
[Chinese](http://www.paddlepaddle.org/doc_cn/) documentation. [Chinese](http://www.paddlepaddle.org/doc_cn/) documentation.
- [Deep Learning 101](http://book.paddlepaddle.org/index.en.html) - [Deep Learning 101](http://book.paddlepaddle.org/index.html)
You might want to start from the this online interactive book that can run in Jupyter Notebook. You might want to start from the this online interactive book that can run in Jupyter Notebook.

@ -11,11 +11,23 @@ find_path(CUDNN_INCLUDE_DIR cudnn.h
get_filename_component(__libpath_hist ${CUDA_CUDART_LIBRARY} PATH) get_filename_component(__libpath_hist ${CUDA_CUDART_LIBRARY} PATH)
if(NOT ${CMAKE_HOST_SYSTEM_PROCESSOR})
execute_process(
COMMAND uname -m COMMAND tr -d '\n'
OUTPUT_VARIABLE HOST_ARCH
RESULT_VARIABLE UNAME_RESULT)
if(${UNAME_RESULT})
set(HOST_ARCH "x86_64")
endif(${UNAME_RESULT})
else(NOT ${CMAKE_HOST_SYSTEM_PROCESSOR})
set(HOST_ARCH ${CMAKE_HOST_SYSTEM_PROCESSOR})
endif(NOT ${CMAKE_HOST_SYSTEM_PROCESSOR})
list(APPEND CUDNN_CHECK_LIBRARY_DIRS list(APPEND CUDNN_CHECK_LIBRARY_DIRS
${CUDNN_ROOT} ${CUDNN_ROOT}
${CUDNN_ROOT}/lib64 ${CUDNN_ROOT}/lib64
${CUDNN_ROOT}/lib ${CUDNN_ROOT}/lib
${CUDNN_ROOT}/lib/x86_64-linux-gnu ${CUDNN_ROOT}/lib/${HOST_ARCH}-linux-gnu
$ENV{CUDNN_ROOT} $ENV{CUDNN_ROOT}
$ENV{CUDNN_ROOT}/lib64 $ENV{CUDNN_ROOT}/lib64
$ENV{CUDNN_ROOT}/lib $ENV{CUDNN_ROOT}/lib

@ -37,7 +37,7 @@ IF(NOT ${CBLAS_FOUND})
SET(OPTIONAL_ARGS HOSTCC=${HOST_C_COMPILER} TARGET=ARMV7 USE_THREAD=0 libs) SET(OPTIONAL_ARGS HOSTCC=${HOST_C_COMPILER} TARGET=ARMV7 USE_THREAD=0 libs)
ELSE() ELSE()
SET(OPENBLAS_COMMIT "v0.2.19") SET(OPENBLAS_COMMIT "v0.2.19")
SET(OPENBLAS_ARGS DYNAMIC_ARCH=1 libs) SET(OPTIONAL_ARGS DYNAMIC_ARCH=1 libs NUM_THREADS=64)
ENDIF() ENDIF()
ExternalProject_Add( ExternalProject_Add(

@ -136,6 +136,9 @@ int paddle_send_grads(paddle_pserver_client* client, const paddle_gradient* grad
/** /**
* @brief paddle_get_params gets parameters from parameter servers. * @brief paddle_get_params gets parameters from parameter servers.
* *
* paddle_get_params will block until parameters are initialized on
* the parameter servers.
*
* @param names the array of names of the parameters to get. * @param names the array of names of the parameters to get.
* @param dst the destination array of parameters to save to. * @param dst the destination array of parameters to save to.
* @param len the length of the names array and the paddle_parameter * @param len the length of the names array and the paddle_parameter

@ -38,7 +38,7 @@ endif()
mark_as_advanced(CMAKE_Go_COMPILER) mark_as_advanced(CMAKE_Go_COMPILER)
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/cmake/CMakeGoCompiler.cmake.in configure_file(${CMAKE_MODULE_PATH}/CMakeGoCompiler.cmake.in
${CMAKE_PLATFORM_INFO_DIR}/CMakeGoCompiler.cmake @ONLY) ${CMAKE_PLATFORM_INFO_DIR}/CMakeGoCompiler.cmake @ONLY)
set(CMAKE_Go_COMPILER_ENV_VAR "GO_COMPILER") set(CMAKE_Go_COMPILER_ENV_VAR "GO_COMPILER")

@ -21,7 +21,7 @@ function(CheckCompilerCXX11Flag)
if (${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 3.3) if (${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 3.3)
message(FATAL_ERROR "Unsupported Clang version. Clang >= 3.3 required.") message(FATAL_ERROR "Unsupported Clang version. Clang >= 3.3 required.")
endif() endif()
endif() endif()
endif() endif()
endfunction() endfunction()
@ -42,4 +42,4 @@ if (CUDA_VERSION VERSION_GREATER "8.0" OR CUDA_VERSION VERSION_EQUAL "8.0")
list(APPEND __arch_flags " -gencode arch=compute_60,code=sm_60") list(APPEND __arch_flags " -gencode arch=compute_60,code=sm_60")
endif() endif()
set(CUDA_NVCC_FLAGS ${__arch_flags} ${CUDA_NVCC_FLAGS}) set(CUDA_NVCC_FLAGS ${__arch_flags} ${CUDA_NVCC_FLAGS})

@ -0,0 +1,50 @@
set(GOPATH "${CMAKE_CURRENT_BINARY_DIR}/go")
file(MAKE_DIRECTORY ${GOPATH})
set(PADDLE_IN_GOPATH "${GOPATH}/src/github.com/PaddlePaddle")
file(MAKE_DIRECTORY ${PADDLE_IN_GOPATH})
function(GO_LIBRARY NAME BUILD_TYPE)
if(BUILD_TYPE STREQUAL "STATIC")
set(BUILD_MODE -buildmode=c-archive)
set(LIB_NAME "lib${NAME}.a")
else()
set(BUILD_MODE -buildmode=c-shared)
if(APPLE)
set(LIB_NAME "lib${NAME}.dylib")
else()
set(LIB_NAME "lib${NAME}.so")
endif()
endif()
file(GLOB GO_SOURCE RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.go")
file(RELATIVE_PATH rel ${CMAKE_BINARY_DIR} ${CMAKE_CURRENT_SOURCE_DIR})
# find Paddle directory.
get_filename_component(PARENT_DIR ${CMAKE_CURRENT_SOURCE_DIR} DIRECTORY)
get_filename_component(PARENT_DIR ${PARENT_DIR} DIRECTORY)
get_filename_component(PADDLE_DIR ${PARENT_DIR} DIRECTORY)
# automatically get all dependencies specified in the source code
# for given target.
add_custom_target(goGet env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} get -d ${rel}/...)
# make a symlink that references Paddle inside $GOPATH, so go get
# will use the local changes in Paddle rather than checkout Paddle
# in github.
add_custom_target(copyPaddle
COMMAND ln -sf ${PADDLE_DIR} ${PADDLE_IN_GOPATH})
add_dependencies(goGet copyPaddle)
add_custom_command(OUTPUT ${OUTPUT_DIR}/.timestamp
COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build ${BUILD_MODE}
-o "${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME}"
${CMAKE_GO_FLAGS} ${GO_SOURCE}
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR})
add_custom_target(${NAME} ALL DEPENDS ${OUTPUT_DIR}/.timestamp ${ARGN})
add_dependencies(${NAME} goGet)
if(NOT BUILD_TYPE STREQUAL "STATIC")
install(PROGRAMS ${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME} DESTINATION bin)
endif()
endfunction(GO_LIBRARY)

@ -13,8 +13,8 @@ import (
"github.com/namsral/flag" "github.com/namsral/flag"
"github.com/PaddlePaddle/Paddle/paddle/go/master" "github.com/PaddlePaddle/Paddle/go/master"
"github.com/PaddlePaddle/Paddle/paddle/go/recordio" "github.com/PaddlePaddle/Paddle/go/recordio"
) )
func main() { func main() {

@ -8,7 +8,7 @@ import (
"github.com/namsral/flag" "github.com/namsral/flag"
"github.com/PaddlePaddle/Paddle/paddle/go/pserver" "github.com/PaddlePaddle/Paddle/go/pserver"
) )
func main() { func main() {

@ -6,7 +6,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/PaddlePaddle/Paddle/paddle/go/recordio" "github.com/PaddlePaddle/Paddle/go/recordio"
) )
const ( const (

@ -0,0 +1,13 @@
cmake_minimum_required(VERSION 3.0)
get_filename_component(PARENT_DIR ${CMAKE_CURRENT_SOURCE_DIR} DIRECTORY)
get_filename_component(PARENT_DIR ${PARENT_DIR} DIRECTORY)
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${PARENT_DIR}/cmake")
project(cxx_go C Go)
include(golang)
include(flags)
go_library(client STATIC)
add_subdirectory(test)

@ -39,10 +39,11 @@ import "C"
import ( import (
"log" "log"
"strings"
"sync" "sync"
"unsafe" "unsafe"
"github.com/PaddlePaddle/Paddle/paddle/go/pserver" "github.com/PaddlePaddle/Paddle/go/pserver"
) )
var nullPtr = unsafe.Pointer(uintptr(0)) var nullPtr = unsafe.Pointer(uintptr(0))
@ -78,34 +79,54 @@ func cArrayToSlice(p unsafe.Pointer, len int) []byte {
return nil return nil
} }
// create a Go clice backed by a C array, // create a Go clice backed by a C array, reference:
// reference: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices // https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
//
// Go garbage collector will not interact with this data, need
// to be freed properly.
return (*[1 << 30]byte)(p)[:len:len] return (*[1 << 30]byte)(p)[:len:len]
} }
type selector bool
func (s selector) Select() bool {
return bool(s)
}
type lister []pserver.Server
func (l lister) List() []pserver.Server {
return l
}
//export paddle_new_pserver_client //export paddle_new_pserver_client
func paddle_new_pserver_client(addr *C.char) C.client { func paddle_new_pserver_client(addrs *C.char, selected int) C.client {
c := pserver.NewClient(C.GoString(addr)) a := C.GoString(addrs)
as := strings.Split(a, ",")
servers := make([]pserver.Server, len(as))
for i := range as {
servers[i].Index = i
servers[i].Addr = as[i]
}
c := pserver.NewClient(lister(servers), len(as), selector(selected != 0))
return add(c) return add(c)
} }
//export paddle_new_etcd_pserver_client
func paddle_new_etcd_pserver_client(etcd_addr *C.char) C.client {
// TODO(helin): fault tolerant pserver client using etcd.
panic("not implemented.")
}
//export paddle_pserver_client_release //export paddle_pserver_client_release
func paddle_pserver_client_release(client C.client) { func paddle_pserver_client_release(client C.client) {
c := remove(client) remove(client)
c.Cleanup()
} }
//export paddle_begin_init_params //export paddle_begin_init_params
func paddle_begin_init_params(client C.client, pserver_config unsafe.Pointer, config_len C.int) C.int { func paddle_begin_init_params(client C.client) C.int {
c := get(client) c := get(client)
b := cArrayToSlice(pserver_config, int(config_len)) if selected := c.BeginInitParams(); selected {
selected, err := c.BeginInitParams(b)
if err != nil {
log.Println(err)
return -1
}
if selected {
return 1 return 1
} }
return 0 return 0
@ -227,7 +248,7 @@ func paddle_get_params(client C.client, names **C.char, dst **C.paddle_parameter
func paddle_save_model(client C.client, path *C.char) C.int { func paddle_save_model(client C.client, path *C.char) C.int {
p := C.GoString(path) p := C.GoString(path)
c := get(client) c := get(client)
err := c.SaveModel(p) err := c.Save(p)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
return -1 return -1

@ -1,18 +1,19 @@
#include "libclient.h" #include <stdio.h>
//#include "gtest/gtest.h" #include "libclient.h"
void panic() { void fail() {
// TODO(helin): fix: gtest using cmake is not working, using this // TODO(helin): fix: gtest using cmake is not working, using this
// hacky way for now. // hacky way for now.
*(void*)0; printf("test failed.\n");
exit(-1);
} }
int main() { int main() {
char addr[] = "localhost:3000"; char addr[] = "localhost:3000";
client c = paddle_new_pserver_client(addr); client c = paddle_new_pserver_client(addr, 1);
retry: retry:
if (paddle_begin_init_params(c, NULL, 0)) { if (paddle_begin_init_params(c)) {
paddle_parameter param; paddle_parameter param;
char name_a[] = "param_a"; char name_a[] = "param_a";
char name_b[] = "param_b"; char name_b[] = "param_b";
@ -35,7 +36,7 @@ retry:
goto retry; goto retry;
} }
} else { } else {
panic(); fail();
} }
char content[] = {0x00, 0x11, 0x22}; char content[] = {0x00, 0x11, 0x22};
@ -44,25 +45,25 @@ retry:
{"param_b", PADDLE_ELEMENT_TYPE_FLOAT32, content, 3}}; {"param_b", PADDLE_ELEMENT_TYPE_FLOAT32, content, 3}};
if (!paddle_send_grads(c, grads, 2)) { if (!paddle_send_grads(c, grads, 2)) {
panic(); fail();
} }
paddle_parameter* params[2] = {NULL, NULL}; paddle_parameter* params[2] = {NULL, NULL};
char* names[] = {"param_a", "param_b"}; char* names[] = {"param_a", "param_b"};
if (!paddle_get_params(c, names, params, 2)) { if (!paddle_get_params(c, names, params, 2)) {
panic(); fail();
} }
// get parameters again by reusing the allocated parameter buffers. // get parameters again by reusing the allocated parameter buffers.
if (!paddle_get_params(c, names, params, 2)) { if (!paddle_get_params(c, names, params, 2)) {
panic(); fail();
} }
paddle_release_param(params[0]); paddle_release_param(params[0]);
paddle_release_param(params[1]); paddle_release_param(params[1]);
if (!paddle_save_model(c, "/tmp/")) { if (!paddle_save_model(c, "/tmp/")) {
panic(); fail();
} }
return 0; return 0;

@ -0,0 +1,232 @@
package pserver
import (
"hash/fnv"
"log"
"sort"
"time"
"github.com/PaddlePaddle/Paddle/go/pserver/internal/connection"
)
// TODO(helin): add RPC call retry logic
// Selector selects if the client should initialize parameter servers.
type Selector interface {
Select() bool
}
// Server is the identification of a parameter Server.
type Server struct {
Index int
Addr string
}
// Lister lists currently available parameter servers.
type Lister interface {
List() []Server
}
// Client is the client to parameter servers.
type Client struct {
sel Selector
pservers []*connection.Conn
}
// NewClient creates a new client.
func NewClient(l Lister, pserverNum int, sel Selector) *Client {
c := &Client{sel: sel}
c.pservers = make([]*connection.Conn, pserverNum)
for i := 0; i < pserverNum; i++ {
c.pservers[i] = connection.New()
}
go c.monitorPservers(l, pserverNum)
return c
}
// monitorPservers monitors pserver addresses, and updates connection
// when the address changes.
func (c *Client) monitorPservers(l Lister, pserverNum int) {
knownServers := make([]Server, pserverNum)
ticker := time.NewTicker(10 * time.Second)
monitor := func() {
curServers := make([]Server, pserverNum)
list := l.List()
for _, l := range list {
curServers[l.Index] = l
}
for i := range knownServers {
if knownServers[i].Addr != curServers[i].Addr {
err := c.pservers[i].Connect(curServers[i].Addr)
if err != nil {
log.Println(err)
// connect to addr failed, set
// to last known addr in order
// to retry next time.
curServers[i].Addr = knownServers[i].Addr
}
}
}
knownServers = curServers
}
monitor()
for _ = range ticker.C {
monitor()
}
}
// BeginInitParams begins to initialize parameters on parameter
// servers.
//
// BeginInitParams will be called from multiple trainers, only one
// trainer will be selected to initialize the parameters on parameter
// servers. Other trainers will be blocked until the initialization is
// done, and they need to get the initialized parameters from
// parameter servers using GetParams.
func (c *Client) BeginInitParams() bool {
return c.sel.Select()
}
// InitParam initializes the parameter on parameter servers.
func (c *Client) InitParam(paramWithConfigs ParameterWithConfig) error {
var dummy int
return c.pservers[c.partition(paramWithConfigs.Param.Name)].Call("Service.InitParam", paramWithConfigs, &dummy)
}
// FinishInitParams tells parameter servers client has sent all
// parameters to parameter servers as initialization.
func (c *Client) FinishInitParams() error {
for _, p := range c.pservers {
var dummy int
err := p.Call("Service.FinishInitParams", dummy, &dummy)
if err != nil {
return err
}
}
return nil
}
// SendGrads sends gradients to parameter servers for updating
// parameters.
func (c *Client) SendGrads(grads []Gradient) error {
errCh := make(chan error, len(grads))
for _, g := range grads {
go func(g Gradient) {
var dummy int
err := c.pservers[c.partition(g.Name)].Call("Service.SendGrad", g, &dummy)
errCh <- err
}(g)
}
recv := 0
for err := range errCh {
if err != nil {
return err
}
recv++
if recv == len(grads) {
break
}
}
return nil
}
type result struct {
idx int
param Parameter
err error
}
type results []result
func (r results) Len() int {
return len(r)
}
func (r results) Less(i int, j int) bool {
return r[i].idx < r[j].idx
}
func (r results) Swap(i int, j int) {
r[i], r[j] = r[j], r[i]
}
// GetParams gets parameters from parameter servers.
func (c *Client) GetParams(names []string) ([]Parameter, error) {
rCh := make(chan result, len(names))
for idx, name := range names {
go func(name string, idx int) {
var parameter Parameter
err := c.pservers[c.partition(name)].Call("Service.GetParam", name, &parameter)
rCh <- result{idx: idx, param: parameter, err: err}
}(name, idx)
}
var rs results
recv := 0
for r := range rCh {
if r.err != nil {
return nil, r.err
}
rs = append(rs, r)
recv++
if recv == len(names) {
break
}
}
sort.Sort(rs)
ps := make([]Parameter, len(rs))
for i := range rs {
ps[i] = rs[i].param
}
return ps, nil
}
// Save indicates parameters to save the parameter to the given path.
func (c *Client) Save(path string) error {
errCh := make(chan error, len(c.pservers))
for _, p := range c.pservers {
var dummy int
err := p.Call("Service.Save", path, &dummy)
errCh <- err
}
recv := 0
for err := range errCh {
if err != nil {
return err
}
recv++
if recv == len(c.pservers) {
break
}
}
// TODO(helin): there will be many files under path, need to
// merge them into a single file.
return nil
}
func strHash(s string) uint32 {
h := fnv.New32a()
h.Write([]byte(s))
return h.Sum32()
}
// TODO(helin): now partition only select which parameter server to
// send the entire parameter. We need to partition a parameter into
// small blocks and send to different parameter servers.
func (c *Client) partition(key string) int {
return int(strHash(key) % uint32(len(c.pservers)))
}

@ -0,0 +1,123 @@
package pserver_test
import (
"net"
"net/http"
"net/rpc"
"strconv"
"strings"
"testing"
"github.com/PaddlePaddle/Paddle/go/pserver"
)
const numPserver = 10
var port [numPserver]int
func init() {
for i := 0; i < numPserver; i++ {
l, err := net.Listen("tcp", ":0")
if err != nil {
panic(err)
}
ss := strings.Split(l.Addr().String(), ":")
p, err := strconv.Atoi(ss[len(ss)-1])
if err != nil {
panic(err)
}
port[i] = p
go func(l net.Listener) {
s := pserver.NewService()
server := rpc.NewServer()
err := server.Register(s)
if err != nil {
panic(err)
}
mux := http.NewServeMux()
mux.Handle(rpc.DefaultRPCPath, server)
err = http.Serve(l, mux)
if err != nil {
panic(err)
}
}(l)
}
}
type selector bool
func (s selector) Select() bool {
return bool(s)
}
type lister []pserver.Server
func (l lister) List() []pserver.Server {
return l
}
func TestClientFull(t *testing.T) {
servers := make([]pserver.Server, numPserver)
for i := 0; i < numPserver; i++ {
servers[i] = pserver.Server{Index: i, Addr: ":" + strconv.Itoa(port[i])}
}
c := pserver.NewClient(lister(servers), len(servers), selector(true))
selected := c.BeginInitParams()
if !selected {
t.Fatal("should be selected.")
}
const numParameter = 100
for i := 0; i < numParameter; i++ {
var p pserver.Parameter
p.Name = "p_" + strconv.Itoa(i)
p.ElementType = pserver.Float32
p.Content = make([]byte, (i+1)*100)
err := c.InitParam(pserver.ParameterWithConfig{Param: p})
if err != nil {
t.Fatal(err)
}
}
err := c.FinishInitParams()
if err != nil {
t.Fatal(err)
}
var grads []pserver.Gradient
for i := 0; i < numParameter/2; i++ {
var g pserver.Gradient
g.Name = "p_" + strconv.Itoa(i)
g.ElementType = pserver.Float32
g.Content = make([]byte, (i+1)*100)
grads = append(grads, g)
}
err = c.SendGrads(grads)
if err != nil {
t.Fatal(err)
}
names := make([]string, numParameter)
for i := 0; i < numParameter; i++ {
names[i] = "p_" + strconv.Itoa(i)
}
params, err := c.GetParams(names)
if err != nil {
t.Fatal(err)
}
if len(names) != len(params) {
t.Fatalf("parameter size not match, need: %d, have: %d", len(names), len(params))
}
for i := range params {
if names[i] != params[i].Name {
t.Fatalf("order of returned parameter does not required: parameter name: %s, required name: %s", names[i], params[i])
}
}
}

@ -0,0 +1,84 @@
package connection
import (
"errors"
"net/rpc"
"sync"
)
// TODO(helin): add TCP re-connect logic
// Conn is a connection to a parameter server
type Conn struct {
mu sync.Mutex
client *rpc.Client
waitConn chan struct{}
}
// New creates a new connection.
func New() *Conn {
c := &Conn{}
return c
}
// Connect connects the connection to a address.
func (c *Conn) Connect(addr string) error {
c.mu.Lock()
if c.client != nil {
err := c.client.Close()
if err != nil {
c.mu.Unlock()
return err
}
c.client = nil
}
c.mu.Unlock()
client, err := rpc.DialHTTP("tcp", addr)
if err != nil {
return err
}
c.mu.Lock()
defer c.mu.Unlock()
if c.client == nil {
c.client = client
if c.waitConn != nil {
close(c.waitConn)
c.waitConn = nil
}
} else {
return errors.New("client already set from a concurrent goroutine")
}
return nil
}
// Call make a RPC call.
//
// Call will be blocked until the connection to remote RPC service
// being established.
func (c *Conn) Call(serviceMethod string, args interface{}, reply interface{}) error {
c.mu.Lock()
client := c.client
var waitCh chan struct{}
if client == nil {
if c.waitConn != nil {
waitCh = c.waitConn
} else {
waitCh = make(chan struct{})
c.waitConn = waitCh
}
}
c.mu.Unlock()
if waitCh != nil {
// wait until new connection being established
<-waitCh
return c.Call(serviceMethod, args, reply)
}
return client.Call(serviceMethod, args, reply)
}

@ -29,11 +29,11 @@ func newOptimizer(t optimizerType, learning_rate float64) *optimizer {
func (o *optimizer) UpdateParameter(p Parameter, g Gradient) error { func (o *optimizer) UpdateParameter(p Parameter, g Gradient) error {
if len(p.Content) != len(g.Content) { if len(p.Content) != len(g.Content) {
return fmt.Errorf("parameter and gradient length not match, parameter: %d, gradient: %d", len(p.Content), len(g.Content)) return fmt.Errorf("Name: %s, parameter and gradient length not match, parameter: %d, gradient: %d", p.Name, len(p.Content), len(g.Content))
} }
if p.ElementType != g.ElementType { if p.ElementType != g.ElementType {
return fmt.Errorf("parameter and gradient element type not match, parameter: %v, gradient: %v", p.ElementType, g.ElementType) return fmt.Errorf("Name: %s, parameter and gradient element type not match, parameter: %v, gradient: %v", p.Name, p.ElementType, g.ElementType)
} }
r := C.paddle_update_parameter(o.opt, unsafe.Pointer(&p.Content[0]), C.paddle_element_type(p.ElementType), unsafe.Pointer(&g.Content[0]), C.int(len(g.Content))) r := C.paddle_update_parameter(o.opt, unsafe.Pointer(&p.Content[0]), C.paddle_element_type(p.ElementType), unsafe.Pointer(&g.Content[0]), C.int(len(g.Content)))

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save