Skip to content

Commit fca5e2f

Browse files
authored
fix(amazonq): switch to ulong to avoid overflow when input is larger than 2gb (#5558)
2GB in bytes > INT_MAX so use ULong, which can handle 18 PB
1 parent 6947870 commit fca5e2f

File tree

5 files changed

+104
-62
lines changed

5 files changed

+104
-62
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type" : "bugfix",
3+
"description" : "Fix integer overflow when local context index input is larger than 2GB"
4+
}

plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/settings/CodeWhispererConfigurable.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ class CodeWhispererConfigurable(private val project: Project) :
133133

134134
row(message("aws.settings.codewhisperer.project_context_index_thread")) {
135135
intTextField(
136-
range = IntRange(0, 50)
136+
range = CodeWhispererSettings.CONTEXT_INDEX_THREADS
137137
).bindIntText(codeWhispererSettings::getProjectContextIndexThreadCount, codeWhispererSettings::setProjectContextIndexThreadCount)
138138
.apply {
139139
connect.subscribe(
@@ -150,7 +150,7 @@ class CodeWhispererConfigurable(private val project: Project) :
150150

151151
row(message("aws.settings.codewhisperer.project_context_index_max_size")) {
152152
intTextField(
153-
range = IntRange(1, 4096)
153+
range = CodeWhispererSettings.CONTEXT_INDEX_SIZE
154154
).bindIntText(codeWhispererSettings::getProjectContextIndexMaxSize, codeWhispererSettings::setProjectContextIndexMaxSize)
155155
.apply {
156156
connect.subscribe(

plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererSettingsTest.kt

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,41 @@ class CodeWhispererSettingsTest : CodeWhispererTestBase() {
211211
assertThat(actual.autoBuildSetting).hasSize(1)
212212
assertThat(actual.autoBuildSetting["project1"]).isTrue()
213213
}
214+
215+
@Test
216+
fun `context thread count is returned in range`() {
217+
val sut = CodeWhispererSettings.getInstance()
218+
219+
mapOf(
220+
1 to 1,
221+
0 to 0,
222+
-1 to 0,
223+
123 to 50,
224+
50 to 50,
225+
51 to 50,
226+
).forEach { s, expected ->
227+
sut.setProjectContextIndexThreadCount(s)
228+
assertThat(sut.getProjectContextIndexThreadCount()).isEqualTo(expected)
229+
}
230+
}
231+
232+
@Test
233+
fun `context index size is returned in range`() {
234+
val sut = CodeWhispererSettings.getInstance()
235+
236+
mapOf(
237+
1 to 1,
238+
0 to 1,
239+
-1 to 1,
240+
123 to 123,
241+
2047 to 2047,
242+
4096 to 4096,
243+
4097 to 4096,
244+
).forEach { s, expected ->
245+
sut.setProjectContextIndexMaxSize(s)
246+
assertThat(sut.getProjectContextIndexMaxSize()).isEqualTo(expected)
247+
}
248+
}
214249
}
215250

216251
class CodeWhispererSettingUnitTest {

plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/services/amazonq/project/ProjectContextProvider.kt

Lines changed: 54 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class ProjectContextProvider(val project: Project, private val encoderServer: En
7070

7171
data class FileCollectionResult(
7272
val files: List<String>,
73-
val fileSize: Int,
73+
val fileSize: Int, // in MB
7474
)
7575

7676
// TODO: move to LspMessage.kt
@@ -246,59 +246,7 @@ class ProjectContextProvider(val project: Project, private val encoderServer: En
246246
}
247247
}
248248

249-
private fun willExceedPayloadLimit(currentTotalFileSize: Long, currentFileSize: Long): Boolean {
250-
val maxSize = CodeWhispererSettings.getInstance().getProjectContextIndexMaxSize()
251-
return currentTotalFileSize.let { totalSize -> totalSize > (maxSize * 1024 * 1024 - currentFileSize) }
252-
}
253-
254-
private fun isBuildOrBin(fileName: String): Boolean {
255-
val regex = Regex("""bin|build|node_modules|venv|\.venv|env|\.idea|\.conda""", RegexOption.IGNORE_CASE)
256-
return regex.find(fileName) != null
257-
}
258-
259-
fun collectFiles(): FileCollectionResult {
260-
val collectedFiles = mutableListOf<String>()
261-
var currentTotalFileSize = 0L
262-
val allFiles = mutableListOf<VirtualFile>()
263-
264-
val projectBaseDirectories = project.getBaseDirectories()
265-
val changeListManager = ChangeListManager.getInstance(project)
266-
267-
projectBaseDirectories.forEach {
268-
VfsUtilCore.visitChildrenRecursively(
269-
it,
270-
object : VirtualFileVisitor<Unit>(NO_FOLLOW_SYMLINKS) {
271-
// TODO: refactor this along with /dev & codescan file traversing logic
272-
override fun visitFile(file: VirtualFile): Boolean {
273-
if ((file.isDirectory && isBuildOrBin(file.name)) ||
274-
!isWorkspaceSourceContent(file, projectBaseDirectories, changeListManager, additionalGlobalIgnoreRulesForStrictSources) ||
275-
(file.isFile && file.length > 10 * 1024 * 1024)
276-
) {
277-
return false
278-
}
279-
if (file.isFile) {
280-
allFiles.add(file)
281-
return false
282-
}
283-
return true
284-
}
285-
}
286-
)
287-
}
288-
289-
for (file in allFiles) {
290-
if (willExceedPayloadLimit(currentTotalFileSize, file.length)) {
291-
break
292-
}
293-
collectedFiles.add(file.path)
294-
currentTotalFileSize += file.length
295-
}
296-
297-
return FileCollectionResult(
298-
files = collectedFiles.toList(),
299-
fileSize = (currentTotalFileSize / 1024 / 1024).toInt()
300-
)
301-
}
249+
fun collectFiles(): FileCollectionResult = collectFiles(project.getBaseDirectories(), ChangeListManager.getInstance(project))
302250

303251
private fun queryResultToRelevantDocuments(queryResult: List<Chunk>): List<RelevantDocument> {
304252
val documents: MutableList<RelevantDocument> = mutableListOf()
@@ -358,5 +306,57 @@ class ProjectContextProvider(val project: Project, private val encoderServer: En
358306

359307
companion object {
360308
private val logger = getLogger<ProjectContextProvider>()
309+
private val regex = Regex("""bin|build|node_modules|venv|\.venv|env|\.idea|\.conda""", RegexOption.IGNORE_CASE)
310+
private val mega = (1024 * 1024).toULong()
311+
private val tenMb = 10 * mega.toInt()
312+
313+
private fun willExceedPayloadLimit(maxSize: ULong, currentTotalFileSize: ULong, currentFileSize: Long) =
314+
currentTotalFileSize.let { totalSize -> totalSize > (maxSize - currentFileSize.toUInt()) }
315+
316+
private fun isBuildOrBin(fileName: String): Boolean =
317+
regex.find(fileName) != null
318+
319+
fun collectFiles(projectBaseDirectories: Set<VirtualFile>, changeListManager: ChangeListManager): FileCollectionResult {
320+
val maxSize = CodeWhispererSettings.getInstance()
321+
.getProjectContextIndexMaxSize().toULong() * mega
322+
val collectedFiles = mutableListOf<String>()
323+
var currentTotalFileSize = 0UL
324+
val allFiles = mutableListOf<VirtualFile>()
325+
326+
projectBaseDirectories.forEach {
327+
VfsUtilCore.visitChildrenRecursively(
328+
it,
329+
object : VirtualFileVisitor<Unit>(NO_FOLLOW_SYMLINKS) {
330+
// TODO: refactor this along with /dev & codescan file traversing logic
331+
override fun visitFile(file: VirtualFile): Boolean {
332+
if ((file.isDirectory && isBuildOrBin(file.name)) ||
333+
!isWorkspaceSourceContent(file, projectBaseDirectories, changeListManager, additionalGlobalIgnoreRulesForStrictSources) ||
334+
(file.isFile && file.length > tenMb)
335+
) {
336+
return false
337+
}
338+
if (file.isFile) {
339+
allFiles.add(file)
340+
return false
341+
}
342+
return true
343+
}
344+
}
345+
)
346+
}
347+
348+
for (file in allFiles) {
349+
if (willExceedPayloadLimit(maxSize, currentTotalFileSize, file.length)) {
350+
break
351+
}
352+
collectedFiles.add(file.path)
353+
currentTotalFileSize += file.length.toUInt()
354+
}
355+
356+
return FileCollectionResult(
357+
files = collectedFiles.toList(),
358+
fileSize = (currentTotalFileSize / 1024u / 1024u).toInt()
359+
)
360+
}
361361
}
362362
}

plugins/amazonq/shared/jetbrains-community/src/software/aws/toolkits/jetbrains/settings/CodeWhispererSettings.kt

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ class CodeWhispererSettings : PersistentStateComponent<CodeWhispererConfiguratio
9292
fun getProjectContextIndexThreadCount(): Int = state.intValue.getOrDefault(
9393
CodeWhispererIntConfigurationType.ProjectContextIndexThreadCount,
9494
0
95-
)
95+
).coerceIn(CONTEXT_INDEX_THREADS)
9696

9797
fun setProjectContextIndexThreadCount(value: Int) {
9898
state.intValue[CodeWhispererIntConfigurationType.ProjectContextIndexThreadCount] = value
@@ -101,7 +101,7 @@ class CodeWhispererSettings : PersistentStateComponent<CodeWhispererConfiguratio
101101
fun getProjectContextIndexMaxSize(): Int = state.intValue.getOrDefault(
102102
CodeWhispererIntConfigurationType.ProjectContextIndexMaxSize,
103103
250
104-
)
104+
).coerceIn(CONTEXT_INDEX_SIZE)
105105

106106
fun setProjectContextIndexMaxSize(value: Int) {
107107
state.intValue[CodeWhispererIntConfigurationType.ProjectContextIndexMaxSize] = value
@@ -134,10 +134,6 @@ class CodeWhispererSettings : PersistentStateComponent<CodeWhispererConfiguratio
134134
state.value[CodeWhispererConfigurationType.IsTabAcceptPriorityNotificationShownOnce] = value
135135
}
136136

137-
companion object {
138-
fun getInstance(): CodeWhispererSettings = service()
139-
}
140-
141137
override fun getState(): CodeWhispererConfiguration = CodeWhispererConfiguration().apply {
142138
value.putAll(state.value)
143139
intValue.putAll(state.intValue)
@@ -155,6 +151,13 @@ class CodeWhispererSettings : PersistentStateComponent<CodeWhispererConfiguratio
155151
this.state.stringValue.putAll(state.stringValue)
156152
this.state.autoBuildSetting.putAll(state.autoBuildSetting)
157153
}
154+
155+
companion object {
156+
fun getInstance(): CodeWhispererSettings = service()
157+
158+
val CONTEXT_INDEX_SIZE = IntRange(1, 4096)
159+
val CONTEXT_INDEX_THREADS = IntRange(0, 50)
160+
}
158161
}
159162

160163
class CodeWhispererConfiguration : BaseState() {

0 commit comments

Comments
 (0)