// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include <vector>

#include <gtest/gtest.h>

#include "low_precision_transformations/reduce_mean_transformation.hpp"


using namespace LayerTestsDefinitions;

namespace {
const std::vector<ngraph::element::Type> netPrecisions = {
    ngraph::element::f32
};

const std::vector<ngraph::pass::low_precision::LayerTransformation::Params> trasformationParamValues = {
     LayerTestsUtils::LayerTransformationParamsNGraphFactory::createParamsU8I8()
};

const std::vector<LayerTestsDefinitions::ReduceMeanTransformationParam> params = {
    {
        { 256ul, ngraph::Shape{ 1, 1, 1, 1 }, { 0.f }, { 1.27f }, { 0.f }, { 1.27f } },
        {},
        {},
        {{ 2, 3 }, true},
        {},
        "Output_original",
        "U8"
    },
    {
        { 256ul, ngraph::Shape{ 1, 1, 1, 1 }, { -128.f }, { 1.27f }, { 0.f }, { 255.f }, ov::element::f32 },
        { ov::element::u8 },
        {
            { ov::element::f32 },
            { 128.f },
            { 0.01f }
        },
        {{ 2, 3 }, true},
        {},
        "Output_original",
        "U8"
    },
    {
        { 256ul, ngraph::Shape{ 1, 1, 1, 1 }, { -128.f }, { 1.27f }, { 0.f }, { 255.f }, ov::element::f32 },
        { ov::element::u8 },
        {
            { ov::element::f32 },
            {},
            { 0.01f }
        },
        {{ 2, 3 }, true},
        {},
        "Output_original",
        "U8"
    },
    {
        { 256ul, ngraph::Shape{ 1, 1, 1, 1 }, { 0.f }, { 255.f }, { 0.f }, { 127.f } },
        {},
        {},
        {{ 2, 3 }, false},
        {},
        "Output_original",
        "U8"
    },
    {
        { 256ul, ngraph::Shape{ 1, 1, 1, 1 }, { 0.f }, { 255.f }, { 0.f }, { 127.f } },
        {},
        {},
        {{ 1 }, true},
        {},
        "Output_original",
        "U8"
    },
    {
        { 256ul, ngraph::Shape{ 1, 1, 1, 1 }, { 0.f }, { 255.f }, { 0.f }, { 127.f } },
        {},
        {},
        {{ 1 }, false},
        {},
        "Output_original",
        "U8"
    },
    {
        {
            256ul, ngraph::Shape{ 1, 3, 1, 1 },
            { -127.f, -127.f, -127.f },
            { 128.f, 128.f, 128.f },
            { 0.f, 0.f, 0.f },
            { 255.f, 25.5f, 2.55f }
        },
        {},
        {},
        {{ 2, 3 }, true},
        {},
        "Output_original",
        "U8"
    },
    {
        {
            256ul, ngraph::Shape{ 1, 3, 1, 1 },
            { -127.f, -127.f, -127.f },
            { 128.f, 128.f, 128.f },
            { 0.f, 0.f, 0.f },
            { 255.f, 25.5f, 2.55f }
        },
        {},
        {},
        {{2, 3}, false},
        {},
        "Output_original",
        "U8"
    },
    {
        {
            256ul, ngraph::Shape{ 1, 3, 1, 1 },
            { -127.f, -127.f, -127.f },
            { 128.f, 128.f, 128.f },
            { 0.f, 0.f, 0.f },
            { 255.f, 25.5f, 2.55f }
        },
        {},
        {},
        {{0, 1}, true},
        {},
        "Output",
        "FP32"
    },
    {
        {
            256ul, ngraph::Shape{ 1, 3, 1, 1 },
            { -127.f, -127.f, -127.f },
            { 128.f, 128.f, 128.f },
            { 0.f, 0.f, 0.f },
            { 255.f, 25.5f, 2.55f }
        },
        {},
        {},
        {{0, 1}, false},
        {},
        "Output",
        "FP32"
    },
};

INSTANTIATE_TEST_SUITE_P(smoke_LPT, ReduceMeanTransformation,
    ::testing::Combine(
        ::testing::ValuesIn(netPrecisions),
        ::testing::Values(ngraph::PartialShape({ 1, 3, 10, 10 })),
        ::testing::Values(CommonTestUtils::DEVICE_CPU),
        ::testing::ValuesIn(trasformationParamValues),
        ::testing::ValuesIn(params)),
    ReduceMeanTransformation::getTestCaseName);
}  // namespace
