"""
Mostly the same as example/run_demo.py, but formatted as a test and using the
mock LLM provider. 
"""

import unittest
from pathlib import Path
from unittest.mock import patch

from kai.constants import PATH_GIT_ROOT
from kai.models.file_solution import FileSolutionContent
from kai.kai_config import (
    KaiConfig,
    KaiConfigIncidentStore,
    KaiConfigIncidentStoreSQLiteArgs,
    KaiConfigModels,
)
from kai.analyzer_types import Report
from kai.service.kai_application.kai_application import KaiApplication


class TestRunDemo(unittest.TestCase):
    def add_mock(self, target: str):
        patcher = patch(target)
        self.addCleanup(patcher.stop)
        return patcher.start()

    def setUp(self):
        PKG = "kai.service.kai_application.kai_application"

        self.mock_guess_language = self.add_mock(f"{PKG}.guess_language")
        self.mock_guess_language.return_value = "java"

        self.mock_parse_file_solution_content = self.add_mock(
            f"{PKG}.parse_file_solution_content"
        )
        self.mock_parse_file_solution_content.return_value = FileSolutionContent(
            reasoning="Mocked reasoning",
            updated_file="Mocked updated file",
            additional_info="Mocked additional info",
        )

    def tearDown(self):
        pass

    def test_run_demo(self):
        custom_llm_response = "Mocked LLM response"

        config = KaiConfig(
            log_level="info",
            demo_mode=False,
            trace_enabled=False,
            incident_store=KaiConfigIncidentStore(
                solution_detectors="naive",
                solution_producers="text_only",
                args=KaiConfigIncidentStoreSQLiteArgs(
                    provider="sqlite", connection_string="sqlite:///:memory:"
                ),
            ),
            models=KaiConfigModels(
                provider="FakeListChatModel",
                args={
                    "responses": [custom_llm_response],
                    "sleep": None,
                },
            ),
            solution_consumers=["diff_only"],
        )

        kai_application = KaiApplication(config)
        kai_application.incident_store.create_tables()

        PATH_EXAMPLE = Path(PATH_GIT_ROOT) / "example"
        report = Report.load_report_from_file(
            PATH_EXAMPLE / "analysis/coolstore/output.yaml"
        )

        impacted_files = report.get_impacted_files()
        for file_path, incidents in impacted_files.items():
            with open(PATH_EXAMPLE / "coolstore" / file_path, "r") as f:
                file_contents = f.read()

            response = kai_application.get_incident_solutions_for_file(
                file_name=str(file_path),
                file_contents=file_contents,
                application_name="coolstore",
                incidents=incidents,
                include_llm_results=True,
                batch_mode="none",
            )

            self.mock_guess_language.assert_called_with(
                file_contents, filename=str(file_path)
            )
            self.mock_parse_file_solution_content.assert_called_with(
                "java", custom_llm_response
            )

            self.assertEqual(response.updated_file, "Mocked updated file")
            self.assertEqual(len(response.total_reasoning), len(incidents))
            self.assertEqual(len(response.used_prompts), len(incidents))
            self.assertEqual(response.model_id, kai_application.model_provider.model_id)
            self.assertEqual(len(response.additional_information), len(incidents))
            self.assertEqual(len(response.llm_results), len(incidents))

        for file_path, incidents in impacted_files.items():
            with open(PATH_EXAMPLE / "coolstore" / file_path, "r") as f:
                file_contents = f.read()

            response = kai_application.get_incident_solutions_for_file(
                file_name=str(file_path),
                file_contents=file_contents,
                application_name="coolstore",
                incidents=incidents,
                include_llm_results=True,
            )

            self.mock_guess_language.assert_called_with(
                file_contents, filename=str(file_path)
            )
            self.mock_parse_file_solution_content.assert_called_with(
                "java", custom_llm_response
            )

            self.assertEqual(response.updated_file, "Mocked updated file")
            self.assertEqual(len(response.total_reasoning), 1)
            self.assertEqual(len(response.used_prompts), 1)
            self.assertEqual(response.model_id, kai_application.model_provider.model_id)
            self.assertEqual(len(response.additional_information), 1)
            self.assertEqual(len(response.llm_results), 1)
