// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "subgraph_tests/matmul_multiply_fusion.hpp"

using namespace SubgraphTestsDefinitions;

namespace {
std::vector<MatMulMultiplyFusionShapeParams> shape_params = {
    {{2, 2}, {2, 2}, false, {}},
    {{2, 2}, {2, 2}, false, {1}},
    {{2, 2}, {2, 2}, false, {1, 2}},
    {{2, 2}, {2, 2}, true, {1, 2}},
    {{5}, {5}, false, {}},
    {{5}, {5, 1}, false, {}},
    {{5}, {5, 1}, false, {1}},
    {{5}, {5, 3}, false, {3}},
    {{5}, {3, 5}, true, {3}},
    {{5, 10}, {10, 7}, false, {}},
    {{5, 10}, {7, 10}, true, {}},
    {{5, 10}, {10, 7}, false, {7}},
    {{5, 10}, {7, 10}, true, {7}},
    {{5, 10}, {10, 7}, false, {1, 7}},
    {{5, 10}, {7, 10}, true, {1, 7}},
    {{5, 10}, {2, 10, 7}, false, {2, 1, 7}},
    {{5, 10}, {2, 7, 10}, true, {2, 1, 7}},
    {{5, 10}, {2, 3, 10, 7}, false, {7}},
    {{5, 10}, {2, 3, 7, 10}, true, {7}},
    {{5, 10}, {2, 3, 10, 7}, false, {1, 7}},
    {{5, 10}, {2, 3, 7, 10}, true, {1, 7}},
    {{5, 10}, {2, 3, 10, 7}, false, {1, 1, 7}},
    {{5, 10}, {2, 3, 7, 10}, true, {1, 1, 7}},
    {{5, 10}, {2, 3, 10, 7}, false, {1, 1, 1, 7}},
    {{5, 10}, {2, 3, 7, 10}, true, {1, 1, 1, 7}},
    {{5, 10}, {2, 3, 10, 7}, false, {1, 3, 1, 7}},
    {{5, 10}, {2, 3, 7, 10}, true, {1, 3, 1, 7}},
    {{5, 10}, {2, 3, 10, 7}, false, {2, 3, 1, 7}},
    {{5, 10}, {2, 3, 7, 10}, true, {2, 3, 1, 7}},
    {{5, 10}, {10}, false, {}},
    {{5, 10}, {10}, false, {1}},
    {{2, 3, 5, 10}, {10, 7}, false, {1}},
    {{2, 3, 5, 10}, {7, 10}, true, {1}},
    {{2, 3, 5, 10}, {10, 7}, false, {1, 1, 7}},
    {{2, 3, 5, 10}, {7, 10}, true, {1, 1, 7}},
    {{2, 3, 5, 10}, {10, 7}, false, {1, 1, 1, 7}},
    {{2, 3, 5, 10}, {7, 10}, true, {1, 1, 1, 7}},
    {{2, 3, 5, 10}, {10, 7}, false, {2, 3, 1, 7}},
    {{2, 3, 5, 10}, {7, 10}, true, {2, 3, 1, 7}},
    {{1, 1, 5, 10}, {10, 7}, false, {1, 1, 1, 7}},
    {{1, 1, 5, 10}, {7, 10}, true, {1, 1, 1, 7}},
    {{2, 3, 5, 10}, {3, 10, 7}, false, {7}},
    {{2, 3, 5, 10}, {3, 7, 10}, true, {7}},
    {{2, 3, 5, 10}, {3, 10, 7}, false, {1, 1, 1, 7}},
    {{2, 3, 5, 10}, {3, 7, 10}, true, {1, 1, 1, 7}},
    {{2, 3, 5, 10}, {2, 3, 10, 7}, false, {}},
    {{2, 3, 5, 10}, {2, 3, 7, 10}, true, {}},
    {{2, 3, 5, 10}, {2, 3, 10, 7}, false, {7}},
    {{2, 3, 5, 10}, {2, 3, 7, 10}, true, {7}},
    {{2, 3, 5, 10}, {2, 3, 10, 7}, false, {1, 7}},
    {{2, 3, 5, 10}, {2, 3, 7, 10}, true, {1, 7}},
    {{2, 3, 5, 10}, {2, 3, 10, 7}, false, {1, 1, 1, 7}},
    {{2, 3, 5, 10}, {2, 3, 7, 10}, true, {1, 1, 1, 7}},
    {{2, 3, 5, 10}, {2, 3, 10, 7}, false, {1, 3, 1, 7}},
    {{2, 3, 5, 10}, {2, 3, 7, 10}, true, {1, 3, 1, 7}},
    {{2, 3, 5, 10}, {2, 3, 10, 7}, false, {2, 3, 1, 7}},
    {{2, 3, 5, 10}, {2, 3, 7, 10}, true, {2, 3, 1, 7}},
};

INSTANTIATE_TEST_SUITE_P(smoke_MatMulMultiplyFusion, MatMulMultiplyFusion,
                        ::testing::Combine(
                                ::testing::ValuesIn(shape_params),
                                ::testing::Values(true), // can be fused
                                ::testing::Values(CommonTestUtils::DEVICE_CPU)),
                        MatMulMultiplyFusion::getTestCaseName);

std::vector<MatMulMultiplyFusionShapeParams> negative_shape_params = {
    {{5}, {5}, false, {1}},
    {{5}, {5}, false, {5}},
    {{5}, {5}, false, {5, 1}},
    {{5}, {5, 3}, false, {1, 3}},
    {{2, 2}, {2, 2}, false, {2, 2}},
    {{2, 2}, {2, 2}, true, {2, 2}},
    {{5, 5}, {5, 5}, false, {5, 5}},
    {{5, 5}, {5, 5}, true, {5, 5}},
    {{5, 10}, {10}, false, {5, 1}},
    {{5, 10}, {10, 7}, false, {5, 7}},
    {{5, 10}, {7, 10}, true, {5, 7}},
    {{5, 10}, {10, 5}, false, {5, 5}},
    {{5, 10}, {5, 10}, true, {5, 5}},
    {{1, 1, 5, 10}, {10, 7}, false, {2, 3, 1, 7}},
    {{1, 1, 5, 10}, {7, 10}, true, {2, 3, 1, 7}},
    {{1, 1, 5, 10}, {1, 1, 10, 7}, false, {2, 3, 1, 7}},
    {{1, 1, 5, 10}, {1, 1, 7, 10}, true, {2, 3, 1, 7}},
    {{2, 1, 5, 10}, {1, 1, 10, 7}, false, {2, 3, 1, 7}},
    {{1, 1, 5, 10}, {1, 3, 10, 7}, false, {2, 3, 1, 7}},
    {{2, 3, 5, 10}, {10}, false, {5}},
    {{2, 3, 5, 10}, {10, 7}, false, {2, 3, 5, 7}},
    {{2, 3, 5, 10}, {7, 10}, true, {2, 3, 5, 7}},
    {{2, 3, 5, 10}, {3, 10, 7}, false, {5, 7}},
    {{2, 3, 5, 10}, {3, 7, 10}, true, {5, 7}},
    {{2, 3, 5, 10}, {3, 10, 7}, false, {1, 3, 5, 7}},
    {{2, 3, 5, 10}, {3, 7, 10}, true, {1, 3, 5, 7}},
    {{2, 3, 5, 10}, {3, 10, 7}, false, {2, 3, 5, 7}},
    {{2, 3, 5, 10}, {3, 7, 10}, true, {2, 3, 5, 7}},
    {{2, 3, 5, 10}, {2, 3, 10, 7}, false, {2, 3, 5, 7}},
    {{2, 3, 5, 10}, {2, 3, 7, 10}, true, {2, 3, 5, 7}},
    {{2, 3, 5, 10}, {2, 3, 10, 7}, false, {1, 1, 1, 1, 7}},
};

INSTANTIATE_TEST_SUITE_P(smoke_NegativeMatMulMultiplyFusion, MatMulMultiplyFusion,
                        ::testing::Combine(
                                ::testing::ValuesIn(negative_shape_params),
                                ::testing::Values(false), // cannot be fused
                                ::testing::Values(CommonTestUtils::DEVICE_CPU)),
                        MatMulMultiplyFusion::getTestCaseName);

std::vector<MatMulMultiplyFusionShapeParams> shape_params2 = {
    {{2, 2}, {2, 2}, false, {}},
    {{2, 2}, {2, 2}, false, {1}},
    {{2, 2}, {2, 2}, false, {1, 2}},
    {{2, 2}, {2, 2}, true, {1, 2}},
    {{5, 10}, {10, 7}, false, {}},
    {{5, 10}, {7, 10}, true, {}},
    {{5, 10}, {10, 7}, false, {7}},
    {{5, 10}, {7, 10}, true, {7}},
    {{5, 10}, {10, 7}, false, {1, 7}},
    {{5, 10}, {7, 10}, true, {1, 7}},
    {{5, 10}, {2, 10, 7}, false, {2, 1, 7}},
    {{5, 10}, {2, 7, 10}, true, {2, 1, 7}},
    {{5, 10}, {2, 3, 10, 7}, false, {7}},
    {{5, 10}, {2, 3, 7, 10}, true, {7}},
    {{5, 10}, {2, 3, 10, 7}, false, {1, 7}},
    {{5, 10}, {2, 3, 7, 10}, true, {1, 7}},
    {{5, 10}, {2, 3, 10, 7}, false, {1, 1, 7}},
    {{5, 10}, {2, 3, 7, 10}, true, {1, 1, 7}},
    {{5, 10}, {2, 3, 10, 7}, false, {1, 1, 1, 7}},
    {{5, 10}, {2, 3, 7, 10}, true, {1, 1, 1, 7}},
    {{5, 10}, {2, 3, 10, 7}, false, {1, 3, 1, 7}},
    {{5, 10}, {2, 3, 7, 10}, true, {1, 3, 1, 7}},
    {{5, 10}, {2, 3, 10, 7}, false, {2, 3, 1, 7}},
    {{5, 10}, {2, 3, 7, 10}, true, {2, 3, 1, 7}},
    {{2, 3, 5, 10}, {10, 7}, false, {1}},
    {{2, 3, 5, 10}, {10, 7}, false, {1, 1}},
    {{2, 3, 5, 10}, {3, 10, 7}, false, {7}},
    {{2, 3, 5, 10}, {3, 7, 10}, true, {7}},
    {{2, 3, 5, 10}, {2, 3, 10, 7}, false, {}},
    {{2, 3, 5, 10}, {2, 3, 7, 10}, true, {}},
    {{2, 3, 5, 10}, {2, 3, 10, 7}, false, {7}},
    {{2, 3, 5, 10}, {2, 3, 7, 10}, true, {7}},
    {{2, 3, 5, 10}, {2, 3, 10, 7}, false, {1, 7}},
    {{2, 3, 5, 10}, {2, 3, 7, 10}, true, {1, 7}},
    {{2, 3, 5, 10}, {2, 3, 10, 7}, false, {1, 1, 1, 7}},
    {{2, 3, 5, 10}, {2, 3, 7, 10}, true, {1, 1, 1, 7}},
    {{2, 3, 5, 10}, {2, 3, 10, 7}, false, {1, 3, 1, 7}},
    {{2, 3, 5, 10}, {2, 3, 7, 10}, true, {1, 3, 1, 7}},
    {{2, 3, 5, 10}, {2, 3, 10, 7}, false, {2, 3, 1, 7}},
    {{2, 3, 5, 10}, {2, 3, 7, 10}, true, {2, 3, 1, 7}},
};

INSTANTIATE_TEST_SUITE_P(smoke_QuantizedMatMulMultiplyFusion, QuantizedMatMulMultiplyFusion,
                        ::testing::Combine(
                                ::testing::ValuesIn(shape_params2),
                                ::testing::Values(true), // can be fused
                                ::testing::Values(CommonTestUtils::DEVICE_CPU)),
                        QuantizedMatMulMultiplyFusion::getTestCaseName);

std::vector<MatMulMultiplyFusionShapeParams> negative_shape_params2 = {
    {{2, 2}, {2, 2}, false, {2, 2}},
    {{2, 2}, {2, 2}, true, {2, 2}},
    {{5, 5}, {5, 5}, false, {5, 5}},
    {{5, 5}, {5, 5}, true, {5, 5}},
    {{5, 10}, {10, 7}, false, {5, 7}},
    {{5, 10}, {7, 10}, true, {5, 7}},
    {{5, 10}, {10, 5}, false, {5, 5}},
    {{5, 10}, {5, 10}, true, {5, 5}},
    {{2, 3, 5, 10}, {10, 7}, false, {1, 1, 1}},
    {{1, 1, 5, 10}, {10, 7}, false, {2, 3, 1, 7}},
    {{1, 1, 5, 10}, {7, 10}, true, {2, 3, 1, 7}},
    {{2, 3, 5, 10}, {10, 7}, false, {2, 3, 5, 7}},
    {{2, 3, 5, 10}, {7, 10}, true, {2, 3, 5, 7}},
    {{2, 3, 5, 10}, {10, 7}, false, {1, 1, 1, 1}},
    {{2, 3, 5, 10}, {10, 7}, false, {1, 1, 7}},
    {{2, 3, 5, 10}, {7, 10}, true, {1, 1, 7}},
    {{2, 3, 5, 10}, {10, 7}, false, {1, 1, 1, 7}},
    {{2, 3, 5, 10}, {7, 10}, true, {1, 1, 1, 7}},
    {{2, 3, 5, 10}, {10, 7}, false, {2, 3, 1, 7}},
    {{2, 3, 5, 10}, {7, 10}, true, {2, 3, 1, 7}},
    {{1, 1, 5, 10}, {10, 7}, false, {1, 1, 1, 7}},
    {{1, 1, 5, 10}, {7, 10}, true, {1, 1, 1, 7}},
    {{2, 3, 5, 10}, {3, 10, 7}, false, {5, 7}},
    {{2, 3, 5, 10}, {3, 7, 10}, true, {5, 7}},
    {{2, 3, 5, 10}, {3, 10, 7}, false, {1, 1, 1, 7}},
    {{2, 3, 5, 10}, {3, 7, 10}, true, {1, 1, 1, 7}},
    {{2, 3, 5, 10}, {3, 10, 7}, false, {1, 3, 5, 7}},
    {{2, 3, 5, 10}, {3, 7, 10}, true, {1, 3, 5, 7}},
    {{2, 3, 5, 10}, {3, 10, 7}, false, {2, 3, 5, 7}},
    {{2, 3, 5, 10}, {3, 7, 10}, true, {2, 3, 5, 7}},
};

INSTANTIATE_TEST_SUITE_P(smoke_NegativeQuantizedMatMulMultiplyFusion, QuantizedMatMulMultiplyFusion,
                        ::testing::Combine(
                                ::testing::ValuesIn(negative_shape_params2),
                                ::testing::Values(false), // cannot be fused
                                ::testing::Values(CommonTestUtils::DEVICE_CPU)),
                        QuantizedMatMulMultiplyFusion::getTestCaseName);

} // namespace
