|
|
|
@ -34,6 +34,7 @@ limitations under the License. */
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
constexpr char kOptimizeBlock[] = "OptimizeBlock";
|
|
|
|
|
constexpr int kCondStart = 0;
|
|
|
|
|
constexpr int kCondRunning = 1;
|
|
|
|
|
constexpr int kCondDone = 2;
|
|
|
|
@ -99,10 +100,8 @@ class RecvOp : public framework::OperatorBase {
|
|
|
|
|
auto fan_in = Attr<int>("Fanin");
|
|
|
|
|
size_t param_count = param_list.size();
|
|
|
|
|
|
|
|
|
|
std::string program_str = Attr<std::string>("OptimizeProgram");
|
|
|
|
|
framework::proto::ProgramDesc program_desc;
|
|
|
|
|
program_desc.ParseFromString(program_str);
|
|
|
|
|
framework::ProgramDesc program(program_desc);
|
|
|
|
|
auto *block = Attr<framework::BlockDesc *>(kOptimizeBlock);
|
|
|
|
|
auto *program = block->Program();
|
|
|
|
|
framework::Executor executor(dev_place);
|
|
|
|
|
|
|
|
|
|
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
|
|
|
|
@ -142,8 +141,9 @@ class RecvOp : public framework::OperatorBase {
|
|
|
|
|
if (exit_flag) {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
try {
|
|
|
|
|
executor.Run(program, &recv_scope, 0, /*global_block*/
|
|
|
|
|
executor.Run(*program, &recv_scope, block->ID(), /*global_block*/
|
|
|
|
|
false /*create_local_scope*/, false /*create_vars*/);
|
|
|
|
|
} catch (std::exception &e) {
|
|
|
|
|
LOG(ERROR) << "run sub program error " << e.what();
|
|
|
|
@ -175,8 +175,8 @@ This operator will recv tensor from send_op
|
|
|
|
|
"IP address to listen on.")
|
|
|
|
|
.SetDefault("127.0.0.1:6164")
|
|
|
|
|
.AddCustomChecker([](const std::string &ip) { return !ip.empty(); });
|
|
|
|
|
AddAttr<std::string>("OptimizeProgram", "type string",
|
|
|
|
|
"Serialized ProgramDesc string for recv to run.");
|
|
|
|
|
AddAttr<framework::BlockDesc *>(
|
|
|
|
|
kOptimizeBlock, "Serialized ProgramDesc string for recv to run.");
|
|
|
|
|
AddAttr<std::vector<std::string>>(
|
|
|
|
|
"ParamList", "type list of string",
|
|
|
|
|
"grad->param name mapping to find which param to optimize.")
|
|
|
|
|