Fix ImageExpandFunction.

cblas_new
hedaoyuan 8 years ago
parent 07cde439aa
commit 9e6ed83cc4

@ -45,9 +45,7 @@ public:
numOutputs_ = 1;
}
virtual void calc(const BufferArgs& inputs, const BufferArgs& outputs) {}
void check(const TensorShape& image, const TensorShape& sequence) const {
void checkShape(const TensorShape& image, const TensorShape& sequence) const {
// image shape should be 4-dimensional.
CHECK_EQ(image.ndims(), (size_t)4);
// sequence shape should be 3-dimensional.
@ -108,12 +106,18 @@ public:
ImageExpandFunction::init(config);
}
void check(const BufferArgs& inputs, const BufferArgs& outputs) override {
const TensorShape& image = inputs[0].shape();
const TensorShape& sequence = outputs[0].shape();
checkShape(image, sequence);
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(numInputs_, inputs.size());
CHECK_EQ(numOutputs_, outputs.size());
check(inputs, outputs);
const TensorShape& image = inputs[0].shape();
const TensorShape& sequence = outputs[0].shape();
check(image, sequence);
TensorShape imShape = TensorShape({image[1], image[2], image[3]});
TensorShape colShape = getColShape(image, sequence);
@ -149,15 +153,21 @@ public:
ImageExpandFunction::init(config);
}
void check(const BufferArgs& inputs, const BufferArgs& outputs) override {
const TensorShape& image = outputs[0].shape();
const TensorShape& sequence = inputs[0].shape();
checkShape(image, sequence);
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(numInputs_, inputs.size());
CHECK_EQ(numOutputs_, outputs.size());
check(inputs, outputs);
// Since the implementation of Col2ImFunctor is ADD_TO,
// this function only supports ADD_TO mode.
CHECK_EQ(outputs[0].getArgType(), ADD_TO);
const TensorShape& image = outputs[0].shape();
const TensorShape& sequence = inputs[0].shape();
check(image, sequence);
TensorShape imShape = TensorShape({image[1], image[2], image[3]});
TensorShape colShape = getColShape(image, sequence);

Loading…
Cancel
Save