Skip to content

Commit 787caf2

Browse files
committed
Revert "revert(amazonq): pr aws#5331 broke tests by throwing NotAMock (aws#5383)"
This reverts commit fd7eb31.
1 parent 57c9d67 commit 787caf2

File tree

3 files changed

+58
-66
lines changed

3 files changed

+58
-66
lines changed

plugins/amazonq/codewhisperer/jetbrains-community/src/software/aws/toolkits/jetbrains/services/codewhisperer/credentials/CodeWhispererClientAdaptor.kt

Lines changed: 6 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
package software.aws.toolkits.jetbrains.services.codewhisperer.credentials
55

66
import com.intellij.openapi.Disposable
7-
import com.intellij.openapi.application.ApplicationManager
87
import com.intellij.openapi.components.service
98
import com.intellij.openapi.project.Project
109
import com.intellij.util.text.nullize
@@ -39,14 +38,9 @@ import software.amazon.awssdk.services.codewhispererruntime.model.TargetCode
3938
import software.amazon.awssdk.services.codewhispererruntime.model.UserIntent
4039
import software.aws.toolkits.core.utils.debug
4140
import software.aws.toolkits.core.utils.getLogger
42-
import software.aws.toolkits.core.utils.warn
43-
import software.aws.toolkits.jetbrains.core.AwsClientManager
4441
import software.aws.toolkits.jetbrains.core.awsClient
45-
import software.aws.toolkits.jetbrains.core.credentials.AwsBearerTokenConnection
46-
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnection
4742
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
48-
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManagerListener
49-
import software.aws.toolkits.jetbrains.core.credentials.pinning.CodeWhispererConnection
43+
import software.aws.toolkits.jetbrains.core.credentials.pinning.QConnection
5044
import software.aws.toolkits.jetbrains.services.amazonq.codeWhispererUserContext
5145
import software.aws.toolkits.jetbrains.services.codewhisperer.customization.CodeWhispererCustomization
5246
import software.aws.toolkits.jetbrains.services.codewhisperer.language.CodeWhispererProgrammingLanguage
@@ -63,7 +57,7 @@ import java.util.concurrent.TimeUnit
6357

