Fix layers.uniform_random (#13823)

* fix layers.uniform_random

* fix uniform_random
test=develop

* remove var type set
test=develop

* fix similar error
test=develop
revert-13821-fix
chengduo 6 years ago committed by GitHub
parent 5f2e837847
commit 9c77b65c06
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -115,14 +115,14 @@ class UniformRandomOpVarTypeInference : public framework::VarTypeInference {
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
auto out_var_name = op_desc.Output("Out").front();
if (block->FindRecursiveOrCreateVar(out_var_name).GetType() ==
framework::proto::VarType::SELECTED_ROWS) {
block->FindRecursiveOrCreateVar(out_var_name)
.SetType(framework::proto::VarType::SELECTED_ROWS);
} else {
block->FindRecursiveOrCreateVar(out_var_name)
.SetType(framework::proto::VarType::LOD_TENSOR);
auto var_data_type = static_cast<framework::proto::VarType::Type>(
boost::get<int>(op_desc.GetAttr("dtype")));
auto out_var = block->FindRecursiveOrCreateVar(out_var_name);
if (out_var.GetType() != framework::proto::VarType::SELECTED_ROWS) {
out_var.SetType(framework::proto::VarType::LOD_TENSOR);
}
out_var.SetDataType(var_data_type);
}
};

@ -14,6 +14,8 @@
from __future__ import print_function
from .layer_function_generator import generate_layer_fn, generate_layer_fn_noattr
from .. import core
from ..framework import convert_np_dtype_to_dtype_
__activations_noattr__ = [
'sigmoid',
@ -58,8 +60,11 @@ _uniform_random_ = generate_layer_fn('uniform_random')
def uniform_random(shape, dtype=None, min=None, max=None, seed=None):
locals_var = locals().keys()
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)
kwargs = dict()
for name in locals():
for name in locals_var:
val = locals()[name]
if val is not None:
kwargs[name] = val
@ -78,8 +83,9 @@ _hard_shrink_ = generate_layer_fn('hard_shrink')
def hard_shrink(x, threshold=None):
locals_var = locals().keys()
kwargs = dict()
for name in locals():
for name in locals_var:
val = locals()[name]
if val is not None:
kwargs[name] = val
@ -99,12 +105,12 @@ _cum_sum_ = generate_layer_fn('cumsum')
def cumsum(x, axis=None, exclusive=None, reverse=None):
locals_var = locals().keys()
kwargs = dict()
for name in locals():
for name in locals_var:
val = locals()[name]
if val is not None:
kwargs[name] = val
return _cum_sum_(**kwargs)
@ -121,8 +127,9 @@ _thresholded_relu_ = generate_layer_fn('thresholded_relu')
def thresholded_relu(x, threshold=None):
locals_var = locals().keys()
kwargs = dict()
for name in locals():
for name in locals_var:
val = locals()[name]
if val is not None:
kwargs[name] = val

Loading…
Cancel
Save