Skip to content

fix(amazonq): switch to ulong to avoid overflow when input is larger than 2gb #5558

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 17, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type" : "bugfix",
"description" : "Fix integer overflow when local context index input is larger than 2GB"
}
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ class CodeWhispererConfigurable(private val project: Project) :

row(message("aws.settings.codewhisperer.project_context_index_thread")) {
intTextField(
range = IntRange(0, 50)
range = CodeWhispererSettings.CONTEXT_INDEX_THREADS
).bindIntText(codeWhispererSettings::getProjectContextIndexThreadCount, codeWhispererSettings::setProjectContextIndexThreadCount)
.apply {
connect.subscribe(
Expand All @@ -150,7 +150,7 @@ class CodeWhispererConfigurable(private val project: Project) :

row(message("aws.settings.codewhisperer.project_context_index_max_size")) {
intTextField(
range = IntRange(1, 4096)
range = CodeWhispererSettings.CONTEXT_INDEX_SIZE
).bindIntText(codeWhispererSettings::getProjectContextIndexMaxSize, codeWhispererSettings::setProjectContextIndexMaxSize)
.apply {
connect.subscribe(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,42 @@ class CodeWhispererSettingsTest : CodeWhispererTestBase() {
assertThat(actual.autoBuildSetting).hasSize(1)
assertThat(actual.autoBuildSetting["project1"]).isTrue()
}

@Test
fun `context thread count is returned in range`() {
val sut = CodeWhispererSettings.getInstance()

mapOf(
1 to 1,
0 to 0,
-1 to 0,
123 to 50,
50 to 50,
51 to 50,
).forEach { s, expected ->
sut.setProjectContextIndexThreadCount(s)
assertThat(sut.getProjectContextIndexThreadCount()).isEqualTo(expected)
}

}

@Test
fun `context index size is returned in range`() {
val sut = CodeWhispererSettings.getInstance()

mapOf(
1 to 1,
0 to 1,
-1 to 1,
123 to 123,
2047 to 2047,
4096 to 4096,
4097 to 4096,
).forEach { s, expected ->
sut.setProjectContextIndexMaxSize(s)
assertThat(sut.getProjectContextIndexMaxSize()).isEqualTo(expected)
}
}
}

class CodeWhispererSettingUnitTest {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class ProjectContextProvider(val project: Project, private val encoderServer: En

data class FileCollectionResult(
val files: List<String>,
val fileSize: Int,
val fileSize: Int, // in MB
)

// TODO: move to LspMessage.kt
Expand Down Expand Up @@ -241,59 +241,7 @@ class ProjectContextProvider(val project: Project, private val encoderServer: En
}
}

private fun willExceedPayloadLimit(currentTotalFileSize: Long, currentFileSize: Long): Boolean {
val maxSize = CodeWhispererSettings.getInstance().getProjectContextIndexMaxSize()
return currentTotalFileSize.let { totalSize -> totalSize > (maxSize * 1024 * 1024 - currentFileSize) }
}

private fun isBuildOrBin(fileName: String): Boolean {
val regex = Regex("""bin|build|node_modules|venv|\.venv|env|\.idea|\.conda""", RegexOption.IGNORE_CASE)
return regex.find(fileName) != null
}

fun collectFiles(): FileCollectionResult {
val collectedFiles = mutableListOf<String>()
var currentTotalFileSize = 0L
val allFiles = mutableListOf<VirtualFile>()

val projectBaseDirectories = project.getBaseDirectories()
val changeListManager = ChangeListManager.getInstance(project)

projectBaseDirectories.forEach {
VfsUtilCore.visitChildrenRecursively(
it,
object : VirtualFileVisitor<Unit>(NO_FOLLOW_SYMLINKS) {
// TODO: refactor this along with /dev & codescan file traversing logic
override fun visitFile(file: VirtualFile): Boolean {
if ((file.isDirectory && isBuildOrBin(file.name)) ||
!isWorkspaceSourceContent(file, projectBaseDirectories, changeListManager, additionalGlobalIgnoreRulesForStrictSources) ||
(file.isFile && file.length > 10 * 1024 * 1024)
) {
return false
}
if (file.isFile) {
allFiles.add(file)
return false
}
return true
}
}
)
}

for (file in allFiles) {
if (willExceedPayloadLimit(currentTotalFileSize, file.length)) {
break
}
collectedFiles.add(file.path)
currentTotalFileSize += file.length
}

return FileCollectionResult(
files = collectedFiles.toList(),
fileSize = (currentTotalFileSize / 1024 / 1024).toInt()
)
}
fun collectFiles(): FileCollectionResult = collectFiles(project.getBaseDirectories(), ChangeListManager.getInstance(project))

private fun queryResultToRelevantDocuments(queryResult: List<Chunk>): List<RelevantDocument> {
val documents: MutableList<RelevantDocument> = mutableListOf()
Expand Down Expand Up @@ -353,5 +301,58 @@ class ProjectContextProvider(val project: Project, private val encoderServer: En

companion object {
private val logger = getLogger<ProjectContextProvider>()

private fun willExceedPayloadLimit(maxSize: ULong, currentTotalFileSize: ULong, currentFileSize: Long) =
currentTotalFileSize.let { totalSize -> totalSize > (maxSize - currentFileSize.toUInt()) }

private fun isBuildOrBin(fileName: String): Boolean {
val regex = Regex("""bin|build|node_modules|venv|\.venv|env|\.idea|\.conda""", RegexOption.IGNORE_CASE)
return regex.find(fileName) != null
}

fun collectFiles(projectBaseDirectories: Set<VirtualFile>, changeListManager: ChangeListManager): FileCollectionResult {
val mega = 1024u * 1024u
val maxSize = CodeWhispererSettings.getInstance()
.getProjectContextIndexMaxSize().toULong() * mega
val tenMb = 10 * mega.toInt()
val collectedFiles = mutableListOf<String>()
var currentTotalFileSize = 0UL
val allFiles = mutableListOf<VirtualFile>()

projectBaseDirectories.forEach {
VfsUtilCore.visitChildrenRecursively(
it,
object : VirtualFileVisitor<Unit>(NO_FOLLOW_SYMLINKS) {
// TODO: refactor this along with /dev & codescan file traversing logic
override fun visitFile(file: VirtualFile): Boolean {
if ((file.isDirectory && isBuildOrBin(file.name)) ||
!isWorkspaceSourceContent(file, projectBaseDirectories, changeListManager, additionalGlobalIgnoreRulesForStrictSources) ||
(file.isFile && file.length > tenMb)
) {
return false
}
if (file.isFile) {
allFiles.add(file)
return false
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be true?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

files have no children to visit so it could go either way

}
return true
}
}
)
}

for (file in allFiles) {
if (willExceedPayloadLimit(maxSize, currentTotalFileSize, file.length)) {
break
}
collectedFiles.add(file.path)
currentTotalFileSize += file.length.toUInt()
}

return FileCollectionResult(
files = collectedFiles.toList(),
fileSize = (currentTotalFileSize / 1024u / 1024u).toInt()
)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class CodeWhispererSettings : PersistentStateComponent<CodeWhispererConfiguratio
fun getProjectContextIndexThreadCount(): Int = state.intValue.getOrDefault(
CodeWhispererIntConfigurationType.ProjectContextIndexThreadCount,
0
)
).coerceIn(CONTEXT_INDEX_THREADS)

fun setProjectContextIndexThreadCount(value: Int) {
state.intValue[CodeWhispererIntConfigurationType.ProjectContextIndexThreadCount] = value
Expand All @@ -101,7 +101,7 @@ class CodeWhispererSettings : PersistentStateComponent<CodeWhispererConfiguratio
fun getProjectContextIndexMaxSize(): Int = state.intValue.getOrDefault(
CodeWhispererIntConfigurationType.ProjectContextIndexMaxSize,
250
)
).coerceIn(CONTEXT_INDEX_SIZE)

fun setProjectContextIndexMaxSize(value: Int) {
state.intValue[CodeWhispererIntConfigurationType.ProjectContextIndexMaxSize] = value
Expand Down Expand Up @@ -134,10 +134,6 @@ class CodeWhispererSettings : PersistentStateComponent<CodeWhispererConfiguratio
state.value[CodeWhispererConfigurationType.IsTabAcceptPriorityNotificationShownOnce] = value
}

companion object {
fun getInstance(): CodeWhispererSettings = service()
}

override fun getState(): CodeWhispererConfiguration = CodeWhispererConfiguration().apply {
value.putAll(state.value)
intValue.putAll(state.intValue)
Expand All @@ -155,6 +151,13 @@ class CodeWhispererSettings : PersistentStateComponent<CodeWhispererConfiguratio
this.state.stringValue.putAll(state.stringValue)
this.state.autoBuildSetting.putAll(state.autoBuildSetting)
}

companion object {
fun getInstance(): CodeWhispererSettings = service()

val CONTEXT_INDEX_SIZE = IntRange(1, 4096)
val CONTEXT_INDEX_THREADS = IntRange(0, 50)
}
}

class CodeWhispererConfiguration : BaseState() {
Expand Down
Loading