Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into fix-dygraph-checkpoint

shanyi15-patch-1
lujun 6 years ago
commit a0d362efb7

@ -337,13 +337,13 @@ paddle.fluid.layers.reciprocal (ArgSpec(args=['x', 'name'], varargs=None, keywor
paddle.fluid.layers.square (ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '48dfb45d773dbc30126c3a7f777de5ee'))
paddle.fluid.layers.softplus (ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '459c5781e9d1dd88283b7c5769d7872a'))
paddle.fluid.layers.softsign (ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '80846bcd4bd457207457a6d5411f4148'))
paddle.fluid.layers.uniform_random (ArgSpec(args=['shape', 'dtype', 'min', 'max', 'seed'], varargs=None, keywords=None, defaults=('float32', -1.0, 1.0, 0)), ('document', '308b619af849caa82bbc31e897f5e641'))
paddle.fluid.layers.uniform_random (ArgSpec(args=['shape', 'dtype', 'min', 'max', 'seed'], varargs=None, keywords=None, defaults=('float32', -1.0, 1.0, 0)), ('document', 'a8c4e972b7d6742c838a37abf407ed9a'))
paddle.fluid.layers.hard_shrink (ArgSpec(args=['x', 'threshold'], varargs=None, keywords=None, defaults=(None,)), ('document', 'c142f5884f3255e0d6075c286bbd531e'))
paddle.fluid.layers.cumsum (ArgSpec(args=['x', 'axis', 'exclusive', 'reverse'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '944d7c03057f5fc88bc78acd4d82f926'))
paddle.fluid.layers.thresholded_relu (ArgSpec(args=['x', 'threshold'], varargs=None, keywords=None, defaults=(None,)), ('document', '90566ea449ea4c681435546e2f70610a'))
paddle.fluid.layers.prior_box (ArgSpec(args=['input', 'image', 'min_sizes', 'max_sizes', 'aspect_ratios', 'variance', 'flip', 'clip', 'steps', 'offset', 'name', 'min_max_aspect_ratios_order'], varargs=None, keywords=None, defaults=(None, [1.0], [0.1, 0.1, 0.2, 0.2], False, False, [0.0, 0.0], 0.5, None, False)), ('document', '14cac0ee643fa6e026ad82aeeee75bd8'))
paddle.fluid.layers.density_prior_box (ArgSpec(args=['input', 'image', 'densities', 'fixed_sizes', 'fixed_ratios', 'variance', 'clip', 'steps', 'offset', 'flatten_to_2d', 'name'], varargs=None, keywords=None, defaults=(None, None, None, [0.1, 0.1, 0.2, 0.2], False, [0.0, 0.0], 0.5, False, None)), ('document', 'a0d762bb08de9ce93bc780aa57cd5cd9'))
paddle.fluid.layers.multi_box_head (ArgSpec(args=['inputs', 'image', 'base_size', 'num_classes', 'aspect_ratios', 'min_ratio', 'max_ratio', 'min_sizes', 'max_sizes', 'steps', 'step_w', 'step_h', 'offset', 'variance', 'flip', 'clip', 'kernel_size', 'pad', 'stride', 'name', 'min_max_aspect_ratios_order'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None, None, 0.5, [0.1, 0.1, 0.2, 0.2], True, False, 1, 0, 1, None, False)), ('document', 'a6ab47a2fe681e52fabb7057ddf0efdd'))
paddle.fluid.layers.multi_box_head (ArgSpec(args=['inputs', 'image', 'base_size', 'num_classes', 'aspect_ratios', 'min_ratio', 'max_ratio', 'min_sizes', 'max_sizes', 'steps', 'step_w', 'step_h', 'offset', 'variance', 'flip', 'clip', 'kernel_size', 'pad', 'stride', 'name', 'min_max_aspect_ratios_order'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None, None, 0.5, [0.1, 0.1, 0.2, 0.2], True, False, 1, 0, 1, None, False)), ('document', 'fe9afaee481dd09f28866df22756466f'))
paddle.fluid.layers.bipartite_match (ArgSpec(args=['dist_matrix', 'match_type', 'dist_threshold', 'name'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '3ddb9b966f193900193a95a3df77c3c1'))
paddle.fluid.layers.target_assign (ArgSpec(args=['input', 'matched_indices', 'negative_indices', 'mismatch_value', 'name'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', 'c0b334f917828f95056f6ebe10907b1c'))
paddle.fluid.layers.detection_output (ArgSpec(args=['loc', 'scores', 'prior_box', 'prior_box_var', 'background_label', 'nms_threshold', 'nms_top_k', 'keep_top_k', 'score_threshold', 'nms_eta'], varargs=None, keywords=None, defaults=(0, 0.3, 400, 200, 0.01, 1.0)), ('document', 'c33093a82a46e3091e789e5572588db1'))
@ -358,7 +358,7 @@ paddle.fluid.layers.generate_mask_labels (ArgSpec(args=['im_info', 'gt_classes',
paddle.fluid.layers.iou_similarity (ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '587845f60c5d97ffdf2dfd21da52eca1'))
paddle.fluid.layers.box_coder (ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name', 'axis'], varargs=None, keywords=None, defaults=('encode_center_size', True, None, 0)), ('document', '032d0f4b7d8f6235ee5d91e473344f0e'))
paddle.fluid.layers.polygon_box_transform (ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '0e5ac2507723a0b5adec473f9556799b'))
paddle.fluid.layers.yolov3_loss (ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'anchor_mask', 'class_num', 'ignore_thresh', 'downsample_ratio', 'gtscore', 'use_label_smooth', 'name'], varargs=None, keywords=None, defaults=(None, True, None)), ('document', '57fa96922e42db8f064c3fb77f2255e8'))
paddle.fluid.layers.yolov3_loss (ArgSpec(args=['x', 'gt_box', 'gt_label', 'anchors', 'anchor_mask', 'class_num', 'ignore_thresh', 'downsample_ratio', 'gt_score', 'use_label_smooth', 'name'], varargs=None, keywords=None, defaults=(None, True, None)), ('document', '4d170807a13d33925d1049d2892832bf'))
paddle.fluid.layers.yolo_box (ArgSpec(args=['x', 'img_size', 'anchors', 'class_num', 'conf_thresh', 'downsample_ratio', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '5566169a5ab993d177792c023c7fb340'))
paddle.fluid.layers.box_clip (ArgSpec(args=['input', 'im_info', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '397e9e02b451d99c56e20f268fa03f2e'))
paddle.fluid.layers.multiclass_nms (ArgSpec(args=['bboxes', 'scores', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'normalized', 'nms_eta', 'background_label', 'name'], varargs=None, keywords=None, defaults=(0.3, True, 1.0, 0, None)), ('document', 'ca7d1107b6c5d2d6d8221039a220fde0'))

@ -64,12 +64,19 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
auto c_dims = ctx->GetInputDim("C0");
PADDLE_ENFORCE_EQ(c_dims.size(), 2, "Input(C0)'s rank must be 2.");
PADDLE_ENFORCE_EQ(c_dims[1], D, "C0 dims should be N x %d.", D);
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(c_dims[1], D, "C0 dims should be N x %d.", D);
}
if (ctx->HasInput("H0")) {
auto h_dims = ctx->GetInputDim("H0");
PADDLE_ENFORCE(h_dims == c_dims,
"The dimension of Input(H0) and Input(C0) "
"should be the same.");
PADDLE_ENFORCE_EQ(h_dims.size(), 2UL, "Input(H0)'s rank must be 2.");
if (ctx->IsRuntime() ||
(framework::product(c_dims) > 0 && framework::product(h_dims) > 0)) {
PADDLE_ENFORCE(h_dims == c_dims,
"The dimension of Input(H0) and Input(C0) "
"should be the same.");
}
}
auto atten_w_dims = ctx->GetInputDim("AttentionWeight");
@ -79,6 +86,7 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
"AttentionWeight shapes must be (%d + %d) * 1.", M, D);
PADDLE_ENFORCE_EQ(atten_w_dims[1], 1,
"AttentionWeight shapes must be (%d + %d) * 1.", M, D);
if (ctx->HasInput("AttentionBias")) {
auto atten_b_dims = ctx->GetInputDim("AttentionBias");
PADDLE_ENFORCE_EQ(atten_b_dims.size(), 2,

@ -32,10 +32,14 @@ class BprLossOp : public framework::OperatorWithKernel {
int rank = x_dims.size();
PADDLE_ENFORCE_EQ(rank, label_dims.size(),
"Input(X) and Input(Label) shall have the same rank.");
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
framework::slice_ddim(label_dims, 0, rank - 1),
"Input(X) and Input(Label) shall have the same shape "
"except the last dimension.");
if (ctx->IsRuntime() || (framework::product(x_dims) > 0 &&
framework::product(label_dims) > 0)) {
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
framework::slice_ddim(label_dims, 0, rank - 1),
"Input(X) and Input(Label) shall have the same shape "
"except the last dimension.");
}
auto y_dims = x_dims;
y_dims[rank - 1] = 1;

@ -81,8 +81,10 @@ class WriteToArrayInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE(context->HasInput("I"), "Must set the subscript index");
PADDLE_ENFORCE_EQ(framework::product(context->GetInputDim("I")), 1,
"The number of element of subscript index must be 1");
if (context->IsRuntime()) {
PADDLE_ENFORCE_EQ(framework::product(context->GetInputDim("I")), 1,
"The number of element of subscript index must be 1");
}
if (!context->HasInput("X")) {
return;
}

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/data_norm_op.h"
#include <memory>
#include <string>
#include "paddle/fluid/framework/data_layout.h"
#ifdef PADDLE_WITH_MKLDNN
@ -65,9 +66,11 @@ class DataNormOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSize").size(), 1UL);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSum").size(), 1UL);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSquareSum").size(), 1UL);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSize")[0], C);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSum")[0], C);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSquareSum")[0], C);
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSize")[0], C);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSum")[0], C);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSquareSum")[0], C);
}
ctx->SetOutputDim("Y", x_dims);
ctx->SetOutputDim("Means", {C});

