address comments and resolve conflicts.

test=develop
revert-15207-remove_op_handle_lock_and_fix_var
Xin Pan 6 years ago
parent b91a7a9d30
commit c132c79011

@ -100,26 +100,6 @@ class Autograd {
}
};
void CreateVariable(const std::string& name, const framework::DDim& dim,
float val, bool random_name, framework::Variable* var) {
if (var->IsInitialized()) return;
std::string varname = name;
if (random_name) {
std::mt19937 rng;
rng.seed(std::random_device()());
std::uniform_int_distribution<std::mt19937::result_type> dist6(
1, std::numeric_limits<int>::max());
int id = dist6(rng);
varname = string::Sprintf("%s@%d", varname, id);
}
VLOG(3) << "creating var " << varname;
framework::LoDTensor* tensor = var->GetMutable<framework::LoDTensor>();
float* data = tensor->mutable_data<float>(dim, platform::CPUPlace());
std::fill(data, data + tensor->numel(), val);
}
framework::LoDTensor& VarBase::Grad() {
VLOG(3) << "get var grad " << var_desc_->Name();
return *grads_->GetMutable<framework::LoDTensor>();

@ -36,7 +36,6 @@ class PreparedOp {
static PreparedOp Prepare(const framework::RuntimeContext& ctx,
const framework::OperatorWithKernel& op,
const platform::Place& place) {
framework::Scope dummy_scope;
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place);
@ -52,7 +51,7 @@ class PreparedOp {
framework::OperatorWithKernel::OpKernelMap& kernels = kernels_iter->second;
auto expected_kernel_key = op.GetExpectedKernelType(
framework::ExecutionContext(op, dummy_scope, *dev_ctx, ctx));
framework::ExecutionContext(op, framework::Scope(), *dev_ctx, ctx));
VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
auto kernel_iter = kernels.find(expected_kernel_key);

Loading…
Cancel
Save