from typing import Optional, List
from datetime import datetime, timedelta
from sqlalchemy import select, delete
from sqlalchemy.exc import IntegrityError
from src.models.computadores import Computer, ComputerOnline
from sqlalchemy.ext.asyncio import AsyncSession
from src.schemas.computadores import ComputerCreate, ComputerUpdate
from src.services.usuarios import UsuarioService
from src.crud.computadores import crud_online


class ComputadorService:
    def __init__(self):
        self.usuario_service = UsuarioService()

    async def criar(
        self,
            db: AsyncSession,
            payload: ComputerCreate) -> Computer:
        usuario_id = None
        if payload.nome_usuario:
            usuario = await self.usuario_service.get_or_create_usuario(
                db, payload.nome_usuario
            )
            usuario_id = usuario.id

        data = payload.model_dump()
        data['usuario_id'] = usuario_id
        obj = Computer(**data)
        db.add(obj)
        try:
            await db.commit()
        except IntegrityError as e:
            await db.rollback()
            raise e
        await db.refresh(obj)
        return obj

    async def listar(
        self,
            db: AsyncSession,
            *,
            empresa_id: Optional[str] = None) -> List[Computer]:
        stmt = select(Computer)
        if empresa_id:
            stmt = stmt.where(Computer.empresa_id == str(empresa_id))
        result = await db.execute(stmt)
        return list(result.scalars().all())

    async def atualizar(
        self,
            db: AsyncSession,
            computador_id: str,
            payload: ComputerUpdate) -> Optional[Computer]:
        result = await db.execute(
            select(Computer).where(Computer.id == str(computador_id))
        )
        obj: Optional[Computer] = result.scalar_one_or_none()
        if not obj:
            return None

        data = payload.model_dump(exclude_unset=True)
        if 'nome_usuario' in data and data['nome_usuario']:
            usuario = await self.usuario_service.get_or_create_usuario(
                db, data['nome_usuario']
            )
            data['usuario_id'] = usuario.id

        for field, value in data.items():
            setattr(obj, field, value)
        try:
            await db.commit()
        except IntegrityError as e:
            await db.rollback()
            raise e
        await db.refresh(obj)
        return obj

    async def deletar(self, db: AsyncSession, computador_id: str) -> bool:
        result = await db.execute(
            select(Computer.id).where(Computer.id == str(computador_id))
        )
        exists = result.scalar_one_or_none()
        if not exists:
            return False
        await db.execute(
            delete(Computer).where(Computer.id == str(computador_id))
        )
        await db.commit()
        return True

    async def registrar_heartbeat(
        self,
            db: AsyncSession,
            computador_id: str,
            nome_usuario: str) -> Optional[ComputerOnline]:
        result = await db.execute(
            select(Computer).where(Computer.id == str(computador_id))
        )
        obj: Optional[Computer] = result.scalar_one_or_none()
        if not obj:
            return None
        online_obj = await crud_online.create_or_update(
            db, computador_id, nome_usuario
        )
        return online_obj

    async def verificar_heartbeat(
        self,
            db: AsyncSession,
            computador_id: str) -> Optional[Computer]:
        result = await db.execute(
            select(Computer).where(Computer.id == str(computador_id))
        )
        obj: Optional[Computer] = result.scalar_one_or_none()
        return obj

    async def listar_computadores_online(
        self,
            db: AsyncSession,
            threshold_minutes: int = 5) -> List[Computer]:
        cutoff_time = datetime.now() - timedelta(minutes=threshold_minutes)
        stmt = select(Computer).join(
            ComputerOnline
        ).where(
            ComputerOnline.ultimo_visto >= cutoff_time
        )
        result = await db.execute(stmt)
        return list(result.scalars().all())
