|
|
|
@ -28,7 +28,7 @@ void FullConnection::SetHasBias(bool has_bias) { this->primitive->value.AsFullCo
|
|
|
|
|
void FullConnection::SetAxis(int axis) { this->primitive->value.AsFullConnection()->axis = axis; }
|
|
|
|
|
void FullConnection::SetUseAxis(bool use_axis) { this->primitive->value.AsFullConnection()->useAxis = use_axis; }
|
|
|
|
|
void FullConnection::SetActivationType(int activationType) {
|
|
|
|
|
this->primitive->value.AsFullConnection()->activationType = (schema::ActivationType) activationType;
|
|
|
|
|
this->primitive->value.AsFullConnection()->activationType = (schema::ActivationType)activationType;
|
|
|
|
|
}
|
|
|
|
|
#else
|
|
|
|
|
|
|
|
|
@ -47,43 +47,58 @@ int FullConnection::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
|
|
|
|
|
MS_ASSERT(this->primitive != nullptr);
|
|
|
|
|
auto input0 = inputs_.front();
|
|
|
|
|
MS_ASSERT(input0 != nullptr);
|
|
|
|
|
auto input1 = inputs_.at(1);
|
|
|
|
|
auto input1 = inputs_[1];
|
|
|
|
|
MS_ASSERT(input1 != nullptr);
|
|
|
|
|
auto output = outputs_.front();
|
|
|
|
|
MS_ASSERT(output != nullptr);
|
|
|
|
|
output->set_data_type(input0->data_type());
|
|
|
|
|
output->SetFormat(input0->GetFormat());
|
|
|
|
|
if (!GetInferFlag()) {
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
if ((GetHasBias() && inputs_.size() != kMultiNum) || (!GetHasBias() && inputs_.size() != kDoubleNum)) {
|
|
|
|
|
MS_LOG(ERROR) << "Input tensors num error";
|
|
|
|
|
return 1;
|
|
|
|
|
return RET_INPUT_TENSOR_ERROR;
|
|
|
|
|
}
|
|
|
|
|
if (GetAxis() < 1 || GetAxis() > static_cast<int>(input0->shape().size())) {
|
|
|
|
|
if (GetUseAxis() && (GetAxis() < 1 || GetAxis() > static_cast<int>(input0->shape().size()))) {
|
|
|
|
|
MS_LOG(ERROR) << "FullConnection axis invalid";
|
|
|
|
|
return 1;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
int new_k = 1;
|
|
|
|
|
for (size_t i = GetAxis(); i < input0->shape().size(); ++i) {
|
|
|
|
|
new_k *= input0->shape().at(i);
|
|
|
|
|
if (GetUseAxis()) {
|
|
|
|
|
for (int i = GetAxis(); i < input0->shape().size(); ++i) {
|
|
|
|
|
new_k *= input0->shape()[i];
|
|
|
|
|
}
|
|
|
|
|
if (new_k != input1->shape().at(1)) {
|
|
|
|
|
if (new_k != input1->shape()[1]) {
|
|
|
|
|
MS_LOG(ERROR) << "Input1 size invalid";
|
|
|
|
|
return 1;
|
|
|
|
|
return RET_INPUT_TENSOR_ERROR;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
new_k = input1->shape()[1];
|
|
|
|
|
}
|
|
|
|
|
if (GetHasBias()) {
|
|
|
|
|
if (inputs_.at(2)->shape()[0] != input1->shape()[0]) {
|
|
|
|
|
if (inputs_[2]->shape()[0] != input1->shape()[0]) {
|
|
|
|
|
MS_LOG(ERROR) << "bias size invalid";
|
|
|
|
|
return 1;
|
|
|
|
|
return RET_INPUT_TENSOR_ERROR;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
std::vector<int> out_shape{inputs_[0]->shape()};
|
|
|
|
|
if (GetUseAxis()) {
|
|
|
|
|
out_shape.resize(GetAxis() + 1);
|
|
|
|
|
out_shape[GetAxis()] = input1->shape()[0];
|
|
|
|
|
} else {
|
|
|
|
|
int total = 1;
|
|
|
|
|
for (int i = 0; i < input0->shape().size(); ++i) {
|
|
|
|
|
total *= input0->shape()[i];
|
|
|
|
|
}
|
|
|
|
|
out_shape.resize(2);
|
|
|
|
|
auto batch_size = total / new_k;
|
|
|
|
|
out_shape[0] = batch_size;
|
|
|
|
|
out_shape[1] = input1->shape()[0];
|
|
|
|
|
}
|
|
|
|
|
output->set_shape(out_shape);
|
|
|
|
|
output->set_data_type(input0->data_type());
|
|
|
|
|
output->SetFormat(input0->GetFormat());
|
|
|
|
|
|
|
|
|
|
return 0;
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
} // namespace lite
|
|
|
|
|
} // namespace mindspore
|
|
|
|
|