Skip to content

refactor(amazonq): stateless Q client #5481

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Mar 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@
serviceImplementation="software.aws.toolkits.jetbrains.services.codewhisperer.customization.DefaultCodeWhispererModelConfigurator"/>

<projectService serviceInterface="software.aws.toolkits.jetbrains.services.codewhisperer.credentials.CodeWhispererClientAdaptor"
serviceImplementation="software.aws.toolkits.jetbrains.services.codewhisperer.credentials.CodeWhispererClientAdaptorImpl"
testServiceImplementation="software.aws.toolkits.jetbrains.services.codewhisperer.credentials.MockCodeWhispererClientAdaptor"/>
serviceImplementation="software.aws.toolkits.jetbrains.services.codewhisperer.credentials.CodeWhispererClientAdaptorImpl"/>
<projectService serviceInterface="software.aws.toolkits.jetbrains.services.codewhisperer.util.FileContextProvider"
serviceImplementation="software.aws.toolkits.jetbrains.services.codewhisperer.util.DefaultCodeWhispererFileContextProvider"/>
<projectService serviceImplementation="software.aws.toolkits.jetbrains.services.codewhisperer.codescan.CodeWhispererCodeScanManager"/>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

package software.aws.toolkits.jetbrains.services.codewhisperer.credentials

import com.intellij.openapi.Disposable
import com.intellij.openapi.application.ApplicationManager
import com.intellij.openapi.components.service
import com.intellij.openapi.project.Project
import com.intellij.util.text.nullize
Expand Down Expand Up @@ -39,14 +37,9 @@ import software.amazon.awssdk.services.codewhispererruntime.model.TargetCode
import software.amazon.awssdk.services.codewhispererruntime.model.UserIntent
import software.aws.toolkits.core.utils.debug
import software.aws.toolkits.core.utils.getLogger
import software.aws.toolkits.core.utils.warn
import software.aws.toolkits.jetbrains.core.AwsClientManager
import software.aws.toolkits.jetbrains.core.awsClient
import software.aws.toolkits.jetbrains.core.credentials.AwsBearerTokenConnection
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnection
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManagerListener
import software.aws.toolkits.jetbrains.core.credentials.pinning.CodeWhispererConnection
import software.aws.toolkits.jetbrains.core.credentials.pinning.QConnection
import software.aws.toolkits.jetbrains.services.amazonq.codeWhispererUserContext
import software.aws.toolkits.jetbrains.services.codewhisperer.customization.CodeWhispererCustomization
import software.aws.toolkits.jetbrains.services.codewhisperer.language.CodeWhispererProgrammingLanguage
Expand All @@ -62,8 +55,11 @@ import java.time.Instant
import java.util.concurrent.TimeUnit

