Skip to content

Commit 4971993

Browse files
authored
refactor(Q tech debt): make cwsprClientAdapr stateless and obtain sdk client from AwsClientManager (#5331)
* No need to store a client as a state within this class, AwsClientManager already handle the connection change resolving comments #5290 (comment)
1 parent 3621194 commit 4971993

File tree

3 files changed

+57
-58
lines changed

3 files changed

+57
-58
lines changed

plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/credentials/CodeWhispererClientAdaptor.kt

Lines changed: 5 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
package software.aws.toolkits.jetbrains.services.codewhisperer.credentials
55

66
import com.intellij.openapi.Disposable
7-
import com.intellij.openapi.application.ApplicationManager
87
import com.intellij.openapi.components.service
98
import com.intellij.openapi.project.Project
109
import com.intellij.util.text.nullize
@@ -41,14 +40,10 @@ import software.amazon.awssdk.services.codewhispererruntime.model.TargetCode
4140
import software.amazon.awssdk.services.codewhispererruntime.model.UserIntent
4241
import software.aws.toolkits.core.utils.debug
4342
import software.aws.toolkits.core.utils.getLogger
44-
import software.aws.toolkits.core.utils.warn
4543
import software.aws.toolkits.jetbrains.core.AwsClientManager
4644
import software.aws.toolkits.jetbrains.core.awsClient
47-
import software.aws.toolkits.jetbrains.core.credentials.AwsBearerTokenConnection
48-
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnection
4945
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
50-
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManagerListener
51-
import software.aws.toolkits.jetbrains.core.credentials.pinning.CodeWhispererConnection
46+
import software.aws.toolkits.jetbrains.core.credentials.pinning.QConnection
5247
import software.aws.toolkits.jetbrains.services.amazonq.codeWhispererUserContext
5348
import software.aws.toolkits.jetbrains.services.codewhisperer.customization.CodeWhispererCustomization
5449
import software.aws.toolkits.jetbrains.services.codewhisperer.explorer.CodeWhispererExplorerActionManager
@@ -67,7 +62,6 @@ import java.util.concurrent.TimeUnit
6762
import kotlin.reflect.KProperty0
6863
import kotlin.reflect.jvm.isAccessible
6964

70-
// TODO: move this file to package "/client"
7165
// As the connection is project-level, we need to make this project-level too
7266
@Deprecated("Methods can throw a NullPointerException if callee does not check if connection is valid")
7367
interface CodeWhispererClientAdaptor : Disposable {
@@ -283,37 +277,16 @@ interface CodeWhispererClientAdaptor : Disposable {
283277
open class CodeWhispererClientAdaptorImpl(override val project: Project) : CodeWhispererClientAdaptor {
284278
private val mySigv4Client by lazy { createUnmanagedSigv4Client() }
285279

286-
@Volatile
287-
private var myBearerClient: CodeWhispererRuntimeClient? = null
288-
289280
private val KProperty0<*>.isLazyInitialized: Boolean
290281
get() {
291282
isAccessible = true
292283
return (getDelegate() as Lazy<*>).isInitialized()
293284
}
294285

295-
init {
296-
initClientUpdateListener()
297-
}
298-
299-
private fun initClientUpdateListener() {
300-
ApplicationManager.getApplication().messageBus.connect(this).subscribe(
301-
ToolkitConnectionManagerListener.TOPIC,
302-
object : ToolkitConnectionManagerListener {
303-
override fun activeConnectionChanged(newConnection: ToolkitConnection?) {
304-
if (newConnection is AwsBearerTokenConnection) {
305-
myBearerClient = getBearerClient(newConnection.getConnectionSettings().providerId)
306-
}
307-
}
308-
}
309-
)
310-
}
311-
312-
private fun bearerClient(): CodeWhispererRuntimeClient {
313-
if (myBearerClient != null) return myBearerClient as CodeWhispererRuntimeClient
314-
myBearerClient = getBearerClient()
315-
return myBearerClient as CodeWhispererRuntimeClient
316-
}
286+
fun bearerClient(): CodeWhispererRuntimeClient =
287+
ToolkitConnectionManager.getInstance(project).activeConnectionForFeature(QConnection.getInstance())?.getConnectionSettings()
288+
?.awsClient<CodeWhispererRuntimeClient>()
289+
?: throw Exception("attempt to get bearer client while there is no valid credential")
317290

318291
override fun generateCompletionsPaginator(firstRequest: GenerateCompletionsRequest) = sequence<GenerateCompletionsResponse> {
319292
var nextToken: String? = firstRequest.nextToken()
@@ -854,28 +827,6 @@ open class CodeWhispererClientAdaptorImpl(override val project: Project) : CodeW
854827
if (this::mySigv4Client.isLazyInitialized) {
855828
mySigv4Client.close()
856829
}
857-
myBearerClient?.close()
858-
}
859-
860-
/**
861-
* Every different SSO/AWS Builder ID connection requires a new client which has its corresponding bearer token provider,
862-
* thus we have to create them dynamically.
863-
* Invalidate and recycle the old client first, and create a new client with the new connection.
864-
* This makes sure when we invoke CW, we always use the up-to-date connection.
865-
* In case this fails to close the client, myBearerClient is already set to null thus next time when we invoke CW,
866-
* it will go through this again which should get the current up-to-date connection. This stale client would be
867-
* unused and stay in memory for a while until eventually closed by ToolkitClientManager.
868-
*/
869-
open fun getBearerClient(oldProviderIdToRemove: String = ""): CodeWhispererRuntimeClient? {
870-
myBearerClient = null
871-
872-
val connection = ToolkitConnectionManager.getInstance(project).activeConnectionForFeature(CodeWhispererConnection.getInstance())
873-
connection as? AwsBearerTokenConnection ?: run {
874-
LOG.warn { "$connection is not a bearer token connection" }
875-
return null
876-
}
877-
878-
return AwsClientManager.getInstance().getClient<CodeWhispererRuntimeClient>(connection.getConnectionSettings())
879830
}
880831

881832
companion object {
@@ -889,7 +840,6 @@ open class CodeWhispererClientAdaptorImpl(override val project: Project) : CodeW
889840
}
890841

891842
class MockCodeWhispererClientAdaptor(override val project: Project) : CodeWhispererClientAdaptorImpl(project) {
892-
override fun getBearerClient(oldProviderIdToRemove: String): CodeWhispererRuntimeClient = project.awsClient()
893843
override fun dispose() {}
894844
}
895845

plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererClientAdaptorTest.kt

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import org.junit.After
1515
import org.junit.Before
1616
import org.junit.Rule
1717
import org.junit.Test
18+
import org.junit.jupiter.api.assertThrows
1819
import org.mockito.kotlin.any
1920
import org.mockito.kotlin.argThat
2021
import org.mockito.kotlin.argumentCaptor
@@ -68,17 +69,20 @@ import software.aws.toolkits.core.TokenConnectionSettings
6869
import software.aws.toolkits.core.utils.test.aString
6970
import software.aws.toolkits.jetbrains.core.MockClientManagerRule
7071
import software.aws.toolkits.jetbrains.core.credentials.AwsBearerTokenConnection
72+
import software.aws.toolkits.jetbrains.core.credentials.DefaultToolkitConnectionManager
7173
import software.aws.toolkits.jetbrains.core.credentials.ManagedSsoProfile
7274
import software.aws.toolkits.jetbrains.core.credentials.MockCredentialManagerRule
7375
import software.aws.toolkits.jetbrains.core.credentials.MockToolkitAuthManagerRule
7476
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
77+
import software.aws.toolkits.jetbrains.core.credentials.logoutFromSsoConnection
78+
import software.aws.toolkits.jetbrains.core.credentials.pinning.QConnection
79+
import software.aws.toolkits.jetbrains.core.credentials.sono.Q_SCOPES
7580
import software.aws.toolkits.jetbrains.core.credentials.sono.SONO_REGION
7681
import software.aws.toolkits.jetbrains.services.amazonq.FEATURE_EVALUATION_PRODUCT_NAME
7782
import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.metadata
7883
import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.pythonRequest
7984
import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.pythonResponseWithToken
8085
import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.sdkHttpResponse
81-
import software.aws.toolkits.jetbrains.services.codewhisperer.credentials.CodeWhispererClientAdaptor
8286
import software.aws.toolkits.jetbrains.services.codewhisperer.credentials.CodeWhispererClientAdaptorImpl
8387
import software.aws.toolkits.jetbrains.services.codewhisperer.customization.CodeWhispererCustomization
8488
import software.aws.toolkits.jetbrains.services.codewhisperer.customization.CodeWhispererModelConfigurator
@@ -109,7 +113,7 @@ class CodeWhispererClientAdaptorTest {
109113
private lateinit var bearerClient: CodeWhispererRuntimeClient
110114
private lateinit var ssoClient: SsoOidcClient
111115

112-
private lateinit var sut: CodeWhispererClientAdaptor
116+
private lateinit var sut: CodeWhispererClientAdaptorImpl
113117
private lateinit var connectionManager: ToolkitConnectionManager
114118
private var isTelemetryEnabledDefault: Boolean = false
115119

@@ -163,6 +167,41 @@ class CodeWhispererClientAdaptorTest {
163167
assertThat("us-east-1").isEqualTo(SONO_REGION)
164168
}
165169

170+
@Test
171+
fun `should throw if there is no valid credential, otherwise return codewhispererRuntimeClient`() {
172+
val connectionManager = DefaultToolkitConnectionManager()
173+
projectRule.project.replaceService(ToolkitConnectionManager::class.java, DefaultToolkitConnectionManager(), disposableRule.disposable)
174+
175+
assertThat(ToolkitConnectionManager.getInstance(projectRule.project).activeConnectionForFeature(QConnection.getInstance())).isNull()
176+
assertThrows<Exception>("attempt to get bearer client while there is no valid credential") {
177+
sut.bearerClient()
178+
}
179+
180+
val qConnection = authManagerRule.createConnection(ManagedSsoProfile("us-east-1", aString(), Q_SCOPES))
181+
connectionManager.switchConnection(qConnection)
182+
assertThat(connectionManager.activeConnectionForFeature(QConnection.getInstance()))
183+
.isNotNull
184+
.isEqualTo(qConnection)
185+
assertThat(sut.bearerClient())
186+
.isNotNull
187+
.isInstanceOf(CodeWhispererRuntimeClient::class.java)
188+
189+
logoutFromSsoConnection(projectRule.project, qConnection as AwsBearerTokenConnection)
190+
assertThat(connectionManager.activeConnectionForFeature(QConnection.getInstance())).isNull()
191+
assertThrows<Exception>("attempt to get bearer client while there is no valid credential") {
192+
sut.bearerClient()
193+
}
194+
195+
val anotherQConnection = authManagerRule.createConnection(ManagedSsoProfile("us-east-1", aString(), Q_SCOPES))
196+
connectionManager.switchConnection(anotherQConnection)
197+
assertThat(connectionManager.activeConnectionForFeature(QConnection.getInstance()))
198+
.isNotNull
199+
.isEqualTo(anotherQConnection)
200+
assertThat(sut.bearerClient())
201+
.isNotNull
202+
.isInstanceOf(CodeWhispererRuntimeClient::class.java)
203+
}
204+
166205
@Test
167206
fun `listCustomizations`() {
168207
val sdkIterable = ListAvailableCustomizationsIterable(bearerClient, ListAvailableCustomizationsRequest.builder().build())

plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererTestBase.kt

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,13 @@ import org.mockito.kotlin.verify
3030
import software.amazon.awssdk.services.codewhispererruntime.CodeWhispererRuntimeClient
3131
import software.amazon.awssdk.services.codewhispererruntime.model.GenerateCompletionsRequest
3232
import software.amazon.awssdk.services.codewhispererruntime.paginators.GenerateCompletionsIterable
33+
import software.amazon.awssdk.services.ssooidc.SsoOidcClient
3334
import software.aws.toolkits.jetbrains.core.MockClientManagerRule
35+
import software.aws.toolkits.jetbrains.core.credentials.ManagedSsoProfile
3436
import software.aws.toolkits.jetbrains.core.credentials.MockCredentialManagerRule
37+
import software.aws.toolkits.jetbrains.core.credentials.MockToolkitAuthManagerRule
38+
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
39+
import software.aws.toolkits.jetbrains.core.credentials.sono.Q_SCOPES
3540
import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.codeWhispererRecommendationActionId
3641
import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.pythonFileName
3742
import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.pythonResponse
@@ -65,10 +70,11 @@ open class CodeWhispererTestBase {
6570
val mockClientManagerRule = MockClientManagerRule()
6671
val mockCredentialRule = MockCredentialManagerRule()
6772
val disposableRule = DisposableRule()
73+
val authManagerRule = MockToolkitAuthManagerRule()
6874

6975
@Rule
7076
@JvmField
71-
val ruleChain = RuleChain(projectRule, mockCredentialRule, mockClientManagerRule, disposableRule)
77+
val ruleChain = RuleChain(projectRule, mockCredentialRule, mockClientManagerRule, authManagerRule, disposableRule)
7278

7379
protected lateinit var mockClient: CodeWhispererRuntimeClient
7480

@@ -86,6 +92,7 @@ open class CodeWhispererTestBase {
8692
@Before
8793
open fun setUp() {
8894
mockClient = mockClientManagerRule.create()
95+
mockClientManagerRule.create<SsoOidcClient>()
8996
val requestCaptor = argumentCaptor<GenerateCompletionsRequest>()
9097
mockClient.stub {
9198
on {
@@ -159,6 +166,9 @@ open class CodeWhispererTestBase {
159166
projectRule.project.replaceService(CodeWhispererClientAdaptor::class.java, clientAdaptorSpy, disposableRule.disposable)
160167
ApplicationManager.getApplication().replaceService(CodeWhispererExplorerActionManager::class.java, stateManager, disposableRule.disposable)
161168
stateManager.setAutoEnabled(false)
169+
170+
val conn = authManagerRule.createConnection(ManagedSsoProfile("us-east-1", "url", Q_SCOPES))
171+
ToolkitConnectionManager.getInstance(projectRule.project).switchConnection(conn)
162172
}
163173

164174
@After

0 commit comments

Comments
 (0)