parent
cf70d5b350
commit
328cb289ed
@ -0,0 +1,217 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <future> // NOLINT
|
||||
#include <ostream>
|
||||
|
||||
#include "paddle/fluid/framework/data_type.h"
|
||||
#include "paddle/fluid/framework/lod_tensor.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/framework/version.h"
|
||||
#include "paddle/fluid/platform/profiler.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
using LoDTensor = framework::LoDTensor;
|
||||
using SelectedRows = framework::SelectedRows;
|
||||
|
||||
struct DeserializedDataFunctor {
|
||||
DeserializedDataFunctor(void **buf, Tensor *tensor,
|
||||
const platform::Place &place)
|
||||
: buf_(buf), tensor_(tensor), place_(place) {}
|
||||
|
||||
template <typename T>
|
||||
void apply() {
|
||||
*buf_ = tensor_->mutable_data<T>(place_);
|
||||
}
|
||||
|
||||
void **buf_;
|
||||
Tensor *tensor_;
|
||||
platform::Place place_;
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class SparseTensorLoadKernel : public paddle::framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const paddle::framework::ExecutionContext &ctx) const override {
|
||||
auto place = ctx.GetPlace();
|
||||
auto filename = ctx.Attr<std::string>("file_path");
|
||||
std::ifstream fin(filename, std::ios::binary);
|
||||
PADDLE_ENFORCE_EQ(static_cast<bool>(fin), true,
|
||||
platform::errors::Unavailable(
|
||||
"Load operator fail to open file %s, please check "
|
||||
"whether the model file is complete or damaged.",
|
||||
filename));
|
||||
auto name = ctx.OutputNames("Out")[0];
|
||||
VLOG(4) << "Sparse Load Var name: " << name;
|
||||
auto *out_var = ctx.OutputVar("Out");
|
||||
PADDLE_ENFORCE_NOT_NULL(
|
||||
out_var, platform::errors::InvalidArgument(
|
||||
"The variable %s to be loaded cannot be found.", name));
|
||||
PADDLE_ENFORCE_EQ(out_var->IsType<paddle::framework::LoDTensor>(), true,
|
||||
platform::errors::InvalidArgument(
|
||||
"SparseLoad OP only support LoDTensor"));
|
||||
LoadLodTensor(fin, place, out_var, ctx);
|
||||
}
|
||||
|
||||
void LoadLodTensor(std::istream &is, const platform::Place &place,
|
||||
paddle::framework::Variable *var,
|
||||
const paddle::framework::ExecutionContext &ctx) const {
|
||||
auto *tensor = var->GetMutable<paddle::framework::LoDTensor>();
|
||||
|
||||
auto node_index = ctx.Attr<int64_t>("node_index");
|
||||
auto node_num = ctx.Attr<int64_t>("node_num");
|
||||
auto shape = ctx.Attr<std::vector<int64_t>>("shape");
|
||||
VLOG(4) << "Sparse LoadLodTensor node_num" << node_num;
|
||||
VLOG(4) << "Sparse LoadLodTensor node_index" << node_index;
|
||||
VLOG(4) << "Sparse LoadLodTensor shape[0]" << shape[0];
|
||||
PADDLE_ENFORCE_GE(node_index, 0, platform::errors::InvalidArgument(
|
||||
"node_num great than or equal to 0"));
|
||||
PADDLE_ENFORCE_GE(node_num, 1, platform::errors::InvalidArgument(
|
||||
"node_num great than or equal to 1"));
|
||||
|
||||
{
|
||||
// the 1st field, unit32_t version for LoDTensor
|
||||
uint32_t version;
|
||||
is.read(reinterpret_cast<char *>(&version), sizeof(version));
|
||||
PADDLE_ENFORCE_EQ(paddle::framework::IsTensorVersionSupported(version),
|
||||
true,
|
||||
platform::errors::InvalidArgument(
|
||||
"Tensor version %u is not supported.", version));
|
||||
PADDLE_ENFORCE_EQ(version, 0U, platform::errors::InvalidArgument(
|
||||
"Tensor version %u is not supported, "
|
||||
"only version 0 is supported.",
|
||||
version));
|
||||
}
|
||||
|
||||
{
|
||||
// the 2st field, LoD information
|
||||
// Todo sparse load need change LoDTensor's lod level
|
||||
uint64_t lod_level;
|
||||
is.read(reinterpret_cast<char *>(&lod_level), sizeof(lod_level));
|
||||
auto &lod = *tensor->mutable_lod();
|
||||
lod.resize(lod_level);
|
||||
}
|
||||
|
||||
// the 3st filed, Tensor
|
||||
|
||||
uint32_t version;
|
||||
is.read(reinterpret_cast<char *>(&version), sizeof(version));
|
||||
|
||||
PADDLE_ENFORCE_EQ(
|
||||
version, 0U,
|
||||
platform::errors::InvalidArgument(
|
||||
"tensor version %u is not supported, Only version 0 is supported",
|
||||
version));
|
||||
|
||||
paddle::framework::proto::VarType::TensorDesc desc;
|
||||
|
||||
{ // int32_t size
|
||||
// proto buffer
|
||||
int32_t size;
|
||||
is.read(reinterpret_cast<char *>(&size), sizeof(size));
|
||||
std::unique_ptr<char[]> buf(new char[size]);
|
||||
is.read(reinterpret_cast<char *>(buf.get()), size);
|
||||
PADDLE_ENFORCE_EQ(
|
||||
desc.ParseFromArray(buf.get(), size), true,
|
||||
platform::errors::InvalidArgument("Cannot parse tensor desc"));
|
||||
}
|
||||
|
||||
{ // read tensor
|
||||
std::vector<int64_t> dims;
|
||||
dims.reserve(static_cast<size_t>(desc.dims().size()));
|
||||
std::copy(desc.dims().begin(), desc.dims().end(),
|
||||
std::back_inserter(dims));
|
||||
|
||||
int64_t line_numel = 1;
|
||||
for (size_t dim = 1; dim < dims.size(); dim++) {
|
||||
line_numel *= dims[dim];
|
||||
}
|
||||
auto total_line = dims[0];
|
||||
|
||||
tensor->Resize(paddle::framework::make_ddim(shape));
|
||||
|
||||
void *buf;
|
||||
auto ctx = platform::CPUDeviceContext();
|
||||
|
||||
paddle::framework::VisitDataType(
|
||||
desc.data_type(),
|
||||
DeserializedDataFunctor(&buf, tensor, ctx.GetPlace()));
|
||||
|
||||
auto line_size =
|
||||
line_numel * paddle::framework::SizeOfType(desc.data_type());
|
||||
char *cur_buf = static_cast<char *>(buf);
|
||||
char *temp_row = new char[line_size];
|
||||
VLOG(4) << "TensorFromStream: line_size " << line_size;
|
||||
VLOG(4) << "TensorFromStream: total_line " << total_line;
|
||||
for (size_t line_index = 0; line_index < static_cast<size_t>(total_line);
|
||||
++line_index) {
|
||||
is.read(temp_row, line_size);
|
||||
if (static_cast<int64_t>(line_index) % node_num == node_index) {
|
||||
memcpy(cur_buf, temp_row, line_size);
|
||||
cur_buf += line_size;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class SparseTensorLoadOp : public paddle::framework::OperatorWithKernel {
|
||||
public:
|
||||
using paddle::framework::OperatorWithKernel::OperatorWithKernel;
|
||||
void InferShape(paddle::framework::InferShapeContext *ctx) const override {}
|
||||
|
||||
protected:
|
||||
paddle::framework::OpKernelType GetExpectedKernelType(
|
||||
const paddle::framework::ExecutionContext &ctx) const override {
|
||||
paddle::framework::OpKernelType kt = paddle::framework::OpKernelType(
|
||||
paddle::framework::proto::VarType::FP32, ctx.GetPlace());
|
||||
return kt;
|
||||
}
|
||||
};
|
||||
|
||||
class SparseTensorLoadOpMaker
|
||||
: public paddle::framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() {
|
||||
AddOutput("Out", "The LoDTensor / SelectedRows need to be loaded");
|
||||
AddAttr<std::string>("file_path",
|
||||
R"(Variable will be loaded from "file_path")")
|
||||
.AddCustomChecker(
|
||||
[](const std::string &path) { return !path.empty(); });
|
||||
AddAttr<int64_t>("node_index", "role id from 0 ~ node_num.").SetDefault(0);
|
||||
AddAttr<int64_t>("node_num", "role nums which need load current varibale.")
|
||||
.SetDefault(0);
|
||||
AddAttr<std::vector<int64_t>>("shape",
|
||||
"(vector<int64_t>) The shape of the output")
|
||||
.SetDefault({});
|
||||
AddComment(R"DOC(
|
||||
SparseTensorLoad OP, Load sprase tensor on parameter server
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
|
||||
REGISTER_OPERATOR(sparse_tensor_load, ops::SparseTensorLoadOp,
|
||||
ops::SparseTensorLoadOpMaker);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
sparse_tensor_load,
|
||||
ops::SparseTensorLoadKernel<paddle::platform::CPUDeviceContext, float>)
|
@ -0,0 +1,122 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import unittest
|
||||
import numpy as np
|
||||
import tempfile
|
||||
import shutil
|
||||
from op_test import OpTest, randomize_probability
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.layers as layers
|
||||
import paddle.distributed.fleet.base.role_maker as role_maker
|
||||
from paddle.distributed.fleet import fleet
|
||||
|
||||
|
||||
class SparseLoadOp(unittest.TestCase):
|
||||
""" Test load operator.
|
||||
"""
|
||||
|
||||
def net(self, emb_array, fc_array):
|
||||
with fluid.unique_name.guard():
|
||||
dense_input = fluid.data('input', shape=[None, 1], dtype="int64")
|
||||
|
||||
emb = fluid.layers.embedding(
|
||||
input=dense_input,
|
||||
is_sparse=True,
|
||||
size=[10, 10],
|
||||
param_attr=fluid.ParamAttr(
|
||||
name="embedding",
|
||||
initializer=fluid.initializer.NumpyArrayInitializer(
|
||||
emb_array)), )
|
||||
|
||||
fc1 = fluid.layers.fc(
|
||||
input=emb,
|
||||
size=10,
|
||||
act="relu",
|
||||
param_attr=fluid.ParamAttr(
|
||||
name='fc',
|
||||
initializer=fluid.initializer.NumpyArrayInitializer(
|
||||
fc_array)))
|
||||
loss = fluid.layers.reduce_mean(fc1)
|
||||
return loss
|
||||
|
||||
def save_origin_model(self, emb_array, fc_array):
|
||||
startup_program = fluid.framework.Program()
|
||||
test_program = fluid.framework.Program()
|
||||
with fluid.framework.program_guard(test_program, startup_program):
|
||||
with fluid.unique_name.guard():
|
||||
loss = self.net(emb_array, fc_array)
|
||||
optimizer = fluid.optimizer.Adam(1e-3)
|
||||
optimizer.minimize(loss)
|
||||
|
||||
exe = fluid.Executor(fluid.CPUPlace())
|
||||
exe.run(startup_program)
|
||||
model_path = tempfile.mkdtemp()
|
||||
fluid.io.save_persistables(executor=exe, dirname=model_path)
|
||||
return model_path
|
||||
|
||||
|
||||
class TestSparseLoadOpCase1(SparseLoadOp):
|
||||
def test_2ps_0_load(self):
|
||||
# init No.0 server env
|
||||
env = {}
|
||||
env["PADDLE_PSERVERS_IP_PORT_LIST"] = "127.0.0.1:4001,127.0.0.1:4002"
|
||||
env["PADDLE_TRAINERS_NUM"] = str(2)
|
||||
env["TRAINING_ROLE"] = "PSERVER"
|
||||
env["PADDLE_PORT"] = "4001"
|
||||
env["POD_IP"] = "127.0.0.1"
|
||||
for k, v in env.items():
|
||||
os.environ[k] = str(v)
|
||||
"""
|
||||
array([[0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
|
||||
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
|
||||
[0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2],
|
||||
[0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3],
|
||||
[0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4],
|
||||
[0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5],
|
||||
[0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6],
|
||||
[0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7],
|
||||
[0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8],
|
||||
[0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9]])
|
||||
"""
|
||||
emb_array = np.arange(0, 1, 0.1).repeat(10).reshape(10, 10)
|
||||
fc_array = np.arange(0, 1, 0.1).repeat(10).reshape(10, 10)
|
||||
model_path = self.save_origin_model(emb_array, fc_array)
|
||||
|
||||
role = role_maker.PaddleCloudRoleMaker()
|
||||
fleet.init(role)
|
||||
loss = self.net(emb_array, fc_array)
|
||||
strategy = paddle.distributed.fleet.DistributedStrategy()
|
||||
strategy.a_sync = True
|
||||
optimizer = fluid.optimizer.Adam(1e-3)
|
||||
optimizer = fleet.distributed_optimizer(optimizer, strategy)
|
||||
optimizer.minimize(loss)
|
||||
fleet.init_server(model_path)
|
||||
|
||||
fc_w = np.array(fluid.global_scope().find_var("fc").get_tensor())
|
||||
|
||||
emb = np.array(fluid.global_scope().find_var("embedding.block0")
|
||||
.get_tensor())
|
||||
|
||||
assert fc_w.all() == fc_array.all()
|
||||
assert emb.all() == emb_array[::2].all()
|
||||
shutil.rmtree(model_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
paddle.enable_static()
|
||||
unittest.main()
|
@ -0,0 +1,76 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import unittest
|
||||
import numpy as np
|
||||
import tempfile
|
||||
import shutil
|
||||
from op_test import OpTest, randomize_probability
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.layers as layers
|
||||
import paddle.distributed.fleet.base.role_maker as role_maker
|
||||
from paddle.distributed.fleet import fleet
|
||||
from test_dist_sparse_load_ps0 import SparseLoadOp
|
||||
|
||||
|
||||
class TestSparseLoadOpCase2(SparseLoadOp):
|
||||
def test_2ps_0_load(self):
|
||||
# init No.1 server env
|
||||
env = {}
|
||||
env["PADDLE_PSERVERS_IP_PORT_LIST"] = "127.0.0.1:4001,127.0.0.1:4002"
|
||||
env["PADDLE_TRAINERS_NUM"] = str(2)
|
||||
env["TRAINING_ROLE"] = "PSERVER"
|
||||
env["PADDLE_PORT"] = "4002"
|
||||
env["POD_IP"] = "127.0.0.1"
|
||||
for k, v in env.items():
|
||||
os.environ[k] = str(v)
|
||||
"""
|
||||
array([[0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
|
||||
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
|
||||
[0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2],
|
||||
[0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3],
|
||||
[0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4],
|
||||
[0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5],
|
||||
[0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6],
|
||||
[0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7],
|
||||
[0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8],
|
||||
[0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9]])
|
||||
"""
|
||||
emb_array = np.arange(0, 1, 0.1).repeat(10).reshape(10, 10)
|
||||
fc_array = np.arange(0, 1, 0.1).repeat(10).reshape(10, 10)
|
||||
model_path = self.save_origin_model(emb_array, fc_array)
|
||||
|
||||
startup_program = fluid.framework.Program()
|
||||
test_program = fluid.framework.Program()
|
||||
role = role_maker.PaddleCloudRoleMaker()
|
||||
fleet.init(role)
|
||||
loss = self.net(emb_array, fc_array)
|
||||
strategy = paddle.distributed.fleet.DistributedStrategy()
|
||||
strategy.a_sync = True
|
||||
optimizer = fluid.optimizer.Adam(1e-3)
|
||||
optimizer = fleet.distributed_optimizer(optimizer, strategy)
|
||||
optimizer.minimize(loss)
|
||||
fleet.init_server(model_path)
|
||||
emb = np.array(fluid.global_scope().find_var("embedding.block1")
|
||||
.get_tensor())
|
||||
assert emb.all() == emb_array[1::2].all()
|
||||
shutil.rmtree(model_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
paddle.enable_static()
|
||||
unittest.main()
|
@ -0,0 +1,48 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import unittest
|
||||
import numpy as np
|
||||
import tempfile
|
||||
import shutil
|
||||
from op_test import OpTest, randomize_probability
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.layers as layers
|
||||
import paddle.distributed.fleet.base.role_maker as role_maker
|
||||
from paddle.distributed.fleet import fleet
|
||||
from test_dist_sparse_tensor_load_sgd import TestSparseLoadProgram
|
||||
|
||||
|
||||
class TestSparseLoadProgramAdagrad(TestSparseLoadProgram):
|
||||
"""
|
||||
Test Sparse load operator.
|
||||
"""
|
||||
|
||||
def test_server_init(self):
|
||||
scope, train_program, startup_program, loss = self.net()
|
||||
with fluid.scope_guard(scope):
|
||||
with fluid.program_guard(train_program, startup_program):
|
||||
optimizer = fluid.optimizer.Adagrad(1e-3)
|
||||
optimizer = fleet.distributed_optimizer(optimizer,
|
||||
self.strategy)
|
||||
optimizer.minimize(loss)
|
||||
fleet.init_server()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
paddle.enable_static()
|
||||
unittest.main()
|
@ -0,0 +1,48 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import unittest
|
||||
import numpy as np
|
||||
import tempfile
|
||||
import shutil
|
||||
from op_test import OpTest, randomize_probability
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.layers as layers
|
||||
import paddle.distributed.fleet.base.role_maker as role_maker
|
||||
from paddle.distributed.fleet import fleet
|
||||
from test_dist_sparse_tensor_load_sgd import TestSparseLoadProgram
|
||||
|
||||
|
||||
class TestSparseLoadProgramAdam(TestSparseLoadProgram):
|
||||
"""
|
||||
Test Sparse load operator.
|
||||
"""
|
||||
|
||||
def test_server_init(self):
|
||||
scope, train_program, startup_program, loss = self.net()
|
||||
with fluid.scope_guard(scope):
|
||||
with fluid.program_guard(train_program, startup_program):
|
||||
optimizer = fluid.optimizer.Adam(1e-3)
|
||||
optimizer = fleet.distributed_optimizer(optimizer,
|
||||
self.strategy)
|
||||
optimizer.minimize(loss)
|
||||
fleet.init_server()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
paddle.enable_static()
|
||||
unittest.main()
|
@ -0,0 +1,48 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import unittest
|
||||
import numpy as np
|
||||
import tempfile
|
||||
import shutil
|
||||
from op_test import OpTest, randomize_probability
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.layers as layers
|
||||
import paddle.distributed.fleet.base.role_maker as role_maker
|
||||
from paddle.distributed.fleet import fleet
|
||||
from test_dist_sparse_tensor_load_sgd import TestSparseLoadProgram
|
||||
|
||||
|
||||
class TestSparseLoadProgramFtrl(TestSparseLoadProgram):
|
||||
"""
|
||||
Test Sparse load operator.
|
||||
"""
|
||||
|
||||
def test_server_init(self):
|
||||
scope, train_program, startup_program, loss = self.net()
|
||||
with fluid.scope_guard(scope):
|
||||
with fluid.program_guard(train_program, startup_program):
|
||||
optimizer = fluid.optimizer.Ftrl(1e-3)
|
||||
optimizer = fleet.distributed_optimizer(optimizer,
|
||||
self.strategy)
|
||||
optimizer.minimize(loss)
|
||||
fleet.init_server()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
paddle.enable_static()
|
||||
unittest.main()
|
@ -0,0 +1,48 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import unittest
|
||||
import numpy as np
|
||||
import tempfile
|
||||
import shutil
|
||||
from op_test import OpTest, randomize_probability
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.layers as layers
|
||||
import paddle.distributed.fleet.base.role_maker as role_maker
|
||||
from paddle.distributed.fleet import fleet
|
||||
from test_dist_sparse_tensor_load_sgd import TestSparseLoadProgram
|
||||
|
||||
|
||||
class TestSparseLoadProgramMomentum(TestSparseLoadProgram):
|
||||
"""
|
||||
Test Sparse load operator.
|
||||
"""
|
||||
|
||||
def test_server_init(self):
|
||||
scope, train_program, startup_program, loss = self.net()
|
||||
with fluid.scope_guard(scope):
|
||||
with fluid.program_guard(train_program, startup_program):
|
||||
optimizer = fluid.optimizer.Momentum(1e-3, 0.9)
|
||||
optimizer = fleet.distributed_optimizer(optimizer,
|
||||
self.strategy)
|
||||
optimizer.minimize(loss)
|
||||
fleet.init_server()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
paddle.enable_static()
|
||||
unittest.main()
|
@ -0,0 +1,48 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import unittest
|
||||
import numpy as np
|
||||
import tempfile
|
||||
import shutil
|
||||
from op_test import OpTest, randomize_probability
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.layers as layers
|
||||
import paddle.distributed.fleet.base.role_maker as role_maker
|
||||
from paddle.distributed.fleet import fleet
|
||||
from test_dist_sparse_tensor_load_sgd import TestSparseLoadProgram
|
||||
|
||||
|
||||
class TestSparseLoadProgramRmsprop(TestSparseLoadProgram):
|
||||
"""
|
||||
Test Sparse load operator.
|
||||
"""
|
||||
|
||||
def test_server_init(self):
|
||||
scope, train_program, startup_program, loss = self.net()
|
||||
with fluid.scope_guard(scope):
|
||||
with fluid.program_guard(train_program, startup_program):
|
||||
optimizer = fluid.optimizer.RMSProp(1e-3)
|
||||
optimizer = fleet.distributed_optimizer(optimizer,
|
||||
self.strategy)
|
||||
optimizer.minimize(loss)
|
||||
fleet.init_server()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
paddle.enable_static()
|
||||
unittest.main()
|
@ -0,0 +1,76 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import unittest
|
||||
import numpy as np
|
||||
import tempfile
|
||||
import shutil
|
||||
from op_test import OpTest, randomize_probability
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.layers as layers
|
||||
import paddle.distributed.fleet.base.role_maker as role_maker
|
||||
from paddle.distributed.fleet import fleet
|
||||
|
||||
|
||||
class TestSparseLoadProgram(unittest.TestCase):
|
||||
"""
|
||||
Test Sparse load operator.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
os.environ[
|
||||
"PADDLE_PSERVERS_IP_PORT_LIST"] = "127.0.0.1:4001,127.0.0.1:4002"
|
||||
os.environ["PADDLE_TRAINERS_NUM"] = str(2)
|
||||
os.environ["TRAINING_ROLE"] = "PSERVER"
|
||||
os.environ["PADDLE_PORT"] = "4001"
|
||||
os.environ["POD_IP"] = "127.0.0.1"
|
||||
role = role_maker.PaddleCloudRoleMaker()
|
||||
fleet.init(role)
|
||||
self.strategy = paddle.distributed.fleet.DistributedStrategy()
|
||||
self.strategy.a_sync = True
|
||||
|
||||
def net(self):
|
||||
train_program = fluid.Program()
|
||||
startup_program = fluid.Program()
|
||||
scope = fluid.Scope()
|
||||
with fluid.scope_guard(scope):
|
||||
with fluid.program_guard(train_program, startup_program):
|
||||
with fluid.unique_name.guard():
|
||||
inputs = fluid.data('input', shape=[None, 1], dtype="int64")
|
||||
emb = fluid.layers.embedding(
|
||||
inputs, is_sparse=True, size=[10000, 128])
|
||||
fc1 = fluid.layers.fc(input=emb, size=128, act="relu")
|
||||
fc2 = fluid.layers.fc(input=fc1, size=64, act="relu")
|
||||
loss = fluid.layers.reduce_mean(fc2)
|
||||
return scope, train_program, startup_program, loss
|
||||
|
||||
|
||||
class TestSparseLoadProgramSGD(TestSparseLoadProgram):
|
||||
def test_server_init(self):
|
||||
scope, train_program, startup_program, loss = self.net()
|
||||
with fluid.scope_guard(scope):
|
||||
with fluid.program_guard(train_program, startup_program):
|
||||
optimizer = fluid.optimizer.SGD(1e-3)
|
||||
optimizer = fleet.distributed_optimizer(optimizer,
|
||||
self.strategy)
|
||||
optimizer.minimize(loss)
|
||||
fleet.init_server()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
paddle.enable_static()
|
||||
unittest.main()
|
Loading…
Reference in new issue