diff --git a/android/app/src/main/AndroidManifest.xml b/android/app/src/main/AndroidManifest.xml index a83efcc..d2d92cf 100644 --- a/android/app/src/main/AndroidManifest.xml +++ b/android/app/src/main/AndroidManifest.xml @@ -1,6 +1,6 @@ _instance; + LocalDbService._internal(); + + Future get database async { + if (_database != null) return _database!; + _database = await _initDb(); + return _database!; + } + + Future _initDb() async { + String path = join(await getDatabasesPath(), 'chat_app.db'); + return await openDatabase( + path, + version: 1, + onCreate: (db, version) async { + await db.execute(''' + CREATE TABLE messages( + id INTEGER PRIMARY KEY, + sender_id INTEGER, + receiver_id INTEGER, + content TEXT, + timestamp TEXT + ) + '''); + }, + ); + } + + // Сохранение списка сообщений (из истории) + Future saveMessages(List messages) async { + final db = await database; + Batch batch = db.batch(); + for (var msg in messages) { + if (msg is MessageModel) { + batch.insert('messages', { + 'id': msg.id, + 'sender_id': msg.senderId, + 'receiver_id': msg.receiverId, + 'content': msg.text, // ВАЖНО: сохраняй зашифрованный текст! + 'timestamp': msg.createdAt.toIso8601String(), + }, conflictAlgorithm: ConflictAlgorithm.replace); + } else { + // Если это Map из API + batch.insert('messages', { + 'id': msg['id'], + 'sender_id': msg['sender_id'], + 'receiver_id': msg['receiver_id'], // Убедись, что ключ совпадает с API + 'content': msg['content'], + 'timestamp': msg['timestamp'], + }, conflictAlgorithm: ConflictAlgorithm.replace); + } + } + await batch.commit(noResult: true); + } + + // Получение сообщений конкретного чата + Future>> getChatHistory( + int contactId, + int myId, + ) async { + final db = await database; + return await db.query( + 'messages', + where: + '(sender_id = ? AND receiver_id = ?) OR (sender_id = ? AND receiver_id = ?)', + whereArgs: [contactId, myId, myId, contactId], + orderBy: 'timestamp ASC', + ); + } +} diff --git a/lib/data/datasources/ws_client.dart b/lib/data/datasources/ws_client.dart index 48409b8..1cf79d8 100644 --- a/lib/data/datasources/ws_client.dart +++ b/lib/data/datasources/ws_client.dart @@ -6,8 +6,15 @@ import 'package:web_socket_channel/status.dart' as status; import 'package:chepuhagram/core/constants.dart'; class SocketService { + static final SocketService _instance = SocketService._internal(); + + factory SocketService() { + return _instance; + } + SocketService._internal(); + WebSocketChannel? _channel; - final StreamController> _messageController = + final StreamController> _messageController = StreamController>.broadcast(); // Поток, который будут слушать провайдеры @@ -19,7 +26,7 @@ class SocketService { // В FastAPI эндпоинт обычно ожидает токен в URL или подзаголовке final uri = Uri.parse("ws://${AppConstants.baseUrl}/ws?token=$token"); - + _channel = WebSocketChannel.connect(uri); _channel!.stream.listen( @@ -38,11 +45,25 @@ class SocketService { } void sendMessage(Map data) { - _channel?.sink.add(jsonEncode(data)); + if (_channel == null) { + print("❌ ОШИБКА: Попытка отправить сообщение через NULL канал."); + return; + } + try { + final encodedData = jsonEncode(data); + + // 1. Проверяем, не закрыт ли sink (у некоторых провайдеров это доступно) + _channel!.sink.add(encodedData); + + // 2. Добавляем принт подтверждения + print("🚀 СООБЩЕНИЕ ОТПРАВЛЕНО В SINK: $encodedData"); + } catch (e) { + print("❌ КРИТИЧЕСКАЯ ОШИБКА ПРИ ОТПРАВКЕ: $e"); + } } void disconnect() { _channel?.sink.close(status.goingAway); _channel = null; } -} \ No newline at end of file +} diff --git a/lib/data/models/contact_model.dart b/lib/data/models/contact_model.dart index 5cf8cd8..ca38142 100644 --- a/lib/data/models/contact_model.dart +++ b/lib/data/models/contact_model.dart @@ -8,6 +8,7 @@ class Contact { final DateTime? lastMessageTime; final bool isOnline; final int unreadCount; + final String? publicKey; Contact({ required this.id, @@ -19,6 +20,7 @@ class Contact { this.lastMessageTime, this.isOnline = false, this.unreadCount = 0, + this.publicKey, }); factory Contact.fromJson(Map json) { @@ -27,7 +29,7 @@ class Contact { username: json['username'] ?? 'Unknown', name: json['name'] ?? 'Unknown', surname: json['surname'] ?? 'Unknown', - // Другие поля можно добавить позже + publicKey: json['public_key'], ); } } \ No newline at end of file diff --git a/lib/data/models/message_model.dart b/lib/data/models/message_model.dart index 1d9c3ab..88dd875 100644 --- a/lib/data/models/message_model.dart +++ b/lib/data/models/message_model.dart @@ -1,6 +1,7 @@ class MessageModel { final int? id; // ID из базы данных (null, если сообщение еще не отправлено) final int senderId; // ID отправителя + final int receiverId; // ID отправителя final String text; // Текст сообщения final DateTime createdAt; // Время отправки final bool isMe; // Локальный флаг для UI (мое/чужое) @@ -8,6 +9,7 @@ class MessageModel { MessageModel({ this.id, required this.senderId, + required this.receiverId, required this.text, required this.createdAt, this.isMe = false, @@ -18,6 +20,7 @@ class MessageModel { return MessageModel( id: json['id'], senderId: json['sender_id'], + receiverId: json['receiverId'], text: json['text'] ?? '', // Парсим дату из ISO строки или временной метки createdAt: DateTime.parse(json['created_at']), diff --git a/lib/data/repositories/contact_repository.dart b/lib/data/repositories/contact_repository.dart index 10d8837..1e8dfe0 100644 --- a/lib/data/repositories/contact_repository.dart +++ b/lib/data/repositories/contact_repository.dart @@ -29,4 +29,22 @@ class ContactRepository { throw Exception('Failed to load contacts'); } } -} \ No newline at end of file + + Future fetchContactById(int userId) async { + final token = await _apiService.getAccessToken(); + final response = await _client.get( + Uri.http(AppConstants.baseUrl, 'users/$userId'), + headers: { + 'Authorization': 'Bearer $token', + 'Content-Type': 'application/json', + }, + ); + + if (response.statusCode == 200) { + final data = jsonDecode(utf8.decode(response.bodyBytes)); + return Contact.fromJson(data); + } else { + throw Exception('Не удалось загрузить данные контакта'); + } + } +} diff --git a/lib/domain/services/api_service.dart b/lib/domain/services/api_service.dart index ada0dc8..660554c 100644 --- a/lib/domain/services/api_service.dart +++ b/lib/domain/services/api_service.dart @@ -90,4 +90,19 @@ class ApiService extends ChangeNotifier { notifyListeners(); } } + + Future> getChatHistory(int contactId) async { + final token = await getAccessToken(); + final response = await http.get( + Uri.http( + AppConstants.baseUrl, + 'messages/history/${contactId.toString()}', + ), + headers: { + 'Content-Type': 'application/json', + "Authorization": "Bearer $token", + }, + ); + return jsonDecode(response.body) as List; + } } diff --git a/lib/logic/contact_provider.dart b/lib/logic/contact_provider.dart index 23ad63e..7d9877c 100644 --- a/lib/logic/contact_provider.dart +++ b/lib/logic/contact_provider.dart @@ -20,6 +20,10 @@ class ContactProvider extends ChangeNotifier { notifyListeners(); } + int? getCurrentUserId() { + return _currentUserId; + } + Future loadContacts() async { _isLoading = true; _error = null; diff --git a/lib/main.dart b/lib/main.dart index 4494988..4b6ef79 100644 --- a/lib/main.dart +++ b/lib/main.dart @@ -1,7 +1,7 @@ import 'package:chepuhagram/presentation/screens/splash_screen.dart'; import 'package:flutter/material.dart'; import 'package:provider/provider.dart'; - +import 'data/datasources/ws_client.dart'; import 'logic/auth_provider.dart'; import 'logic/contact_provider.dart'; import 'core/theme_manager.dart'; @@ -13,7 +13,10 @@ void main() { MultiProvider( providers: [ ChangeNotifierProvider(create: (_) => AuthProvider()), - ChangeNotifierProvider(create: (_) => ThemeProvider()), ChangeNotifierProvider(create: (_) => ContactProvider()), ], + ChangeNotifierProvider(create: (_) => ThemeProvider()), + ChangeNotifierProvider(create: (_) => ContactProvider()), + Provider(create: (_) => SocketService()), + ], child: const MyApp(), ), ); diff --git a/lib/presentation/screens/chat_screen.dart b/lib/presentation/screens/chat_screen.dart index bc69a75..fa99e89 100644 --- a/lib/presentation/screens/chat_screen.dart +++ b/lib/presentation/screens/chat_screen.dart @@ -2,6 +2,14 @@ import 'package:flutter/material.dart'; import '/data/models/message_model.dart'; import '/data/models/contact_model.dart'; import 'package:chepuhagram/presentation/widgets/message_bubble.dart'; +import 'package:chepuhagram/data/repositories/contact_repository.dart'; +import 'package:chepuhagram/domain/services/crypto_service.dart'; +import 'package:chepuhagram/data/datasources/ws_client.dart'; +import 'dart:convert'; +import 'package:provider/provider.dart'; +import '/logic/contact_provider.dart'; +import '../../domain/services/api_service.dart'; +import 'package:chepuhagram/data/datasources/local_db_service.dart'; class ChatScreen extends StatefulWidget { final Contact contact; @@ -13,8 +21,48 @@ class ChatScreen extends StatefulWidget { } class _ChatScreenState extends State { + int myId = 0; + late Contact _currentContact; + bool _isKeyLoading = false; final TextEditingController _controller = TextEditingController(); - final List messages = []; + final ContactRepository _contactRepository = ContactRepository(); + final apiService = ApiService(); + final CryptoService _cryptoService = CryptoService(); + List messages = []; + + @override + void initState() { + super.initState(); + _currentContact = widget.contact; + final contactProvider = context.read(); + myId = contactProvider.getCurrentUserId() ?? 0; + _loadHistory(); + // Если ключа нет, загружаем его при входе + if (_currentContact.publicKey == null) { + _loadContactKey(); + } + } + + Future _loadContactKey() async { + setState(() => _isKeyLoading = true); + try { + final updatedContact = await _contactRepository.fetchContactById( + _currentContact.id, + ); + setState(() { + _currentContact = updatedContact; + _isKeyLoading = false; + }); + print(updatedContact.publicKey); + } catch (e) { + setState(() => _isKeyLoading = false); + ScaffoldMessenger.of(context).showSnackBar( + const SnackBar( + content: Text("Не удалось получить ключ шифрования собеседника"), + ), + ); + } + } @override void dispose() { @@ -25,7 +73,7 @@ class _ChatScreenState extends State { @override Widget build(BuildContext context) { return Scaffold( - appBar: AppBar(title: Text(widget.contact.name)), + appBar: AppBar(title: Text(_currentContact.name)), body: Column( children: [ Expanded( @@ -49,25 +97,210 @@ class _ChatScreenState extends State { } Widget _buildMessageInput() { - return Padding( - padding: const EdgeInsets.all(8.0), - child: Row( - children: [ - Expanded( - child: TextField( - controller: _controller, - decoration: const InputDecoration(hintText: "Напиши сообщение..."), + return SafeArea( + // Добавляем SafeArea здесь + child: Padding( + padding: const EdgeInsets.all(8.0), + child: Row( + children: [ + Expanded( + child: TextField( + controller: _controller, + decoration: const InputDecoration( + hintText: "Напиши сообщение...", + ), + ), ), - ), - IconButton( - icon: const Icon(Icons.send), - onPressed: () { - // Логика отправки через WebSocket или API - _controller.clear(); - }, - ), - ], + IconButton( + icon: const Icon(Icons.send), + onPressed: () { + _sendMessage(); + }, + ), + ], + ), ), ); } -} \ No newline at end of file + + Future _sendMessage() async { + final rawText = _controller.text.trim(); + if (rawText.isEmpty) return; + _controller.clear(); + + if (_currentContact.publicKey == null) { + await _loadContactKey(); + if (_currentContact.publicKey == null) return; + } + + try { + final myPrivKey = await _cryptoService.getPrivateKey(); + + final sharedSecret = await _cryptoService.deriveSharedSecret( + myPrivKey!, + _currentContact.publicKey!, + ); + + final encryptedText = await _cryptoService.encryptMessage( + rawText, + sharedSecret, + ); + + // Формируем payload для сервера + final payload = { + "type": "private_message", + "receiver_id": _currentContact.id, + "content": encryptedText, + }; + + // Отправляем + print("ОТПРАВКА: $payload"); + Provider.of(context, listen: false).sendMessage(payload); + + // Обновляем UI (себе показываем расшифрованный текст) + + setState(() { + messages.add( + MessageModel( + text: rawText, + isMe: true, + senderId: myId, + receiverId: _currentContact.id, + createdAt: DateTime.now(), + ), + ); + }); + + _controller.clear(); + } catch (e) { + _controller.text = rawText; + ScaffoldMessenger.of( + context, + ).showSnackBar(SnackBar(content: Text("Ошибка шифрования: $e"))); + } + } + + @override + void didChangeDependencies() { + super.didChangeDependencies(); + // Подписываемся на поток сообщений из сокета + final socketService = Provider.of(context, listen: false); + + socketService.messages.listen((rawData) { + _handleIncomingMessage(rawData); + }); + } + + void _handleIncomingMessage(Map data) async { + if (data['type'] == 'private_message') { + final int senderId = int.parse(data['sender_id'].toString()); + + // 1. Проверяем, что сообщение именно от того, с кем мы сейчас общаемся + if (senderId == widget.contact.id) { + try { + final myPrivKey = await _cryptoService.getPrivateKey(); + + // 2. Вычисляем общий секрет для расшифровки + final sharedSecret = await _cryptoService.deriveSharedSecret( + myPrivKey!, + widget.contact.publicKey!, + ); + + // 3. Расшифровываем контент + final decryptedText = await _cryptoService.decryptMessage( + data['content'], + sharedSecret, + ); + + // 4. Добавляем в список и обновляем экран + await LocalDbService().saveMessages([data]); + setState(() { + messages.add( + MessageModel( + text: decryptedText, + isMe: false, + senderId: senderId, + receiverId: myId, + createdAt: DateTime.parse(data['timestamp']), + ), + ); + }); + } catch (e) { + print("Ошибка расшифровки входящего сообщения: $e"); + } + } else { + print( + "Сообщение от другого пользователя (ID: $senderId), игнорируем в этом чате", + ); + // Тут можно добавить логику уведомления для списка чатов + } + } + } + + Future _loadHistory() async { + try { + final myPrivKey = await _cryptoService.getPrivateKey(); + final sharedSecret = await _cryptoService.deriveSharedSecret( + myPrivKey!, + widget.contact.publicKey!, + ); + final localDb = LocalDbService(); + final cached = await localDb.getChatHistory(widget.contact.id, myId); + + try { + List loadedLocalMessages = []; + for (var msg in cached) { + final decrypted = await _cryptoService.decryptMessage( + msg['content'], + sharedSecret, + ); + loadedLocalMessages.add( + MessageModel( + text: decrypted, + isMe: msg['sender_id'] == myId, + senderId: msg['sender_id'], + receiverId: msg['receiver_id'], + createdAt: DateTime.parse(msg['timestamp']), + ), + ); + } + if (cached.isNotEmpty) { + setState(() { + messages = loadedLocalMessages; + _isKeyLoading = false; + }); + } + } catch (e) { + print(e); + } + + final history = await apiService.getChatHistory(widget.contact.id); + + List loadedMessages = []; + for (var msg in history) { + final decrypted = await _cryptoService.decryptMessage( + msg['content'], + sharedSecret, + ); + loadedMessages.add( + MessageModel( + text: decrypted, + isMe: msg['sender_id'] == myId, + senderId: msg['sender_id'], + receiverId: msg['receiver_id'], + createdAt: DateTime.parse(msg['timestamp']), + ), + ); + } + await localDb.saveMessages(history); + + setState(() { + messages = loadedMessages; + _isKeyLoading = false; + }); + } catch (e) { + print("Ошибка загрузки истории: $e"); + setState(() => _isKeyLoading = false); + } + } +} diff --git a/macos/Flutter/GeneratedPluginRegistrant.swift b/macos/Flutter/GeneratedPluginRegistrant.swift index 61d01d0..d378756 100644 --- a/macos/Flutter/GeneratedPluginRegistrant.swift +++ b/macos/Flutter/GeneratedPluginRegistrant.swift @@ -7,8 +7,10 @@ import Foundation import flutter_secure_storage_darwin import path_provider_foundation +import sqflite_darwin func RegisterGeneratedPlugins(registry: FlutterPluginRegistry) { FlutterSecureStorageDarwinPlugin.register(with: registry.registrar(forPlugin: "FlutterSecureStorageDarwinPlugin")) PathProviderPlugin.register(with: registry.registrar(forPlugin: "PathProviderPlugin")) + SqflitePlugin.register(with: registry.registrar(forPlugin: "SqflitePlugin")) } diff --git a/pubspec.lock b/pubspec.lock index d9bdf0b..0efd1b5 100644 --- a/pubspec.lock +++ b/pubspec.lock @@ -273,7 +273,7 @@ packages: source: hosted version: "2.2.0" path: - dependency: transitive + dependency: "direct main" description: name: path sha256: "75cca69d1490965be98c73ceaea117e8a04dd21217b37b292c9ddbec0d955bc5" @@ -365,6 +365,46 @@ packages: url: "https://pub.dev" source: hosted version: "1.10.2" + sqflite: + dependency: "direct main" + description: + name: sqflite + sha256: "564cfed0746fe53140c23b70b308e045c3b31f17778f2f326ccb7d804ea0250a" + url: "https://pub.dev" + source: hosted + version: "2.4.2+1" + sqflite_android: + dependency: transitive + description: + name: sqflite_android + sha256: "881e28efdcc9950fd8e9bb42713dcf1103e62a2e7168f23c9338d82db13dec40" + url: "https://pub.dev" + source: hosted + version: "2.4.2+3" + sqflite_common: + dependency: transitive + description: + name: sqflite_common + sha256: "5e8377564d95166761a968ed96104e0569b6b6cc611faac92a36ab8a169112c3" + url: "https://pub.dev" + source: hosted + version: "2.5.6+1" + sqflite_darwin: + dependency: transitive + description: + name: sqflite_darwin + sha256: "279832e5cde3fe99e8571879498c9211f3ca6391b0d818df4e17d9fff5c6ccb3" + url: "https://pub.dev" + source: hosted + version: "2.4.2" + sqflite_platform_interface: + dependency: transitive + description: + name: sqflite_platform_interface + sha256: "8dd4515c7bdcae0a785b0062859336de775e8c65db81ae33dd5445f35be61920" + url: "https://pub.dev" + source: hosted + version: "2.4.0" stack_trace: dependency: transitive description: @@ -389,6 +429,14 @@ packages: url: "https://pub.dev" source: hosted version: "1.4.1" + synchronized: + dependency: transitive + description: + name: synchronized + sha256: c254ade258ec8282947a0acbbc90b9575b4f19673533ee46f2f6e9b3aeefd7c0 + url: "https://pub.dev" + source: hosted + version: "3.4.0" term_glyph: dependency: transitive description: @@ -471,4 +519,4 @@ packages: version: "1.1.0" sdks: dart: ">=3.10.0 <4.0.0" - flutter: ">=3.35.6" + flutter: ">=3.38.0" diff --git a/pubspec.yaml b/pubspec.yaml index 4b7842d..7390ecd 100644 --- a/pubspec.yaml +++ b/pubspec.yaml @@ -40,6 +40,8 @@ dependencies: jwt_decoder: ^2.0.1 web_socket_channel: ^3.0.3 cryptography: ^2.5.0 + sqflite: ^2.3.0 + path: ^1.9.0 dev_dependencies: flutter_test: diff --git a/srv/app/api/endpoints/messages.py b/srv/app/api/endpoints/messages.py new file mode 100644 index 0000000..3e46fdd --- /dev/null +++ b/srv/app/api/endpoints/messages.py @@ -0,0 +1,35 @@ +from fastapi import Depends, APIRouter, HTTPException, Depends +from sqlalchemy.orm import Session +from app.db import models +from app.core.security import get_current_user +from app.api import schemas + + +# бд +def get_db(): + db = models.SessionLocal() + try: + yield db + finally: + db.close() + + +messagesRouter = APIRouter( + prefix="/messages", + tags=[], +) + +@messagesRouter.get("/history/{contact_id}") +async def get_chat_history( + contact_id: int, + current_user: models.User = Depends(get_current_user), + db: Session = Depends(get_db), + limit: int = 50 +): + messages = db.query(models.Message).filter( + (models.Message.sender_id == current_user.id) & (models.Message.receiver_id == contact_id) | + (models.Message.sender_id == contact_id) & (models.Message.receiver_id == current_user.id) + ).order_by(models.Message.timestamp.asc()).limit(limit).all() + + return messages + diff --git a/srv/app/api/endpoints/users.py b/srv/app/api/endpoints/users.py index 51e87f9..94acfbd 100644 --- a/srv/app/api/endpoints/users.py +++ b/srv/app/api/endpoints/users.py @@ -1,8 +1,9 @@ -from fastapi import Depends, APIRouter +from fastapi import Depends, APIRouter, HTTPException, Depends from sqlalchemy.orm import Session from app.db import models from app.core.security import get_current_user +from app.api import schemas # бд @@ -13,17 +14,38 @@ def get_db(): finally: db.close() + usersRouter = APIRouter( prefix="/users", tags=[], ) # Пример защищенного роута + + @usersRouter.get("/me") async def read_users_me(current_user: models.User = Depends(get_current_user)): return {"id": current_user.id, "username": current_user.username, "first_name": current_user.first_name, "last_name": current_user.last_name, "public_key": current_user.public_key, "encrypted_private_key": current_user.encrypted_private_key} + @usersRouter.get("/all") async def read_users_all(current_user: models.User = Depends(get_current_user), db: Session = Depends(get_db)): users = db.query(models.User).all() - return [{"id": user.id, "username": user.username, "name": f"{user.first_name} {user.last_name or ''}".strip(), "public_key": user.public_key} for user in users] \ No newline at end of file + return [{"id": user.id, "username": user.username, "name": f"{user.first_name} {user.last_name or ''}".strip(), "public_key": user.public_key} for user in users] + + +@usersRouter.get("/{user_id}", response_model=schemas.UserPublic) +def get_user_by_id( + user_id: int, + db: Session = Depends(get_db), + current_user: models.User = Depends(get_current_user) +): + """ + Получить публичную информацию о пользователе, включая его публичный ключ. + """ + user = db.query(models.User).filter(models.User.id == user_id).first() + + if not user: + raise HTTPException(status_code=404, detail="Пользователь не найден") + + return user diff --git a/srv/app/api/schemas.py b/srv/app/api/schemas.py index c51496a..224c2bd 100644 --- a/srv/app/api/schemas.py +++ b/srv/app/api/schemas.py @@ -1,4 +1,5 @@ from pydantic import BaseModel +from typing import Optional class SetPublicKey(BaseModel): public_key: str @@ -10,4 +11,14 @@ class SetupAccount(BaseModel): first_name: str last_name: str public_key: str - encrypted_private_key: str \ No newline at end of file + encrypted_private_key: str + +class UserPublic(BaseModel): + id: int + username: str + first_name: Optional[str] = None + last_name: Optional[str] = None + public_key: Optional[str] = None + + class Config: + from_attributes = True \ No newline at end of file diff --git a/srv/app/db/models.py b/srv/app/db/models.py index d9d36d1..0c21fc5 100644 --- a/srv/app/db/models.py +++ b/srv/app/db/models.py @@ -1,6 +1,8 @@ from sqlalchemy import Column, Integer, String, create_engine from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker +from sqlalchemy import Column, Integer, Text, ForeignKey, DateTime +from sqlalchemy.sql import func SQLALCHEMY_DATABASE_URL = "sqlite:///./chepuhagram.db" engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}) @@ -20,5 +22,13 @@ class User(Base): hashed_password = Column(String) public_key = Column(String, nullable=True) encrypted_private_key = Column(String, nullable=True) + +class Message(Base): + __tablename__ = "messages" + id = Column(Integer, primary_key=True, index=True) + sender_id = Column(Integer, ForeignKey("users.id")) + receiver_id = Column(Integer, ForeignKey("users.id")) + content = Column(Text) + timestamp = Column(DateTime(timezone=True), server_default=func.now()) Base.metadata.create_all(bind=engine) \ No newline at end of file diff --git a/srv/app/websocket/connection_manager.py b/srv/app/websocket/connection_manager.py index 7de7c18..1879d20 100644 --- a/srv/app/websocket/connection_manager.py +++ b/srv/app/websocket/connection_manager.py @@ -1,16 +1,28 @@ -from fastapi import HTTPException, status, APIRouter, WebSocket, WebSocketDisconnect, Query +from fastapi import HTTPException, status, APIRouter, WebSocket, WebSocketDisconnect, Query, Depends from app.core.security import test_token from typing import Dict from datetime import datetime +import json +from sqlalchemy.orm import Session +from app.db import models + + +# бд +def get_db(): + db = models.SessionLocal() + try: + yield db + finally: + db.close() + wsRouter = APIRouter( - prefix="/ws", - tags=[], + prefix='/ws' ) @wsRouter.websocket("") -async def websocket_endpoint(websocket: WebSocket, token: str = Query(None)): +async def websocket_endpoint(websocket: WebSocket, token: str = Query(None), db: Session = Depends(get_db)): if token is None: await websocket.close(code=status.WS_1008_POLICY_VIOLATION) return @@ -19,20 +31,39 @@ async def websocket_endpoint(websocket: WebSocket, token: str = Query(None)): except HTTPException: await websocket.close(code=status.WS_1008_POLICY_VIOLATION) return + print("ПОДКЛЮЧЕНИЕ") await manager.connect(user_id, websocket) + print("ПОДКЛЮЧЕНО") try: while True: - data = await websocket.receive_json() + print("ОЖИДАНИЕ СООБЩЕНИЙ") + data = await websocket.receive_text() + message_data = json.loads(data) + print(f"DEBUG: Получены данные: {message_data}") + + if message_data.get("type") == "private_message": + receiver_id = message_data.get("receiver_id") + content = message_data.get("content") + new_msg = models.Message( + sender_id=user_id, + receiver_id=receiver_id, + content=content + ) + db.add(new_msg) + db.commit() + db.refresh(new_msg) + # Формируем пакет для получателя + outgoing_message = { + "id": new_msg.id, + "type": "private_message", + "sender_id": user_id, + "reciever_id": receiver_id, + "content": message_data.get("content"), + "timestamp": datetime.now().isoformat() + } - receiver_id = str(data.get("receiver_id")) - message_to_send = { - "sender_id": user_id, - "text": data.get("text"), - "created_at": datetime.now().isoformat() - } - - await manager.send_personal_message(message_to_send, receiver_id) - await manager.send_personal_message(message_to_send, user_id) + # Пересылаем получателю, если он в сети + await manager.send_personal_message(outgoing_message, str(receiver_id)) except WebSocketDisconnect: pass finally: @@ -53,8 +84,11 @@ class ConnectionManager: del self.active_connections[user_id] async def send_personal_message(self, message: dict, user_id: str): - if user_id in self.active_connections: - await self.active_connections[user_id].send_json(message) + if str(user_id) in self.active_connections: + await self.active_connections[str(user_id)].send_json(message) + print('Sent to socket') + else: + print('User not active') async def broadcast(self, message: dict): # Рассылка вообще всем (например, системное уведомление) diff --git a/srv/main.py b/srv/main.py index d6389f9..ee50fd3 100644 --- a/srv/main.py +++ b/srv/main.py @@ -1,5 +1,5 @@ from fastapi import FastAPI -from app.api.endpoints import users, auth +from app.api.endpoints import users, auth, messages from app.websocket.connection_manager import wsRouter from fastapi.middleware.cors import CORSMiddleware @@ -7,6 +7,7 @@ app = FastAPI() app.include_router(auth.authRouter) app.include_router(users.usersRouter) +app.include_router(messages.messagesRouter) app.include_router(wsRouter) app.add_middleware(