Merge remote-tracking branch 'ups/develop' into refine/op/fusion_lstm

fix-develop-build.sh
tensor-tang 7 years ago
commit 5f586e2223

@ -11,7 +11,6 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h" #include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h"
#include <string> #include <string>
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"

@ -14,7 +14,7 @@ else
fi fi
PREFIX=inference-vis-demos%2F PREFIX=inference-vis-demos%2F
URL_ROOT=http://paddlemodels.bj.bcebos.com/${PREFIX} URL_ROOT=http://paddlemodels.cdn.bcebos.com/${PREFIX}
# download vis_demo data # download vis_demo data
function download() { function download() {

@ -39,8 +39,17 @@ bool RequestSendHandler::Handle(const std::string& varname,
const std::string& out_var_name) { const std::string& out_var_name) {
VLOG(4) << "RequestSendHandler:" << varname; VLOG(4) << "RequestSendHandler:" << varname;
// Sync
if (varname == BATCH_BARRIER_MESSAGE) {
VLOG(3) << "sync: recv BATCH_BARRIER_MESSAGE";
rpc_server_->IncreaseBatchBarrier(kRequestSend);
} else if (varname == COMPLETE_MESSAGE) {
VLOG(3) << "sync: recv complete message";
rpc_server_->Complete();
} else {
// Async // Async
if (!sync_mode_) { if (!sync_mode_) {
VLOG(3) << "async process var: " << varname;
rpc_server_->Profiler().OneStep(); rpc_server_->Profiler().OneStep();
try { try {
executor_->RunPreparedContext((*grad_to_prepared_ctx_)[varname].get(), executor_->RunPreparedContext((*grad_to_prepared_ctx_)[varname].get(),
@ -50,17 +59,7 @@ bool RequestSendHandler::Handle(const std::string& varname,
return false; return false;
} }
return true; return true;
} } else { // sync
// Sync
if (varname == BATCH_BARRIER_MESSAGE) {
VLOG(3) << "sync: recv BATCH_BARRIER_MESSAGE";
rpc_server_->IncreaseBatchBarrier(kRequestSend);
} else if (varname == COMPLETE_MESSAGE) {
VLOG(3) << "sync: recv complete message";
rpc_server_->Complete();
} else {
VLOG(3) << "sync: received var_name: " << varname;
rpc_server_->WaitCond(kRequestSend); rpc_server_->WaitCond(kRequestSend);
VLOG(3) << "sync: processing received var: " << varname; VLOG(3) << "sync: processing received var: " << varname;
@ -68,11 +67,13 @@ bool RequestSendHandler::Handle(const std::string& varname,
LOG(FATAL) << "sync: Can not find server side var: " << varname; LOG(FATAL) << "sync: Can not find server side var: " << varname;
return false; return false;
} }
if (invar->IsType<framework::SelectedRows>()) { if (invar->IsType<framework::SelectedRows>()) {
std::unique_lock<std::mutex> lock(mutex_sparse_vars_); std::unique_lock<std::mutex> lock(mutex_sparse_vars_);
sparse_vars_.push_back(invar); sparse_vars_.push_back(invar);
} }
} }
}
return true; return true;
} }

File diff suppressed because it is too large Load Diff

@ -3546,11 +3546,6 @@ def topk(input, k, name=None):
top5_values, top5_indices = layers.topk(input, k=5) top5_values, top5_indices = layers.topk(input, k=5)
""" """
shape = input.shape
if k < 1 or k >= shape[-1]:
raise ValueError("k must be greater than 0 and less than %d." %
(shape[-1]))
helper = LayerHelper("top_k", **locals()) helper = LayerHelper("top_k", **locals())
values = helper.create_tmp_variable(dtype=input.dtype) values = helper.create_tmp_variable(dtype=input.dtype)
indices = helper.create_tmp_variable(dtype="int64") indices = helper.create_tmp_variable(dtype="int64")

@ -58,6 +58,7 @@ class TestFusionLSTMOp(OpTest):
self.act_cell = 'tanh' self.act_cell = 'tanh'
self.act_cand = 'tanh' self.act_cand = 'tanh'
self.use_peepholes = False self.use_peepholes = False
self.use_seq = False
self.set_conf() self.set_conf()
T = sum(self.lod[0]) T = sum(self.lod[0])
@ -107,6 +108,7 @@ class TestFusionLSTMOp(OpTest):
} }
self.attrs = { self.attrs = {
'use_peepholes': self.use_peepholes, 'use_peepholes': self.use_peepholes,
'use_seq': self.use_seq,
'is_reverse': self.is_reverse, 'is_reverse': self.is_reverse,
'gate_activation': self.act_gate, 'gate_activation': self.act_gate,
'cell_activation': self.act_cell, 'cell_activation': self.act_cell,
@ -159,5 +161,68 @@ class TestFusionLSTMOpBS1(TestFusionLSTMOp):
self.D = 16 self.D = 16
class TestFusionLSTMOpPeepholes(TestFusionLSTMOp):
def set_conf(self):
self.use_peepholes = True
class TestFusionLSTMOpPeepholesInit(TestFusionLSTMOp):
def set_conf(self):
self.use_peepholes = True
self.has_initial_state = True
class TestFusionLSTMOpPeepholesReverse(TestFusionLSTMOp):
def set_conf(self):
self.use_peepholes = True
self.is_reverse = True
class TestFusionLSTMOpPoopholesBS1(TestFusionLSTMOp):
def set_conf(self):
self.use_peepholes = True
self.lod = [[3]]
self.D = 16
class TestFusionLSTMOpSeqInit(TestFusionLSTMOp):
def set_conf(self):
self.use_seq = True
self.has_initial_state = True
class TestFusionLSTMOpSeqReverse(TestFusionLSTMOp):
def set_conf(self):
self.use_seq = True
self.is_reverse = True
class TestFusionLSTMOpSeqInitReverse(TestFusionLSTMOp):
def set_conf(self):
self.use_seq = True
self.has_initial_state = True
self.is_reverse = True
class TestFusionLSTMOpSeqPeepholes(TestFusionLSTMOp):
def set_conf(self):
self.use_seq = True
self.use_peepholes = True
class TestFusionLSTMOpSeqPeepholesInit(TestFusionLSTMOp):
def set_conf(self):
self.use_seq = True
self.use_peepholes = True
self.has_initial_state = True
class TestFusionLSTMOpSeqPeepholesReverse(TestFusionLSTMOp):
def set_conf(self):
self.use_seq = True
self.use_peepholes = True
self.is_reverse = True
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

Loading…
Cancel
Save