// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "jit_dnnl_emitters.hpp"
#include <nodes/eltwise.h>

using namespace dnnl::impl::utils;
using namespace dnnl::impl;
using namespace dnnl::impl::cpu::x64;
using namespace Xbyak;

namespace ov {
namespace intel_cpu {

jit_dnnl_emitter::jit_dnnl_emitter(jit_generator *host, cpu_isa_t host_isa, const std::shared_ptr<ngraph::Node>& node, InferenceEngine::Precision exec_prc)
    : jit_emitter(host, host_isa, node, exec_prc) {

    kind = dnnl_eltwise_tanh;
    alpha = 0.f;
    beta = 0.f;

    set_injector();
}

jit_dnnl_emitter::jit_dnnl_emitter(jit_generator *host, cpu_isa_t host_isa,
                                   dnnl_alg_kind_t algKind, float alpha, float beta,
                                   InferenceEngine::Precision exec_prc)
    : jit_emitter(host, host_isa, exec_prc), kind(algKind), alpha(alpha), beta(beta) {

    set_injector();
}

void jit_dnnl_emitter::set_injector() {
    if (host_isa_ == cpu::x64::sse41) {
        eltwise_injector_sse42 = std::make_shared<jit_uni_eltwise_injector_f32<cpu::x64::sse41>>(
                h, kind, alpha, beta, 1);
    } else if (host_isa_ == cpu::x64::avx2) {
        eltwise_injector_avx2 = std::make_shared<jit_uni_eltwise_injector_f32<cpu::x64::avx2>>(
                h, kind, alpha, beta, 1);
    } else if (host_isa_ == cpu::x64::avx512_core) {
        eltwise_injector_avx512_core = std::make_shared<jit_uni_eltwise_injector_f32<cpu::x64::avx512_core>>(
                h, kind, alpha, beta, 1);
    } else {
        assert(!"unsupported isa");
    }
}

size_t jit_dnnl_emitter::get_inputs_num() const { return 1; }

void jit_dnnl_emitter::emit_code(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs,
                                 const std::vector<size_t> &pool_vec_idxs, const std::vector<size_t> &pool_gpr_idxs) const {
    if (host_isa_ == cpu::x64::sse41) {
        if (out_vec_idxs[0] != in_vec_idxs[0])
            h->uni_vmovups(Xmm(out_vec_idxs[0]), Xmm(in_vec_idxs[0]));
        eltwise_injector_sse42->compute_vector(out_vec_idxs[0]);
    } else if (host_isa_ == cpu::x64::avx2) {
        if (out_vec_idxs[0] != in_vec_idxs[0])
            h->uni_vmovups(Ymm(out_vec_idxs[0]), Ymm(in_vec_idxs[0]));
        eltwise_injector_avx2->compute_vector(out_vec_idxs[0]);
    } else if (host_isa_ == cpu::x64::avx512_core) {
        if (out_vec_idxs[0] != in_vec_idxs[0])
            h->uni_vmovups(Zmm(out_vec_idxs[0]), Zmm(in_vec_idxs[0]));
        eltwise_injector_avx512_core->compute_vector(out_vec_idxs[0]);
    } else {
        assert(!"unsupported isa");
    }
}

void jit_dnnl_emitter::emit_data() const {
    if (host_isa_ == cpu::x64::sse41) {
        eltwise_injector_sse42->prepare_table();
    } else if (host_isa_ == cpu::x64::avx2) {
        eltwise_injector_avx2->prepare_table();
    } else if (host_isa_ == cpu::x64::avx512_core) {
        eltwise_injector_avx512_core->prepare_table();
    } else {
        assert(!"unsupported isa");
    }
}

jit_dnnl_aux_emitter::jit_dnnl_aux_emitter(jit_generator *host, cpu_isa_t host_isa,
                                           dnnl_alg_kind_t algKind, float inpAlpha, float inpBeta,
                                           InferenceEngine::Precision exec_prc)
    : jit_dnnl_emitter(host, host_isa, algKind, inpAlpha, inpBeta, exec_prc) {
}

}   // namespace intel_cpu
}   // namespace ov
