feat(chat): add message pagination with cursor-based loading

Load initial 200 messages, auto-load older messages on scroll to top.
Uses cursor-based DAO queries for stable scroll position.
Also fixes pre-existing ackMessage compile errors in test fakes.

Closes #113
This commit is contained in:
Jens Reinemann 2026-05-18 21:32:47 +02:00
parent 7b21394cc6
commit 975976fd06
10 changed files with 203 additions and 6 deletions

View file

@ -45,4 +45,40 @@ internal interface MessageDao {
@Query("UPDATE messages SET is_read = 1 WHERE receiver_id = :myId AND sender_id = :senderId") @Query("UPDATE messages SET is_read = 1 WHERE receiver_id = :myId AND sender_id = :senderId")
suspend fun markConversationAsRead(myId: String, senderId: String) suspend fun markConversationAsRead(myId: String, senderId: String)
@Query("""
SELECT * FROM messages
WHERE (sender_id = :myId AND receiver_id = :otherId)
OR (sender_id = :otherId AND receiver_id = :myId)
ORDER BY sent_at DESC
LIMIT :limit
""")
suspend fun getLatestMessages(myId: String, otherId: String, limit: Int): List<MessageEntity>
@Query("""
SELECT * FROM messages
WHERE ((sender_id = :myId AND receiver_id = :otherId)
OR (sender_id = :otherId AND receiver_id = :myId))
AND sent_at < :beforeTimestamp
ORDER BY sent_at DESC
LIMIT :limit
""")
suspend fun getMessagesBefore(myId: String, otherId: String, beforeTimestamp: Long, limit: Int): List<MessageEntity>
@Query("""
SELECT COUNT(*) FROM messages
WHERE ((sender_id = :myId AND receiver_id = :otherId)
OR (sender_id = :otherId AND receiver_id = :myId))
AND sent_at < :beforeTimestamp
""")
suspend fun countOlderMessages(myId: String, otherId: String, beforeTimestamp: Long): Int
@Query("""
SELECT * FROM messages
WHERE ((sender_id = :myId AND receiver_id = :otherId)
OR (sender_id = :otherId AND receiver_id = :myId))
AND sent_at > :afterTimestamp
ORDER BY sent_at ASC
""")
fun getNewMessagesAfter(myId: String, otherId: String, afterTimestamp: Long): Flow<List<MessageEntity>>
} }

View file

