import queue
import threading
from pathlib import Path
from unittest.mock import Mock

import pytest

import api.config as config
import api.models as models
import api.streaming as streaming
from api.models import Session


@pytest.fixture(autouse=True)
def _isolate_sessions(tmp_path, monkeypatch):
    session_dir = tmp_path / "sessions"
    session_dir.mkdir()
    index_file = session_dir / "_index.json"
    monkeypatch.setattr(models, "SESSION_DIR", session_dir)
    monkeypatch.setattr(models, "SESSION_INDEX_FILE", index_file)
    monkeypatch.setattr(streaming, "SESSION_DIR", session_dir)
    monkeypatch.setattr(config, "SESSION_INDEX_FILE", index_file, raising=False)
    models.SESSIONS.clear()
    config.STREAMS.clear()
    config.CANCEL_FLAGS.clear()
    config.AGENT_INSTANCES.clear()
    config.SESSION_AGENT_LOCKS.clear()
    yield
    models.SESSIONS.clear()
    config.STREAMS.clear()
    config.CANCEL_FLAGS.clear()
    config.AGENT_INSTANCES.clear()
    config.SESSION_AGENT_LOCKS.clear()


def test_stream_writeback_requires_active_stream_ownership():
    s = Session(session_id="ownership", messages=[])
    s.active_stream_id = "current-stream"

    assert streaming._stream_writeback_is_current(s, "current-stream") is True

    s.active_stream_id = None
    assert streaming._stream_writeback_is_current(s, "current-stream") is False

    s.active_stream_id = "newer-stream"
    assert streaming._stream_writeback_is_current(s, "current-stream") is False


def test_cancel_stream_does_not_append_marker_after_stream_ownership_rotated():
    sid = "rotated_cancel_sid"
    old_stream = "old-stream"
    s = Session(
        session_id=sid,
        title="Rotated stream",
        messages=[{"role": "user", "content": "newer prompt"}],
    )
    s.active_stream_id = "newer-stream"
    s.pending_user_message = "newer prompt"
    s.pending_started_at = 456.0
    s.save()
    models.SESSIONS[sid] = s

    config.STREAMS[old_stream] = queue.Queue()
    config.CANCEL_FLAGS[old_stream] = threading.Event()
    mock_agent = Mock()
    mock_agent.session_id = sid
    mock_agent.interrupt = Mock()
    config.AGENT_INSTANCES[old_stream] = mock_agent

    assert streaming.cancel_stream(old_stream) is True

    assert s.active_stream_id == "newer-stream"
    assert s.pending_user_message == "newer prompt"
    assert [m["content"] for m in s.messages] == ["newer prompt"]
    assert all(m.get("content") != "*Task cancelled.*" for m in s.messages)


def test_success_path_checks_stream_ownership_before_persisting_result():
    src = Path("api/streaming.py").read_text(encoding="utf-8")
    guard = "if not ephemeral and not _stream_writeback_is_current(s, stream_id):"
    guard_pos = src.find(guard)
    result_merge_pos = src.find("_result_messages = result.get('messages') or _previous_context_messages")
    compression_pos = src.find("Handle context compression side effects")

    assert guard_pos != -1
    assert result_merge_pos != -1
    assert compression_pos != -1
    assert guard_pos < result_merge_pos
    assert guard_pos < compression_pos


def test_self_heal_retry_success_checks_stream_ownership_before_writeback():
    src = Path("api/streaming.py").read_text(encoding="utf-8")
    start = src.index("logger.info('[webui] self-heal (except path): retrying stream")
    end = src.index("logger.info('[webui] self-heal (except path): retry succeeded')", start)
    block = src[start:end]
    guard = "if not ephemeral and not _stream_writeback_is_current(s, stream_id):"

    assert guard in block
    assert block.index(guard) < block.index("_result_messages = _heal_result.get('messages') or _previous_context_messages")
    assert block.index(guard) < block.index("s.save()")


def test_outer_exception_path_checks_stream_ownership_before_error_writeback():
    src = Path("api/streaming.py").read_text(encoding="utf-8")
    outer_error_payload = src.index("_error_payload = _provider_error_payload(err_str, _exc_type, _exc_hint)")
    start = src.index("# Persist the error so it survives page reload.", outer_error_payload)
    end = src.index("put('apperror', _error_payload)", start)
    block = src[start:end]
    guard = "if not ephemeral and not _stream_writeback_is_current(s, stream_id):"

    assert guard in block
    assert block.index(guard) < block.index("_materialize_pending_user_turn_before_error(s)")
    assert block.index(guard) < block.index("s.active_stream_id = None")
    assert block.index(guard) < block.index("s.messages.append(_error_message)")
