Skip to content

Commit 57c9d67

Browse files
authored
fix(inline-completion): potential inline completion failure due to input validation exception of supplemental context (aws#5466)
1 parent 06e8c47 commit 57c9d67

File tree

6 files changed

+163
-2
lines changed

6 files changed

+163
-2
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type" : "bugfix",
3+
"description" : "Fix inline completion failure due to context length exceeding the threshold"
4+
}

plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/util/CodeWhispererConstants.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,8 @@ object CodeWhispererConstants {
188188
const val NUMBER_OF_LINE_IN_CHUNK = 50
189189
const val NUMBER_OF_CHUNK_TO_FETCH = 3
190190
const val MAX_TOTAL_LENGTH = 20480
191+
const val MAX_LENGTH_PER_CHUNK = 10240
192+
const val MAX_CONTEXT_COUNT = 5
191193
}
192194

193195
object Utg {

plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/util/CodeWhispererFileContextProvider.kt

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -327,10 +327,30 @@ class DefaultCodeWhispererFileContextProvider(private val project: Project) : Fi
327327
return truncateContext(contextBeforeTruncation)
328328
}
329329

330+
/**
331+
* Requirement
332+
* - Maximum 5 supplemental context.
333+
* - Each chunk can't exceed 10240 characters
334+
* - Sum of all chunks can't exceed 20480 characters
335+
*/
330336
fun truncateContext(context: SupplementalContextInfo): SupplementalContextInfo {
331-
var c = context.contents
332-
while (c.sumOf { it.content.length } >= CodeWhispererConstants.CrossFile.MAX_TOTAL_LENGTH) {
337+
var c = context.contents.map {
338+
return@map if (it.content.length > CodeWhispererConstants.CrossFile.MAX_LENGTH_PER_CHUNK) {
339+
it.copy(content = truncateLineByLine(it.content, CodeWhispererConstants.CrossFile.MAX_LENGTH_PER_CHUNK))
340+
} else {
341+
it
342+
}
343+
}
344+
345+
if (c.size > CodeWhispererConstants.CrossFile.MAX_CONTEXT_COUNT) {
346+
c = c.subList(0, CodeWhispererConstants.CrossFile.MAX_CONTEXT_COUNT)
347+
}
348+
349+
var curTotalLength = c.sumOf { it.content.length }
350+
while (curTotalLength >= CodeWhispererConstants.CrossFile.MAX_TOTAL_LENGTH) {
351+
val last = c.last()
333352
c = c.dropLast(1)
353+
curTotalLength -= last.content.length
334354
}
335355

336356
return context.copy(contents = c)

plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/util/CodeWhispererUtil.kt

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,28 @@ suspend fun String.toCodeChunk(path: String): List<Chunk> {
112112
}
113113
}
114114

115+
fun truncateLineByLine(input: String, l: Int): String {
116+
val maxLength = if (l > 0) l else -1 * l
117+
if (input.isEmpty()) {
118+
return ""
119+
}
120+
val shouldAddNewLineBack = input.last() == '\n'
121+
var lines = input.trim().split("\n")
122+
var curLen = input.length
123+
while (curLen > maxLength) {
124+
val last = lines.last()
125+
lines = lines.dropLast(1)
126+
curLen -= last.length + 1
127+
}
128+
129+
val r = lines.joinToString("\n")
130+
return if (shouldAddNewLineBack) {
131+
r + "\n"
132+
} else {
133+
r
134+
}
135+
}
136+
115137
fun getAuthType(project: Project): CredentialSourceId? {
116138
val connection = checkBearerConnectionValidity(project, BearerTokenFeatureSet.Q)
117139
var authType: CredentialSourceId? = null

plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererFileContextProviderTest.kt

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,70 @@ class CodeWhispererFileContextProviderTest {
529529
assertThat(r.targetFileName).isEqualTo("foo")
530530
}
531531

532+
@Test
533+
fun `truncate context should make context item lte 5`() {
534+
val supplementalContext = SupplementalContextInfo(
535+
isUtg = false,
536+
contents = listOf(
537+
Chunk(content = "a", path = "a.java"),
538+
Chunk(content = "b", path = "b.java"),
539+
Chunk(content = "c", path = "c.java"),
540+
Chunk(content = "d", path = "d.java"),
541+
Chunk(content = "e", path = "e.java"),
542+
Chunk(content = "f", path = "e.java"),
543+
Chunk(content = "g", path = "e.java"),
544+
),
545+
targetFileName = "foo",
546+
strategy = CrossFileStrategy.Codemap
547+
)
548+
549+
val r = sut.truncateContext(supplementalContext)
550+
assertThat(r.contents).hasSize(5)
551+
assertThat(r.strategy).isEqualTo(CrossFileStrategy.Codemap)
552+
assertThat(r.targetFileName).isEqualTo("foo")
553+
}
554+
555+
@Test
556+
fun `truncate context should make context length per item fit in 10240 cap`() {
557+
val chunkA = Chunk(content = "a\n".repeat(4000), path = "a.java")
558+
val chunkB = Chunk(content = "b\n".repeat(6000), path = "b.java")
559+
val chunkC = Chunk(content = "c\n".repeat(1000), path = "c.java")
560+
val chunkD = Chunk(content = "d\n".repeat(1500), path = "d.java")
561+
562+
assertThat(chunkA.content.length).isEqualTo(8000)
563+
assertThat(chunkB.content.length).isEqualTo(12000)
564+
assertThat(chunkC.content.length).isEqualTo(2000)
565+
assertThat(chunkD.content.length).isEqualTo(3000)
566+
assertThat(chunkA.content.length + chunkB.content.length + chunkC.content.length + chunkD.content.length).isEqualTo(25000)
567+
568+
val supplementalContext = SupplementalContextInfo(
569+
isUtg = false,
570+
contents = listOf(
571+
chunkA,
572+
chunkB,
573+
chunkC,
574+
chunkD,
575+
),
576+
targetFileName = "foo",
577+
strategy = CrossFileStrategy.Codemap
578+
)
579+
580+
val r = sut.truncateContext(supplementalContext)
581+
582+
assertThat(r.contents).hasSize(3)
583+
val truncatedChunkA = r.contents[0]
584+
val truncatedChunkB = r.contents[1]
585+
val truncatedChunkC = r.contents[2]
586+
587+
assertThat(truncatedChunkA.content.length).isEqualTo(8000)
588+
assertThat(truncatedChunkB.content.length).isEqualTo(10240)
589+
assertThat(truncatedChunkC.content.length).isEqualTo(2000)
590+
591+
assertThat(r.contentLength).isEqualTo(20240)
592+
assertThat(r.strategy).isEqualTo(CrossFileStrategy.Codemap)
593+
assertThat(r.targetFileName).isEqualTo("foo")
594+
}
595+
532596
private fun setupFixture(fixture: JavaCodeInsightTestFixture): List<PsiFile> {
533597
val psiFile1 = fixture.addFileToProject("Main.java", JAVA_MAIN)
534598
val psiFile2 = fixture.addFileToProject("UtilClass.java", JAVA_UTILCLASS)

plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererUtilTest.kt

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import software.aws.toolkits.jetbrains.services.codewhisperer.util.CodeWhisperer
2727
import software.aws.toolkits.jetbrains.services.codewhisperer.util.isWithin
2828
import software.aws.toolkits.jetbrains.services.codewhisperer.util.runIfIdcConnectionOrTelemetryEnabled
2929
import software.aws.toolkits.jetbrains.services.codewhisperer.util.toCodeChunk
30+
import software.aws.toolkits.jetbrains.services.codewhisperer.util.truncateLineByLine
3031
import software.aws.toolkits.jetbrains.settings.AwsSettings
3132
import software.aws.toolkits.jetbrains.utils.rules.JavaCodeInsightTestFixtureRule
3233
import software.aws.toolkits.telemetry.CodewhispererCompletionType
@@ -61,6 +62,54 @@ class CodeWhispererUtilTest {
6162
AwsSettings.getInstance().isTelemetryEnabled = isTelemetryEnabledDefault
6263
}
6364

65+
@Test
66+
fun `truncateLineByLine should drop the last line if max length is greater than threshold`() {
67+
val input: String = """
68+
${"a".repeat(11)}
69+
${"b".repeat(11)}
70+
${"c".repeat(11)}
71+
${"d".repeat(11)}
72+
${"e".repeat(11)}
73+
""".trimIndent()
74+
assertThat(input.length).isGreaterThan(50)
75+
val actual = truncateLineByLine(input, 50)
76+
assertThat(actual).isEqualTo(
77+
"""
78+
${"a".repeat(11)}
79+
${"b".repeat(11)}
80+
${"c".repeat(11)}
81+
${"d".repeat(11)}
82+
""".trimIndent()
83+
)
84+
85+
val input2 = "b\n".repeat(10)
86+
val actual2 = truncateLineByLine(input2, 8)
87+
assertThat(actual2.length).isEqualTo(8)
88+
}
89+
90+
@Test
91+
fun `truncateLineByLine should return empty if empty string is provided`() {
92+
val input = ""
93+
val actual = truncateLineByLine(input, 50)
94+
assertThat(actual).isEqualTo("")
95+
}
96+
97+
@Test
98+
fun `truncateLineByLine should return empty if 0 max length is provided`() {
99+
val input = "aaaaa"
100+
val actual = truncateLineByLine(input, 0)
101+
assertThat(actual).isEqualTo("")
102+
}
103+
104+
@Test
105+
fun `truncateLineByLine should return flip the value if negative max length is provided`() {
106+
val input = "aaaaa\nbbbbb"
107+
val actual = truncateLineByLine(input, -6)
108+
val expected1 = truncateLineByLine(input, 6)
109+
assertThat(actual).isEqualTo(expected1)
110+
assertThat(actual).isEqualTo("aaaaa")
111+
}
112+
64113
@Test
65114
fun `checkIfIdentityCenterLoginOrTelemetryEnabled will execute callback if the connection is IamIdentityCenter`() {
66115
val modificationTracker = SimpleModificationTracker()

0 commit comments

Comments
 (0)