diff --git a/plugins/amazonq/chat/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/webview/BrowserConnector.kt b/plugins/amazonq/chat/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/webview/BrowserConnector.kt index d44a33b51e9..8db334596df 100644 --- a/plugins/amazonq/chat/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/webview/BrowserConnector.kt +++ b/plugins/amazonq/chat/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/webview/BrowserConnector.kt @@ -140,6 +140,7 @@ class BrowserConnector( // this is sent when the named agents UI is ready "ui-is-ready" -> { uiReady.complete(true) + chatCommunicationManager.setUiReady() RunOnceUtil.runOnceForApp("AmazonQ-UI-Ready") { MeetQSettings.getInstance().reinvent2024OnboardingCount += 1 } @@ -324,6 +325,7 @@ class BrowserConnector( CHAT_READY -> { handleChatNotification(node) { server, _ -> uiReady.complete(true) + chatCommunicationManager.setUiReady() RunOnceUtil.runOnceForApp("AmazonQ-UI-Ready") { MeetQSettings.getInstance().reinvent2024OnboardingCount += 1 } @@ -349,7 +351,7 @@ class BrowserConnector( } CHAT_OPEN_TAB -> { val response = serializer.deserializeChatMessages(node) - ChatCommunicationManager.completeTabOpen( + chatCommunicationManager.completeTabOpen( response.requestId, response.params.result.tabId ) @@ -420,7 +422,7 @@ class BrowserConnector( GET_SERIALIZED_CHAT_REQUEST_METHOD -> { val response = serializer.deserializeChatMessages(node) - ChatCommunicationManager.completeSerializedChatResponse( + chatCommunicationManager.completeSerializedChatResponse( response.requestId, response.params.result.content ) diff --git a/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/AmazonQLanguageClientImpl.kt b/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/AmazonQLanguageClientImpl.kt index 8da47a7d33a..f1264d1603f 100644 --- a/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/AmazonQLanguageClientImpl.kt +++ b/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/AmazonQLanguageClientImpl.kt @@ -137,9 +137,10 @@ class AmazonQLanguageClientImpl(private val project: Project) : AmazonQLanguageC override fun openTab(params: OpenTabParams): CompletableFuture { val requestId = UUID.randomUUID().toString() val result = CompletableFuture() - ChatCommunicationManager.pendingTabRequests[requestId] = result + val chatManager = ChatCommunicationManager.getInstance(project) + chatManager.addTabOpenRequest(requestId, result) - AsyncChatUiListener.notifyPartialMessageUpdate( + chatManager.notifyUi( FlareUiMessage( command = CHAT_OPEN_TAB, params = params, @@ -149,7 +150,7 @@ class AmazonQLanguageClientImpl(private val project: Project) : AmazonQLanguageC result.orTimeout(30000, TimeUnit.MILLISECONDS) .whenComplete { _, error -> - ChatCommunicationManager.pendingTabRequests.remove(requestId) + chatManager.removeTabOpenRequest(requestId) } return result @@ -188,10 +189,10 @@ class AmazonQLanguageClientImpl(private val project: Project) : AmazonQLanguageC override fun getSerializedChat(params: GetSerializedChatParams): CompletableFuture { val requestId = UUID.randomUUID().toString() val result = CompletableFuture() + val chatManager = ChatCommunicationManager.getInstance(project) + chatManager.addSerializedChatRequest(requestId, result) - ChatCommunicationManager.pendingSerializedChatRequests[requestId] = result - - AsyncChatUiListener.notifyPartialMessageUpdate( + chatManager.notifyUi( FlareUiMessage( command = GET_SERIALIZED_CHAT_REQUEST_METHOD, params = params, @@ -201,7 +202,7 @@ class AmazonQLanguageClientImpl(private val project: Project) : AmazonQLanguageC result.orTimeout(30000, TimeUnit.MILLISECONDS) .whenComplete { _, error -> - ChatCommunicationManager.pendingSerializedChatRequests.remove(requestId) + chatManager.removeSerializedChatRequest(requestId) } return result @@ -340,13 +341,13 @@ class AmazonQLanguageClientImpl(private val project: Project) : AmazonQLanguageC ) override fun sendContextCommands(params: LSPAny): CompletableFuture { - AsyncChatUiListener.notifyPartialMessageUpdate( + val chatManager = ChatCommunicationManager.getInstance(project) + chatManager.notifyUi( FlareUiMessage( command = CHAT_SEND_CONTEXT_COMMANDS, params = params ?: error("received empty payload for $CHAT_SEND_CONTEXT_COMMANDS"), ) ) - return CompletableFuture.completedFuture(Unit) } diff --git a/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/flareChat/ChatCommunicationManager.kt b/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/flareChat/ChatCommunicationManager.kt index 2fdb765ec53..95255ac02bd 100644 --- a/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/flareChat/ChatCommunicationManager.kt +++ b/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/lsp/flareChat/ChatCommunicationManager.kt @@ -7,6 +7,9 @@ import com.google.gson.Gson import com.intellij.openapi.components.Service import com.intellij.openapi.components.service import com.intellij.openapi.project.Project +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.launch import org.eclipse.lsp4j.ProgressParams import software.aws.toolkits.core.utils.getLogger import software.aws.toolkits.core.utils.warn @@ -29,12 +32,23 @@ import java.util.concurrent.CompletableFuture import java.util.concurrent.ConcurrentHashMap @Service(Service.Level.PROJECT) -class ChatCommunicationManager { +class ChatCommunicationManager(private val cs: CoroutineScope) { + val uiReady = CompletableDeferred() private val chatPartialResultMap = ConcurrentHashMap() - private fun getPartialChatMessage(partialResultToken: String): String? = - chatPartialResultMap.getOrDefault(partialResultToken, null) - private val inflightRequestByTabId = ConcurrentHashMap>() + private val pendingSerializedChatRequests = ConcurrentHashMap>() + private val pendingTabRequests = ConcurrentHashMap>() + + fun setUiReady() { + uiReady.complete(true) + } + + fun notifyUi(uiMessage: FlareUiMessage) { + cs.launch { + uiReady.await() + AsyncChatUiListener.notifyPartialMessageUpdate(uiMessage) + } + } fun setInflightRequestForTab(tabId: String, result: CompletableFuture) { inflightRequestByTabId[tabId] = result @@ -53,9 +67,36 @@ class ChatCommunicationManager { return partialResultToken } + private fun getPartialChatMessage(partialResultToken: String): String? = + chatPartialResultMap.getOrDefault(partialResultToken, null) + fun removePartialChatMessage(partialResultToken: String) = chatPartialResultMap.remove(partialResultToken) + fun addSerializedChatRequest(requestId: String, result: CompletableFuture) { + pendingSerializedChatRequests[requestId] = result + } + + fun completeSerializedChatResponse(requestId: String, content: String) { + pendingSerializedChatRequests.remove(requestId)?.complete(GetSerializedChatResult((content))) + } + + fun removeSerializedChatRequest(requestId: String) { + pendingSerializedChatRequests.remove(requestId) + } + + fun addTabOpenRequest(requestId: String, result: CompletableFuture) { + pendingTabRequests[requestId] = result + } + + fun completeTabOpen(requestId: String, tabId: String) { + pendingTabRequests.remove(requestId)?.complete(OpenTabResult(tabId)) + } + + fun removeTabOpenRequest(requestId: String) { + pendingTabRequests.remove(requestId) + } + fun handlePartialResultProgressNotification(project: Project, params: ProgressParams) { val token = ProgressNotificationUtils.getToken(params) val tabId = getPartialChatMessage(token) @@ -134,11 +175,6 @@ class ChatCommunicationManager { private val LOG = getLogger() - val pendingSerializedChatRequests = ConcurrentHashMap>() - fun completeSerializedChatResponse(requestId: String, content: String) { - pendingSerializedChatRequests.remove(requestId)?.complete(GetSerializedChatResult((content))) - } - fun convertToJsonToSendToChat(command: String, tabId: String, params: String, isPartialResult: Boolean): String = """ { @@ -148,11 +184,5 @@ class ChatCommunicationManager { "isPartialResult": $isPartialResult } """.trimIndent() - - val pendingTabRequests = ConcurrentHashMap>() - - fun completeTabOpen(requestId: String, tabId: String) { - pendingTabRequests.remove(requestId)?.complete(OpenTabResult(tabId)) - } } }