100 lines
3.4 KiB
Python
100 lines
3.4 KiB
Python
from fastapi import HTTPException, status, APIRouter, WebSocket, WebSocketDisconnect, Query, Depends
|
|
from app.core.security import test_token
|
|
from typing import Dict
|
|
from datetime import datetime
|
|
import json
|
|
from sqlalchemy.orm import Session
|
|
from app.db import models
|
|
|
|
|
|
# бд
|
|
def get_db():
|
|
db = models.SessionLocal()
|
|
try:
|
|
yield db
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
wsRouter = APIRouter(
|
|
prefix='/ws'
|
|
)
|
|
|
|
|
|
@wsRouter.websocket("")
|
|
async def websocket_endpoint(websocket: WebSocket, token: str = Query(None), db: Session = Depends(get_db)):
|
|
if token is None:
|
|
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
|
return
|
|
try:
|
|
user_id = await test_token(token=token)
|
|
except HTTPException:
|
|
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
|
return
|
|
print("ПОДКЛЮЧЕНИЕ")
|
|
await manager.connect(user_id, websocket)
|
|
print("ПОДКЛЮЧЕНО")
|
|
try:
|
|
while True:
|
|
print("ОЖИДАНИЕ СООБЩЕНИЙ")
|
|
data = await websocket.receive_text()
|
|
message_data = json.loads(data)
|
|
print(f"DEBUG: Получены данные: {message_data}")
|
|
|
|
if message_data.get("type") == "private_message":
|
|
receiver_id = message_data.get("receiver_id")
|
|
content = message_data.get("content")
|
|
new_msg = models.Message(
|
|
sender_id=user_id,
|
|
receiver_id=receiver_id,
|
|
content=content
|
|
)
|
|
db.add(new_msg)
|
|
db.commit()
|
|
db.refresh(new_msg)
|
|
# Формируем пакет для получателя
|
|
outgoing_message = {
|
|
"id": new_msg.id,
|
|
"type": "private_message",
|
|
"sender_id": user_id,
|
|
"reciever_id": receiver_id,
|
|
"content": message_data.get("content"),
|
|
"timestamp": datetime.now().isoformat()
|
|
}
|
|
|
|
# Пересылаем получателю, если он в сети
|
|
await manager.send_personal_message(outgoing_message, str(receiver_id))
|
|
except WebSocketDisconnect:
|
|
pass
|
|
finally:
|
|
manager.disconnect(user_id)
|
|
|
|
|
|
class ConnectionManager:
|
|
def __init__(self):
|
|
# Храним активные соединения: {user_id: websocket}
|
|
self.active_connections: Dict[str, WebSocket] = {}
|
|
|
|
async def connect(self, user_id: str, websocket: WebSocket):
|
|
await websocket.accept()
|
|
self.active_connections[user_id] = websocket
|
|
|
|
def disconnect(self, user_id: str):
|
|
if user_id in self.active_connections:
|
|
del self.active_connections[user_id]
|
|
|
|
async def send_personal_message(self, message: dict, user_id: str):
|
|
if str(user_id) in self.active_connections:
|
|
await self.active_connections[str(user_id)].send_json(message)
|
|
print('Sent to socket')
|
|
else:
|
|
print('User not active')
|
|
|
|
async def broadcast(self, message: dict):
|
|
# Рассылка вообще всем (например, системное уведомление)
|
|
for connection in self.active_connections.values():
|
|
await connection.send_json(message)
|
|
|
|
|
|
manager = ConnectionManager()
|