// As the connection is project-level, we need to make this project-level too
@Deprecated("Methods can throw a NullPointerException if callee does not check if connection is valid")
interface CodeWhispererClientAdaptor : Disposable {
@Deprecated(
"It was needed as we were supporting two service models (sigv4 & bearer), " +
"it's no longer the case as we remove sigv4 support, should use AwsClientManager.getClient() directly"
)
interface CodeWhispererClientAdaptor {
val project: Project

fun generateCompletionsPaginator(
Expand Down Expand Up @@ -261,32 +257,11 @@ interface CodeWhispererClientAdaptor : Disposable {
}
}

open class CodeWhispererClientAdaptorImpl(override val project: Project) : CodeWhispererClientAdaptor {
@Volatile
private var myBearerClient: CodeWhispererRuntimeClient? = null

init {
initClientUpdateListener()
}

private fun initClientUpdateListener() {
ApplicationManager.getApplication().messageBus.connect(this).subscribe(
ToolkitConnectionManagerListener.TOPIC,
object : ToolkitConnectionManagerListener {
override fun activeConnectionChanged(newConnection: ToolkitConnection?) {
if (newConnection is AwsBearerTokenConnection) {
myBearerClient = getBearerClient(newConnection.getConnectionSettings().providerId)
}
}
}
)
}

private fun bearerClient(): CodeWhispererRuntimeClient {
if (myBearerClient != null) return myBearerClient as CodeWhispererRuntimeClient
myBearerClient = getBearerClient()
return myBearerClient as CodeWhispererRuntimeClient
}
class CodeWhispererClientAdaptorImpl(override val project: Project) : CodeWhispererClientAdaptor {
fun bearerClient(): CodeWhispererRuntimeClient =
ToolkitConnectionManager.getInstance(project).activeConnectionForFeature(QConnection.getInstance())?.getConnectionSettings()
?.awsClient<CodeWhispererRuntimeClient>()
?: throw Exception("attempt to get bearer client while there is no valid credential")

override fun generateCompletionsPaginator(firstRequest: GenerateCompletionsRequest) = sequence<GenerateCompletionsResponse> {
var nextToken: String? = firstRequest.nextToken()
Expand Down Expand Up @@ -809,41 +784,11 @@ open class CodeWhispererClientAdaptorImpl(override val project: Project) : CodeW
requestBuilder.userContext(codeWhispererUserContext())
}

override fun dispose() {
myBearerClient?.close()
}

/**
* Every different SSO/AWS Builder ID connection requires a new client which has its corresponding bearer token provider,
* thus we have to create them dynamically.
* Invalidate and recycle the old client first, and create a new client with the new connection.
* This makes sure when we invoke CW, we always use the up-to-date connection.
* In case this fails to close the client, myBearerClient is already set to null thus next time when we invoke CW,
* it will go through this again which should get the current up-to-date connection. This stale client would be
* unused and stay in memory for a while until eventually closed by ToolkitClientManager.
*/
open fun getBearerClient(oldProviderIdToRemove: String = ""): CodeWhispererRuntimeClient? {
myBearerClient = null

val connection = ToolkitConnectionManager.getInstance(project).activeConnectionForFeature(CodeWhispererConnection.getInstance())
connection as? AwsBearerTokenConnection ?: run {
LOG.warn { "$connection is not a bearer token connection" }
return null
}

return AwsClientManager.getInstance().getClient<CodeWhispererRuntimeClient>(connection.getConnectionSettings())
}

companion object {
private val LOG = getLogger<CodeWhispererClientAdaptorImpl>()
}
}

class MockCodeWhispererClientAdaptor(override val project: Project) : CodeWhispererClientAdaptorImpl(project) {
override fun getBearerClient(oldProviderIdToRemove: String): CodeWhispererRuntimeClient = project.awsClient()
override fun dispose() {}
}

private fun CodewhispererSuggestionState.toCodeWhispererSdkType() = when {
this == CodewhispererSuggestionState.Accept -> SuggestionState.ACCEPT
this == CodewhispererSuggestionState.Reject -> SuggestionState.REJECT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
package software.aws.toolkits.jetbrains.services.codewhisperer

import com.intellij.openapi.application.ApplicationManager
import com.intellij.openapi.util.Disposer
import com.intellij.openapi.util.SystemInfo
import com.intellij.testFramework.DisposableRule
import com.intellij.testFramework.RuleChain
Expand All @@ -15,6 +14,7 @@ import org.junit.After
import org.junit.Before
import org.junit.Rule
import org.junit.Test
import org.junit.jupiter.api.assertThrows
import org.mockito.kotlin.any
import org.mockito.kotlin.argThat
import org.mockito.kotlin.argumentCaptor
Expand All @@ -24,7 +24,6 @@ import org.mockito.kotlin.mock
import org.mockito.kotlin.stub
import org.mockito.kotlin.times
import org.mockito.kotlin.verify
import org.mockito.kotlin.whenever
import software.amazon.awssdk.services.codewhispererruntime.CodeWhispererRuntimeClient
import software.amazon.awssdk.services.codewhispererruntime.model.ArtifactType
import software.amazon.awssdk.services.codewhispererruntime.model.CodeAnalysisFindingsSchema
Expand Down Expand Up @@ -54,21 +53,22 @@ import software.amazon.awssdk.services.codewhispererruntime.model.SuggestionStat
import software.amazon.awssdk.services.codewhispererruntime.paginators.GenerateCompletionsIterable
import software.amazon.awssdk.services.codewhispererruntime.paginators.ListAvailableCustomizationsIterable
import software.amazon.awssdk.services.ssooidc.SsoOidcClient
import software.aws.toolkits.core.TokenConnectionSettings
import software.aws.toolkits.core.utils.test.aString
import software.aws.toolkits.jetbrains.core.MockClientManagerRule
import software.aws.toolkits.jetbrains.core.credentials.AwsBearerTokenConnection
import software.aws.toolkits.jetbrains.core.credentials.ManagedSsoProfile
import software.aws.toolkits.jetbrains.core.credentials.MockCredentialManagerRule
import software.aws.toolkits.jetbrains.core.credentials.MockToolkitAuthManagerRule
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
import software.aws.toolkits.jetbrains.core.credentials.logoutFromSsoConnection
import software.aws.toolkits.jetbrains.core.credentials.pinning.QConnection
import software.aws.toolkits.jetbrains.core.credentials.sono.Q_SCOPES
import software.aws.toolkits.jetbrains.core.credentials.sono.SONO_REGION
import software.aws.toolkits.jetbrains.services.amazonq.FEATURE_EVALUATION_PRODUCT_NAME
import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.metadata
import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.pythonRequest
import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.pythonResponseWithToken
import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.sdkHttpResponse
import software.aws.toolkits.jetbrains.services.codewhisperer.credentials.CodeWhispererClientAdaptor
import software.aws.toolkits.jetbrains.services.codewhisperer.credentials.CodeWhispererClientAdaptorImpl
import software.aws.toolkits.jetbrains.services.codewhisperer.customization.CodeWhispererCustomization
import software.aws.toolkits.jetbrains.services.codewhisperer.customization.CodeWhispererModelConfigurator
Expand All @@ -93,13 +93,12 @@ class CodeWhispererClientAdaptorTest {

@Rule
@JvmField
val ruleChain = RuleChain(projectRule, mockCredentialRule, mockClientManagerRule, disposableRule)
val ruleChain = RuleChain(projectRule, mockCredentialRule, mockClientManagerRule, authManagerRule, disposableRule)

private lateinit var bearerClient: CodeWhispererRuntimeClient
private lateinit var ssoClient: SsoOidcClient

private lateinit var sut: CodeWhispererClientAdaptor
private lateinit var connectionManager: ToolkitConnectionManager
private lateinit var sut: CodeWhispererClientAdaptorImpl
private var isTelemetryEnabledDefault: Boolean = false

@Before
Expand All @@ -117,15 +116,8 @@ class CodeWhispererClientAdaptorTest {
on { listFeatureEvaluations(any<ListFeatureEvaluationsRequest>()) } doReturn listFeatureEvaluationsResponse
}

val mockConnection = mock<AwsBearerTokenConnection>()
whenever(mockConnection.getConnectionSettings()) doReturn mock<TokenConnectionSettings>()

connectionManager = mock {
on {
activeConnectionForFeature(any())
} doReturn authManagerRule.createConnection(ManagedSsoProfile("us-east-1", aString(), listOf("scopes"))) as AwsBearerTokenConnection
}
projectRule.project.replaceService(ToolkitConnectionManager::class.java, connectionManager, disposableRule.disposable)
val conn = authManagerRule.createConnection(ManagedSsoProfile("us-east-1", "url", Q_SCOPES))
ToolkitConnectionManager.getInstance(projectRule.project).switchConnection(conn)

isTelemetryEnabledDefault = AwsSettings.getInstance().isTelemetryEnabled
}
Expand All @@ -135,16 +127,37 @@ class CodeWhispererClientAdaptorTest {
AwsSettings.getInstance().isTelemetryEnabled = isTelemetryEnabledDefault
}

@After
fun cleanup() {
Disposer.dispose(sut)
}

@Test
fun `Sono region is us-east-1`() {
assertThat("us-east-1").isEqualTo(SONO_REGION)
}

@Test
fun `should throw if there is no valid credential, otherwise return codewhispererRuntimeClient`() {
val connectionManager = ToolkitConnectionManager.getInstance(projectRule.project)

assertThat(connectionManager.activeConnectionForFeature(QConnection.getInstance()))
.isNotNull
assertThat(sut.bearerClient())
.isNotNull
.isInstanceOf(CodeWhispererRuntimeClient::class.java)

logoutFromSsoConnection(projectRule.project, connectionManager.activeConnectionForFeature(QConnection.getInstance()) as AwsBearerTokenConnection)
assertThat(connectionManager.activeConnectionForFeature(QConnection.getInstance())).isNull()
assertThrows<Exception>("attempt to get bearer client while there is no valid credential") {
sut.bearerClient()
}

val anotherQConnection = authManagerRule.createConnection(ManagedSsoProfile("us-east-1", aString(), Q_SCOPES))
connectionManager.switchConnection(anotherQConnection)
assertThat(connectionManager.activeConnectionForFeature(QConnection.getInstance()))
.isNotNull
.isEqualTo(anotherQConnection)
assertThat(sut.bearerClient())
.isNotNull
.isInstanceOf(CodeWhispererRuntimeClient::class.java)
}

@Test
fun `listCustomizations`() {
val sdkIterable = ListAvailableCustomizationsIterable(bearerClient, ListAvailableCustomizationsRequest.builder().build())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class CodeWhispererSettingsTest : CodeWhispererTestBase() {
stateManager.loadState(CodeWhispererExploreActionState())
CodeWhispererSettings.getInstance().loadState(CodeWhispererConfiguration())

val problemsWindow = ProblemsView.getToolWindow(projectRule.project) ?: fail("Problems window not found")
ProblemsView.getToolWindow(projectRule.project) ?: fail("Problems window not found")
val codeReferenceWindow = ToolWindowManager.getInstance(projectRule.project).getToolWindow(
CodeWhispererCodeReferenceToolWindowFactory.id
) ?: fail("Code Reference Log window not found")
Expand All @@ -114,7 +114,6 @@ class CodeWhispererSettingsTest : CodeWhispererTestBase() {
} ?: fail("CodeWhisperer status bar widget not found")

runInEdtAndWait {
assertThat(problemsWindow.contentManager.contentCount).isEqualTo(0)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems to be flaky, sometimes it will be 1 with code scan sometimes not.

assertThat(codeReferenceWindow.isAvailable).isFalse
assertThat(statusBarWidgetFactory.isAvailable(projectRule.project)).isTrue
assertThat(settingsManager.isIncludeCodeWithReference()).isFalse
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,13 @@ import org.mockito.kotlin.verify
import software.amazon.awssdk.services.codewhispererruntime.CodeWhispererRuntimeClient
import software.amazon.awssdk.services.codewhispererruntime.model.GenerateCompletionsRequest
import software.amazon.awssdk.services.codewhispererruntime.paginators.GenerateCompletionsIterable
import software.amazon.awssdk.services.ssooidc.SsoOidcClient
import software.aws.toolkits.jetbrains.core.MockClientManagerRule
import software.aws.toolkits.jetbrains.core.credentials.ManagedSsoProfile
import software.aws.toolkits.jetbrains.core.credentials.MockCredentialManagerRule
import software.aws.toolkits.jetbrains.core.credentials.MockToolkitAuthManagerRule
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
import software.aws.toolkits.jetbrains.core.credentials.sono.Q_SCOPES
import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.codeWhispererRecommendationActionId
import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.pythonFileName
import software.aws.toolkits.jetbrains.services.codewhisperer.CodeWhispererTestUtil.pythonResponse
Expand Down Expand Up @@ -65,10 +70,11 @@ open class CodeWhispererTestBase {
val mockClientManagerRule = MockClientManagerRule()
val mockCredentialRule = MockCredentialManagerRule()
val disposableRule = DisposableRule()
val authManagerRule = MockToolkitAuthManagerRule()

@Rule
@JvmField
val ruleChain = RuleChain(projectRule, mockCredentialRule, mockClientManagerRule, disposableRule)
val ruleChain = RuleChain(projectRule, mockCredentialRule, mockClientManagerRule, authManagerRule, disposableRule)

protected lateinit var mockClient: CodeWhispererRuntimeClient

Expand All @@ -86,6 +92,7 @@ open class CodeWhispererTestBase {
@Before
open fun setUp() {
mockClient = mockClientManagerRule.create()
mockClientManagerRule.create<SsoOidcClient>()
val requestCaptor = argumentCaptor<GenerateCompletionsRequest>()
mockClient.stub {
on {
Expand Down Expand Up @@ -159,6 +166,9 @@ open class CodeWhispererTestBase {
projectRule.project.replaceService(CodeWhispererClientAdaptor::class.java, clientAdaptorSpy, disposableRule.disposable)
ApplicationManager.getApplication().replaceService(CodeWhispererExplorerActionManager::class.java, stateManager, disposableRule.disposable)
stateManager.setAutoEnabled(false)

val conn = authManagerRule.createConnection(ManagedSsoProfile("us-east-1", "url", Q_SCOPES))
ToolkitConnectionManager.getInstance(projectRule.project).switchConnection(conn)
}

@After
Expand Down
Loading