@ -171,8 +171,8 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker {
The output of previous network is in shape [N, C, H, W], while H and W
should be the same, H and W specify the grid size, each grid point predict
given number boxes, this given number, which following will be represented as S,
is specified by the number of anchors, In the second dimension(the channel
given number bounding boxes, this given number, which following will be represented as S,
is specified by the number of anchor clusters in each scale. In the second dimension(the channel
dimension), C should be equal to S * (class_num + 5), class_num is the object
category number of source dataset(such as 80 in coco dataset), so in the
second(channel) dimension, apart from 4 box location coordinates x, y, w, h,
@ -203,7 +203,7 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker {
thresh, the confidence score loss of this anchor box will be ignored.
Therefore, the yolov3 loss consist of three major parts, box location loss,
confidence score loss, and classification loss. The L2 loss is used for
confidence score loss, and classification loss. The L1 loss is used for
box coordinates (w, h), and sigmoid cross entropy loss is used for box
coordinates (x, y), confidence score loss and classification loss.

@ -31,13 +31,18 @@ class HuberLossOp : public framework::OperatorWithKernel {
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
PADDLE_ENFORCE_EQ(x_dims, y_dims);
PADDLE_ENFORCE_EQ(x_dims.size(), 2,
"The rank of Input(X) must be 2 and the shape is "
"[batch_size, 1].");
PADDLE_ENFORCE_EQ(x_dims[1], 1,
"Each row of Input(X) contains a real value, "
"so the 2nd dimension of Input(X) must be 1.");
if (ctx->IsRuntime() ||
(framework::product(x_dims) > 0 && framework::product(y_dims) > 0)) {
PADDLE_ENFORCE_EQ(x_dims, y_dims, "Shape of X and Y should be same");
}
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(x_dims[1], 1,
"Each row of Input(X) contains a real value, "
"so the 2nd dimension of Input(X) must be 1.");
}
ctx->SetOutputDim("Residual", x_dims);
ctx->SetOutputDim("Out", {x_dims[0], 1});

@ -46,11 +46,18 @@ class LayerNormOp : public framework::OperatorWithKernel {
int right = static_cast<int>(matrix_dim[1]);
if (ctx->HasInput("Scale")) {
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale").size(), 1);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale")[0], right);
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale")[0], right,
"scale should with right");
}
}
if (ctx->HasInput("Bias")) {
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias").size(), 1);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias")[0], right);
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias")[0], right,
"bias should with right");
}
}
ctx->SetOutputDim("Y", ctx->GetInputDim("X"));

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/linear_chain_crf_op.h"
#include <memory>
namespace paddle {
@ -152,12 +153,19 @@ class LinearChainCRFOp : public framework::OperatorWithKernel {
auto transition_dims = ctx->GetInputDim("Transition");
PADDLE_ENFORCE_EQ(transition_dims.size(), 2,
"The Input(Transition) should be a 2-D tensor.");
PADDLE_ENFORCE_EQ(
transition_dims[0] - 2, transition_dims[1],
"An invalid dimension for the Input(Transition), which should "
"be a 2-D tensor with shape [(D + 2) x D].");
PADDLE_ENFORCE_EQ(
emission_dims[1], transition_dims[1],
bool check = true;
if ((!ctx->IsRuntime()) &&
(transition_dims[0] <= 0 || transition_dims[1] <= 0)) {
check = false;
}
if (check) {
PADDLE_ENFORCE_EQ(
transition_dims[0] - 2, transition_dims[1],
"An invalid dimension for the Input(Transition), which should "
"be a 2-D tensor with shape [(D + 2) x D].");
}
PADDLE_INFERSHAPE_ENFORCE_EQ(
ctx, emission_dims[1], transition_dims[1],
"The 2nd dimension of the Input(Emission) and the Input(Transition) "
"should be equal to the tag number.");
@ -165,8 +173,8 @@ class LinearChainCRFOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(label_dims.size() == 2UL && label_dims[1] == 1UL,
"The Input(Label) should be a 2-D tensor with the 2nd "
"dimensions fixed to 1.");
PADDLE_ENFORCE_EQ(
emission_dims[0], label_dims[0],
PADDLE_INFERSHAPE_ENFORCE_EQ(
ctx, emission_dims[0], label_dims[0],
"The height of Input(Emission) and the height of Input(Label) "
"should be the same.");
@ -211,12 +219,19 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel {
auto transition_exps_dims = ctx->GetInputDim("TransitionExps");
PADDLE_ENFORCE_EQ(transition_exps_dims.size(), 2,
"The Input(TransitionExps) should be a 2-D tensor.");
PADDLE_ENFORCE_EQ(
transition_exps_dims[0] - 2, transition_exps_dims[1],
"An invalid dimension for the Input(TransitionExps), which should "
"be a 2-D tensor with shape [(D + 2) x D].");
PADDLE_ENFORCE_EQ(
emission_exps_dims[1], transition_exps_dims[1],
bool check = true;
if ((!ctx->IsRuntime()) &&
(transition_exps_dims[0] <= 0 || transition_exps_dims[1] <= 0)) {
check = false;
}
if (check) {
PADDLE_ENFORCE_EQ(
transition_exps_dims[0] - 2, transition_exps_dims[1],
"An invalid dimension for the Input(TransitionExps), which should "
"be a 2-D tensor with shape [(D + 2) x D].");
}
PADDLE_INFERSHAPE_ENFORCE_EQ(
ctx, emission_exps_dims[1], transition_exps_dims[1],
"The 2nd dimension of the Input(EmissionExps) and the "
"Input(TransitionExps) should be equal to the tag number.");
@ -224,8 +239,8 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(label_dims.size() == 2UL && label_dims[1] == 1UL,
"The Input(Label) should be a 2-D tensor with the 2nd "
"dimensions fixed to 1.");
PADDLE_ENFORCE_EQ(
emission_exps_dims[0], label_dims[0],
PADDLE_INFERSHAPE_ENFORCE_EQ(
ctx, emission_exps_dims[0], label_dims[0],
"The height of Input(EmissionExps) and the height of Input(Label) "
"should be the same.");

@ -41,10 +41,11 @@ class AccuracyOp : public framework::OperatorWithKernel {
// it's the output of topk.
PADDLE_ENFORCE_EQ(label_dim.size(), 2, "label's rank must be 2.");
PADDLE_ENFORCE_EQ(label_dim[1], 1, "label's second dimension must be 1");
PADDLE_ENFORCE_EQ(inference_dim[0], label_dim[0],
"the inference tensor's num_rows must be"
" the same as label.");
PADDLE_INFERSHAPE_ENFORCE_EQ(ctx, label_dim[1], 1,
"label's second dimension must be 1");
PADDLE_INFERSHAPE_ENFORCE_EQ(ctx, inference_dim[0], label_dim[0],
"the inference tensor's num_rows must be"
" the same as label.");
ctx->SetOutputDim("Accuracy", {1});
ctx->SetOutputDim("Correct", {1});

@ -28,12 +28,13 @@ class AucOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasInput("Label"),
"Input of Label should not be null.");
auto predict_width = ctx->GetInputDim("Predict")[1];
PADDLE_ENFORCE_EQ(predict_width, 2, "Only support binary classification");
PADDLE_INFERSHAPE_ENFORCE_EQ(ctx, predict_width, 2,
"Only support binary classification");
auto predict_height = ctx->GetInputDim("Predict")[0];
auto label_height = ctx->GetInputDim("Label")[0];
PADDLE_ENFORCE_EQ(predict_height, label_height,
"Out and Label should have same height.");
PADDLE_INFERSHAPE_ENFORCE_EQ(ctx, predict_height, label_height,
"Out and Label should have same height.");
int num_pred_buckets = ctx->Attrs().Get<int>("num_thresholds") + 1;
int slide_steps = ctx->Attrs().Get<int>("slide_steps");

@ -40,30 +40,40 @@ class PrecisionRecallOp : public framework::OperatorWithKernel {
auto max_probs_dims = ctx->GetInputDim("MaxProbs");
auto labels_dims = ctx->GetInputDim("Labels");
PADDLE_ENFORCE_EQ(max_probs_dims[1], 1,
"Each instance contains one max probability, so the "
"shape of Input(MaxProbs) should be [batch_size, 1].");
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Indices"), max_probs_dims,
"The shape of Input(Indices) should be [batch_size, 1].");
PADDLE_ENFORCE_EQ(max_probs_dims[0], labels_dims[0],
"The 1st dimension of Input(MaxProbs) and "
"Input(Labels) both are batch_size and the shape should "
"be the same.");
PADDLE_ENFORCE_EQ(labels_dims[1], 1,
"The 2nd dimension of Input(Labels) contains instance "
"label and the shape should be equal to 1.");
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(max_probs_dims[1], 1,
"Each instance contains one max probability, so the "
"shape of Input(MaxProbs) should be [batch_size, 1].");
PADDLE_ENFORCE_EQ(
ctx->GetInputDim("Indices"), max_probs_dims,
"The shape of Input(Indices) should bes same with max_probs_dims");
PADDLE_ENFORCE_EQ(
max_probs_dims[0], labels_dims[0],
"The 1st dimension of Input(MaxProbs) and "
"Input(Labels) both are batch_size and the shape should "
"be the same.");
PADDLE_ENFORCE_EQ(labels_dims[1], 1,
"The 2nd dimension of Input(Labels) contains instance "
"label and the shape should be equal to 1.");
}
if (ctx->HasInput("Weights")) {
auto weights_dims = ctx->GetInputDim("Weights");
PADDLE_ENFORCE_EQ(weights_dims,
framework::make_ddim({max_probs_dims[0], 1}),
"The shape of Input(Weights) should be "
"[batch_size, 1].");
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(weights_dims,
framework::make_ddim({max_probs_dims[0], 1}),
"The shape of Input(Weights) should be "
"[batch_size, 1].");
}
}
if (ctx->HasInput("StatesInfo")) {
auto states_dims = ctx->GetInputDim("StatesInfo");
PADDLE_ENFORCE_EQ(states_dims, framework::make_ddim({cls_num, 4}),
"The shape of Input(StatesInfo) should be "
"[class_number, 4].");
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(states_dims, framework::make_ddim({cls_num, 4}),
"The shape of Input(StatesInfo) should be "
"[class_number, 4].");
}
}
// Layouts of BatchMetrics and AccumMetrics both are:

@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/minus_op.h"
#include <memory>
#include <string>
#include <vector>
@ -38,9 +39,12 @@ class MinusOp : public framework::OperatorWithKernel {
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
PADDLE_ENFORCE_EQ(
x_dims, y_dims,
"Minus operator must take two tensor with same num of elements");
if (ctx->IsRuntime() ||
(framework::product(x_dims) > 0 && framework::product(y_dims) > 0)) {
PADDLE_ENFORCE_EQ(
x_dims, y_dims,
"Minus operator must take two tensor with same num of elements");
}
ctx->SetOutputDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}

@ -28,9 +28,16 @@ class ModifiedHuberLossOp : public framework::OperatorWithKernel {
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
PADDLE_ENFORCE_EQ(x_dims, y_dims, "The shape of X and Y must be the same.");
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "The tensor rank of X must be 2.");
PADDLE_ENFORCE_EQ(x_dims[1], 1, "The 2nd dimension of X must be 1.");
if (ctx->IsRuntime() ||
(framework::product(x_dims) > 0 && framework::product(y_dims) > 0)) {
PADDLE_ENFORCE_EQ(x_dims, y_dims,
"The shape of X and Y must be the same.");
}
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(x_dims[1], 1, "The 2nd dimension of X must be 1.");
}
ctx->SetOutputDim("IntermediateVal", x_dims);
ctx->SetOutputDim("Out", {x_dims[0], 1});
@ -90,11 +97,13 @@ class ModifiedHuberLossGradOp : public framework::OperatorWithKernel {
auto intermediate_dims = ctx->GetInputDim("IntermediateVal");
auto out_grad_dims = ctx->GetInputDim(framework::GradVarName("Out"));
PADDLE_ENFORCE_EQ(
intermediate_dims, x_dims,
"The shape of X and intermediate value must be the same.");
PADDLE_ENFORCE_EQ(out_grad_dims, x_dims,
"The shape of Input(Out@Grad) and X must be the same.");
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(
intermediate_dims, x_dims,
"The shape of X and intermediate value must be the same.");
PADDLE_ENFORCE_EQ(out_grad_dims, x_dims,
"The shape of Input(Out@Grad) and X must be the same.");
}
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/pad_constant_like_op.h"
#include <memory>
namespace paddle {
namespace operators {
@ -38,8 +39,16 @@ class PadConstantLikeOp : public framework::OperatorWithKernel {
"The dimention of X and Y should be the same.");
for (int i = 0; i < x_dim.size(); ++i) {
PADDLE_ENFORCE_GE(x_dim[i], y_dim[i]);
if ((!ctx->IsRuntime()) && ((x_dim[i] == -1) || (y_dim[i] == -1))) {
continue;
} else {
PADDLE_ENFORCE_GE(
x_dim[i], y_dim[i],
"expected X_dim[i] >= Y_dim[i], but received %d < %d for dim %d",
x_dim[i], y_dim[i], i);
}
}
ctx->SetOutputDim("Out", x_dim);
ctx->ShareLoD("X", /*->*/ "Out");
}
@ -162,7 +171,14 @@ class PadConstantLikeOpGrad : public framework::OperatorWithKernel {
ctx->ShareLoD("Y", /*->*/ y_grad_name);
for (int i = 0; i < y_dim.size(); ++i) {
PADDLE_ENFORCE_GE(dout_dim[i], y_dim[i]);
if ((!ctx->IsRuntime()) && ((dout_dim[i] == -1) || (y_dim[i] == -1))) {
continue;
} else {
PADDLE_ENFORCE_GE(dout_dim[i], y_dim[i],
"expected Out_dim[i] >= Y_dim[i], but received %d "
"< %d for dim %d",
dout_dim[i], y_dim[i], i);
}
}
}
}

@ -34,9 +34,16 @@ class PadOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(x_dim.size() * 2, int64_t(paddings.size()),
"Size of paddings should be equal to 2 * dimension size "
"of input tensor.");
for (size_t i = 0; i < paddings.size(); ++i) {
PADDLE_ENFORCE_GE(paddings[i], 0, "paddings should >= 0.");
}
std::vector<int64_t> out_dims(x_dim.size());
for (int i = 0; i < x_dim.size(); ++i) {
out_dims[i] = x_dim[i] + paddings[i * 2] + paddings[i * 2 + 1];
if ((!ctx->IsRuntime()) && (x_dim[i] == -1)) {
out_dims[i] = -1;
} else {
out_dims[i] = x_dim[i] + paddings[i * 2] + paddings[i * 2 + 1];
}
}
ctx->SetOutputDim("Out", framework::make_ddim(out_dims));
if (out_dims[0] == x_dim[0]) {
@ -100,18 +107,14 @@ class PadOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out"));
auto& paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
for (int i = 0; i < dout_dims.size(); ++i) {
dout_dims[i] -= (paddings[i * 2] + paddings[i * 2 + 1]);
}
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out"));
auto& paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
for (int i = 0; i < dout_dims.size(); ++i) {
dout_dims[i] -= (paddings[i * 2] + paddings[i * 2 + 1]);
if (ctx->IsRuntime() || (dout_dims[i] != -1)) {
dout_dims[i] -= (paddings[i * 2] + paddings[i * 2 + 1]);
}
}
ctx->SetOutputDim(x_grad_name, dout_dims);
}

@ -61,23 +61,31 @@ class PositiveNegativePairOp : public framework::OperatorWithKernel {
auto query_dim = ctx->GetInputDim("QueryID");
PADDLE_ENFORCE_EQ(score_dim.size(), 2, "Score should be a 2-D tensor.");
PADDLE_ENFORCE_EQ(label_dim.size(), 2, "Label should be a 2-D tensor.");
PADDLE_ENFORCE_EQ(
label_dim[0], score_dim[0],
"Tensor Score and Label should have the same height (batch size).");
PADDLE_ENFORCE_EQ(label_dim[1], 1,
"The width of Label should be 1, i.e. each item should "
"have a scalar label.");
PADDLE_ENFORCE(query_dim == label_dim,
"QueryID should have the same shape as Label.");
if (ctx->HasInput("Weight")) {
PADDLE_ENFORCE(ctx->GetInputDim("Weight") == label_dim,
"Weight should have the same shape as Label.");
if (ctx->IsRuntime() ||
(score_dim[0] > 0 && label_dim[0] > 0 && query_dim[0] > 0)) {
PADDLE_ENFORCE_EQ(
label_dim[0], score_dim[0],
"Tensor Score and Label should have the same height (batch size).");
PADDLE_ENFORCE_EQ(label_dim[1], 1,
"The width of Label should be 1, i.e. each item should "
"have a scalar label.");
PADDLE_ENFORCE(query_dim == label_dim,
"QueryID should have the same shape as Label.");
if (ctx->HasInput("Weight")) {
PADDLE_ENFORCE(ctx->GetInputDim("Weight") == label_dim,
"Weight should have the same shape as Label.");
}
int column = ctx->Attrs().Get<int>("column");
auto depth = score_dim[1];
PADDLE_ENFORCE(column < depth && column >= -depth,
"Attribute column should be in the range of [-%l, %l)",
depth, depth);
}
int column = ctx->Attrs().Get<int>("column");
auto depth = score_dim[1];
PADDLE_ENFORCE(column < depth && column >= -depth,
"Attribute column should be in the range of [-%l, %l)",
depth, depth);
ctx->SetOutputDim("PositivePair", scalar_dim);
ctx->SetOutputDim("NegativePair", scalar_dim);

@ -37,9 +37,11 @@ class ROIAlignOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(rois_dims.size() == 2,
"ROIs should be a 2-D LoDTensor of shape (num_rois, 4)"
"given as [[x1, y1, x2, y2], ...].");
PADDLE_ENFORCE(rois_dims[1] == 4,
"ROIs should be a 2-D LoDTensor of shape (num_rois, 4)"
"given as [[x1, y1, x2, y2], ...].");
if (ctx->IsRuntime()) {
PADDLE_ENFORCE(rois_dims[1] == 4,
"ROIs should be a 2-D LoDTensor of shape (num_rois, 4)"
"given as [[x1, y1, x2, y2], ...].");
}
int pooled_height = ctx->Attrs().Get<int>("pooled_height");
int pooled_width = ctx->Attrs().Get<int>("pooled_width");
float spatial_scale = ctx->Attrs().Get<float>("spatial_scale");

@ -11,7 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/sample_logits_op.h"
#include <memory>
#include "paddle/fluid/operators/math/sample_prob.h"
@ -141,7 +140,10 @@ class SampleLogitsOp : public framework::OperatorWithKernel {
"The labels should be a 2-D tensor.");
const int num_samples = ctx->Attrs().Get<int>("num_samples");
const int num_sampled_classes = labels_dims[1] + num_samples;
int num_sampled_classes = labels_dims[1] + num_samples;
if ((!ctx->IsRuntime()) && labels_dims[1] <= 0) {
num_sampled_classes = -1;
}
ctx->SetOutputDim("Samples", {logits_dims[0], num_sampled_classes});
ctx->SetOutputDim("Probabilities", {logits_dims[0], num_sampled_classes});
ctx->SetOutputDim("SampledLogits", {logits_dims[0], num_sampled_classes});

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/smooth_l1_loss_op.h"
#include <memory>
namespace paddle {
namespace operators {
@ -27,15 +28,39 @@ class SmoothL1LossOp : public framework::OperatorWithKernel {
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
PADDLE_ENFORCE_EQ(x_dims, y_dims);
bool check = true;
if ((!ctx->IsRuntime()) &&
(framework::product(x_dims) <= 0 || framework::product(y_dims) <= 0)) {
check = false;
}
if (check) {
PADDLE_ENFORCE_EQ(x_dims, y_dims);
}
PADDLE_ENFORCE_GE(x_dims.size(), 2,
"The tensor rank of Input(X) should not be less than 2.");
if (ctx->HasInput("InsideWeight")) {
PADDLE_ENFORCE(ctx->HasInput("OutsideWeight"),
"If weights are provided, must specify both "
"inside and outside weights.");
PADDLE_ENFORCE_EQ(ctx->GetInputDim("InsideWeight"), x_dims);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("OutsideWeight"), x_dims);
auto dims = ctx->GetInputDim("InsideWeight");
bool check = true;
if ((!ctx->IsRuntime()) &&
(framework::product(dims) <= 0 || framework::product(x_dims) <= 0)) {
check = false;
}
if (check) {
PADDLE_ENFORCE_EQ(dims, x_dims);
}
dims = ctx->GetInputDim("OutsideWeight");
check = true;
if ((!ctx->IsRuntime()) &&
(framework::product(dims) <= 0 || framework::product(x_dims) <= 0)) {
check = false;
}
if (check) {
PADDLE_ENFORCE_EQ(dims, x_dims);
}
}
ctx->SetOutputDim("Diff", x_dims);
@ -110,11 +135,11 @@ class SmoothL1LossGradOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_GE(out_dims.size(), 2,
"The tensor rank of Input(Out@Grad) should be 2.");
PADDLE_ENFORCE_EQ(out_dims[0], in_dims[0],
"The 1st dimension of Input(Out@Grad) must be "
"same as input.");
PADDLE_ENFORCE_EQ(out_dims[1], 1,
"The 2nd dimension of Input(Out@Grad) must be 1.");
PADDLE_INFERSHAPE_ENFORCE_EQ(ctx, out_dims[0], in_dims[0],
"The 1st dimension of Input(Out@Grad) must be "
"same as input.");
PADDLE_INFERSHAPE_ENFORCE_EQ(
ctx, out_dims[1], 1, "The 2nd dimension of Input(Out@Grad) must be 1.");
auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y");

@ -106,24 +106,36 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
auto logits_dims = ctx->GetInputDim("Logits");
auto labels_dims = ctx->GetInputDim("Label");
int rank = logits_dims.size();
PADDLE_ENFORCE_EQ(
logits_dims.size(), 2UL,
"The input of softmax_with_cross_entropy should be a 2-D tensor.");
PADDLE_ENFORCE_EQ(labels_dims.size(), 2UL,
"The labels should be a 2-D tensor.");
rank, labels_dims.size(),
"Input(logits) and Input(Label) shall have the same rank.");
bool check = ctx->IsRuntime() || (framework::product(logits_dims) > 0 &&
framework::product(labels_dims) > 0);
if (check) {
PADDLE_ENFORCE_EQ(framework::slice_ddim(logits_dims, 0, rank - 1),
framework::slice_ddim(labels_dims, 0, rank - 1),
"Input(X) and Input(Label) shall have the same shape "
"except the last dimension.");
}
if (ctx->Attrs().Get<bool>("soft_label")) {
PADDLE_ENFORCE_EQ(logits_dims[1], labels_dims[1],
"If Attr(soft_label) == true, the 2nd dimension of "
"Input(X) and Input(Label) should be equal.");
if (check) {
PADDLE_ENFORCE_EQ(logits_dims[rank - 1], labels_dims[rank - 1],
"If Attr(soft_label) == true, the last dimension of "
"Input(X) and Input(Label) should be equal.");
}
} else {
PADDLE_ENFORCE_EQ(labels_dims[1], 1UL,
"If Attr(soft_label) == false, the 2nd dimension of "
PADDLE_ENFORCE_EQ(labels_dims[rank - 1], 1UL,
"If Attr(softLabel) == false, the last dimension of "
"Input(Label) should be 1.");
}
ctx->SetOutputDim("Softmax", logits_dims);
ctx->SetOutputDim("Loss", {logits_dims[0], 1});
auto loss_dims = logits_dims;
loss_dims[rank - 1] = 1;
ctx->SetOutputDim("Loss", loss_dims);
ctx->ShareLoD("Logits", /*->*/ "Softmax");
ctx->ShareLoD("Logits", /*->*/ "Loss");
@ -152,16 +164,33 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
auto softmax_dims = ctx->GetInputDim("Softmax");
auto labels_dims = ctx->GetInputDim("Label");
PADDLE_ENFORCE_EQ(labels_dims.size(), 2UL,
"The labels should be a 2-D tensor.");
int rank = softmax_dims.size();
PADDLE_ENFORCE_EQ(
rank, labels_dims.size(),
"Input(logits) and Input(Label) shall have the same rank.");
bool check = true;
if ((!ctx->IsRuntime()) && (framework::product(softmax_dims) <= 0 ||
framework::product(labels_dims) <= 0)) {
check = false;
}
if (check) {
PADDLE_ENFORCE_EQ(
framework::slice_ddim(softmax_dims, 0, rank - 1),
framework::slice_ddim(labels_dims, 0, rank - 1),
"Input(Softmax) and Input(Label) shall have the same shape "
"except the last dimension.");
}
if (ctx->Attrs().Get<bool>("soft_label")) {
PADDLE_ENFORCE_EQ(softmax_dims[1], labels_dims[1],
"When Attr(soft_label) == true, the 2nd dimension of "
"Input(X) and Input(Label) should be equal.");
if (check) {
PADDLE_ENFORCE_EQ(softmax_dims[rank - 1], labels_dims[rank - 1],
"If Attr(soft_label) == true, the last dimension of "
"Input( Softmax) and Input(Label) should be equal.");
}
} else {
PADDLE_ENFORCE_EQ(labels_dims[1], 1UL,
"When Attr(soft_label) == false, the 2nd dimension of "
PADDLE_ENFORCE_EQ(labels_dims[rank - 1], 1UL,
"If Attr(softLabel) == false, the last dimension of "
"Input(Label) should be 1.");
}

@ -400,9 +400,15 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
auto soft_label = context.Attr<bool>("soft_label");
auto ignore_index = context.Attr<int>("ignore_index");
int rank = logits->dims().size();
if (soft_label) {
int batch_size = logits->dims()[0];
int feature_size = logits->dims()[1];
int batch_size = 1;
for (int i = 0; i < rank - 1; ++i) {
batch_size *= logits->dims()[i];
}
int feature_size = logits->dims()[rank - 1];
auto* logits_data = logits->data<T>();
auto* labels_data = labels->data<T>();
SoftmaxWithCrossEntropyFusedKernel(
@ -410,14 +416,23 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
feature_size, context.cuda_device_context().stream());
} else {
if (!context.Attr<bool>("numeric_stable_mode")) {
math::SoftmaxCUDNNFunctor<T>()(context.cuda_device_context(), logits,
softmax);
// reshape to 2d
Tensor logits_2d = framework::ReshapeToMatrix(*logits, rank - 1);
Tensor softmax_2d = framework::ReshapeToMatrix(*softmax, rank - 1);
Tensor loss_2d = framework::ReshapeToMatrix(*loss, rank - 1);
Tensor labels_2d = framework::ReshapeToMatrix(*labels, rank - 1);
math::SoftmaxCUDNNFunctor<T>()(context.cuda_device_context(),
&logits_2d, &softmax_2d);
math::CrossEntropyFunctor<platform::CUDADeviceContext, T>()(
context.cuda_device_context(), loss, softmax, labels, false,
ignore_index);
context.cuda_device_context(), &loss_2d, &softmax_2d, &labels_2d,
false, ignore_index);
} else {
int batch_size = logits->dims()[0];
int feature_size = logits->dims()[1];
int batch_size = 1;
for (int i = 0; i < rank - 1; ++i) {
batch_size *= logits->dims()[i];
}
int feature_size = logits->dims()[rank - 1];
auto* logits_data = logits->data<T>();
auto* labels_data = labels->data<int64_t>();
HardLabelSoftmaxWithCrossEntropy<T>(
@ -443,8 +458,13 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
context.device_context(), logit_grad);
T* logit_grad_data = logit_grad->data<T>();
const int batch_size = logit_grad->dims()[0];
const int class_num = logit_grad->dims()[1];
int rank = logit_grad->dims().size();
int batch_size = 1;
for (int i = 0; i < rank - 1; ++i) {
batch_size *= logit_grad->dims()[i];
}
const int class_num = logit_grad->dims()[rank - 1];
int block = 512;
auto stream = context.cuda_device_context().stream();
auto ignore_index = context.Attr<int>("ignore_index");

@ -40,15 +40,22 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> {
softmax->mutable_data<T>(context.GetPlace());
loss->mutable_data<T>(context.GetPlace());
int axis_dim = logits->dims()[logits->dims().size() - 1];
// reshape to 2D tensor
int rank = logits->dims().size();
Tensor logits_2d = framework::ReshapeToMatrix(*logits, rank - 1);
Tensor labels_2d = framework::ReshapeToMatrix(*labels, rank - 1);
Tensor loss_2d = framework::ReshapeToMatrix(*loss, rank - 1);
Tensor softmax_2d = framework::ReshapeToMatrix(*softmax, rank - 1);
int axis_dim = logits->dims()[rank - 1];
auto& dev_ctx =
context.template device_context<platform::CPUDeviceContext>();
math::SoftmaxFunctor<platform::CPUDeviceContext, T, false>()(
dev_ctx, axis_dim, logits, softmax);
dev_ctx, axis_dim, &logits_2d, &softmax_2d);
math::CrossEntropyFunctor<platform::CPUDeviceContext, T>()(
dev_ctx, loss, softmax, labels, context.Attr<bool>("soft_label"),
context.Attr<int>("ignore_index"));
dev_ctx, &loss_2d, &softmax_2d, &labels_2d,
context.Attr<bool>("soft_label"), context.Attr<int>("ignore_index"));
}
};
@ -63,13 +70,19 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
context.Output<Tensor>(framework::GradVarName("Logits"));
logit_grad->ShareDataWith(*context.Input<Tensor>("Softmax"));
const int class_num = logit_grad->dims()[1];
auto out_grad_mat = EigenMatrix<T>::From(*out_grad);
auto logit_grad_mat = EigenMatrix<T>::From(*logit_grad);
int rank = logit_grad->dims().size();
const int class_num = logit_grad->dims()[rank - 1];
// reshape to 2d
Tensor logit_grad_2d = framework::ReshapeToMatrix(*logit_grad, rank - 1);
Tensor out_grad_2d = framework::ReshapeToMatrix(*out_grad, rank - 1);
auto out_grad_mat = EigenMatrix<T>::From(out_grad_2d);
auto logit_grad_mat = EigenMatrix<T>::From(logit_grad_2d);
auto& place = *context.template device_context<platform::CPUDeviceContext>()
.eigen_device();
if (context.Attr<bool>("soft_label")) {
auto lbl_mat = EigenMatrix<T>::From(*labels);
Tensor labels_2d = framework::ReshapeToMatrix(*labels, rank - 1);
auto lbl_mat = EigenMatrix<T>::From(labels_2d);
logit_grad_mat.device(place) =
out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, class_num)) *
(logit_grad_mat - lbl_mat);
@ -78,7 +91,8 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
logit_grad_mat *
out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, class_num));
const int batch_size = logit_grad->dims()[0];
const int batch_size = logit_grad_2d.dims()[0];
const int64_t* label_data = labels->data<int64_t>();
T* logit_grad_data = logit_grad->data<T>();
const T* out_grad_data = out_grad->data<T>();

@ -40,19 +40,44 @@ class SpaceToDepthOp : public framework::OperatorWithKernel {
auto blocksize = ctx->Attrs().Get<int64_t>("blocksize");
PADDLE_ENFORCE_GT(blocksize, 1, "The blocksize should be Greater than 1");
PADDLE_ENFORCE_GT(x_dims[1], 0, "input channel should be Greater than 0");
PADDLE_ENFORCE_GT(x_dims[2], 0, "input Height should be Greater than 0");
PADDLE_ENFORCE_GT(x_dims[3], 0, "input Width should be Greater than 0");
PADDLE_ENFORCE_EQ(x_dims[1] % (blocksize * blocksize), 0,
"input channel should be divisible of the square of "
"SpaceToDepthOp blocksize");
PADDLE_ENFORCE_EQ(x_dims[2] % (blocksize), 0,
"input Height should be divisible of the square of "
"SpaceToDepthOp blocksize");
PADDLE_ENFORCE_EQ(x_dims[3] % (blocksize), 0,
"input Width should be divisible of the square of "
"SpaceToDepthOp blocksize");
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_GT(x_dims[1], 0, "input channel should be Greater than 0");
PADDLE_ENFORCE_GT(x_dims[2], 0, "input Height should be Greater than 0");
PADDLE_ENFORCE_GT(x_dims[3], 0, "input Width should be Greater than 0");
PADDLE_ENFORCE_EQ(x_dims[1] % (blocksize * blocksize), 0,
"input channel should be divisible of the square of "
"SpaceToDepthOp blocksize");
PADDLE_ENFORCE_EQ(x_dims[2] % (blocksize), 0,
"input Height should be divisible of the square of "
"SpaceToDepthOp blocksize");
PADDLE_ENFORCE_EQ(x_dims[3] % (blocksize), 0,
"input Width should be divisible of the square of "
"SpaceToDepthOp blocksize");
} else {
if (x_dims[1] != -1) {
PADDLE_ENFORCE_GT(x_dims[1], 0,
"input channel should be Greater than 0");
PADDLE_ENFORCE_EQ(x_dims[1] % (blocksize * blocksize), 0,
"input channel should be divisible of the square of "
"SpaceToDepthOp blocksize");
}
if (x_dims[2] != -1) {
PADDLE_ENFORCE_GT(x_dims[2], 0,
"input Height should be Greater than 0");
PADDLE_ENFORCE_EQ(x_dims[2] % (blocksize), 0,
"input Height should be divisible of the square of "
"SpaceToDepthOp blocksize");
}
if (x_dims[3] != -1) {
PADDLE_ENFORCE_GT(x_dims[3], 0, "input Width should be Greater than 0");
PADDLE_ENFORCE_EQ(x_dims[3] % (blocksize), 0,
"input Width should be divisible of the square of "
"SpaceToDepthOp blocksize");
}
}
VLOG(3) << "SpaceToDepthOp operator x.shape=" << x_dims
<< "Attribute blocksize" << blocksize << std::endl;

@ -157,7 +157,9 @@ class SplitLoDTensorInferShape : public framework::InferShapeBase {
auto mask_dim = context->GetInputDim("Mask");
PADDLE_ENFORCE_EQ(mask_dim.size(), 2);
PADDLE_ENFORCE_EQ(mask_dim[1], 1);
if (context->IsRuntime()) {
PADDLE_ENFORCE_EQ(mask_dim[1], 1);
}
context->SetOutputDim("OutTrue", context->GetInputDim("X"));
context->SetOutputDim("OutFalse", context->GetInputDim("X"));

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save