// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <ngraph/pass/graph_rewrite.hpp>
#include <transformations_visibility.hpp>

namespace ngraph {
namespace pass {

class TRANSFORMATIONS_API BinarizeWeights;

}  // namespace pass
}  // namespace ngraph

// clang-format off
/**
 * @ingroup ie_transformation_common_api
 * @brief This transformation converts weights to -1/+1 form
 * and applies normalization factors to output low/high and after Convolution.
 * For example, following graph
 *
 *         .... ....  out_low  out_high           weights ..    ..  out_low out_high
 *           |    |      |        |                  |     |    |      |     |
 *          +--------------------------+           +--------------------------+
 *          | FakeQuantize (levels==2) |           | FakeQuantize (levels==2) |
 *          |     (on activations)     |           |       (on weights)       |
 *          +--------------------------+           +--------------------------+
 *                        |                                      |
 *                        |                                      |
 *                        -----------------    -------------------
 *                                        |    |
 *                                        v    v
 *                                   +-------------+
 *                                   | Convolution |
 *                                   +-------------+
 *                                          |
 *                                          v
 *
 * is transformed to:
 *
 *                  normalized normalized
 *         .... ....  out_low   out_high
 *           |    |      |         |
 *          +--------------------------+           +--------------------------+
 *          | FakeQuantize (levels==2) |           |         Constant         |
 *          |     (on activations)     |           | (with converted weights) |
 *          +--------------------------+           +--------------------------+
 *                        |                                      |
 *                        |                                      |
 *                        -----------------    -------------------
 *                                        |    |
 *                                        v    v
 *                                   +-------------+
 *                                   | Convolution |
 *                                   +-------------+
 *                                          |
 *                                          v
 *                                   +------------+     +---------------------------------------------------------------+
 *                                   |  Multiply  | <---| Constant (normalization factor coming from FQ on activations) |
 *                                   +------------+     +---------------------------------------------------------------+
 *                                          |
 *                                          v
 *                                   +------------+     +-----------------------------------------------------------+
 *                                   |  Multiply  | <---| Constant (normalization factor coming from FQ on weights) |
 *                                   +------------+     +------------------------------------------------------------
 *                                          |
 *                                          v
 *
 * Normalization factors are chosen based output_high value.
 * If it's zero - norm factor is equal to output_low and output_high otherwise
 */
// clang-format on

class ngraph::pass::BinarizeWeights : public ngraph::pass::MatcherPass {
public:
    OPENVINO_RTTI("BinarizeWeights", "0");
    BinarizeWeights();
};
