diff --git a/plugins/amazonq/chat/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonq/workspace/context/ProjectContextProviderTest.kt b/plugins/amazonq/chat/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonq/workspace/context/ProjectContextProviderTest.kt index 700f3a2024e..8cdf146948b 100644 --- a/plugins/amazonq/chat/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonq/workspace/context/ProjectContextProviderTest.kt +++ b/plugins/amazonq/chat/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/amazonq/workspace/context/ProjectContextProviderTest.kt @@ -17,6 +17,8 @@ import com.intellij.openapi.application.ApplicationManager import com.intellij.openapi.project.Project import com.intellij.testFramework.DisposableRule import com.intellij.testFramework.replaceService +import io.mockk.every +import io.mockk.spyk import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.TimeoutCancellationException import kotlinx.coroutines.test.StandardTestDispatcher @@ -44,6 +46,7 @@ import software.aws.toolkits.jetbrains.services.amazonq.project.InlineBm25Chunk import software.aws.toolkits.jetbrains.services.amazonq.project.InlineContextTarget import software.aws.toolkits.jetbrains.services.amazonq.project.LspMessage import software.aws.toolkits.jetbrains.services.amazonq.project.ProjectContextProvider +import software.aws.toolkits.jetbrains.services.amazonq.project.ProjectContextProvider.FileCollectionResult import software.aws.toolkits.jetbrains.services.amazonq.project.QueryChatRequest import software.aws.toolkits.jetbrains.services.amazonq.project.QueryInlineCompletionRequest import software.aws.toolkits.jetbrains.services.amazonq.project.RelevantDocument @@ -82,8 +85,8 @@ class ProjectContextProviderTest { fun setup() { encoderServer = spy(EncoderServer(project)) encoderServer.stub { on { port } doReturn wireMock.port() } - - sut = ProjectContextProvider(project, encoderServer, TestScope(context = dispatcher)) + encoderServer.stub { on { isNodeProcessRunning() } doReturn true } + sut = spyk(ProjectContextProvider(project, encoderServer, TestScope(context = dispatcher))) // initialization stubFor(any(urlPathEqualTo("/initialize")).willReturn(aResponse().withStatus(200).withResponseBody(Body("initialize response")))) @@ -143,7 +146,10 @@ class ProjectContextProviderTest { projectRule.fixture.addFileToProject("Foo.java", "foo") projectRule.fixture.addFileToProject("Bar.java", "bar") projectRule.fixture.addFileToProject("Baz.java", "baz") - + every { sut.collectFiles() } returns FileCollectionResult( + files = listOf("Foo.java", "Bar.java", "Baz.java"), + fileSize = 10 + ) sut.index() val request = IndexRequest(listOf("/src/Foo.java", "/src/Bar.java", "/src/Baz.java"), "/src", "all", "") @@ -175,7 +181,10 @@ class ProjectContextProviderTest { projectRule.fixture.addFileToProject("Foo.java", "foo") projectRule.fixture.addFileToProject("Bar.java", "bar") projectRule.fixture.addFileToProject("Baz.java", "baz") - + every { sut.collectFiles() } returns FileCollectionResult( + files = listOf("Foo.java", "Bar.java", "Baz.java"), + fileSize = 10 + ) sut.index() val request = IndexRequest(listOf("/src/Foo.java", "/src/Bar.java", "/src/Baz.java"), "/src", "default", "") @@ -408,6 +417,10 @@ class ProjectContextProviderTest { @Test fun `test index payload is encrypted`() = runTest { whenever(encoderServer.port).thenReturn(3000) + every { sut.collectFiles() } returns FileCollectionResult( + files = listOf("Foo.java", "Bar.java", "Baz.java"), + fileSize = 10 + ) try { sut.index() } catch (e: ConnectException) { diff --git a/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/project/ProjectContextProvider.kt b/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/project/ProjectContextProvider.kt index 4a516bcfdb9..ba6897e87e8 100644 --- a/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/project/ProjectContextProvider.kt +++ b/plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/project/ProjectContextProvider.kt @@ -130,7 +130,7 @@ class ProjectContextProvider(val project: Project, private val encoderServer: En private fun initEncryption(): Boolean { val request = encoderServer.getEncryptionRequest() val response = sendMsgToLsp(LspMessage.Initialize, request) - return response.responseCode == 200 + return response?.responseCode == 200 } fun index(): Boolean { @@ -138,6 +138,10 @@ class ProjectContextProvider(val project: Project, private val encoderServer: En val indexStartTime = System.currentTimeMillis() val filesResult = collectFiles() + if (filesResult.files.isEmpty()) { + logger.warn { "No file found in workspace" } + return false + } var duration = (System.currentTimeMillis() - indexStartTime).toDouble() logger.debug { "time elapsed to collect project context files: ${duration}ms, collected ${filesResult.files.size} files" } @@ -149,12 +153,13 @@ class ProjectContextProvider(val project: Project, private val encoderServer: En logger.debug { "project context index time: ${duration}ms" } val startUrl = getStartUrl(project) - if (response.responseCode == 200) { + if (response?.responseCode == 200) { val usage = getUsage() recordIndexWorkspace(duration, filesResult.files.size, filesResult.fileSize, true, usage?.memoryUsage, usage?.cpuUsage, startUrl) logger.debug { "project context index finished for ${project.name}" } return true } else { + logger.debug { "project context index failed" } recordIndexWorkspace(duration, filesResult.files.size, filesResult.fileSize, false, null, null, startUrl) return false } @@ -164,8 +169,7 @@ class ProjectContextProvider(val project: Project, private val encoderServer: En suspend fun query(prompt: String, timeout: Long?): List = withTimeout(timeout ?: CHAT_EXPLICIT_PROJECT_CONTEXT_TIMEOUT) { cs.async { val encrypted = encryptRequest(QueryChatRequest(prompt)) - val response = sendMsgToLsp(LspMessage.QueryChat, encrypted) - + val response = sendMsgToLsp(LspMessage.QueryChat, encrypted) ?: return@async emptyList() val parsedResponse = mapper.readValue>(response.responseBody) queryResultToRelevantDocuments(parsedResponse) }.await() @@ -174,13 +178,13 @@ class ProjectContextProvider(val project: Project, private val encoderServer: En suspend fun queryInline(query: String, filePath: String, target: InlineContextTarget): List = withTimeout(SUPPLEMENTAL_CONTEXT_TIMEOUT) { cs.async { val encrypted = encryptRequest(QueryInlineCompletionRequest(query, filePath, target.toString())) - val r = sendMsgToLsp(LspMessage.QueryInlineCompletion, encrypted) + val r = sendMsgToLsp(LspMessage.QueryInlineCompletion, encrypted) ?: return@async emptyList() return@async mapper.readValue>(r.responseBody) }.await() } fun getUsage(): Usage? { - val response = sendMsgToLsp(LspMessage.GetUsageMetrics, request = null) + val response = sendMsgToLsp(LspMessage.GetUsageMetrics, request = null) ?: return null return try { val parsedResponse = mapper.readValue(response.responseBody) parsedResponse @@ -246,7 +250,7 @@ class ProjectContextProvider(val project: Project, private val encoderServer: En return regex.find(fileName) != null } - private fun collectFiles(): FileCollectionResult { + fun collectFiles(): FileCollectionResult { val collectedFiles = mutableListOf() var currentTotalFileSize = 0L val featureDevSessionContext = FeatureDevSessionContext(project) @@ -306,9 +310,13 @@ class ProjectContextProvider(val project: Project, private val encoderServer: En return encoderServer.encrypt(payloadJson) } - private fun sendMsgToLsp(msgType: LspMessage, request: String?): LspResponse { + private fun sendMsgToLsp(msgType: LspMessage, request: String?): LspResponse? { logger.info { "sending message: ${msgType.endpoint} to lsp on port ${encoderServer.port}" } val url = URL("http://localhost:${encoderServer.port}/${msgType.endpoint}") + if (!encoderServer.isNodeProcessRunning()) { + logger.warn { "language server is not running" } + return null + } // use 1h as timeout for index, 5 seconds for other APIs val timeoutMs = if (msgType is LspMessage.Index) 60.minutes.inWholeMilliseconds.toInt() else 5000 return with(url.openConnection() as HttpURLConnection) {