4
4
package software.aws.toolkits.jetbrains.services.codewhisperer.credentials
5
5
6
6
import com.intellij.openapi.Disposable
7
+ import com.intellij.openapi.application.ApplicationManager
7
8
import com.intellij.openapi.components.service
8
9
import com.intellij.openapi.project.Project
9
10
import com.intellij.util.text.nullize
@@ -40,10 +41,14 @@ import software.amazon.awssdk.services.codewhispererruntime.model.TargetCode
40
41
import software.amazon.awssdk.services.codewhispererruntime.model.UserIntent
41
42
import software.aws.toolkits.core.utils.debug
42
43
import software.aws.toolkits.core.utils.getLogger
44
+ import software.aws.toolkits.core.utils.warn
43
45
import software.aws.toolkits.jetbrains.core.AwsClientManager
44
46
import software.aws.toolkits.jetbrains.core.awsClient
47
+ import software.aws.toolkits.jetbrains.core.credentials.AwsBearerTokenConnection
48
+ import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnection
45
49
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
46
- import software.aws.toolkits.jetbrains.core.credentials.pinning.QConnection
50
+ import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManagerListener
51
+ import software.aws.toolkits.jetbrains.core.credentials.pinning.CodeWhispererConnection
47
52
import software.aws.toolkits.jetbrains.services.amazonq.codeWhispererUserContext
48
53
import software.aws.toolkits.jetbrains.services.codewhisperer.customization.CodeWhispererCustomization
49
54
import software.aws.toolkits.jetbrains.services.codewhisperer.explorer.CodeWhispererExplorerActionManager
@@ -62,6 +67,7 @@ import java.util.concurrent.TimeUnit
62
67
import kotlin.reflect.KProperty0
63
68
import kotlin.reflect.jvm.isAccessible
64
69
70
+ // TODO: move this file to package "/client"
65
71
// As the connection is project-level, we need to make this project-level too
66
72
@Deprecated(" Methods can throw a NullPointerException if callee does not check if connection is valid" )
67
73
interface CodeWhispererClientAdaptor : Disposable {
@@ -277,16 +283,37 @@ interface CodeWhispererClientAdaptor : Disposable {
277
283
open class CodeWhispererClientAdaptorImpl (override val project : Project ) : CodeWhispererClientAdaptor {
278
284
private val mySigv4Client by lazy { createUnmanagedSigv4Client() }
279
285
286
+ @Volatile
287
+ private var myBearerClient: CodeWhispererRuntimeClient ? = null
288
+
280
289
private val KProperty0 <* >.isLazyInitialized: Boolean
281
290
get() {
282
291
isAccessible = true
283
292
return (getDelegate() as Lazy <* >).isInitialized()
284
293
}
285
294
286
- fun bearerClient (): CodeWhispererRuntimeClient =
287
- ToolkitConnectionManager .getInstance(project).activeConnectionForFeature(QConnection .getInstance())?.getConnectionSettings()
288
- ?.awsClient<CodeWhispererRuntimeClient >()
289
- ? : throw Exception (" attempt to get bearer client while there is no valid credential" )
295
+ init {
296
+ initClientUpdateListener()
297
+ }
298
+
299
+ private fun initClientUpdateListener () {
300
+ ApplicationManager .getApplication().messageBus.connect(this ).subscribe(
301
+ ToolkitConnectionManagerListener .TOPIC ,
302
+ object : ToolkitConnectionManagerListener {
303
+ override fun activeConnectionChanged (newConnection : ToolkitConnection ? ) {
304
+ if (newConnection is AwsBearerTokenConnection ) {
305
+ myBearerClient = getBearerClient(newConnection.getConnectionSettings().providerId)
306
+ }
307
+ }
308
+ }
309
+ )
310
+ }
311
+
312
+ private fun bearerClient (): CodeWhispererRuntimeClient {
313
+ if (myBearerClient != null ) return myBearerClient as CodeWhispererRuntimeClient
314
+ myBearerClient = getBearerClient()
315
+ return myBearerClient as CodeWhispererRuntimeClient
316
+ }
290
317
291
318
override fun generateCompletionsPaginator (firstRequest : GenerateCompletionsRequest ) = sequence<GenerateCompletionsResponse > {
292
319
var nextToken: String? = firstRequest.nextToken()
@@ -827,6 +854,28 @@ open class CodeWhispererClientAdaptorImpl(override val project: Project) : CodeW
827
854
if (this ::mySigv4Client.isLazyInitialized) {
828
855
mySigv4Client.close()
829
856
}
857
+ myBearerClient?.close()
858
+ }
859
+
860
+ /* *
861
+ * Every different SSO/AWS Builder ID connection requires a new client which has its corresponding bearer token provider,
862
+ * thus we have to create them dynamically.
863
+ * Invalidate and recycle the old client first, and create a new client with the new connection.
864
+ * This makes sure when we invoke CW, we always use the up-to-date connection.
865
+ * In case this fails to close the client, myBearerClient is already set to null thus next time when we invoke CW,
866
+ * it will go through this again which should get the current up-to-date connection. This stale client would be
867
+ * unused and stay in memory for a while until eventually closed by ToolkitClientManager.
868
+ */
869
+ open fun getBearerClient (oldProviderIdToRemove : String = ""): CodeWhispererRuntimeClient ? {
870
+ myBearerClient = null
871
+
872
+ val connection = ToolkitConnectionManager .getInstance(project).activeConnectionForFeature(CodeWhispererConnection .getInstance())
873
+ connection as ? AwsBearerTokenConnection ? : run {
874
+ LOG .warn { " $connection is not a bearer token connection" }
875
+ return null
876
+ }
877
+
878
+ return AwsClientManager .getInstance().getClient<CodeWhispererRuntimeClient >(connection.getConnectionSettings())
830
879
}
831
880
832
881
companion object {
@@ -840,6 +889,7 @@ open class CodeWhispererClientAdaptorImpl(override val project: Project) : CodeW
840
889
}
841
890
842
891
class MockCodeWhispererClientAdaptor (override val project : Project ) : CodeWhispererClientAdaptorImpl(project) {
892
+ override fun getBearerClient (oldProviderIdToRemove : String ): CodeWhispererRuntimeClient = project.awsClient()
843
893
override fun dispose () {}
844
894
}
845
895
0 commit comments