@ -115,6 +115,18 @@ internal class MessageRepositoryImpl @Inject constructor(
override fun getConversation(myId: String, otherId: String): Flow<List<MessageEntity>> = override fun getConversation(myId: String, otherId: String): Flow<List<MessageEntity>> =
dao.getConversation(myId, otherId) dao.getConversation(myId, otherId)
override suspend fun getLatestMessages(myId: String, otherId: String, limit: Int): List<MessageEntity> =
withContext(Dispatchers.IO) { dao.getLatestMessages(myId, otherId, limit) }
override suspend fun getMessagesBefore(myId: String, otherId: String, beforeTimestamp: Long, limit: Int): List<MessageEntity> =
withContext(Dispatchers.IO) { dao.getMessagesBefore(myId, otherId, beforeTimestamp, limit) }
override suspend fun countOlderMessages(myId: String, otherId: String, beforeTimestamp: Long): Int =
withContext(Dispatchers.IO) { dao.countOlderMessages(myId, otherId, beforeTimestamp) }
override fun getNewMessagesAfter(myId: String, otherId: String, afterTimestamp: Long): Flow<List<MessageEntity>> =
dao.getNewMessagesAfter(myId, otherId, afterTimestamp)
override suspend fun sendMessage(recipientId: String, body: String) { override suspend fun sendMessage(recipientId: String, body: String) {
val myId = settingsRepository.getStringOrNull(StringKey.AuthUserId) ?: return val myId = settingsRepository.getStringOrNull(StringKey.AuthUserId) ?: return
val myUsername = settingsRepository.getString(StringKey.AuthUsername) val myUsername = settingsRepository.getString(StringKey.AuthUsername)

View file

@ -6,6 +6,10 @@ import kotlinx.coroutines.flow.Flow
internal interface MessageRepository { internal interface MessageRepository {
fun getConversation(myId: String, otherId: String): Flow<List<MessageEntity>> fun getConversation(myId: String, otherId: String): Flow<List<MessageEntity>>
suspend fun getLatestMessages(myId: String, otherId: String, limit: Int): List<MessageEntity>
suspend fun getMessagesBefore(myId: String, otherId: String, beforeTimestamp: Long, limit: Int): List<MessageEntity>
suspend fun countOlderMessages(myId: String, otherId: String, beforeTimestamp: Long): Int
fun getNewMessagesAfter(myId: String, otherId: String, afterTimestamp: Long): Flow<List<MessageEntity>>
suspend fun sendMessage(recipientId: String, body: String) suspend fun sendMessage(recipientId: String, body: String)
suspend fun fetchUsers(): Result<List<UserListItemDto>> suspend fun fetchUsers(): Result<List<UserListItemDto>>
suspend fun getMyUserId(): String? suspend fun getMyUserId(): String?

View file

@ -1,6 +1,7 @@
package de.bollwerk.app.ui.messaging package de.bollwerk.app.ui.messaging
import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.PaddingValues import androidx.compose.foundation.layout.PaddingValues
import androidx.compose.foundation.layout.Row import androidx.compose.foundation.layout.Row
@ -9,6 +10,7 @@ import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.imePadding import androidx.compose.foundation.layout.imePadding
import androidx.compose.foundation.layout.padding import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size
import androidx.compose.foundation.layout.width import androidx.compose.foundation.layout.width
import androidx.compose.foundation.layout.widthIn import androidx.compose.foundation.layout.widthIn
import androidx.compose.foundation.lazy.LazyColumn import androidx.compose.foundation.lazy.LazyColumn
@ -18,6 +20,7 @@ import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material.icons.Icons import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.automirrored.filled.ArrowBack import androidx.compose.material.icons.automirrored.filled.ArrowBack
import androidx.compose.material.icons.automirrored.filled.Send import androidx.compose.material.icons.automirrored.filled.Send
import androidx.compose.material3.CircularProgressIndicator
import androidx.compose.material3.ExperimentalMaterial3Api import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Icon import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton import androidx.compose.material3.IconButton
@ -29,14 +32,18 @@ import androidx.compose.material3.Text
import androidx.compose.material3.TopAppBar import androidx.compose.material3.TopAppBar
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.derivedStateOf
import androidx.compose.runtime.getValue import androidx.compose.runtime.getValue
import androidx.compose.runtime.remember import androidx.compose.runtime.remember
import androidx.compose.runtime.snapshotFlow
import androidx.compose.ui.Alignment import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import androidx.hilt.navigation.compose.hiltViewModel import androidx.hilt.navigation.compose.hiltViewModel
import androidx.lifecycle.compose.collectAsStateWithLifecycle import androidx.lifecycle.compose.collectAsStateWithLifecycle
import de.bollwerk.app.data.db.entity.MessageEntity import de.bollwerk.app.data.db.entity.MessageEntity
import kotlinx.coroutines.flow.distinctUntilChanged
import kotlinx.coroutines.flow.filter
import java.text.SimpleDateFormat import java.text.SimpleDateFormat
import java.util.Date import java.util.Date
import java.util.Locale import java.util.Locale
@ -48,14 +55,39 @@ internal fun ChatScreen(
viewModel: ChatViewModel = hiltViewModel() viewModel: ChatViewModel = hiltViewModel()
) { ) {
val uiState by viewModel.uiState.collectAsStateWithLifecycle() val uiState by viewModel.uiState.collectAsStateWithLifecycle()
val prependedCount by viewModel.prependedCount.collectAsStateWithLifecycle()
val listState = rememberLazyListState() val listState = rememberLazyListState()
// Scroll to bottom when new messages arrive (only if already near bottom)
val isNearBottom by remember {
derivedStateOf {
val lastVisible = listState.layoutInfo.visibleItemsInfo.lastOrNull()?.index ?: 0
lastVisible >= listState.layoutInfo.totalItemsCount - 3
}
}
LaunchedEffect(uiState.messages.size) { LaunchedEffect(uiState.messages.size) {
if (uiState.messages.isNotEmpty()) { if (uiState.messages.isNotEmpty() && isNearBottom) {
listState.animateScrollToItem(uiState.messages.size - 1) listState.animateScrollToItem(uiState.messages.size - 1)
} }
} }
// Maintain scroll position after prepending older messages
LaunchedEffect(prependedCount) {
if (prependedCount > 0) {
val targetIndex = listState.firstVisibleItemIndex + prependedCount
listState.scrollToItem(targetIndex, listState.firstVisibleItemScrollOffset)
viewModel.onPrependHandled()
}
}
// Auto-load more when scrolled to top
LaunchedEffect(listState) {
snapshotFlow { listState.firstVisibleItemIndex }
.distinctUntilChanged()
.filter { it == 0 }
.collect { viewModel.loadMore() }
}
Scaffold( Scaffold(
topBar = { topBar = {
TopAppBar( TopAppBar(
@ -83,6 +115,18 @@ internal fun ChatScreen(
contentPadding = PaddingValues(horizontal = 8.dp, vertical = 8.dp), contentPadding = PaddingValues(horizontal = 8.dp, vertical = 8.dp),
verticalArrangement = Arrangement.spacedBy(4.dp) verticalArrangement = Arrangement.spacedBy(4.dp)
) { ) {
if (uiState.isLoadingMore) {
item(key = "loading_indicator") {
Box(
modifier = Modifier
.fillMaxWidth()
.padding(8.dp),
contentAlignment = Alignment.Center
) {
CircularProgressIndicator(modifier = Modifier.size(24.dp))
}
}
}
items(uiState.messages, key = { it.id }) { message -> items(uiState.messages, key = { it.id }) { message ->
MessageBubble( MessageBubble(
message = message, message = message,

View file

@ -18,7 +18,9 @@ internal data class ChatUiState(
val myUserId: String = "", val myUserId: String = "",
val recipientUsername: String = "", val recipientUsername: String = "",
val inputText: String = "", val inputText: String = "",
val isSending: Boolean = false val isSending: Boolean = false,
val isLoadingMore: Boolean = false,
val hasOlderMessages: Boolean = true
) )
@HiltViewModel @HiltViewModel
@ -34,6 +36,14 @@ internal class ChatViewModel @Inject constructor(
private val _uiState = MutableStateFlow(ChatUiState(recipientUsername = recipientUsername)) private val _uiState = MutableStateFlow(ChatUiState(recipientUsername = recipientUsername))
val uiState: StateFlow<ChatUiState> = _uiState val uiState: StateFlow<ChatUiState> = _uiState
private val loadedMessages = mutableListOf<MessageEntity>()
private var oldestTimestamp: Long = Long.MAX_VALUE
private var newestTimestamp: Long = 0L
/** Number of items prepended in the last loadMore call, consumed by the UI for scroll adjustment. */
private val _prependedCount = MutableStateFlow(0)
val prependedCount: StateFlow<Int> = _prependedCount
init { init {
notificationHelper.setActiveChat(recipientId) notificationHelper.setActiveChat(recipientId)
notificationHelper.cancelNotificationForSender(recipientId) notificationHelper.cancelNotificationForSender(recipientId)
@ -41,14 +51,66 @@ internal class ChatViewModel @Inject constructor(
val myId = messageRepository.getMyUserId() ?: "" val myId = messageRepository.getMyUserId() ?: ""
_uiState.update { it.copy(myUserId = myId) } _uiState.update { it.copy(myUserId = myId) }
if (myId.isNotEmpty()) { if (myId.isNotEmpty()) {
messageRepository.getConversation(myId, recipientId).collect { messages -> loadInitialMessages(myId)
_uiState.update { it.copy(messages = messages) } observeNewMessages(myId)
messageRepository.markConversationAsRead(recipientId)
}
} }
} }
} }
private suspend fun loadInitialMessages(myId: String) {
val initial = messageRepository.getLatestMessages(myId, recipientId, PAGE_SIZE)
// getLatestMessages returns DESC order, reverse to ASC for display
val sorted = initial.sortedBy { it.sentAt }
loadedMessages.addAll(sorted)
if (sorted.isNotEmpty()) {
oldestTimestamp = sorted.first().sentAt
newestTimestamp = sorted.last().sentAt
}
val hasOlder = if (sorted.isNotEmpty()) {
messageRepository.countOlderMessages(myId, recipientId, oldestTimestamp) > 0
} else {
false
}
_uiState.update { it.copy(messages = loadedMessages.toList(), hasOlderMessages = hasOlder) }
messageRepository.markConversationAsRead(recipientId)
}
private suspend fun observeNewMessages(myId: String) {
messageRepository.getNewMessagesAfter(myId, recipientId, newestTimestamp).collect { newMessages ->
val truly = newMessages.filter { msg -> loadedMessages.none { it.id == msg.id } }
if (truly.isNotEmpty()) {
loadedMessages.addAll(truly)
newestTimestamp = loadedMessages.maxOf { it.sentAt }
_uiState.update { it.copy(messages = loadedMessages.toList()) }
messageRepository.markConversationAsRead(recipientId)
}
}
}
fun loadMore() {
if (_uiState.value.isLoadingMore || !_uiState.value.hasOlderMessages) return
val myId = _uiState.value.myUserId
if (myId.isEmpty()) return
viewModelScope.launch {
_uiState.update { it.copy(isLoadingMore = true) }
val older = messageRepository.getMessagesBefore(myId, recipientId, oldestTimestamp, PAGE_SIZE)
val sorted = older.sortedBy { it.sentAt }
if (sorted.isNotEmpty()) {
oldestTimestamp = sorted.first().sentAt
loadedMessages.addAll(0, sorted)
val hasOlder = messageRepository.countOlderMessages(myId, recipientId, oldestTimestamp) > 0
_prependedCount.value = sorted.size
_uiState.update { it.copy(messages = loadedMessages.toList(), isLoadingMore = false, hasOlderMessages = hasOlder) }
} else {
_uiState.update { it.copy(isLoadingMore = false, hasOlderMessages = false) }
}
}
}
fun onPrependHandled() {
_prependedCount.value = 0
}
override fun onCleared() { override fun onCleared() {
super.onCleared() super.onCleared()
notificationHelper.setActiveChat(null) notificationHelper.setActiveChat(null)
@ -67,4 +129,8 @@ internal class ChatViewModel @Inject constructor(
_uiState.update { it.copy(isSending = false) } _uiState.update { it.copy(isSending = false) }
} }
} }
private companion object {
const val PAGE_SIZE = 200
}
} }

View file

@ -176,6 +176,7 @@ private class FakeWebSocketClient : WebSocketClient {
suspend fun emit(event: WebSocketEvent) { _events.emit(event) } suspend fun emit(event: WebSocketEvent) { _events.emit(event) }
override fun connect(serverUrl: String, accessToken: String) {} override fun connect(serverUrl: String, accessToken: String) {}
override fun disconnect() {} override fun disconnect() {}
override fun ackMessage(messageId: String) = Unit
} }
private fun buildItem( private fun buildItem(

View file

@ -68,6 +68,26 @@ private class FakeMessageDao : MessageDao {
override suspend fun markConversationAsRead(myId: String, senderId: String) { override suspend fun markConversationAsRead(myId: String, senderId: String) {
markedAsRead.add(myId to senderId) markedAsRead.add(myId to senderId)
} }
override suspend fun getLatestMessages(myId: String, otherId: String, limit: Int): List<MessageEntity> =
upserted.filter {
(it.senderId == myId && it.receiverId == otherId) ||
(it.senderId == otherId && it.receiverId == myId)
}.sortedByDescending { it.sentAt }.take(limit)
override suspend fun getMessagesBefore(myId: String, otherId: String, beforeTimestamp: Long, limit: Int): List<MessageEntity> =
upserted.filter {
((it.senderId == myId && it.receiverId == otherId) ||
(it.senderId == otherId && it.receiverId == myId)) && it.sentAt < beforeTimestamp
}.sortedByDescending { it.sentAt }.take(limit)
override suspend fun countOlderMessages(myId: String, otherId: String, beforeTimestamp: Long): Int =
upserted.count {
((it.senderId == myId && it.receiverId == otherId) ||
(it.senderId == otherId && it.receiverId == myId)) && it.sentAt < beforeTimestamp
}
override fun getNewMessagesAfter(myId: String, otherId: String, afterTimestamp: Long): Flow<List<MessageEntity>> =
flowOf(upserted.filter {
((it.senderId == myId && it.receiverId == otherId) ||
(it.senderId == otherId && it.receiverId == myId)) && it.sentAt > afterTimestamp
})
} }
private class FakeMessageSettingsRepository : SettingsRepository { private class FakeMessageSettingsRepository : SettingsRepository {
@ -100,6 +120,7 @@ private class FakeMessageWsClient : WebSocketClient {
_connectionState _connectionState
override fun connect(serverUrl: String, accessToken: String) = Unit override fun connect(serverUrl: String, accessToken: String) = Unit
override fun disconnect() = Unit override fun disconnect() = Unit
override fun ackMessage(messageId: String) = Unit
} }
private fun buildFakeE2EEKeyManager(): E2EEKeyManager = mockk<E2EEKeyManager>(relaxed = true).also { private fun buildFakeE2EEKeyManager(): E2EEKeyManager = mockk<E2EEKeyManager>(relaxed = true).also {

View file

@ -32,6 +32,14 @@ private class FakeChatMessageRepository(
override val totalUnreadCount: Flow<Int> = flowOf(0) override val totalUnreadCount: Flow<Int> = flowOf(0)
override fun getConversation(myId: String, otherId: String): Flow<List<MessageEntity>> = conversation override fun getConversation(myId: String, otherId: String): Flow<List<MessageEntity>> = conversation
override suspend fun getLatestMessages(myId: String, otherId: String, limit: Int): List<MessageEntity> =
conversation.value.sortedByDescending { it.sentAt }.take(limit)
override suspend fun getMessagesBefore(myId: String, otherId: String, beforeTimestamp: Long, limit: Int): List<MessageEntity> =
conversation.value.filter { it.sentAt < beforeTimestamp }.sortedByDescending { it.sentAt }.take(limit)
override suspend fun countOlderMessages(myId: String, otherId: String, beforeTimestamp: Long): Int =
conversation.value.count { it.sentAt < beforeTimestamp }
override fun getNewMessagesAfter(myId: String, otherId: String, afterTimestamp: Long): Flow<List<MessageEntity>> =
flowOf(emptyList())
override fun getUnreadCountsBySender(): Flow<Map<String, Int>> = flowOf(emptyMap()) override fun getUnreadCountsBySender(): Flow<Map<String, Int>> = flowOf(emptyMap())
override suspend fun markConversationAsRead(senderId: String) = Unit override suspend fun markConversationAsRead(senderId: String) = Unit
override suspend fun sendMessage(recipientId: String, body: String) { override suspend fun sendMessage(recipientId: String, body: String) {

View file

@ -31,6 +31,10 @@ private class FakeUserListMessageRepository(
override val totalUnreadCount: Flow<Int> = flowOf(0) override val totalUnreadCount: Flow<Int> = flowOf(0)
override fun getConversation(myId: String, otherId: String): Flow<List<MessageEntity>> = override fun getConversation(myId: String, otherId: String): Flow<List<MessageEntity>> =
MutableStateFlow(emptyList()) MutableStateFlow(emptyList())
override suspend fun getLatestMessages(myId: String, otherId: String, limit: Int): List<MessageEntity> = emptyList()
override suspend fun getMessagesBefore(myId: String, otherId: String, beforeTimestamp: Long, limit: Int): List<MessageEntity> = emptyList()
override suspend fun countOlderMessages(myId: String, otherId: String, beforeTimestamp: Long): Int = 0
override fun getNewMessagesAfter(myId: String, otherId: String, afterTimestamp: Long): Flow<List<MessageEntity>> = flowOf(emptyList())
override fun getUnreadCountsBySender(): Flow<Map<String, Int>> = flowOf(emptyMap()) override fun getUnreadCountsBySender(): Flow<Map<String, Int>> = flowOf(emptyMap())
override suspend fun markConversationAsRead(senderId: String) = Unit override suspend fun markConversationAsRead(senderId: String) = Unit
override suspend fun sendMessage(recipientId: String, body: String) = Unit override suspend fun sendMessage(recipientId: String, body: String) = Unit

View file

@ -912,6 +912,7 @@ private class FakeWebSocketClient : WebSocketClient {
var connectedUrl: String? = null var connectedUrl: String? = null
override fun connect(serverUrl: String, accessToken: String) { connectedUrl = serverUrl } override fun connect(serverUrl: String, accessToken: String) { connectedUrl = serverUrl }
override fun disconnect() { connectedUrl = null } override fun disconnect() { connectedUrl = null }
override fun ackMessage(messageId: String) = Unit
} }
// endregion // endregion