Merge branch 'develop' of github.com:PaddlePaddle/Paddle into rnn_varilen_design

cblas_new
Superjom 8 years ago
commit 01626be9b3

@ -22,9 +22,11 @@
hooks:
- id: clang-formater
- repo: https://github.com/PaddlePaddle/pre-commit-golang
sha: 16398aeccf263adaf53b2495eed0406347d76281
sha: 8337620115c25ff8333f1b1a493bd031049bd7c0
hooks:
- id: go-fmt
types: [go]
- id: gometalinter
types: [go]
- id: go-fmt
types:
- go
- id: gometalinter
types:
- go

@ -18,7 +18,6 @@ package main
#include <stdlib.h>
#include <string.h>
#include <stdio.h>
#define PADDLE_MASTER_OK 0
#define PADDLE_MASTER_ERROR -1
@ -101,6 +100,12 @@ func paddle_release_master_client(client C.paddle_master_client) {
remove(client)
}
//export paddle_start_get_records
func paddle_start_get_records(client C.paddle_master_client, pass C.int) {
c := get(client)
c.StartGetRecords(int(pass))
}
//export paddle_set_dataset
func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int) C.int {
c := get(client)
@ -121,15 +126,19 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int
// paddle_next_record gets the nexts training record.
//
// returns number of bytes of the records if success, -1 if failed.
// returns number of bytes of the records if success, -1 if failed, -2 if pass end.
//
//export paddle_next_record
func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int {
c := get(client)
r, err := c.NextRecord()
if err != nil {
// Error
// TODO: return the type of error?
// NOTE: use errors to indicate pass ends
if err.Error() == master.ErrAllTaskFailed.Error() ||
err.Error() == master.ErrNoMoreAvailable.Error() ||
err.Error() == master.ErrPassBefore.Error() {
return -2
}
*record = (*C.uchar)(nil)
return -1
}

@ -16,7 +16,6 @@ package master
import (
"os"
"sync"
"time"
"github.com/PaddlePaddle/Paddle/go/connection"
@ -27,9 +26,9 @@ import (
// Client is the client of the master server.
type Client struct {
conn *connection.Conn
ch chan record
initChOnce sync.Once
conn *connection.Conn
ch chan record
bufSize int
}
type record struct {
@ -46,11 +45,7 @@ func WithBuffer(bufSize int) func(*Client) error {
if bufSize <= 0 {
return nil
}
c.initChOnce.Do(func() {
c.ch = make(chan record, bufSize)
go c.getRecords()
})
c.bufSize = bufSize
return nil
}
}
@ -104,25 +99,41 @@ func NewClient(opts ...func(*Client) error) (*Client, error) {
if err != nil {
return nil, err
}
}
c.ch = make(chan record, c.bufSize)
// FIXME: connection is created asyncrosly in monitorMaster go routine,
// ensure the connection is ready for use before calling c.addClient.
time.Sleep(time.Second)
return c, nil
}
func (c *Client) getRecords() {
// StartGetRecords must be called at beginning of each pass
func (c *Client) StartGetRecords(passID int) {
go c.getRecords(passID)
}
func (c *Client) getRecords(passID int) {
for {
t, err := c.getTask()
t, err := c.getTask(passID)
if err != nil {
log.Errorf("Get task failed, sleep 3 seconds and continue, %s", err)
time.Sleep(3 * time.Second)
continue
if err.Error() == ErrPassBefore.Error() ||
err.Error() == ErrNoMoreAvailable.Error() ||
err.Error() == ErrAllTaskFailed.Error() {
c.ch <- record{nil, err}
break
}
if err.Error() == ErrPassAfter.Error() {
// wait util last pass finishes
time.Sleep(time.Second * 3)
continue
}
log.Errorf("getTask error: %s", err)
}
for _, chunk := range t.Chunks {
f, err := os.Open(chunk.Path)
if err != nil {
log.Errorln(err)
f, e := os.Open(chunk.Path)
if e != nil {
log.Errorln(e)
continue
}
@ -178,18 +189,21 @@ func (c *Client) monitorMaster(addrCh <-chan string) {
}
}
// SetDataset set dataset for the master server to dispatch.
// SetDataset sets dataset to dispatch for the master server.
//
// SetDataset can be call multiple times at one pass. But only the first call
// will be honored.
//
// SetDataset can be call multiple times from different nodes. But
// only the first call will be honored.
// After all tasks are done, another call of SetDataset will start another pass.
func (c *Client) SetDataset(globPaths []string) error {
return c.conn.Call("Service.SetDataset", globPaths, nil)
err := c.conn.Call("Service.SetDataset", globPaths, nil)
return err
}
// getTask gets a new task from the master server.
func (c *Client) getTask() (Task, error) {
func (c *Client) getTask(passID int) (Task, error) {
var t Task
err := c.conn.Call("Service.GetTask", 0, &t)
err := c.conn.Call("Service.GetTask", passID, &t)
return t, err
}
@ -208,12 +222,6 @@ func (c *Client) taskFailed(meta TaskMeta) error {
// NextRecord will block until the next record is available. It is
// thread-safe.
func (c *Client) NextRecord() ([]byte, error) {
c.initChOnce.Do(func() {
// initialize with in case WithBuffer is not used.
c.ch = make(chan record, 0)
go c.getRecords()
})
r := <-c.ch
return r.r, r.err
}

@ -54,22 +54,22 @@ func TestGetFinishTask(t *testing.T) {
panic(err)
}
go func(l net.Listener) {
s, err := NewService(&InMemStore{}, chunkPerTask, time.Second, 1)
if err != nil {
panic(err)
s, sErr := NewService(&InMemStore{}, chunkPerTask, time.Second, 1)
if sErr != nil {
panic(sErr)
}
server := rpc.NewServer()
err = server.Register(s)
if err != nil {
panic(err)
sErr = server.Register(s)
if sErr != nil {
panic(sErr)
}
mux := http.NewServeMux()
mux.Handle(rpc.DefaultRPCPath, server)
err = http.Serve(l, mux)
if err != nil {
panic(err)
sErr = http.Serve(l, mux)
if sErr != nil {
panic(sErr)
}
}(l)
@ -103,6 +103,7 @@ func TestGetFinishTask(t *testing.T) {
ch := make(chan string, 1)
ch <- addr
go c.monitorMaster(ch)
err = c.SetDataset([]string{path})
if err != nil {
panic(err)
@ -111,44 +112,47 @@ func TestGetFinishTask(t *testing.T) {
checkOnePass := func(i int) {
var tasks []Task
for idx := 0; idx < totalTask; idx++ {
task, err := c.getTask()
if err != nil {
t.Fatalf("Error: %v, pass: %d\n", err, i)
task, cErr := c.getTask(i)
if cErr != nil && cErr.Error() != ErrNoMoreAvailable.Error() && cErr.Error() != ErrPassAfter.Error() {
t.Fatalf("error: %v, pass: %d\n", cErr, i)
}
tasks = append(tasks, task)
}
_, err = c.getTask()
if err == nil {
// getting task before task finishes should return error
_, cErr := c.getTask(i)
if cErr == nil {
t.Fatalf("Should get error, pass: %d\n", i)
}
err = c.taskFinished(tasks[0].Meta.ID)
if err != nil {
t.Fatalf("Error: %v, pass: %d\n", err, i)
cErr = c.taskFinished(tasks[0].Meta.ID)
if cErr != nil {
t.Fatalf("Error: %v, pass: %d\n", cErr, i)
}
err = c.taskFailed(tasks[0].Meta)
if err != nil {
t.Fatalf("Error: %v, pass: %d\n", err, i)
// call taskFailed once won't put the task to failed queue, just ensure
// the call
cErr = c.taskFailed(tasks[0].Meta)
if cErr != nil {
t.Fatalf("Error: %v, pass: %d\n", cErr, i)
}
tasks = tasks[1:]
task, err := c.getTask()
if err != nil {
t.Fatal(err)
_, cErr = c.getTask(i)
if cErr != nil && cErr.Error() != ErrNoMoreAvailable.Error() && cErr.Error() != ErrPassAfter.Error() {
t.Fatalf("Should be ErrNoMoreAvailable or ErrPassAfter: %s", cErr)
}
tasks = append(tasks, task)
for _, task := range tasks {
err = c.taskFinished(task.Meta.ID)
if err != nil {
t.Fatalf("Error: %v, pass: %d\n", err, i)
cErr = c.taskFinished(task.Meta.ID)
if cErr != nil {
t.Fatal(cErr)
}
}
}
for i := 0; i < 10; i++ {
// init pass data
c.StartGetRecords(i)
checkOnePass(i)
}
}

@ -20,8 +20,10 @@ import (
"net/http"
"net/rpc"
"os"
"runtime"
"strconv"
"strings"
"sync"
"testing"
"time"
@ -29,6 +31,18 @@ import (
"github.com/PaddlePaddle/recordio"
)
// tool function for testing output goroutine ids
func goid() int {
var buf [64]byte
n := runtime.Stack(buf[:], false)
idField := strings.Fields(strings.TrimPrefix(string(buf[:n]), "goroutine "))[0]
id, err := strconv.Atoi(idField)
if err != nil {
panic(fmt.Sprintf("cannot get goroutine id: %v", err))
}
return id
}
func TestNextRecord(t *testing.T) {
const (
path = "/tmp/master_client_TestFull"
@ -45,7 +59,7 @@ func TestNextRecord(t *testing.T) {
panic(err)
}
go func(l net.Listener) {
s, err := master.NewService(&master.InMemStore{}, 10, time.Second, 1)
s, err := master.NewService(&master.InMemStore{}, 1, time.Second*60, 1)
if err != nil {
panic(err)
}
@ -69,7 +83,7 @@ func TestNextRecord(t *testing.T) {
panic(err)
}
w := recordio.NewWriter(f, -1, -1)
w := recordio.NewWriter(f, 1, -1)
for i := 0; i < total; i++ {
_, err = w.Write([]byte{byte(i)})
if err != nil {
@ -87,32 +101,49 @@ func TestNextRecord(t *testing.T) {
panic(err)
}
c, err := master.NewClient(master.WithAddr(fmt.Sprintf(":%d", p)), master.WithBuffer(10))
if err != nil {
panic(err)
}
err = c.SetDataset([]string{path})
if err != nil {
panic(err)
}
for pass := 0; pass < 50; pass++ {
received := make(map[byte]bool)
for i := 0; i < total; i++ {
r, err := c.NextRecord()
if err != nil {
t.Fatal(pass, i, "Read error:", err)
// start several client to test task fetching
var wg sync.WaitGroup
for i := 0; i < 4; i++ {
wg.Add(1)
// test for multiple concurrent clients
go func() {
defer wg.Done()
// each go-routine needs a single client connection instance
c, e := master.NewClient(master.WithAddr(fmt.Sprintf(":%d", p)), master.WithBuffer(1))
if e != nil {
t.Fatal(e)
}
if len(r) != 1 {
t.Fatal(pass, i, "Length should be 1.", r)
e = c.SetDataset([]string{path})
if e != nil {
panic(e)
}
if received[r[0]] {
t.Fatal(pass, i, "Received duplicate.", received, r)
// test for n passes
for pass := 0; pass < 10; pass++ {
c.StartGetRecords(pass)
received := make(map[byte]bool)
taskid := 0
for {
r, e := c.NextRecord()
if e != nil {
// ErrorPassAfter will wait, else break for next pass
if e.Error() == master.ErrPassBefore.Error() ||
e.Error() == master.ErrNoMoreAvailable.Error() {
break
}
t.Fatal(pass, taskid, "Read error:", e)
}
if len(r) != 1 {
t.Fatal(pass, taskid, "Length should be 1.", r)
}
if received[r[0]] {
t.Fatal(pass, taskid, "Received duplicate.", received, r)
}
taskid++
received[r[0]] = true
}
}
received[r[0]] = true
}
}()
}
wg.Wait()
}

@ -19,6 +19,7 @@ import (
"compress/gzip"
"encoding/gob"
"errors"
"math/rand"
"os"
"path/filepath"
"sync"
@ -33,6 +34,18 @@ const (
dialTimeout = 5 * time.Second
)
// ErrAllTaskFailed occur when tasks are in done or failed state.
var ErrAllTaskFailed = errors.New("all task finished")
// ErrNoMoreAvailable occur when no task in todo and yet not all done or fail.
var ErrNoMoreAvailable = errors.New("no more available task")
// ErrPassBefore client side pass number does not match with master counter.
var ErrPassBefore = errors.New("pass number smaller than master")
// ErrPassAfter client side pass number does not match with master counter.
var ErrPassAfter = errors.New("pass number larger than master")
// Store is the interface for save and load the master state.
type Store interface {
Save([]byte) error
@ -75,17 +88,26 @@ type Service struct {
chunksPerTask int
timeoutDur time.Duration
failureMax int
ready chan struct{}
store Store
mu sync.Mutex
initDone bool
taskQueues taskQueues
ready chan struct{}
initDone bool
mu sync.Mutex
taskQueues taskQueues
currPass int
jobTasks []taskEntry
savingTrainer string
}
func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
id := 0
// generate uniq id across job using nanosecond + randint + counter
// FIXME(typhoonzero): this is a workaround, use uuid
randStart := rand.Int()
counter := 0
timestamp := time.Now().Nanosecond()
id := timestamp + randStart + counter
if chunksPerTask <= 0 {
chunksPerTask = 1
}
@ -95,7 +117,8 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
for i, c := range chunks {
if i%chunksPerTask == 0 && len(cur.Task.Chunks) > 0 {
cur.Task.Meta.ID = id
id++
counter++
id = timestamp + randStart + counter
result = append(result, cur)
cur.Task.Chunks = nil
}
@ -266,19 +289,21 @@ func (s *Service) SetDataset(globPaths []string, _ *int) error {
return err
}
s.taskQueues.Todo = partition(chunks, s.chunksPerTask)
s.jobTasks = partition(chunks, s.chunksPerTask)
s.taskQueues.Todo = s.jobTasks
err = s.snapshot()
if err != nil {
log.Errorln(err)
return err
}
close(s.ready)
s.initDone = true
return nil
}
// processFailedTask retry s.failureMax times for failed task.
// return true if all task are done or failed.
func (s *Service) processFailedTask(t taskEntry, epoch int) {
if t.Task.Meta.Epoch != epoch {
// new epoch, task launched after the
@ -302,8 +327,9 @@ func (s *Service) processFailedTask(t taskEntry, epoch int) {
return
}
log.Warningf("Task %v failed %d times, discard.", t.Task, t.NumFailure)
log.Warningf("Task %v failed %d times, re-dispatch.", t.Task, t.NumFailure)
s.taskQueues.Todo = append(s.taskQueues.Todo, t)
return
}
func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() {
@ -331,37 +357,30 @@ func (s *Service) logFields() log.Fields {
}
// GetTask gets a new task from the service.
func (s *Service) GetTask(_ int, task *Task) error {
// passID is the client side pass count
func (s *Service) GetTask(passID int, task *Task) error {
select {
case <-s.ready:
}
s.mu.Lock()
defer s.mu.Unlock()
if passID < s.currPass {
return ErrPassBefore
}
if passID > s.currPass {
// Client may get run to pass after master when one client faster than the
// other
return ErrPassAfter
}
if len(s.taskQueues.Todo) == 0 {
if len(s.taskQueues.Done) == 0 {
if len(s.taskQueues.Pending) == 0 {
err := errors.New("all task failed")
log.WithFields(s.logFields()).Warningln("All tasks failed.")
return err
}
// TODO(helin): client need to retry in this
// error case. Gotcha: RPC client can't
// compare returned error with predefined
// errors like io.EOF, because the error
// instance deserialized from RPC is a
// different instance than the error defined
// in package. So we need to figure out a way
// for client to check this error correctly.
err := errors.New("no more available task")
log.WithFields(s.logFields()).Warningln("No more available task.")
return err
if len(s.taskQueues.Done) == 0 && len(s.taskQueues.Pending) == 0 {
log.WithFields(s.logFields()).Warningln("All tasks failed, may start next pass")
return ErrAllTaskFailed
}
s.taskQueues.Todo = s.taskQueues.Done
s.taskQueues.Done = nil
log.WithFields(s.logFields()).Infoln("No more todo task, but trainer is requesting task to do. Move all done task to todo.")
log.WithFields(s.logFields()).Warningln("No more available task.")
return ErrNoMoreAvailable
}
t := s.taskQueues.Todo[0]
@ -381,7 +400,7 @@ func (s *Service) GetTask(_ int, task *Task) error {
}
// TaskFinished tell the service that a task is finished.
func (s *Service) TaskFinished(taskID int, _ *int) error {
func (s *Service) TaskFinished(taskID int, dummy *int) error {
select {
case <-s.ready:
}
@ -401,11 +420,14 @@ func (s *Service) TaskFinished(taskID int, _ *int) error {
delete(s.taskQueues.Pending, taskID)
log.WithFields(s.logFields()).Infof("Task #%d finished.", taskID)
if len(s.taskQueues.Pending) == 0 && len(s.taskQueues.Todo) == 0 {
log.WithFields(s.logFields()).Infoln("No more todo and pending task, start a new pass.")
s.taskQueues.Todo = append(s.taskQueues.Todo, s.taskQueues.Done...)
s.taskQueues.Done = nil
if len(s.taskQueues.Todo) == 0 && len(s.taskQueues.Pending) == 0 {
// increase master side pass count if all tasks finished
s.currPass++
s.taskQueues.Todo = s.jobTasks
s.taskQueues.Done = []taskEntry{}
// TODO(typhoonzero): deal with failed tasks
s.taskQueues.Failed = []taskEntry{}
log.WithFields(s.logFields()).Warningf("all task finished, add new pass data, newpass: %d.", s.currPass)
}
err := s.snapshot()
@ -416,7 +438,7 @@ func (s *Service) TaskFinished(taskID int, _ *int) error {
}
// TaskFailed tells the service that a task is failed.
func (s *Service) TaskFailed(meta TaskMeta, _ *int) error {
func (s *Service) TaskFailed(meta TaskMeta, dummy *int) error {
select {
case <-s.ready:
}

@ -44,7 +44,8 @@ func TestPartionIndex(t *testing.T) {
cs := make([]Chunk, 100)
ts := partition(cs, 20)
for i := range ts {
if ts[i].Task.Meta.ID != i {
// test auto increament ids
if i > 0 && ts[i].Task.Meta.ID != ts[i-1].Task.Meta.ID+1 {
t.Error(ts[i], i)
}
}

@ -6,16 +6,19 @@ import cPickle as pickle
etcd_ip = os.getenv("MASTER_IP", "127.0.0.1")
etcd_endpoint = "http://" + etcd_ip + ":2379"
print "connecting to master, etcd endpoints: ", etcd_endpoint
master_client = master.client(etcd_endpoint, 5, 64)
def cloud_reader():
print "connecting to master, etcd endpoints: ", etcd_endpoint
master_client = master.client(etcd_endpoint, 5, 64)
global master_client
master_client.set_dataset(
["/pfs/dlnel/public/dataset/uci_housing/uci_housing-*-of-*"])
["/pfs/dlnel/public/dataset/uci_housing/uci_housing-*"], passes=30)
while 1:
r, e = master_client.next_record()
if not r:
if e != -2: # other errors
print "get record error:", e
break
yield pickle.loads(r)
@ -27,10 +30,12 @@ def main():
# network config
x = paddle.layer.data(name='x', type=paddle.data_type.dense_vector(13))
y_predict = paddle.layer.fc(input=x,
param_attr=paddle.attr.Param(name='w'),
param_attr=paddle.attr.Param(
name='w', learning_rate=1e-3),
size=1,
act=paddle.activation.Linear(),
bias_attr=paddle.attr.Param(name='b'))
bias_attr=paddle.attr.Param(
name='b', learning_rate=1e-3))
y = paddle.layer.data(name='y', type=paddle.data_type.dense_vector(1))
cost = paddle.layer.mse_cost(input=y_predict, label=y)
@ -38,9 +43,8 @@ def main():
parameters = paddle.parameters.create(cost)
# create optimizer of new remote updater to pserver
optimizer = paddle.optimizer.Momentum(momentum=0)
optimizer = paddle.optimizer.Momentum(momentum=0, learning_rate=1e-3)
print "etcd endoint: ", etcd_endpoint
trainer = paddle.trainer.SGD(cost=cost,
parameters=parameters,
update_equation=optimizer,
@ -51,6 +55,8 @@ def main():
# event_handler to print training and testing info
def event_handler(event):
if isinstance(event, paddle.event.EndIteration):
# FIXME: for cloud data reader, pass number is managed by master
# should print the server side pass number
if event.batch_id % 100 == 0:
print "Pass %d, Batch %d, Cost %f" % (
event.pass_id, event.batch_id, event.cost)

@ -37,7 +37,7 @@ std::vector<std::string> Evaluator::getNames() const {
double Evaluator::getValue(const std::string name) const {
paddle::Error err;
double v = m->rawPtr->getValue(name, &err);
if (err) {
if (!err.isOK()) {
throw std::runtime_error(err.msg());
}
return v;

@ -3,7 +3,7 @@ cc_library(ddim SRCS ddim.cc DEPS eigen3)
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
nv_test(dim_test SRCS dim_test.cu DEPS ddim)
cc_library(tensor SRCS tensor.cc DEPS ddim place paddle_memory)
cc_library(tensor SRCS tensor.cc DEPS ddim place paddle_memory device_context)
cc_test(tensor_test SRCS tensor_test.cc DEPS tensor)
cc_test(eigen_test SRCS eigen_test.cc DEPS tensor)
@ -29,7 +29,5 @@ py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.
add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
add_dependencies(framework_py_proto framework_py_proto_init)
proto_library(net_proto SRCS net_proto.proto DEPS op_proto)
# cc_library(net SRCS net.cc DEPS operator net_proto op_registry fc_op)
cc_library(net SRCS net.cc DEPS operator net_proto op_registry)
cc_library(net SRCS net.cc DEPS op_registry)
cc_test(net_op_test SRCS net_op_test.cc DEPS net add_op mul_op sigmoid_op softmax_op fc_op)

@ -0,0 +1,160 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/memory/memcpy.h"
namespace paddle {
namespace framework {
template <typename T>
inline void Tensor::check_memory_size() const {
PADDLE_ENFORCE(holder_ != nullptr,
"Tenosr holds no memory. Call Tensor::mutable_data first.");
PADDLE_ENFORCE(holder_->size() >= product(dims_) * sizeof(T) + offset_,
"Tensor's dims_ is out of bound. Call Tensor::mutable_data "
"first to re-allocate memory.");
}
template <typename T>
inline const T* Tensor::data() const {
check_memory_size<T>();
return reinterpret_cast<const T*>(
reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_);
}
template <typename T>
inline T* Tensor::data() {
check_memory_size<T>();
return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
offset_);
}
template <typename T>
inline T* Tensor::mutable_data(DDim dims, platform::Place place) {
static_assert(std::is_pod<T>::value, "T must be POD");
Resize(dims);
return mutable_data<T>(place);
}
template <typename T>
inline T* Tensor::mutable_data(platform::Place place) {
static_assert(std::is_pod<T>::value, "T must be POD");
PADDLE_ENFORCE(product(dims_) > 0,
"Tensor's numel must be larger than zero to call "
"Tensor::mutable_data. Call Tensor::set_dim first.");
/* some versions of boost::variant don't have operator!= */
size_t size = product(dims_) * sizeof(T);
if (holder_ == nullptr || !(holder_->place() == place) ||
holder_->size() < size + offset_) {
if (platform::is_cpu_place(place)) {
holder_.reset(new PlaceholderImpl<T, platform::CPUPlace>(
boost::get<platform::CPUPlace>(place), size));
}
#ifndef PADDLE_ONLY_CPU
else if (platform::is_gpu_place(place)) {
holder_.reset(new PlaceholderImpl<T, platform::GPUPlace>(
boost::get<platform::GPUPlace>(place), size));
}
#endif
offset_ = 0;
}
return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
offset_);
}
template <typename T>
inline void Tensor::ShareDataWith(const Tensor& src) {
src.check_memory_size<T>();
*this = src;
}
template <typename T>
inline void Tensor::CopyFrom(const Tensor& src,
const platform::CPUDeviceContext& ctx) {
src.check_memory_size<T>();
Resize(src.dims());
auto src_place = src.holder_->place();
auto src_ptr = static_cast<const void*>(src.data<T>());
auto dst_place = ctx.GetPlace();
auto dst_ptr = static_cast<void*>(mutable_data<T>(dst_place));
auto size = product(src.dims_) * sizeof(T);
if (platform::is_cpu_place(src_place)) {
memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr,
boost::get<platform::CPUPlace>(src_place), src_ptr, size);
}
#ifndef PADDLE_ONLY_CPU
else if (platform::is_gpu_place(src_place)) {
memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr,
boost::get<platform::GPUPlace>(src_place), src_ptr, size, 0);
}
#endif
}
#ifndef PADDLE_ONLY_CPU
template <typename T>
inline void Tensor::CopyFrom(const Tensor& src,
const platform::CUDADeviceContext& ctx) {
src.check_memory_size<T>();
Resize(src.dims());
auto src_place = src.holder_->place();
auto src_ptr = static_cast<const void*>(src.data<T>());
auto dst_place = ctx.GetPlace();
auto dst_ptr = static_cast<void*>(mutable_data<T>(dst_place));
auto size = product(src.dims_) * sizeof(T);
if (platform::is_cpu_place(src_place)) {
memory::Copy(boost::get<platform::GPUPlace>(dst_place), dst_ptr,
boost::get<platform::CPUPlace>(src_place), src_ptr, size,
ctx.stream());
} else if (platform::is_gpu_place(src_place)) {
memory::Copy(boost::get<platform::GPUPlace>(dst_place), dst_ptr,
boost::get<platform::GPUPlace>(src_place), src_ptr, size,
ctx.stream());
}
}
#endif
template <typename T>
inline Tensor Tensor::Slice(const int& begin_idx, const int& end_idx) const {
check_memory_size<T>();
PADDLE_ENFORCE(begin_idx >= 0, "Slice begin index is less than zero.");
PADDLE_ENFORCE(end_idx <= dims_[0], "Slice end index is out of bound.");
PADDLE_ENFORCE(begin_idx < end_idx,
"Begin index must be less than end index.");
PADDLE_ENFORCE(dims_[0] != 1, "Can not slice a tensor with dims_[0] = 1.");
int base = product(dims_) / dims_[0];
Tensor dst;
dst.holder_ = holder_;
DDim dst_dims = dims_;
dst_dims[0] = end_idx - begin_idx;
dst.Resize(dst_dims);
dst.offset_ = offset_ + begin_idx * base * sizeof(T);
return dst;
}
inline void Tensor::Resize(const DDim& dims) { dims_ = dims; }
inline const DDim& Tensor::dims() const { return dims_; }
} // namespace framework
} // namespace paddle

@ -20,17 +20,7 @@
namespace paddle {
namespace framework {
std::shared_ptr<PlainNet> AddBackwardOp(std::shared_ptr<PlainNet> ForwardOps) {
auto grad_ops = std::make_shared<PlainNet>();
for (auto& op : ForwardOps->ops_) {
auto op_grad = OpRegistry::CreateGradOp(op);
grad_ops->AddOp(op_grad);
}
grad_ops->CompleteAddOp();
return grad_ops;
}
void PlainNet::CompleteAddOp(bool calc) {
void NetOp::CompleteAddOp(bool calc) {
add_op_done_ = true;
if (!calc) return;
std::unordered_set<std::string> input_set;
@ -70,7 +60,7 @@ void PlainNet::CompleteAddOp(bool calc) {
attrs_["temporary_index"] = tmp_index;
}
std::string PlainNet::DebugString() const {
std::string NetOp::DebugString() const {
std::ostringstream os;
os << OperatorBase::DebugString() << std::endl;
for (auto& op : ops_) {
@ -82,5 +72,7 @@ std::string PlainNet::DebugString() const {
return os.str();
}
bool NetOp::IsNetOp() const { return true; }
} // namespace framework
} // namespace paddle

@ -37,21 +37,7 @@ namespace framework {
* This is the base class of network, all the networks should implement the APIs
* it defines.
*/
class Net : public OperatorBase {
public:
virtual void AddOp(const std::shared_ptr<OperatorBase>& op) = 0;
virtual void CompleteAddOp(bool calc) = 0;
};
using NetPtr = std::shared_ptr<Net>;
/**
* @brief a basic implementation of Net.
*
* PlainNet is a very simple Net, it create a list of operators, and run them
* sequentially following the order they added.
*/
class PlainNet : public Net {
class NetOp : public OperatorBase {
public:
/**
* Infer all the operators' input and output variables' shapes, will be called
@ -80,15 +66,17 @@ class PlainNet : public Net {
/**
* @brief Add an operator by ptr
*/
void AddOp(const std::shared_ptr<OperatorBase>& op) override {
void AddOp(const std::shared_ptr<OperatorBase>& op) {
PADDLE_ENFORCE(!add_op_done_, "Cannot AddOp when this network is sealed");
ops_.push_back(op);
}
void CompleteAddOp(bool calculate = true) override;
void CompleteAddOp(bool calculate = true);
std::string DebugString() const override;
bool IsNetOp() const override;
std::vector<std::shared_ptr<OperatorBase>> ops_;
private:
@ -100,7 +88,5 @@ class PlainNet : public Net {
}
};
std::shared_ptr<PlainNet> AddBackwardOp(std::shared_ptr<PlainNet> ForwardOps);
} // namespace framework
} // namespace paddle

@ -40,7 +40,7 @@ void AssertSameVectorWithoutOrder(const std::vector<T>& expected,
}
TEST(OpKernel, all) {
auto net = std::make_shared<PlainNet>();
auto net = std::make_shared<NetOp>();
ASSERT_NE(net, nullptr);
auto op1 = std::make_shared<TestOp>();
@ -69,30 +69,23 @@ TEST(OpKernel, all) {
net->Run(scope, dev_ctx);
ASSERT_EQ(2, infer_shape_cnt);
ASSERT_EQ(2, run_cnt);
ASSERT_THROW(net->AddOp(op2), std::runtime_error);
}
TEST(AddBackwardOp, TestGradOp) {
auto net = std::make_shared<PlainNet>();
ASSERT_NE(net, nullptr);
net->AddOp(framework::OpRegistry::CreateOp("mul", {"X", "Y"}, {"Out"}, {}));
net->AddOp(
framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {"Out"}, {}));
net->AddOp(framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {""}, {}));
auto grad_ops = AddBackwardOp(net);
for (auto& op : grad_ops->ops_) {
op->DebugString();
}
ASSERT_THROW(net->AddOp(op2), paddle::platform::EnforceNotMet);
}
// TODO(zhihong): add fc grad without registering.
// TEST(AddBackwardOp, TestNoGradOp) {
// auto net = std::make_shared<PlainNet>();
// ASSERT_NE(net, nullptr);
// net->AddOp(framework::OpRegistry::CreateOp("fc", {"X", "W", "b"}, {"Y"},
// {})); auto grad_ops = AddBackwardOp(net); for (auto& op : grad_ops->ops_) {
// op->DebugString();
// }
// }
//! TODO(yuyang18): Refine Backward Op.
// TEST(AddBackwardOp, TestGradOp) {
// auto net = std::make_shared<NetOp>();
// ASSERT_NE(net, nullptr);
// net->AddOp(framework::OpRegistry::CreateOp("mul", {"X", "Y"}, {"Out"}, {}));
// net->AddOp(
// framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {"Out"}, {}));
// net->AddOp(framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {""},
// {}));
// auto grad_ops = AddBackwardOp(net);
// for (auto& op : grad_ops->ops_) {
// op->DebugString();
// }
//}
} // namespace framework
} // namespace paddle

@ -1,15 +0,0 @@
syntax="proto2";
package paddle.framework;
import "op_proto.proto";
message NetDesc {
// network identification
optional string name = 1;
// operator contains in network
repeated OpProto operators = 2;
// network type to run with. e.g "plainNet", "DAG"
optional string net_type = 3;
// num worker always
optional int32 num_workers = 4;
}

@ -403,15 +403,16 @@ class GradOpRegisterHelper {
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op_kernel_##type##_##DEVICE_TYPE##__, \
"REGISTER_OP_KERNEL must be in global namespace"); \
struct __op_kernel_register__##type##__ { \
__op_kernel_register__##type##__() { \
struct __op_kernel_register__##type##__##DEVICE_TYPE##__ { \
__op_kernel_register__##type##__##DEVICE_TYPE##__() { \
::paddle::framework::OperatorWithKernel::OpKernelKey key; \
key.place_ = PlaceType(); \
::paddle::framework::OperatorWithKernel::AllOpKernels()[#type][key] \
.reset(new __VA_ARGS__()); \
} \
}; \
static __op_kernel_register__##type##__ __reg_kernel_##type##__; \
static __op_kernel_register__##type##__##DEVICE_TYPE##__ \
__reg_kernel_##type##__##DEVICE_TYPE##__; \
int __op_kernel_register_##type##_handle_##DEVICE_TYPE##__() { return 0; }
// (type, KernelType)

@ -90,7 +90,7 @@ TEST(OpRegistry, IllegalAttr) {
bool caught = false;
try {
paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (std::runtime_error& err) {
} catch (paddle::platform::EnforceNotMet err) {
caught = true;
std::string msg = "larger_than check fail";
const char* err_msg = err.what();
@ -136,7 +136,7 @@ TEST(OpRegistry, CustomChecker) {
bool caught = false;
try {
paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (std::runtime_error& err) {
} catch (paddle::platform::EnforceNotMet err) {
caught = true;
std::string msg = "Attribute 'test_attr' is required!";
const char* err_msg = err.what();
@ -154,7 +154,7 @@ TEST(OpRegistry, CustomChecker) {
caught = false;
try {
paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (std::runtime_error& err) {
} catch (paddle::platform::EnforceNotMet err) {
caught = true;
std::string msg = "'test_attr' must be even!";
const char* err_msg = err.what();
@ -192,7 +192,7 @@ TEST(ProtoMaker, DuplicatedAttr) {
pd::OpProto op_proto;
pd::OpAttrChecker op_checker;
auto proto_maker = TestAttrProtoMaker(&op_proto, &op_checker);
ASSERT_THROW(proto_maker.Validate(), std::runtime_error);
ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
}
class TestInOutProtoMaker : public pd::OpProtoAndCheckerMaker {
@ -208,5 +208,5 @@ TEST(ProtoMaker, DuplicatedInOut) {
pd::OpProto op_proto;
pd::OpAttrChecker op_checker;
auto proto_maker = TestInOutProtoMaker(&op_proto, &op_checker);
ASSERT_THROW(proto_maker.Validate(), std::runtime_error);
ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
}

@ -90,15 +90,17 @@ class OperatorBase {
virtual void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const = 0;
// Get a input with argument's name described in `op_proto`
virtual bool IsNetOp() const { return false; }
//! Get a input with argument's name described in `op_proto`
const std::string& Input(const std::string& name) const;
// Get a input which has multiple variables.
// TODO add a vector_view to prevent memory copy.
//! Get a input which has multiple variables.
//! TODO add a vector_view to prevent memory copy.
std::vector<std::string> Inputs(const std::string& name) const;
// Get a output with argument's name described in `op_proto`
//! Get a output with argument's name described in `op_proto`
const std::string& Output(const std::string& name) const;
// Get an output which has multiple variables.
// TODO add a vector_view to prevent memory copy.
//! Get an output which has multiple variables.
//! TODO add a vector_view to prevent memory copy.
std::vector<std::string> Outputs(const std::string& name) const;
public:
@ -199,7 +201,9 @@ class OperatorWithKernel : public OperatorBase {
place_ = dev_ctx.GetPlace();
}
bool operator==(const OpKernelKey& o) const { return place_ == o.place_; }
bool operator==(const OpKernelKey& o) const {
return platform::places_are_same_class(place_, o.place_);
}
};
struct OpKernelHash {

@ -12,7 +12,7 @@
See the License for the specific language governing permissions and
limitations under the License. */
#include <paddle/framework/tensor.h>
#include "paddle/framework/tensor.h"
namespace paddle {
namespace framework {}

File diff suppressed because it is too large Load Diff

@ -33,7 +33,7 @@ TEST(Tensor, DataAssert) {
bool caught = false;
try {
src_tensor.data<double>();
} catch (std::runtime_error& err) {
} catch (paddle::platform::EnforceNotMet err) {
caught = true;
std::string msg =
"Tenosr holds no memory. Call Tensor::mutable_data first.";
@ -72,7 +72,8 @@ TEST(Tensor, MutableData) {
p2 = src_tensor.mutable_data<float>(make_ddim({2, 2}), CPUPlace());
EXPECT_EQ(p1, p2);
}
#ifdef __CUDACC__
#ifndef PADDLE_ONLY_CPU
{
Tensor src_tensor;
float* p1 = nullptr;
@ -107,7 +108,7 @@ TEST(Tensor, ShareDataWith) {
bool caught = false;
try {
dst_tensor.ShareDataWith<float>(src_tensor);
} catch (std::runtime_error& err) {
} catch (paddle::platform::EnforceNotMet err) {
caught = true;
std::string msg =
"Tenosr holds no memory. Call Tensor::mutable_data first.";
@ -123,7 +124,7 @@ TEST(Tensor, ShareDataWith) {
ASSERT_EQ(src_tensor.data<int>(), dst_tensor.data<int>());
}
#ifdef __CUDACC__
#ifndef PADDLE_ONLY_CPU
{
Tensor src_tensor;
Tensor dst_tensor;
@ -160,7 +161,7 @@ TEST(Tensor, Slice) {
EXPECT_EQ(src_data_address + 3 * 4 * 1 * sizeof(int), slice_data_address);
}
#ifdef __CUDACC__
#ifndef PADDLE_ONLY_CPU
{
Tensor src_tensor;
src_tensor.mutable_data<double>(make_ddim({6, 9}), GPUPlace());
@ -188,25 +189,74 @@ TEST(Tensor, Slice) {
TEST(Tensor, CopyFrom) {
using namespace paddle::framework;
using namespace paddle::platform;
{
Tensor src_tensor;
Tensor dst_tensor;
int* src_ptr = src_tensor.mutable_data<int>(make_ddim({3, 3}), CPUPlace());
int arr[9] = {1, 2, 3, 4, 5, 6, 7, 8, 9};
memcpy(src_ptr, arr, 9 * sizeof(int));
Tensor src_tensor;
int* src_ptr = src_tensor.mutable_data<int>(make_ddim({3, 3}), CPUPlace());
int arr[9] = {1, 2, 3, 4, 5, 6, 7, 8, 9};
memcpy(src_ptr, arr, 9 * sizeof(int));
Tensor dst_tensor;
dst_tensor.CopyFrom<int>(src_tensor, CPUPlace());
const int* dst_ptr = dst_tensor.data<int>();
ASSERT_NE(src_ptr, dst_ptr);
for (size_t i = 0; i < 9; ++i) {
EXPECT_EQ(src_ptr[i], dst_ptr[i]);
auto* cpu_ctx = new paddle::platform::CPUDeviceContext();
dst_tensor.CopyFrom<int>(src_tensor, *cpu_ctx);
const int* dst_ptr = dst_tensor.data<int>();
ASSERT_NE(src_ptr, dst_ptr);
for (size_t i = 0; i < 9; ++i) {
EXPECT_EQ(src_ptr[i], dst_ptr[i]);
}
Tensor slice_tensor = src_tensor.Slice<int>(1, 2);
dst_tensor.CopyFrom<int>(slice_tensor, *cpu_ctx);
const int* slice_ptr = slice_tensor.data<int>();
dst_ptr = dst_tensor.data<int>();
ASSERT_NE(dst_ptr, slice_ptr);
for (size_t i = 0; i < 3; ++i) {
EXPECT_EQ(dst_ptr[i], slice_ptr[i]);
}
}
#ifndef PADDLE_ONLY_CPU
{
Tensor src_tensor;
Tensor gpu_tensor;
Tensor dst_tensor;
int* src_ptr = src_tensor.mutable_data<int>(make_ddim({3, 3}), CPUPlace());
int arr[9] = {1, 2, 3, 4, 5, 6, 7, 8, 9};
memcpy(src_ptr, arr, 9 * sizeof(int));
// CPU Tensor to GPU Tensor
auto gpu_ctx = new paddle::platform::CUDADeviceContext(0);
gpu_tensor.CopyFrom<int>(src_tensor, *gpu_ctx);
// GPU Tensor to CPU Tensor
auto cpu_ctx = new paddle::platform::CPUDeviceContext();
dst_tensor.CopyFrom<int>(gpu_tensor, *cpu_ctx);
// Compare Tensors
const int* dst_ptr = dst_tensor.data<int>();
ASSERT_NE(src_ptr, dst_ptr);
for (size_t i = 0; i < 9; ++i) {
EXPECT_EQ(src_ptr[i], dst_ptr[i]);
}
Tensor slice_tensor = src_tensor.Slice<int>(1, 2);
// CPU Slice Tensor to GPU Tensor
gpu_tensor.CopyFrom<int>(slice_tensor, *gpu_ctx);
Tensor slice_tensor = src_tensor.Slice<int>(1, 2);
dst_tensor.CopyFrom<int>(slice_tensor, CPUPlace());
const int* slice_ptr = slice_tensor.data<int>();
dst_ptr = dst_tensor.data<int>();
ASSERT_NE(dst_ptr, slice_ptr);
for (size_t i = 0; i < 3; ++i) {
EXPECT_EQ(dst_ptr[i], slice_ptr[i]);
// GPU Tensor to CPU Tensor
dst_tensor.CopyFrom<int>(gpu_tensor, *cpu_ctx);
// Compare Slice Tensors
const int* slice_ptr = slice_tensor.data<int>();
dst_ptr = dst_tensor.data<int>();
ASSERT_NE(dst_ptr, slice_ptr);
for (size_t i = 0; i < 3; ++i) {
EXPECT_EQ(dst_ptr[i], slice_ptr[i]);
}
}
#endif
}

@ -207,8 +207,8 @@ Error __must_check backward(Argument& act) {
argument_.value->setData(act.value->getData() + offset, 1UL, size);
argument_.grad->setData(act.grad->getData() + offset, 1UL, size);
Error status = softmax_.backward(argument_);
if (!status) return status;
Error err = softmax_.backward(argument_);
if (!err.isOK()) return err;
}
return Error();
}

@ -1,7 +1,7 @@
add_subdirectory(detail)
cc_library(memory SRCS memory.cc)
cc_library(memcpy SRCS memcpy.cc DEPS device_context)
cc_library(memcpy SRCS memcpy.cc)
cc_library(paddle_memory
DEPS

@ -27,12 +27,11 @@ BuddyAllocator::BuddyAllocator(SystemAllocator* system_allocator,
system_allocator_(std::move(system_allocator)) {}
BuddyAllocator::~BuddyAllocator() {
DLOG(INFO) << "BuddyAllocator Disconstructor makes sure that all of these "
"have actually been freed";
VLOG(3) << "BuddyAllocator Disconstructor makes sure that all of these "
"have actually been freed";
while (!pool_.empty()) {
auto block = static_cast<MemoryBlock*>(std::get<2>(*pool_.begin()));
DLOG(INFO) << "Free from block (" << block << ", " << max_chunk_size_
<< ")";
VLOG(3) << "Free from block (" << block << ", " << max_chunk_size_ << ")";
system_allocator_->Free(block, max_chunk_size_, block->index(cache_));
cache_.invalidate(block);
@ -52,12 +51,11 @@ void* BuddyAllocator::Alloc(size_t unaligned_size) {
// acquire the allocator lock
std::lock_guard<std::mutex> lock(mutex_);
DLOG(INFO) << "Allocate " << unaligned_size << " bytes from chunk size "
<< size;
VLOG(3) << "Allocate " << unaligned_size << " bytes from chunk size " << size;
// if the allocation is huge, send directly to the system allocator
if (size > max_chunk_size_) {
DLOG(INFO) << "Allocate from system allocator.";
VLOG(3) << "Allocate from system allocator.";
return SystemAlloc(size);
}
@ -72,9 +70,9 @@ void* BuddyAllocator::Alloc(size_t unaligned_size) {
return nullptr;
}
} else {
DLOG(INFO) << "Allocation from existing memory block " << std::get<2>(*it)
<< " at address "
<< reinterpret_cast<MemoryBlock*>(std::get<2>(*it))->data();
VLOG(3) << "Allocation from existing memory block " << std::get<2>(*it)
<< " at address "
<< reinterpret_cast<MemoryBlock*>(std::get<2>(*it))->data();
}
total_used_ += size;
@ -91,10 +89,10 @@ void BuddyAllocator::Free(void* p) {
// Acquire the allocator lock
std::lock_guard<std::mutex> lock(mutex_);
DLOG(INFO) << "Free from address " << block;
VLOG(3) << "Free from address " << block;
if (block->type(cache_) == MemoryBlock::HUGE_CHUNK) {
DLOG(INFO) << "Free directly from system allocator";
VLOG(3) << "Free directly from system allocator";
system_allocator_->Free(block, block->total_size(cache_),
block->index(cache_));
@ -111,8 +109,8 @@ void BuddyAllocator::Free(void* p) {
// Trying to merge the right buddy
if (block->has_right_buddy(cache_)) {
DLOG(INFO) << "Merging this block " << block << " with its right buddy "
<< block->right_buddy(cache_);
VLOG(3) << "Merging this block " << block << " with its right buddy "
<< block->right_buddy(cache_);
auto right_buddy = block->right_buddy(cache_);
@ -129,8 +127,8 @@ void BuddyAllocator::Free(void* p) {
// Trying to merge the left buddy
if (block->has_left_buddy(cache_)) {
DLOG(INFO) << "Merging this block " << block << " with its left buddy "
<< block->left_buddy(cache_);
VLOG(3) << "Merging this block " << block << " with its left buddy "
<< block->left_buddy(cache_);
auto left_buddy = block->left_buddy(cache_);
@ -146,8 +144,8 @@ void BuddyAllocator::Free(void* p) {
}
// Dumping this block into pool
DLOG(INFO) << "Inserting free block (" << block << ", "
<< block->total_size(cache_) << ")";
VLOG(3) << "Inserting free block (" << block << ", "
<< block->total_size(cache_) << ")";
pool_.insert(
IndexSizeAddress(block->index(cache_), block->total_size(cache_), block));
@ -166,7 +164,7 @@ void* BuddyAllocator::SystemAlloc(size_t size) {
size_t index = 0;
void* p = system_allocator_->Alloc(index, size);
DLOG(INFO) << "Allocated " << p << " from system allocator.";
VLOG(3) << "Allocated " << p << " from system allocator.";
if (p == nullptr) return nullptr;
@ -192,8 +190,8 @@ BuddyAllocator::PoolSet::iterator BuddyAllocator::RefillPool() {
if (p == nullptr) return pool_.end();
DLOG(INFO) << "Creating and inserting new block " << p
<< " from system allocator";
VLOG(3) << "Creating and inserting new block " << p
<< " from system allocator";
static_cast<MemoryBlock*>(p)->init(cache_, MemoryBlock::FREE_CHUNK, index,
max_chunk_size_, nullptr, nullptr);
@ -237,19 +235,19 @@ void* BuddyAllocator::SplitToAlloc(BuddyAllocator::PoolSet::iterator it,
auto block = static_cast<MemoryBlock*>(std::get<2>(*it));
pool_.erase(it);
DLOG(INFO) << "Split block (" << block << ", " << block->total_size(cache_)
<< ") into";
VLOG(3) << "Split block (" << block << ", " << block->total_size(cache_)
<< ") into";
block->split(cache_, size);
DLOG(INFO) << "Left block (" << block << ", " << block->total_size(cache_)
<< ")";
VLOG(3) << "Left block (" << block << ", " << block->total_size(cache_)
<< ")";
block->set_type(cache_, MemoryBlock::ARENA_CHUNK);
// the rest of memory if exist
if (block->has_right_buddy(cache_)) {
if (block->right_buddy(cache_)->type(cache_) == MemoryBlock::FREE_CHUNK) {
DLOG(INFO) << "Insert right block (" << block->right_buddy(cache_) << ", "
<< block->right_buddy(cache_)->total_size(cache_) << ")";
VLOG(3) << "Insert right block (" << block->right_buddy(cache_) << ", "
<< block->right_buddy(cache_)->total_size(cache_) << ")";
pool_.insert(
IndexSizeAddress(block->right_buddy(cache_)->index(cache_),
@ -276,7 +274,7 @@ void BuddyAllocator::CleanIdleFallBackAlloc() {
return;
}
DLOG(INFO) << "Return block " << block << " to fallback allocator.";
VLOG(3) << "Return block " << block << " to fallback allocator.";
system_allocator_->Free(block, max_chunk_size_, block->index(cache_));
cache_.invalidate(block);
@ -312,7 +310,7 @@ void BuddyAllocator::CleanIdleNormalAlloc() {
MemoryBlock* block = static_cast<MemoryBlock*>(std::get<2>(*pool));
DLOG(INFO) << "Return block " << block << " to base allocator.";
VLOG(3) << "Return block " << block << " to base allocator.";
system_allocator_->Free(block, max_chunk_size_, block->index(cache_));
cache_.invalidate(block);

@ -35,7 +35,7 @@ void Copy<platform::CPUPlace, platform::GPUPlace>(platform::CPUPlace dst_place,
platform::GPUPlace src_place,
const void* src, size_t num,
cudaStream_t stream) {
platform::GPUPlaceGuard g(src_place.device);
platform::SetDeviceId(src_place.device);
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream);
}
@ -45,7 +45,7 @@ void Copy<platform::GPUPlace, platform::CPUPlace>(platform::GPUPlace dst_place,
platform::CPUPlace src_place,
const void* src, size_t num,
cudaStream_t stream) {
platform::GPUPlaceGuard g(dst_place.device);
platform::SetDeviceId(dst_place.device);
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream);
}
@ -56,7 +56,7 @@ void Copy<platform::GPUPlace, platform::GPUPlace>(platform::GPUPlace dst_place,
const void* src, size_t num,
cudaStream_t stream) {
if (dst_place == src_place) {
platform::GPUPlaceGuard g(src_place.device);
platform::SetDeviceId(src_place.device);
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToDevice, stream);
} else {
platform::GpuMemcpyPeer(dst, dst_place.device, src, src_place.device, num,

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save