Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into fix_crop

release/0.11.0
wanghaoshuang 7 years ago
commit ce4e0e90a0

@ -522,7 +522,7 @@ ParamGradInfoMap AppendBackward(
new OpDescBind("fill_constant", {}, {{"Out", {fill_one_op_out}}},
{{"shape", std::vector<int>{1}},
{"value", static_cast<float>(1.0)},
{"data_type", target.GetDataType()}}));
{"dtype", target.GetDataType()}}));
// infer var type of fill_one_op
fill_one_op->InferVarType(root_block);

@ -120,7 +120,7 @@ void Executor::Run(const ProgramDescBind& pdesc, Scope* scope, int block_id,
for (auto& op_desc : block.AllOps()) {
auto op = paddle::framework::OpRegistry::CreateOp(*op_desc);
VLOG(10) << op->DebugString();
VLOG(3) << op->DebugString();
op->Run(*local_scope, *device);
}
if (create_local_scope) {

@ -26,6 +26,8 @@ namespace framework {
const std::string kFeedOpType = "feed";
const std::string kFetchOpType = "fetch";
const std::string kDropOutOpType = "dropout";
const std::string kBatchNormOpType = "batch_norm";
bool HasDependentVar(const OpDesc& op_desc,
const std::set<std::string>& dependent_vars) {
@ -106,5 +108,26 @@ void Prune(const ProgramDesc& input, ProgramDesc* output) {
prune_impl(input, output, 0);
}
void inference_optimize_impl(const ProgramDesc& input, ProgramDesc* output,
int block_id) {
*output = input;
auto* op_field = output->mutable_blocks(block_id)->mutable_ops();
for (auto& op_desc : *op_field) {
if (op_desc.type() == kDropOutOpType ||
op_desc.type() == kBatchNormOpType) {
for (auto& attr : *op_desc.mutable_attrs()) {
if (attr.name() == "is_test") {
attr.set_b(true);
break;
}
}
}
}
}
void InferenceOptimize(const ProgramDesc& input, ProgramDesc* output) {
inference_optimize_impl(input, output, 0);
}
} // namespace framework
} // namespace paddle

@ -22,5 +22,7 @@ namespace framework {
void Prune(const ProgramDesc& input, ProgramDesc* output);
void InferenceOptimize(const ProgramDesc& input, ProgramDesc* output);
} // namespace framework
} // namespace paddle

@ -302,7 +302,7 @@ LoDTensor TensorArray::Stack() const {
const auto& first_dims = values_.front().dims();
// check all the values have the same shape
// TODO(superjom) check the same dtypes
// TODO(superjom) check the same data_type
for (size_t idx = 1; idx < size(); idx++) {
const auto& value_dims = values_[idx].dims();
PADDLE_ENFORCE_EQ(first_dims, value_dims);

@ -1,6 +1,6 @@
add_subdirectory(detail)
cc_library(memory SRCS memory.cc DEPS place)
cc_library(memory SRCS memory.cc DEPS place enforce)
cc_library(memcpy SRCS memcpy.cc)
cc_library(paddle_memory

@ -17,6 +17,36 @@ limitations under the License. */
namespace paddle {
namespace operators {
struct BeamSearchDecodeFunctor {
BeamSearchDecodeFunctor(const LoDTensorArray& step_ids,
const LoDTensorArray& step_scores,
LoDTensor* id_tensor, LoDTensor* score_tensor)
: step_ids_(step_ids),
step_scores_(step_scores),
id_tensor_(id_tensor),
score_tensor_(score_tensor) {}
template <typename T>
void operator()() const;
const LoDTensorArray& step_ids_;
const LoDTensorArray& step_scores_;
LoDTensor* id_tensor_;
LoDTensor* score_tensor_;
};
template <typename T>
void BeamSearchDecodeFunctor::operator()() const {
BeamSearchDecoder<T> beam_search_decoder;
beam_search_decoder.PackAllSteps(step_ids_, step_scores_, id_tensor_,
score_tensor_);
}
template <>
void BeamSearchDecodeFunctor::operator()<bool>() const {
PADDLE_THROW("beam search decode op does not support bool!");
}
class BeamSearchDecodeOp : public framework::OperatorBase {
public:
BeamSearchDecodeOp(const std::string& type,
@ -45,9 +75,9 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
LoDTensor* sentenceIds = ctx.Output<LoDTensor>("SentenceIds");
LoDTensor* sentenceScores = ctx.Output<LoDTensor>("SentenceScores");
BeamSearchDecoder<float> beam_search_decoder;
beam_search_decoder.PackAllSteps(*ids, *scores, sentenceIds,
sentenceScores);
framework::VisitDataType(
framework::ToDataType(scores->at(0).type()),
BeamSearchDecodeFunctor(*ids, *scores, sentenceIds, sentenceScores));
}
};

@ -77,11 +77,19 @@ class BilinearTensorProductOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("Out", "The output of bilinear_tensor_product operator.");
AddComment(R"DOC(
Bilinear Tensor Product operator.
Given input X and Y, a 3D tensor weight, and bias. Each column of the
output is computed by one slice i = 1, . . . , k of the tensor:
M = (X W_i) \cdot Y
Out_i = \sum_i {M_i} + Bias_i
Given input X and Y, a 3D tensor Weight and a Bias. Each column of the
Output is computed by one slice $i = 1, . . . , k$ of the tensor:
$$
M = (X W_i) * Y \\
Out_i = \sum_j {M_j} + Bias_i
$$
Where $W_i$ is the $i$-th slice of Input(Weight);
$M_j$ is the $j$-th column of $M$;
$Out_i$ is the $i$-th column of Output(Out);
$Bias_i$ is a column vector, each element of it is equal to
the $i$-th element of $Bias$;
)DOC");
}

@ -25,8 +25,8 @@ class CastOpProtoMaker : public framework::OpProtoAndCheckerMaker {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input tensor of cast op");
AddOutput("Out", "The output tensor of cast op");
AddAttr<int>("out_data_type", "output data type");
AddAttr<int>("in_data_type", "input data type");
AddAttr<int>("out_dtype", "output data type");
AddAttr<int>("in_dtype", "input data type");
AddComment(R"DOC(
Cast Operator.
@ -58,8 +58,8 @@ class CastOpGradMaker : public framework::SingleGradOpDescMaker {
grad->SetType("cast");
grad->SetInput("X", OutputGrad("Out"));
grad->SetOutput("Out", InputGrad("X"));
grad->SetAttr("out_data_type", GetAttr("in_data_type"));
grad->SetAttr("in_data_type", GetAttr("out_data_type"));
grad->SetAttr("out_dtype", GetAttr("in_dtype"));
grad->SetAttr("in_dtype", GetAttr("out_dtype"));
return std::unique_ptr<framework::OpDescBind>(grad);
}
};

@ -55,7 +55,7 @@ class CastOpKernel : public framework::OpKernel<InT> {
auto* in = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
framework::VisitDataType(
static_cast<framework::DataType>(context.Attr<int>("out_data_type")),
static_cast<framework::DataType>(context.Attr<int>("out_dtype")),
CastOpFunctor<Place, InT>(in, out, context.device_context()));
}
};

@ -30,7 +30,7 @@ class DropoutOp : public framework::OperatorWithKernel {
auto x_dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", x_dims);
if (ctx->Attrs().Get<bool>("is_training") == true) {
if (ctx->Attrs().Get<bool>("is_test") == false) {
ctx->SetOutputDim("Mask", x_dims);
}
ctx->ShareLoD("X", /*->*/ "Out");
@ -49,7 +49,7 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<float>("dropout_prob", "Probability of setting units to zero.")
.SetDefault(.5f);
AddAttr<bool>("is_training", "True if in training phase.").SetDefault(true);
AddAttr<bool>("is_test", "True if in test phase.").SetDefault(false);
AddAttr<int>("seed", "Dropout random seed.").SetDefault(0);
AddComment(R"DOC(
@ -71,8 +71,8 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->Attrs().Get<bool>("is_training"), true,
"GradOp is only callable when is_training is true");
PADDLE_ENFORCE_EQ(ctx->Attrs().Get<bool>("is_test"), false,
"GradOp is only callable when is_test is false");
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
PADDLE_ENFORCE(ctx->HasInput("Mask"), "Mask must not be null.");

@ -59,7 +59,7 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
auto Y = EigenMatrix<T>::Reshape(*y, 1);
auto place = context.GetEigenDevice<Place>();
if (context.Attr<bool>("is_training")) {
if (!context.Attr<bool>("is_test")) {
auto* mask = context.Output<Tensor>("Mask");
auto* mask_data = mask->mutable_data<T>(context.GetPlace());
int size = framework::product(mask->dims());

@ -35,7 +35,7 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
auto* y_data = y->mutable_data<T>(context.GetPlace());
float dropout_prob = context.Attr<float>("dropout_prob");
if (context.Attr<bool>("is_training")) {
if (!context.Attr<bool>("is_test")) {
auto* mask = context.Output<Tensor>("Mask");
auto* mask_data = mask->mutable_data<T>(context.GetPlace());
int seed = context.Attr<int>("seed");
@ -65,8 +65,8 @@ template <typename Place, typename T>
class DropoutGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE(context.Attr<bool>("is_training"),
"GradOp is only callable when is_training is true");
PADDLE_ENFORCE(!context.Attr<bool>("is_test"),
"GradOp is only callable when is_test is false");
auto* grad_x = context.Output<Tensor>(framework::GradVarName("X"));
auto* grad_y = context.Input<Tensor>(framework::GradVarName("Out"));

@ -52,7 +52,7 @@ class FillConstantBatchSizeLikeOp : public framework::OperatorWithKernel {
framework::OpKernelType GetKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
static_cast<framework::DataType>(ctx.Attr<int>("data_type")),
static_cast<framework::DataType>(ctx.Attr<int>("dtype")),
ctx.device_context());
}
};
@ -63,7 +63,7 @@ class FillConstantBatchSizeLikeOpMaker
FillConstantBatchSizeLikeOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<int>("data_type",
AddAttr<int>("dtype",
"(int, default 5 (FP32)) "
"Output data type")
.SetDefault(framework::DataType::FP32);

@ -34,7 +34,7 @@ class FillConstantOp : public framework::OperatorBase {
using framework::OperatorBase::OperatorBase;
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
auto data_type = static_cast<framework::DataType>(Attr<int>("data_type"));
auto data_type = static_cast<framework::DataType>(Attr<int>("dtype"));
auto value = Attr<float>("value");
auto force_cpu = Attr<bool>("force_cpu");
auto &out =
@ -55,7 +55,7 @@ class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker {
FillConstantOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<int>("data_type",
AddAttr<int>("dtype",
"(int, default 5 (FP32)) "
"Output data type")
.SetDefault(framework::DataType::FP32);

@ -60,7 +60,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
static_cast<framework::DataType>(ctx.Attr<int>("data_type")),
static_cast<framework::DataType>(ctx.Attr<int>("dtype")),
ctx.device_context());
}
};
@ -88,7 +88,7 @@ class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker {
"Random seed of generator."
"0 means use system wide seed.")
.SetDefault(0);
AddAttr<int>("data_type",
AddAttr<int>("dtype",
"(int, default 5(FP32)) "
"Output data type.")
.SetDefault(framework::DataType::FP32);

@ -32,19 +32,19 @@ class LinearChainCRFOpMaker : public framework::OpProtoAndCheckerMaker {
"[(D + 2) x D]. The learnable parameter for the linear_chain_crf "
"operator. See more details in the operator's comments.");
AddInput("Label",
"(LoDTensor, default LoDTensor<int>) A LoDTensor with shape "
"(LoDTensor, default LoDTensor<int64_t>) A LoDTensor with shape "
"[N x 1], where N is the total element number in a mini-batch. "
"The ground truth.");
AddOutput(
"Alpha",
"(Tensor, default Tensor<float>) A 2-D Tensor with shape [N x D]. "
"The forward vectors for the entire batch. Denote it as \f$\alpha\f$. "
"\f$\alpha$\f is a memo table used to calculate the normalization "
"factor in CRF. \f$\alpha[k, v]$\f stores the unnormalized "
"The forward vectors for the entire batch. Denote it as $\alpha$. "
"$\alpha$ is a memo table used to calculate the normalization "
"factor in CRF. $\alpha[k, v]$ stores the unnormalized "
"probabilites of all possible unfinished sequences of tags that end at "
"position \f$k$\f with tag \f$v$\f. For each \f$k$\f, "
"\f$\alpha[k, v]$\f is a vector of length \f$D$\f with a component for "
"each tag value \f$v$\f. This vector is called a forward vecotr and "
"position $k$ with tag $v$. For each $k$, "
"$\alpha[k, v]$ is a vector of length $D$ with a component for "
"each tag value $v$. This vector is called a forward vecotr and "
"will also be used in backward computations.")
.AsIntermediate();
AddOutput(
@ -73,9 +73,9 @@ LinearChainCRF Operator.
Conditional Random Field defines an undirected probabilistic graph with nodes
denoting random variables and edges denoting dependencies between these
variables. CRF learns the conditional probability \f$P(Y|X)\f$, where
\f$X = (x_1, x_2, ... , x_n)\f$ are structured inputs and
\f$Y = (y_1, y_2, ... , y_n)\f$ are labels for the inputs.
variables. CRF learns the conditional probability $P(Y|X)$, where
$X = (x_1, x_2, ... , x_n)$ are structured inputs and
$Y = (y_1, y_2, ... , y_n)$ are labels for the inputs.
Linear chain CRF is a special case of CRF that is useful for sequence labeling
task. Sequence labeling tasks do not assume a lot of conditional
@ -88,21 +88,22 @@ CRF. Please refer to http://www.cs.columbia.edu/~mcollins/fb.pdf and
http://cseweb.ucsd.edu/~elkan/250Bwinter2012/loglinearCRFs.pdf for details.
Equation:
1. Denote Input(Emission) to this operator as \f$x\f$ here.
1. Denote Input(Emission) to this operator as $x$ here.
2. The first D values of Input(Transition) to this operator are for starting
weights, denoted as \f$a\f$ here.
weights, denoted as $a$ here.
3. The next D values of Input(Transition) of this operator are for ending
weights, denoted as \f$b\f$ here.
weights, denoted as $b$ here.
4. The remaning values of Input(Transition) are for transition weights,
denoted as \f$w\f$ here.
5. Denote Input(Label) as \f$s\f$ here.
The probability of a sequence \f$s\f$ of length \f$L\f$ is defined as:
\f$P(s) = (1/Z) \exp(a_{s_1} + b_{s_L}
+ \sum_{l=1}^L x_{s_l}
+ \sum_{l=2}^L w_{s_{l-1},s_l})\f$
where \f$Z\f$ is a normalization value so that the sum of \f$P(s)\f$ over
all possible sequences is \f$1\f$, and \f$x\f$ is the emission feature weight
denoted as $w$ here.
5. Denote Input(Label) as $s$ here.
The probability of a sequence $s$ of length $L$ is defined as:
$$P(s) = (1/Z) \exp(a_{s_1} + b_{s_L}
+ \sum_{l=1}^L x_{s_l}
+ \sum_{l=2}^L w_{s_{l-1},s_l})$$
where $Z$ is a normalization value so that the sum of $P(s)$ over
all possible sequences is 1, and $x$ is the emission feature weight
to the linear chain CRF.
Finally, the linear chain CRF operator outputs the logarithm of the conditional

@ -49,7 +49,7 @@ class NCCLInitOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("Communicator",
"Create Communicator for communicating between gpus");
AddAttr<std::vector<int>>("gpus", "(vector<int>) GPU id lists");
AddAttr<int>("data_type",
AddAttr<int>("dtype",
"(int, default 5 (FP32)) "
"Output data type")
.SetDefault(framework::DataType::FP32);

@ -401,7 +401,7 @@ class RecurrentGradOp : public RecurrentBase {
auto &inside_tensor = cur_scope.FindVar(inside_grad_name)
->Get<framework::LoDTensor>();
framework::AttributeMap attrs;
attrs["data_type"] = framework::ToDataType(inside_tensor.type());
attrs["dtype"] = framework::ToDataType(inside_tensor.type());
attrs["shape"] = framework::vectorize2int(inside_tensor.dims());
attrs["value"] = 0.0f;

@ -62,7 +62,7 @@ class RNNMemoryHelperOpInfoMaker : public framework::OpProtoAndCheckerMaker {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "");
AddOutput("Out", "");
AddAttr<int>("data_type",
AddAttr<int>("dtype",
"(int, default 5 (FP32)) "
"Output data type")
.SetDefault(framework::DataType::FP32);
@ -95,7 +95,7 @@ class RNNMemoryHelperGradOp : public framework::OperatorBase {
auto &in_var_tensor = in_var->Get<framework::LoDTensor>();
framework::AttributeMap attrs;
attrs["data_type"] = framework::ToDataType(in_var_tensor.type());
attrs["dtype"] = framework::ToDataType(in_var_tensor.type());
attrs["shape"] = framework::vectorize2int(in_var_tensor.dims());
attrs["value"] = 0.0f;
@ -121,7 +121,7 @@ class RNNMemoryHelperGradOpInfoMaker
AddInput("X", "");
AddInput("Out", "");
AddOutput(framework::GradVarName("X"), "");
AddAttr<int>("data_type",
AddAttr<int>("dtype",
"(int, default 5 (FP32)) "
"Output data type")
.SetDefault(framework::DataType::FP32);

@ -59,7 +59,7 @@ Then the ratio of the exponential of the given dimension and the sum of
exponential values of all the other dimensions is the output of the softmax
operator.
For each row `i` and each column `j` in input X, we have:
For each row $i$ and each column $j$ in Input(X), we have:
$$Y[i, j] = \frac{\exp(X[i, j])}{\sum_j(exp(X[i, j])}$$
)DOC");

@ -67,15 +67,15 @@ The equation is as follows:
1) Hard label (one-hot label, so every sample has exactly one class)
$$Loss_j = \f$ -\text{Logit}_{Label_j} +
$$Loss_j = -\text{Logit}_{Label_j} +
\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right),
j = 1, ..., K $\f$$
j = 1,..., K$$
2) Soft label (each sample can have a distribution over all classes)
$$Loss_j = \f$ -\sum_{i=0}^{K}\text{Label}_i\left(\text{Logit}_i -
$$Loss_j = -\sum_{i=0}^{K}\text{Label}_i \left(\text{Logit}_i -
\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right)\right),
j = 1,...,K $\f$$
j = 1,...,K$$
)DOC");
}

@ -66,7 +66,7 @@ class UniformRandomOp : public framework::OperatorWithKernel {
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
static_cast<framework::DataType>(ctx.Attr<int>("data_type")),
static_cast<framework::DataType>(ctx.Attr<int>("dtype")),
ctx.device_context());
}
};
@ -99,7 +99,7 @@ uniform distribution.
"Random seed used for generating samples. "
"0 means use a seed generated by the system.")
.SetDefault(0);
AddAttr<int>("data_type", "(int, default 5(FP32)) Output tensor data type")
AddAttr<int>("dtype", "(int, default 5(FP32)) Output tensor data type")
.SetDefault(framework::DataType::FP32);
}
};

@ -180,7 +180,7 @@ class WhileGradOp : public framework::OperatorBase {
if (var->IsType<LoDTensor>()) {
auto &inside_tensor = var->Get<framework::LoDTensor>();
framework::AttributeMap attrs;
attrs["data_type"] = framework::ToDataType(inside_tensor.type());
attrs["dtype"] = framework::ToDataType(inside_tensor.type());
attrs["shape"] = framework::vectorize2int(inside_tensor.dims());
attrs["value"] = 0.0f;

@ -1,15 +1,20 @@
cc_library(cpu_info SRCS cpu_info.cc DEPS gflags glog)
if(WITH_GPU)
cc_library(enforce SRCS enforce.cc DEPS nccl)
else()
cc_library(enforce SRCS enforce.cc)
endif()
cc_test(enforce_test SRCS enforce_test.cc DEPS stringpiece enforce)
cc_library(cpu_info SRCS cpu_info.cc DEPS gflags glog enforce)
cc_test(cpu_info_test SRCS cpu_info_test.cc DEPS cpu_info)
nv_library(gpu_info SRCS gpu_info.cc DEPS gflags glog)
nv_library(gpu_info SRCS gpu_info.cc DEPS gflags glog enforce)
cc_library(place SRCS place.cc)
cc_library(place SRCS place.cc DEPS enforce)
cc_test(place_test SRCS place_test.cc DEPS place glog gflags)
add_subdirectory(dynload)
cc_test(enforce_test SRCS enforce_test.cc DEPS stringpiece)
IF(WITH_GPU)
set(GPU_CTX_DEPS dynload_cuda dynamic_loader)
ELSE()

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save