6458
// As the connection is project-level, we need to make this project-level too
6559
@Deprecated("Methods can throw a NullPointerException if callee does not check if connection is valid")
66-
interface CodeWhispererClientAdaptor : Disposable {
60+
interface CodeWhispererClientAdaptor {
6761
val project: Project
6862

6963
fun generateCompletionsPaginator(
@@ -262,31 +256,10 @@ interface CodeWhispererClientAdaptor : Disposable {
262256
}
263257

264258
open class CodeWhispererClientAdaptorImpl(override val project: Project) : CodeWhispererClientAdaptor {
265-
@Volatile
266-
private var myBearerClient: CodeWhispererRuntimeClient? = null
267-
268-
init {
269-
initClientUpdateListener()
270-
}
271-
272-
private fun initClientUpdateListener() {
273-
ApplicationManager.getApplication().messageBus.connect(this).subscribe(
274-
ToolkitConnectionManagerListener.TOPIC,
275-
object : ToolkitConnectionManagerListener {
276-
override fun activeConnectionChanged(newConnection: ToolkitConnection?) {
277-
if (newConnection is AwsBearerTokenConnection) {
278-
myBearerClient = getBearerClient(newConnection.getConnectionSettings().providerId)
279-
}
280-
}
281-
}
282-
)
283-
}
284-
285-
private fun bearerClient(): CodeWhispererRuntimeClient {
286-
if (myBearerClient != null) return myBearerClient as CodeWhispererRuntimeClient
287-
myBearerClient = getBearerClient()
288-
return myBearerClient as CodeWhispererRuntimeClient
289-
}
259+
fun bearerClient(): CodeWhispererRuntimeClient =
260+
ToolkitConnectionManager.getInstance(project).activeConnectionForFeature(QConnection.getInstance())?.getConnectionSettings()
261+
?.awsClient<CodeWhispererRuntimeClient>()
262+
?: throw Exception("attempt to get bearer client while there is no valid credential")
290263

291264
override fun generateCompletionsPaginator(firstRequest: GenerateCompletionsRequest) = sequence<GenerateCompletionsResponse> {
292265
var nextToken: String? = firstRequest.nextToken()
@@ -809,41 +782,11 @@ open class CodeWhispererClientAdaptorImpl(override val project: Project) : CodeW
809782
requestBuilder.userContext(codeWhispererUserContext())
810783
}
811784

812-
override fun dispose() {
813-
myBearerClient?.close()
814-
}
815-
816-
/**
817-
* Every different SSO/AWS Builder ID connection requires a new client which has its corresponding bearer token provider,
818-
* thus we have to create them dynamically.
819-
* Invalidate and recycle the old client first, and create a new client with the new connection.
820-
* This makes sure when we invoke CW, we always use the up-to-date connection.
821-
* In case this fails to close the client, myBearerClient is already set to null thus next time when we invoke CW,
822-
* it will go through this again which should get the current up-to-date connection. This stale client would be
823-
* unused and stay in memory for a while until eventually closed by ToolkitClientManager.
824-
*/
825-
open fun getBearerClient(oldProviderIdToRemove: String = ""): CodeWhispererRuntimeClient? {
826-
myBearerClient = null
827-
828-
val connection = ToolkitConnectionManager.getInstance(project).activeConnectionForFeature(CodeWhispererConnection.getInstance())
829-
connection as? AwsBearerTokenConnection ?: run {
830-
LOG.warn { "$connection is not a bearer token connection" }
831-
return null
832-
}
833-
834-
return AwsClientManager.getInstance().getClient<CodeWhispererRuntimeClient>(connection.getConnectionSettings())
835-
}
836-
837785
companion object {
838786
private val LOG = getLogger<CodeWhispererClientAdaptorImpl>()
839787
}
840788
}
841789

842-
class MockCodeWhispererClientAdaptor(override val project: Project) : CodeWhispererClientAdaptorImpl(project) {
843-
override fun getBearerClient(oldProviderIdToRemove: String): CodeWhispererRuntimeClient = project.awsClient()
844-
override fun dispose() {}
845-
}
846-
847790
private fun CodewhispererSuggestionState.toCodeWhispererSdkType() = when {
848791
this == CodewhispererSuggestionState.Accept -> SuggestionState.ACCEPT
849792
this == CodewhispererSuggestionState.Reject -> SuggestionState.REJECT

plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererClientAdaptorTest.kt

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import org.junit.After
1515
import org.junit.Before
1616
import org.junit.Rule
1717
import org.junit.Test
18+
import org.junit.jupiter.api.assertThrows
1819
import org.mockito.kotlin.any
1920
import org.mockito.kotlin.argThat
2021
import org.mockito.kotlin.argumentCaptor
@@ -58,17 +59,20 @@ import software.aws.toolkits.core.TokenConnectionSettings
5859
import software.aws.toolkits.core.utils.test.aString
5960
import software.aws.toolkits.jetbrains.core.MockClientManagerRule
6061
import software.aws.toolkits.jetbrains.core.credentials.AwsBearerTokenConnection
62+
import software.aws.toolkits.jetbrains.core.credentials.DefaultToolkitConnectionManager
6163
import software.aws.toolkits.jetbrains.core.credentials.ManagedSsoProfile
6264
import software.aws.toolkits.jetbrains.core.credentials.MockCredentialManagerRule
6365
import software.aws.toolkits.jetbrains.core.credentials.MockToolkitAuthManagerRule
6466
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
67+
import software.aws.toolkits.jetbrains.core.credentials.logoutFromSsoConnection
68+
import software.aws.toolkits.jetbrains.core.credentials.pinning.QConnection
69+
import software.aws.toolkits.jetbrains.core.credentials.sono.Q_SCOPES
6570
import software.aws.toolkits.jetbrains.core.credentials.sono.SONO_REGION
6671
import software.aws.toolkits.jetbrains.services.amazonq.FEATURE_EVALUATION_PRODUCT_NAME
6772
import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.metadata
6873
import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.pythonRequest
6974
import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.pythonResponseWithToken
7075
import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.sdkHttpResponse
71-
import software.aws.toolkits.jetbrains.services.codewhisperer.credentials.CodeWhispererClientAdaptor
7276
import software.aws.toolkits.jetbrains.services.codewhisperer.credentials.CodeWhispererClientAdaptorImpl
7377
import software.aws.toolkits.jetbrains.services.codewhisperer.customization.CodeWhispererCustomization
7478
import software.aws.toolkits.jetbrains.services.codewhisperer.customization.CodeWhispererModelConfigurator
@@ -98,7 +102,7 @@ class CodeWhispererClientAdaptorTest {
98102
private lateinit var bearerClient: CodeWhispererRuntimeClient
99103
private lateinit var ssoClient: SsoOidcClient
100104

101-
private lateinit var sut: CodeWhispererClientAdaptor
105+
private lateinit var sut: CodeWhispererClientAdaptorImpl
102106
private lateinit var connectionManager: ToolkitConnectionManager
103107
private var isTelemetryEnabledDefault: Boolean = false
104108

@@ -145,6 +149,41 @@ class CodeWhispererClientAdaptorTest {
145149
assertThat("us-east-1").isEqualTo(SONO_REGION)
146150
}
147151

152+
@Test
153+
fun `should throw if there is no valid credential, otherwise return codewhispererRuntimeClient`() {
154+
val connectionManager = DefaultToolkitConnectionManager()
155+
projectRule.project.replaceService(ToolkitConnectionManager::class.java, DefaultToolkitConnectionManager(), disposableRule.disposable)
156+
157+
assertThat(ToolkitConnectionManager.getInstance(projectRule.project).activeConnectionForFeature(QConnection.getInstance())).isNull()
158+
assertThrows<Exception>("attempt to get bearer client while there is no valid credential") {
159+
sut.bearerClient()
160+
}
161+
162+
val qConnection = authManagerRule.createConnection(ManagedSsoProfile("us-east-1", aString(), Q_SCOPES))
163+
connectionManager.switchConnection(qConnection)
164+
assertThat(connectionManager.activeConnectionForFeature(QConnection.getInstance()))
165+
.isNotNull
166+
.isEqualTo(qConnection)
167+
assertThat(sut.bearerClient())
168+
.isNotNull
169+
.isInstanceOf(CodeWhispererRuntimeClient::class.java)
170+
171+
logoutFromSsoConnection(projectRule.project, qConnection as AwsBearerTokenConnection)
172+
assertThat(connectionManager.activeConnectionForFeature(QConnection.getInstance())).isNull()
173+
assertThrows<Exception>("attempt to get bearer client while there is no valid credential") {
174+
sut.bearerClient()
175+
}
176+
177+
val anotherQConnection = authManagerRule.createConnection(ManagedSsoProfile("us-east-1", aString(), Q_SCOPES))
178+
connectionManager.switchConnection(anotherQConnection)
179+
assertThat(connectionManager.activeConnectionForFeature(QConnection.getInstance()))
180+
.isNotNull
181+
.isEqualTo(anotherQConnection)
182+
assertThat(sut.bearerClient())
183+
.isNotNull
184+
.isInstanceOf(CodeWhispererRuntimeClient::class.java)
185+
}
186+
148187
@Test
149188
fun `listCustomizations`() {
150189
val sdkIterable = ListAvailableCustomizationsIterable(bearerClient, ListAvailableCustomizationsRequest.builder().build())

plugins/amazonq/codewhisperer/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/codewhisperer/CodeWhispererTestBase.kt

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,13 @@ import org.mockito.kotlin.verify
3030
import software.amazon.awssdk.services.codewhispererruntime.CodeWhispererRuntimeClient
3131
import software.amazon.awssdk.services.codewhispererruntime.model.GenerateCompletionsRequest
3232
import software.amazon.awssdk.services.codewhispererruntime.paginators.GenerateCompletionsIterable
33+
import software.amazon.awssdk.services.ssooidc.SsoOidcClient
3334
import software.aws.toolkits.jetbrains.core.MockClientManagerRule
35+
import software.aws.toolkits.jetbrains.core.credentials.ManagedSsoProfile
3436
import software.aws.toolkits.jetbrains.core.credentials.MockCredentialManagerRule
37+
import software.aws.toolkits.jetbrains.core.credentials.MockToolkitAuthManagerRule
38+
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
39+
import software.aws.toolkits.jetbrains.core.credentials.sono.Q_SCOPES
3540
import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.codeWhispererRecommendationActionId
3641
import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.pythonFileName
3742
import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.pythonResponse
@@ -65,10 +70,11 @@ open class CodeWhispererTestBase {
6570
val mockClientManagerRule = MockClientManagerRule()
6671
val mockCredentialRule = MockCredentialManagerRule()
6772
val disposableRule = DisposableRule()
73+
val authManagerRule = MockToolkitAuthManagerRule()
6874

6975
@Rule
7076
@JvmField
71-
val ruleChain = RuleChain(projectRule, mockCredentialRule, mockClientManagerRule, disposableRule)
77+
val ruleChain = RuleChain(projectRule, mockCredentialRule, mockClientManagerRule, authManagerRule, disposableRule)
7278

7379
protected lateinit var mockClient: CodeWhispererRuntimeClient
7480

@@ -86,6 +92,7 @@ open class CodeWhispererTestBase {
8692
@Before
8793
open fun setUp() {
8894
mockClient = mockClientManagerRule.create()
95+
mockClientManagerRule.create<SsoOidcClient>()
8996
val requestCaptor = argumentCaptor<GenerateCompletionsRequest>()
9097
mockClient.stub {
9198
on {
@@ -159,6 +166,9 @@ open class CodeWhispererTestBase {
159166
projectRule.project.replaceService(CodeWhispererClientAdaptor::class.java, clientAdaptorSpy, disposableRule.disposable)
160167
ApplicationManager.getApplication().replaceService(CodeWhispererExplorerActionManager::class.java, stateManager, disposableRule.disposable)
161168
stateManager.setAutoEnabled(false)
169+
170+
val conn = authManagerRule.createConnection(ManagedSsoProfile("us-east-1", "url", Q_SCOPES))
171+
ToolkitConnectionManager.getInstance(projectRule.project).switchConnection(conn)
162172
}
163173

164174
@After

0 commit comments

Comments
 (0)