3
3
4
4
package software.aws.toolkits.jetbrains.core.credentials.sso
5
5
6
+ import com.intellij.openapi.application.ApplicationManager
6
7
import com.intellij.testFramework.ApplicationRule
8
+ import com.intellij.testFramework.DisposableRule
7
9
import com.intellij.testFramework.RuleChain
10
+ import com.intellij.testFramework.replaceService
8
11
import kotlinx.coroutines.runBlocking
9
12
import org.assertj.core.api.Assertions.assertThat
10
13
import org.assertj.core.api.Assertions.assertThatThrownBy
11
14
import org.junit.Before
12
15
import org.junit.Rule
13
16
import org.junit.Test
17
+ import org.junit.jupiter.api.assertThrows
14
18
import org.mockito.kotlin.KStubbing
15
19
import org.mockito.kotlin.any
16
20
import org.mockito.kotlin.eq
@@ -33,6 +37,7 @@ import software.amazon.awssdk.services.ssooidc.model.StartDeviceAuthorizationRes
33
37
import software.aws.toolkits.core.region.aRegionId
34
38
import software.aws.toolkits.core.utils.delegateMock
35
39
import software.aws.toolkits.core.utils.test.aString
40
+ import software.aws.toolkits.jetbrains.core.credentials.sso.pkce.ToolkitOAuthService
36
41
import software.aws.toolkits.jetbrains.utils.rules.SsoLoginCallbackProviderRule
37
42
import java.time.Clock
38
43
import java.time.Duration
@@ -53,11 +58,12 @@ class SsoAccessTokenProviderTest {
53
58
private lateinit var ssoCache: SsoCache
54
59
55
60
private val applicationRule = ApplicationRule ()
61
+ private val disposableRule = DisposableRule ()
56
62
private val ssoCallbackRule = SsoLoginCallbackProviderRule ()
57
63
58
64
@JvmField
59
65
@Rule
60
- val ruleChain = RuleChain (applicationRule, ssoCallbackRule)
66
+ val ruleChain = RuleChain (applicationRule, ssoCallbackRule, disposableRule )
61
67
62
68
@Before
63
69
fun setUp () {
@@ -163,6 +169,58 @@ class SsoAccessTokenProviderTest {
163
169
verify(ssoCache).saveAccessToken(ssoUrl, accessToken)
164
170
}
165
171
172
+ @Test
173
+ fun `initiates authorizatation_grant registration when scopes are requested in a commercial region` () {
174
+ val sut = SsoAccessTokenProvider (ssoUrl, " us-east-1" , ssoCache, ssoOidcClient, scopes = listOf (" dummy:scope" ), clock = clock)
175
+ setupCacheStub(returnValue = null )
176
+
177
+ val oauth = mock<ToolkitOAuthService >()
178
+ ApplicationManager .getApplication().replaceService(ToolkitOAuthService ::class .java, oauth, disposableRule.disposable)
179
+
180
+ ssoOidcClient.stub {
181
+ on(
182
+ ssoOidcClient.registerClient(any<RegisterClientRequest >())
183
+ ).thenReturn(
184
+ RegisterClientResponse .builder()
185
+ .clientId(clientId)
186
+ .clientSecret(clientSecret)
187
+ .clientSecretExpiresAt(clock.instant().plusSeconds(180 ).toEpochMilli())
188
+ .build()
189
+ )
190
+ }
191
+
192
+ // flow is not completely stubbed out
193
+ assertThrows<Exception > { sut.accessToken() }
194
+
195
+ verify(ssoCache).saveClientRegistration(any<PKCEClientRegistrationCacheKey >(), any())
196
+ }
197
+
198
+ @Test
199
+ fun `initiates device code registration when scopes are requested in a non-commercial region` () {
200
+ val sut = SsoAccessTokenProvider (ssoUrl, " us-gov-east-1" , ssoCache, ssoOidcClient, scopes = listOf (" dummy:scope" ), clock = clock)
201
+ setupCacheStub(returnValue = null )
202
+
203
+ val oauth = mock<ToolkitOAuthService >()
204
+ ApplicationManager .getApplication().replaceService(ToolkitOAuthService ::class .java, oauth, disposableRule.disposable)
205
+
206
+ ssoOidcClient.stub {
207
+ on(
208
+ ssoOidcClient.registerClient(any<RegisterClientRequest >())
209
+ ).thenReturn(
210
+ RegisterClientResponse .builder()
211
+ .clientId(clientId)
212
+ .clientSecret(clientSecret)
213
+ .clientSecretExpiresAt(clock.instant().plusSeconds(180 ).toEpochMilli())
214
+ .build()
215
+ )
216
+ }
217
+
218
+ // flow is not completely stubbed out
219
+ assertThrows<Exception > { sut.accessToken() }
220
+
221
+ verify(ssoCache).saveClientRegistration(any<DeviceAuthorizationClientRegistrationCacheKey >(), any())
222
+ }
223
+
166
224
@Test
167
225
fun getAccessTokenWithoutCachesMultiplePolls () {
168
226
val expirationClientRegistration = clock.instant().plusSeconds(120 )
0 commit comments