diff --git a/.github/workflows/android_test.yml b/.github/workflows/android_test.yml index c19b8eb9..37500bf9 100644 --- a/.github/workflows/android_test.yml +++ b/.github/workflows/android_test.yml @@ -23,4 +23,4 @@ jobs: run: chmod +x ./gradlew - name: Run Tests with Gradle - run: ./gradlew test \ No newline at end of file + run: ./gradlew clean testDebugUnitTest \ No newline at end of file diff --git a/app/src/main/java/com/shifthackz/aisdv1/app/di/ProvidersModule.kt b/app/src/main/java/com/shifthackz/aisdv1/app/di/ProvidersModule.kt index 4f0eb1bb..933011dc 100755 --- a/app/src/main/java/com/shifthackz/aisdv1/app/di/ProvidersModule.kt +++ b/app/src/main/java/com/shifthackz/aisdv1/app/di/ProvidersModule.kt @@ -7,6 +7,7 @@ import com.shifthackz.aisdv1.core.common.appbuild.BuildVersion import com.shifthackz.aisdv1.core.common.file.FileProviderDescriptor import com.shifthackz.aisdv1.core.common.links.LinksProvider import com.shifthackz.aisdv1.core.common.schedulers.SchedulersProvider +import com.shifthackz.aisdv1.core.common.time.TimeProvider import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.aisdv1.domain.feature.auth.AuthorizationCredentials import com.shifthackz.aisdv1.domain.feature.auth.AuthorizationStore @@ -24,6 +25,7 @@ import io.reactivex.rxjava3.core.Scheduler import io.reactivex.rxjava3.schedulers.Schedulers import org.koin.android.ext.koin.androidApplication import org.koin.dsl.module +import java.util.Date import java.util.concurrent.Executor import java.util.concurrent.Executors @@ -134,6 +136,14 @@ val providersModule = module { } } + single { + object : TimeProvider { + override fun nanoTime(): Long = System.nanoTime() + override fun currentTimeMillis(): Long = System.currentTimeMillis() + override fun currentDate(): Date = Date() + } + } + single { object : FileProviderDescriptor { override val providerPath: String = "${androidApplication().packageName}.fileprovider" diff --git a/core/common/src/main/java/com/shifthackz/aisdv1/core/common/extensions/StringExtensions.kt b/core/common/src/main/java/com/shifthackz/aisdv1/core/common/extensions/StringExtensions.kt index a686b66b..0fe94bfb 100644 --- a/core/common/src/main/java/com/shifthackz/aisdv1/core/common/extensions/StringExtensions.kt +++ b/core/common/src/main/java/com/shifthackz/aisdv1/core/common/extensions/StringExtensions.kt @@ -3,13 +3,6 @@ package com.shifthackz.aisdv1.core.common.extensions private const val PROTOCOL_DELIMITER = "://" private const val PROTOCOL_HOLDER = "[[_PROTOCOL_]]" -fun String.withoutUrlProtocol(): String { - if (!this.contains(PROTOCOL_DELIMITER)) return this - val decomposed = this.split(PROTOCOL_DELIMITER) - if (decomposed.size < 2) return this - return decomposed.last() -} - fun String.fixUrlSlashes(): String = this .replace(PROTOCOL_DELIMITER, PROTOCOL_HOLDER) .replace(Regex("/{2,}"), "/") diff --git a/core/common/src/main/java/com/shifthackz/aisdv1/core/common/time/TimeProvider.kt b/core/common/src/main/java/com/shifthackz/aisdv1/core/common/time/TimeProvider.kt new file mode 100644 index 00000000..7564bcb1 --- /dev/null +++ b/core/common/src/main/java/com/shifthackz/aisdv1/core/common/time/TimeProvider.kt @@ -0,0 +1,9 @@ +package com.shifthackz.aisdv1.core.common.time + +import java.util.Date + +interface TimeProvider { + fun nanoTime(): Long + fun currentTimeMillis(): Long + fun currentDate(): Date +} diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/entity/AppVersionTest.kt b/core/common/src/test/java/com/shifthackz/aisdv1/core/common/appbuild/BuildVersionTest.kt similarity index 92% rename from domain/src/test/java/com/shifthackz/aisdv1/domain/entity/AppVersionTest.kt rename to core/common/src/test/java/com/shifthackz/aisdv1/core/common/appbuild/BuildVersionTest.kt index 0dda593b..7383f94e 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/entity/AppVersionTest.kt +++ b/core/common/src/test/java/com/shifthackz/aisdv1/core/common/appbuild/BuildVersionTest.kt @@ -1,10 +1,9 @@ -package com.shifthackz.aisdv1.domain.entity +package com.shifthackz.aisdv1.core.common.appbuild -import com.shifthackz.aisdv1.core.common.appbuild.BuildVersion import org.junit.Assert import org.junit.Test -class AppVersionTest { +class BuildVersionTest { @Test fun `Parse 1_0_0, expected success`() { diff --git a/core/common/src/test/java/com/shifthackz/aisdv1/core/common/extensions/DateExtensionsTest.kt b/core/common/src/test/java/com/shifthackz/aisdv1/core/common/extensions/DateExtensionsTest.kt new file mode 100644 index 00000000..b4f77ff5 --- /dev/null +++ b/core/common/src/test/java/com/shifthackz/aisdv1/core/common/extensions/DateExtensionsTest.kt @@ -0,0 +1,33 @@ +package com.shifthackz.aisdv1.core.common.extensions + +import org.junit.Assert +import org.junit.Test +import java.util.Date + +class DateExtensionsTest { + + companion object { + private val date = Date(894333955000) // 1998-05-05 05:05:55 + } + + @Test + fun `given date 05_05_1998, then getRawDay, expected 5`() { + val expected = 5 + val actual = date.getRawDay() + Assert.assertEquals(expected, actual) + } + + @Test + fun `given date 05_05_1998, then getRawMonth, expected 5`() { + val expected = 5 + val actual = date.getRawMonth() + Assert.assertEquals(expected, actual) + } + + @Test + fun `given date 05_05_1998, then getRawYear, expected 1998`() { + val expected = 1998 + val actual = date.getRawYear() + Assert.assertEquals(expected, actual) + } +} diff --git a/core/common/src/test/java/com/shifthackz/aisdv1/core/common/extensions/KotlinExtensionsTest.kt b/core/common/src/test/java/com/shifthackz/aisdv1/core/common/extensions/KotlinExtensionsTest.kt new file mode 100644 index 00000000..5fb165f4 --- /dev/null +++ b/core/common/src/test/java/com/shifthackz/aisdv1/core/common/extensions/KotlinExtensionsTest.kt @@ -0,0 +1,41 @@ +package com.shifthackz.aisdv1.core.common.extensions + +import org.junit.Assert +import org.junit.Test + +class KotlinExtensionsTest { + + @Test + fun `given TestClass with null value, then applyIf with true predicate, expected test value changed`() { + class TestClass { + var testValue: String? = null + } + + val instance = TestClass() + val predicate = true + instance.applyIf(predicate) { + testValue = "5598" + } + + val expected = "5598" + val actual = instance.testValue + Assert.assertEquals(expected, actual) + } + + @Test + fun `given TestClass with null value, then applyIf with false predicate, expected test value NOT changed`() { + class TestClass { + var testValue: String? = null + } + + val instance = TestClass() + val predicate = false + instance.applyIf(predicate) { + testValue = "5598" + } + + val expected = null + val actual = instance.testValue + Assert.assertEquals(expected, actual) + } +} diff --git a/core/common/src/test/java/com/shifthackz/aisdv1/core/common/math/MathUtilsTest.kt b/core/common/src/test/java/com/shifthackz/aisdv1/core/common/math/MathUtilsTest.kt new file mode 100644 index 00000000..f06ceea6 --- /dev/null +++ b/core/common/src/test/java/com/shifthackz/aisdv1/core/common/math/MathUtilsTest.kt @@ -0,0 +1,23 @@ +package com.shifthackz.aisdv1.core.common.math + +import org.junit.Assert +import org.junit.Test + +class MathUtilsTest { + + @Test + fun `given Double with 8 fraction digits, then roundTo(2), expected Double with 2 fraction digits`() { + val value = 55.98238462 + val expected = 55.98.toString() + val actual = value.roundTo(2).toString() + Assert.assertEquals(expected, actual) + } + + @Test + fun `given Float with 6 fraction digits, then roundTo(2), expected Float with 2 fraction digits`() { + val value = 55.982384f + val expected = 55.98f.toString() + val actual = value.roundTo(2).toString() + Assert.assertEquals(expected, actual) + } +} diff --git a/core/validation/build.gradle b/core/validation/build.gradle index 3ea1fdd5..eb43efd2 100644 --- a/core/validation/build.gradle +++ b/core/validation/build.gradle @@ -12,4 +12,6 @@ android { dependencies { implementation di.koinCore implementation reactive.rxkotlin + testImplementation test.junit + testImplementation test.mockk } diff --git a/core/validation/src/main/java/com/shifthackz/aisdv1/core/validation/horde/CommonStringValidator.kt b/core/validation/src/main/java/com/shifthackz/aisdv1/core/validation/common/CommonStringValidator.kt similarity index 81% rename from core/validation/src/main/java/com/shifthackz/aisdv1/core/validation/horde/CommonStringValidator.kt rename to core/validation/src/main/java/com/shifthackz/aisdv1/core/validation/common/CommonStringValidator.kt index abe4a2ff..fd9c3ac3 100644 --- a/core/validation/src/main/java/com/shifthackz/aisdv1/core/validation/horde/CommonStringValidator.kt +++ b/core/validation/src/main/java/com/shifthackz/aisdv1/core/validation/common/CommonStringValidator.kt @@ -1,4 +1,4 @@ -package com.shifthackz.aisdv1.core.validation.horde +package com.shifthackz.aisdv1.core.validation.common import com.shifthackz.aisdv1.core.validation.ValidationResult diff --git a/core/validation/src/main/java/com/shifthackz/aisdv1/core/validation/horde/CommonStringValidatorImpl.kt b/core/validation/src/main/java/com/shifthackz/aisdv1/core/validation/common/CommonStringValidatorImpl.kt similarity index 89% rename from core/validation/src/main/java/com/shifthackz/aisdv1/core/validation/horde/CommonStringValidatorImpl.kt rename to core/validation/src/main/java/com/shifthackz/aisdv1/core/validation/common/CommonStringValidatorImpl.kt index 170f1e97..bf3c69ae 100644 --- a/core/validation/src/main/java/com/shifthackz/aisdv1/core/validation/horde/CommonStringValidatorImpl.kt +++ b/core/validation/src/main/java/com/shifthackz/aisdv1/core/validation/common/CommonStringValidatorImpl.kt @@ -1,4 +1,4 @@ -package com.shifthackz.aisdv1.core.validation.horde +package com.shifthackz.aisdv1.core.validation.common import com.shifthackz.aisdv1.core.validation.ValidationResult diff --git a/core/validation/src/main/java/com/shifthackz/aisdv1/core/validation/di/ValidatorsModule.kt b/core/validation/src/main/java/com/shifthackz/aisdv1/core/validation/di/ValidatorsModule.kt index 33a24e52..7c0d819c 100644 --- a/core/validation/src/main/java/com/shifthackz/aisdv1/core/validation/di/ValidatorsModule.kt +++ b/core/validation/src/main/java/com/shifthackz/aisdv1/core/validation/di/ValidatorsModule.kt @@ -1,9 +1,9 @@ package com.shifthackz.aisdv1.core.validation.di +import com.shifthackz.aisdv1.core.validation.common.CommonStringValidator +import com.shifthackz.aisdv1.core.validation.common.CommonStringValidatorImpl import com.shifthackz.aisdv1.core.validation.dimension.DimensionValidator import com.shifthackz.aisdv1.core.validation.dimension.DimensionValidatorImpl -import com.shifthackz.aisdv1.core.validation.horde.CommonStringValidator -import com.shifthackz.aisdv1.core.validation.horde.CommonStringValidatorImpl import com.shifthackz.aisdv1.core.validation.url.UrlValidator import com.shifthackz.aisdv1.core.validation.url.UrlValidatorImpl import org.koin.core.module.dsl.factoryOf @@ -13,7 +13,7 @@ import org.koin.dsl.module val validatorsModule = module { // !!! Do not use [factoryOf] for DimensionValidatorImpl, it has 2 default Ints in constructor factory { DimensionValidatorImpl() } + factory { UrlValidatorImpl() } - factoryOf(::UrlValidatorImpl) bind UrlValidator::class factoryOf(::CommonStringValidatorImpl) bind CommonStringValidator::class } diff --git a/core/validation/src/main/java/com/shifthackz/aisdv1/core/validation/url/UrlValidatorImpl.kt b/core/validation/src/main/java/com/shifthackz/aisdv1/core/validation/url/UrlValidatorImpl.kt index c3e45da2..c6ab7e57 100644 --- a/core/validation/src/main/java/com/shifthackz/aisdv1/core/validation/url/UrlValidatorImpl.kt +++ b/core/validation/src/main/java/com/shifthackz/aisdv1/core/validation/url/UrlValidatorImpl.kt @@ -4,8 +4,11 @@ import android.util.Patterns import android.webkit.URLUtil import com.shifthackz.aisdv1.core.validation.ValidationResult import java.net.URI +import java.util.regex.Pattern -internal class UrlValidatorImpl : UrlValidator { +internal class UrlValidatorImpl( + private val webUrlPattern: Pattern = Patterns.WEB_URL, +) : UrlValidator { override operator fun invoke(input: String?): ValidationResult = when { input == null -> ValidationResult( @@ -32,7 +35,7 @@ internal class UrlValidatorImpl : UrlValidator { isValid = false, validationError = UrlValidator.Error.Invalid, ) - !Patterns.WEB_URL.matcher(input).matches() -> ValidationResult( + !webUrlPattern.matcher(input).matches() -> ValidationResult( isValid = false, validationError = UrlValidator.Error.Invalid, ) diff --git a/core/validation/src/test/java/com/shifthackz/aisdv1/core/validation/common/CommonStringValidatorImplTest.kt b/core/validation/src/test/java/com/shifthackz/aisdv1/core/validation/common/CommonStringValidatorImplTest.kt new file mode 100644 index 00000000..771f8f0c --- /dev/null +++ b/core/validation/src/test/java/com/shifthackz/aisdv1/core/validation/common/CommonStringValidatorImplTest.kt @@ -0,0 +1,47 @@ +package com.shifthackz.aisdv1.core.validation.common + +import com.shifthackz.aisdv1.core.validation.ValidationResult +import org.junit.Assert +import org.junit.Test + +class CommonStringValidatorImplTest { + + private val validator = CommonStringValidatorImpl() + + @Test + fun `given input is null, expected not valid with Empty error`() { + val expected = ValidationResult( + isValid = false, + validationError = CommonStringValidator.Error.Empty, + ) + val actual = validator(null) + Assert.assertEquals(expected, actual) + } + + @Test + fun `given input is empty, expected not valid with Empty error`() { + val expected = ValidationResult( + isValid = false, + validationError = CommonStringValidator.Error.Empty, + ) + val actual = validator("") + Assert.assertEquals(expected, actual) + } + + @Test + fun `given input is blank, expected not valid with Empty error`() { + val expected = ValidationResult( + isValid = false, + validationError = CommonStringValidator.Error.Empty, + ) + val actual = validator(" ") + Assert.assertEquals(expected, actual) + } + + @Test + fun `given input is non empty string, expected valid`() { + val expected = ValidationResult(true) + val actual = validator("5598 is my favorite") + Assert.assertEquals(expected, actual) + } +} diff --git a/core/validation/src/test/java/com/shifthackz/aisdv1/core/validation/dimension/DimensionValidatorImplTest.kt b/core/validation/src/test/java/com/shifthackz/aisdv1/core/validation/dimension/DimensionValidatorImplTest.kt new file mode 100644 index 00000000..fff6b4ae --- /dev/null +++ b/core/validation/src/test/java/com/shifthackz/aisdv1/core/validation/dimension/DimensionValidatorImplTest.kt @@ -0,0 +1,72 @@ +package com.shifthackz.aisdv1.core.validation.dimension + +import com.shifthackz.aisdv1.core.validation.ValidationResult +import org.junit.Assert +import org.junit.Test + +class DimensionValidatorImplTest { + + private val validator = DimensionValidatorImpl(MIN, MAX) + + @Test + fun `given input is null, expected not valid with Empty error`() { + val expected = ValidationResult( + isValid = false, + validationError = DimensionValidator.Error.Empty, + ) + val actual = validator(null) + Assert.assertEquals(expected, actual) + } + + @Test + fun `given input is empty, expected not valid with Empty error`() { + val expected = ValidationResult( + isValid = false, + validationError = DimensionValidator.Error.Empty, + ) + val actual = validator("") + Assert.assertEquals(expected, actual) + } + + @Test + fun `given input is unparsable to int, expected not valid with Unexpected error`() { + val expected = ValidationResult( + isValid = false, + validationError = DimensionValidator.Error.Unexpected, + ) + val actual = validator("5598❤") + Assert.assertEquals(expected, actual) + } + + @Test + fun `given input is less than minimum allowed value, expected not valid with LessThanMinimum error`() { + val expected = ValidationResult( + isValid = false, + validationError = DimensionValidator.Error.LessThanMinimum(MIN), + ) + val actual = validator("55") + Assert.assertEquals(expected, actual) + } + + @Test + fun `given input is bigger than maximum allowed value, expected not valid with BiggerThanMaximum error`() { + val expected = ValidationResult( + isValid = false, + validationError = DimensionValidator.Error.BiggerThanMaximum(MAX), + ) + val actual = validator("5598") + Assert.assertEquals(expected, actual) + } + + @Test + fun `given input is valid parsable int value, expected valid`() { + val expected = ValidationResult(true) + val actual = validator("1024") + Assert.assertEquals(expected, actual) + } + + companion object { + private const val MIN = 64 + private const val MAX = 2048 + } +} diff --git a/core/validation/src/test/java/com/shifthackz/aisdv1/core/validation/url/UrlValidatorImplTest.kt b/core/validation/src/test/java/com/shifthackz/aisdv1/core/validation/url/UrlValidatorImplTest.kt new file mode 100644 index 00000000..549f8a61 --- /dev/null +++ b/core/validation/src/test/java/com/shifthackz/aisdv1/core/validation/url/UrlValidatorImplTest.kt @@ -0,0 +1,126 @@ +package com.shifthackz.aisdv1.core.validation.url + +import android.webkit.URLUtil +import com.shifthackz.aisdv1.core.validation.ValidationResult +import io.mockk.every +import io.mockk.mockk +import io.mockk.mockkConstructor +import io.mockk.mockkStatic +import org.junit.Assert +import org.junit.Test +import java.util.regex.Matcher +import java.util.regex.Pattern + +class UrlValidatorImplTest { + + private val stubMatcher = mockk() + private val stubPattern = mockk() + + private val validator = UrlValidatorImpl(stubPattern) + + @Test + fun `given input is null, expected not valid with Empty error`() { + val expected = ValidationResult( + isValid = false, + validationError = UrlValidator.Error.Empty, + ) + val actual = validator(null) + Assert.assertEquals(expected, actual) + } + + @Test + fun `given input is empty, expected not valid with Empty error`() { + val expected = ValidationResult( + isValid = false, + validationError = UrlValidator.Error.Empty, + ) + val actual = validator("") + Assert.assertEquals(expected, actual) + } + + @Test + fun `given input is blank, expected not valid with Empty error`() { + val expected = ValidationResult( + isValid = false, + validationError = UrlValidator.Error.Empty, + ) + val actual = validator(" ") + Assert.assertEquals(expected, actual) + } + + @Test + fun `given input is url with ftp protocol, expected not valid with BadScheme error`() { + val expected = ValidationResult( + isValid = false, + validationError = UrlValidator.Error.BadScheme, + ) + val actual = validator("ftp://5598.is.my.favorite.com:21/i_failed.dat") + Assert.assertEquals(expected, actual) + } + + @Test + fun `given input is url with port 99999, expected not valid with BadPort error`() { + val expected = ValidationResult( + isValid = false, + validationError = UrlValidator.Error.BadPort, + ) + val actual = validator("http://5598.is.my.favorite.com:99999") + Assert.assertEquals(expected, actual) + } + + @Test + fun `given input is http localhost ipv4 address, expected not valid with Localhost error`() { + val expected = ValidationResult( + isValid = false, + validationError = UrlValidator.Error.Localhost, + ) + val actual = validator("http://127.0.0.1:7860") + Assert.assertEquals(expected, actual) + } + + @Test + fun `given input is https localhost ipv4 address, expected not valid with Localhost error`() { + val expected = ValidationResult( + isValid = false, + validationError = UrlValidator.Error.Localhost, + ) + val actual = validator("https://127.0.0.1:7860") + Assert.assertEquals(expected, actual) + } + + @Test + fun `given input is not valid url, expected not valid with Invalid error`() { + val mockInput = "https://968.666.777.5598:00000000" + mockkConstructor(URLUtil::class) + mockkStatic(URLUtil::class) + every { + URLUtil.isValidUrl(mockInput) + } returns false + val expected = ValidationResult( + isValid = false, + validationError = UrlValidator.Error.Invalid, + ) + val actual = validator(mockInput) + Assert.assertEquals(expected, actual) + } + + @Test + fun `given input is valid url, expected valid`() { + val mockInput = "https://192.168.0.1:7860" + mockkConstructor(URLUtil::class) + mockkStatic(URLUtil::class) + every { + URLUtil.isValidUrl(mockInput) + } returns true + every { + stubMatcher.matches() + } returns true + every { + stubPattern.matcher(any()) + } returns stubMatcher + + val expected = ValidationResult(true) + val actual = validator(mockInput) + Assert.assertEquals(expected, actual) + } +} diff --git a/data/build.gradle b/data/build.gradle index d04d6e1f..8c09db1e 100755 --- a/data/build.gradle +++ b/data/build.gradle @@ -8,6 +8,14 @@ apply from: "$project.rootDir/gradle/common.gradle" android { namespace 'com.shifthackz.aisdv1.data' + testOptions { + unitTests.all { + jvmArgs( + "--add-opens", "java.base/java.lang=ALL-UNNAMED", + "--add-opens", "java.base/java.lang.reflect=ALL-UNNAMED" + ) + } + } } dependencies { @@ -20,4 +28,7 @@ dependencies { implementation di.koinAndroid implementation reactive.rxkotlin implementation google.gson + testImplementation test.junit + testImplementation test.mockito + testImplementation test.mockk } diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/core/CoreMediaStoreRepository.kt b/data/src/main/java/com/shifthackz/aisdv1/data/core/CoreMediaStoreRepository.kt index 18c24987..281befda 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/core/CoreMediaStoreRepository.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/core/CoreMediaStoreRepository.kt @@ -37,8 +37,8 @@ internal abstract class CoreMediaStoreRepository( val stream = ByteArrayOutputStream() bmp.compress(Bitmap.CompressFormat.PNG, 100, stream) mediaStoreGateway.exportToFile( - "sdai_${System.currentTimeMillis()}", - stream.toByteArray() + fileName = "sdai_${System.currentTimeMillis()}", + content = stream.toByteArray(), ) } .onErrorComplete { t -> diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/di/LocalDataSourceModule.kt b/data/src/main/java/com/shifthackz/aisdv1/data/di/LocalDataSourceModule.kt index 9bf5fd7b..97410914 100755 --- a/data/src/main/java/com/shifthackz/aisdv1/data/di/LocalDataSourceModule.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/di/LocalDataSourceModule.kt @@ -31,7 +31,8 @@ import org.koin.dsl.module val localDataSourceModule = module { singleOf(::DatabaseClearGatewayImpl) bind DatabaseClearGateway::class - singleOf(::StabilityAiCreditsLocalDataSource) bind StabilityAiCreditsDataSource.Local::class + // !!! Do not use [factoryOf] for StabilityAiCreditsLocalDataSource, it has default constructor + single { StabilityAiCreditsLocalDataSource() } factoryOf(::StableDiffusionModelsLocalDataSource) bind StableDiffusionModelsDataSource.Local::class factoryOf(::StableDiffusionSamplersLocalDataSource) bind StableDiffusionSamplersDataSource.Local::class factoryOf(::StableDiffusionLorasLocalDataSource) bind StableDiffusionLorasDataSource.Local::class diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSource.kt index 5284b92d..0e4b9f02 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSource.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSource.kt @@ -11,7 +11,6 @@ import com.shifthackz.aisdv1.domain.preference.PreferenceManager import com.shifthackz.aisdv1.storage.db.persistent.dao.LocalModelDao import com.shifthackz.aisdv1.storage.db.persistent.entity.LocalModelEntity import io.reactivex.rxjava3.core.Completable -import io.reactivex.rxjava3.core.Flowable import io.reactivex.rxjava3.core.Observable import io.reactivex.rxjava3.core.Single import java.io.File @@ -23,7 +22,8 @@ internal class DownloadableModelLocalDataSource( private val buildInfoProvider: BuildInfoProvider, ) : DownloadableModelDataSource.Local { - override fun getAll(): Single> = dao.query() + override fun getAll() = dao + .query() .map(List::mapEntityToDomain) .map { models -> buildList { @@ -34,20 +34,24 @@ internal class DownloadableModelLocalDataSource( .flatMap { models -> models.withLocalData() } override fun getById(id: String): Single { - val chain = if (id == LocalAiModel.CUSTOM.id) Single.just(LocalAiModel.CUSTOM) - else dao - .queryById(id) - .map(LocalModelEntity::mapEntityToDomain) + val chain = if (id == LocalAiModel.CUSTOM.id) { + Single.just(LocalAiModel.CUSTOM) + } else { + dao + .queryById(id) + .map(LocalModelEntity::mapEntityToDomain) + } return chain.flatMap { model -> model.withLocalData() } } - override fun getSelected(): Single = Single + override fun getSelected() = Single .just(preferenceManager.localModelId) + .onErrorResumeNext { Single.error(IllegalStateException("No selected model.")) } .flatMap(::getById) - .onErrorResumeNext { Single.error(Throwable("No selected model")) } + .onErrorResumeNext { Single.error(IllegalStateException("No selected model.")) } - override fun observeAll(): Flowable> = dao + override fun observeAll() = dao .observe() .map(List::mapEntityToDomain) .map { models -> @@ -58,7 +62,7 @@ internal class DownloadableModelLocalDataSource( } .flatMap { models -> models.withLocalData().toFlowable() } - override fun select(id: String): Completable = Completable.fromAction { + override fun select(id: String) = Completable.fromAction { preferenceManager.localModelId = id } @@ -67,15 +71,13 @@ internal class DownloadableModelLocalDataSource( .mapDomainToEntity() .let(dao::insertList) - override fun isDownloaded(id: String): Single = Single.create { emitter -> + override fun isDownloaded(id: String) = Single.create { emitter -> try { if (id == LocalAiModel.CUSTOM.id) { if (!emitter.isDisposed) emitter.onSuccess(true) } else { - val localModelDir = getLocalModelDirectory(id) - val files = - (localModelDir.listFiles()?.filter { it.isDirectory }) ?: emptyList() - if (!emitter.isDisposed) emitter.onSuccess(localModelDir.exists() && files.size == 4) + val files = getLocalModelFiles(id) + if (!emitter.isDisposed) emitter.onSuccess(files.size == 4) } } catch (e: Exception) { if (!emitter.isDisposed) emitter.onSuccess(false) @@ -90,12 +92,18 @@ internal class DownloadableModelLocalDataSource( return File("${fileProviderDescriptor.localModelDirPath}/${id}") } - private fun List.withLocalData(): Single> = Observable + private fun getLocalModelFiles(id: String): List { + val localModelDir = getLocalModelDirectory(id) + if (!localModelDir.exists()) return emptyList() + return localModelDir.listFiles()?.filter { it.isDirectory } ?: emptyList() + } + + private fun List.withLocalData() = Observable .fromIterable(this) .flatMapSingle { model -> model.withLocalData() } .toList() - private fun LocalAiModel.withLocalData(): Single = isDownloaded(id) + private fun LocalAiModel.withLocalData() = isDownloaded(id) .map { downloaded -> copy( downloaded = downloaded, diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/local/GenerationResultLocalDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/local/GenerationResultLocalDataSource.kt index ffc8c8cc..99922e71 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/local/GenerationResultLocalDataSource.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/local/GenerationResultLocalDataSource.kt @@ -12,8 +12,9 @@ internal class GenerationResultLocalDataSource( private val dao: GenerationResultDao, ) : GenerationResultDataSource.Local { - override fun insert(result: AiGenerationResult) = dao - .insert(result.mapDomainToEntity()) + override fun insert(result: AiGenerationResult) = result + .mapDomainToEntity() + .let(dao::insert) override fun queryAll(): Single> = dao .query() diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/local/ServerConfigurationLocalDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/local/ServerConfigurationLocalDataSource.kt index 37f4e7c0..fb6219d5 100755 --- a/data/src/main/java/com/shifthackz/aisdv1/data/local/ServerConfigurationLocalDataSource.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/local/ServerConfigurationLocalDataSource.kt @@ -6,17 +6,15 @@ import com.shifthackz.aisdv1.domain.datasource.ServerConfigurationDataSource import com.shifthackz.aisdv1.domain.entity.ServerConfiguration import com.shifthackz.aisdv1.storage.db.cache.dao.ServerConfigurationDao import com.shifthackz.aisdv1.storage.db.cache.entity.ServerConfigurationEntity -import io.reactivex.rxjava3.core.Completable -import io.reactivex.rxjava3.core.Single internal class ServerConfigurationLocalDataSource( private val dao: ServerConfigurationDao, ) : ServerConfigurationDataSource.Local { - override fun save(configuration: ServerConfiguration): Completable = dao + override fun save(configuration: ServerConfiguration) = dao .insert(configuration.mapToEntity()) - override fun get(): Single = dao + override fun get() = dao .query() .map(ServerConfigurationEntity::mapToDomain) } diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/local/StabilityAiCreditsLocalDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/local/StabilityAiCreditsLocalDataSource.kt index 67ae0a58..54e65ddf 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/local/StabilityAiCreditsLocalDataSource.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/local/StabilityAiCreditsLocalDataSource.kt @@ -5,9 +5,9 @@ import io.reactivex.rxjava3.core.BackpressureStrategy import io.reactivex.rxjava3.core.Completable import io.reactivex.rxjava3.subjects.BehaviorSubject -internal class StabilityAiCreditsLocalDataSource : StabilityAiCreditsDataSource.Local { - - private val creditsSubject: BehaviorSubject = BehaviorSubject.createDefault(0f) +internal class StabilityAiCreditsLocalDataSource( + private val creditsSubject: BehaviorSubject = BehaviorSubject.createDefault(0f), +) : StabilityAiCreditsDataSource.Local { override fun get() = creditsSubject .firstOrError() diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/local/StableDiffusionLorasLocalDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/local/StableDiffusionLorasLocalDataSource.kt index cf3a722c..e130ade2 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/local/StableDiffusionLorasLocalDataSource.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/local/StableDiffusionLorasLocalDataSource.kt @@ -10,6 +10,7 @@ import com.shifthackz.aisdv1.storage.db.cache.entity.StableDiffusionLoraEntity internal class StableDiffusionLorasLocalDataSource( private val dao: StableDiffusionLoraDao, ) : StableDiffusionLorasDataSource.Local { + override fun getLoras() = dao .queryAll() .map(List::mapEntityToDomain) diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/local/StableDiffusionModelsLocalDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/local/StableDiffusionModelsLocalDataSource.kt index 0491255e..b3a1012c 100755 --- a/data/src/main/java/com/shifthackz/aisdv1/data/local/StableDiffusionModelsLocalDataSource.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/local/StableDiffusionModelsLocalDataSource.kt @@ -6,18 +6,16 @@ import com.shifthackz.aisdv1.domain.datasource.StableDiffusionModelsDataSource import com.shifthackz.aisdv1.domain.entity.StableDiffusionModel import com.shifthackz.aisdv1.storage.db.cache.dao.StableDiffusionModelDao import com.shifthackz.aisdv1.storage.db.cache.entity.StableDiffusionModelEntity -import io.reactivex.rxjava3.core.Completable -import io.reactivex.rxjava3.core.Single internal class StableDiffusionModelsLocalDataSource( private val dao: StableDiffusionModelDao, ) : StableDiffusionModelsDataSource.Local { - override fun insertModels(models: List): Completable = dao - .deleteAll() - .andThen(dao.insertList(models.mapDomainToEntity())) - - override fun getModels(): Single> = dao + override fun getModels() = dao .queryAll() .map(List::mapEntityToDomain) + + override fun insertModels(models: List) = dao + .deleteAll() + .andThen(dao.insertList(models.mapDomainToEntity())) } diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/local/StableDiffusionSamplersLocalDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/local/StableDiffusionSamplersLocalDataSource.kt index 35b444d5..7cf608a9 100755 --- a/data/src/main/java/com/shifthackz/aisdv1/data/local/StableDiffusionSamplersLocalDataSource.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/local/StableDiffusionSamplersLocalDataSource.kt @@ -6,7 +6,6 @@ import com.shifthackz.aisdv1.domain.datasource.StableDiffusionSamplersDataSource import com.shifthackz.aisdv1.domain.entity.StableDiffusionSampler import com.shifthackz.aisdv1.storage.db.cache.dao.StableDiffusionSamplerDao import com.shifthackz.aisdv1.storage.db.cache.entity.StableDiffusionSamplerEntity -import io.reactivex.rxjava3.core.Completable internal class StableDiffusionSamplersLocalDataSource( private val dao: StableDiffusionSamplerDao, @@ -16,7 +15,7 @@ internal class StableDiffusionSamplersLocalDataSource( .queryAll() .map(List::mapEntityToDomain) - override fun insertSamplers(samplers: List): Completable = dao + override fun insertSamplers(samplers: List) = dao .deleteAll() .andThen(dao.insertList(samplers.mapDomainToEntity())) } diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt index 9d5fa110..6c23ab0e 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt @@ -42,7 +42,6 @@ class PreferenceManagerImpl( .apply() .also { onPreferencesChanged() } - override var autoSaveAiResults: Boolean get() = preferences.getBoolean(KEY_AI_AUTO_SAVE, true) set(value) = preferences.edit() @@ -78,6 +77,7 @@ class PreferenceManagerImpl( .putString(KEY_SERVER_SOURCE, value.key) .apply() .also { onPreferencesChanged() } + override var sdModel: String get() = preferences.getString(KEY_SD_MODEL, "") ?: "" set(value) = preferences.edit() @@ -190,6 +190,7 @@ class PreferenceManagerImpl( override fun observe(): Flowable = preferencesChangedSubject .toFlowable(BackpressureStrategy.LATEST) + .distinctUntilChanged() .map { Settings( serverUrl = serverUrl, @@ -214,28 +215,28 @@ class PreferenceManagerImpl( private fun onPreferencesChanged() = preferencesChangedSubject.onNext(Unit) companion object { - private const val KEY_SERVER_URL = "key_server_url" - private const val KEY_DEMO_MODE = "key_demo_mode" - private const val KEY_MONITOR_CONNECTIVITY = "key_monitor_connectivity" - private const val KEY_AI_AUTO_SAVE = "key_ai_auto_save" - private const val KEY_SAVE_TO_MEDIA_STORE = "key_save_to_media_store" - private const val KEY_FORM_ALWAYS_SHOW_ADVANCED_OPTIONS = "key_always_show_advanced_options" - private const val KEY_FORM_PROMPT_TAGGED_INPUT = "key_prompt_tagged_input" - private const val KEY_SERVER_SOURCE = "key_server_source" - private const val KEY_SD_MODEL = "key_sd_model" - private const val KEY_HORDE_API_KEY = "key_horde_api_key" - private const val KEY_OPEN_AI_API_KEY = "key_open_ai_api_key" - private const val KEY_HUGGING_FACE_API_KEY = "key_hugging_face_api_key" - private const val KEY_HUGGING_FACE_MODEL_KEY = "key_hugging_face_model_key" - private const val KEY_STABILITY_AI_API_KEY = "key_stability_ai_api_key" - private const val KEY_STABILITY_AI_ENGINE_ID_KEY = "key_stability_ai_engine_id_key" - private const val KEY_LOCAL_NN_API = "key_local_nn_api" - private const val KEY_LOCAL_MODEL_ID = "key_local_model_id" - private const val KEY_DESIGN_DYNAMIC_COLORS = "key_design_dynamic_colors" - private const val KEY_DESIGN_SYSTEM_DARK_THEME = "key_design_system_dark_theme" - private const val KEY_DESIGN_DARK_THEME = "key_design_dark_theme" - private const val KEY_DESIGN_COLOR_TOKEN = "key_design_color_token_theme" - private const val KEY_DESIGN_DARK_TOKEN = "key_design_dark_color_token_theme" - private const val KEY_FORCE_SETUP_AFTER_UPDATE = "force_upd_setup_v0.x.x-v0.5.8" + const val KEY_SERVER_URL = "key_server_url" + const val KEY_DEMO_MODE = "key_demo_mode" + const val KEY_MONITOR_CONNECTIVITY = "key_monitor_connectivity" + const val KEY_AI_AUTO_SAVE = "key_ai_auto_save" + const val KEY_SAVE_TO_MEDIA_STORE = "key_save_to_media_store" + const val KEY_FORM_ALWAYS_SHOW_ADVANCED_OPTIONS = "key_always_show_advanced_options" + const val KEY_FORM_PROMPT_TAGGED_INPUT = "key_prompt_tagged_input" + const val KEY_SERVER_SOURCE = "key_server_source" + const val KEY_SD_MODEL = "key_sd_model" + const val KEY_HORDE_API_KEY = "key_horde_api_key" + const val KEY_OPEN_AI_API_KEY = "key_open_ai_api_key" + const val KEY_HUGGING_FACE_API_KEY = "key_hugging_face_api_key" + const val KEY_HUGGING_FACE_MODEL_KEY = "key_hugging_face_model_key" + const val KEY_STABILITY_AI_API_KEY = "key_stability_ai_api_key" + const val KEY_STABILITY_AI_ENGINE_ID_KEY = "key_stability_ai_engine_id_key" + const val KEY_FORCE_SETUP_AFTER_UPDATE = "force_upd_setup_v0.x.x-v0.5.8" + const val KEY_LOCAL_MODEL_ID = "key_local_model_id" + const val KEY_LOCAL_NN_API = "key_local_nn_api" + const val KEY_DESIGN_DYNAMIC_COLORS = "key_design_dynamic_colors" + const val KEY_DESIGN_SYSTEM_DARK_THEME = "key_design_system_dark_theme" + const val KEY_DESIGN_DARK_THEME = "key_design_dark_theme" + const val KEY_DESIGN_COLOR_TOKEN = "key_design_color_token_theme" + const val KEY_DESIGN_DARK_TOKEN = "key_design_dark_color_token_theme" } } diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSource.kt index eda865d7..9241dcc6 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSource.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSource.kt @@ -20,7 +20,8 @@ internal class DownloadableModelRemoteDataSource( .fetchDownloadableModels() .map(List::mapRawToDomain) - override fun download(id: String, url: String): Observable = Completable + override fun download(id: String, url: String): Observable = + Completable .fromAction { val dir = File("${fileProviderDescriptor.localModelDirPath}/${id}") val destination = File(getDestinationPath(id)) diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/remote/HordeGenerationRemoteDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/remote/HordeGenerationRemoteDataSource.kt index 1d8a14c2..86b447da 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/remote/HordeGenerationRemoteDataSource.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/remote/HordeGenerationRemoteDataSource.kt @@ -45,7 +45,7 @@ internal class HordeGenerationRemoteDataSource( override fun interruptGeneration() = statusSource.id ?.let(hordeApi::cancelRequest) - ?: Completable.error(Throwable("No cached request id")) + ?: Completable.error(IllegalStateException("No cached request id")) private fun executeRequestChain(request: HordeGenerationAsyncRequest) = hordeApi .generateAsync(request) diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/repository/GenerationResultRepositoryImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/repository/GenerationResultRepositoryImpl.kt index 7f106654..2fb396c8 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/repository/GenerationResultRepositoryImpl.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/repository/GenerationResultRepositoryImpl.kt @@ -14,8 +14,11 @@ internal class GenerationResultRepositoryImpl( mediaStoreGateway: MediaStoreGateway, base64ToBitmapConverter: Base64ToBitmapConverter, private val localDataSource: GenerationResultDataSource.Local, -) : CoreMediaStoreRepository(preferenceManager, mediaStoreGateway, base64ToBitmapConverter), - GenerationResultRepository { +) : CoreMediaStoreRepository( + preferenceManager, + mediaStoreGateway, + base64ToBitmapConverter, +), GenerationResultRepository { override fun getAll() = localDataSource.queryAll() diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/repository/HuggingFaceModelsRepositoryImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/repository/HuggingFaceModelsRepositoryImpl.kt index e4240621..abf2e578 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/repository/HuggingFaceModelsRepositoryImpl.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/repository/HuggingFaceModelsRepositoryImpl.kt @@ -13,6 +13,7 @@ internal class HuggingFaceModelsRepositoryImpl( .concatMapCompletable(localDataSource::save) override fun fetchAndGetHuggingFaceModels() = fetchHuggingFaceModels() + .onErrorComplete() .andThen(getHuggingFaceModels()) override fun getHuggingFaceModels() = localDataSource.getAll() diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionGenerationRepositoryImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionGenerationRepositoryImpl.kt index 1fff92e4..c4a4b119 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionGenerationRepositoryImpl.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionGenerationRepositoryImpl.kt @@ -36,7 +36,7 @@ internal class LocalDiffusionGenerationRepositoryImpl( .getSelected() .flatMap { model -> if (model.downloaded) generate(payload) - else Single.error(Throwable("Model not downloaded")) + else Single.error(IllegalStateException("Model not downloaded.")) } override fun interruptGeneration() = localDiffusion.interrupt() diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/repository/ServerConfigurationRepositoryImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/repository/ServerConfigurationRepositoryImpl.kt index 1cdb4227..f7c20567 100755 --- a/data/src/main/java/com/shifthackz/aisdv1/data/repository/ServerConfigurationRepositoryImpl.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/repository/ServerConfigurationRepositoryImpl.kt @@ -3,24 +3,22 @@ package com.shifthackz.aisdv1.data.repository import com.shifthackz.aisdv1.domain.datasource.ServerConfigurationDataSource import com.shifthackz.aisdv1.domain.entity.ServerConfiguration import com.shifthackz.aisdv1.domain.repository.ServerConfigurationRepository -import io.reactivex.rxjava3.core.Completable -import io.reactivex.rxjava3.core.Single internal class ServerConfigurationRepositoryImpl( private val remoteDataSource: ServerConfigurationDataSource.Remote, private val localDataSource: ServerConfigurationDataSource.Local, ) : ServerConfigurationRepository { - override fun fetchConfiguration(): Completable = remoteDataSource + override fun fetchConfiguration() = remoteDataSource .fetchConfiguration() .flatMapCompletable(localDataSource::save) - override fun fetchAndGetConfiguration(): Single = - fetchConfiguration() - .andThen(getConfiguration()) + override fun fetchAndGetConfiguration() = fetchConfiguration() + .onErrorComplete() + .andThen(getConfiguration()) - override fun getConfiguration(): Single = localDataSource.get() + override fun getConfiguration() = localDataSource.get() - override fun updateConfiguration(configuration: ServerConfiguration): Completable = - remoteDataSource.updateConfiguration(configuration) + override fun updateConfiguration(configuration: ServerConfiguration) = remoteDataSource + .updateConfiguration(configuration) } diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/repository/StabilityAiCreditsRepositoryImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/repository/StabilityAiCreditsRepositoryImpl.kt index 0079d1ab..ed9ef0d6 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/repository/StabilityAiCreditsRepositoryImpl.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/repository/StabilityAiCreditsRepositoryImpl.kt @@ -18,30 +18,30 @@ internal class StabilityAiCreditsRepositoryImpl( onValid = remoteDataSource .fetch() .flatMapCompletable(localDataSource::save), - onNotValid = Completable.error(Throwable("Wrong server source selected.")), + onNotValid = Completable.error(IllegalStateException("Wrong server source selected.")), ) override fun fetchAndGet() = checkServerSource( onValid = fetch().onErrorComplete().andThen(get()), - onNotValid = Single.error(Throwable("Wrong server source selected.")), + onNotValid = Single.error(IllegalStateException("Wrong server source selected.")), ) override fun fetchAndObserve() = checkServerSource( onValid = fetch().onErrorComplete().andThen(observe()), - onNotValid = Flowable.error(Throwable("Wrong server source selected.")), + onNotValid = Flowable.error(IllegalStateException("Wrong server source selected.")), ) override fun get() = checkServerSource( onValid = localDataSource.get(), - onNotValid = Single.error(Throwable("Wrong server source selected.")), + onNotValid = Single.error(IllegalStateException("Wrong server source selected.")), ) override fun observe() = checkServerSource( onValid = localDataSource.observe(), - onNotValid = Flowable.error(Throwable("Wrong server source selected.")), + onNotValid = Flowable.error(IllegalStateException("Wrong server source selected.")), ) - private fun checkServerSource( + private fun checkServerSource( onValid: T, onNotValid: T, ): T = when (preferenceManager.source) { diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/repository/StabilityAiGenerationRepositoryImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/repository/StabilityAiGenerationRepositoryImpl.kt index c28e5d42..e5185774 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/repository/StabilityAiGenerationRepositoryImpl.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/repository/StabilityAiGenerationRepositoryImpl.kt @@ -61,5 +61,6 @@ internal class StabilityAiGenerationRepositoryImpl( private fun refreshCredits(ai: AiGenerationResult) = creditsRds .fetch() .flatMapCompletable(creditsLds::save) + .onErrorComplete() .andThen(Single.just(ai)) } diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/repository/TemporaryGenerationResultRepositoryImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/repository/TemporaryGenerationResultRepositoryImpl.kt index 84cb7a45..e53acfa7 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/repository/TemporaryGenerationResultRepositoryImpl.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/repository/TemporaryGenerationResultRepositoryImpl.kt @@ -16,6 +16,6 @@ internal class TemporaryGenerationResultRepositoryImpl : TemporaryGenerationResu override fun get(): Single { return lastCachedResult ?.let { Single.just(it) } - ?: Single.error(Throwable("No last cached result")) + ?: Single.error(IllegalStateException("No last cached result.")) } } diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/repository/WakeLockRepositoryImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/repository/WakeLockRepositoryImpl.kt index 0c15059c..90263707 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/repository/WakeLockRepositoryImpl.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/repository/WakeLockRepositoryImpl.kt @@ -5,7 +5,7 @@ import com.shifthackz.aisdv1.domain.repository.WakeLockRepository internal class WakeLockRepositoryImpl( val powerManager: () -> PowerManager, -): WakeLockRepository { +) : WakeLockRepository { private var _wakeLock: PowerManager.WakeLock? = null override val wakeLock: PowerManager.WakeLock diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/gateway/DatabaseClearGatewayImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/gateway/DatabaseClearGatewayImplTest.kt new file mode 100644 index 00000000..6bd64483 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/gateway/DatabaseClearGatewayImplTest.kt @@ -0,0 +1,75 @@ +package com.shifthackz.aisdv1.data.gateway + +import com.shifthackz.aisdv1.storage.gateway.GatewayClearCacheDb +import com.shifthackz.aisdv1.storage.gateway.GatewayClearPersistentDb +import io.mockk.every +import io.mockk.mockk +import org.junit.Test + +class DatabaseClearGatewayImplTest { + + private val stubException = Throwable("Error occurred.") + private val stubGatewayClearCacheDb = mockk() + private val stubGatewayClearPersistentDb = mockk() + + private val gateway = DatabaseClearGatewayImpl( + gatewayClearCacheDb = stubGatewayClearCacheDb, + gatewayClearPersistentDb = stubGatewayClearPersistentDb, + ) + + @Test + fun `given attempt to clearSessionScopeDb, operation succeed, expected complete value`() { + every { + stubGatewayClearCacheDb() + } returns Unit + + gateway + .clearSessionScopeDb() + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given attempt to clearSessionScopeDb, operation failed, expected error value`() { + every { + stubGatewayClearCacheDb() + } throws stubException + + gateway + .clearSessionScopeDb() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to clearStorageScopeDb, operation succeed, expected complete value`() { + every { + stubGatewayClearPersistentDb() + } returns Unit + + gateway + .clearStorageScopeDb() + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given attempt to clearStorageScopeDb, operation failed, expected error value`() { + every { + stubGatewayClearPersistentDb() + } throws stubException + + gateway + .clearStorageScopeDb() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/gateway/ServerConnectivityGatewayImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/gateway/ServerConnectivityGatewayImplTest.kt new file mode 100644 index 00000000..572b308e --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/gateway/ServerConnectivityGatewayImplTest.kt @@ -0,0 +1,116 @@ +package com.shifthackz.aisdv1.data.gateway + +import com.shifthackz.aisdv1.data.provider.ServerUrlProvider +import com.shifthackz.aisdv1.network.connectivity.ConnectivityMonitor +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Observable +import io.reactivex.rxjava3.core.Single +import io.reactivex.rxjava3.subjects.BehaviorSubject +import org.junit.Before +import org.junit.Test + +class ServerConnectivityGatewayImplTest { + + private val stubException = Throwable("Internal server error.") + private val stubConnectivityState = BehaviorSubject.create() + private val stubConnectivityMonitor = mockk< ConnectivityMonitor>() + private val stubServerUrlProvider = mockk() + + private val gateway = ServerConnectivityGatewayImpl( + connectivityMonitor = stubConnectivityMonitor, + serverUrlProvider = stubServerUrlProvider, + ) + + @Before + fun initialize() { + every { + stubServerUrlProvider(any()) + } returns Single.just("http://192.168.0.1:7860") + + every { + stubConnectivityMonitor.observe(any()) + } returns stubConnectivityState + } + + @Test + fun `given initially offline, then go online, expected false, then true`() { + val stubObserver = gateway.observe().test() + + stubConnectivityState.onNext(false) + + stubObserver + .assertNoErrors() + .assertValueAt(0, false) + + stubConnectivityState.onNext(true) + + stubObserver + .assertNoErrors() + .assertValueAt(1, true) + } + + @Test + fun `given initially online, then go offline, expected true, then false`() { + val stubObserver = gateway.observe().test() + + stubConnectivityState.onNext(true) + + stubObserver + .assertNoErrors() + .assertValueAt(0, true) + + stubConnectivityState.onNext(false) + + stubObserver + .assertNoErrors() + .assertValueAt(1, false) + } + + @Test + fun `given received online signal twice, expected true, twice`() { + val stubObserver = gateway.observe().test() + + stubConnectivityState.onNext(true) + + stubObserver + .assertNoErrors() + .assertValueAt(0, true) + + stubConnectivityState.onNext(true) + + stubObserver + .assertNoErrors() + .assertValueAt(1, true) + } + + @Test + fun `given received offline signal twice, expected false, twice`() { + val stubObserver = gateway.observe().test() + + stubConnectivityState.onNext(false) + + stubObserver + .assertNoErrors() + .assertValueAt(0, false) + + stubConnectivityState.onNext(false) + + stubObserver + .assertNoErrors() + .assertValueAt(1, false) + } + + @Test + fun `given connectivity monitor throws error, expected error value`() { + every { + stubConnectivityMonitor.observe(any()) + } returns Observable.error(stubException) + + gateway.observe().test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/gateway/mediastore/MediaStoreGatewayFactoryTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/gateway/mediastore/MediaStoreGatewayFactoryTest.kt new file mode 100644 index 00000000..a4cbefe4 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/gateway/mediastore/MediaStoreGatewayFactoryTest.kt @@ -0,0 +1,81 @@ +package com.shifthackz.aisdv1.data.gateway.mediastore + +import android.content.Context +import android.os.Build +import com.shifthackz.aisdv1.core.common.file.FileProviderDescriptor +import io.mockk.mockk +import org.junit.Assert +import org.junit.Test +import java.lang.reflect.Field +import java.lang.reflect.Method +import java.lang.reflect.Modifier + +class MediaStoreGatewayFactoryTest { + + private val stubContext = mockk() + private val stubFileProviderDescriptor = mockk() + + private val factory = MediaStoreGatewayFactory( + context = stubContext, + fileProviderDescriptor = stubFileProviderDescriptor, + ) + + @Test + fun `given app running on Android SDK 26 (O), expected factory returned instance of type MediaStoreGatewayOldImpl`() { + mockSdkInt(Build.VERSION_CODES.O) + val actual = factory.invoke() + Assert.assertEquals(true, actual is MediaStoreGatewayOldImpl) + } + + @Test + fun `given app running on Android SDK 31 (S), expected factory returned instance of type MediaStoreGatewayOldImpl`() { + mockSdkInt(Build.VERSION_CODES.S) + val actual = factory.invoke() + Assert.assertEquals(true, actual is MediaStoreGatewayOldImpl) + } + + @Test + fun `given app running on Android SDK 32 (S_V2), expected factory returned instance of type MediaStoreGatewayOldImpl`() { + mockSdkInt(Build.VERSION_CODES.S_V2) + val actual = factory.invoke() + Assert.assertEquals(true, actual is MediaStoreGatewayImpl) + } + + @Test + fun `given app running on Android SDK 34 (UPSIDE_DOWN_CAKE), expected factory returned instance of type MediaStoreGatewayOldImpl`() { + mockSdkInt(Build.VERSION_CODES.UPSIDE_DOWN_CAKE) + val actual = factory.invoke() + Assert.assertEquals(true, actual is MediaStoreGatewayImpl) + } + + private fun mockSdkInt(sdkInt: Int) { + val sdkIntField = Build.VERSION::class.java.getField("SDK_INT") + sdkIntField.isAccessible = true + getModifiersField().also { + it.isAccessible = true + it.set(sdkIntField, sdkIntField.modifiers and Modifier.FINAL.inv()) + } + sdkIntField.set(null, sdkInt) + } + + private fun getModifiersField(): Field { + return try { + Field::class.java.getDeclaredField("modifiers") + } catch (e: NoSuchFieldException) { + try { + val getDeclaredFields0: Method = + Class::class.java.getDeclaredMethod("getDeclaredFields0", Boolean::class.javaPrimitiveType) + getDeclaredFields0.isAccessible = true + val fields = getDeclaredFields0.invoke(Field::class.java, false) as Array + for (field in fields) { + if ("modifiers" == field.name) { + return field + } + } + } catch (ex: ReflectiveOperationException) { + e.addSuppressed(ex) + } + throw e + } + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSourceTest.kt new file mode 100644 index 00000000..e1c2c830 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSourceTest.kt @@ -0,0 +1,470 @@ +package com.shifthackz.aisdv1.data.local + +import com.shifthackz.aisdv1.core.common.appbuild.BuildInfoProvider +import com.shifthackz.aisdv1.core.common.appbuild.BuildType +import com.shifthackz.aisdv1.core.common.file.FileProviderDescriptor +import com.shifthackz.aisdv1.data.mappers.mapEntityToDomain +import com.shifthackz.aisdv1.data.mocks.mockLocalAiModels +import com.shifthackz.aisdv1.data.mocks.mockLocalModelEntities +import com.shifthackz.aisdv1.data.mocks.mockLocalModelEntity +import com.shifthackz.aisdv1.domain.entity.LocalAiModel +import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import com.shifthackz.aisdv1.storage.db.persistent.dao.LocalModelDao +import com.shifthackz.aisdv1.storage.db.persistent.entity.LocalModelEntity +import io.mockk.every +import io.mockk.mockk +import io.mockk.mockkConstructor +import io.mockk.mockkStatic +import io.reactivex.rxjava3.core.BackpressureStrategy +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Flowable +import io.reactivex.rxjava3.core.Single +import io.reactivex.rxjava3.subjects.BehaviorSubject +import org.junit.Assert +import org.junit.Test +import java.io.File + +class DownloadableModelLocalDataSourceTest { + + private val stubException = Throwable("Database error.") + private val stubLocalModels = BehaviorSubject.create>() + private val stubFileProviderDescriptor = mockk() + private val stubDao = mockk() + private val stubPreferenceManager = mockk() + private val stubBuildInfoProvider = mockk() + + private val localDataSource = DownloadableModelLocalDataSource( + fileProviderDescriptor = stubFileProviderDescriptor, + dao = stubDao, + preferenceManager = stubPreferenceManager, + buildInfoProvider = stubBuildInfoProvider, + ) + + @Test + fun `given attempt to get all models, dao returns models list, app build type is PLAY, expected valid domain models list`() { + every { + stubDao.query() + } returns Single.just(mockLocalModelEntities) + + every { + stubBuildInfoProvider.type + } returns BuildType.PLAY + + every { + stubPreferenceManager.localModelId + } returns "" + + every { + stubFileProviderDescriptor.localModelDirPath + } returns "/tmp/local" + + val expected = mockLocalModelEntities.mapEntityToDomain() + + localDataSource + .getAll() + .test() + .assertNoErrors() + .assertValue { actual -> + expected == actual && expected.size == actual.size + } + .await() + .assertComplete() + } + + @Test + fun `given attempt to get all models, dao returns empty models list, app build type is PLAY, expected empty domain models list`() { + every { + stubDao.query() + } returns Single.just(emptyList()) + + every { + stubBuildInfoProvider.type + } returns BuildType.PLAY + + every { + stubPreferenceManager.localModelId + } returns "" + + localDataSource + .getAll() + .test() + .assertNoErrors() + .assertValue(emptyList()) + .await() + .assertComplete() + } + + @Test + fun `given attempt to get all models, dao returns models list, app build type is FOSS, expected valid domain models list with CUSTOM model included`() { + every { + stubDao.query() + } returns Single.just(mockLocalModelEntities) + + every { + stubBuildInfoProvider.type + } returns BuildType.FOSS + + every { + stubPreferenceManager.localModelId + } returns "" + + every { + stubFileProviderDescriptor.localModelDirPath + } returns "/tmp/local" + + val expected = buildList { + addAll(mockLocalModelEntities.mapEntityToDomain()) + add(LocalAiModel.CUSTOM.copy(downloaded = true)) + } + + localDataSource + .getAll() + .test() + .assertNoErrors() + .assertValue { actual -> + expected == actual && expected.size == actual.size + } + .await() + .assertComplete() + } + + @Test + fun `given attempt to get all models, dao returns empty models list, app build type is FOSS, expected domain models list with only CUSTOM model included`() { + every { + stubDao.query() + } returns Single.just(emptyList()) + + every { + stubBuildInfoProvider.type + } returns BuildType.FOSS + + every { + stubPreferenceManager.localModelId + } returns "" + + localDataSource + .getAll() + .test() + .assertNoErrors() + .assertValue(listOf(LocalAiModel.CUSTOM.copy(downloaded = true))) + .await() + .assertComplete() + } + + @Test + fun `given attempt to get all models, dao throws exception, expected error value`() { + every { + stubDao.query() + } returns Single.error(stubException) + + localDataSource + .getAll() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to get model by id, dao returns model, model id does not match local model id in preference, expected valid domain model value with selected equals false`() { + every { + stubDao.queryById(any()) + } returns Single.just(mockLocalModelEntity) + + every { + stubPreferenceManager.localModelId + } returns "" + + every { + stubFileProviderDescriptor.localModelDirPath + } returns "/tmp/local" + + val expected = mockLocalModelEntity.mapEntityToDomain() + + localDataSource + .getById("5598") + .test() + .assertNoErrors() + .assertValue(expected) + .await() + .assertComplete() + } + + @Test + fun `given attempt to get model by id, dao returns model, model id matches local model id in preference, expected valid domain model value with selected equals true`() { + every { + stubDao.queryById(any()) + } returns Single.just(mockLocalModelEntity) + + every { + stubPreferenceManager.localModelId + } returns "5598" + + every { + stubFileProviderDescriptor.localModelDirPath + } returns "/tmp/local" + + val expected = mockLocalModelEntity.mapEntityToDomain().copy(selected = true) + + localDataSource + .getById("5598") + .test() + .assertNoErrors() + .assertValue(expected) + .await() + .assertComplete() + } + + @Test + fun `given attempt to get model by id, dao throws exception, expected error true`() { + every { + stubDao.queryById(any()) + } returns Single.error(stubException) + + localDataSource + .getById("5598") + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to get selected model, dao has model with provided id in db, expected domain valid model value`() { + every { + stubDao.queryById(any()) + } returns Single.just(mockLocalModelEntity) + + every { + stubPreferenceManager.localModelId + } returns "5598" + + every { + stubFileProviderDescriptor.localModelDirPath + } returns "/tmp/local" + + val expected = mockLocalModelEntity.mapEntityToDomain().copy(selected = true) + + localDataSource + .getSelected() + .test() + .assertNoErrors() + .assertValue(expected) + .await() + .assertComplete() + } + + @Test + fun `given attempt to get selected model, preference throws exception, expected error value`() { + every { + stubPreferenceManager.localModelId + } returns "" + + localDataSource + .getSelected() + .test() + .assertError { t -> + t is IllegalStateException && t.message == "No selected model." + } + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to observe all models, dao emits empty list, then list with two items, app build type is PLAY, expected empty list, then domain list with two items`() { + every { + stubDao.observe() + } returns stubLocalModels.toFlowable(BackpressureStrategy.LATEST) + + every { + stubBuildInfoProvider.type + } returns BuildType.PLAY + + every { + stubPreferenceManager.localModelId + } returns "" + + every { + stubFileProviderDescriptor.localModelDirPath + } returns "/tmp/local" + + val stubObserver = localDataSource + .observeAll() + .test() + + stubLocalModels.onNext(emptyList()) + + stubObserver + .assertNoErrors() + .assertValueAt(0, emptyList()) + + stubLocalModels.onNext(mockLocalModelEntities) + + stubObserver + .assertNoErrors() + .assertValueAt(1, mockLocalModelEntities.mapEntityToDomain()) + } + + @Test + fun `given attempt to observe all models, dao emits empty list, then list with two items, app build type is FOSS, expected list with only CUSTOM model included, then domain list with two items and CUSTOM`() { + every { + stubDao.observe() + } returns stubLocalModels.toFlowable(BackpressureStrategy.LATEST) + + every { + stubBuildInfoProvider.type + } returns BuildType.FOSS + + every { + stubPreferenceManager.localModelId + } returns "" + + every { + stubFileProviderDescriptor.localModelDirPath + } returns "/tmp/local" + + val stubObserver = localDataSource + .observeAll() + .test() + + stubLocalModels.onNext(emptyList()) + + stubObserver + .assertNoErrors() + .assertValueAt(0, listOf(LocalAiModel.CUSTOM.copy(downloaded = true))) + + stubLocalModels.onNext(mockLocalModelEntities) + + stubObserver + .assertNoErrors() + .assertValueAt(1, buildList { + addAll(mockLocalModelEntities.mapEntityToDomain()) + add(LocalAiModel.CUSTOM.copy(downloaded = true)) + }) + } + + @Test + fun `given attempt to observe all models, dao throws exception, expected error value`() { + every { + stubDao.observe() + } returns Flowable.error(stubException) + + localDataSource + .observeAll() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to select model, preference changed, expected preference returns changed selected model id value`() { + every { + stubPreferenceManager.localModelId + } returns "" + + every { + stubPreferenceManager::localModelId.set(any()) + } returns Unit + + localDataSource + .select("5598") + .test() + .assertNoErrors() + .await() + .assertComplete() + + every { + stubPreferenceManager.localModelId + } returns "5598" + + Assert.assertEquals("5598", stubPreferenceManager.localModelId) + } + + @Test + fun `given attempt to select model, preference throws exception, expected error value`() { + every { + stubPreferenceManager::localModelId.set(any()) + } throws stubException + + localDataSource + .select("5598") + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to save local model list, dao insert success, expected complete value`() { + every { + stubDao.insertList(any()) + } returns Completable.complete() + + localDataSource + .save(mockLocalAiModels) + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given attempt to save local model list, dao throws exception, expected error value`() { + every { + stubDao.insertList(any()) + } returns Completable.error(stubException) + + localDataSource + .save(mockLocalAiModels) + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + //-- + + @Test + fun `given attempt to delete file, delete operation success, expected complete value`() { + every { + stubFileProviderDescriptor.localModelDirPath + } returns "/tmp/local" + + localDataSource + .delete("5598") + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given attempt to delete file, delete operation failed, expected error value`() { + every { + stubFileProviderDescriptor.localModelDirPath + } throws stubException + + localDataSource + .delete("5598") + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to check if CUSTOM model is downloaded, expected true`() { + localDataSource + .isDownloaded(LocalAiModel.CUSTOM.id) + .test() + .assertNoErrors() + .assertValue(true) + .await() + .assertComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/local/GenerationResultLocalDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/local/GenerationResultLocalDataSourceTest.kt new file mode 100644 index 00000000..8d7b194d --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/local/GenerationResultLocalDataSourceTest.kt @@ -0,0 +1,224 @@ +package com.shifthackz.aisdv1.data.local + +import com.shifthackz.aisdv1.data.mocks.mockAiGenerationResult +import com.shifthackz.aisdv1.data.mocks.mockAiGenerationResults +import com.shifthackz.aisdv1.data.mocks.mockGenerationResultEntities +import com.shifthackz.aisdv1.data.mocks.mockGenerationResultEntity +import com.shifthackz.aisdv1.domain.entity.AiGenerationResult +import com.shifthackz.aisdv1.storage.db.persistent.dao.GenerationResultDao +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class GenerationResultLocalDataSourceTest { + + private val stubException = Throwable("Database error.") + private val stubDao = mockk() + + private val localDataSource = GenerationResultLocalDataSource(stubDao) + + @Test + fun `given attempt to insert ai generation result, operation successful, expected id of inserted result value`() { + every { + stubDao.insert(any()) + } returns Single.just(mockAiGenerationResult.id) + + localDataSource + .insert(mockAiGenerationResult) + .test() + .assertNoErrors() + .assertValue(mockAiGenerationResult.id) + .await() + .assertComplete() + } + + @Test + fun `given attempt to insert ai generation result, operation failed, expected error value`() { + every { + stubDao.insert(any()) + } returns Single.error(stubException) + + localDataSource + .insert(mockAiGenerationResult) + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to query ai generation results, dao returns list, expected valid domain model list value`() { + every { + stubDao.query() + } returns Single.just(mockGenerationResultEntities) + + localDataSource + .queryAll() + .test() + .assertNoErrors() + .assertValue(mockAiGenerationResults) + .await() + .assertComplete() + } + + @Test + fun `given attempt to query ai generation results, dao returns empty list, expected empty domain model list value`() { + every { + stubDao.query() + } returns Single.just(emptyList()) + + localDataSource + .queryAll() + .test() + .assertNoErrors() + .assertValue(emptyList()) + .await() + .assertComplete() + } + + @Test + fun `given attempt to query ai generation results, dao throws exception, expected error value`() { + every { + stubDao.query() + } returns Single.error(stubException) + + localDataSource + .queryAll() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given two attempts to query by page, dao has only one page, expected valid page value, then empty page value`() { + every { + stubDao.queryPage(20, 0) + } returns Single.just((0 until 20).map { mockGenerationResultEntity }) + + every { + stubDao.queryPage(20, 1) + } returns Single.just(emptyList()) + + localDataSource + .queryPage(20, 0) + .test() + .assertNoErrors() + .assertValue { actual -> actual is List && actual.size == 20 } + .await() + .assertComplete() + + localDataSource + .queryPage(20, 1) + .test() + .assertNoErrors() + .assertValue(emptyList()) + .await() + .assertComplete() + } + + @Test + fun `given attempt to query by page, dao throws error, expected error value`() { + every { + stubDao.queryPage(any(), any()) + } returns Single.error(stubException) + + localDataSource + .queryPage(20, 0) + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to query by id, dao returns item, expected valid domain model value`() { + every { + stubDao.queryById(5598L) + } returns Single.just(mockGenerationResultEntity) + + localDataSource + .queryById(5598L) + .test() + .assertNoErrors() + .assertValue(mockAiGenerationResult) + .await() + .assertComplete() + } + + @Test + fun `given attempt to query by id, dao throws exception, expected error value`() { + every { + stubDao.queryById(5598L) + } returns Single.error(stubException) + + localDataSource + .queryById(5598L) + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to delete by id, dao deleted successfully, expected complete value`() { + every { + stubDao.deleteById(any()) + } returns Completable.complete() + + localDataSource + .deleteById(5598L) + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given attempt to delete by id, dao delete failure, expected error value`() { + every { + stubDao.deleteById(any()) + } returns Completable.error(stubException) + + localDataSource + .deleteById(5598L) + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to delete all, dao deleted successfully, expected complete value`() { + every { + stubDao.deleteAll() + } returns Completable.complete() + + localDataSource + .deleteAll() + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given attempt to delete all, dao delete failure, expected error value`() { + every { + stubDao.deleteAll() + } returns Completable.error(stubException) + + localDataSource + .deleteAll() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/local/HuggingFaceModelsLocalDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/local/HuggingFaceModelsLocalDataSourceTest.kt new file mode 100644 index 00000000..5248a1de --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/local/HuggingFaceModelsLocalDataSourceTest.kt @@ -0,0 +1,85 @@ +package com.shifthackz.aisdv1.data.local + +import com.shifthackz.aisdv1.data.mappers.mapEntityToDomain +import com.shifthackz.aisdv1.data.mocks.mockHuggingFaceModelEntities +import com.shifthackz.aisdv1.data.mocks.mockHuggingFaceModels +import com.shifthackz.aisdv1.storage.db.persistent.dao.HuggingFaceModelDao +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class HuggingFaceModelsLocalDataSourceTest { + + private val stubException = Throwable("Database error.") + private val stubDao = mockk() + + private val localDataSource = HuggingFaceModelsLocalDataSource(stubDao) + + @Test + fun `given attempt to get all, dao returns list, expected valid domain model list value`() { + every { + stubDao.query() + } returns Single.just(mockHuggingFaceModelEntities) + + localDataSource + .getAll() + .test() + .assertNoErrors() + .assertValue(mockHuggingFaceModelEntities.mapEntityToDomain()) + .await() + .assertComplete() + } + + @Test + fun `given attempt to get all, dao throws exception, expected error value`() { + every { + stubDao.query() + } returns Single.error(stubException) + + localDataSource + .getAll() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to insert list, dao insert success, expected complete value`() { + every { + stubDao.deleteAll() + } returns Completable.complete() + + every { + stubDao.insertList(any()) + } returns Completable.complete() + + localDataSource + .save(mockHuggingFaceModels) + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given attempt to insert list, dao throws exception, expected error value`() { + every { + stubDao.deleteAll() + } returns Completable.complete() + + every { + stubDao.insertList(any()) + } returns Completable.error(stubException) + + localDataSource + .save(mockHuggingFaceModels) + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/local/ServerConfigurationLocalDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/local/ServerConfigurationLocalDataSourceTest.kt new file mode 100644 index 00000000..8edcd977 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/local/ServerConfigurationLocalDataSourceTest.kt @@ -0,0 +1,76 @@ +package com.shifthackz.aisdv1.data.local + +import com.shifthackz.aisdv1.data.mocks.mockServerConfiguration +import com.shifthackz.aisdv1.data.mocks.mockServerConfigurationEntity +import com.shifthackz.aisdv1.storage.db.cache.dao.ServerConfigurationDao +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class ServerConfigurationLocalDataSourceTest { + + private val stubException = Throwable("Database error.") + private val stubDao = mockk() + + private val localDataSource = ServerConfigurationLocalDataSource(stubDao) + + @Test + fun `given attempt to save server configuration, dao insert success, expected complete value`() { + every { + stubDao.insert(any()) + } returns Completable.complete() + + localDataSource + .save(mockServerConfiguration) + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given attempt to save server configuration, dao insert failed, expected error value`() { + every { + stubDao.insert(any()) + } returns Completable.error(stubException) + + localDataSource + .save(mockServerConfiguration) + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to get server configuration, dao returned record, expected valid domain value`() { + every { + stubDao.query() + } returns Single.just(mockServerConfigurationEntity) + + localDataSource + .get() + .test() + .assertNoErrors() + .assertValue(mockServerConfiguration) + .await() + .assertComplete() + } + + @Test + fun `given attempt to get server configuration, dao throws exception, expected error value`() { + every { + stubDao.query() + } returns Single.error(stubException) + + localDataSource + .get() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/local/StabilityAiCreditsLocalDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/local/StabilityAiCreditsLocalDataSourceTest.kt new file mode 100644 index 00000000..b703cb98 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/local/StabilityAiCreditsLocalDataSourceTest.kt @@ -0,0 +1,57 @@ +package com.shifthackz.aisdv1.data.local + +import io.reactivex.rxjava3.subjects.BehaviorSubject +import org.junit.Test + +class StabilityAiCreditsLocalDataSourceTest { + + private val stubSubject = BehaviorSubject.createDefault(0f) + + private val localDataSource = StabilityAiCreditsLocalDataSource(stubSubject) + + @Test + fun `given attempt to get, then save and get, expected default value, save complete, then saved value`() { + localDataSource + .get() + .test() + .assertNoErrors() + .assertValue(0f) + .await() + .assertComplete() + + localDataSource + .save(5598f) + .test() + .assertNoErrors() + .await() + .assertComplete() + + localDataSource + .get() + .test() + .assertNoErrors() + .assertValue(5598f) + .await() + .assertComplete() + } + + @Test + fun `given attempt to observe changes, value changed from default to another, expected default value, save complete, then saved value`() { + val stubObserver = localDataSource.observe().test() + + stubObserver + .assertNoErrors() + .assertValueAt(0, 0f) + + localDataSource + .save(5598f) + .test() + .assertNoErrors() + .await() + .assertComplete() + + stubObserver + .assertNoErrors() + .assertValueAt(1, 5598f) + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/local/StableDiffusionEmbeddingsLocalDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/local/StableDiffusionEmbeddingsLocalDataSourceTest.kt new file mode 100644 index 00000000..09cc3b27 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/local/StableDiffusionEmbeddingsLocalDataSourceTest.kt @@ -0,0 +1,117 @@ +package com.shifthackz.aisdv1.data.local + +import com.shifthackz.aisdv1.data.mocks.mockStableDiffusionEmbeddingEntities +import com.shifthackz.aisdv1.data.mocks.mockStableDiffusionEmbeddings +import com.shifthackz.aisdv1.storage.db.cache.dao.StableDiffusionEmbeddingDao +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class StableDiffusionEmbeddingsLocalDataSourceTest { + + private val stubException = Throwable("Database error.") + private val stubDao = mockk() + + private val localDataSource = StableDiffusionEmbeddingsLocalDataSource(stubDao) + + @Test + fun `given attempt to get embeddings, dao returns list, expected valid domain model list value`() { + every { + stubDao.queryAll() + } returns Single.just(mockStableDiffusionEmbeddingEntities) + + localDataSource + .getEmbeddings() + .test() + .assertNoErrors() + .assertValue(mockStableDiffusionEmbeddings) + .await() + .assertComplete() + } + + @Test + fun `given attempt to get embeddings, dao returns empty list, expected empty domain model list value`() { + every { + stubDao.queryAll() + } returns Single.just(emptyList()) + + localDataSource + .getEmbeddings() + .test() + .assertNoErrors() + .assertValue(emptyList()) + .await() + .assertComplete() + } + + @Test + fun `given attempt to get embeddings, dao throws exception, expected error value`() { + every { + stubDao.queryAll() + } returns Single.error(stubException) + + localDataSource + .getEmbeddings() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to insert embeddings, dao replaces list, expected complete value`() { + every { + stubDao.deleteAll() + } returns Completable.complete() + + every { + stubDao.insertList(any()) + } returns Completable.complete() + + localDataSource + .insertEmbeddings(mockStableDiffusionEmbeddings) + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given attempt to insert embeddings, dao throws exception during delete, expected error value`() { + every { + stubDao.deleteAll() + } returns Completable.error(stubException) + + every { + stubDao.insertList(any()) + } returns Completable.complete() + + localDataSource + .insertEmbeddings(mockStableDiffusionEmbeddings) + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to insert embeddings, dao throws exception during insertion, expected error value`() { + every { + stubDao.deleteAll() + } returns Completable.complete() + + every { + stubDao.insertList(any()) + } returns Completable.error(stubException) + + localDataSource + .insertEmbeddings(mockStableDiffusionEmbeddings) + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/local/StableDiffusionHyperNetworksLocalDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/local/StableDiffusionHyperNetworksLocalDataSourceTest.kt new file mode 100644 index 00000000..be9daa46 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/local/StableDiffusionHyperNetworksLocalDataSourceTest.kt @@ -0,0 +1,117 @@ +package com.shifthackz.aisdv1.data.local + +import com.shifthackz.aisdv1.data.mocks.mockStableDiffusionHyperNetworkEntities +import com.shifthackz.aisdv1.data.mocks.mockStableDiffusionHyperNetworks +import com.shifthackz.aisdv1.storage.db.cache.dao.StableDiffusionHyperNetworkDao +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class StableDiffusionHyperNetworksLocalDataSourceTest { + + private val stubException = Throwable("Database error.") + private val stubDao = mockk() + + private val localDataSource = StableDiffusionHyperNetworksLocalDataSource(stubDao) + + @Test + fun `given attempt to get hypernetworks, dao returns list, expected valid domain model list value`() { + every { + stubDao.queryAll() + } returns Single.just(mockStableDiffusionHyperNetworkEntities) + + localDataSource + .getHyperNetworks() + .test() + .assertNoErrors() + .assertValue(mockStableDiffusionHyperNetworks) + .await() + .assertComplete() + } + + @Test + fun `given attempt to get hypernetworks, dao returns empty list, expected empty domain model list value`() { + every { + stubDao.queryAll() + } returns Single.just(emptyList()) + + localDataSource + .getHyperNetworks() + .test() + .assertNoErrors() + .assertValue(emptyList()) + .await() + .assertComplete() + } + + @Test + fun `given attempt to get hypernetworks, dao throws exception, expected error value`() { + every { + stubDao.queryAll() + } returns Single.error(stubException) + + localDataSource + .getHyperNetworks() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to insert hypernetworks, dao replaces list, expected complete value`() { + every { + stubDao.deleteAll() + } returns Completable.complete() + + every { + stubDao.insertList(any()) + } returns Completable.complete() + + localDataSource + .insertHyperNetworks(mockStableDiffusionHyperNetworks) + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given attempt to insert hypernetworks, dao throws exception during delete, expected error value`() { + every { + stubDao.deleteAll() + } returns Completable.error(stubException) + + every { + stubDao.insertList(any()) + } returns Completable.complete() + + localDataSource + .insertHyperNetworks(mockStableDiffusionHyperNetworks) + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to insert hypernetworks, dao throws exception during insertion, expected error value`() { + every { + stubDao.deleteAll() + } returns Completable.complete() + + every { + stubDao.insertList(any()) + } returns Completable.error(stubException) + + localDataSource + .insertHyperNetworks(mockStableDiffusionHyperNetworks) + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/local/StableDiffusionLorasLocalDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/local/StableDiffusionLorasLocalDataSourceTest.kt new file mode 100644 index 00000000..6a7baaa3 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/local/StableDiffusionLorasLocalDataSourceTest.kt @@ -0,0 +1,117 @@ +package com.shifthackz.aisdv1.data.local + +import com.shifthackz.aisdv1.data.mocks.mockStableDiffusionLoraEntities +import com.shifthackz.aisdv1.data.mocks.mockStableDiffusionLoras +import com.shifthackz.aisdv1.storage.db.cache.dao.StableDiffusionLoraDao +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class StableDiffusionLorasLocalDataSourceTest { + + private val stubException = Throwable("Database error.") + private val stubDao = mockk() + + private val localDataSource = StableDiffusionLorasLocalDataSource(stubDao) + + @Test + fun `given attempt to get loras, dao returns list, expected valid domain model list value`() { + every { + stubDao.queryAll() + } returns Single.just(mockStableDiffusionLoraEntities) + + localDataSource + .getLoras() + .test() + .assertNoErrors() + .assertValue(mockStableDiffusionLoras) + .await() + .assertComplete() + } + + @Test + fun `given attempt to get loras, dao returns empty list, expected empty domain model list value`() { + every { + stubDao.queryAll() + } returns Single.just(emptyList()) + + localDataSource + .getLoras() + .test() + .assertNoErrors() + .assertValue(emptyList()) + .await() + .assertComplete() + } + + @Test + fun `given attempt to get loras, dao throws exception, expected error value`() { + every { + stubDao.queryAll() + } returns Single.error(stubException) + + localDataSource + .getLoras() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to insert loras, dao replaces list, expected complete value`() { + every { + stubDao.deleteAll() + } returns Completable.complete() + + every { + stubDao.insertList(any()) + } returns Completable.complete() + + localDataSource + .insertLoras(mockStableDiffusionLoras) + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given attempt to insert loras, dao throws exception during delete, expected error value`() { + every { + stubDao.deleteAll() + } returns Completable.error(stubException) + + every { + stubDao.insertList(any()) + } returns Completable.complete() + + localDataSource + .insertLoras(mockStableDiffusionLoras) + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to insert loras, dao throws exception during insertion, expected error value`() { + every { + stubDao.deleteAll() + } returns Completable.complete() + + every { + stubDao.insertList(any()) + } returns Completable.error(stubException) + + localDataSource + .insertLoras(mockStableDiffusionLoras) + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/local/StableDiffusionModelsLocalDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/local/StableDiffusionModelsLocalDataSourceTest.kt new file mode 100644 index 00000000..c51db3fc --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/local/StableDiffusionModelsLocalDataSourceTest.kt @@ -0,0 +1,117 @@ +package com.shifthackz.aisdv1.data.local + +import com.shifthackz.aisdv1.data.mocks.mockStableDiffusionModelEntities +import com.shifthackz.aisdv1.data.mocks.mockStableDiffusionModels +import com.shifthackz.aisdv1.storage.db.cache.dao.StableDiffusionModelDao +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class StableDiffusionModelsLocalDataSourceTest { + + private val stubException = Throwable("Database error.") + private val stubDao = mockk() + + private val localDataSource = StableDiffusionModelsLocalDataSource(stubDao) + + @Test + fun `given attempt to get models, dao returns list, expected valid domain model list value`() { + every { + stubDao.queryAll() + } returns Single.just(mockStableDiffusionModelEntities) + + localDataSource + .getModels() + .test() + .assertNoErrors() + .assertValue(mockStableDiffusionModels) + .await() + .assertComplete() + } + + @Test + fun `given attempt to get models, dao returns empty list, expected empty domain model list value`() { + every { + stubDao.queryAll() + } returns Single.just(emptyList()) + + localDataSource + .getModels() + .test() + .assertNoErrors() + .assertValue(emptyList()) + .await() + .assertComplete() + } + + @Test + fun `given attempt to get models, dao throws exception, expected error value`() { + every { + stubDao.queryAll() + } returns Single.error(stubException) + + localDataSource + .getModels() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to insert models, dao replaces list, expected complete value`() { + every { + stubDao.deleteAll() + } returns Completable.complete() + + every { + stubDao.insertList(any()) + } returns Completable.complete() + + localDataSource + .insertModels(mockStableDiffusionModels) + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given attempt to insert models, dao throws exception during delete, expected error value`() { + every { + stubDao.deleteAll() + } returns Completable.error(stubException) + + every { + stubDao.insertList(any()) + } returns Completable.complete() + + localDataSource + .insertModels(mockStableDiffusionModels) + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to insert models, dao throws exception during insertion, expected error value`() { + every { + stubDao.deleteAll() + } returns Completable.complete() + + every { + stubDao.insertList(any()) + } returns Completable.error(stubException) + + localDataSource + .insertModels(mockStableDiffusionModels) + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/local/StableDiffusionSamplersLocalDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/local/StableDiffusionSamplersLocalDataSourceTest.kt new file mode 100644 index 00000000..b7376d68 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/local/StableDiffusionSamplersLocalDataSourceTest.kt @@ -0,0 +1,117 @@ +package com.shifthackz.aisdv1.data.local + +import com.shifthackz.aisdv1.data.mocks.mockStableDiffusionSamplerEntities +import com.shifthackz.aisdv1.data.mocks.mockStableDiffusionSamplers +import com.shifthackz.aisdv1.storage.db.cache.dao.StableDiffusionSamplerDao +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class StableDiffusionSamplersLocalDataSourceTest { + + private val stubException = Throwable("Database error.") + private val stubDao = mockk() + + private val localDataSource = StableDiffusionSamplersLocalDataSource(stubDao) + + @Test + fun `given attempt to get samplers, dao returns list, expected valid domain model list value`() { + every { + stubDao.queryAll() + } returns Single.just(mockStableDiffusionSamplerEntities) + + localDataSource + .getSamplers() + .test() + .assertNoErrors() + .assertValue(mockStableDiffusionSamplers) + .await() + .assertComplete() + } + + @Test + fun `given attempt to get samplers, dao returns empty list, expected empty domain model list value`() { + every { + stubDao.queryAll() + } returns Single.just(emptyList()) + + localDataSource + .getSamplers() + .test() + .assertNoErrors() + .assertValue(emptyList()) + .await() + .assertComplete() + } + + @Test + fun `given attempt to get samplers, dao throws exception, expected error value`() { + every { + stubDao.queryAll() + } returns Single.error(stubException) + + localDataSource + .getSamplers() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to insert samplers, dao replaces list, expected complete value`() { + every { + stubDao.deleteAll() + } returns Completable.complete() + + every { + stubDao.insertList(any()) + } returns Completable.complete() + + localDataSource + .insertSamplers(mockStableDiffusionSamplers) + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given attempt to insert samplers, dao throws exception during delete, expected error value`() { + every { + stubDao.deleteAll() + } returns Completable.error(stubException) + + every { + stubDao.insertList(any()) + } returns Completable.complete() + + localDataSource + .insertSamplers(mockStableDiffusionSamplers) + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to insert samplers, dao throws exception during insertion, expected error value`() { + every { + stubDao.deleteAll() + } returns Completable.complete() + + every { + stubDao.insertList(any()) + } returns Completable.error(stubException) + + localDataSource + .insertSamplers(mockStableDiffusionSamplers) + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/AiGenerationResultMocks.kt b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/AiGenerationResultMocks.kt new file mode 100644 index 00000000..354198f9 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/AiGenerationResultMocks.kt @@ -0,0 +1,26 @@ +package com.shifthackz.aisdv1.data.mocks + +import com.shifthackz.aisdv1.domain.entity.AiGenerationResult +import java.util.Date + +val mockAiGenerationResult = AiGenerationResult( + id = 5598L, + image = "img", + inputImage = "inp", + createdAt = Date(0), + type = AiGenerationResult.Type.IMAGE_TO_IMAGE, + prompt = "prompt", + negativePrompt = "negative", + width = 512, + height = 512, + samplingSteps = 7, + cfgScale = 0.7f, + restoreFaces = true, + sampler = "sampler", + seed = "5598", + subSeed = "1504", + subSeedStrength = 5598f, + denoisingStrength = 1504f, +) + +val mockAiGenerationResults = listOf(mockAiGenerationResult) diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/DownloadableModelResponseMocks.kt b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/DownloadableModelResponseMocks.kt new file mode 100644 index 00000000..deb6dd4d --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/DownloadableModelResponseMocks.kt @@ -0,0 +1,12 @@ +package com.shifthackz.aisdv1.data.mocks + +import com.shifthackz.aisdv1.network.response.DownloadableModelResponse + +val mockDownloadableModelsResponse = listOf( + DownloadableModelResponse( + id = "1", + name = "Model 1", + size = "5 Gb", + sources = listOf("https://example.com/1.html"), + ) +) diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/GenerationResultEntityMocks.kt b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/GenerationResultEntityMocks.kt new file mode 100644 index 00000000..e1ad84b4 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/GenerationResultEntityMocks.kt @@ -0,0 +1,27 @@ +package com.shifthackz.aisdv1.data.mocks + +import com.shifthackz.aisdv1.domain.entity.AiGenerationResult +import com.shifthackz.aisdv1.storage.db.persistent.entity.GenerationResultEntity +import java.util.Date + +val mockGenerationResultEntity = GenerationResultEntity( + id = 5598L, + imageBase64 = "img", + originalImageBase64 = "inp", + createdAt = Date(0), + generationType = AiGenerationResult.Type.IMAGE_TO_IMAGE.key, + prompt = "prompt", + negativePrompt = "negative", + width = 512, + height = 512, + samplingSteps = 7, + cfgScale = 0.7f, + restoreFaces = true, + sampler = "sampler", + seed = "5598", + subSeed = "1504", + subSeedStrength = 5598f, + denoisingStrength = 1504f, +) + +val mockGenerationResultEntities = listOf(mockGenerationResultEntity) diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/HuggingFaceModelEntityMocks.kt b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/HuggingFaceModelEntityMocks.kt new file mode 100644 index 00000000..25922897 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/HuggingFaceModelEntityMocks.kt @@ -0,0 +1,14 @@ +package com.shifthackz.aisdv1.data.mocks + +import com.shifthackz.aisdv1.storage.db.persistent.entity.HuggingFaceModelEntity + +val mockHuggingFaceModelEntity = HuggingFaceModelEntity( + id = "050598", + name = "Super model", + alias = "❤", + source = "https://life.archive.org/models/unique/050598", +) + +val mockHuggingFaceModelEntities = listOf( + mockHuggingFaceModelEntity, +) diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/HuggingFaceModelMocks.kt b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/HuggingFaceModelMocks.kt new file mode 100644 index 00000000..21ad6d07 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/HuggingFaceModelMocks.kt @@ -0,0 +1,12 @@ +package com.shifthackz.aisdv1.data.mocks + +import com.shifthackz.aisdv1.domain.entity.HuggingFaceModel + +val mockHuggingFaceModel = HuggingFaceModel( + id = "050598", + name = "Super model", + alias = "❤", + source = "https://life.archive.org/models/unique/050598", +) + +val mockHuggingFaceModels = listOf(mockHuggingFaceModel) diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/HuggingFaceModelRawMocks.kt b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/HuggingFaceModelRawMocks.kt new file mode 100644 index 00000000..efb83fcf --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/HuggingFaceModelRawMocks.kt @@ -0,0 +1,18 @@ +package com.shifthackz.aisdv1.data.mocks + +import com.shifthackz.aisdv1.network.model.HuggingFaceModelRaw + +val mockHuggingFaceModelsRaw = listOf( + HuggingFaceModelRaw( + id = "151297", + name = "Not so super model", + alias = "❤", + source = "https://life.archive.org/models/unique/151297", + ), + HuggingFaceModelRaw( + id = "050598", + name = "Super model", + alias = "❤", + source = "https://life.archive.org/models/unique/050598", + ), +) diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/ImageToImagePayloadMocks.kt b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/ImageToImagePayloadMocks.kt new file mode 100644 index 00000000..3138ce97 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/ImageToImagePayloadMocks.kt @@ -0,0 +1,29 @@ +package com.shifthackz.aisdv1.data.mocks + +import com.shifthackz.aisdv1.domain.entity.ImageToImagePayload + +val mockImageToImagePayload = ImageToImagePayload( + base64Image = "", + base64MaskImage = "", + denoisingStrength = 7f, + prompt = "prompt", + negativePrompt = "negative", + samplingSteps = 12, + cfgScale = 0.7f, + width = 512, + height = 512, + restoreFaces = true, + seed = "5598", + subSeed = "1504", + subSeedStrength = 5598f, + sampler = "sampler", + nsfw = true, + batchCount = 1, + inPaintingMaskInvert = 0, + inPaintFullResPadding = 0, + inPaintingFill = 0, + inPaintFullRes = false, + maskBlur = 0, + stabilityAiClipGuidance = null, + stabilityAiStylePreset = null, +) diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/LocalAiModelMocks.kt b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/LocalAiModelMocks.kt new file mode 100644 index 00000000..6eee82bc --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/LocalAiModelMocks.kt @@ -0,0 +1,14 @@ +package com.shifthackz.aisdv1.data.mocks + +import com.shifthackz.aisdv1.domain.entity.LocalAiModel + +val mockLocalAiModel = LocalAiModel( + id = "5598", + name = "Model 5598", + size = "5 Gb", + sources = listOf("https://example.com/1.html"), + downloaded = false, + selected = false, +) + +val mockLocalAiModels = listOf(mockLocalAiModel) diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/LocalModelEntityMocks.kt b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/LocalModelEntityMocks.kt new file mode 100644 index 00000000..33b45cd9 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/LocalModelEntityMocks.kt @@ -0,0 +1,20 @@ +package com.shifthackz.aisdv1.data.mocks + +import com.shifthackz.aisdv1.storage.db.persistent.entity.LocalModelEntity + +val mockLocalModelEntity = LocalModelEntity( + id = "5598", + name = "Best model in entire universe", + size = "5598 Gb", + sources = listOf("https://5598.is.my.favourite.com"), +) + +val mockLocalModelEntities = listOf( + LocalModelEntity( + id = "1", + name = "Model 1", + size = "1 Gb", + sources = listOf("https://example.com/1.php"), + ), + mockLocalModelEntity, +) diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/OpenAiResponseMocks.kt b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/OpenAiResponseMocks.kt new file mode 100644 index 00000000..adb54d03 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/OpenAiResponseMocks.kt @@ -0,0 +1,20 @@ +package com.shifthackz.aisdv1.data.mocks + +import com.shifthackz.aisdv1.network.model.OpenAiImageRaw +import com.shifthackz.aisdv1.network.response.OpenAiResponse + +val mockSuccessOpenAiResponse = OpenAiResponse( + created = System.currentTimeMillis(), + data = listOf( + OpenAiImageRaw( + "base64", + "https://openai.com", + "prompt", + ), + ), +) + +val mockBadOpenAiResponse = OpenAiResponse( + created = System.currentTimeMillis(), + data = emptyList(), +) diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/SdEmbeddingsResponseMocks.kt b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/SdEmbeddingsResponseMocks.kt new file mode 100644 index 00000000..d2250d17 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/SdEmbeddingsResponseMocks.kt @@ -0,0 +1,11 @@ +package com.shifthackz.aisdv1.data.mocks + +import com.shifthackz.aisdv1.network.response.SdEmbeddingsResponse + +val mockSdEmbeddingsResponse = SdEmbeddingsResponse( + loaded = mapOf("1504" to "5598"), +) + +val mockEmptySdEmbeddingsResponse = SdEmbeddingsResponse( + loaded = emptyMap(), +) diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/SdGenerationResponseMocks.kt b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/SdGenerationResponseMocks.kt new file mode 100644 index 00000000..fa6aec9c --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/SdGenerationResponseMocks.kt @@ -0,0 +1,8 @@ +package com.shifthackz.aisdv1.data.mocks + +import com.shifthackz.aisdv1.network.response.SdGenerationResponse + +val mockSdGenerationResponse = SdGenerationResponse( + images = listOf("base64"), + info = "info", +) diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/ServerConfigurationEntityMocks.kt b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/ServerConfigurationEntityMocks.kt new file mode 100644 index 00000000..a649dc5b --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/ServerConfigurationEntityMocks.kt @@ -0,0 +1,8 @@ +package com.shifthackz.aisdv1.data.mocks + +import com.shifthackz.aisdv1.storage.db.cache.entity.ServerConfigurationEntity + +val mockServerConfigurationEntity = ServerConfigurationEntity( + serverId = "5598", + sdModelCheckpoint = "checkpoint", +) diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/ServerConfigurationMocks.kt b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/ServerConfigurationMocks.kt new file mode 100644 index 00000000..f6a0fb6f --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/ServerConfigurationMocks.kt @@ -0,0 +1,5 @@ +package com.shifthackz.aisdv1.data.mocks + +import com.shifthackz.aisdv1.domain.entity.ServerConfiguration + +val mockServerConfiguration = ServerConfiguration("checkpoint") diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/ServerConfigurationRawMocks.kt b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/ServerConfigurationRawMocks.kt new file mode 100644 index 00000000..885f7969 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/ServerConfigurationRawMocks.kt @@ -0,0 +1,5 @@ +package com.shifthackz.aisdv1.data.mocks + +import com.shifthackz.aisdv1.network.model.ServerConfigurationRaw + +val mockServerConfigurationRaw = ServerConfigurationRaw("5598") diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StabilityAiEngineMocks.kt b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StabilityAiEngineMocks.kt new file mode 100644 index 00000000..1616763f --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StabilityAiEngineMocks.kt @@ -0,0 +1,10 @@ +package com.shifthackz.aisdv1.data.mocks + +import com.shifthackz.aisdv1.domain.entity.StabilityAiEngine + +val mockStabilityAiEngines = listOf( + StabilityAiEngine( + id = "5598", + name = "engine_5598", + ), +) diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StabilityAiEngineRawMocks.kt b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StabilityAiEngineRawMocks.kt new file mode 100644 index 00000000..9dcd0305 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StabilityAiEngineRawMocks.kt @@ -0,0 +1,12 @@ +package com.shifthackz.aisdv1.data.mocks + +import com.shifthackz.aisdv1.network.model.StabilityAiEngineRaw + +val mockStabilityAiEnginesRaw = listOf( + StabilityAiEngineRaw( + description = "❤", + id = "5598", + name = "Super engine", + type = "unique", + ), +) diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StabilityGenerationResponseMocks.kt b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StabilityGenerationResponseMocks.kt new file mode 100644 index 00000000..5199d519 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StabilityGenerationResponseMocks.kt @@ -0,0 +1,15 @@ +package com.shifthackz.aisdv1.data.mocks + +import com.shifthackz.aisdv1.network.response.StabilityGenerationResponse + +val mockStabilityGenerationResponse = StabilityGenerationResponse( + artifacts = listOf( + StabilityGenerationResponse.Artifact( + base64 = "base64", + finishReason = "reasonable reason", + seed = 5598L, + ), + ), +) + +val mockBadStabilityGenerationResponse = StabilityGenerationResponse(emptyList()) diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionEmbeddingEntityMocks.kt b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionEmbeddingEntityMocks.kt new file mode 100644 index 00000000..90405489 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionEmbeddingEntityMocks.kt @@ -0,0 +1,14 @@ +package com.shifthackz.aisdv1.data.mocks + +import com.shifthackz.aisdv1.storage.db.cache.entity.StableDiffusionEmbeddingEntity + +val mockStableDiffusionEmbeddingEntities = listOf( + StableDiffusionEmbeddingEntity( + id = "5598", + keyword = "keyword_5598", + ), + StableDiffusionEmbeddingEntity( + id = "151297", + keyword = "keyword_151297", + ), +) diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionEmbeddingMocks.kt b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionEmbeddingMocks.kt new file mode 100644 index 00000000..a7286a4e --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionEmbeddingMocks.kt @@ -0,0 +1,8 @@ +package com.shifthackz.aisdv1.data.mocks + +import com.shifthackz.aisdv1.domain.entity.StableDiffusionEmbedding + +val mockStableDiffusionEmbeddings = listOf( + StableDiffusionEmbedding("keyword_5598"), + StableDiffusionEmbedding("keyword_151297"), +) diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionHyperNetworkEntityMocks.kt b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionHyperNetworkEntityMocks.kt new file mode 100644 index 00000000..9d618111 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionHyperNetworkEntityMocks.kt @@ -0,0 +1,16 @@ +package com.shifthackz.aisdv1.data.mocks + +import com.shifthackz.aisdv1.storage.db.cache.entity.StableDiffusionHyperNetworkEntity + +val mockStableDiffusionHyperNetworkEntities = listOf( + StableDiffusionHyperNetworkEntity( + id = "5598", + name = "net_5598", + path = "/unknown", + ), + StableDiffusionHyperNetworkEntity( + id = "151297", + name = "net_151297", + path = "/unknown", + ), +) diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionHyperNetworkMocks.kt b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionHyperNetworkMocks.kt new file mode 100644 index 00000000..6e4c5931 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionHyperNetworkMocks.kt @@ -0,0 +1,14 @@ +package com.shifthackz.aisdv1.data.mocks + +import com.shifthackz.aisdv1.domain.entity.StableDiffusionHyperNetwork + +val mockStableDiffusionHyperNetworks = listOf( + StableDiffusionHyperNetwork( + name = "net_5598", + path = "/unknown", + ), + StableDiffusionHyperNetwork( + name = "net_151297", + path = "/unknown", + ) +) diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionHyperNetworkRawMocks.kt b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionHyperNetworkRawMocks.kt new file mode 100644 index 00000000..d3de9fb2 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionHyperNetworkRawMocks.kt @@ -0,0 +1,10 @@ +package com.shifthackz.aisdv1.data.mocks + +import com.shifthackz.aisdv1.network.model.StableDiffusionHyperNetworkRaw + +val mockStableDiffusionHyperNetworkRaw = listOf( + StableDiffusionHyperNetworkRaw( + name = "5598", + path = "Unknown", + ), +) diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionLoraEntityMocks.kt b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionLoraEntityMocks.kt new file mode 100644 index 00000000..d0ed9985 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionLoraEntityMocks.kt @@ -0,0 +1,18 @@ +package com.shifthackz.aisdv1.data.mocks + +import com.shifthackz.aisdv1.storage.db.cache.entity.StableDiffusionLoraEntity + +val mockStableDiffusionLoraEntities = listOf( + StableDiffusionLoraEntity( + id = "5598", + name = "name_5598", + alias = "alias_5598", + path = "/unknown", + ), + StableDiffusionLoraEntity( + id = "151297", + name = "name_151297", + alias = "alias_151297", + path = "/unknown", + ), +) diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionLoraMocks.kt b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionLoraMocks.kt new file mode 100644 index 00000000..07b40191 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionLoraMocks.kt @@ -0,0 +1,16 @@ +package com.shifthackz.aisdv1.data.mocks + +import com.shifthackz.aisdv1.domain.entity.StableDiffusionLora + +val mockStableDiffusionLoras = listOf( + StableDiffusionLora( + name = "name_5598", + alias = "alias_5598", + path = "/unknown", + ), + StableDiffusionLora( + name = "name_151297", + alias = "alias_151297", + path = "/unknown", + ), +) diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionLoraRawMocks.kt b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionLoraRawMocks.kt new file mode 100644 index 00000000..0f2e8bea --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionLoraRawMocks.kt @@ -0,0 +1,11 @@ +package com.shifthackz.aisdv1.data.mocks + +import com.shifthackz.aisdv1.network.model.StableDiffusionLoraRaw + +val mockStableDiffusionLoraRaw = listOf( + StableDiffusionLoraRaw( + name = "Super lora", + alias = "5598", + path = "Unknown", + ), +) diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionModelEntityMocks.kt b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionModelEntityMocks.kt new file mode 100644 index 00000000..51e9496c --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionModelEntityMocks.kt @@ -0,0 +1,24 @@ +package com.shifthackz.aisdv1.data.mocks + +import com.shifthackz.aisdv1.storage.db.cache.entity.StableDiffusionModelEntity + +val mockStableDiffusionModelEntities = listOf( + StableDiffusionModelEntity( + id = "5598", + title = "title_5598", + name = "name_5598", + hash = "hash_5598", + sha256 = "sha_5598", + filename = "file_5598", + config = "config_5598", + ), + StableDiffusionModelEntity( + id = "151297", + title = "title_151297", + name = "name_151297", + hash = "hash_151297", + sha256 = "sha_151297", + filename = "file_151297", + config = "config_151297", + ), +) diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionModelMocks.kt b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionModelMocks.kt new file mode 100644 index 00000000..25222e86 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionModelMocks.kt @@ -0,0 +1,22 @@ +package com.shifthackz.aisdv1.data.mocks + +import com.shifthackz.aisdv1.domain.entity.StableDiffusionModel + +val mockStableDiffusionModels = listOf( + StableDiffusionModel( + title = "title_5598", + modelName = "name_5598", + hash = "hash_5598", + sha256 = "sha_5598", + filename = "file_5598", + config = "config_5598", + ), + StableDiffusionModel( + title = "title_151297", + modelName = "name_151297", + hash = "hash_151297", + sha256 = "sha_151297", + filename = "file_151297", + config = "config_151297", + ), +) diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionModelRawMocks.kt b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionModelRawMocks.kt new file mode 100644 index 00000000..c9332886 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionModelRawMocks.kt @@ -0,0 +1,14 @@ +package com.shifthackz.aisdv1.data.mocks + +import com.shifthackz.aisdv1.network.model.StableDiffusionModelRaw + +val mockStableDiffusionModelRaw = listOf( + StableDiffusionModelRaw( + title = "5598", + modelName = "5598", + hash = "hash5598", + sha256 = "sha5598", + filename = "Unknown", + config = "Unconfigurable", + ), +) diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionSamplerEntityMocks.kt b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionSamplerEntityMocks.kt new file mode 100644 index 00000000..860e7283 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionSamplerEntityMocks.kt @@ -0,0 +1,18 @@ +package com.shifthackz.aisdv1.data.mocks + +import com.shifthackz.aisdv1.storage.db.cache.entity.StableDiffusionSamplerEntity + +val mockStableDiffusionSamplerEntities = listOf( + StableDiffusionSamplerEntity( + id = "5598", + name = "name_5598", + aliases = emptyList(), + options = emptyMap(), + ), + StableDiffusionSamplerEntity( + id = "151297", + name = "name_151297", + aliases = emptyList(), + options = emptyMap(), + ), +) diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionSamplerMocks.kt b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionSamplerMocks.kt new file mode 100644 index 00000000..f8640a7a --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionSamplerMocks.kt @@ -0,0 +1,16 @@ +package com.shifthackz.aisdv1.data.mocks + +import com.shifthackz.aisdv1.domain.entity.StableDiffusionSampler + +val mockStableDiffusionSamplers = listOf( + StableDiffusionSampler( + name = "name_5598", + aliases = emptyList(), + options = emptyMap(), + ), + StableDiffusionSampler( + name = "name_151297", + aliases = emptyList(), + options = emptyMap(), + ), +) diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionSamplerRawMocks.kt b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionSamplerRawMocks.kt new file mode 100644 index 00000000..f92d5abc --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/StableDiffusionSamplerRawMocks.kt @@ -0,0 +1,11 @@ +package com.shifthackz.aisdv1.data.mocks + +import com.shifthackz.aisdv1.network.model.StableDiffusionSamplerRaw + +val mockStableDiffusionSamplerRaw = listOf( + StableDiffusionSamplerRaw( + "5598", + listOf(), + mapOf(), + ) +) diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/TextToImagePayloadMocks.kt b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/TextToImagePayloadMocks.kt new file mode 100644 index 00000000..ea38da27 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/TextToImagePayloadMocks.kt @@ -0,0 +1,24 @@ +package com.shifthackz.aisdv1.data.mocks + +import com.shifthackz.aisdv1.domain.entity.TextToImagePayload + +val mockTextToImagePayload = TextToImagePayload( + prompt = "prompt", + negativePrompt = "negative", + samplingSteps = 12, + cfgScale = 0.7f, + width = 512, + height = 512, + restoreFaces = true, + seed = "5598", + subSeed = "1504", + subSeedStrength = 5598f, + sampler = "sampler", + nsfw = true, + batchCount = 1, + quality = null, + style = null, + openAiModel = null, + stabilityAiClipGuidance = null, + stabilityAiStylePreset = null, +) diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImplTest.kt new file mode 100644 index 00000000..0da16f96 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImplTest.kt @@ -0,0 +1,511 @@ +package com.shifthackz.aisdv1.data.preference + +import android.content.SharedPreferences +import com.nhaarman.mockitokotlin2.any +import com.nhaarman.mockitokotlin2.doNothing +import com.nhaarman.mockitokotlin2.eq +import com.nhaarman.mockitokotlin2.mock +import com.nhaarman.mockitokotlin2.whenever +import com.shifthackz.aisdv1.data.preference.PreferenceManagerImpl.Companion.KEY_AI_AUTO_SAVE +import com.shifthackz.aisdv1.data.preference.PreferenceManagerImpl.Companion.KEY_DEMO_MODE +import com.shifthackz.aisdv1.data.preference.PreferenceManagerImpl.Companion.KEY_DESIGN_COLOR_TOKEN +import com.shifthackz.aisdv1.data.preference.PreferenceManagerImpl.Companion.KEY_DESIGN_DARK_THEME +import com.shifthackz.aisdv1.data.preference.PreferenceManagerImpl.Companion.KEY_DESIGN_DARK_TOKEN +import com.shifthackz.aisdv1.data.preference.PreferenceManagerImpl.Companion.KEY_DESIGN_DYNAMIC_COLORS +import com.shifthackz.aisdv1.data.preference.PreferenceManagerImpl.Companion.KEY_DESIGN_SYSTEM_DARK_THEME +import com.shifthackz.aisdv1.data.preference.PreferenceManagerImpl.Companion.KEY_FORCE_SETUP_AFTER_UPDATE +import com.shifthackz.aisdv1.data.preference.PreferenceManagerImpl.Companion.KEY_FORM_ALWAYS_SHOW_ADVANCED_OPTIONS +import com.shifthackz.aisdv1.data.preference.PreferenceManagerImpl.Companion.KEY_FORM_PROMPT_TAGGED_INPUT +import com.shifthackz.aisdv1.data.preference.PreferenceManagerImpl.Companion.KEY_HORDE_API_KEY +import com.shifthackz.aisdv1.data.preference.PreferenceManagerImpl.Companion.KEY_HUGGING_FACE_API_KEY +import com.shifthackz.aisdv1.data.preference.PreferenceManagerImpl.Companion.KEY_HUGGING_FACE_MODEL_KEY +import com.shifthackz.aisdv1.data.preference.PreferenceManagerImpl.Companion.KEY_LOCAL_MODEL_ID +import com.shifthackz.aisdv1.data.preference.PreferenceManagerImpl.Companion.KEY_LOCAL_NN_API +import com.shifthackz.aisdv1.data.preference.PreferenceManagerImpl.Companion.KEY_MONITOR_CONNECTIVITY +import com.shifthackz.aisdv1.data.preference.PreferenceManagerImpl.Companion.KEY_OPEN_AI_API_KEY +import com.shifthackz.aisdv1.data.preference.PreferenceManagerImpl.Companion.KEY_SAVE_TO_MEDIA_STORE +import com.shifthackz.aisdv1.data.preference.PreferenceManagerImpl.Companion.KEY_SD_MODEL +import com.shifthackz.aisdv1.data.preference.PreferenceManagerImpl.Companion.KEY_SERVER_SOURCE +import com.shifthackz.aisdv1.data.preference.PreferenceManagerImpl.Companion.KEY_SERVER_URL +import com.shifthackz.aisdv1.data.preference.PreferenceManagerImpl.Companion.KEY_STABILITY_AI_API_KEY +import com.shifthackz.aisdv1.data.preference.PreferenceManagerImpl.Companion.KEY_STABILITY_AI_ENGINE_ID_KEY +import com.shifthackz.aisdv1.domain.entity.ColorToken +import com.shifthackz.aisdv1.domain.entity.DarkThemeToken +import com.shifthackz.aisdv1.domain.entity.HuggingFaceModel +import com.shifthackz.aisdv1.domain.entity.ServerSource +import org.junit.Assert +import org.junit.Before +import org.junit.Test + +class PreferenceManagerImplTest { + + private val stubEditor = mock() + private val stubPreference = mock() + + private val preferenceManager = PreferenceManagerImpl(stubPreference) + + @Before + fun initialize() { + doNothing() + .whenever(stubEditor) + .apply() + + whenever(stubEditor.putString(any(), any())) + .thenReturn(stubEditor) + + whenever(stubEditor.putBoolean(any(), any())) + .thenReturn(stubEditor) + + whenever(stubPreference.edit()) + .thenReturn(stubEditor) + + whenever(stubPreference.getString(any(), any())) + .thenReturn("") + + whenever(stubPreference.getBoolean(any(), any())) + .thenReturn(false) + } + + @Test + fun `given user reads default serverUrl, changes it, expected default value, then changed value, observer emits changed value`() { + whenever(stubPreference.getString(eq(KEY_SERVER_URL), any())) + .thenReturn("") + + Assert.assertEquals("", preferenceManager.serverUrl) + + whenever(stubPreference.getString(eq(KEY_SERVER_URL), any())) + .thenReturn("https://192.168.0.1:7860") + + preferenceManager.serverUrl = "https://192.168.0.1:7860" + + Assert.assertEquals("https://192.168.0.1:7860", preferenceManager.serverUrl) + + preferenceManager + .observe() + .test() + .assertNoErrors() + .assertValueAt(0) { settings -> + settings.serverUrl == "https://192.168.0.1:7860" + } + } + + @Test + fun `given user reads default demoMode, changes it, expected default value, then changed value, observer emits changed value`() { + whenever(stubPreference.getBoolean(eq(KEY_DEMO_MODE), any())) + .thenReturn(false) + + Assert.assertEquals(false, preferenceManager.demoMode) + + whenever(stubPreference.getBoolean(eq(KEY_DEMO_MODE), any())) + .thenReturn(true) + + preferenceManager.demoMode = true + + Assert.assertEquals(true, preferenceManager.demoMode) + + preferenceManager + .observe() + .test() + .assertNoErrors() + .assertValueAt(0) { settings -> settings.demoMode } + } + + @Test + fun `given user reads default monitorConnectivity, changes it, expected default value, then changed value, observer emits changed value`() { + whenever(stubPreference.getBoolean(eq(KEY_MONITOR_CONNECTIVITY), any())) + .thenReturn(false) + + Assert.assertEquals(false, preferenceManager.monitorConnectivity) + + whenever(stubPreference.getBoolean(eq(KEY_MONITOR_CONNECTIVITY), any())) + .thenReturn(true) + + preferenceManager.monitorConnectivity = true + + Assert.assertEquals(true, preferenceManager.monitorConnectivity) + + preferenceManager + .observe() + .test() + .assertNoErrors() + .assertValueAt(0) { settings -> settings.monitorConnectivity } + } + + @Test + fun `given user reads default autoSaveAiResults, changes it, expected default value, then changed value, observer emits changed value`() { + whenever(stubPreference.getBoolean(eq(KEY_AI_AUTO_SAVE), any())) + .thenReturn(false) + + Assert.assertEquals(false, preferenceManager.autoSaveAiResults) + + whenever(stubPreference.getBoolean(eq(KEY_AI_AUTO_SAVE), any())) + .thenReturn(true) + + preferenceManager.autoSaveAiResults = true + + Assert.assertEquals(true, preferenceManager.autoSaveAiResults) + + preferenceManager + .observe() + .test() + .assertNoErrors() + .assertValueAt(0) { settings -> settings.autoSaveAiResults } + } + + @Test + fun `given user reads default saveToMediaStore, changes it, expected default value, then changed value, observer emits changed value`() { + whenever(stubPreference.getBoolean(eq(KEY_SAVE_TO_MEDIA_STORE), any())) + .thenReturn(false) + + Assert.assertEquals(false, preferenceManager.saveToMediaStore) + + whenever(stubPreference.getBoolean(eq(KEY_SAVE_TO_MEDIA_STORE), any())) + .thenReturn(true) + + preferenceManager.saveToMediaStore = true + + Assert.assertEquals(true, preferenceManager.saveToMediaStore) + + preferenceManager + .observe() + .test() + .assertNoErrors() + .assertValueAt(0) { settings -> settings.saveToMediaStore } + } + + @Test + fun `given user reads default formAdvancedOptionsAlwaysShow, changes it, expected default value, then changed value, observer emits changed value`() { + whenever(stubPreference.getBoolean(eq(KEY_FORM_ALWAYS_SHOW_ADVANCED_OPTIONS), any())) + .thenReturn(false) + + Assert.assertEquals(false, preferenceManager.formAdvancedOptionsAlwaysShow) + + whenever(stubPreference.getBoolean(eq(KEY_FORM_ALWAYS_SHOW_ADVANCED_OPTIONS), any())) + .thenReturn(true) + + preferenceManager.formAdvancedOptionsAlwaysShow = true + + Assert.assertEquals(true, preferenceManager.formAdvancedOptionsAlwaysShow) + + preferenceManager + .observe() + .test() + .assertNoErrors() + .assertValueAt(0) { settings -> settings.formAdvancedOptionsAlwaysShow } + } + + @Test + fun `given user reads default formPromptTaggedInput, changes it, expected default value, then changed value, observer emits changed value`() { + whenever(stubPreference.getBoolean(eq(KEY_FORM_PROMPT_TAGGED_INPUT), any())) + .thenReturn(false) + + Assert.assertEquals(false, preferenceManager.formPromptTaggedInput) + + whenever(stubPreference.getBoolean(eq(KEY_FORM_PROMPT_TAGGED_INPUT), any())) + .thenReturn(true) + + preferenceManager.formPromptTaggedInput = true + + Assert.assertEquals(true, preferenceManager.formPromptTaggedInput) + + preferenceManager + .observe() + .test() + .assertNoErrors() + .assertValueAt(0) { settings -> settings.formPromptTaggedInput } + } + + @Test + fun `given user reads default source, changes it, expected default value, then changed value, observer emits changed value`() { + whenever(stubPreference.getString(eq(KEY_SERVER_SOURCE), any())) + .thenReturn(ServerSource.AUTOMATIC1111.key) + + Assert.assertEquals(ServerSource.AUTOMATIC1111, preferenceManager.source) + + whenever(stubPreference.getString(eq(KEY_SERVER_SOURCE), any())) + .thenReturn(ServerSource.LOCAL.key) + + preferenceManager.source = ServerSource.LOCAL + + Assert.assertEquals(ServerSource.LOCAL, preferenceManager.source) + + preferenceManager + .observe() + .test() + .assertNoErrors() + .assertValueAt(0) { settings -> settings.source == ServerSource.LOCAL } + } + + @Test + fun `given user reads default sdModel, changes it, expected default value, then changed value, observer emits changed value`() { + whenever(stubPreference.getString(eq(KEY_SD_MODEL), any())) + .thenReturn("model5598") + + Assert.assertEquals("model5598", preferenceManager.sdModel) + + whenever(stubPreference.getString(eq(KEY_SD_MODEL), any())) + .thenReturn("model1504") + + preferenceManager.sdModel = "model1504" + + Assert.assertEquals("model1504", preferenceManager.sdModel) + + preferenceManager + .observe() + .test() + .assertNoErrors() + .assertValueAt(0) { settings -> settings.sdModel == "model1504" } + } + + @Test + fun `given user reads default hordeApiKey, changes it, expected default value, then changed value, observer emits changed value`() { + whenever(stubPreference.getString(eq(KEY_HORDE_API_KEY), any())) + .thenReturn("00000000") + + Assert.assertEquals("00000000", preferenceManager.hordeApiKey) + + whenever(stubPreference.getString(eq(KEY_HORDE_API_KEY), any())) + .thenReturn("key") + + preferenceManager.hordeApiKey = "key" + + Assert.assertEquals("key", preferenceManager.hordeApiKey) + + preferenceManager + .observe() + .test() + .assertNoErrors() + .assertValueAt(0) { settings -> settings.hordeApiKey == "key" } + } + + @Test + fun `given user reads default openAiApiKey, changes it, expected default value, then changed value`() { + whenever(stubPreference.getString(eq(KEY_OPEN_AI_API_KEY), any())) + .thenReturn("00000000") + + Assert.assertEquals("00000000", preferenceManager.openAiApiKey) + + whenever(stubPreference.getString(eq(KEY_OPEN_AI_API_KEY), any())) + .thenReturn("key") + + preferenceManager.openAiApiKey = "key" + + Assert.assertEquals("key", preferenceManager.openAiApiKey) + } + + @Test + fun `given user reads default huggingFaceApiKey, changes it, expected default value, then changed value`() { + whenever(stubPreference.getString(eq(KEY_HUGGING_FACE_API_KEY), any())) + .thenReturn("00000000") + + Assert.assertEquals("00000000", preferenceManager.huggingFaceApiKey) + + whenever(stubPreference.getString(eq(KEY_HUGGING_FACE_API_KEY), any())) + .thenReturn("key") + + preferenceManager.huggingFaceApiKey = "key" + + Assert.assertEquals("key", preferenceManager.huggingFaceApiKey) + } + + @Test + fun `given user reads default huggingFaceModel, changes it, expected default value, then changed value`() { + whenever(stubPreference.getString(eq(KEY_HUGGING_FACE_MODEL_KEY), any())) + .thenReturn(HuggingFaceModel.default.alias) + + Assert.assertEquals(HuggingFaceModel.default.alias, preferenceManager.huggingFaceModel) + + whenever(stubPreference.getString(eq(KEY_HUGGING_FACE_MODEL_KEY), any())) + .thenReturn("key") + + preferenceManager.huggingFaceModel = "key" + + Assert.assertEquals("key", preferenceManager.huggingFaceModel) + } + + @Test + fun `given user reads default stabilityAiApiKey, changes it, expected default value, then changed value`() { + whenever(stubPreference.getString(eq(KEY_STABILITY_AI_API_KEY), any())) + .thenReturn("") + + Assert.assertEquals("", preferenceManager.stabilityAiApiKey) + + whenever(stubPreference.getString(eq(KEY_STABILITY_AI_API_KEY), any())) + .thenReturn("key") + + preferenceManager.stabilityAiApiKey = "key" + + Assert.assertEquals("key", preferenceManager.stabilityAiApiKey) + } + + @Test + fun `given user reads default stabilityAiEngineId, changes it, expected default value, then changed value`() { + whenever(stubPreference.getString(eq(KEY_STABILITY_AI_ENGINE_ID_KEY), any())) + .thenReturn("") + + Assert.assertEquals("", preferenceManager.stabilityAiEngineId) + + whenever(stubPreference.getString(eq(KEY_STABILITY_AI_ENGINE_ID_KEY), any())) + .thenReturn("key") + + preferenceManager.stabilityAiEngineId = "key" + + Assert.assertEquals("key", preferenceManager.stabilityAiEngineId) + } + + @Test + fun `given user reads default forceSetupAfterUpdate, changes it, expected default value, then changed value, observer emits changed value`() { + whenever(stubPreference.getBoolean(eq(KEY_FORCE_SETUP_AFTER_UPDATE), any())) + .thenReturn(false) + + Assert.assertEquals(false, preferenceManager.forceSetupAfterUpdate) + + whenever(stubPreference.getBoolean(eq(KEY_FORCE_SETUP_AFTER_UPDATE), any())) + .thenReturn(true) + + preferenceManager.forceSetupAfterUpdate = true + + Assert.assertEquals(true, preferenceManager.forceSetupAfterUpdate) + } + + @Test + fun `given user reads default localModelId, changes it, expected default value, then changed value`() { + whenever(stubPreference.getString(eq(KEY_LOCAL_MODEL_ID), any())) + .thenReturn("") + + Assert.assertEquals("", preferenceManager.localModelId) + + whenever(stubPreference.getString(eq(KEY_LOCAL_MODEL_ID), any())) + .thenReturn("key") + + preferenceManager.localModelId = "key" + + Assert.assertEquals("key", preferenceManager.localModelId) + } + + @Test + fun `given user reads default localUseNNAPI, changes it, expected default value, then changed value, observer emits changed value`() { + whenever(stubPreference.getBoolean(eq(KEY_LOCAL_NN_API), any())) + .thenReturn(false) + + Assert.assertEquals(false, preferenceManager.localUseNNAPI) + + whenever(stubPreference.getBoolean(eq(KEY_LOCAL_NN_API), any())) + .thenReturn(true) + + preferenceManager.localUseNNAPI = true + + Assert.assertEquals(true, preferenceManager.localUseNNAPI) + + preferenceManager + .observe() + .test() + .assertNoErrors() + .assertValueAt(0) { settings -> settings.localUseNNAPI } + } + + @Test + fun `given user reads default designUseSystemColorPalette, changes it, expected default value, then changed value, observer emits changed value`() { + whenever(stubPreference.getBoolean(eq(KEY_DESIGN_DYNAMIC_COLORS), any())) + .thenReturn(false) + + Assert.assertEquals(false, preferenceManager.designUseSystemColorPalette) + + whenever(stubPreference.getBoolean(eq(KEY_DESIGN_DYNAMIC_COLORS), any())) + .thenReturn(true) + + preferenceManager.designUseSystemColorPalette = true + + Assert.assertEquals(true, preferenceManager.designUseSystemColorPalette) + + preferenceManager + .observe() + .test() + .assertNoErrors() + .assertValueAt(0) { settings -> settings.designUseSystemColorPalette } + } + + @Test + fun `given user reads default designUseSystemDarkTheme, changes it, expected default value, then changed value, observer emits changed value`() { + whenever(stubPreference.getBoolean(eq(KEY_DESIGN_SYSTEM_DARK_THEME), any())) + .thenReturn(false) + + Assert.assertEquals(false, preferenceManager.designUseSystemDarkTheme) + + whenever(stubPreference.getBoolean(eq(KEY_DESIGN_SYSTEM_DARK_THEME), any())) + .thenReturn(true) + + preferenceManager.designUseSystemDarkTheme = true + + Assert.assertEquals(true, preferenceManager.designUseSystemDarkTheme) + + preferenceManager + .observe() + .test() + .assertNoErrors() + .assertValueAt(0) { settings -> settings.designUseSystemDarkTheme } + } + + @Test + fun `given user reads default designDarkTheme, changes it, expected default value, then changed value, observer emits changed value`() { + whenever(stubPreference.getBoolean(eq(KEY_DESIGN_DARK_THEME), any())) + .thenReturn(false) + + Assert.assertEquals(false, preferenceManager.designDarkTheme) + + whenever(stubPreference.getBoolean(eq(KEY_DESIGN_DARK_THEME), any())) + .thenReturn(true) + + preferenceManager.designDarkTheme = true + + Assert.assertEquals(true, preferenceManager.designDarkTheme) + + preferenceManager + .observe() + .test() + .assertNoErrors() + .assertValueAt(0) { settings -> settings.designDarkTheme } + } + + @Test + fun `given user reads default designColorToken, changes it, expected default value, then changed value, observer emits changed value`() { + whenever(stubPreference.getString(eq(KEY_DESIGN_COLOR_TOKEN), any())) + .thenReturn("${ColorToken.MAUVE}") + + Assert.assertEquals("${ColorToken.MAUVE}", preferenceManager.designColorToken) + + whenever(stubPreference.getString(eq(KEY_DESIGN_COLOR_TOKEN), any())) + .thenReturn("${ColorToken.PEACH}") + + preferenceManager.designColorToken = "${ColorToken.PEACH}" + + Assert.assertEquals("${ColorToken.PEACH}", preferenceManager.designColorToken) + + preferenceManager + .observe() + .test() + .assertNoErrors() + .assertValueAt(0) { settings -> settings.designColorToken == "${ColorToken.PEACH}" } + } + + @Test + fun `given user reads default designDarkThemeToken, changes it, expected default value, then changed value, observer emits changed value`() { + whenever(stubPreference.getString(eq(KEY_DESIGN_DARK_TOKEN), any())) + .thenReturn("${DarkThemeToken.FRAPPE}") + + Assert.assertEquals("${DarkThemeToken.FRAPPE}", preferenceManager.designDarkThemeToken) + + whenever(stubPreference.getString(eq(KEY_DESIGN_DARK_TOKEN), any())) + .thenReturn("${DarkThemeToken.MOCHA}") + + preferenceManager.designDarkThemeToken = "${DarkThemeToken.MOCHA}" + + Assert.assertEquals("${DarkThemeToken.MOCHA}", preferenceManager.designDarkThemeToken) + + preferenceManager + .observe() + .test() + .assertNoErrors() + .assertValueAt(0) { settings -> settings.designDarkThemeToken == "${DarkThemeToken.MOCHA}" } + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/preference/SessionPreferenceImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/preference/SessionPreferenceImplTest.kt new file mode 100644 index 00000000..7d6c9df4 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/preference/SessionPreferenceImplTest.kt @@ -0,0 +1,21 @@ +package com.shifthackz.aisdv1.data.preference + +import org.junit.Assert +import org.junit.Test + +class SessionPreferenceImplTest { + + private val sessionPreference = SessionPreferenceImpl() + + @Test + fun `given user reads default coinsPerDay value, expected -1`() { + Assert.assertEquals(-1, sessionPreference.coinsPerDay) + } + + @Test + fun `given user reads default coinsPerDay value, then changes it, expected -1, then changed value`() { + Assert.assertEquals(-1, sessionPreference.coinsPerDay) + sessionPreference.coinsPerDay = 5598 + Assert.assertEquals(5598, sessionPreference.coinsPerDay) + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSourceTest.kt new file mode 100644 index 00000000..c8e7339d --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSourceTest.kt @@ -0,0 +1,66 @@ +package com.shifthackz.aisdv1.data.remote + +import com.nhaarman.mockitokotlin2.mock +import com.nhaarman.mockitokotlin2.whenever +import com.shifthackz.aisdv1.core.common.file.FileProviderDescriptor +import com.shifthackz.aisdv1.data.mappers.mapRawToDomain +import com.shifthackz.aisdv1.data.mocks.mockDownloadableModelsResponse +import com.shifthackz.aisdv1.network.api.sdai.DownloadableModelsApi +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class DownloadableModelRemoteDataSourceTest { + + private val stubException = Throwable("Internal server error.") + private val stubApi = mock() + private val stubFileProviderDescriptor = mock() + + private val remoteDataSource = DownloadableModelRemoteDataSource( + api = stubApi, + fileProviderDescriptor = stubFileProviderDescriptor, + ) + + @Test + fun `given attempt to fetch models list, api returns data, expected valid domain models list`() { + whenever(stubApi.fetchDownloadableModels()) + .thenReturn(Single.just(mockDownloadableModelsResponse)) + + val expected = mockDownloadableModelsResponse.mapRawToDomain() + + remoteDataSource + .fetch() + .test() + .assertNoErrors() + .assertValue(expected) + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch models list, api returns empty data, expected empty domain models list`() { + whenever(stubApi.fetchDownloadableModels()) + .thenReturn(Single.just(emptyList())) + + remoteDataSource + .fetch() + .test() + .assertNoErrors() + .assertValue(emptyList()) + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch models list, api returns error, expected error value`() { + whenever(stubApi.fetchDownloadableModels()) + .thenReturn(Single.error(stubException)) + + remoteDataSource + .fetch() + .test() + .assertError(stubException) + .assertValueCount(0) + .await() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/remote/HordeGenerationRemoteDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/remote/HordeGenerationRemoteDataSourceTest.kt new file mode 100644 index 00000000..fe6cc4a0 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/remote/HordeGenerationRemoteDataSourceTest.kt @@ -0,0 +1,136 @@ +package com.shifthackz.aisdv1.data.remote + + +import android.graphics.Bitmap +import android.graphics.BitmapFactory +import com.shifthackz.aisdv1.data.mocks.mockTextToImagePayload +import com.shifthackz.aisdv1.core.imageprocessing.BitmapToBase64Converter +import com.shifthackz.aisdv1.domain.datasource.HordeGenerationDataSource +import com.shifthackz.aisdv1.network.api.horde.HordeRestApi +import com.shifthackz.aisdv1.network.response.HordeGenerationAsyncResponse +import com.shifthackz.aisdv1.network.response.HordeGenerationCheckFullResponse +import com.shifthackz.aisdv1.network.response.HordeGenerationCheckResponse +import com.shifthackz.aisdv1.network.response.HordeUserResponse +import io.mockk.every +import io.mockk.mockk +import io.mockk.mockkConstructor +import io.mockk.mockkStatic +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single +import io.reactivex.rxjava3.schedulers.Schedulers +import org.junit.Test +import java.net.URL +import java.util.concurrent.TimeUnit + +class HordeGenerationRemoteDataSourceTest { + + private val stubBytes = ByteArray(1024) + private val stubBitmap = mockk() + private val stubException = Throwable("Internal server error.") + private val stubApi = mockk() + private val stubBmpToBase64Converter = mockk() + private val stubHordeStatusSource = mockk() + + private val remoteDataSource = HordeGenerationRemoteDataSource( + hordeApi = stubApi, + converter = stubBmpToBase64Converter, + statusSource = stubHordeStatusSource, + ) + + @Test + fun `given attempt to validate api key, api returns user with valid id, expected true value`() { + every { + stubApi.checkHordeApiKey() + } returns Single.just(HordeUserResponse(5598)) + + remoteDataSource + .validateApiKey() + .test() + .assertNoErrors() + .assertValue(true) + .await() + .assertComplete() + } + + @Test + fun `given attempt to validate api key, api returns null, expected false value`() { + every { + stubApi.checkHordeApiKey() + } returns Single.just(HordeUserResponse(null)) + + remoteDataSource + .validateApiKey() + .test() + .assertNoErrors() + .assertValue(false) + .await() + .assertComplete() + } + + @Test + fun `given attempt to validate api key, api throws exception, expected false value`() { + every { + stubApi.checkHordeApiKey() + } returns Single.error(stubException) + + remoteDataSource + .validateApiKey() + .test() + .assertNoErrors() + .assertValue(false) + .await() + .assertComplete() + } + + @Test + fun `given attempt to interrupt generation, id present in cache, api returns success, expected complete value`() { + every { + stubHordeStatusSource.id + } returns "5598" + + every { + stubApi.cancelRequest(any()) + } returns Completable.complete() + + remoteDataSource + .interruptGeneration() + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given attempt to interrupt generation, id present in cache, api throws exception, expected error value`() { + every { + stubHordeStatusSource.id + } returns "5598" + + every { + stubApi.cancelRequest(any()) + } returns Completable.error(stubException) + + remoteDataSource + .interruptGeneration() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to interrupt generation, no id present in cache, expected error value`() { + every { + stubHordeStatusSource.id + } returns null + + remoteDataSource + .interruptGeneration() + .test() + .assertError { t -> + t is IllegalStateException && t.message == "No cached request id" + } + .await() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/remote/HuggingFaceGenerationRemoteDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/remote/HuggingFaceGenerationRemoteDataSourceTest.kt new file mode 100644 index 00000000..43606065 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/remote/HuggingFaceGenerationRemoteDataSourceTest.kt @@ -0,0 +1,166 @@ +package com.shifthackz.aisdv1.data.remote + +import android.graphics.Bitmap +import com.shifthackz.aisdv1.core.imageprocessing.BitmapToBase64Converter +import com.shifthackz.aisdv1.data.mocks.mockTextToImagePayload +import com.shifthackz.aisdv1.data.mocks.mockImageToImagePayload +import com.shifthackz.aisdv1.domain.entity.AiGenerationResult +import com.shifthackz.aisdv1.network.api.huggingface.HuggingFaceApi +import com.shifthackz.aisdv1.network.api.huggingface.HuggingFaceInferenceApi +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class HuggingFaceGenerationRemoteDataSourceTest { + + private val stubConverterException = Throwable("Converter failure.") + private val stubApiException = Throwable("Internal server error.") + private val stubBitmap = mockk() + private val stubApi = mockk() + private val stubInferenceApi = mockk() + private val stubBmpToBase64Converter = mockk() + + private val remoteDataSource = HuggingFaceGenerationRemoteDataSource( + huggingFaceApi = stubApi, + huggingFaceInferenceApi = stubInferenceApi, + converter = stubBmpToBase64Converter, + ) + + @Test + fun `given attempt to validate api key, api request succeeds, expected true`() { + every { + stubApi.validateBearerToken() + } returns Completable.complete() + + remoteDataSource + .validateApiKey() + .test() + .assertNoErrors() + .assertValue(true) + .await() + .assertComplete() + } + + @Test + fun `given attempt to validate api key, api request fails, expected false`() { + every { + stubApi.validateBearerToken() + } returns Completable.error(Throwable("Invalid api key.")) + + remoteDataSource + .validateApiKey() + .test() + .assertNoErrors() + .assertValue(false) + .await() + .assertComplete() + } + + @Test + fun `given attempt to generate txt2img, api returns result, converter proceed successfully, expected valid ai generation result value`() { + every { + stubInferenceApi.generate(any(), any()) + } returns Single.just(stubBitmap) + + every { + stubBmpToBase64Converter(any()) + } returns Single.just(BitmapToBase64Converter.Output("base64")) + + remoteDataSource + .textToImage("model", mockTextToImagePayload) + .test() + .assertNoErrors() + .assertValue { it is AiGenerationResult } + .await() + .assertComplete() + } + + @Test + fun `given attempt to generate txt2img, api fails, expected error value`() { + every { + stubInferenceApi.generate(any(), any()) + } returns Single.error(stubApiException) + + remoteDataSource + .textToImage("model", mockTextToImagePayload) + .test() + .assertError(stubApiException) + .assertValueCount(0) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to generate txt2img, api returns result, converter proceed with fail, expected error value`() { + every { + stubInferenceApi.generate(any(), any()) + } returns Single.just(stubBitmap) + + every { + stubBmpToBase64Converter(any()) + } returns Single.error(stubConverterException) + + remoteDataSource + .textToImage("model", mockTextToImagePayload) + .test() + .assertError(stubConverterException) + .assertValueCount(0) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to generate img2img, api returns result, converter proceed successfully, expected valid ai generation result value`() { + every { + stubInferenceApi.generate(any(), any()) + } returns Single.just(stubBitmap) + + every { + stubBmpToBase64Converter(any()) + } returns Single.just(BitmapToBase64Converter.Output("base64")) + + remoteDataSource + .imageToImage("model", mockImageToImagePayload) + .test() + .assertNoErrors() + .assertValue { it is AiGenerationResult } + .await() + .assertComplete() + } + + @Test + fun `given attempt to generate img2img, api fails, expected error value`() { + every { + stubInferenceApi.generate(any(), any()) + } returns Single.error(stubApiException) + + remoteDataSource + .imageToImage("model", mockImageToImagePayload) + .test() + .assertError(stubApiException) + .assertValueCount(0) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to generate img2img, api returns result, converter proceed with fail, expected error value`() { + every { + stubInferenceApi.generate(any(), any()) + } returns Single.just(stubBitmap) + + every { + stubBmpToBase64Converter(any()) + } returns Single.error(stubConverterException) + + remoteDataSource + .imageToImage("model", mockImageToImagePayload) + .test() + .assertError(stubConverterException) + .assertValueCount(0) + .await() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/remote/HuggingFaceModelsRemoteDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/remote/HuggingFaceModelsRemoteDataSourceTest.kt new file mode 100644 index 00000000..4588f114 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/remote/HuggingFaceModelsRemoteDataSourceTest.kt @@ -0,0 +1,56 @@ +package com.shifthackz.aisdv1.data.remote + +import com.shifthackz.aisdv1.data.mocks.mockHuggingFaceModelsRaw +import com.shifthackz.aisdv1.domain.entity.HuggingFaceModel +import com.shifthackz.aisdv1.network.api.sdai.HuggingFaceModelsApi +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class HuggingFaceModelsRemoteDataSourceTest { + + private val stubException = Throwable("Error to communicate with api.") + private val stubApi = mockk() + + private val remoteDataSource = HuggingFaceModelsRemoteDataSource(stubApi) + + @Test + fun `given attempt to fetch hugging face models, api returns two models, expected list value with two domain models`() { + every { + stubApi.fetchHuggingFaceModels() + } returns Single.just(mockHuggingFaceModelsRaw) + + remoteDataSource + .fetchHuggingFaceModels() + .test() + .assertNoErrors() + .assertValue { models -> + models is List + && models.size == mockHuggingFaceModelsRaw.size + && models.any { it.id == "050598" } + && models.any { it.id == "151297" } + } + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch hugging face models, api throws exception, expected list value with one default domain model`() { + every { + stubApi.fetchHuggingFaceModels() + } returns Single.error(stubException) + + remoteDataSource + .fetchHuggingFaceModels() + .test() + .assertNoErrors() + .assertValue { models -> + models is List + && models.size == 1 + && models.first() == HuggingFaceModel.default + } + .await() + .assertComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/remote/OpenAiGenerationRemoteDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/remote/OpenAiGenerationRemoteDataSourceTest.kt new file mode 100644 index 00000000..a4f7069f --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/remote/OpenAiGenerationRemoteDataSourceTest.kt @@ -0,0 +1,98 @@ +package com.shifthackz.aisdv1.data.remote + +import com.shifthackz.aisdv1.data.mocks.mockBadOpenAiResponse +import com.shifthackz.aisdv1.data.mocks.mockSuccessOpenAiResponse +import com.shifthackz.aisdv1.data.mocks.mockTextToImagePayload +import com.shifthackz.aisdv1.domain.entity.AiGenerationResult +import com.shifthackz.aisdv1.network.api.openai.OpenAiApi +import com.shifthackz.aisdv1.network.response.OpenAiResponse +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class OpenAiGenerationRemoteDataSourceTest { + + private val stubApiException = Throwable("Internal server error.") + private val stubApi = mockk() + + private val remoteDataSource = OpenAiGenerationRemoteDataSource(stubApi) + + @Test + fun `given attempt to validate bearer token, api returns success response, expected true`() { + every { + stubApi.validateBearerToken() + } returns Completable.complete() + + remoteDataSource + .validateApiKey() + .test() + .assertNoErrors() + .assertValue(true) + .await() + .assertComplete() + } + + @Test + fun `given attempt to validate bearer token, api returns error response, expected false`() { + every { + stubApi.validateBearerToken() + } returns Completable.error(Throwable("Bad api key.")) + + remoteDataSource + .validateApiKey() + .test() + .assertNoErrors() + .assertValue(false) + .await() + .assertComplete() + } + + @Test + fun `given attempt to generate txt2img, api returns result, expected valid ai generation result value`() { + every { + stubApi.generateImage(any()) + } returns Single.just(mockSuccessOpenAiResponse) + + remoteDataSource + .textToImage(mockTextToImagePayload) + .test() + .assertNoErrors() + .assertValue { it is AiGenerationResult } + .await() + .assertComplete() + } + + @Test + fun `given attempt to generate txt2img, api returns empty result, expected error value`() { + every { + stubApi.generateImage(any()) + } returns Single.just(mockBadOpenAiResponse) + + remoteDataSource + .textToImage(mockTextToImagePayload) + .test() + .assertError { t -> + t is IllegalStateException && t.message == "Got null data object from API." + } + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to generate txt2img, api request fails, expected error value`() { + every { + stubApi.generateImage(any()) + } returns Single.error(stubApiException) + + remoteDataSource + .textToImage(mockTextToImagePayload) + .test() + .assertError(stubApiException) + .assertNoValues() + .await() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/remote/RandomImageRemoteDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/remote/RandomImageRemoteDataSourceTest.kt new file mode 100644 index 00000000..c8cf8151 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/remote/RandomImageRemoteDataSourceTest.kt @@ -0,0 +1,47 @@ +package com.shifthackz.aisdv1.data.remote + +import android.graphics.Bitmap +import com.shifthackz.aisdv1.network.api.imagecdn.ImageCdnRestApi +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class RandomImageRemoteDataSourceTest { + + private val stubException = Throwable("Internal server error.") + private val stubBitmap = mockk() + private val stubApi = mockk() + + private val remoteDataSource = RandomImageRemoteDataSource(stubApi) + + @Test + fun `given attempt to fetch bitmap with random image, api request succeed, expected valid bitmap value`() { + every { + stubApi.fetchRandomImage() + } returns Single.just(stubBitmap) + + remoteDataSource + .fetch() + .test() + .assertNoErrors() + .assertValue(stubBitmap) + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch bitmap with random image, api request fails, expected error value`() { + every { + stubApi.fetchRandomImage() + } returns Single.error(stubException) + + remoteDataSource + .fetch() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/remote/ServerConfigurationRemoteDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/remote/ServerConfigurationRemoteDataSourceTest.kt new file mode 100644 index 00000000..b3edaacc --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/remote/ServerConfigurationRemoteDataSourceTest.kt @@ -0,0 +1,89 @@ +package com.shifthackz.aisdv1.data.remote + +import com.shifthackz.aisdv1.data.mocks.mockServerConfigurationRaw +import com.shifthackz.aisdv1.data.provider.ServerUrlProvider +import com.shifthackz.aisdv1.domain.entity.ServerConfiguration +import com.shifthackz.aisdv1.network.api.automatic1111.Automatic1111RestApi +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single +import org.junit.Before +import org.junit.Test + +class ServerConfigurationRemoteDataSourceTest { + + private val stubException = Throwable("Internal server error.") + private val stubUrlProvider = mockk() + private val stubApi = mockk() + + private val remoteDataSource = ServerConfigurationRemoteDataSource( + serverUrlProvider = stubUrlProvider, + api = stubApi, + ) + + @Before + fun initialize() { + every { + stubUrlProvider(any()) + } returns Single.just("http://192.168.0.1:7860") + } + + @Test + fun `given attempt to fetch configuration, api returns success response, expected valid server configuration value`() { + every { + stubApi.fetchConfiguration(any()) + } returns Single.just(mockServerConfigurationRaw) + + remoteDataSource + .fetchConfiguration() + .test() + .assertNoErrors() + .assertValue(ServerConfiguration("5598")) + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch configuration, api returns error response, expected error value`() { + every { + stubApi.fetchConfiguration(any()) + } returns Single.error(stubException) + + remoteDataSource + .fetchConfiguration() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to update configuration, api returns success response, expected complete value`() { + every { + stubApi.updateConfiguration(any(), any()) + } returns Completable.complete() + + remoteDataSource + .updateConfiguration(ServerConfiguration("5598")) + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given attempt to update configuration, api returns error response, expected error value`() { + every { + stubApi.updateConfiguration(any(), any()) + } returns Completable.error(stubException) + + remoteDataSource + .updateConfiguration(ServerConfiguration("5598")) + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/remote/StabilityAiCreditsRemoteDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/remote/StabilityAiCreditsRemoteDataSourceTest.kt new file mode 100644 index 00000000..c621f5fb --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/remote/StabilityAiCreditsRemoteDataSourceTest.kt @@ -0,0 +1,76 @@ +package com.shifthackz.aisdv1.data.remote + +import com.shifthackz.aisdv1.network.api.stabilityai.StabilityAiApi +import com.shifthackz.aisdv1.network.response.StabilityCreditsResponse +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class StabilityAiCreditsRemoteDataSourceTest { + + private val stubException = Throwable("Internal server error.") + private val stubApi = mockk() + + private val remoteDataSource = StabilityAiCreditsRemoteDataSource(stubApi) + + @Test + fun `given attempt to fetch credits, api returns response with normal value, expected valid credits value`() { + every { + stubApi.fetchCredits() + } returns Single.just(StabilityCreditsResponse(5598f)) + + remoteDataSource + .fetch() + .test() + .assertNoErrors() + .assertValue(5598f) + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch credits, api returns response with zero value, expected zero credits value`() { + every { + stubApi.fetchCredits() + } returns Single.just(StabilityCreditsResponse(0f)) + + remoteDataSource + .fetch() + .test() + .assertNoErrors() + .assertValue(0f) + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch credits, api returns response with null value, expected zero credits value`() { + every { + stubApi.fetchCredits() + } returns Single.just(StabilityCreditsResponse(null)) + + remoteDataSource + .fetch() + .test() + .assertNoErrors() + .assertValue(0f) + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch credits, api returns error response, expected error value`() { + every { + stubApi.fetchCredits() + } returns Single.error(stubException) + + remoteDataSource + .fetch() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/remote/StabilityAiEnginesRemoteDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/remote/StabilityAiEnginesRemoteDataSourceTest.kt new file mode 100644 index 00000000..57f97915 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/remote/StabilityAiEnginesRemoteDataSourceTest.kt @@ -0,0 +1,66 @@ +package com.shifthackz.aisdv1.data.remote + +import com.shifthackz.aisdv1.data.mocks.mockStabilityAiEnginesRaw +import com.shifthackz.aisdv1.domain.entity.StabilityAiEngine +import com.shifthackz.aisdv1.network.api.stabilityai.StabilityAiApi +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class StabilityAiEnginesRemoteDataSourceTest { + + private val stubException = Throwable("Internal server error.") + private val stubApi = mockk() + + private val remoteDataSource = StabilityAiEnginesRemoteDataSource(stubApi) + + @Test + fun `given attempt to fetch engines, api returns success response, expected valid engines list value`() { + every { + stubApi.fetchEngines() + } returns Single.just(mockStabilityAiEnginesRaw) + + remoteDataSource + .fetch() + .test() + .assertNoErrors() + .assertValue { engines -> + engines is List + && engines.size == mockStabilityAiEnginesRaw.size + && engines.any { it.id == "5598" } + } + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch engines, api returns empty response, expected empty engines list value`() { + every { + stubApi.fetchEngines() + } returns Single.just(emptyList()) + + remoteDataSource + .fetch() + .test() + .assertNoErrors() + .assertValue(emptyList()) + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch engines, api returns error response, expected error value`() { + every { + stubApi.fetchEngines() + } returns Single.error(stubException) + + remoteDataSource + .fetch() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/remote/StabilityAiGenerationRemoteDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/remote/StabilityAiGenerationRemoteDataSourceTest.kt new file mode 100644 index 00000000..04665a14 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/remote/StabilityAiGenerationRemoteDataSourceTest.kt @@ -0,0 +1,108 @@ +package com.shifthackz.aisdv1.data.remote + +import com.shifthackz.aisdv1.data.mocks.mockBadStabilityGenerationResponse +import com.shifthackz.aisdv1.data.mocks.mockStabilityGenerationResponse +import com.shifthackz.aisdv1.data.mocks.mockTextToImagePayload +import com.shifthackz.aisdv1.domain.entity.AiGenerationResult +import com.shifthackz.aisdv1.network.api.stabilityai.StabilityAiApi +import com.shifthackz.aisdv1.network.error.StabilityAiErrorMapper +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single +import org.junit.Before +import org.junit.Test + +class StabilityAiGenerationRemoteDataSourceTest { + + private val stubException = Throwable("Internal server error.") + private val stubApi = mockk() + private val stubErrorMapper = mockk() + + private val remoteDataSource = StabilityAiGenerationRemoteDataSource( + api = stubApi, + stabilityAiErrorMapper = stubErrorMapper, + ) + + @Before + fun initialize() { + every { + stubErrorMapper.invoke(any()) + } returns Single.error(stubException) + } + + @Test + fun `given attempt to validate bearer token, api returns success response, expected true`() { + every { + stubApi.validateBearerToken() + } returns Completable.complete() + + remoteDataSource + .validateApiKey() + .test() + .assertNoErrors() + .assertValue(true) + .await() + .assertComplete() + } + + @Test + fun `given attempt to validate bearer token, api returns error response, expected false`() { + every { + stubApi.validateBearerToken() + } returns Completable.error(Throwable("Invalid api key.")) + + remoteDataSource + .validateApiKey() + .test() + .assertNoErrors() + .assertValue(false) + .await() + .assertComplete() + } + + @Test + fun `given attempt to generate txt2img, api returns result, expected valid ai generation result value`() { + every { + stubApi.textToImage(any(), any()) + } returns Single.just(mockStabilityGenerationResponse) + + remoteDataSource + .textToImage("5598", mockTextToImagePayload) + .test() + .assertNoErrors() + .assertValue { it is AiGenerationResult } + .await() + .assertComplete() + } + + @Test + fun `given attempt to generate txt2img, api returns empty result, expected error value`() { + every { + stubApi.textToImage(any(), any()) + } returns Single.just(mockBadStabilityGenerationResponse) + + remoteDataSource + .textToImage("5598", mockTextToImagePayload) + .test() + .assertError { true } + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to generate txt2img, api returns error response, expected error value`() { + every { + stubApi.textToImage(any(), any()) + } returns Single.error(stubException) + + remoteDataSource + .textToImage("5598", mockTextToImagePayload) + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/remote/StableDiffusionEmbeddingsRemoteDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/remote/StableDiffusionEmbeddingsRemoteDataSourceTest.kt new file mode 100644 index 00000000..bc31e80f --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/remote/StableDiffusionEmbeddingsRemoteDataSourceTest.kt @@ -0,0 +1,76 @@ +package com.shifthackz.aisdv1.data.remote + +import com.shifthackz.aisdv1.data.mocks.mockEmptySdEmbeddingsResponse +import com.shifthackz.aisdv1.data.mocks.mockSdEmbeddingsResponse +import com.shifthackz.aisdv1.data.provider.ServerUrlProvider +import com.shifthackz.aisdv1.domain.entity.StableDiffusionEmbedding +import com.shifthackz.aisdv1.network.api.automatic1111.Automatic1111RestApi +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Single +import org.junit.Before +import org.junit.Test + +class StableDiffusionEmbeddingsRemoteDataSourceTest { + + private val stubException = Throwable("Internal server error.") + private val stubUrlProvider = mockk() + private val stubApi = mockk() + + private val remoteDataSource = StableDiffusionEmbeddingsRemoteDataSource( + serverUrlProvider = stubUrlProvider, + api = stubApi, + ) + + @Before + fun initialize() { + every { + stubUrlProvider(any()) + } returns Single.just("http://192.168.0.1:7860") + } + + @Test + fun `given attempt to fetch embeddings, api returns success response, expected valid embeddings value`() { + every { + stubApi.fetchEmbeddings(any()) + } returns Single.just(mockSdEmbeddingsResponse) + + remoteDataSource + .fetchEmbeddings() + .test() + .assertNoErrors() + .assertValue(listOf(StableDiffusionEmbedding("1504"))) + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch embeddings, api returns empty response, expected empty embeddings value`() { + every { + stubApi.fetchEmbeddings(any()) + } returns Single.just(mockEmptySdEmbeddingsResponse) + + remoteDataSource + .fetchEmbeddings() + .test() + .assertNoErrors() + .assertValue(emptyList()) + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch embeddings, api returns error response, expected error value`() { + every { + stubApi.fetchEmbeddings(any()) + } returns Single.error(stubException) + + remoteDataSource + .fetchEmbeddings() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/remote/StableDiffusionGenerationRemoteDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/remote/StableDiffusionGenerationRemoteDataSourceTest.kt new file mode 100644 index 00000000..b38b133f --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/remote/StableDiffusionGenerationRemoteDataSourceTest.kt @@ -0,0 +1,150 @@ +package com.shifthackz.aisdv1.data.remote + +import com.shifthackz.aisdv1.data.mocks.mockImageToImagePayload +import com.shifthackz.aisdv1.data.mocks.mockSdGenerationResponse +import com.shifthackz.aisdv1.data.mocks.mockTextToImagePayload +import com.shifthackz.aisdv1.data.provider.ServerUrlProvider +import com.shifthackz.aisdv1.domain.entity.AiGenerationResult +import com.shifthackz.aisdv1.network.api.automatic1111.Automatic1111RestApi +import com.shifthackz.aisdv1.network.response.SdGenerationResponse +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single +import org.junit.Before +import org.junit.Test + +class StableDiffusionGenerationRemoteDataSourceTest { + + private val stubException = Throwable("Internal server error.") + private val stubUrlProvider = mockk() + private val stubApi = mockk() + + private val remoteDataSource = StableDiffusionGenerationRemoteDataSource( + serverUrlProvider = stubUrlProvider, + api = stubApi, + ) + + @Before + fun initialize() { + every { + stubUrlProvider(any()) + } returns Single.just("http://192.168.0.1:7860") + } + + @Test + fun `given attempt to do health check, api returns success response, expected complete value`() { + every { + stubApi.healthCheck(any()) + } returns Completable.complete() + + remoteDataSource + .checkAvailability() + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given attempt to do health check, api returns error response, expected error value`() { + every { + stubApi.healthCheck(any()) + } returns Completable.error(stubException) + + remoteDataSource + .checkAvailability() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to generate txt2img, api returns result, expected valid ai generation result value`() { + every { + stubApi.textToImage(any(), any()) + } returns Single.just(mockSdGenerationResponse) + + remoteDataSource + .textToImage(mockTextToImagePayload) + .test() + .assertNoErrors() + .assertValue { it is AiGenerationResult } + .await() + .assertComplete() + } + + @Test + fun `given attempt to generate txt2img, api returns error, expected error value`() { + every { + stubApi.textToImage(any(), any()) + } returns Single.error(stubException) + + remoteDataSource + .textToImage(mockTextToImagePayload) + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to generate img2img, api returns result, expected valid ai generation result value`() { + every { + stubApi.imageToImage(any(), any()) + } returns Single.just(mockSdGenerationResponse) + + remoteDataSource + .imageToImage(mockImageToImagePayload) + .test() + .assertNoErrors() + .assertValue { it is AiGenerationResult } + .await() + .assertComplete() + } + + @Test + fun `given attempt to generate img2img, api returns error, expected error value`() { + every { + stubApi.imageToImage(any(), any()) + } returns Single.error(stubException) + + remoteDataSource + .imageToImage(mockImageToImagePayload) + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to interrupt generation, api returns success response, expected complete value`() { + every { + stubApi.interrupt(any()) + } returns Completable.complete() + + remoteDataSource + .interruptGeneration() + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given attempt to interrupt generation, api returns error response, expected error value`() { + every { + stubApi.interrupt(any()) + } returns Completable.error(stubException) + + remoteDataSource + .interruptGeneration() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/remote/StableDiffusionHyperNetworksRemoteDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/remote/StableDiffusionHyperNetworksRemoteDataSourceTest.kt new file mode 100644 index 00000000..610b95d7 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/remote/StableDiffusionHyperNetworksRemoteDataSourceTest.kt @@ -0,0 +1,79 @@ +package com.shifthackz.aisdv1.data.remote + +import com.shifthackz.aisdv1.data.mocks.mockStableDiffusionHyperNetworkRaw +import com.shifthackz.aisdv1.data.provider.ServerUrlProvider +import com.shifthackz.aisdv1.domain.entity.StableDiffusionHyperNetwork +import com.shifthackz.aisdv1.network.api.automatic1111.Automatic1111RestApi +import com.shifthackz.aisdv1.network.model.StableDiffusionHyperNetworkRaw +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Single +import org.junit.Before +import org.junit.Test + +class StableDiffusionHyperNetworksRemoteDataSourceTest { + + private val stubException = Throwable("Internal server error.") + private val stubUrlProvider = mockk() + private val stubApi = mockk() + + private val remoteDataSource = StableDiffusionHyperNetworksRemoteDataSource( + serverUrlProvider = stubUrlProvider, + api = stubApi, + ) + + @Before + fun initialize() { + every { + stubUrlProvider(any()) + } returns Single.just("http://192.168.0.1:7860") + } + + @Test + fun `given attempt to fetch hyper networks, api returns success response, expected valid hyper networks list value`() { + every { + stubApi.fetchHyperNetworks(any()) + } returns Single.just(mockStableDiffusionHyperNetworkRaw) + + remoteDataSource + .fetchHyperNetworks() + .test() + .assertNoErrors() + .assertValue { networks -> + networks is List + && networks.size == mockStableDiffusionHyperNetworkRaw.size + } + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch hyper networks, api returns empty response, expected empty hyper networks list value`() { + every { + stubApi.fetchHyperNetworks(any()) + } returns Single.just(emptyList()) + + remoteDataSource + .fetchHyperNetworks() + .test() + .assertNoErrors() + .assertValue(emptyList()) + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch hyper networks, api returns error response, expected error value`() { + every { + stubApi.fetchHyperNetworks(any()) + } returns Single.error(stubException) + + remoteDataSource + .fetchHyperNetworks() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/remote/StableDiffusionLorasRemoteDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/remote/StableDiffusionLorasRemoteDataSourceTest.kt new file mode 100644 index 00000000..3c999085 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/remote/StableDiffusionLorasRemoteDataSourceTest.kt @@ -0,0 +1,78 @@ +package com.shifthackz.aisdv1.data.remote + +import com.shifthackz.aisdv1.data.mocks.mockStableDiffusionLoraRaw +import com.shifthackz.aisdv1.data.provider.ServerUrlProvider +import com.shifthackz.aisdv1.domain.entity.StableDiffusionLora +import com.shifthackz.aisdv1.network.api.automatic1111.Automatic1111RestApi +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Single +import org.junit.Before +import org.junit.Test + +class StableDiffusionLorasRemoteDataSourceTest { + + private val stubException = Throwable("Internal server error.") + private val stubUrlProvider = mockk() + private val stubApi = mockk() + + private val remoteDataSource = StableDiffusionLorasRemoteDataSource( + serverUrlProvider = stubUrlProvider, + api = stubApi, + ) + + @Before + fun initialize() { + every { + stubUrlProvider(any()) + } returns Single.just("http://192.168.0.1:7860") + } + + @Test + fun `given attempt to fetch loras, api returns success response, expected valid loras list value`() { + every { + stubApi.fetchLoras(any()) + } returns Single.just(mockStableDiffusionLoraRaw) + + remoteDataSource + .fetchLoras() + .test() + .assertNoErrors() + .assertValue { loras -> + loras is List + && loras.size == mockStableDiffusionLoraRaw.size + } + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch loras, api returns empty response, expected empty loras value`() { + every { + stubApi.fetchLoras(any()) + } returns Single.just(emptyList()) + + remoteDataSource + .fetchLoras() + .test() + .assertNoErrors() + .assertValue(emptyList()) + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch loras, api returns error response, expected error value`() { + every { + stubApi.fetchLoras(any()) + } returns Single.error(stubException) + + remoteDataSource + .fetchLoras() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/remote/StableDiffusionModelsRemoteDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/remote/StableDiffusionModelsRemoteDataSourceTest.kt new file mode 100644 index 00000000..6cbb03c6 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/remote/StableDiffusionModelsRemoteDataSourceTest.kt @@ -0,0 +1,78 @@ +package com.shifthackz.aisdv1.data.remote + +import com.shifthackz.aisdv1.data.mocks.mockStableDiffusionModelRaw +import com.shifthackz.aisdv1.data.provider.ServerUrlProvider +import com.shifthackz.aisdv1.domain.entity.StableDiffusionModel +import com.shifthackz.aisdv1.network.api.automatic1111.Automatic1111RestApi +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Single +import org.junit.Before +import org.junit.Test + +class StableDiffusionModelsRemoteDataSourceTest { + + private val stubException = Throwable("Internal server error.") + private val stubUrlProvider = mockk() + private val stubApi = mockk() + + private val remoteDataSource = StableDiffusionModelsRemoteDataSource( + serverUrlProvider = stubUrlProvider, + api = stubApi, + ) + + @Before + fun initialize() { + every { + stubUrlProvider(any()) + } returns Single.just("http://192.168.0.1:7860") + } + + @Test + fun `given attempt to fetch models, api returns success response, expected valid models list value`() { + every { + stubApi.fetchSdModels(any()) + } returns Single.just(mockStableDiffusionModelRaw) + + remoteDataSource + .fetchSdModels() + .test() + .assertNoErrors() + .assertValue { models -> + models is List + && models.size == mockStableDiffusionModelRaw.size + } + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch models, api returns empty response, expected empty models value`() { + every { + stubApi.fetchSdModels(any()) + } returns Single.just(emptyList()) + + remoteDataSource + .fetchSdModels() + .test() + .assertNoErrors() + .assertValue(emptyList()) + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch models, api returns error response, expected error value`() { + every { + stubApi.fetchSdModels(any()) + } returns Single.error(stubException) + + remoteDataSource + .fetchSdModels() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/remote/StableDiffusionSamplersRemoteDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/remote/StableDiffusionSamplersRemoteDataSourceTest.kt new file mode 100644 index 00000000..f8a757eb --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/remote/StableDiffusionSamplersRemoteDataSourceTest.kt @@ -0,0 +1,78 @@ +package com.shifthackz.aisdv1.data.remote + +import com.shifthackz.aisdv1.data.mocks.mockStableDiffusionSamplerRaw +import com.shifthackz.aisdv1.data.provider.ServerUrlProvider +import com.shifthackz.aisdv1.domain.entity.StableDiffusionSampler +import com.shifthackz.aisdv1.network.api.automatic1111.Automatic1111RestApi +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Single +import org.junit.Before +import org.junit.Test + +class StableDiffusionSamplersRemoteDataSourceTest { + + private val stubException = Throwable("Internal server error.") + private val stubUrlProvider = mockk() + private val stubApi = mockk() + + private val remoteDataSource = StableDiffusionSamplersRemoteDataSource( + stubUrlProvider, + stubApi, + ) + + @Before + fun initialize() { + every { + stubUrlProvider(any()) + } returns Single.just("http://192.168.0.1:7860") + } + + @Test + fun `given attempt to fetch samplers, api returns success response, expected valid samplers list value`() { + every { + stubApi.fetchSamplers(any()) + } returns Single.just(mockStableDiffusionSamplerRaw) + + remoteDataSource + .fetchSamplers() + .test() + .assertNoErrors() + .assertValue { samplers -> + samplers is List + && samplers.size == mockStableDiffusionSamplerRaw.size + } + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch samplers, api returns empty response, expected empty samplers value`() { + every { + stubApi.fetchSamplers(any()) + } returns Single.just(emptyList()) + + remoteDataSource + .fetchSamplers() + .test() + .assertNoErrors() + .assertValue(emptyList()) + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch samplers, api returns error response, expected error value`() { + every { + stubApi.fetchSamplers(any()) + } returns Single.error(stubException) + + remoteDataSource + .fetchSamplers() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/repository/DownloadableModelRepositoryImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/repository/DownloadableModelRepositoryImplTest.kt new file mode 100644 index 00000000..2daf1b5f --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/repository/DownloadableModelRepositoryImplTest.kt @@ -0,0 +1,414 @@ +package com.shifthackz.aisdv1.data.repository + +import com.shifthackz.aisdv1.data.mocks.mockLocalAiModel +import com.shifthackz.aisdv1.data.mocks.mockLocalAiModels +import com.shifthackz.aisdv1.domain.datasource.DownloadableModelDataSource +import com.shifthackz.aisdv1.domain.entity.DownloadState +import com.shifthackz.aisdv1.domain.entity.LocalAiModel +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.BackpressureStrategy +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Flowable +import io.reactivex.rxjava3.core.Observable +import io.reactivex.rxjava3.core.Single +import io.reactivex.rxjava3.subjects.BehaviorSubject +import org.junit.Before +import org.junit.Test +import java.io.File + +class DownloadableModelRepositoryImplTest { + + private val stubException = Throwable("Something went wrong.") + private val stubFile = mockk() + private val stubLocalModels = BehaviorSubject.create>() + private val stubDownloadState = BehaviorSubject.create() + private val stubRemoteDataSource = mockk() + private val stubLocalDataSource = mockk() + + private val repository = DownloadableModelRepositoryImpl( + remoteDataSource = stubRemoteDataSource, + localDataSource = stubLocalDataSource, + ) + + @Before + fun initialize() { + every { + stubLocalDataSource.observeAll() + } returns stubLocalModels.toFlowable(BackpressureStrategy.LATEST) + + every { + stubRemoteDataSource.download(any(), any()) + } returns stubDownloadState + } + + @Test + fun `given attempt to check if model downloaded, local data source returns true, expected true value`() { + every { + stubLocalDataSource.isDownloaded(any()) + } returns Single.just(true) + + repository + .isModelDownloaded("5598") + .test() + .assertNoErrors() + .assertValue(true) + .await() + .assertComplete() + } + + @Test + fun `given attempt to check if model downloaded, local data source returns false, expected false value`() { + every { + stubLocalDataSource.isDownloaded(any()) + } returns Single.just(false) + + repository + .isModelDownloaded("5598") + .test() + .assertNoErrors() + .assertValue(false) + .await() + .assertComplete() + } + + @Test + fun `given attempt to check if model downloaded, local data source throws exception, expected error value`() { + every { + stubLocalDataSource.isDownloaded(any()) + } returns Single.error(stubException) + + repository + .isModelDownloaded("5598") + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to delete model, local data source completes, expected complete value`() { + every { + stubLocalDataSource.delete(any()) + } returns Completable.complete() + + repository + .delete("5598") + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given attempt to delete model, local data source throws exception, expected error value`() { + every { + stubLocalDataSource.delete(any()) + } returns Completable.error(stubException) + + repository + .delete("5598") + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to select model, local data source completes, expected complete value`() { + every { + stubLocalDataSource.select(any()) + } returns Completable.complete() + + repository + .select("5598") + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given attempt to select model, local data source throws exception, expected error value`() { + every { + stubLocalDataSource.select(any()) + } returns Completable.error(stubException) + + repository + .select("5598") + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to get all, remote returns list, save success, local query success, expected valid domain model list value`() { + every { + stubRemoteDataSource.fetch() + } returns Single.just(mockLocalAiModels) + + every { + stubLocalDataSource.save(any()) + } returns Completable.complete() + + every { + stubLocalDataSource.getAll() + } returns Single.just(mockLocalAiModels) + + repository + .getAll() + .test() + .assertNoErrors() + .assertValue(mockLocalAiModels) + .await() + .assertComplete() + } + + @Test + fun `given attempt to get all, remote returns list, save fails, local query success, expected valid domain model list value`() { + every { + stubRemoteDataSource.fetch() + } returns Single.just(mockLocalAiModels) + + every { + stubLocalDataSource.save(any()) + } returns Completable.error(stubException) + + every { + stubLocalDataSource.getAll() + } returns Single.just(mockLocalAiModels) + + repository + .getAll() + .test() + .assertNoErrors() + .assertValue(mockLocalAiModels) + .await() + .assertComplete() + } + + @Test + fun `given attempt to get all, remote fails, local query success, expected valid domain model list value`() { + every { + stubRemoteDataSource.fetch() + } returns Single.error(stubException) + + every { + stubLocalDataSource.save(any()) + } returns Completable.complete() + + every { + stubLocalDataSource.getAll() + } returns Single.just(mockLocalAiModels) + + repository + .getAll() + .test() + .assertNoErrors() + .assertValue(mockLocalAiModels) + .await() + .assertComplete() + } + + @Test + fun `given attempt to get all, remote returns list, save success, local query fails, expected error value`() { + every { + stubRemoteDataSource.fetch() + } returns Single.just(mockLocalAiModels) + + every { + stubLocalDataSource.save(any()) + } returns Completable.complete() + + every { + stubLocalDataSource.getAll() + } returns Single.error(stubException) + + repository + .getAll() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to get by id, local data source returns data, expected valid domain model value`() { + every { + stubLocalDataSource.getById(any()) + } returns Single.just(mockLocalAiModel) + + repository + .getById("5598") + .test() + .assertNoErrors() + .assertValue(mockLocalAiModel) + .await() + .assertComplete() + } + + @Test + fun `given attempt to get by id, local data source fails, expected error value`() { + every { + stubLocalDataSource.getById(any()) + } returns Single.error(stubException) + + repository + .getById("5598") + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given observe all models, local data source emits empty list, then another list, expected empty value, then valid domain models list value`() { + val stubObserver = repository.observeAll().test() + + stubLocalModels.onNext(emptyList()) + + stubObserver + .assertNoErrors() + .assertValueAt(0, emptyList()) + + stubLocalModels.onNext(mockLocalAiModels) + + stubObserver + .assertNoErrors() + .assertValueAt(1, mockLocalAiModels) + } + + @Test + fun `given observe all models, local data source emits list, then changed list, expected valid domain models list value, then changed value`() { + val stubObserver = repository.observeAll().test() + + stubLocalModels.onNext(mockLocalAiModels) + + stubObserver + .assertNoErrors() + .assertValueAt(0, mockLocalAiModels) + + stubLocalModels.onNext(mockLocalAiModels.map { it.copy(id = "1") }) + + stubObserver + .assertNoErrors() + .assertValueAt(1, listOf(mockLocalAiModel.copy(id = "1"))) + } + + @Test + fun `given observe all models, local data source throws exception, expected error value`() { + every { + stubLocalDataSource.observeAll() + } returns Flowable.error(stubException) + + repository + .observeAll() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to download model, local data source has no such model, expected error value`() { + every { + stubLocalDataSource.getById(any()) + } returns Single.error(stubException) + + repository + .download("5598") + .test() + .assertNoValues() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to download model, local data source has such model, download succeeds, expected unknown, downloading, complete values`() { + every { + stubLocalDataSource.getById(any()) + } returns Single.just(mockLocalAiModel) + + val stubObserver = repository + .download("5598") + .test() + + stubDownloadState.onNext(DownloadState.Unknown) + + stubObserver + .assertNoErrors() + .assertValueAt(0, DownloadState.Unknown) + + stubDownloadState.onNext(DownloadState.Downloading(44)) + + stubObserver + .assertNoErrors() + .assertValueAt(1, DownloadState.Downloading(44)) + + stubDownloadState.onNext(DownloadState.Downloading(100)) + + stubObserver + .assertNoErrors() + .assertValueAt(2, DownloadState.Downloading(100)) + + stubDownloadState.onNext(DownloadState.Complete(stubFile)) + + stubObserver + .assertNoErrors() + .assertValueAt(3, DownloadState.Complete(stubFile)) + } + + @Test + fun `given attempt to download model, local data source has such model, download fails, expected unknown, downloading, error values`() { + every { + stubLocalDataSource.getById(any()) + } returns Single.just(mockLocalAiModel) + + val stubObserver = repository + .download("5598") + .test() + + stubDownloadState.onNext(DownloadState.Unknown) + + stubObserver + .assertNoErrors() + .assertValueAt(0, DownloadState.Unknown) + + stubDownloadState.onNext(DownloadState.Downloading(44)) + + stubObserver + .assertNoErrors() + .assertValueAt(1, DownloadState.Downloading(44)) + + stubDownloadState.onNext(DownloadState.Error(stubException)) + + stubObserver + .assertNoErrors() + .assertValueAt(2, DownloadState.Error(stubException)) + } + + @Test + fun `given attempt to download model, local data source has such model, remote data source throws exception, expected error value`() { + every { + stubLocalDataSource.getById(any()) + } returns Single.just(mockLocalAiModel) + + every { + stubRemoteDataSource.download(any(), any()) + } returns Observable.error(stubException) + + repository + .download("5598") + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/repository/GenerationResultRepositoryImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/repository/GenerationResultRepositoryImplTest.kt new file mode 100644 index 00000000..6242c59a --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/repository/GenerationResultRepositoryImplTest.kt @@ -0,0 +1,274 @@ +package com.shifthackz.aisdv1.data.repository + +import com.shifthackz.aisdv1.core.imageprocessing.Base64ToBitmapConverter +import com.shifthackz.aisdv1.data.mocks.mockAiGenerationResult +import com.shifthackz.aisdv1.data.mocks.mockAiGenerationResults +import com.shifthackz.aisdv1.domain.datasource.GenerationResultDataSource +import com.shifthackz.aisdv1.domain.entity.MediaStoreInfo +import com.shifthackz.aisdv1.domain.gateway.MediaStoreGateway +import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class GenerationResultRepositoryImplTest { + + private val stubException = Throwable("Something went wrong.") + private val stubPreferenceManager = mockk() + private val stubMediaStoreGateway = mockk() + private val stubBase64ToBitmapConverter = mockk() + private val stubLocalDataSource = mockk() + + private val repository = GenerationResultRepositoryImpl( + preferenceManager = stubPreferenceManager, + mediaStoreGateway = stubMediaStoreGateway, + base64ToBitmapConverter = stubBase64ToBitmapConverter, + localDataSource = stubLocalDataSource, + ) + + @Test + fun `given attempt to get all, local returns data, expected valid domain model list value`() { + every { + stubLocalDataSource.queryAll() + } returns Single.just(mockAiGenerationResults) + + repository + .getAll() + .test() + .assertNoErrors() + .assertValue(mockAiGenerationResults) + .await() + .assertComplete() + } + + @Test + fun `given attempt to get all, local returns empty data, expected empty domain model list value`() { + every { + stubLocalDataSource.queryAll() + } returns Single.just(emptyList()) + + repository + .getAll() + .test() + .assertNoErrors() + .assertValue(emptyList()) + .await() + .assertComplete() + } + + @Test + fun `given attempt to get all, local throws exception, expected error value`() { + every { + stubLocalDataSource.queryAll() + } returns Single.error(stubException) + + repository + .getAll() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to get page, local returns data, expected valid domain model list value`() { + every { + stubLocalDataSource.queryPage(any(), any()) + } returns Single.just(mockAiGenerationResults) + + repository + .getPage(20, 0) + .test() + .assertNoErrors() + .assertValue(mockAiGenerationResults) + .await() + .assertComplete() + } + + @Test + fun `given attempt to get page, local returns empty data, expected empty domain model list value`() { + every { + stubLocalDataSource.queryPage(any(), any()) + } returns Single.just(emptyList()) + + repository + .getPage(20, 0) + .test() + .assertNoErrors() + .assertValue(emptyList()) + .await() + .assertComplete() + } + + @Test + fun `given attempt to get page, local throws exception, expected error value`() { + every { + stubLocalDataSource.queryPage(any(), any()) + } returns Single.error(stubException) + + repository + .getPage(20, 0) + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to get media store info, gateway returned data, expected valid media store info value`() { + every { + stubMediaStoreGateway.getInfo() + } returns MediaStoreInfo() + + repository + .getMediaStoreInfo() + .test() + .assertNoErrors() + .assertValue(MediaStoreInfo()) + .await() + .assertComplete() + } + + @Test + fun `given attempt to get media store info, gateway throws exception, expected error value`() { + every { + stubMediaStoreGateway.getInfo() + } throws stubException + + repository + .getMediaStoreInfo() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to get by id, local returns data, expected valid domain model value`() { + every { + stubLocalDataSource.queryById(any()) + } returns Single.just(mockAiGenerationResult) + + repository + .getById(5598L) + .test() + .assertNoErrors() + .assertValue(mockAiGenerationResult) + .await() + .assertComplete() + } + + @Test + fun `given attempt to get by id, local throws exception, expected error value`() { + every { + stubLocalDataSource.queryById(any()) + } returns Single.error(stubException) + + repository + .getById(5598L) + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to delete by id, local delete success, expected complete value`() { + every { + stubLocalDataSource.deleteById(any()) + } returns Completable.complete() + + repository + .deleteById(5598L) + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given attempt to delete by id, local delete fails, expected error value`() { + every { + stubLocalDataSource.deleteById(any()) + } returns Completable.error(stubException) + + repository + .deleteById(5598L) + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to delete all, local delete success, expected complete value`() { + every { + stubLocalDataSource.deleteAll() + } returns Completable.complete() + + repository + .deleteAll() + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given attempt to delete all, local delete fails, expected complete value`() { + every { + stubLocalDataSource.deleteAll() + } returns Completable.error(stubException) + + repository + .deleteAll() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to insert data, local insert success, expected id of inserted model value`() { + every { + stubPreferenceManager.saveToMediaStore + } returns false + + every { + stubLocalDataSource.insert(any()) + } returns Single.just(mockAiGenerationResult.id) + + repository + .insert(mockAiGenerationResult) + .test() + .assertNoErrors() + .assertValue(5598L) + .await() + .assertComplete() + } + + @Test + fun `given attempt to insert data, local insert fails, expected error value`() { + every { + stubPreferenceManager.saveToMediaStore + } returns false + + every { + stubLocalDataSource.insert(any()) + } returns Single.error(stubException) + + repository + .insert(mockAiGenerationResult) + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/repository/HordeGenerationRepositoryImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/repository/HordeGenerationRepositoryImplTest.kt new file mode 100644 index 00000000..31ceb2d1 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/repository/HordeGenerationRepositoryImplTest.kt @@ -0,0 +1,217 @@ +package com.shifthackz.aisdv1.data.repository + +import com.shifthackz.aisdv1.core.imageprocessing.Base64ToBitmapConverter +import com.shifthackz.aisdv1.data.mocks.mockAiGenerationResult +import com.shifthackz.aisdv1.data.mocks.mockImageToImagePayload +import com.shifthackz.aisdv1.data.mocks.mockTextToImagePayload +import com.shifthackz.aisdv1.domain.datasource.GenerationResultDataSource +import com.shifthackz.aisdv1.domain.datasource.HordeGenerationDataSource +import com.shifthackz.aisdv1.domain.entity.HordeProcessStatus +import com.shifthackz.aisdv1.domain.gateway.MediaStoreGateway +import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.BackpressureStrategy +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Flowable +import io.reactivex.rxjava3.core.Single +import io.reactivex.rxjava3.subjects.BehaviorSubject +import org.junit.Before +import org.junit.Test + +class HordeGenerationRepositoryImplTest { + + private val stubException = Throwable("Something went wrong.") + private val stubStatus = BehaviorSubject.create() + private val stubMediaStoreGateway = mockk() + private val stubBase64ToBitmapConverter = mockk() + private val stubLocalDataSource = mockk() + private val stubPreferenceManager = mockk() + private val stubRemoteDataSource = mockk() + private val stubStatusSource = mockk() + + private val repository = HordeGenerationRepositoryImpl( + mediaStoreGateway = stubMediaStoreGateway, + base64ToBitmapConverter = stubBase64ToBitmapConverter, + localDataSource = stubLocalDataSource, + preferenceManager = stubPreferenceManager, + remoteDataSource = stubRemoteDataSource, + statusSource = stubStatusSource, + ) + + @Before + fun initialize() { + every { + stubStatusSource.observe() + } returns stubStatus.toFlowable(BackpressureStrategy.LATEST) + + every { + stubPreferenceManager.autoSaveAiResults + } returns false + } + + @Test + fun `given attempt to observe status, status source emits two values, expected valid values in same order`() { + val stubObserver = repository.observeStatus().test() + + stubStatus.onNext(HordeProcessStatus(5598, 1504)) + + stubObserver + .assertNoErrors() + .assertValueAt(0, HordeProcessStatus(5598, 1504)) + + stubStatus.onNext(HordeProcessStatus(0, 0)) + + stubObserver + .assertNoErrors() + .assertValueAt(1, HordeProcessStatus(0, 0)) + } + + @Test + fun `given attempt to observe status, status source throws exception, expected error value`() { + every { + stubStatusSource.observe() + } returns Flowable.error(stubException) + + repository + .observeStatus() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to validate api key, remote returns true, expected true value`() { + every { + stubRemoteDataSource.validateApiKey() + } returns Single.just(true) + + repository + .validateApiKey() + .test() + .assertNoErrors() + .assertValue(true) + .await() + .assertComplete() + } + + @Test + fun `given attempt to validate api key, remote returns false, expected false value`() { + every { + stubRemoteDataSource.validateApiKey() + } returns Single.just(false) + + repository + .validateApiKey() + .test() + .assertNoErrors() + .assertValue(false) + .await() + .assertComplete() + } + + @Test + fun `given attempt to validate api key, remote throws exception, expected error value`() { + every { + stubRemoteDataSource.validateApiKey() + } returns Single.error(stubException) + + repository + .validateApiKey() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to generate from text, remote returns result, expected valid domain model value`() { + every { + stubRemoteDataSource.textToImage(any()) + } returns Single.just(mockAiGenerationResult) + + repository + .generateFromText(mockTextToImagePayload) + .test() + .assertNoErrors() + .assertValue(mockAiGenerationResult) + .await() + .assertComplete() + } + + @Test + fun `given attempt to generate from text, remote throws exception, expected error value`() { + every { + stubRemoteDataSource.textToImage(any()) + } returns Single.error(stubException) + + repository + .generateFromText(mockTextToImagePayload) + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to generate from image, remote returns result, expected valid domain model value`() { + every { + stubRemoteDataSource.imageToImage(any()) + } returns Single.just(mockAiGenerationResult) + + repository + .generateFromImage(mockImageToImagePayload) + .test() + .assertNoErrors() + .assertValue(mockAiGenerationResult) + .await() + .assertComplete() + } + + @Test + fun `given attempt to generate from image, remote throws exception, expected error value`() { + every { + stubRemoteDataSource.imageToImage(any()) + } returns Single.error(stubException) + + repository + .generateFromImage(mockImageToImagePayload) + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to interrupt generation, remote completes, expected complete value`() { + every { + stubRemoteDataSource.interruptGeneration() + } returns Completable.complete() + + repository + .interruptGeneration() + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given attempt to interrupt generation, remote throws exception, expected error value`() { + every { + stubRemoteDataSource.interruptGeneration() + } returns Completable.error(stubException) + + repository + .interruptGeneration() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/repository/HuggingFaceGenerationRepositoryImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/repository/HuggingFaceGenerationRepositoryImplTest.kt new file mode 100644 index 00000000..67c5df99 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/repository/HuggingFaceGenerationRepositoryImplTest.kt @@ -0,0 +1,149 @@ +package com.shifthackz.aisdv1.data.repository + +import com.shifthackz.aisdv1.core.imageprocessing.Base64ToBitmapConverter +import com.shifthackz.aisdv1.data.mocks.mockAiGenerationResult +import com.shifthackz.aisdv1.data.mocks.mockImageToImagePayload +import com.shifthackz.aisdv1.data.mocks.mockTextToImagePayload +import com.shifthackz.aisdv1.domain.datasource.GenerationResultDataSource +import com.shifthackz.aisdv1.domain.datasource.HuggingFaceGenerationDataSource +import com.shifthackz.aisdv1.domain.gateway.MediaStoreGateway +import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Single +import org.junit.Before +import org.junit.Test + +class HuggingFaceGenerationRepositoryImplTest { + + private val stubException = Throwable("Something went wrong.") + private val stubMediaStoreGateway = mockk() + private val stubBase64ToBitmapConverter = mockk() + private val stubLocalDataSource = mockk() + private val stubPreferenceManager = mockk() + private val stubRemoteDataSource = mockk() + + private val repository = HuggingFaceGenerationRepositoryImpl( + mediaStoreGateway = stubMediaStoreGateway, + base64ToBitmapConverter = stubBase64ToBitmapConverter, + localDataSource = stubLocalDataSource, + preferenceManager = stubPreferenceManager, + remoteDataSource = stubRemoteDataSource, + ) + + @Before + fun initialize() { + every { + stubPreferenceManager.autoSaveAiResults + } returns false + + every { + stubPreferenceManager.huggingFaceModel + } returns "hf_5598" + } + + @Test + fun `given attempt to validate api key, remote returns true, expected true value`() { + every { + stubRemoteDataSource.validateApiKey() + } returns Single.just(true) + + repository + .validateApiKey() + .test() + .assertNoErrors() + .assertValue(true) + .await() + .assertComplete() + } + + @Test + fun `given attempt to validate api key, remote returns false, expected false value`() { + every { + stubRemoteDataSource.validateApiKey() + } returns Single.just(false) + + repository + .validateApiKey() + .test() + .assertNoErrors() + .assertValue(false) + .await() + .assertComplete() + } + + @Test + fun `given attempt to validate api key, remote throws exception, expected error value`() { + every { + stubRemoteDataSource.validateApiKey() + } returns Single.error(stubException) + + repository + .validateApiKey() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to generate from text, remote returns result, expected valid domain model value`() { + every { + stubRemoteDataSource.textToImage(any(), any()) + } returns Single.just(mockAiGenerationResult) + + repository + .generateFromText(mockTextToImagePayload) + .test() + .assertNoErrors() + .assertValue(mockAiGenerationResult) + .await() + .assertComplete() + } + + @Test + fun `given attempt to generate from text, remote throws exception, expected error value`() { + every { + stubRemoteDataSource.textToImage(any(), any()) + } returns Single.error(stubException) + + repository + .generateFromText(mockTextToImagePayload) + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to generate from image, remote returns result, expected valid domain model value`() { + every { + stubRemoteDataSource.imageToImage(any(), any()) + } returns Single.just(mockAiGenerationResult) + + repository + .generateFromImage(mockImageToImagePayload) + .test() + .assertNoErrors() + .assertValue(mockAiGenerationResult) + .await() + .assertComplete() + } + + @Test + fun `given attempt to generate from image, remote throws exception, expected error value`() { + every { + stubRemoteDataSource.imageToImage(any(), any()) + } returns Single.error(stubException) + + repository + .generateFromImage(mockImageToImagePayload) + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/repository/HuggingFaceModelsRepositoryImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/repository/HuggingFaceModelsRepositoryImplTest.kt new file mode 100644 index 00000000..e6cb8716 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/repository/HuggingFaceModelsRepositoryImplTest.kt @@ -0,0 +1,185 @@ +package com.shifthackz.aisdv1.data.repository + +import com.shifthackz.aisdv1.data.mocks.mockHuggingFaceModels +import com.shifthackz.aisdv1.domain.datasource.HuggingFaceModelsDataSource +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class HuggingFaceModelsRepositoryImplTest { + + private val stubException = Throwable("Something went wrong.") + private val stubRemoteDataSource = mockk() + private val stubLocalDataSource = mockk() + + private val repository = HuggingFaceModelsRepositoryImpl( + remoteDataSource = stubRemoteDataSource, + localDataSource = stubLocalDataSource, + ) + + @Test + fun `given attempt to fetch models, remote returns data, local insert success, expected complete value`() { + every { + stubRemoteDataSource.fetchHuggingFaceModels() + } returns Single.just(mockHuggingFaceModels) + + every { + stubLocalDataSource.save(any()) + } returns Completable.complete() + + repository + .fetchHuggingFaceModels() + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch models, remote throws exception, local insert success, expected error value`() { + every { + stubRemoteDataSource.fetchHuggingFaceModels() + } returns Single.error(stubException) + + every { + stubLocalDataSource.save(any()) + } returns Completable.complete() + + repository + .fetchHuggingFaceModels() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to fetch models, remote returns data, local insert fails, expected error value`() { + every { + stubRemoteDataSource.fetchHuggingFaceModels() + } returns Single.just(mockHuggingFaceModels) + + every { + stubLocalDataSource.save(any()) + } returns Completable.error(stubException) + + repository + .fetchHuggingFaceModels() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to get models, local data source returns list, expected valid domain models list value`() { + every { + stubLocalDataSource.getAll() + } returns Single.just(mockHuggingFaceModels) + + repository + .getHuggingFaceModels() + .test() + .assertNoErrors() + .assertValue(mockHuggingFaceModels) + .await() + .assertComplete() + } + + @Test + fun `given attempt to get models, local data source returns empty list, expected empty domain models list value`() { + every { + stubLocalDataSource.getAll() + } returns Single.just(emptyList()) + + repository + .getHuggingFaceModels() + .test() + .assertNoErrors() + .assertValue(emptyList()) + .await() + .assertComplete() + } + + @Test + fun `given attempt to get models, local data source throws exception, expected error value`() { + every { + stubLocalDataSource.getAll() + } returns Single.error(stubException) + + repository + .getHuggingFaceModels() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to fetch and get models, remote returns data, local returns data, expected valid domain models list value`() { + every { + stubRemoteDataSource.fetchHuggingFaceModels() + } returns Single.just(mockHuggingFaceModels) + + every { + stubLocalDataSource.save(any()) + } returns Completable.complete() + + every { + stubLocalDataSource.getAll() + } returns Single.just(mockHuggingFaceModels) + + repository + .fetchAndGetHuggingFaceModels() + .test() + .assertNoErrors() + .assertValue(mockHuggingFaceModels) + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch and get models, remote fails, local returns data, expected valid domain models list value`() { + every { + stubRemoteDataSource.fetchHuggingFaceModels() + } returns Single.error(stubException) + + every { + stubLocalDataSource.save(any()) + } returns Completable.complete() + + every { + stubLocalDataSource.getAll() + } returns Single.just(mockHuggingFaceModels) + + repository + .fetchAndGetHuggingFaceModels() + .test() + .assertNoErrors() + .assertValue(mockHuggingFaceModels) + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch and get models, remote fails, local fails, expected valid error value`() { + every { + stubRemoteDataSource.fetchHuggingFaceModels() + } returns Single.error(stubException) + + every { + stubLocalDataSource.getAll() + } returns Single.error(stubException) + + repository + .fetchAndGetHuggingFaceModels() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionGenerationRepositoryImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionGenerationRepositoryImplTest.kt new file mode 100644 index 00000000..2b776d5d --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionGenerationRepositoryImplTest.kt @@ -0,0 +1,215 @@ +package com.shifthackz.aisdv1.data.repository + +import android.graphics.Bitmap +import com.shifthackz.aisdv1.core.common.schedulers.SchedulersProvider +import com.shifthackz.aisdv1.core.imageprocessing.Base64ToBitmapConverter +import com.shifthackz.aisdv1.core.imageprocessing.BitmapToBase64Converter +import com.shifthackz.aisdv1.data.mocks.mockLocalAiModel +import com.shifthackz.aisdv1.data.mocks.mockTextToImagePayload +import com.shifthackz.aisdv1.domain.datasource.DownloadableModelDataSource +import com.shifthackz.aisdv1.domain.datasource.GenerationResultDataSource +import com.shifthackz.aisdv1.domain.entity.AiGenerationResult +import com.shifthackz.aisdv1.domain.feature.diffusion.LocalDiffusion +import com.shifthackz.aisdv1.domain.gateway.MediaStoreGateway +import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Observable +import io.reactivex.rxjava3.core.Scheduler +import io.reactivex.rxjava3.core.Single +import io.reactivex.rxjava3.schedulers.Schedulers +import io.reactivex.rxjava3.subjects.BehaviorSubject +import org.junit.Before +import org.junit.Test +import java.util.concurrent.Executor +import java.util.concurrent.Executors + +class LocalDiffusionGenerationRepositoryImplTest { + + private val stubBitmap = mockk() + private val stubException = Throwable("Something went wrong.") + private val stubStatus = BehaviorSubject.create() + private val stubMediaStoreGateway = mockk() + private val stubBase64ToBitmapConverter = mockk() + private val stubBitmapToBase64Converter = mockk() + private val stubLocalDataSource = mockk() + private val stubPreferenceManager = mockk() + private val stubLocalDiffusion = mockk() + private val stubDownloadableLocalDataSource = mockk() + + private val stubSchedulersProvider = object : SchedulersProvider { + override val io: Scheduler = Schedulers.trampoline() + override val ui: Scheduler = Schedulers.trampoline() + override val computation: Scheduler = Schedulers.trampoline() + override val singleThread: Executor = Executors.newSingleThreadExecutor() + } + + private val repository = LocalDiffusionGenerationRepositoryImpl( + mediaStoreGateway = stubMediaStoreGateway, + base64ToBitmapConverter = stubBase64ToBitmapConverter, + localDataSource = stubLocalDataSource, + preferenceManager = stubPreferenceManager, + localDiffusion = stubLocalDiffusion, + downloadableLocalDataSource = stubDownloadableLocalDataSource, + bitmapToBase64Converter = stubBitmapToBase64Converter, + schedulersProvider = stubSchedulersProvider, + ) + + @Before + fun initialize() { + every { + stubLocalDiffusion.observeStatus() + } returns stubStatus + + every { + stubPreferenceManager.autoSaveAiResults + } returns false + } + + @Test + fun `given attempt to observe status, local emits two values, expected same values with same order`() { + val stubObserver = repository.observeStatus().test() + + stubStatus.onNext(LocalDiffusion.Status(1, 2)) + + stubObserver + .assertNoErrors() + .assertValueAt(0, LocalDiffusion.Status(1, 2)) + + stubStatus.onNext(LocalDiffusion.Status(2, 2)) + + stubObserver + .assertNoErrors() + .assertValueAt(1, LocalDiffusion.Status(2, 2)) + } + + @Test + fun `given attempt to observe status, local throws exception, expected error value`() { + every { + stubLocalDiffusion.observeStatus() + } returns Observable.error(stubException) + + repository + .observeStatus() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to interrupt generation, remote completes, expected complete value`() { + every { + stubLocalDiffusion.interrupt() + } returns Completable.complete() + + repository + .interruptGeneration() + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given attempt to interrupt generation, remote throws exception, expected error value`() { + every { + stubLocalDiffusion.interrupt() + } returns Completable.error(stubException) + + repository + .interruptGeneration() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to generate from text, no selected model, expected error value`() { + every { + stubDownloadableLocalDataSource.getSelected() + } returns Single.error(stubException) + + repository + .generateFromText(mockTextToImagePayload) + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to generate from text, has selected not downloaded model, expected IllegalStateException error value`() { + every { + stubDownloadableLocalDataSource.getSelected() + } returns Single.just(mockLocalAiModel.copy(downloaded = false)) + + every { + stubLocalDiffusion.process(any()) + } returns Single.just(stubBitmap) + + every { + stubBitmapToBase64Converter(any()) + } returns Single.just(BitmapToBase64Converter.Output("base64")) + + repository + .generateFromText(mockTextToImagePayload) + .test() + .assertError { t -> + t is IllegalStateException && t.message == "Model not downloaded." + } + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to generate from text, has selected downloaded model, local process success, expected valid domain model value`() { + every { + stubDownloadableLocalDataSource.getSelected() + } returns Single.just(mockLocalAiModel.copy(downloaded = true)) + + every { + stubLocalDiffusion.process(any()) + } returns Single.just(stubBitmap) + + every { + stubBitmapToBase64Converter(any()) + } returns Single.just(BitmapToBase64Converter.Output("base64")) + + repository + .generateFromText(mockTextToImagePayload) + .test() + .assertNoErrors() + .assertValue { it is AiGenerationResult } + .await() + .assertComplete() + } + + @Test + fun `given attempt to generate from text, has selected downloaded model, local process fails, expected error value`() { + every { + stubDownloadableLocalDataSource.getSelected() + } returns Single.just(mockLocalAiModel.copy(downloaded = true)) + + every { + stubLocalDiffusion.process(any()) + } returns Single.error(stubException) + + every { + stubBitmapToBase64Converter(any()) + } returns Single.just(BitmapToBase64Converter.Output("base64")) + + repository + .generateFromText(mockTextToImagePayload) + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } +} \ No newline at end of file diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/repository/OpenAiGenerationRepositoryImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/repository/OpenAiGenerationRepositoryImplTest.kt new file mode 100644 index 00000000..cd7133df --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/repository/OpenAiGenerationRepositoryImplTest.kt @@ -0,0 +1,114 @@ +package com.shifthackz.aisdv1.data.repository + +import com.shifthackz.aisdv1.core.imageprocessing.Base64ToBitmapConverter +import com.shifthackz.aisdv1.data.mocks.mockAiGenerationResult +import com.shifthackz.aisdv1.data.mocks.mockTextToImagePayload +import com.shifthackz.aisdv1.domain.datasource.GenerationResultDataSource +import com.shifthackz.aisdv1.domain.datasource.OpenAiGenerationDataSource +import com.shifthackz.aisdv1.domain.gateway.MediaStoreGateway +import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Single +import org.junit.Before +import org.junit.Test + +class OpenAiGenerationRepositoryImplTest { + + private val stubException = Throwable("Something went wrong.") + private val stubMediaStoreGateway = mockk() + private val stubBase64ToBitmapConverter = mockk() + private val stubLocalDataSource = mockk() + private val stubPreferenceManager = mockk() + private val stubRemoteDataSource = mockk() + + private val repository = OpenAiGenerationRepositoryImpl( + mediaStoreGateway = stubMediaStoreGateway, + base64ToBitmapConverter = stubBase64ToBitmapConverter, + localDataSource = stubLocalDataSource, + preferenceManager = stubPreferenceManager, + remoteDataSource = stubRemoteDataSource, + ) + + @Before + fun initialize() { + every { + stubPreferenceManager.autoSaveAiResults + } returns false + } + + @Test + fun `given attempt to validate api key, remote returns true, expected true value`() { + every { + stubRemoteDataSource.validateApiKey() + } returns Single.just(true) + + repository + .validateApiKey() + .test() + .assertNoErrors() + .assertValue(true) + .await() + .assertComplete() + } + + @Test + fun `given attempt to validate api key, remote returns false, expected false value`() { + every { + stubRemoteDataSource.validateApiKey() + } returns Single.just(false) + + repository + .validateApiKey() + .test() + .assertNoErrors() + .assertValue(false) + .await() + .assertComplete() + } + + @Test + fun `given attempt to validate api key, remote throws exception, expected error value`() { + every { + stubRemoteDataSource.validateApiKey() + } returns Single.error(stubException) + + repository + .validateApiKey() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to generate from text, remote returns result, expected valid domain model value`() { + every { + stubRemoteDataSource.textToImage(any()) + } returns Single.just(mockAiGenerationResult) + + repository + .generateFromText(mockTextToImagePayload) + .test() + .assertNoErrors() + .assertValue(mockAiGenerationResult) + .await() + .assertComplete() + } + + @Test + fun `given attempt to generate from text, remote throws exception, expected error value`() { + every { + stubRemoteDataSource.textToImage(any()) + } returns Single.error(stubException) + + repository + .generateFromText(mockTextToImagePayload) + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/repository/RandomImageRepositoryImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/repository/RandomImageRepositoryImplTest.kt new file mode 100644 index 00000000..4bbef387 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/repository/RandomImageRepositoryImplTest.kt @@ -0,0 +1,47 @@ +package com.shifthackz.aisdv1.data.repository + +import android.graphics.Bitmap +import com.shifthackz.aisdv1.domain.datasource.RandomImageDataSource +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class RandomImageRepositoryImplTest { + + private val stubBitmap = mockk() + private val stubException = Throwable("Something went wrong.") + private val stubRemoteDataSource = mockk() + + private val repository = RandomImageRepositoryImpl(stubRemoteDataSource) + + @Test + fun `given attempt to fetch and get bitmap, remote returns image, expected valid bitmap value`() { + every { + stubRemoteDataSource.fetch() + } returns Single.just(stubBitmap) + + repository + .fetchAndGet() + .test() + .assertNoErrors() + .assertValue(stubBitmap) + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch and get bitmap, remote throws exception, expected valid bitmap value`() { + every { + stubRemoteDataSource.fetch() + } returns Single.error(stubException) + + repository + .fetchAndGet() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/repository/ServerConfigurationRepositoryImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/repository/ServerConfigurationRepositoryImplTest.kt new file mode 100644 index 00000000..4667890c --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/repository/ServerConfigurationRepositoryImplTest.kt @@ -0,0 +1,202 @@ +package com.shifthackz.aisdv1.data.repository + +import com.shifthackz.aisdv1.data.mocks.mockServerConfiguration +import com.shifthackz.aisdv1.domain.datasource.ServerConfigurationDataSource +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class ServerConfigurationRepositoryImplTest { + + private val stubException = Throwable("Something went wrong.") + private val stubRemoteDataSource = mockk() + private val stubLocalDataSource = mockk() + + private val repository = ServerConfigurationRepositoryImpl( + remoteDataSource = stubRemoteDataSource, + localDataSource = stubLocalDataSource, + ) + + @Test + fun `given attempt to update configuration, remote completes, expected complete value`() { + every { + stubRemoteDataSource.updateConfiguration(any()) + } returns Completable.complete() + + repository + .updateConfiguration(mockServerConfiguration) + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given attempt to update configuration, remote throws exception, expected error value`() { + every { + stubRemoteDataSource.updateConfiguration(any()) + } returns Completable.error(stubException) + + repository + .updateConfiguration(mockServerConfiguration) + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to get configuration, local returns data, expected valid domain model value`() { + every { + stubLocalDataSource.get() + } returns Single.just(mockServerConfiguration) + + repository + .getConfiguration() + .test() + .assertNoErrors() + .assertValue(mockServerConfiguration) + .await() + .assertComplete() + } + + @Test + fun `given attempt to get configuration, local throws exception, expected error value`() { + every { + stubLocalDataSource.get() + } returns Single.error(stubException) + + repository + .getConfiguration() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `attempt to fetch configuration, remote returns data, local save success, expected complete value`() { + every { + stubRemoteDataSource.fetchConfiguration() + } returns Single.just(mockServerConfiguration) + + every { + stubLocalDataSource.save(any()) + } returns Completable.complete() + + repository + .fetchConfiguration() + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `attempt to fetch configuration, remote throws exception, expected error value`() { + every { + stubRemoteDataSource.fetchConfiguration() + } returns Single.error(stubException) + + every { + stubLocalDataSource.save(any()) + } returns Completable.complete() + + repository + .fetchConfiguration() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `attempt to fetch configuration, remote returns data, local save fails, expected error value`() { + every { + stubRemoteDataSource.fetchConfiguration() + } returns Single.just(mockServerConfiguration) + + every { + stubLocalDataSource.save(any()) + } returns Completable.error(stubException) + + repository + .fetchConfiguration() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to fetch and get, fetch success, get success, expected valid domain model value`() { + every { + stubRemoteDataSource.fetchConfiguration() + } returns Single.just(mockServerConfiguration) + + every { + stubLocalDataSource.save(any()) + } returns Completable.complete() + + every { + stubLocalDataSource.get() + } returns Single.just(mockServerConfiguration) + + repository + .fetchAndGetConfiguration() + .test() + .assertNoErrors() + .assertValue(mockServerConfiguration) + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch and get, fetch fails, get success, expected valid domain model value`() { + every { + stubRemoteDataSource.fetchConfiguration() + } returns Single.error(stubException) + + every { + stubLocalDataSource.save(any()) + } returns Completable.complete() + + every { + stubLocalDataSource.get() + } returns Single.just(mockServerConfiguration) + + repository + .fetchAndGetConfiguration() + .test() + .assertNoErrors() + .assertValue(mockServerConfiguration) + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch and get, fetch fails, get fails, expected error value`() { + every { + stubRemoteDataSource.fetchConfiguration() + } returns Single.error(stubException) + + every { + stubLocalDataSource.save(any()) + } returns Completable.complete() + + every { + stubLocalDataSource.get() + } returns Single.error(stubException) + + repository + .fetchAndGetConfiguration() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/repository/StabilityAiCreditsRepositoryImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/repository/StabilityAiCreditsRepositoryImplTest.kt new file mode 100644 index 00000000..3cc310df --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/repository/StabilityAiCreditsRepositoryImplTest.kt @@ -0,0 +1,427 @@ +package com.shifthackz.aisdv1.data.repository + +import com.shifthackz.aisdv1.domain.datasource.StabilityAiCreditsDataSource +import com.shifthackz.aisdv1.domain.entity.ServerSource +import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.BackpressureStrategy +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Flowable +import io.reactivex.rxjava3.core.Single +import io.reactivex.rxjava3.observers.TestObserver +import io.reactivex.rxjava3.subjects.BehaviorSubject +import io.reactivex.rxjava3.subscribers.TestSubscriber +import org.junit.Before +import org.junit.Test + +class StabilityAiCreditsRepositoryImplTest { + + private val stubException = Throwable("Something went wrong.") + private val stubCredits = BehaviorSubject.create() + private val stubRemoteDataSource = mockk() + private val stubLocalDataSource = mockk() + private val stubPreferenceManager = mockk() + + private val repository = StabilityAiCreditsRepositoryImpl( + remoteDataSource = stubRemoteDataSource, + localDataSource = stubLocalDataSource, + preferenceManager = stubPreferenceManager, + ) + + @Before + fun initialize() { + every { + stubLocalDataSource.observe() + } returns stubCredits.toFlowable(BackpressureStrategy.LATEST) + } + + @Test + fun `given server source is not STABILITY_AI, attempt to fetch, expected IllegalStateException error value`() { + every { + stubPreferenceManager.source + } returns ServerSource.LOCAL + + every { + stubRemoteDataSource.fetch() + } returns Single.just(5598f) + + every { + stubLocalDataSource.save(any()) + } returns Completable.complete() + + repository + .fetch() + .test() + .assertWrongServerSourceSelected() + .await() + .assertNotComplete() + } + + @Test + fun `given server source is not STABILITY_AI, attempt to fetch and get, expected IllegalStateException error value`() { + every { + stubPreferenceManager.source + } returns ServerSource.LOCAL + + every { + stubRemoteDataSource.fetch() + } returns Single.just(5598f) + + every { + stubLocalDataSource.get() + } returns Single.just(5598f) + + every { + stubLocalDataSource.save(any()) + } returns Completable.complete() + + repository + .fetchAndGet() + .test() + .assertWrongServerSourceSelected() + .await() + .assertNotComplete() + } + + @Test + fun `given server source is not STABILITY_AI, attempt to fetch and observe, expected IllegalStateException error value`() { + every { + stubPreferenceManager.source + } returns ServerSource.LOCAL + + every { + stubRemoteDataSource.fetch() + } returns Single.just(5598f) + + every { + stubLocalDataSource.save(any()) + } returns Completable.complete() + + repository + .fetchAndObserve() + .test() + .assertWrongServerSourceSelected() + .await() + .assertNotComplete() + } + + @Test + fun `given server source is not STABILITY_AI, attempt to get, expected IllegalStateException error value`() { + every { + stubPreferenceManager.source + } returns ServerSource.LOCAL + + every { + stubLocalDataSource.get() + } returns Single.just(5598f) + + repository + .get() + .test() + .assertWrongServerSourceSelected() + .await() + .assertNotComplete() + } + + @Test + fun `given server source is not STABILITY_AI, attempt to observe, expected IllegalStateException error value`() { + every { + stubPreferenceManager.source + } returns ServerSource.LOCAL + + repository + .observe() + .test() + .assertWrongServerSourceSelected() + .await() + .assertNotComplete() + } + + @Test + fun `given server source is STABILITY_AI, attempt to fetch, remote returns data, local save success, expected complete value`() { + every { + stubPreferenceManager.source + } returns ServerSource.STABILITY_AI + + every { + stubRemoteDataSource.fetch() + } returns Single.just(5598f) + + every { + stubLocalDataSource.save(any()) + } returns Completable.complete() + + repository + .fetch() + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given server source is STABILITY_AI, attempt to fetch, remote returns error, local save success, expected error value`() { + every { + stubPreferenceManager.source + } returns ServerSource.STABILITY_AI + + every { + stubRemoteDataSource.fetch() + } returns Single.error(stubException) + + every { + stubLocalDataSource.save(any()) + } returns Completable.complete() + + repository + .fetch() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given server source is STABILITY_AI, attempt to fetch, remote returns data, local save fails, expected error value`() { + every { + stubPreferenceManager.source + } returns ServerSource.STABILITY_AI + + every { + stubRemoteDataSource.fetch() + } returns Single.just(5598f) + + every { + stubLocalDataSource.save(any()) + } returns Completable.error(stubException) + + repository + .fetch() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given server source is STABILITY_AI, attempt to fetch and get, fetch success, get success, expected valid credits value`() { + every { + stubPreferenceManager.source + } returns ServerSource.STABILITY_AI + + every { + stubRemoteDataSource.fetch() + } returns Single.just(5598f) + + every { + stubLocalDataSource.save(any()) + } returns Completable.complete() + + every { + stubLocalDataSource.get() + } returns Single.just(5598f) + + repository + .fetchAndGet() + .test() + .assertNoErrors() + .assertValue(5598f) + .await() + .assertComplete() + } + + @Test + fun `given server source is STABILITY_AI, attempt to fetch and get, fetch fails, get success, expected valid credits value`() { + every { + stubPreferenceManager.source + } returns ServerSource.STABILITY_AI + + every { + stubRemoteDataSource.fetch() + } returns Single.error(stubException) + + every { + stubLocalDataSource.save(any()) + } returns Completable.complete() + + every { + stubLocalDataSource.get() + } returns Single.just(5598f) + + repository + .fetchAndGet() + .test() + .assertNoErrors() + .assertValue(5598f) + .await() + .assertComplete() + } + + @Test + fun `given server source is STABILITY_AI, attempt to fetch and get, fetch fails, get fails, expected error value`() { + every { + stubPreferenceManager.source + } returns ServerSource.STABILITY_AI + + every { + stubRemoteDataSource.fetch() + } returns Single.error(stubException) + + every { + stubLocalDataSource.save(any()) + } returns Completable.complete() + + every { + stubLocalDataSource.get() + } returns Single.error(stubException) + + repository + .fetchAndGet() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given server source is STABILITY_AI, attempt to fetch and observe, fetch success, expected valid credits value`() { + every { + stubPreferenceManager.source + } returns ServerSource.STABILITY_AI + + every { + stubRemoteDataSource.fetch() + } returns Single.just(5598f) + + every { + stubLocalDataSource.save(any()) + } returns Completable.complete() + + val stubObserver = repository + .fetchAndObserve() + .test() + + stubCredits.onNext(5598f) + + stubObserver + .assertNoErrors() + .assertValueAt(0, 5598f) + } + + @Test + fun `given server source is STABILITY_AI, attempt to fetch and observe, fetch fails, expected valid credits value`() { + every { + stubPreferenceManager.source + } returns ServerSource.STABILITY_AI + + every { + stubRemoteDataSource.fetch() + } returns Single.error(stubException) + + every { + stubLocalDataSource.save(any()) + } returns Completable.complete() + + val stubObserver = repository + .fetchAndObserve() + .test() + + stubCredits.onNext(0f) + + stubObserver + .assertNoErrors() + .assertValueAt(0, 0f) + } + + @Test + fun `given server source is STABILITY_AI, attempt to get, local returns data, expected valid credits value`() { + every { + stubPreferenceManager.source + } returns ServerSource.STABILITY_AI + + every { + stubLocalDataSource.get() + } returns Single.just(5598f) + + repository + .get() + .test() + .assertNoErrors() + .assertValue(5598f) + .await() + .assertComplete() + } + + @Test + fun `given server source is STABILITY_AI, attempt to get, local throws exception, expected error value`() { + every { + stubPreferenceManager.source + } returns ServerSource.STABILITY_AI + + every { + stubLocalDataSource.get() + } returns Single.error(stubException) + + repository + .get() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given server source is STABILITY_AI, attempt to observe, local emits two values, expected valid credits values in same order`() { + every { + stubPreferenceManager.source + } returns ServerSource.STABILITY_AI + + val stubObserver = repository + .observe() + .test() + + stubCredits.onNext(0f) + + stubObserver + .assertNoErrors() + .assertValueAt(0, 0f) + + stubCredits.onNext(5598f) + + stubObserver + .assertNoErrors() + .assertValueAt(1, 5598f) + } + + @Test + fun `given server source is STABILITY_AI, attempt to observe, local throws exception, expected error value`() { + every { + stubPreferenceManager.source + } returns ServerSource.STABILITY_AI + + every { + stubLocalDataSource.observe() + } returns Flowable.error(stubException) + + repository + .observe() + .test() + .assertNoValues() + .assertError(stubException) + .await() + .assertNotComplete() + } + + private fun TestObserver.assertWrongServerSourceSelected() = this + .assertError { t -> t.wrongSourceErrorPredicate() } + + private fun TestSubscriber.assertWrongServerSourceSelected() = this + .assertError { t -> t.wrongSourceErrorPredicate() } + + private fun Throwable.wrongSourceErrorPredicate(): Boolean { + return this is IllegalStateException && message == "Wrong server source selected." + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/repository/StabilityAiEnginesRepositoryImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/repository/StabilityAiEnginesRepositoryImplTest.kt new file mode 100644 index 00000000..24a1b319 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/repository/StabilityAiEnginesRepositoryImplTest.kt @@ -0,0 +1,61 @@ +package com.shifthackz.aisdv1.data.repository + +import com.shifthackz.aisdv1.data.mocks.mockStabilityAiEngines +import com.shifthackz.aisdv1.domain.datasource.StabilityAiEnginesDataSource +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class StabilityAiEnginesRepositoryImplTest { + + private val stubException = Throwable("Something went wrong.") + private val stubRemoteDataSource = mockk() + + private val repository = StabilityAiEnginesRepositoryImpl(stubRemoteDataSource) + + @Test + fun `given attempt to fetch and get engines, remote returns data, expected valid domain model list value`() { + every { + stubRemoteDataSource.fetch() + } returns Single.just(mockStabilityAiEngines) + + repository + .fetchAndGet() + .test() + .assertNoErrors() + .assertValue(mockStabilityAiEngines) + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch and get engines, remote returns empty data, expected empty domain model list value`() { + every { + stubRemoteDataSource.fetch() + } returns Single.just(emptyList()) + + repository + .fetchAndGet() + .test() + .assertNoErrors() + .assertValue(emptyList()) + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch and get engines, remote throws exception, expected error value`() { + every { + stubRemoteDataSource.fetch() + } returns Single.error(stubException) + + repository + .fetchAndGet() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/repository/StabilityAiGenerationRepositoryImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/repository/StabilityAiGenerationRepositoryImplTest.kt new file mode 100644 index 00000000..f9f4afad --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/repository/StabilityAiGenerationRepositoryImplTest.kt @@ -0,0 +1,177 @@ +package com.shifthackz.aisdv1.data.repository + +import android.graphics.Bitmap +import com.shifthackz.aisdv1.core.imageprocessing.Base64ToBitmapConverter +import com.shifthackz.aisdv1.data.mocks.mockAiGenerationResult +import com.shifthackz.aisdv1.data.mocks.mockImageToImagePayload +import com.shifthackz.aisdv1.data.mocks.mockTextToImagePayload +import com.shifthackz.aisdv1.domain.datasource.GenerationResultDataSource +import com.shifthackz.aisdv1.domain.datasource.StabilityAiCreditsDataSource +import com.shifthackz.aisdv1.domain.datasource.StabilityAiGenerationDataSource +import com.shifthackz.aisdv1.domain.gateway.MediaStoreGateway +import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single +import org.junit.Before +import org.junit.Test + +class StabilityAiGenerationRepositoryImplTest { + + private val stubBitmap = mockk() + private val stubException = Throwable("Something went wrong.") + private val stubMediaStoreGateway = mockk() + private val stubBase64ToBitmapConverter = mockk() + private val stubLocalDataSource = mockk() + private val stubPreferenceManager = mockk() + private val stubRemoteDataSource = mockk() + private val stubCreditsRds = mockk() + private val stubCreditsLds = mockk() + + private val repository = StabilityAiGenerationRepositoryImpl( + mediaStoreGateway = stubMediaStoreGateway, + base64ToBitmapConverter = stubBase64ToBitmapConverter, + localDataSource = stubLocalDataSource, + preferenceManager = stubPreferenceManager, + generationRds = stubRemoteDataSource, + creditsRds = stubCreditsRds, + creditsLds = stubCreditsLds, + ) + + @Before + fun initialize() { + every { + stubPreferenceManager.autoSaveAiResults + } returns false + + every { + stubPreferenceManager.stabilityAiEngineId + } returns "engine_5598" + + every { + stubCreditsRds.fetch() + } returns Single.just(5598f) + + every { + stubCreditsLds.save(any()) + } returns Completable.complete() + + every { + stubBitmap.compress(any(), any(), any()) + } returns true + } + + @Test + fun `given attempt to validate api key, remote returns true, expected true value`() { + every { + stubRemoteDataSource.validateApiKey() + } returns Single.just(true) + + repository + .validateApiKey() + .test() + .assertNoErrors() + .assertValue(true) + .await() + .assertComplete() + } + + @Test + fun `given attempt to validate api key, remote returns false, expected false value`() { + every { + stubRemoteDataSource.validateApiKey() + } returns Single.just(false) + + repository + .validateApiKey() + .test() + .assertNoErrors() + .assertValue(false) + .await() + .assertComplete() + } + + @Test + fun `given attempt to validate api key, remote throws exception, expected error value`() { + every { + stubRemoteDataSource.validateApiKey() + } returns Single.error(stubException) + + repository + .validateApiKey() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to generate from text, remote returns result, expected valid domain model value`() { + every { + stubRemoteDataSource.textToImage(any(), any()) + } returns Single.just(mockAiGenerationResult) + + repository + .generateFromText(mockTextToImagePayload) + .test() + .assertNoErrors() + .assertValue(mockAiGenerationResult) + .await() + .assertComplete() + } + + @Test + fun `given attempt to generate from text, remote throws exception, expected error value`() { + every { + stubRemoteDataSource.textToImage(any(), any()) + } returns Single.error(stubException) + + repository + .generateFromText(mockTextToImagePayload) + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to generate from image, remote returns result, expected valid domain model value`() { + every { + stubRemoteDataSource.imageToImage(any(), any(), any()) + } returns Single.just(mockAiGenerationResult) + + every { + stubBase64ToBitmapConverter(any()) + } returns Single.just(Base64ToBitmapConverter.Output(stubBitmap)) + + repository + .generateFromImage(mockImageToImagePayload) + .test() + .assertNoErrors() + .assertValue(mockAiGenerationResult) + .await() + .assertComplete() + } + + @Test + fun `given attempt to generate from image, remote throws exception, expected error value`() { + every { + stubRemoteDataSource.imageToImage(any(), any(), any()) + } returns Single.error(stubException) + + every { + stubBase64ToBitmapConverter(any()) + } returns Single.just(Base64ToBitmapConverter.Output(stubBitmap)) + + repository + .generateFromImage(mockImageToImagePayload) + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/repository/StableDiffusionEmbeddingsRepositoryImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/repository/StableDiffusionEmbeddingsRepositoryImplTest.kt new file mode 100644 index 00000000..933c19c3 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/repository/StableDiffusionEmbeddingsRepositoryImplTest.kt @@ -0,0 +1,185 @@ +package com.shifthackz.aisdv1.data.repository + +import com.shifthackz.aisdv1.data.mocks.mockStableDiffusionEmbeddings +import com.shifthackz.aisdv1.domain.datasource.StableDiffusionEmbeddingsDataSource +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class StableDiffusionEmbeddingsRepositoryImplTest { + + private val stubException = Throwable("Something went wrong.") + private val stubRemoteDataSource = mockk() + private val stubLocalDataSource = mockk() + + private val repository = StableDiffusionEmbeddingsRepositoryImpl( + remoteDataSource = stubRemoteDataSource, + localDataSource = stubLocalDataSource, + ) + + @Test + fun `given attempt to fetch embeddings, remote returns data, local insert success, expected complete value`() { + every { + stubRemoteDataSource.fetchEmbeddings() + } returns Single.just(mockStableDiffusionEmbeddings) + + every { + stubLocalDataSource.insertEmbeddings(any()) + } returns Completable.complete() + + repository + .fetchEmbeddings() + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch embeddings, remote throws exception, local insert success, expected error value`() { + every { + stubRemoteDataSource.fetchEmbeddings() + } returns Single.error(stubException) + + every { + stubLocalDataSource.insertEmbeddings(any()) + } returns Completable.complete() + + repository + .fetchEmbeddings() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to fetch embeddings, remote returns data, local insert fails, expected error value`() { + every { + stubRemoteDataSource.fetchEmbeddings() + } returns Single.just(mockStableDiffusionEmbeddings) + + every { + stubLocalDataSource.insertEmbeddings(any()) + } returns Completable.error(stubException) + + repository + .fetchEmbeddings() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to get embeddings, local data source returns list, expected valid domain models list value`() { + every { + stubLocalDataSource.getEmbeddings() + } returns Single.just(mockStableDiffusionEmbeddings) + + repository + .getEmbeddings() + .test() + .assertNoErrors() + .assertValue(mockStableDiffusionEmbeddings) + .await() + .assertComplete() + } + + @Test + fun `given attempt to get embeddings, local data source returns empty list, expected empty domain models list value`() { + every { + stubLocalDataSource.getEmbeddings() + } returns Single.just(emptyList()) + + repository + .getEmbeddings() + .test() + .assertNoErrors() + .assertValue(emptyList()) + .await() + .assertComplete() + } + + @Test + fun `given attempt to get embeddings, local data source throws exception, expected error value`() { + every { + stubLocalDataSource.getEmbeddings() + } returns Single.error(stubException) + + repository + .getEmbeddings() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to fetch and get embeddings, remote returns data, local returns data, expected valid domain models list value`() { + every { + stubRemoteDataSource.fetchEmbeddings() + } returns Single.just(mockStableDiffusionEmbeddings) + + every { + stubLocalDataSource.insertEmbeddings(any()) + } returns Completable.complete() + + every { + stubLocalDataSource.getEmbeddings() + } returns Single.just(mockStableDiffusionEmbeddings) + + repository + .fetchAndGetEmbeddings() + .test() + .assertNoErrors() + .assertValue(mockStableDiffusionEmbeddings) + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch and get embeddings, remote fails, local returns data, expected valid domain models list value`() { + every { + stubRemoteDataSource.fetchEmbeddings() + } returns Single.error(stubException) + + every { + stubLocalDataSource.insertEmbeddings(any()) + } returns Completable.complete() + + every { + stubLocalDataSource.getEmbeddings() + } returns Single.just(mockStableDiffusionEmbeddings) + + repository + .fetchAndGetEmbeddings() + .test() + .assertNoErrors() + .assertValue(mockStableDiffusionEmbeddings) + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch and get embeddings, remote fails, local fails, expected valid error value`() { + every { + stubRemoteDataSource.fetchEmbeddings() + } returns Single.error(stubException) + + every { + stubLocalDataSource.getEmbeddings() + } returns Single.error(stubException) + + repository + .fetchAndGetEmbeddings() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/repository/StableDiffusionGenerationRepositoryImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/repository/StableDiffusionGenerationRepositoryImplTest.kt new file mode 100644 index 00000000..8ad5c9b5 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/repository/StableDiffusionGenerationRepositoryImplTest.kt @@ -0,0 +1,283 @@ +package com.shifthackz.aisdv1.data.repository + +import com.shifthackz.aisdv1.core.imageprocessing.Base64ToBitmapConverter +import com.shifthackz.aisdv1.data.mocks.mockAiGenerationResult +import com.shifthackz.aisdv1.data.mocks.mockImageToImagePayload +import com.shifthackz.aisdv1.data.mocks.mockTextToImagePayload +import com.shifthackz.aisdv1.domain.datasource.GenerationResultDataSource +import com.shifthackz.aisdv1.domain.datasource.StableDiffusionGenerationDataSource +import com.shifthackz.aisdv1.domain.demo.ImageToImageDemo +import com.shifthackz.aisdv1.domain.demo.TextToImageDemo +import com.shifthackz.aisdv1.domain.gateway.MediaStoreGateway +import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single +import org.junit.Before +import org.junit.Test + +class StableDiffusionGenerationRepositoryImplTest { + + private val stubException = Throwable("Something went wrong.") + private val stubMediaStoreGateway = mockk() + private val stubBase64ToBitmapConverter = mockk() + private val stubLocalDataSource = mockk() + private val stubRemoteDataSource = mockk() + private val stubPreferenceManager = mockk() + private val stubTextToImageDemo = mockk() + private val stubImageToImageDemo = mockk() + + private val repository = StableDiffusionGenerationRepositoryImpl( + mediaStoreGateway = stubMediaStoreGateway, + base64ToBitmapConverter = stubBase64ToBitmapConverter, + localDataSource = stubLocalDataSource, + remoteDataSource = stubRemoteDataSource, + preferenceManager = stubPreferenceManager, + textToImageDemo = stubTextToImageDemo, + imageToImageDemo = stubImageToImageDemo, + ) + + @Before + fun initialize() { + every { + stubPreferenceManager.autoSaveAiResults + } returns false + } + + @Test + fun `given attempt to check api availability, remote completes, expected complete value`() { + every { + stubRemoteDataSource.checkAvailability() + } returns Completable.complete() + + repository + .checkApiAvailability() + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given attempt to check api availability, remote throws exception, expected error value`() { + every { + stubRemoteDataSource.checkAvailability() + } returns Completable.error(stubException) + + repository + .checkApiAvailability() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to check api availability by url, remote completes, expected complete value`() { + every { + stubRemoteDataSource.checkAvailability(any()) + } returns Completable.complete() + + repository + .checkApiAvailability("https://5598.is.my.favourite.com") + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given attempt to check api availability by url, remote throws exception, expected error value`() { + every { + stubRemoteDataSource.checkAvailability(any()) + } returns Completable.error(stubException) + + repository + .checkApiAvailability("https://5598.is.my.favourite.com") + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to generate from text, demo mode is on, demo returns result, expected valid domain model value`() { + every { + stubPreferenceManager.demoMode + } returns true + + every { + stubTextToImageDemo.getDemoBase64(any()) + } returns Single.just(mockAiGenerationResult) + + repository + .generateFromText(mockTextToImagePayload) + .test() + .assertNoErrors() + .assertValue(mockAiGenerationResult) + .await() + .assertComplete() + } + + @Test + fun `given attempt to generate from text, demo mode is on, demo throws exception, expected error value`() { + every { + stubPreferenceManager.demoMode + } returns true + + every { + stubTextToImageDemo.getDemoBase64(any()) + } returns Single.error(stubException) + + repository + .generateFromText(mockTextToImagePayload) + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to generate from text, demo mode is off, remote returns result, expected valid domain model value`() { + every { + stubPreferenceManager.demoMode + } returns false + + every { + stubRemoteDataSource.textToImage(any()) + } returns Single.just(mockAiGenerationResult) + + repository + .generateFromText(mockTextToImagePayload) + .test() + .assertNoErrors() + .assertValue(mockAiGenerationResult) + .await() + .assertComplete() + } + + @Test + fun `given attempt to generate from text, demo mode is off, remote throws exception, expected error value`() { + every { + stubPreferenceManager.demoMode + } returns false + + every { + stubRemoteDataSource.textToImage(any()) + } returns Single.error(stubException) + + repository + .generateFromText(mockTextToImagePayload) + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to generate from image, demo mode is on, demo returns result, expected valid domain model value`() { + every { + stubPreferenceManager.demoMode + } returns true + + every { + stubImageToImageDemo.getDemoBase64(any()) + } returns Single.just(mockAiGenerationResult) + + repository + .generateFromImage(mockImageToImagePayload) + .test() + .assertNoErrors() + .assertValue(mockAiGenerationResult) + .await() + .assertComplete() + } + + @Test + fun `given attempt to generate from image, demo mode is on, demo throws exception, expected error value`() { + every { + stubPreferenceManager.demoMode + } returns true + + every { + stubImageToImageDemo.getDemoBase64(any()) + } returns Single.error(stubException) + + repository + .generateFromImage(mockImageToImagePayload) + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to generate from image, demo mode is off, remote returns result, expected valid domain model value`() { + every { + stubPreferenceManager.demoMode + } returns false + + every { + stubRemoteDataSource.imageToImage(any()) + } returns Single.just(mockAiGenerationResult) + + repository + .generateFromImage(mockImageToImagePayload) + .test() + .assertNoErrors() + .assertValue(mockAiGenerationResult) + .await() + .assertComplete() + } + + @Test + fun `given attempt to generate from image, demo mode is off, remote throws exception, expected error value`() { + every { + stubPreferenceManager.demoMode + } returns false + + every { + stubRemoteDataSource.imageToImage(any()) + } returns Single.error(stubException) + + repository + .generateFromImage(mockImageToImagePayload) + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to interrupt generation, remote completes, expected complete value`() { + every { + stubRemoteDataSource.interruptGeneration() + } returns Completable.complete() + + repository + .interruptGeneration() + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given attempt to interrupt generation, remote throws exception, expected error value`() { + every { + stubRemoteDataSource.interruptGeneration() + } returns Completable.error(stubException) + + repository + .interruptGeneration() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/repository/StableDiffusionHyperNetworksRepositoryImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/repository/StableDiffusionHyperNetworksRepositoryImplTest.kt new file mode 100644 index 00000000..7764e2c5 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/repository/StableDiffusionHyperNetworksRepositoryImplTest.kt @@ -0,0 +1,185 @@ +package com.shifthackz.aisdv1.data.repository + +import com.shifthackz.aisdv1.data.mocks.mockStableDiffusionHyperNetworks +import com.shifthackz.aisdv1.domain.datasource.StableDiffusionHyperNetworksDataSource +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class StableDiffusionHyperNetworksRepositoryImplTest { + + private val stubException = Throwable("Something went wrong.") + private val stubRemoteDataSource = mockk() + private val stubLocalDataSource = mockk() + + private val repository = StableDiffusionHyperNetworksRepositoryImpl( + stubRemoteDataSource, + stubLocalDataSource, + ) + + @Test + fun `given attempt to fetch hyper networks, remote returns data, local insert success, expected complete value`() { + every { + stubRemoteDataSource.fetchHyperNetworks() + } returns Single.just(mockStableDiffusionHyperNetworks) + + every { + stubLocalDataSource.insertHyperNetworks(any()) + } returns Completable.complete() + + repository + .fetchHyperNetworks() + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch hyper networks, remote throws exception, local insert success, expected error value`() { + every { + stubRemoteDataSource.fetchHyperNetworks() + } returns Single.error(stubException) + + every { + stubLocalDataSource.insertHyperNetworks(any()) + } returns Completable.complete() + + repository + .fetchHyperNetworks() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to fetch hyper networks, remote returns data, local insert fails, expected error value`() { + every { + stubRemoteDataSource.fetchHyperNetworks() + } returns Single.just(mockStableDiffusionHyperNetworks) + + every { + stubLocalDataSource.insertHyperNetworks(any()) + } returns Completable.error(stubException) + + repository + .fetchHyperNetworks() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to get hyper networks, local data source returns list, expected valid domain models list value`() { + every { + stubLocalDataSource.getHyperNetworks() + } returns Single.just(mockStableDiffusionHyperNetworks) + + repository + .getHyperNetworks() + .test() + .assertNoErrors() + .assertValue(mockStableDiffusionHyperNetworks) + .await() + .assertComplete() + } + + @Test + fun `given attempt to get hyper networks, local data source returns empty list, expected empty domain models list value`() { + every { + stubLocalDataSource.getHyperNetworks() + } returns Single.just(emptyList()) + + repository + .getHyperNetworks() + .test() + .assertNoErrors() + .assertValue(emptyList()) + .await() + .assertComplete() + } + + @Test + fun `given attempt to get hyper networks, local data source throws exception, expected error value`() { + every { + stubLocalDataSource.getHyperNetworks() + } returns Single.error(stubException) + + repository + .getHyperNetworks() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to fetch and get hyper networks, remote returns data, local returns data, expected valid domain models list value`() { + every { + stubRemoteDataSource.fetchHyperNetworks() + } returns Single.just(mockStableDiffusionHyperNetworks) + + every { + stubLocalDataSource.insertHyperNetworks(any()) + } returns Completable.complete() + + every { + stubLocalDataSource.getHyperNetworks() + } returns Single.just(mockStableDiffusionHyperNetworks) + + repository + .fetchAndGetHyperNetworks() + .test() + .assertNoErrors() + .assertValue(mockStableDiffusionHyperNetworks) + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch and get hyper networks, remote fails, local returns data, expected valid domain models list value`() { + every { + stubRemoteDataSource.fetchHyperNetworks() + } returns Single.error(stubException) + + every { + stubLocalDataSource.insertHyperNetworks(any()) + } returns Completable.complete() + + every { + stubLocalDataSource.getHyperNetworks() + } returns Single.just(mockStableDiffusionHyperNetworks) + + repository + .fetchAndGetHyperNetworks() + .test() + .assertNoErrors() + .assertValue(mockStableDiffusionHyperNetworks) + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch and get hyper networks, remote fails, local fails, expected valid error value`() { + every { + stubRemoteDataSource.fetchHyperNetworks() + } returns Single.error(stubException) + + every { + stubLocalDataSource.getHyperNetworks() + } returns Single.error(stubException) + + repository + .fetchAndGetHyperNetworks() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/repository/StableDiffusionLorasRepositoryImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/repository/StableDiffusionLorasRepositoryImplTest.kt new file mode 100644 index 00000000..5fbf1aaa --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/repository/StableDiffusionLorasRepositoryImplTest.kt @@ -0,0 +1,186 @@ +package com.shifthackz.aisdv1.data.repository + +import com.shifthackz.aisdv1.data.mocks.mockStableDiffusionLoras +import com.shifthackz.aisdv1.data.mocks.mockStableDiffusionModels +import com.shifthackz.aisdv1.domain.datasource.StableDiffusionLorasDataSource +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class StableDiffusionLorasRepositoryImplTest { + + private val stubException = Throwable("Something went wrong.") + private val stubRemoteDataSource = mockk() + private val stubLocalDataSource = mockk() + + private val repository = StableDiffusionLorasRepositoryImpl( + remoteDataSource = stubRemoteDataSource, + localDataSource = stubLocalDataSource, + ) + + @Test + fun `given attempt to fetch loras, remote returns data, local insert success, expected complete value`() { + every { + stubRemoteDataSource.fetchLoras() + } returns Single.just(mockStableDiffusionLoras) + + every { + stubLocalDataSource.insertLoras(any()) + } returns Completable.complete() + + repository + .fetchLoras() + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch loras, remote throws exception, local insert success, expected error value`() { + every { + stubRemoteDataSource.fetchLoras() + } returns Single.error(stubException) + + every { + stubLocalDataSource.insertLoras(any()) + } returns Completable.complete() + + repository + .fetchLoras() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to fetch loras, remote returns data, local insert fails, expected error value`() { + every { + stubRemoteDataSource.fetchLoras() + } returns Single.just(mockStableDiffusionLoras) + + every { + stubLocalDataSource.insertLoras(any()) + } returns Completable.error(stubException) + + repository + .fetchLoras() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to get loras, local data source returns list, expected valid domain models list value`() { + every { + stubLocalDataSource.getLoras() + } returns Single.just(mockStableDiffusionLoras) + + repository + .getLoras() + .test() + .assertNoErrors() + .assertValue(mockStableDiffusionLoras) + .await() + .assertComplete() + } + + @Test + fun `given attempt to get loras, local data source returns empty list, expected empty domain models list value`() { + every { + stubLocalDataSource.getLoras() + } returns Single.just(emptyList()) + + repository + .getLoras() + .test() + .assertNoErrors() + .assertValue(emptyList()) + .await() + .assertComplete() + } + + @Test + fun `given attempt to get loras, local data source throws exception, expected error value`() { + every { + stubLocalDataSource.getLoras() + } returns Single.error(stubException) + + repository + .getLoras() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to fetch and get loras, remote returns data, local returns data, expected valid domain models list value`() { + every { + stubRemoteDataSource.fetchLoras() + } returns Single.just(mockStableDiffusionLoras) + + every { + stubLocalDataSource.insertLoras(any()) + } returns Completable.complete() + + every { + stubLocalDataSource.getLoras() + } returns Single.just(mockStableDiffusionLoras) + + repository + .fetchAndGetLoras() + .test() + .assertNoErrors() + .assertValue(mockStableDiffusionLoras) + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch and get loras, remote fails, local returns data, expected valid domain models list value`() { + every { + stubRemoteDataSource.fetchLoras() + } returns Single.error(stubException) + + every { + stubLocalDataSource.insertLoras(any()) + } returns Completable.complete() + + every { + stubLocalDataSource.getLoras() + } returns Single.just(mockStableDiffusionLoras) + + repository + .fetchAndGetLoras() + .test() + .assertNoErrors() + .assertValue(mockStableDiffusionLoras) + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch and get loras, remote fails, local fails, expected valid error value`() { + every { + stubRemoteDataSource.fetchLoras() + } returns Single.error(stubException) + + every { + stubLocalDataSource.getLoras() + } returns Single.error(stubException) + + repository + .fetchAndGetLoras() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/repository/StableDiffusionModelsRepositoryImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/repository/StableDiffusionModelsRepositoryImplTest.kt new file mode 100644 index 00000000..6ab990f6 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/repository/StableDiffusionModelsRepositoryImplTest.kt @@ -0,0 +1,187 @@ +package com.shifthackz.aisdv1.data.repository + +import com.nhaarman.mockitokotlin2.mock +import com.shifthackz.aisdv1.data.mocks.mockStableDiffusionModels +import com.shifthackz.aisdv1.data.mocks.mockStableDiffusionSamplers +import com.shifthackz.aisdv1.domain.datasource.StableDiffusionModelsDataSource +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class StableDiffusionModelsRepositoryImplTest { + + private val stubException = Throwable("Something went wrong.") + private val stubRemoteDataSource = mockk() + private val stubLocalDataSource = mockk() + + private val repository = StableDiffusionModelsRepositoryImpl( + remoteDataSource = stubRemoteDataSource, + localDataSource = stubLocalDataSource, + ) + + @Test + fun `given attempt to fetch models, remote returns data, local insert success, expected complete value`() { + every { + stubRemoteDataSource.fetchSdModels() + } returns Single.just(mockStableDiffusionModels) + + every { + stubLocalDataSource.insertModels(any()) + } returns Completable.complete() + + repository + .fetchModels() + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch models, remote throws exception, local insert success, expected error value`() { + every { + stubRemoteDataSource.fetchSdModels() + } returns Single.error(stubException) + + every { + stubLocalDataSource.insertModels(any()) + } returns Completable.complete() + + repository + .fetchModels() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to fetch models, remote returns data, local insert fails, expected error value`() { + every { + stubRemoteDataSource.fetchSdModels() + } returns Single.just(mockStableDiffusionModels) + + every { + stubLocalDataSource.insertModels(any()) + } returns Completable.error(stubException) + + repository + .fetchModels() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to get models, local data source returns list, expected valid domain models list value`() { + every { + stubLocalDataSource.getModels() + } returns Single.just(mockStableDiffusionModels) + + repository + .getModels() + .test() + .assertNoErrors() + .assertValue(mockStableDiffusionModels) + .await() + .assertComplete() + } + + @Test + fun `given attempt to get models, local data source returns empty list, expected empty domain models list value`() { + every { + stubLocalDataSource.getModels() + } returns Single.just(emptyList()) + + repository + .getModels() + .test() + .assertNoErrors() + .assertValue(emptyList()) + .await() + .assertComplete() + } + + @Test + fun `given attempt to get models, local data source throws exception, expected error value`() { + every { + stubLocalDataSource.getModels() + } returns Single.error(stubException) + + repository + .getModels() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to fetch and get models, remote returns data, local returns data, expected valid domain models list value`() { + every { + stubRemoteDataSource.fetchSdModels() + } returns Single.just(mockStableDiffusionModels) + + every { + stubLocalDataSource.insertModels(any()) + } returns Completable.complete() + + every { + stubLocalDataSource.getModels() + } returns Single.just(mockStableDiffusionModels) + + repository + .fetchAndGetModels() + .test() + .assertNoErrors() + .assertValue(mockStableDiffusionModels) + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch and get models, remote fails, local returns data, expected valid domain models list value`() { + every { + stubRemoteDataSource.fetchSdModels() + } returns Single.error(stubException) + + every { + stubLocalDataSource.insertModels(any()) + } returns Completable.complete() + + every { + stubLocalDataSource.getModels() + } returns Single.just(mockStableDiffusionModels) + + repository + .fetchAndGetModels() + .test() + .assertNoErrors() + .assertValue(mockStableDiffusionModels) + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch and get models, remote fails, local fails, expected valid error value`() { + every { + stubRemoteDataSource.fetchSdModels() + } returns Single.error(stubException) + + every { + stubLocalDataSource.getModels() + } returns Single.error(stubException) + + repository + .fetchAndGetModels() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/repository/StableDiffusionSamplersRepositoryImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/repository/StableDiffusionSamplersRepositoryImplTest.kt new file mode 100644 index 00000000..d703e3b1 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/repository/StableDiffusionSamplersRepositoryImplTest.kt @@ -0,0 +1,120 @@ +package com.shifthackz.aisdv1.data.repository + +import com.shifthackz.aisdv1.data.mocks.mockStableDiffusionSamplers +import com.shifthackz.aisdv1.domain.datasource.StableDiffusionSamplersDataSource +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class StableDiffusionSamplersRepositoryImplTest { + + private val stubException = Throwable("Something went wrong.") + private val stubRemoteDataSource = mockk() + private val stubLocalDataSource = mockk() + + private val repository = StableDiffusionSamplersRepositoryImpl( + remoteDataSource = stubRemoteDataSource, + localDataSource = stubLocalDataSource, + ) + + @Test + fun `given attempt to fetch samplers, remote returns data, local insert success, expected complete value`() { + every { + stubRemoteDataSource.fetchSamplers() + } returns Single.just(mockStableDiffusionSamplers) + + every { + stubLocalDataSource.insertSamplers(any()) + } returns Completable.complete() + + repository + .fetchSamplers() + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given attempt to fetch samplers, remote throws exception, local insert success, expected error value`() { + every { + stubRemoteDataSource.fetchSamplers() + } returns Single.error(stubException) + + every { + stubLocalDataSource.insertSamplers(any()) + } returns Completable.complete() + + repository + .fetchSamplers() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to fetch samplers, remote returns data, local insert fails, expected error value`() { + every { + stubRemoteDataSource.fetchSamplers() + } returns Single.just(mockStableDiffusionSamplers) + + every { + stubLocalDataSource.insertSamplers(any()) + } returns Completable.error(stubException) + + repository + .fetchSamplers() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given attempt to get samplers, local data source returns list, expected valid domain models list value`() { + every { + stubLocalDataSource.getSamplers() + } returns Single.just(mockStableDiffusionSamplers) + + repository + .getSamplers() + .test() + .assertNoErrors() + .assertValue(mockStableDiffusionSamplers) + .await() + .assertComplete() + } + + @Test + fun `given attempt to get samplers, local data source returns empty list, expected empty domain models list value`() { + every { + stubLocalDataSource.getSamplers() + } returns Single.just(emptyList()) + + repository + .getSamplers() + .test() + .assertNoErrors() + .assertValue(emptyList()) + .await() + .assertComplete() + } + + @Test + fun `given attempt to get samplers, local data source throws exception, expected error value`() { + every { + stubLocalDataSource.getSamplers() + } returns Single.error(stubException) + + repository + .getSamplers() + .test() + .assertError(stubException) + .assertNoValues() + .await() + .assertNotComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/repository/TemporaryGenerationResultRepositoryImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/repository/TemporaryGenerationResultRepositoryImplTest.kt new file mode 100644 index 00000000..989f8060 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/repository/TemporaryGenerationResultRepositoryImplTest.kt @@ -0,0 +1,39 @@ +package com.shifthackz.aisdv1.data.repository + +import com.shifthackz.aisdv1.data.mocks.mockAiGenerationResult +import org.junit.Test + +class TemporaryGenerationResultRepositoryImplTest { + + private val repository = TemporaryGenerationResultRepositoryImpl() + + @Test + fun `given cache is empty, then get, expected error value`() { + repository + .get() + .test() + .assertError { t -> + t is IllegalStateException && t.message == "No last cached result." + } + .await() + .assertNotComplete() + } + + @Test + fun `given cache contains value, then get, expected valid cached value`() { + repository + .put(mockAiGenerationResult) + .test() + .assertNoErrors() + .await() + .assertComplete() + + repository + .get() + .test() + .assertNoErrors() + .assertValue(mockAiGenerationResult) + .await() + .assertComplete() + } +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/repository/WakeLockRepositoryImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/repository/WakeLockRepositoryImplTest.kt new file mode 100644 index 00000000..6e796482 --- /dev/null +++ b/data/src/test/java/com/shifthackz/aisdv1/data/repository/WakeLockRepositoryImplTest.kt @@ -0,0 +1,43 @@ +package com.shifthackz.aisdv1.data.repository + +import android.os.PowerManager +import io.mockk.every +import io.mockk.mockk +import org.junit.Assert +import org.junit.Test + +class WakeLockRepositoryImplTest { + + private val stubWakeLock = mockk() + private val stubPowerManager = mockk() + + private val repository = WakeLockRepositoryImpl { stubPowerManager } + + @Test + fun `given repository is not yet initialized, attempt to get wakelock, expected repository initializes and returns wakelock`() { + every { + stubPowerManager.newWakeLock(any(), any()) + } returns stubWakeLock + + val actual = repository.wakeLock + val expected = stubWakeLock + Assert.assertEquals(expected, actual) + } + + @Test + fun `given repository already initialized, attempt to get wakelock, expected repository returns existing wakelock`() { + every { + stubPowerManager.newWakeLock(any(), any()) + } returns stubWakeLock + + val actualBeforeInit = repository.wakeLock + val expectedBeforeInit = stubWakeLock + Assert.assertEquals(expectedBeforeInit, actualBeforeInit) + + val actualAfterInit = repository.wakeLock + val expectedAfterInit = stubWakeLock + Assert.assertEquals(expectedAfterInit, actualAfterInit) + + Assert.assertEquals(actualBeforeInit, actualAfterInit) + } +} diff --git a/demo/build.gradle b/demo/build.gradle index 0cf74044..97eb5ef0 100644 --- a/demo/build.gradle +++ b/demo/build.gradle @@ -11,7 +11,10 @@ android { } dependencies { + implementation project(":core:common") implementation project(":domain") implementation di.koinCore implementation reactive.rxkotlin + testImplementation test.junit + testImplementation test.mockk } diff --git a/demo/src/main/java/com/shifthackz/aisdv1/demo/ImageToImageDemoImpl.kt b/demo/src/main/java/com/shifthackz/aisdv1/demo/ImageToImageDemoImpl.kt index 234afb68..37a15f5a 100644 --- a/demo/src/main/java/com/shifthackz/aisdv1/demo/ImageToImageDemoImpl.kt +++ b/demo/src/main/java/com/shifthackz/aisdv1/demo/ImageToImageDemoImpl.kt @@ -1,20 +1,21 @@ package com.shifthackz.aisdv1.demo +import com.shifthackz.aisdv1.core.common.time.TimeProvider import com.shifthackz.aisdv1.demo.serialize.DemoDataSerializer import com.shifthackz.aisdv1.domain.demo.ImageToImageDemo import com.shifthackz.aisdv1.domain.entity.AiGenerationResult import com.shifthackz.aisdv1.domain.entity.ImageToImagePayload -import java.util.Date internal class ImageToImageDemoImpl( override val demoDataSerializer: DemoDataSerializer, + private val timeProvider: TimeProvider, ) : ImageToImageDemo, DemoFeature { override fun mapper(input: ImageToImagePayload, base64: String) = AiGenerationResult( id = 0L, image = base64, inputImage = input.base64Image, - createdAt = Date(), + createdAt = timeProvider.currentDate(), type = AiGenerationResult.Type.IMAGE_TO_IMAGE, prompt = input.prompt, negativePrompt = input.negativePrompt, @@ -24,8 +25,8 @@ internal class ImageToImageDemoImpl( cfgScale = input.cfgScale, restoreFaces = input.restoreFaces, sampler = input.sampler, - seed = System.currentTimeMillis().toString(), - subSeed = System.currentTimeMillis().toString(), + seed = timeProvider.currentTimeMillis().toString(), + subSeed = timeProvider.currentTimeMillis().toString(), subSeedStrength = 0f, denoisingStrength = 0f, ) diff --git a/demo/src/main/java/com/shifthackz/aisdv1/demo/TextToImageDemoImpl.kt b/demo/src/main/java/com/shifthackz/aisdv1/demo/TextToImageDemoImpl.kt index ce01f893..ee183fee 100644 --- a/demo/src/main/java/com/shifthackz/aisdv1/demo/TextToImageDemoImpl.kt +++ b/demo/src/main/java/com/shifthackz/aisdv1/demo/TextToImageDemoImpl.kt @@ -1,20 +1,21 @@ package com.shifthackz.aisdv1.demo +import com.shifthackz.aisdv1.core.common.time.TimeProvider import com.shifthackz.aisdv1.demo.serialize.DemoDataSerializer import com.shifthackz.aisdv1.domain.demo.TextToImageDemo import com.shifthackz.aisdv1.domain.entity.AiGenerationResult import com.shifthackz.aisdv1.domain.entity.TextToImagePayload -import java.util.Date internal class TextToImageDemoImpl( override val demoDataSerializer: DemoDataSerializer, + private val timeProvider: TimeProvider, ) : TextToImageDemo, DemoFeature { override fun mapper(input: TextToImagePayload, base64: String) = AiGenerationResult( id = 0L, image = base64, inputImage = "", - createdAt = Date(), + createdAt = timeProvider.currentDate(), type = AiGenerationResult.Type.TEXT_TO_IMAGE, prompt = input.prompt, negativePrompt = input.negativePrompt, @@ -24,8 +25,8 @@ internal class TextToImageDemoImpl( cfgScale = input.cfgScale, restoreFaces = input.restoreFaces, sampler = input.sampler, - seed = System.currentTimeMillis().toString(), - subSeed = System.currentTimeMillis().toString(), + seed = timeProvider.currentTimeMillis().toString(), + subSeed = timeProvider.currentTimeMillis().toString(), subSeedStrength = 0f, denoisingStrength = 0f, ) diff --git a/demo/src/test/java/com/shifthackz/aisdv1/demo/ImageToImageDemoImplTest.kt b/demo/src/test/java/com/shifthackz/aisdv1/demo/ImageToImageDemoImplTest.kt new file mode 100644 index 00000000..77b5c39c --- /dev/null +++ b/demo/src/test/java/com/shifthackz/aisdv1/demo/ImageToImageDemoImplTest.kt @@ -0,0 +1,46 @@ +package com.shifthackz.aisdv1.demo + +import com.shifthackz.aisdv1.core.common.time.TimeProvider +import com.shifthackz.aisdv1.demo.mocks.mockImageToImagePayload +import com.shifthackz.aisdv1.demo.serialize.DemoDataSerializer +import io.mockk.every +import io.mockk.mockk +import org.junit.Before +import org.junit.Test +import java.util.Date + +class ImageToImageDemoImplTest { + + private val stubSerializer = DemoDataSerializer() + private val stubTimeProvider = mockk() + + private val demo = ImageToImageDemoImpl( + demoDataSerializer = stubSerializer, + timeProvider = stubTimeProvider, + ) + + @Before + fun initialize() { + every { + stubTimeProvider.currentTimeMillis() + } returns 5598L + + every { + stubTimeProvider.currentDate() + } returns Date(5598L) + } + + @Test + fun `given done demo generation, expected generation result base64 is from demo serializer`() { + demo.getDemoBase64(mockImageToImagePayload) + .test() + .await() + .assertNoErrors() + .assertValue { actual -> + stubSerializer + .readDemoAssets() + .contains(actual.image) + } + .assertComplete() + } +} diff --git a/demo/src/test/java/com/shifthackz/aisdv1/demo/TextToImageDemoImplTest.kt b/demo/src/test/java/com/shifthackz/aisdv1/demo/TextToImageDemoImplTest.kt new file mode 100644 index 00000000..a96626e4 --- /dev/null +++ b/demo/src/test/java/com/shifthackz/aisdv1/demo/TextToImageDemoImplTest.kt @@ -0,0 +1,46 @@ +package com.shifthackz.aisdv1.demo + +import com.shifthackz.aisdv1.core.common.time.TimeProvider +import com.shifthackz.aisdv1.demo.mocks.mockTextToImagePayload +import com.shifthackz.aisdv1.demo.serialize.DemoDataSerializer +import io.mockk.every +import io.mockk.mockk +import org.junit.Before +import org.junit.Test +import java.util.Date + +class TextToImageDemoImplTest { + + private val stubSerializer = DemoDataSerializer() + private val stubTimeProvider = mockk() + + private val demo = TextToImageDemoImpl( + demoDataSerializer = stubSerializer, + timeProvider = stubTimeProvider, + ) + + @Before + fun initialize() { + every { + stubTimeProvider.currentTimeMillis() + } returns 5598L + + every { + stubTimeProvider.currentDate() + } returns Date(5598L) + } + + @Test + fun `given done demo generation, expected generation result base64 is from demo serializer`() { + demo.getDemoBase64(mockTextToImagePayload) + .test() + .await() + .assertNoErrors() + .assertValue { actual -> + stubSerializer + .readDemoAssets() + .contains(actual.image) + } + .assertComplete() + } +} diff --git a/demo/src/test/java/com/shifthackz/aisdv1/demo/mocks/ImageToImagePayloadMocks.kt b/demo/src/test/java/com/shifthackz/aisdv1/demo/mocks/ImageToImagePayloadMocks.kt new file mode 100644 index 00000000..1da3962e --- /dev/null +++ b/demo/src/test/java/com/shifthackz/aisdv1/demo/mocks/ImageToImagePayloadMocks.kt @@ -0,0 +1,29 @@ +package com.shifthackz.aisdv1.demo.mocks + +import com.shifthackz.aisdv1.domain.entity.ImageToImagePayload + +val mockImageToImagePayload = ImageToImagePayload( + base64Image = "", + base64MaskImage = "", + denoisingStrength = 7f, + prompt = "prompt", + negativePrompt = "negative", + samplingSteps = 12, + cfgScale = 0.7f, + width = 512, + height = 512, + restoreFaces = true, + seed = "5598", + subSeed = "1504", + subSeedStrength = 5598f, + sampler = "sampler", + nsfw = true, + batchCount = 1, + inPaintingMaskInvert = 0, + inPaintFullResPadding = 0, + inPaintingFill = 0, + inPaintFullRes = false, + maskBlur = 0, + stabilityAiClipGuidance = null, + stabilityAiStylePreset = null, +) diff --git a/demo/src/test/java/com/shifthackz/aisdv1/demo/mocks/TextToImagePayloadMocks.kt b/demo/src/test/java/com/shifthackz/aisdv1/demo/mocks/TextToImagePayloadMocks.kt new file mode 100644 index 00000000..819dc51c --- /dev/null +++ b/demo/src/test/java/com/shifthackz/aisdv1/demo/mocks/TextToImagePayloadMocks.kt @@ -0,0 +1,24 @@ +package com.shifthackz.aisdv1.demo.mocks + +import com.shifthackz.aisdv1.domain.entity.TextToImagePayload + +val mockTextToImagePayload = TextToImagePayload( + prompt = "prompt", + negativePrompt = "negative", + samplingSteps = 12, + cfgScale = 0.7f, + width = 512, + height = 512, + restoreFaces = true, + seed = "5598", + subSeed = "1504", + subSeedStrength = 5598f, + sampler = "sampler", + nsfw = true, + batchCount = 1, + quality = null, + style = null, + openAiModel = null, + stabilityAiClipGuidance = null, + stabilityAiStylePreset = null, +) diff --git a/dependencies.gradle b/dependencies.gradle index 54c4dd6a..05348ecc 100755 --- a/dependencies.gradle +++ b/dependencies.gradle @@ -26,12 +26,17 @@ ext { googleMaterialVersion = '1.9.0' accompanistSystemUiControllerVersion = '0.30.1' cryptoVersion = '1.0.0' + exifinterfaceVersion = '1.3.6' onnxruntimeVersion = '1.16.3' catppuccinVersion = '0.1.1' composeGesturesVersion = '3.1' composeEasyCropVersion = '0.1.1' testJunitVersion = '4.13.2' + testMockitoVersion = '2.2.0' + testMockkVersion = '1.13.11' + testCoroutinesVersion = '1.8.1' + testTurbibeVersion = '1.1.0' androidx = [ core : "androidx.core:core-ktx:$coreKtxVersion", @@ -53,6 +58,7 @@ ext { pagingRx3 : "androidx.paging:paging-rxjava3:$pagingVersion", pagingCompose : "androidx.paging:paging-compose:$pagingComposeVersion", crypto : "androidx.security:security-crypto:$cryptoVersion", + exifinterface : "androidx.exifinterface:exifinterface:$exifinterfaceVersion", ] google = [ gson : "com.google.code.gson:gson:$gsonVersion", @@ -104,6 +110,10 @@ ext { stringutils: "org.apache.commons:commons-lang3:$apacheLangVersion" ] test = [ - junit: "junit:junit:$testJunitVersion", + junit : "junit:junit:$testJunitVersion", + mockito : "com.nhaarman.mockitokotlin2:mockito-kotlin:$testMockitoVersion", + mockk : "io.mockk:mockk:$testMockkVersion", + coroutines: "org.jetbrains.kotlinx:kotlinx-coroutines-test:$testCoroutinesVersion", + turbine : "app.cash.turbine:turbine:$testTurbibeVersion", ] } diff --git a/domain/build.gradle b/domain/build.gradle index 445a1fb6..b196bb2c 100755 --- a/domain/build.gradle +++ b/domain/build.gradle @@ -15,4 +15,6 @@ dependencies { implementation di.koinCore implementation reactive.rxkotlin testImplementation test.junit + testImplementation test.mockito + testImplementation test.mockk } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/AiGenerationResult.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/AiGenerationResult.kt index 10a77d58..a541bd73 100755 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/AiGenerationResult.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/AiGenerationResult.kt @@ -26,7 +26,7 @@ data class AiGenerationResult( IMAGE_TO_IMAGE("img2img"); companion object { - fun parse(input: String?) = values() + fun parse(input: String?) = entries .find { it.key == input } ?: TEXT_TO_IMAGE } } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/Configuration.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/Configuration.kt index 9d550ce3..1f2007d9 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/Configuration.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/Configuration.kt @@ -3,15 +3,15 @@ package com.shifthackz.aisdv1.domain.entity import com.shifthackz.aisdv1.domain.feature.auth.AuthorizationCredentials data class Configuration( - val serverUrl: String, - val demoMode: Boolean, - val source: ServerSource, - val hordeApiKey: String, - val openAiApiKey: String, - val huggingFaceApiKey: String, - val huggingFaceModel: String, - val stabilityAiApiKey: String, - val stabilityAiEngineId: String, - val authCredentials: AuthorizationCredentials, - val localModelId: String, + val serverUrl: String = "", + val demoMode: Boolean = false, + val source: ServerSource = ServerSource.AUTOMATIC1111, + val hordeApiKey: String = "", + val openAiApiKey: String = "", + val huggingFaceApiKey: String = "", + val huggingFaceModel: String = "", + val stabilityAiApiKey: String = "", + val stabilityAiEngineId: String = "", + val authCredentials: AuthorizationCredentials = AuthorizationCredentials.None, + val localModelId: String = "", ) diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/Settings.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/Settings.kt index 2267d535..84425022 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/Settings.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/Settings.kt @@ -1,20 +1,20 @@ package com.shifthackz.aisdv1.domain.entity data class Settings( - val serverUrl: String, - val sdModel: String, - val demoMode: Boolean, - val monitorConnectivity: Boolean, - val autoSaveAiResults: Boolean, - val saveToMediaStore: Boolean, - val formAdvancedOptionsAlwaysShow: Boolean, - val formPromptTaggedInput: Boolean, - val source: ServerSource, - val hordeApiKey: String, - val localUseNNAPI: Boolean, - val designUseSystemColorPalette: Boolean, - val designUseSystemDarkTheme: Boolean, - val designDarkTheme: Boolean, - val designColorToken: String, - val designDarkThemeToken: String, + val serverUrl: String = "", + val sdModel: String = "", + val demoMode: Boolean = false, + val monitorConnectivity: Boolean = false, + val autoSaveAiResults: Boolean = false, + val saveToMediaStore: Boolean = false, + val formAdvancedOptionsAlwaysShow: Boolean = false, + val formPromptTaggedInput: Boolean = false, + val source: ServerSource = ServerSource.AUTOMATIC1111, + val hordeApiKey: String = "", + val localUseNNAPI: Boolean = false, + val designUseSystemColorPalette: Boolean = false, + val designUseSystemDarkTheme: Boolean = false, + val designDarkTheme: Boolean = false, + val designColorToken: String = "", + val designDarkThemeToken: String = "", ) diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/connectivity/ObserveSeverConnectivityUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/connectivity/ObserveSeverConnectivityUseCaseImpl.kt index 295de376..f01e4c26 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/connectivity/ObserveSeverConnectivityUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/connectivity/ObserveSeverConnectivityUseCaseImpl.kt @@ -6,5 +6,7 @@ internal class ObserveSeverConnectivityUseCaseImpl( private val serverConnectivityGateway: ServerConnectivityGateway, ) : ObserveSeverConnectivityUseCase { - override fun invoke() = serverConnectivityGateway.observe() + override fun invoke() = serverConnectivityGateway + .observe() + .distinctUntilChanged() } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalAiModelsUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalAiModelsUseCaseImpl.kt index 21f14410..5fae54e1 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalAiModelsUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalAiModelsUseCaseImpl.kt @@ -6,5 +6,7 @@ internal class ObserveLocalAiModelsUseCaseImpl( private val repository: DownloadableModelRepository, ) : ObserveLocalAiModelsUseCase { - override fun invoke() = repository.observeAll() + override fun invoke() = repository + .observeAll() + .distinctUntilChanged() } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/ObserveHordeProcessStatusUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/ObserveHordeProcessStatusUseCaseImpl.kt index e106224f..c7154bed 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/ObserveHordeProcessStatusUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/ObserveHordeProcessStatusUseCaseImpl.kt @@ -6,5 +6,7 @@ internal class ObserveHordeProcessStatusUseCaseImpl( private val hordeGenerationRepository: HordeGenerationRepository, ) : ObserveHordeProcessStatusUseCase { - override fun invoke() = hordeGenerationRepository.observeStatus() + override fun invoke() = hordeGenerationRepository + .observeStatus() + .distinctUntilChanged() } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/ObserveLocalDiffusionProcessStatusUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/ObserveLocalDiffusionProcessStatusUseCaseImpl.kt index 68e48a19..8e8430dd 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/ObserveLocalDiffusionProcessStatusUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/ObserveLocalDiffusionProcessStatusUseCaseImpl.kt @@ -6,5 +6,7 @@ internal class ObserveLocalDiffusionProcessStatusUseCaseImpl( private val localDiffusionGenerationRepository: LocalDiffusionGenerationRepository, ) : ObserveLocalDiffusionProcessStatusUseCase { - override fun invoke() = localDiffusionGenerationRepository.observeStatus() + override fun invoke() = localDiffusionGenerationRepository + .observeStatus() + .distinctUntilChanged() } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToHordeUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToHordeUseCaseImpl.kt index a2756980..07458212 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToHordeUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToHordeUseCaseImpl.kt @@ -30,7 +30,7 @@ internal class ConnectToHordeUseCaseImpl( .andThen(testHordeApiKeyUseCase()) .flatMap { if (it) Single.just(Result.success(Unit)) - else Single.error(Throwable("Bad key")) + else Single.error(IllegalStateException("Bad key")) } .onErrorResumeNext { t -> val chain = configuration?.let(setServerConfigurationUseCase::invoke) diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToHuggingFaceUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToHuggingFaceUseCaseImpl.kt index 94fcc33c..9a4cb873 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToHuggingFaceUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToHuggingFaceUseCaseImpl.kt @@ -31,7 +31,7 @@ internal class ConnectToHuggingFaceUseCaseImpl( .andThen(testHuggingFaceApiKeyUseCase()) .flatMap { if (it) Single.just(Result.success(Unit)) - else Single.error(Throwable("Bad key")) + else Single.error(IllegalStateException("Bad key")) } .onErrorResumeNext { t -> val chain = configuration?.let(setServerConfigurationUseCase::invoke) diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToOpenAiUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToOpenAiUseCaseImpl.kt index 5ead7f86..85e8881a 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToOpenAiUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToOpenAiUseCaseImpl.kt @@ -30,7 +30,7 @@ internal class ConnectToOpenAiUseCaseImpl( .andThen(testOpenAiApiKeyUseCase()) .flatMap { if (it) Single.just(Result.success(Unit)) - else Single.error(Throwable("Bad key")) + else Single.error(IllegalStateException("Bad key")) } .onErrorResumeNext { t -> val chain = configuration?.let(setServerConfigurationUseCase::invoke) diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToStabilityAiUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToStabilityAiUseCaseImpl.kt index 0b0f9d00..400767c0 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToStabilityAiUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToStabilityAiUseCaseImpl.kt @@ -30,7 +30,7 @@ internal class ConnectToStabilityAiUseCaseImpl( .andThen(testStabilityAiApiKeyUseCase()) .flatMap { if (it) Single.just(Result.success(Unit)) - else Single.error(Throwable("Bad key")) + else Single.error(IllegalStateException("Bad key")) } .onErrorResumeNext { t -> val chain = configuration?.let(setServerConfigurationUseCase::invoke) diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/splash/SplashNavigationUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/splash/SplashNavigationUseCaseImpl.kt index ff63fd11..f50bfc26 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/splash/SplashNavigationUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/splash/SplashNavigationUseCaseImpl.kt @@ -15,14 +15,8 @@ internal class SplashNavigationUseCaseImpl( Action.LAUNCH_SERVER_SETUP } - preferenceManager.source == ServerSource.LOCAL - || preferenceManager.source == ServerSource.HORDE - || preferenceManager.source == ServerSource.OPEN_AI - || preferenceManager.source == ServerSource.HUGGING_FACE -> { - Action.LAUNCH_HOME - } - - preferenceManager.serverUrl.isEmpty() -> { + preferenceManager.serverUrl.isEmpty() + && preferenceManager.source == ServerSource.AUTOMATIC1111 -> { Action.LAUNCH_SERVER_SETUP } diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/AiGenerationResultMocks.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/AiGenerationResultMocks.kt new file mode 100644 index 00000000..da9d09d9 --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/AiGenerationResultMocks.kt @@ -0,0 +1,26 @@ +package com.shifthackz.aisdv1.domain.mocks + +import com.shifthackz.aisdv1.domain.entity.AiGenerationResult +import java.util.Date + +val mockAiGenerationResult = AiGenerationResult( + id = 5598L, + image = "img", + inputImage = "inp", + createdAt = Date(), + type = AiGenerationResult.Type.IMAGE_TO_IMAGE, + prompt = "prompt", + negativePrompt = "negative", + width = 512, + height = 512, + samplingSteps = 7, + cfgScale = 0.7f, + restoreFaces = true, + sampler = "sampler", + seed = "5598", + subSeed = "1504", + subSeedStrength = 5598f, + denoisingStrength = 1504f, +) + +val mockAiGenerationResults = listOf(mockAiGenerationResult) diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/ConfigurationMocks.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/ConfigurationMocks.kt new file mode 100644 index 00000000..84e9f9da --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/ConfigurationMocks.kt @@ -0,0 +1,14 @@ +package com.shifthackz.aisdv1.domain.mocks + +import com.shifthackz.aisdv1.domain.entity.Configuration + +val mockConfiguration = Configuration( + serverUrl = "http://5598.is.my.favorite.com", + hordeApiKey = "5598", + openAiApiKey = "5598", + huggingFaceApiKey = "5598", + huggingFaceModel = "5598", + stabilityAiApiKey = "5598", + stabilityAiEngineId = "5598", + localModelId = "5598", +) diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/HuggingFaceModelMocks.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/HuggingFaceModelMocks.kt new file mode 100644 index 00000000..d9e8e956 --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/HuggingFaceModelMocks.kt @@ -0,0 +1,13 @@ +package com.shifthackz.aisdv1.domain.mocks + +import com.shifthackz.aisdv1.domain.entity.HuggingFaceModel + +val mockHuggingFaceModels = listOf( + HuggingFaceModel.default, + HuggingFaceModel( + "80974f2d-7ee0-48e5-97bc-448de3c1d634", + "Analog Diffusion", + "wavymulder/Analog-Diffusion", + "https://huggingface.co/wavymulder/Analog-Diffusion", + ), +) diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/ImageToImagePayloadMocks.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/ImageToImagePayloadMocks.kt new file mode 100644 index 00000000..d72aa0ff --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/ImageToImagePayloadMocks.kt @@ -0,0 +1,29 @@ +package com.shifthackz.aisdv1.domain.mocks + +import com.shifthackz.aisdv1.domain.entity.ImageToImagePayload + +val mockImageToImagePayload = ImageToImagePayload( + base64Image = "", + base64MaskImage = "", + denoisingStrength = 7f, + prompt = "prompt", + negativePrompt = "negative", + samplingSteps = 12, + cfgScale = 0.7f, + width = 512, + height = 512, + restoreFaces = true, + seed = "5598", + subSeed = "1504", + subSeedStrength = 5598f, + sampler = "sampler", + nsfw = true, + batchCount = 1, + inPaintingMaskInvert = 0, + inPaintFullResPadding = 0, + inPaintingFill = 0, + inPaintFullRes = false, + maskBlur = 0, + stabilityAiClipGuidance = null, + stabilityAiStylePreset = null, +) diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/LocalAiModelMocks.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/LocalAiModelMocks.kt new file mode 100644 index 00000000..69fcf3f0 --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/LocalAiModelMocks.kt @@ -0,0 +1,15 @@ +package com.shifthackz.aisdv1.domain.mocks + +import com.shifthackz.aisdv1.domain.entity.LocalAiModel + +val mockLocalAiModels = listOf( + LocalAiModel.CUSTOM, + LocalAiModel( + id = "1", + name = "Model 1", + size = "5 Gb", + sources = listOf("https://example.com/1.html"), + downloaded = false, + selected = false, + ), +) diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/MediaStoreInfoMocks.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/MediaStoreInfoMocks.kt new file mode 100644 index 00000000..e3049220 --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/MediaStoreInfoMocks.kt @@ -0,0 +1,9 @@ +package com.shifthackz.aisdv1.domain.mocks + +import android.net.Uri +import com.shifthackz.aisdv1.domain.entity.MediaStoreInfo + +val mockMediaStoreInfo = MediaStoreInfo( + count = 5598, + folderUri = Uri.EMPTY, +) diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/ServerConfigurationMocks.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/ServerConfigurationMocks.kt new file mode 100644 index 00000000..9008a5bf --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/ServerConfigurationMocks.kt @@ -0,0 +1,5 @@ +package com.shifthackz.aisdv1.domain.mocks + +import com.shifthackz.aisdv1.domain.entity.ServerConfiguration + +val mockServerConfiguration = ServerConfiguration("checkpoint") diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/SettingsMocks.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/SettingsMocks.kt new file mode 100644 index 00000000..4dfca863 --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/SettingsMocks.kt @@ -0,0 +1,23 @@ +package com.shifthackz.aisdv1.domain.mocks + +import com.shifthackz.aisdv1.domain.entity.ServerSource +import com.shifthackz.aisdv1.domain.entity.Settings + +val mockSettings = Settings( + serverUrl = "", + sdModel = "", + demoMode = false, + monitorConnectivity = true, + autoSaveAiResults = true, + saveToMediaStore = true, + formAdvancedOptionsAlwaysShow = true, + formPromptTaggedInput = true, + source = ServerSource.STABILITY_AI, + hordeApiKey = "", + localUseNNAPI = false, + designUseSystemColorPalette = true, + designUseSystemDarkTheme = true, + designDarkTheme = true, + designColorToken = "", + designDarkThemeToken = "", +) diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/StabilityAiEngineMocks.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/StabilityAiEngineMocks.kt new file mode 100644 index 00000000..330dad13 --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/StabilityAiEngineMocks.kt @@ -0,0 +1,14 @@ +package com.shifthackz.aisdv1.domain.mocks + +import com.shifthackz.aisdv1.domain.entity.StabilityAiEngine + +val mockStabilityAiEngines = listOf( + StabilityAiEngine( + id = "engine_1", + name = "Engine 1", + ), + StabilityAiEngine( + id = "engine_2", + name = "Engine 2", + ), +) diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/StableDiffusionEmbeddingMocks.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/StableDiffusionEmbeddingMocks.kt new file mode 100644 index 00000000..8cf54d7e --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/StableDiffusionEmbeddingMocks.kt @@ -0,0 +1,9 @@ +package com.shifthackz.aisdv1.domain.mocks + +import com.shifthackz.aisdv1.domain.entity.StableDiffusionEmbedding + +val mockStableDiffusionEmbeddings = listOf( + StableDiffusionEmbedding("embedding_1"), + StableDiffusionEmbedding("embedding_2"), + StableDiffusionEmbedding("embedding_3"), +) diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/StableDiffusionHyperNetworkMocks.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/StableDiffusionHyperNetworkMocks.kt new file mode 100644 index 00000000..e7e697eb --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/StableDiffusionHyperNetworkMocks.kt @@ -0,0 +1,14 @@ +package com.shifthackz.aisdv1.domain.mocks + +import com.shifthackz.aisdv1.domain.entity.StableDiffusionHyperNetwork + +val mockStableDiffusionHyperNetworks = listOf( + StableDiffusionHyperNetwork( + name = "hyper_net_1", + path = "", + ), + StableDiffusionHyperNetwork( + name = "hyper_net_2", + path = "", + ), +) diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/StableDiffusionModelMocks.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/StableDiffusionModelMocks.kt new file mode 100644 index 00000000..f55d4e38 --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/StableDiffusionModelMocks.kt @@ -0,0 +1,22 @@ +package com.shifthackz.aisdv1.domain.mocks + +import com.shifthackz.aisdv1.domain.entity.StableDiffusionModel + +val mockStableDiffusionModels = listOf( + StableDiffusionModel( + title = "model", + modelName = "name", + hash = "hash", + sha256 = "sha256", + filename = "filename", + config = "config", + ), + StableDiffusionModel( + title = "checkpoint", + modelName = "checkpoint", + hash = "hash_2", + sha256 = "sha256_2", + filename = "filename_2", + config = "config_2", + ), +) diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/StableDiffusionSamplerMocks.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/StableDiffusionSamplerMocks.kt new file mode 100644 index 00000000..2cece698 --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/StableDiffusionSamplerMocks.kt @@ -0,0 +1,16 @@ +package com.shifthackz.aisdv1.domain.mocks + +import com.shifthackz.aisdv1.domain.entity.StableDiffusionSampler + +val mockStableDiffusionSamplers = listOf( + StableDiffusionSampler( + name = "sampler_1", + aliases = listOf("alias_1"), + options = mapOf("option" to "value"), + ), + StableDiffusionSampler( + name = "sampler_2", + aliases = listOf("alias_2"), + options = mapOf("option" to "value"), + ), +) diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/TextToImagePayloadMocks.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/TextToImagePayloadMocks.kt new file mode 100644 index 00000000..1817e0e5 --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/TextToImagePayloadMocks.kt @@ -0,0 +1,24 @@ +package com.shifthackz.aisdv1.domain.mocks + +import com.shifthackz.aisdv1.domain.entity.TextToImagePayload + +val mockTextToImagePayload = TextToImagePayload( + prompt = "prompt", + negativePrompt = "negative", + samplingSteps = 12, + cfgScale = 0.7f, + width = 512, + height = 512, + restoreFaces = true, + seed = "5598", + subSeed = "1504", + subSeedStrength = 5598f, + sampler = "sampler", + nsfw = true, + batchCount = 1, + quality = null, + style = null, + openAiModel = null, + stabilityAiClipGuidance = null, + stabilityAiStylePreset = null, +) diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/caching/ClearAppCacheUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/caching/ClearAppCacheUseCaseImplTest.kt new file mode 100644 index 00000000..5a6dc4bf --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/caching/ClearAppCacheUseCaseImplTest.kt @@ -0,0 +1,115 @@ +package com.shifthackz.aisdv1.domain.usecase.caching + +import com.shifthackz.aisdv1.core.common.file.FileProviderDescriptor +import com.shifthackz.aisdv1.core.common.log.FileLoggingTree +import com.shifthackz.aisdv1.domain.repository.GenerationResultRepository +import io.mockk.every +import io.mockk.mockk +import io.mockk.mockkObject +import io.mockk.unmockkObject +import io.reactivex.rxjava3.core.Completable +import org.junit.After +import org.junit.Before +import org.junit.Test + +class ClearAppCacheUseCaseImplTest { + + private val stubException = Throwable("Fatal error.") + private val stubFileProviderDescriptor = mockk() + private val stubRepository = mockk() + + private val useCase = ClearAppCacheUseCaseImpl( + fileProviderDescriptor = stubFileProviderDescriptor, + repository = stubRepository, + ) + + @Before + fun initialize() { + mockkObject(FileLoggingTree) + } + + @After + fun finalize() { + unmockkObject(FileLoggingTree) + } + + @Test + fun `given repository and logs clear success, expected complete value`() { + every { + stubRepository.deleteAll() + } returns Completable.complete() + + every { + stubFileProviderDescriptor.logsCacheDirPath + } returns "/tmp/cache" + + every { + FileLoggingTree.clearLog(any()) + } returns Unit + + useCase() + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given repository clear fails, logs clear success, expected error value`() { + every { + stubRepository.deleteAll() + } returns Completable.error(stubException) + + every { + stubFileProviderDescriptor.logsCacheDirPath + } returns "/tmp/cache" + + every { + FileLoggingTree.clearLog(any()) + } returns Unit + + useCase() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given repository clear success, logs clear fails, expected error value`() { + every { + stubRepository.deleteAll() + } returns Completable.complete() + + every { + stubFileProviderDescriptor.logsCacheDirPath + } returns "/tmp/cache" + + every { + FileLoggingTree.clearLog(any()) + } throws stubException + + useCase() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given repository clear success, file provider throws exception, expected error value`() { + every { + stubRepository.deleteAll() + } returns Completable.complete() + + every { + stubFileProviderDescriptor.logsCacheDirPath + } throws stubException + + useCase() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } +} diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/caching/DataPreLoaderUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/caching/DataPreLoaderUseCaseImplTest.kt new file mode 100644 index 00000000..372373b5 --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/caching/DataPreLoaderUseCaseImplTest.kt @@ -0,0 +1,232 @@ +package com.shifthackz.aisdv1.domain.usecase.caching + +import com.nhaarman.mockitokotlin2.mock +import com.nhaarman.mockitokotlin2.whenever +import com.shifthackz.aisdv1.domain.repository.ServerConfigurationRepository +import com.shifthackz.aisdv1.domain.repository.StableDiffusionEmbeddingsRepository +import com.shifthackz.aisdv1.domain.repository.StableDiffusionHyperNetworksRepository +import com.shifthackz.aisdv1.domain.repository.StableDiffusionLorasRepository +import com.shifthackz.aisdv1.domain.repository.StableDiffusionModelsRepository +import com.shifthackz.aisdv1.domain.repository.StableDiffusionSamplersRepository +import io.reactivex.rxjava3.core.Completable +import org.junit.Test + +class DataPreLoaderUseCaseImplTest { + + private val stubServerConfigurationRepository = mock() + private val stubStableDiffusionModelsRepository = mock() + private val stubStableDiffusionSamplersRepository = mock() + private val stubStableDiffusionLorasRepository = mock() + private val stubStableDiffusionHyperNetworksRepository = mock() + private val stubStableDiffusionEmbeddingsRepository = mock() + + private val useCase = DataPreLoaderUseCaseImpl( + serverConfigurationRepository = stubServerConfigurationRepository, + sdModelsRepository = stubStableDiffusionModelsRepository, + sdSamplersRepository = stubStableDiffusionSamplersRepository, + sdLorasRepository = stubStableDiffusionLorasRepository, + sdHyperNetworksRepository = stubStableDiffusionHyperNetworksRepository, + sdEmbeddingsRepository = stubStableDiffusionEmbeddingsRepository, + ) + + @Test + fun `given all data fetched successfully, expected complete value`() { + whenever(stubServerConfigurationRepository.fetchConfiguration()) + .thenReturn(Completable.complete()) + + whenever(stubStableDiffusionModelsRepository.fetchModels()) + .thenReturn(Completable.complete()) + + whenever(stubStableDiffusionSamplersRepository.fetchSamplers()) + .thenReturn(Completable.complete()) + + whenever(stubStableDiffusionLorasRepository.fetchLoras()) + .thenReturn(Completable.complete()) + + whenever(stubStableDiffusionHyperNetworksRepository.fetchHyperNetworks()) + .thenReturn(Completable.complete()) + + whenever(stubStableDiffusionEmbeddingsRepository.fetchEmbeddings()) + .thenReturn(Completable.complete()) + + useCase() + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given configuration fetch failed, expected error value`() { + val stubException = Throwable("Can not fetch configuration.") + + whenever(stubServerConfigurationRepository.fetchConfiguration()) + .thenReturn(Completable.error(stubException)) + + whenever(stubStableDiffusionModelsRepository.fetchModels()) + .thenReturn(Completable.complete()) + + whenever(stubStableDiffusionSamplersRepository.fetchSamplers()) + .thenReturn(Completable.complete()) + + whenever(stubStableDiffusionLorasRepository.fetchLoras()) + .thenReturn(Completable.complete()) + + whenever(stubStableDiffusionHyperNetworksRepository.fetchHyperNetworks()) + .thenReturn(Completable.complete()) + + whenever(stubStableDiffusionEmbeddingsRepository.fetchEmbeddings()) + .thenReturn(Completable.complete()) + + useCase() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given models fetch failed, expected error value`() { + val stubException = Throwable("Can not fetch models.") + + whenever(stubServerConfigurationRepository.fetchConfiguration()) + .thenReturn(Completable.complete()) + + whenever(stubStableDiffusionModelsRepository.fetchModels()) + .thenReturn(Completable.error(stubException)) + + whenever(stubStableDiffusionSamplersRepository.fetchSamplers()) + .thenReturn(Completable.complete()) + + whenever(stubStableDiffusionLorasRepository.fetchLoras()) + .thenReturn(Completable.complete()) + + whenever(stubStableDiffusionHyperNetworksRepository.fetchHyperNetworks()) + .thenReturn(Completable.complete()) + + whenever(stubStableDiffusionEmbeddingsRepository.fetchEmbeddings()) + .thenReturn(Completable.complete()) + + useCase() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given samplers fetch failed, expected error value`() { + val stubException = Throwable("Can not fetch samplers.") + + whenever(stubServerConfigurationRepository.fetchConfiguration()) + .thenReturn(Completable.complete()) + + whenever(stubStableDiffusionModelsRepository.fetchModels()) + .thenReturn(Completable.complete()) + + whenever(stubStableDiffusionSamplersRepository.fetchSamplers()) + .thenReturn(Completable.error(stubException)) + + whenever(stubStableDiffusionLorasRepository.fetchLoras()) + .thenReturn(Completable.complete()) + + whenever(stubStableDiffusionHyperNetworksRepository.fetchHyperNetworks()) + .thenReturn(Completable.complete()) + + whenever(stubStableDiffusionEmbeddingsRepository.fetchEmbeddings()) + .thenReturn(Completable.complete()) + + useCase() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given loras fetch failed, expected error value`() { + val stubException = Throwable("Can not fetch loras.") + + whenever(stubServerConfigurationRepository.fetchConfiguration()) + .thenReturn(Completable.complete()) + + whenever(stubStableDiffusionModelsRepository.fetchModels()) + .thenReturn(Completable.complete()) + + whenever(stubStableDiffusionSamplersRepository.fetchSamplers()) + .thenReturn(Completable.complete()) + + whenever(stubStableDiffusionLorasRepository.fetchLoras()) + .thenReturn(Completable.error(stubException)) + + whenever(stubStableDiffusionHyperNetworksRepository.fetchHyperNetworks()) + .thenReturn(Completable.complete()) + + whenever(stubStableDiffusionEmbeddingsRepository.fetchEmbeddings()) + .thenReturn(Completable.complete()) + + useCase() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given hypernetworks fetch failed, expected error value`() { + val stubException = Throwable("Can not fetch hypernetworks.") + + whenever(stubServerConfigurationRepository.fetchConfiguration()) + .thenReturn(Completable.complete()) + + whenever(stubStableDiffusionModelsRepository.fetchModels()) + .thenReturn(Completable.complete()) + + whenever(stubStableDiffusionSamplersRepository.fetchSamplers()) + .thenReturn(Completable.complete()) + + whenever(stubStableDiffusionLorasRepository.fetchLoras()) + .thenReturn(Completable.complete()) + + whenever(stubStableDiffusionHyperNetworksRepository.fetchHyperNetworks()) + .thenReturn(Completable.error(stubException)) + + whenever(stubStableDiffusionEmbeddingsRepository.fetchEmbeddings()) + .thenReturn(Completable.complete()) + + useCase() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given embeddings fetch failed, expected error value`() { + val stubException = Throwable("Can not fetch embeddings.") + + whenever(stubServerConfigurationRepository.fetchConfiguration()) + .thenReturn(Completable.complete()) + + whenever(stubStableDiffusionModelsRepository.fetchModels()) + .thenReturn(Completable.complete()) + + whenever(stubStableDiffusionSamplersRepository.fetchSamplers()) + .thenReturn(Completable.complete()) + + whenever(stubStableDiffusionLorasRepository.fetchLoras()) + .thenReturn(Completable.complete()) + + whenever(stubStableDiffusionHyperNetworksRepository.fetchHyperNetworks()) + .thenReturn(Completable.complete()) + + whenever(stubStableDiffusionEmbeddingsRepository.fetchEmbeddings()) + .thenReturn(Completable.error(stubException)) + + useCase() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } +} diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/caching/GetLastResultFromCacheUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/caching/GetLastResultFromCacheUseCaseImplTest.kt new file mode 100644 index 00000000..3d275385 --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/caching/GetLastResultFromCacheUseCaseImplTest.kt @@ -0,0 +1,41 @@ +package com.shifthackz.aisdv1.domain.usecase.caching + +import com.nhaarman.mockitokotlin2.mock +import com.nhaarman.mockitokotlin2.whenever +import com.shifthackz.aisdv1.domain.mocks.mockAiGenerationResult +import com.shifthackz.aisdv1.domain.repository.TemporaryGenerationResultRepository +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class GetLastResultFromCacheUseCaseImplTest { + + private val stubException = Throwable("No last cached result") + private val stubRepository = mock() + + private val useCase = GetLastResultFromCacheUseCaseImpl(stubRepository) + + @Test + fun `given repository returned last ai result, expected valid result value`() { + whenever(stubRepository.get()) + .thenReturn(Single.just(mockAiGenerationResult)) + + useCase() + .test() + .assertNoErrors() + .assertValue(mockAiGenerationResult) + .await() + .assertComplete() + } + + @Test + fun `given repository has no last ai result, expected error value`() { + whenever(stubRepository.get()) + .thenReturn(Single.error(stubException)) + + useCase() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } +} diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/caching/SaveLastResultToCacheUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/caching/SaveLastResultToCacheUseCaseImplTest.kt new file mode 100644 index 00000000..7507d5f1 --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/caching/SaveLastResultToCacheUseCaseImplTest.kt @@ -0,0 +1,66 @@ +package com.shifthackz.aisdv1.domain.usecase.caching + +import com.nhaarman.mockitokotlin2.any +import com.nhaarman.mockitokotlin2.mock +import com.nhaarman.mockitokotlin2.whenever +import com.shifthackz.aisdv1.domain.mocks.mockAiGenerationResult +import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import com.shifthackz.aisdv1.domain.repository.TemporaryGenerationResultRepository +import io.reactivex.rxjava3.core.Completable +import org.junit.Test + +class SaveLastResultToCacheUseCaseImplTest { + + private val stubException = Throwable("No last cached result") + private val stubRepository = mock() + private val stubPreferenceManager = mock() + + private val useCase = SaveLastResultToCacheUseCaseImpl( + temporaryGenerationResultRepository = stubRepository, + preferenceManager = stubPreferenceManager, + ) + + @Test + fun `given user has enabled autosave, try to save, expected valid ai result value`() { + whenever(stubPreferenceManager.autoSaveAiResults) + .thenReturn(true) + + useCase(mockAiGenerationResult) + .test() + .assertNoErrors() + .assertValue(mockAiGenerationResult) + .await() + .assertComplete() + } + + @Test + fun `given user has disabled autosave, save completed successfully, expected valid ai result value`() { + whenever(stubPreferenceManager.autoSaveAiResults) + .thenReturn(false) + + whenever(stubRepository.put(any())) + .thenReturn(Completable.complete()) + + useCase(mockAiGenerationResult) + .test() + .assertNoErrors() + .assertValue(mockAiGenerationResult) + .await() + .assertComplete() + } + + @Test + fun `given user has disabled autosave, save fails, expected error value`() { + whenever(stubPreferenceManager.autoSaveAiResults) + .thenReturn(false) + + whenever(stubRepository.put(any())) + .thenReturn(Completable.error(stubException)) + + useCase(mockAiGenerationResult) + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } +} diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/connectivity/ObserveSeverConnectivityUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/connectivity/ObserveSeverConnectivityUseCaseImplTest.kt new file mode 100644 index 00000000..777ade05 --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/connectivity/ObserveSeverConnectivityUseCaseImplTest.kt @@ -0,0 +1,88 @@ +package com.shifthackz.aisdv1.domain.usecase.connectivity + +import com.nhaarman.mockitokotlin2.mock +import com.nhaarman.mockitokotlin2.whenever +import com.shifthackz.aisdv1.domain.gateway.ServerConnectivityGateway +import io.reactivex.rxjava3.core.BackpressureStrategy +import io.reactivex.rxjava3.core.Flowable +import io.reactivex.rxjava3.subjects.BehaviorSubject +import org.junit.Before +import org.junit.Test + +class ObserveSeverConnectivityUseCaseImplTest { + + private val stubException = Throwable("Unexpected Flowable termination.") + private val stubConnectivityValue = BehaviorSubject.create() + private val stubGateway = mock() + + private val useCase = ObserveSeverConnectivityUseCaseImpl(stubGateway) + + @Before + fun initialize() { + whenever(stubGateway.observe()) + .thenReturn(stubConnectivityValue.toFlowable(BackpressureStrategy.LATEST)) + } + + @Test + fun `given server not connected, then connection establishes, expected false, then true`() { + val stubObserver = useCase().test() + + stubConnectivityValue.onNext(false) + + stubObserver + .assertNoErrors() + .assertValueAt(0, false) + + stubConnectivityValue.onNext(true) + + stubObserver + .assertNoErrors() + .assertValueAt(1, true) + } + + @Test + fun `given server connected, then connection lost, expected true, then false`() { + val stubObserver = useCase().test() + + stubConnectivityValue.onNext(true) + + stubObserver + .assertNoErrors() + .assertValueAt(0, true) + + stubConnectivityValue.onNext(false) + + stubObserver + .assertNoErrors() + .assertValueAt(1, false) + } + + @Test + fun `given server connected, gateway emits value twice, expected true, only once`() { + val stubObserver = useCase().test() + + stubConnectivityValue.onNext(true) + + stubObserver + .assertNoErrors() + .assertValueAt(0, true) + + stubConnectivityValue.onNext(true) + + stubObserver + .assertNoErrors() + .assertValueCount(1) + } + + @Test + fun `given gateway throws unexpected flowable termination, expected error value`() { + whenever(stubGateway.observe()) + .thenReturn(Flowable.error(stubException)) + + useCase() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } +} diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/connectivity/PingStableDiffusionServiceUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/connectivity/PingStableDiffusionServiceUseCaseImplTest.kt new file mode 100644 index 00000000..741e5e34 --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/connectivity/PingStableDiffusionServiceUseCaseImplTest.kt @@ -0,0 +1,39 @@ +package com.shifthackz.aisdv1.domain.usecase.connectivity + +import com.nhaarman.mockitokotlin2.mock +import com.nhaarman.mockitokotlin2.whenever +import com.shifthackz.aisdv1.domain.repository.StableDiffusionGenerationRepository +import io.reactivex.rxjava3.core.Completable +import org.junit.Test + +class PingStableDiffusionServiceUseCaseImplTest { + + private val stubException = Throwable("Can not establish connection to server.") + private val stubRepository = mock() + + private val useCase = PingStableDiffusionServiceUseCaseImpl(stubRepository) + + @Test + fun `given connection to server can be established, expected complete value`() { + whenever(stubRepository.checkApiAvailability()) + .thenReturn(Completable.complete()) + + useCase() + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given connection to server can not be established, expected error value`() { + whenever(stubRepository.checkApiAvailability()) + .thenReturn(Completable.error(stubException)) + + useCase() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } +} diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/connectivity/TestConnectivityUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/connectivity/TestConnectivityUseCaseImplTest.kt new file mode 100644 index 00000000..6d5ffcc1 --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/connectivity/TestConnectivityUseCaseImplTest.kt @@ -0,0 +1,44 @@ +package com.shifthackz.aisdv1.domain.usecase.connectivity + +import com.nhaarman.mockitokotlin2.any +import com.nhaarman.mockitokotlin2.mock +import com.nhaarman.mockitokotlin2.whenever +import com.shifthackz.aisdv1.domain.repository.StableDiffusionGenerationRepository +import io.reactivex.rxjava3.core.Completable +import org.junit.Test + +class TestConnectivityUseCaseImplTest { + + companion object { + private const val STUB_URL = "https://5598.is.my.favourite.com" + } + + private val stubException = Throwable("Can not establish connection to server.") + private val stubRepository = mock() + + private val useCase = TestConnectivityUseCaseImpl(stubRepository) + + @Test + fun `given connection to server can be established, expected complete value`() { + whenever(stubRepository.checkApiAvailability(any())) + .thenReturn(Completable.complete()) + + useCase(STUB_URL) + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given connection to server can not be established, expected error value`() { + whenever(stubRepository.checkApiAvailability(any())) + .thenReturn(Completable.error(stubException)) + + useCase(STUB_URL) + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } +} diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/connectivity/TestHordeApiKeyUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/connectivity/TestHordeApiKeyUseCaseImplTest.kt new file mode 100644 index 00000000..769d5f36 --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/connectivity/TestHordeApiKeyUseCaseImplTest.kt @@ -0,0 +1,53 @@ +package com.shifthackz.aisdv1.domain.usecase.connectivity + +import com.nhaarman.mockitokotlin2.mock +import com.nhaarman.mockitokotlin2.whenever +import com.shifthackz.aisdv1.domain.repository.HordeGenerationRepository +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class TestHordeApiKeyUseCaseImplTest { + + private val stubException = Throwable("Can not connect to Horde AI.") + private val stubRepository = mock() + + private val useCase = TestHordeApiKeyUseCaseImpl(stubRepository) + + @Test + fun `given horde api key passed validation, expected true`() { + whenever(stubRepository.validateApiKey()) + .thenReturn(Single.just(true)) + + useCase() + .test() + .assertNoErrors() + .assertValue(true) + .await() + .assertComplete() + } + + @Test + fun `given horde api key not passed validation, expected false`() { + whenever(stubRepository.validateApiKey()) + .thenReturn(Single.just(false)) + + useCase() + .test() + .assertNoErrors() + .assertValue(false) + .await() + .assertComplete() + } + + @Test + fun `given validator thrown exception, expected error value`() { + whenever(stubRepository.validateApiKey()) + .thenReturn(Single.error(stubException)) + + useCase() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } +} diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/connectivity/TestHuggingFaceApiKeyUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/connectivity/TestHuggingFaceApiKeyUseCaseImplTest.kt new file mode 100644 index 00000000..10d4a46c --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/connectivity/TestHuggingFaceApiKeyUseCaseImplTest.kt @@ -0,0 +1,53 @@ +package com.shifthackz.aisdv1.domain.usecase.connectivity + +import com.nhaarman.mockitokotlin2.mock +import com.nhaarman.mockitokotlin2.whenever +import com.shifthackz.aisdv1.domain.repository.HuggingFaceGenerationRepository +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class TestHuggingFaceApiKeyUseCaseImplTest { + + private val stubException = Throwable("Can not connect to Hugging Face AI.") + private val stubRepository = mock() + + private val useCase = TestHuggingFaceApiKeyUseCaseImpl(stubRepository) + + @Test + fun `given hugging face api key passed validation, expected true`() { + whenever(stubRepository.validateApiKey()) + .thenReturn(Single.just(true)) + + useCase() + .test() + .assertNoErrors() + .assertValue(true) + .await() + .assertComplete() + } + + @Test + fun `given hugging face api key not passed validation, expected false`() { + whenever(stubRepository.validateApiKey()) + .thenReturn(Single.just(false)) + + useCase() + .test() + .assertNoErrors() + .assertValue(false) + .await() + .assertComplete() + } + + @Test + fun `given validator thrown exception, expected error value`() { + whenever(stubRepository.validateApiKey()) + .thenReturn(Single.error(stubException)) + + useCase() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } +} diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/connectivity/TestOpenAiApiKeyUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/connectivity/TestOpenAiApiKeyUseCaseImplTest.kt new file mode 100644 index 00000000..40877cc8 --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/connectivity/TestOpenAiApiKeyUseCaseImplTest.kt @@ -0,0 +1,53 @@ +package com.shifthackz.aisdv1.domain.usecase.connectivity + +import com.nhaarman.mockitokotlin2.mock +import com.nhaarman.mockitokotlin2.whenever +import com.shifthackz.aisdv1.domain.repository.OpenAiGenerationRepository +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class TestOpenAiApiKeyUseCaseImplTest { + + private val stubException = Throwable("Can not connect to OpenAI.") + private val stubRepository = mock() + + private val useCase = TestOpenAiApiKeyUseCaseImpl(stubRepository) + + @Test + fun `given openai api key passed validation, expected true`() { + whenever(stubRepository.validateApiKey()) + .thenReturn(Single.just(true)) + + useCase() + .test() + .assertNoErrors() + .assertValue(true) + .await() + .assertComplete() + } + + @Test + fun `given openai api key not passed validation, expected false`() { + whenever(stubRepository.validateApiKey()) + .thenReturn(Single.just(false)) + + useCase() + .test() + .assertNoErrors() + .assertValue(false) + .await() + .assertComplete() + } + + @Test + fun `given validator thrown exception, expected error value`() { + whenever(stubRepository.validateApiKey()) + .thenReturn(Single.error(stubException)) + + useCase() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } +} diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/connectivity/TestStabilityAiApiKeyUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/connectivity/TestStabilityAiApiKeyUseCaseImplTest.kt new file mode 100644 index 00000000..23a1d673 --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/connectivity/TestStabilityAiApiKeyUseCaseImplTest.kt @@ -0,0 +1,53 @@ +package com.shifthackz.aisdv1.domain.usecase.connectivity + +import com.nhaarman.mockitokotlin2.mock +import com.nhaarman.mockitokotlin2.whenever +import com.shifthackz.aisdv1.domain.repository.StabilityAiGenerationRepository +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class TestStabilityAiApiKeyUseCaseImplTest { + + private val stubException = Throwable("Can not connect to Stability AI.") + private val stubRepository = mock() + + private val useCase = TestStabilityAiApiKeyUseCaseImpl(stubRepository) + + @Test + fun `given Stability AI api key passed validation, expected true`() { + whenever(stubRepository.validateApiKey()) + .thenReturn(Single.just(true)) + + useCase() + .test() + .assertNoErrors() + .assertValue(true) + .await() + .assertComplete() + } + + @Test + fun `given Stability AI api key not passed validation, expected false`() { + whenever(stubRepository.validateApiKey()) + .thenReturn(Single.just(false)) + + useCase() + .test() + .assertNoErrors() + .assertValue(false) + .await() + .assertComplete() + } + + @Test + fun `given validator thrown exception, expected error value`() { + whenever(stubRepository.validateApiKey()) + .thenReturn(Single.error(stubException)) + + useCase() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } +} diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/debug/DebugInsertBadBase64UseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/debug/DebugInsertBadBase64UseCaseImplTest.kt new file mode 100644 index 00000000..1d56c81b --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/debug/DebugInsertBadBase64UseCaseImplTest.kt @@ -0,0 +1,41 @@ +package com.shifthackz.aisdv1.domain.usecase.debug + +import com.nhaarman.mockitokotlin2.any +import com.nhaarman.mockitokotlin2.mock +import com.nhaarman.mockitokotlin2.whenever +import com.shifthackz.aisdv1.domain.repository.GenerationResultRepository +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class DebugInsertBadBase64UseCaseImplTest { + + private val stubRepository = mock() + + private val useCase = DebugInsertBadBase64UseCaseImpl(stubRepository) + + @Test + fun `given inserted value with bad BASE64 into DB, expected complete value`() { + whenever(stubRepository.insert(any())) + .thenReturn(Single.just(5598L)) + + useCase() + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given failed to insert value with bad BASE64 into DB, expected error value`() { + val stubException = Throwable("DB error.") + + whenever(stubRepository.insert(any())) + .thenReturn(Single.error(stubException)) + + useCase() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } +} diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/downloadable/DeleteModelUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/downloadable/DeleteModelUseCaseImplTest.kt new file mode 100644 index 00000000..f3504939 --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/downloadable/DeleteModelUseCaseImplTest.kt @@ -0,0 +1,41 @@ +package com.shifthackz.aisdv1.domain.usecase.downloadable + +import com.nhaarman.mockitokotlin2.any +import com.nhaarman.mockitokotlin2.mock +import com.nhaarman.mockitokotlin2.whenever +import com.shifthackz.aisdv1.domain.repository.DownloadableModelRepository +import io.reactivex.rxjava3.core.Completable +import org.junit.Test + +class DeleteModelUseCaseImplTest { + + private val stubRepository = mock() + + private val useCase = DeleteModelUseCaseImpl(stubRepository) + + @Test + fun `given model deleted successfully, expected completion`() { + whenever(stubRepository.delete(any())) + .thenReturn(Completable.complete()) + + useCase("5598") + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given model delete failed, expected error value`() { + val stubException = Throwable("Failed to delete model.") + + whenever(stubRepository.delete(any())) + .thenReturn(Completable.error(stubException)) + + useCase("5598") + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } +} diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/downloadable/DownloadModelUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/downloadable/DownloadModelUseCaseImplTest.kt new file mode 100644 index 00000000..669200f4 --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/downloadable/DownloadModelUseCaseImplTest.kt @@ -0,0 +1,152 @@ +package com.shifthackz.aisdv1.domain.usecase.downloadable + +import com.nhaarman.mockitokotlin2.any +import com.nhaarman.mockitokotlin2.mock +import com.nhaarman.mockitokotlin2.whenever +import com.shifthackz.aisdv1.domain.entity.DownloadState +import com.shifthackz.aisdv1.domain.repository.DownloadableModelRepository +import io.reactivex.rxjava3.core.Observable +import io.reactivex.rxjava3.subjects.PublishSubject +import org.junit.Before +import org.junit.Test +import java.io.File + +class DownloadModelUseCaseImplTest { + + private val stubFile = File("/storage/emulated/0/file.dat") + private val stubException = Throwable("Error downloading file.") + private val stubTerminateException = Throwable("Unexpected Observable termination.") + private val stubDownloadStatus = PublishSubject.create() + private val stubRepository = mock() + + private val useCase = DownloadModelUseCaseImpl(stubRepository) + + @Before + fun initialize() { + whenever(stubRepository.download(any())) + .thenReturn(stubDownloadStatus) + } + + @Test + fun `given download running, then finishes successfully, expected final state is Complete`() { + val stubObserver = useCase("5598").test() + + stubDownloadStatus.onNext(DownloadState.Unknown) + + stubObserver + .assertNoErrors() + .assertValueAt(0, DownloadState.Unknown) + + stubDownloadStatus.onNext(DownloadState.Downloading(33)) + + stubObserver + .assertNoErrors() + .assertValueAt(1, DownloadState.Downloading(33)) + + stubDownloadStatus.onNext(DownloadState.Downloading(100)) + + stubObserver + .assertNoErrors() + .assertValueAt(2, DownloadState.Downloading(100)) + + stubDownloadStatus.onNext(DownloadState.Complete(stubFile)) + + stubObserver + .assertNoErrors() + .assertValueAt(3, DownloadState.Complete(stubFile)) + } + + @Test + fun `given download running, then fails, expected final state is Error`() { + val stubObserver = useCase("5598").test() + + stubDownloadStatus.onNext(DownloadState.Unknown) + + stubObserver + .assertNoErrors() + .assertValueAt(0, DownloadState.Unknown) + + stubDownloadStatus.onNext(DownloadState.Downloading(33)) + + stubObserver + .assertNoErrors() + .assertValueAt(1, DownloadState.Downloading(33)) + + stubDownloadStatus.onNext(DownloadState.Downloading(100)) + + stubObserver + .assertNoErrors() + .assertValueAt(2, DownloadState.Downloading(100)) + + stubDownloadStatus.onNext(DownloadState.Error(stubException)) + + stubObserver + .assertNoErrors() + .assertValueAt(3, DownloadState.Error(stubException)) + } + + @Test + fun `given download running, then fails, then user restarts download, then completes, expected state Error on 1st try, final state is Complete`() { + val stubObserver = useCase("5598").test() + + stubDownloadStatus.onNext(DownloadState.Unknown) + + stubObserver + .assertNoErrors() + .assertValueAt(0, DownloadState.Unknown) + + stubDownloadStatus.onNext(DownloadState.Downloading(33)) + + stubObserver + .assertNoErrors() + .assertValueAt(1, DownloadState.Downloading(33)) + + stubDownloadStatus.onNext(DownloadState.Downloading(100)) + + stubObserver + .assertNoErrors() + .assertValueAt(2, DownloadState.Downloading(100)) + + stubDownloadStatus.onNext(DownloadState.Error(stubException)) + + stubObserver + .assertNoErrors() + .assertValueAt(3, DownloadState.Error(stubException)) + + stubDownloadStatus.onNext(DownloadState.Unknown) + + stubObserver + .assertNoErrors() + .assertValueAt(4, DownloadState.Unknown) + + stubDownloadStatus.onNext(DownloadState.Downloading(33)) + + stubObserver + .assertNoErrors() + .assertValueAt(5, DownloadState.Downloading(33)) + + stubDownloadStatus.onNext(DownloadState.Downloading(100)) + + stubObserver + .assertNoErrors() + .assertValueAt(6, DownloadState.Downloading(100)) + + stubDownloadStatus.onNext(DownloadState.Complete(stubFile)) + + stubObserver + .assertNoErrors() + .assertValueAt(7, DownloadState.Complete(stubFile)) + } + + @Test + fun `given observable terminated with unexpected error, expected error value`() { + whenever(stubRepository.download(any())) + .thenReturn(Observable.error(stubTerminateException)) + + useCase("5598") + .test() + .assertError(stubTerminateException) + .await() + .assertNotComplete() + } +} diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalAiModelsUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalAiModelsUseCaseImplTest.kt new file mode 100644 index 00000000..679ca297 --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalAiModelsUseCaseImplTest.kt @@ -0,0 +1,55 @@ +package com.shifthackz.aisdv1.domain.usecase.downloadable + +import com.nhaarman.mockitokotlin2.mock +import com.nhaarman.mockitokotlin2.whenever +import com.shifthackz.aisdv1.domain.mocks.mockLocalAiModels +import com.shifthackz.aisdv1.domain.repository.DownloadableModelRepository +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class GetLocalAiModelsUseCaseImplTest { + + private val stubRepository = mock() + + private val useCase = GetLocalAiModelsUseCaseImpl(stubRepository) + + @Test + fun `given repository returned models list, expected valid models list value`() { + whenever(stubRepository.getAll()) + .thenReturn(Single.just(mockLocalAiModels)) + + useCase() + .test() + .assertNoErrors() + .assertValue(mockLocalAiModels) + .await() + .assertComplete() + } + + @Test + fun `given repository returned empty models list, expected empty models list value`() { + whenever(stubRepository.getAll()) + .thenReturn(Single.just(emptyList())) + + useCase() + .test() + .assertNoErrors() + .assertValue(emptyList()) + .await() + .assertComplete() + } + + @Test + fun `given repository thrown exception, expected error value`() { + val stubException = Throwable("Unable to collect local models.") + + whenever(stubRepository.getAll()) + .thenReturn(Single.error(stubException)) + + useCase() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } +} diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalAiModelsUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalAiModelsUseCaseImplTest.kt new file mode 100644 index 00000000..00fde27a --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalAiModelsUseCaseImplTest.kt @@ -0,0 +1,109 @@ +package com.shifthackz.aisdv1.domain.usecase.downloadable + +import com.nhaarman.mockitokotlin2.mock +import com.nhaarman.mockitokotlin2.whenever +import com.shifthackz.aisdv1.domain.entity.LocalAiModel +import com.shifthackz.aisdv1.domain.mocks.mockLocalAiModels +import com.shifthackz.aisdv1.domain.repository.DownloadableModelRepository +import io.reactivex.rxjava3.core.BackpressureStrategy +import io.reactivex.rxjava3.core.Flowable +import io.reactivex.rxjava3.subjects.BehaviorSubject +import org.junit.Before +import org.junit.Test + +class ObserveLocalAiModelsUseCaseImplTest { + + private val stubLocalModels = BehaviorSubject.create>() + private val stubRepository = mock() + + private val useCase = ObserveLocalAiModelsUseCaseImpl(stubRepository) + + @Before + fun initialize() { + whenever(stubRepository.observeAll()) + .thenReturn(stubLocalModels.toFlowable(BackpressureStrategy.LATEST)) + } + + @Test + fun `given repository has empty model list, then list inserted, expected receive empty list value, then valid list value`() { + val stubObserver = useCase().test() + + stubLocalModels.onNext(emptyList()) + + stubObserver + .assertNoErrors() + .assertValueAt(0, emptyList()) + + stubLocalModels.onNext(mockLocalAiModels) + + stubObserver + .assertNoErrors() + .assertValueAt(1, mockLocalAiModels) + } + + @Test + fun `given repository has model list, then clear, expected receive valid list value, then empty list value`() { + val stubObserver = useCase().test() + + stubLocalModels.onNext(mockLocalAiModels) + + stubObserver + .assertNoErrors() + .assertValueAt(0, mockLocalAiModels) + + stubLocalModels.onNext(emptyList()) + + stubObserver + .assertNoErrors() + .assertValueAt(1, emptyList()) + } + + @Test + fun `given repository has model list, then list changes, expected receive valid list value, then changed list value`() { + val stubObserver = useCase().test() + + stubLocalModels.onNext(mockLocalAiModels) + + stubObserver + .assertNoErrors() + .assertValueAt(0, mockLocalAiModels) + + val changedLocalAiModels = listOf(LocalAiModel.CUSTOM) + stubLocalModels.onNext(changedLocalAiModels) + + stubObserver + .assertNoErrors() + .assertValueAt(1, changedLocalAiModels) + } + + @Test + fun `given repository observer has model list, emits twice, expected receive valid list value once`() { + val stubObserver = useCase().test() + + stubLocalModels.onNext(mockLocalAiModels) + + stubObserver + .assertNoErrors() + .assertValueAt(0, mockLocalAiModels) + + stubLocalModels.onNext(mockLocalAiModels) + + stubObserver + .assertNoErrors() + .assertValueCount(1) + } + + @Test + fun `given observer terminates with unexpected error, expected receive error value`() { + val stubException = Throwable("Unexpected Flowable termination.") + + whenever(stubRepository.observeAll()) + .thenReturn(Flowable.error(stubException)) + + useCase() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } +} diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/gallery/DeleteGalleryItemUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/gallery/DeleteGalleryItemUseCaseImplTest.kt new file mode 100644 index 00000000..13cbfad4 --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/gallery/DeleteGalleryItemUseCaseImplTest.kt @@ -0,0 +1,41 @@ +package com.shifthackz.aisdv1.domain.usecase.gallery + +import com.nhaarman.mockitokotlin2.any +import com.nhaarman.mockitokotlin2.mock +import com.nhaarman.mockitokotlin2.whenever +import com.shifthackz.aisdv1.domain.repository.GenerationResultRepository +import io.reactivex.rxjava3.core.Completable +import org.junit.Test + +class DeleteGalleryItemUseCaseImplTest { + + private val stubRepository = mock() + + private val useCase = DeleteGalleryItemUseCaseImpl(stubRepository) + + @Test + fun `given repository deleted date successfully, expected complete`() { + whenever(stubRepository.deleteById(any())) + .thenReturn(Completable.complete()) + + useCase(5598L) + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given repository deleted date failed, expected error`() { + val stubException = Throwable("Database communication error.") + + whenever(stubRepository.deleteById(any())) + .thenReturn(Completable.error(stubException)) + + useCase(5598L) + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } +} diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/gallery/GetAllGalleryUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/gallery/GetAllGalleryUseCaseImplTest.kt new file mode 100644 index 00000000..4ba7d2c4 --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/gallery/GetAllGalleryUseCaseImplTest.kt @@ -0,0 +1,55 @@ +package com.shifthackz.aisdv1.domain.usecase.gallery + +import com.nhaarman.mockitokotlin2.mock +import com.nhaarman.mockitokotlin2.whenever +import com.shifthackz.aisdv1.domain.mocks.mockAiGenerationResults +import com.shifthackz.aisdv1.domain.repository.GenerationResultRepository +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class GetAllGalleryUseCaseImplTest { + + private val stubRepository = mock() + + private val useCase = GetAllGalleryUseCaseImpl(stubRepository) + + @Test + fun `given repository returned list of generations, expected valid list value`() { + whenever(stubRepository.getAll()) + .thenReturn(Single.just(mockAiGenerationResults)) + + useCase() + .test() + .assertNoErrors() + .assertValue(mockAiGenerationResults) + .await() + .assertComplete() + } + + @Test + fun `given repository returned empty list of generations, expected empty list value`() { + whenever(stubRepository.getAll()) + .thenReturn(Single.just(emptyList())) + + useCase() + .test() + .assertNoErrors() + .assertValue(emptyList()) + .await() + .assertComplete() + } + + @Test + fun `given repository thrown exception, expected error value`() { + val stubException = Throwable("Database communication error.") + + whenever(stubRepository.getAll()) + .thenReturn(Single.error(stubException)) + + useCase() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } +} diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/gallery/GetMediaStoreInfoUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/gallery/GetMediaStoreInfoUseCaseImplTest.kt new file mode 100644 index 00000000..18a11c0e --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/gallery/GetMediaStoreInfoUseCaseImplTest.kt @@ -0,0 +1,56 @@ +package com.shifthackz.aisdv1.domain.usecase.gallery + +import com.nhaarman.mockitokotlin2.mock +import com.nhaarman.mockitokotlin2.whenever +import com.shifthackz.aisdv1.domain.entity.MediaStoreInfo +import com.shifthackz.aisdv1.domain.mocks.mockMediaStoreInfo +import com.shifthackz.aisdv1.domain.repository.GenerationResultRepository +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class GetMediaStoreInfoUseCaseImplTest { + + private val stubRepository = mock() + + private val useCase = GetMediaStoreInfoUseCaseImpl(stubRepository) + + @Test + fun `given repository provided media store info, expected valid media store info`() { + whenever(stubRepository.getMediaStoreInfo()) + .thenReturn(Single.just(mockMediaStoreInfo)) + + useCase() + .test() + .assertNoErrors() + .assertValue(mockMediaStoreInfo) + .await() + .assertComplete() + } + + @Test + fun `given repository provided empty media store info, expected default media store info`() { + whenever(stubRepository.getMediaStoreInfo()) + .thenReturn(Single.just(MediaStoreInfo())) + + useCase() + .test() + .assertNoErrors() + .assertValue(MediaStoreInfo()) + .await() + .assertComplete() + } + + @Test + fun `given repository thrown exception, expected error value`() { + val stubException = Throwable("Error communicating with MediaStore.") + + whenever(stubRepository.getMediaStoreInfo()) + .thenReturn(Single.error(stubException)) + + useCase() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } +} diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/GetGenerationResultPagedUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/GetGenerationResultPagedUseCaseImplTest.kt new file mode 100644 index 00000000..5ca053da --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/GetGenerationResultPagedUseCaseImplTest.kt @@ -0,0 +1,55 @@ +package com.shifthackz.aisdv1.domain.usecase.generation + +import com.nhaarman.mockitokotlin2.any +import com.nhaarman.mockitokotlin2.mock +import com.nhaarman.mockitokotlin2.whenever +import com.shifthackz.aisdv1.domain.mocks.mockAiGenerationResults +import com.shifthackz.aisdv1.domain.repository.GenerationResultRepository +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class GetGenerationResultPagedUseCaseImplTest { + + private val stubException = Throwable("Can not read DB.") + private val stubRepository = mock() + + private val useCase = GetGenerationResultPagedUseCaseImpl(stubRepository) + + @Test + fun `given repository returned page with items, expected valid list value`() { + whenever(stubRepository.getPage(any(), any())) + .thenReturn(Single.just(mockAiGenerationResults)) + + useCase(20, 0) + .test() + .assertNoErrors() + .assertValue(mockAiGenerationResults) + .await() + .assertComplete() + } + + @Test + fun `given repository returned empty page with no items, expected empty list value`() { + whenever(stubRepository.getPage(any(), any())) + .thenReturn(Single.just(emptyList())) + + useCase(20, 5598) + .test() + .assertNoErrors() + .assertValue(emptyList()) + .await() + .assertComplete() + } + + @Test + fun `given repository thrown exception, expected error value`() { + whenever(stubRepository.getPage(any(), any())) + .thenReturn(Single.error(stubException)) + + useCase(20, 5598) + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } +} diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/GetGenerationResultUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/GetGenerationResultUseCaseImplTest.kt new file mode 100644 index 00000000..07a2f519 --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/GetGenerationResultUseCaseImplTest.kt @@ -0,0 +1,42 @@ +package com.shifthackz.aisdv1.domain.usecase.generation + +import com.nhaarman.mockitokotlin2.any +import com.nhaarman.mockitokotlin2.mock +import com.nhaarman.mockitokotlin2.whenever +import com.shifthackz.aisdv1.domain.mocks.mockAiGenerationResult +import com.shifthackz.aisdv1.domain.repository.GenerationResultRepository +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class GetGenerationResultUseCaseImplTest { + + private val stubException = Throwable("Ai generation result not found.") + private val stubRepository = mock() + + private val useCase = GetGenerationResultUseCaseImpl(stubRepository) + + @Test + fun `given repository has ai result with provided id, expected valid ai generation result value`() { + whenever(stubRepository.getById(any())) + .thenReturn(Single.just(mockAiGenerationResult)) + + useCase(5598L) + .test() + .assertNoErrors() + .assertValue(mockAiGenerationResult) + .await() + .assertComplete() + } + + @Test + fun `given repository has no ai result with provided id, expected error value`() { + whenever(stubRepository.getById(any())) + .thenReturn(Single.error(stubException)) + + useCase(5598L) + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } +} diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/GetRandomImageUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/GetRandomImageUseCaseImplTest.kt new file mode 100644 index 00000000..8cbe8a82 --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/GetRandomImageUseCaseImplTest.kt @@ -0,0 +1,42 @@ +package com.shifthackz.aisdv1.domain.usecase.generation + +import android.graphics.Bitmap +import com.nhaarman.mockitokotlin2.mock +import com.nhaarman.mockitokotlin2.whenever +import com.shifthackz.aisdv1.domain.repository.RandomImageRepository +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class GetRandomImageUseCaseImplTest { + + private val stubException = Throwable("Can not generate random image.") + private val stubBitmap = mock() + private val stubRepository = mock() + + private val useCase = GetRandomImageUseCaseImpl(stubRepository) + + @Test + fun `given repository provided bitmap with random image, expected valid bitmap value`() { + whenever(stubRepository.fetchAndGet()) + .thenReturn(Single.just(stubBitmap)) + + useCase() + .test() + .assertNoErrors() + .assertValue(stubBitmap) + .await() + .assertComplete() + } + + @Test + fun `given repository thrown exception, expected error value`() { + whenever(stubRepository.fetchAndGet()) + .thenReturn(Single.error(stubException)) + + useCase() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } +} diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/ImageToImageUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/ImageToImageUseCaseImplTest.kt new file mode 100644 index 00000000..2a1544b3 --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/ImageToImageUseCaseImplTest.kt @@ -0,0 +1,319 @@ +package com.shifthackz.aisdv1.domain.usecase.generation + +import com.nhaarman.mockitokotlin2.any +import com.nhaarman.mockitokotlin2.mock +import com.nhaarman.mockitokotlin2.whenever +import com.shifthackz.aisdv1.domain.entity.ServerSource +import com.shifthackz.aisdv1.domain.mocks.mockAiGenerationResult +import com.shifthackz.aisdv1.domain.mocks.mockImageToImagePayload +import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import com.shifthackz.aisdv1.domain.repository.HordeGenerationRepository +import com.shifthackz.aisdv1.domain.repository.HuggingFaceGenerationRepository +import com.shifthackz.aisdv1.domain.repository.StabilityAiGenerationRepository +import com.shifthackz.aisdv1.domain.repository.StableDiffusionGenerationRepository +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class ImageToImageUseCaseImplTest { + + private val stubException = Throwable("Unable to generate image.") + private val stubStableDiffusionGenerationRepository = mock() + private val stubHordeGenerationRepository = mock() + private val stubHuggingFaceGenerationRepository = mock() + private val stubStabilityAiGenerationRepository = mock() + private val stubPreferenceManager = mock() + + private val useCase = ImageToImageUseCaseImpl( + stableDiffusionGenerationRepository = stubStableDiffusionGenerationRepository, + hordeGenerationRepository = stubHordeGenerationRepository, + huggingFaceGenerationRepository = stubHuggingFaceGenerationRepository, + stabilityAiGenerationRepository = stubStabilityAiGenerationRepository, + preferenceManager = stubPreferenceManager, + ) + + @Test + fun `given source is AUTOMATIC1111, batch count is 1, generated successfully, expected generations list with size 1`() { + whenever(stubPreferenceManager.source) + .thenReturn(ServerSource.AUTOMATIC1111) + + whenever(stubStableDiffusionGenerationRepository.generateFromImage(any())) + .thenReturn(Single.just(mockAiGenerationResult)) + + val stubBatchCount = 1 + val stubPayload = mockImageToImagePayload.copy(batchCount = stubBatchCount) + + val expectedResult = listOf(mockAiGenerationResult) + + useCase(stubPayload) + .test() + .assertNoErrors() + .assertValue { generations -> + generations.size == stubBatchCount && expectedResult == generations + } + .await() + .assertComplete() + } + + @Test + fun `given source is AUTOMATIC1111, batch count is 10, generated successfully, expected generations list with size 10`() { + whenever(stubPreferenceManager.source) + .thenReturn(ServerSource.AUTOMATIC1111) + + whenever(stubStableDiffusionGenerationRepository.generateFromImage(any())) + .thenReturn(Single.just(mockAiGenerationResult)) + + val stubBatchCount = 10 + val stubPayload = mockImageToImagePayload.copy(batchCount = stubBatchCount) + + val expectedResult = (0 until 10).map { mockAiGenerationResult } + + useCase(stubPayload) + .test() + .assertNoErrors() + .assertValue { generations -> + generations.size == stubBatchCount && expectedResult == generations + } + .await() + .assertComplete() + } + + @Test + fun `given source is AUTOMATIC1111, batch count is 1, generate failed, expected error`() { + whenever(stubPreferenceManager.source) + .thenReturn(ServerSource.AUTOMATIC1111) + + whenever(stubStableDiffusionGenerationRepository.generateFromImage(any())) + .thenReturn(Single.error(stubException)) + + val stubBatchCount = 1 + val stubPayload = mockImageToImagePayload.copy(batchCount = stubBatchCount) + + useCase(stubPayload) + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given source is HORDE, batch count is 1, generated successfully, expected generations list with size 1`() { + whenever(stubPreferenceManager.source) + .thenReturn(ServerSource.HORDE) + + whenever(stubHordeGenerationRepository.generateFromImage(any())) + .thenReturn(Single.just(mockAiGenerationResult)) + + val stubBatchCount = 1 + val stubPayload = mockImageToImagePayload.copy(batchCount = stubBatchCount) + + val expectedResult = listOf(mockAiGenerationResult) + + useCase(stubPayload) + .test() + .assertNoErrors() + .assertValue { generations -> + generations.size == stubBatchCount && expectedResult == generations + } + .await() + .assertComplete() + } + + @Test + fun `given source is HORDE, batch count is 10, generated successfully, expected generations list with size 10`() { + whenever(stubPreferenceManager.source) + .thenReturn(ServerSource.HORDE) + + whenever(stubHordeGenerationRepository.generateFromImage(any())) + .thenReturn(Single.just(mockAiGenerationResult)) + + val stubBatchCount = 10 + val stubPayload = mockImageToImagePayload.copy(batchCount = stubBatchCount) + + val expectedResult = (0 until 10).map { mockAiGenerationResult } + + useCase(stubPayload) + .test() + .assertNoErrors() + .assertValue { generations -> + generations.size == stubBatchCount && expectedResult == generations + } + .await() + .assertComplete() + } + + @Test + fun `given source is HORDE, batch count is 1, generate failed, expected error`() { + whenever(stubPreferenceManager.source) + .thenReturn(ServerSource.HORDE) + + whenever(stubHordeGenerationRepository.generateFromImage(any())) + .thenReturn(Single.error(stubException)) + + val stubBatchCount = 1 + val stubPayload = mockImageToImagePayload.copy(batchCount = stubBatchCount) + + useCase(stubPayload) + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given source is HUGGING_FACE, batch count is 1, generated successfully, expected generations list with size 1`() { + whenever(stubPreferenceManager.source) + .thenReturn(ServerSource.HUGGING_FACE) + + whenever(stubHuggingFaceGenerationRepository.generateFromImage(any())) + .thenReturn(Single.just(mockAiGenerationResult)) + + val stubBatchCount = 1 + val stubPayload = mockImageToImagePayload.copy(batchCount = stubBatchCount) + + val expectedResult = listOf(mockAiGenerationResult) + + useCase(stubPayload) + .test() + .assertNoErrors() + .assertValue { generations -> + generations.size == stubBatchCount && expectedResult == generations + } + .await() + .assertComplete() + } + + @Test + fun `given source is HUGGING_FACE, batch count is 10, generated successfully, expected generations list with size 10`() { + whenever(stubPreferenceManager.source) + .thenReturn(ServerSource.HUGGING_FACE) + + whenever(stubHuggingFaceGenerationRepository.generateFromImage(any())) + .thenReturn(Single.just(mockAiGenerationResult)) + + val stubBatchCount = 10 + val stubPayload = mockImageToImagePayload.copy(batchCount = stubBatchCount) + + val expectedResult = (0 until 10).map { mockAiGenerationResult } + + useCase(stubPayload) + .test() + .assertNoErrors() + .assertValue { generations -> + generations.size == stubBatchCount && expectedResult == generations + } + .await() + .assertComplete() + } + + @Test + fun `given source is HUGGING_FACE, batch count is 1, generate failed, expected error`() { + whenever(stubPreferenceManager.source) + .thenReturn(ServerSource.HUGGING_FACE) + + whenever(stubHuggingFaceGenerationRepository.generateFromImage(any())) + .thenReturn(Single.error(stubException)) + + val stubBatchCount = 1 + val stubPayload = mockImageToImagePayload.copy(batchCount = stubBatchCount) + + useCase(stubPayload) + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given source is STABILITY_AI, batch count is 1, generated successfully, expected generations list with size 1`() { + whenever(stubPreferenceManager.source) + .thenReturn(ServerSource.STABILITY_AI) + + whenever(stubStabilityAiGenerationRepository.generateFromImage(any())) + .thenReturn(Single.just(mockAiGenerationResult)) + + val stubBatchCount = 1 + val stubPayload = mockImageToImagePayload.copy(batchCount = stubBatchCount) + + val expectedResult = listOf(mockAiGenerationResult) + + useCase(stubPayload) + .test() + .assertNoErrors() + .assertValue { generations -> + generations.size == stubBatchCount && expectedResult == generations + } + .await() + .assertComplete() + } + + @Test + fun `given source is STABILITY_AI, batch count is 10, generated successfully, expected generations list with size 10`() { + whenever(stubPreferenceManager.source) + .thenReturn(ServerSource.STABILITY_AI) + + whenever(stubStabilityAiGenerationRepository.generateFromImage(any())) + .thenReturn(Single.just(mockAiGenerationResult)) + + val stubBatchCount = 10 + val stubPayload = mockImageToImagePayload.copy(batchCount = stubBatchCount) + + val expectedResult = (0 until 10).map { mockAiGenerationResult } + + useCase(stubPayload) + .test() + .assertNoErrors() + .assertValue { generations -> + generations.size == stubBatchCount && expectedResult == generations + } + .await() + .assertComplete() + } + + @Test + fun `given source is STABILITY_AI, batch count is 1, generate failed, expected error`() { + whenever(stubPreferenceManager.source) + .thenReturn(ServerSource.STABILITY_AI) + + whenever(stubStabilityAiGenerationRepository.generateFromImage(any())) + .thenReturn(Single.error(stubException)) + + val stubBatchCount = 1 + val stubPayload = mockImageToImagePayload.copy(batchCount = stubBatchCount) + + useCase(stubPayload) + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given source is LOCAL, expected Img2Img not yet supported error`() { + whenever(stubPreferenceManager.source) + .thenReturn(ServerSource.LOCAL) + + useCase(mockImageToImagePayload) + .test() + .assertError { + it is IllegalStateException + && it.message?.startsWith("Img2Img not yet supported") == true + } + .await() + .assertNotComplete() + } + + @Test + fun `given source is OPEN_AI, expected Img2Img not yet supported error`() { + whenever(stubPreferenceManager.source) + .thenReturn(ServerSource.OPEN_AI) + + useCase(mockImageToImagePayload) + .test() + .assertError { + it is IllegalStateException + && it.message?.startsWith("Img2Img not yet supported") == true + } + .await() + .assertNotComplete() + } +} diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/InterruptGenerationUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/InterruptGenerationUseCaseImplTest.kt new file mode 100644 index 00000000..5ef6334c --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/InterruptGenerationUseCaseImplTest.kt @@ -0,0 +1,155 @@ +package com.shifthackz.aisdv1.domain.usecase.generation + +import com.nhaarman.mockitokotlin2.mock +import com.nhaarman.mockitokotlin2.whenever +import com.shifthackz.aisdv1.domain.entity.ServerSource +import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import com.shifthackz.aisdv1.domain.repository.HordeGenerationRepository +import com.shifthackz.aisdv1.domain.repository.LocalDiffusionGenerationRepository +import com.shifthackz.aisdv1.domain.repository.StableDiffusionGenerationRepository +import io.reactivex.rxjava3.core.Completable +import org.junit.Test + +class InterruptGenerationUseCaseImplTest { + + private val stubException = Throwable("Can not interrupt generation.") + private val stubStableDiffusionGenerationRepository = mock() + private val stubHordeGenerationRepository = mock() + private val stubLocalDiffusionGenerationRepository = mock() + private val stubPreferenceManager = mock() + + private val useCase = InterruptGenerationUseCaseImpl( + stableDiffusionGenerationRepository = stubStableDiffusionGenerationRepository, + hordeGenerationRepository = stubHordeGenerationRepository, + localDiffusionGenerationRepository = stubLocalDiffusionGenerationRepository, + preferenceManager = stubPreferenceManager, + ) + + @Test + fun `given source is AUTOMATIC1111, api interrupt success, expected complete value`() { + whenever(stubPreferenceManager.source) + .thenReturn(ServerSource.AUTOMATIC1111) + + whenever(stubStableDiffusionGenerationRepository.interruptGeneration()) + .thenReturn(Completable.complete()) + + useCase() + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given source is AUTOMATIC1111, api interrupt fail, expected error value`() { + whenever(stubPreferenceManager.source) + .thenReturn(ServerSource.AUTOMATIC1111) + + whenever(stubStableDiffusionGenerationRepository.interruptGeneration()) + .thenReturn(Completable.error(stubException)) + + useCase() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given source is HORDE, api interrupt success, expected complete value`() { + whenever(stubPreferenceManager.source) + .thenReturn(ServerSource.HORDE) + + whenever(stubHordeGenerationRepository.interruptGeneration()) + .thenReturn(Completable.complete()) + + useCase() + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given source is HORDE, api interrupt fail, expected error value`() { + whenever(stubPreferenceManager.source) + .thenReturn(ServerSource.HORDE) + + whenever(stubHordeGenerationRepository.interruptGeneration()) + .thenReturn(Completable.error(stubException)) + + useCase() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given source is LOCAL, api interrupt success, expected complete value`() { + whenever(stubPreferenceManager.source) + .thenReturn(ServerSource.LOCAL) + + whenever(stubLocalDiffusionGenerationRepository.interruptGeneration()) + .thenReturn(Completable.complete()) + + useCase() + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given source is LOCAL, api interrupt fail, expected error value`() { + whenever(stubPreferenceManager.source) + .thenReturn(ServerSource.LOCAL) + + whenever(stubLocalDiffusionGenerationRepository.interruptGeneration()) + .thenReturn(Completable.error(stubException)) + + useCase() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + //-- + + @Test + fun `given source is HUGGING_FACE, expected complete value`() { + whenever(stubPreferenceManager.source) + .thenReturn(ServerSource.HUGGING_FACE) + + useCase() + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given source is OPEN_AI, expected complete value`() { + whenever(stubPreferenceManager.source) + .thenReturn(ServerSource.OPEN_AI) + + useCase() + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given source is STABILITY_AI, expected complete value`() { + whenever(stubPreferenceManager.source) + .thenReturn(ServerSource.STABILITY_AI) + + useCase() + .test() + .assertNoErrors() + .await() + .assertComplete() + } +} diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/ObserveHordeProcessStatusUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/ObserveHordeProcessStatusUseCaseImplTest.kt new file mode 100644 index 00000000..094b0879 --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/ObserveHordeProcessStatusUseCaseImplTest.kt @@ -0,0 +1,73 @@ +package com.shifthackz.aisdv1.domain.usecase.generation + +import com.nhaarman.mockitokotlin2.mock +import com.nhaarman.mockitokotlin2.whenever +import com.shifthackz.aisdv1.domain.entity.HordeProcessStatus +import com.shifthackz.aisdv1.domain.repository.HordeGenerationRepository +import io.reactivex.rxjava3.core.BackpressureStrategy +import io.reactivex.rxjava3.core.Flowable +import io.reactivex.rxjava3.subjects.BehaviorSubject +import org.junit.Before +import org.junit.Test + +class ObserveHordeProcessStatusUseCaseImplTest { + + private val stubException = Throwable("Error communicating with Horde.") + private val stubHordeStatus = BehaviorSubject.create() + private val stubRepository = mock() + + private val useCase = ObserveHordeProcessStatusUseCaseImpl(stubRepository) + + @Before + fun initialize() { + whenever(stubRepository.observeStatus()) + .thenReturn(stubHordeStatus.toFlowable(BackpressureStrategy.LATEST)) + } + + @Test + fun `given repository emits two different values, expected two valid values`() { + val stubObserver = useCase().test() + + stubHordeStatus.onNext(HordeProcessStatus(5598, 1504)) + + stubObserver + .assertNoErrors() + .assertValueAt(0, HordeProcessStatus(5598, 1504)) + + stubHordeStatus.onNext(HordeProcessStatus(0, 0)) + + stubObserver + .assertNoErrors() + .assertValueAt(1, HordeProcessStatus(0, 0)) + .assertValueCount(2) + } + + @Test + fun `given repository emits two same values, expected one valid value`() { + val stubObserver = useCase().test() + + stubHordeStatus.onNext(HordeProcessStatus(5598, 1504)) + + stubObserver + .assertNoErrors() + .assertValueAt(0, HordeProcessStatus(5598, 1504)) + + stubHordeStatus.onNext(HordeProcessStatus(5598, 1504)) + + stubObserver + .assertNoErrors() + .assertValueCount(1) + } + + @Test + fun `given repository throws exception, expected error value`() { + whenever(stubRepository.observeStatus()) + .thenReturn(Flowable.error(stubException)) + + useCase() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } +} diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/ObserveLocalDiffusionProcessStatusUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/ObserveLocalDiffusionProcessStatusUseCaseImplTest.kt new file mode 100644 index 00000000..528e2e6e --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/ObserveLocalDiffusionProcessStatusUseCaseImplTest.kt @@ -0,0 +1,86 @@ +package com.shifthackz.aisdv1.domain.usecase.generation + +import com.nhaarman.mockitokotlin2.mock +import com.nhaarman.mockitokotlin2.whenever +import com.shifthackz.aisdv1.domain.feature.diffusion.LocalDiffusion +import com.shifthackz.aisdv1.domain.repository.LocalDiffusionGenerationRepository +import io.reactivex.rxjava3.core.Observable +import io.reactivex.rxjava3.subjects.BehaviorSubject +import org.junit.Before +import org.junit.Test + +class ObserveLocalDiffusionProcessStatusUseCaseImplTest { + + private val stubException = Throwable("Error loading Local Diffusion.") + private val stubLocalStatus = BehaviorSubject.create() + private val stubRepository = mock() + + private val useCase = ObserveLocalDiffusionProcessStatusUseCaseImpl(stubRepository) + + @Before + fun initialize() { + whenever(stubRepository.observeStatus()) + .thenReturn(stubLocalStatus) + } + + @Test + fun `given repository processes three steps, expected three valid status values`() { + val stubObserver = useCase().test() + + stubLocalStatus.onNext(LocalDiffusion.Status(1, 3)) + + stubObserver + .assertNoErrors() + .assertValueAt(0, LocalDiffusion.Status(1, 3)) + + stubLocalStatus.onNext(LocalDiffusion.Status(2, 3)) + + stubObserver + .assertNoErrors() + .assertValueAt(1, LocalDiffusion.Status(2, 3)) + + stubLocalStatus.onNext(LocalDiffusion.Status(3, 3)) + + stubObserver + .assertNoErrors() + .assertValueAt(2, LocalDiffusion.Status(3, 3)) + .assertValueCount(3) + } + + @Test + fun `given repository processes two steps, emits same step twice, expected two valid status values`() { + val stubObserver = useCase().test() + + stubLocalStatus.onNext(LocalDiffusion.Status(1, 2)) + + stubObserver + .assertNoErrors() + .assertValueAt(0, LocalDiffusion.Status(1, 2)) + + stubLocalStatus.onNext(LocalDiffusion.Status(1, 2)) + + stubObserver + .assertNoErrors() + .assertValueCount(1) + + stubLocalStatus.onNext(LocalDiffusion.Status(2, 2)) + + stubObserver + .assertNoErrors() + .assertValueAt(1, LocalDiffusion.Status(2, 2)) + .assertValueCount(2) + } + + @Test + fun `given repository throws exception, expected error value`() { + whenever(stubRepository.observeStatus()) + .thenReturn(Observable.error(stubException)) + + useCase() + .test() + .assertError(stubException) + .await() + .assertValueCount(0) + .assertNotComplete() + } +} diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/SaveGenerationResultUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/SaveGenerationResultUseCaseImplTest.kt new file mode 100644 index 00000000..750dc81b --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/SaveGenerationResultUseCaseImplTest.kt @@ -0,0 +1,41 @@ +package com.shifthackz.aisdv1.domain.usecase.generation + +import com.nhaarman.mockitokotlin2.any +import com.nhaarman.mockitokotlin2.mock +import com.nhaarman.mockitokotlin2.whenever +import com.shifthackz.aisdv1.domain.mocks.mockAiGenerationResult +import com.shifthackz.aisdv1.domain.repository.GenerationResultRepository +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class SaveGenerationResultUseCaseImplTest { + + private val stubException = Throwable("Error inserting into DB.") + private val stubRepository = mock() + + private val useCase = SaveGenerationResultUseCaseImpl(stubRepository) + + @Test + fun `given repository saved generation result successfully, expected complete value`() { + whenever(stubRepository.insert(any())) + .thenReturn(Single.just(5598L)) + + useCase(mockAiGenerationResult) + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `given repository failed to save generation result, expected error value`() { + whenever(stubRepository.insert(any())) + .thenReturn(Single.error(stubException)) + + useCase(mockAiGenerationResult) + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } +} diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImplTest.kt new file mode 100644 index 00000000..1283e409 --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImplTest.kt @@ -0,0 +1,423 @@ +package com.shifthackz.aisdv1.domain.usecase.generation + +import com.nhaarman.mockitokotlin2.any +import com.nhaarman.mockitokotlin2.mock +import com.nhaarman.mockitokotlin2.whenever +import com.shifthackz.aisdv1.domain.entity.ServerSource +import com.shifthackz.aisdv1.domain.mocks.mockAiGenerationResult +import com.shifthackz.aisdv1.domain.mocks.mockTextToImagePayload +import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import com.shifthackz.aisdv1.domain.repository.HordeGenerationRepository +import com.shifthackz.aisdv1.domain.repository.HuggingFaceGenerationRepository +import com.shifthackz.aisdv1.domain.repository.LocalDiffusionGenerationRepository +import com.shifthackz.aisdv1.domain.repository.OpenAiGenerationRepository +import com.shifthackz.aisdv1.domain.repository.StabilityAiGenerationRepository +import com.shifthackz.aisdv1.domain.repository.StableDiffusionGenerationRepository +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class TextToImageUseCaseImplTest { + + private val stubException = Throwable("Unable to generate image.") + private val stubStableDiffusionGenerationRepository = mock() + private val stubHordeGenerationRepository = mock() + private val stubHuggingFaceGenerationRepository = mock() + private val stubOpenAiGenerationRepository = mock() + private val stubStabilityAiGenerationRepository = mock() + private val stubLocalDiffusionGenerationRepository = mock() + private val stubPreferenceManager = mock() + + private val useCase = TextToImageUseCaseImpl( + stableDiffusionGenerationRepository = stubStableDiffusionGenerationRepository, + hordeGenerationRepository = stubHordeGenerationRepository, + huggingFaceGenerationRepository = stubHuggingFaceGenerationRepository, + openAiGenerationRepository = stubOpenAiGenerationRepository, + stabilityAiGenerationRepository = stubStabilityAiGenerationRepository, + localDiffusionGenerationRepository = stubLocalDiffusionGenerationRepository, + preferenceManager = stubPreferenceManager, + ) + + @Test + fun `given source is AUTOMATIC1111, batch count is 1, generated successfully, expected generations list with size 1`() { + whenever(stubPreferenceManager.source) + .thenReturn(ServerSource.AUTOMATIC1111) + + whenever(stubStableDiffusionGenerationRepository.generateFromText(any())) + .thenReturn(Single.just(mockAiGenerationResult)) + + val stubBatchCount = 1 + val stubPayload = mockTextToImagePayload.copy(batchCount = stubBatchCount) + + val expectedResult = listOf(mockAiGenerationResult) + + useCase(stubPayload) + .test() + .assertNoErrors() + .assertValue { generations -> + generations.size == stubBatchCount && expectedResult == generations + } + .await() + .assertComplete() + } + + @Test + fun `given source is AUTOMATIC1111, batch count is 10, generated successfully, expected generations list with size 10`() { + whenever(stubPreferenceManager.source) + .thenReturn(ServerSource.AUTOMATIC1111) + + whenever(stubStableDiffusionGenerationRepository.generateFromText(any())) + .thenReturn(Single.just(mockAiGenerationResult)) + + val stubBatchCount = 10 + val stubPayload = mockTextToImagePayload.copy(batchCount = stubBatchCount) + + val expectedResult = (0 until 10).map { mockAiGenerationResult } + + useCase(stubPayload) + .test() + .assertNoErrors() + .assertValue { generations -> + generations.size == stubBatchCount && expectedResult == generations + } + .await() + .assertComplete() + } + + @Test + fun `given source is AUTOMATIC1111, batch count is 1, generate failed, expected error`() { + whenever(stubPreferenceManager.source) + .thenReturn(ServerSource.AUTOMATIC1111) + + whenever(stubStableDiffusionGenerationRepository.generateFromText(any())) + .thenReturn(Single.error(stubException)) + + val stubBatchCount = 1 + val stubPayload = mockTextToImagePayload.copy(batchCount = stubBatchCount) + + useCase(stubPayload) + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given source is HORDE, batch count is 1, generated successfully, expected generations list with size 1`() { + whenever(stubPreferenceManager.source) + .thenReturn(ServerSource.HORDE) + + whenever(stubHordeGenerationRepository.generateFromText(any())) + .thenReturn(Single.just(mockAiGenerationResult)) + + val stubBatchCount = 1 + val stubPayload = mockTextToImagePayload.copy(batchCount = stubBatchCount) + + val expectedResult = listOf(mockAiGenerationResult) + + useCase(stubPayload) + .test() + .assertNoErrors() + .assertValue { generations -> + generations.size == stubBatchCount && expectedResult == generations + } + .await() + .assertComplete() + } + + @Test + fun `given source is HORDE, batch count is 10, generated successfully, expected generations list with size 10`() { + whenever(stubPreferenceManager.source) + .thenReturn(ServerSource.HORDE) + + whenever(stubHordeGenerationRepository.generateFromText(any())) + .thenReturn(Single.just(mockAiGenerationResult)) + + val stubBatchCount = 10 + val stubPayload = mockTextToImagePayload.copy(batchCount = stubBatchCount) + + val expectedResult = (0 until 10).map { mockAiGenerationResult } + + useCase(stubPayload) + .test() + .assertNoErrors() + .assertValue { generations -> + generations.size == stubBatchCount && expectedResult == generations + } + .await() + .assertComplete() + } + + @Test + fun `given source is HORDE, batch count is 1, generate failed, expected error`() { + whenever(stubPreferenceManager.source) + .thenReturn(ServerSource.HORDE) + + whenever(stubHordeGenerationRepository.generateFromText(any())) + .thenReturn(Single.error(stubException)) + + val stubBatchCount = 1 + val stubPayload = mockTextToImagePayload.copy(batchCount = stubBatchCount) + + useCase(stubPayload) + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given source is HUGGING_FACE, batch count is 1, generated successfully, expected generations list with size 1`() { + whenever(stubPreferenceManager.source) + .thenReturn(ServerSource.HUGGING_FACE) + + whenever(stubHuggingFaceGenerationRepository.generateFromText(any())) + .thenReturn(Single.just(mockAiGenerationResult)) + + val stubBatchCount = 1 + val stubPayload = mockTextToImagePayload.copy(batchCount = stubBatchCount) + + val expectedResult = listOf(mockAiGenerationResult) + + useCase(stubPayload) + .test() + .assertNoErrors() + .assertValue { generations -> + generations.size == stubBatchCount && expectedResult == generations + } + .await() + .assertComplete() + } + + @Test + fun `given source is HUGGING_FACE, batch count is 10, generated successfully, expected generations list with size 10`() { + whenever(stubPreferenceManager.source) + .thenReturn(ServerSource.HUGGING_FACE) + + whenever(stubHuggingFaceGenerationRepository.generateFromText(any())) + .thenReturn(Single.just(mockAiGenerationResult)) + + val stubBatchCount = 10 + val stubPayload = mockTextToImagePayload.copy(batchCount = stubBatchCount) + + val expectedResult = (0 until 10).map { mockAiGenerationResult } + + useCase(stubPayload) + .test() + .assertNoErrors() + .assertValue { generations -> + generations.size == stubBatchCount && expectedResult == generations + } + .await() + .assertComplete() + } + + @Test + fun `given source is HUGGING_FACE, batch count is 1, generate failed, expected error`() { + whenever(stubPreferenceManager.source) + .thenReturn(ServerSource.HUGGING_FACE) + + whenever(stubHuggingFaceGenerationRepository.generateFromText(any())) + .thenReturn(Single.error(stubException)) + + val stubBatchCount = 1 + val stubPayload = mockTextToImagePayload.copy(batchCount = stubBatchCount) + + useCase(stubPayload) + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given source is OPEN_AI, batch count is 1, generated successfully, expected generations list with size 1`() { + whenever(stubPreferenceManager.source) + .thenReturn(ServerSource.OPEN_AI) + + whenever(stubOpenAiGenerationRepository.generateFromText(any())) + .thenReturn(Single.just(mockAiGenerationResult)) + + val stubBatchCount = 1 + val stubPayload = mockTextToImagePayload.copy(batchCount = stubBatchCount) + + val expectedResult = listOf(mockAiGenerationResult) + + useCase(stubPayload) + .test() + .assertNoErrors() + .assertValue { generations -> + generations.size == stubBatchCount && expectedResult == generations + } + .await() + .assertComplete() + } + + @Test + fun `given source is OPEN_AI, batch count is 10, generated successfully, expected generations list with size 10`() { + whenever(stubPreferenceManager.source) + .thenReturn(ServerSource.OPEN_AI) + + whenever(stubOpenAiGenerationRepository.generateFromText(any())) + .thenReturn(Single.just(mockAiGenerationResult)) + + val stubBatchCount = 10 + val stubPayload = mockTextToImagePayload.copy(batchCount = stubBatchCount) + + val expectedResult = (0 until 10).map { mockAiGenerationResult } + + useCase(stubPayload) + .test() + .assertNoErrors() + .assertValue { generations -> + generations.size == stubBatchCount && expectedResult == generations + } + .await() + .assertComplete() + } + + @Test + fun `given source is OPEN_AI, batch count is 1, generate failed, expected error`() { + whenever(stubPreferenceManager.source) + .thenReturn(ServerSource.OPEN_AI) + + whenever(stubOpenAiGenerationRepository.generateFromText(any())) + .thenReturn(Single.error(stubException)) + + val stubBatchCount = 1 + val stubPayload = mockTextToImagePayload.copy(batchCount = stubBatchCount) + + useCase(stubPayload) + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given source is STABILITY_AI, batch count is 1, generated successfully, expected generations list with size 1`() { + whenever(stubPreferenceManager.source) + .thenReturn(ServerSource.STABILITY_AI) + + whenever(stubStabilityAiGenerationRepository.generateFromText(any())) + .thenReturn(Single.just(mockAiGenerationResult)) + + val stubBatchCount = 1 + val stubPayload = mockTextToImagePayload.copy(batchCount = stubBatchCount) + + val expectedResult = listOf(mockAiGenerationResult) + + useCase(stubPayload) + .test() + .assertNoErrors() + .assertValue { generations -> + generations.size == stubBatchCount && expectedResult == generations + } + .await() + .assertComplete() + } + + @Test + fun `given source is STABILITY_AI, batch count is 10, generated successfully, expected generations list with size 10`() { + whenever(stubPreferenceManager.source) + .thenReturn(ServerSource.STABILITY_AI) + + whenever(stubStabilityAiGenerationRepository.generateFromText(any())) + .thenReturn(Single.just(mockAiGenerationResult)) + + val stubBatchCount = 10 + val stubPayload = mockTextToImagePayload.copy(batchCount = stubBatchCount) + + val expectedResult = (0 until 10).map { mockAiGenerationResult } + + useCase(stubPayload) + .test() + .assertNoErrors() + .assertValue { generations -> + generations.size == stubBatchCount && expectedResult == generations + } + .await() + .assertComplete() + } + + @Test + fun `given source is STABILITY_AI, batch count is 1, generate failed, expected error`() { + whenever(stubPreferenceManager.source) + .thenReturn(ServerSource.STABILITY_AI) + + whenever(stubStabilityAiGenerationRepository.generateFromText(any())) + .thenReturn(Single.error(stubException)) + + val stubBatchCount = 1 + val stubPayload = mockTextToImagePayload.copy(batchCount = stubBatchCount) + + useCase(stubPayload) + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given source is LOCAL, batch count is 1, generated successfully, expected generations list with size 1`() { + whenever(stubPreferenceManager.source) + .thenReturn(ServerSource.LOCAL) + + whenever(stubLocalDiffusionGenerationRepository.generateFromText(any())) + .thenReturn(Single.just(mockAiGenerationResult)) + + val stubBatchCount = 1 + val stubPayload = mockTextToImagePayload.copy(batchCount = stubBatchCount) + + val expectedResult = listOf(mockAiGenerationResult) + + useCase(stubPayload) + .test() + .assertNoErrors() + .assertValue { generations -> + generations.size == stubBatchCount && expectedResult == generations + } + .await() + .assertComplete() + } + + @Test + fun `given source is LOCAL, batch count is 10, generated successfully, expected generations list with size 10`() { + whenever(stubPreferenceManager.source) + .thenReturn(ServerSource.LOCAL) + + whenever(stubLocalDiffusionGenerationRepository.generateFromText(any())) + .thenReturn(Single.just(mockAiGenerationResult)) + + val stubBatchCount = 10 + val stubPayload = mockTextToImagePayload.copy(batchCount = stubBatchCount) + + val expectedResult = (0 until 10).map { mockAiGenerationResult } + + useCase(stubPayload) + .test() + .assertNoErrors() + .assertValue { generations -> + generations.size == stubBatchCount && expectedResult == generations + } + .await() + .assertComplete() + } + + @Test + fun `given source is LOCAL, batch count is 1, generate failed, expected error`() { + whenever(stubPreferenceManager.source) + .thenReturn(ServerSource.LOCAL) + + whenever(stubLocalDiffusionGenerationRepository.generateFromText(any())) + .thenReturn(Single.error(stubException)) + + val stubBatchCount = 1 + val stubPayload = mockTextToImagePayload.copy(batchCount = stubBatchCount) + + useCase(stubPayload) + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } +} diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/huggingface/FetchAndGetHuggingFaceModelsUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/huggingface/FetchAndGetHuggingFaceModelsUseCaseImplTest.kt new file mode 100644 index 00000000..77da1737 --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/huggingface/FetchAndGetHuggingFaceModelsUseCaseImplTest.kt @@ -0,0 +1,55 @@ +package com.shifthackz.aisdv1.domain.usecase.huggingface + +import com.nhaarman.mockitokotlin2.mock +import com.nhaarman.mockitokotlin2.whenever +import com.shifthackz.aisdv1.domain.mocks.mockHuggingFaceModels +import com.shifthackz.aisdv1.domain.repository.HuggingFaceModelsRepository +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class FetchAndGetHuggingFaceModelsUseCaseImplTest { + + private val stubRepository = mock() + + private val useCase = FetchAndGetHuggingFaceModelsUseCaseImpl(stubRepository) + + @Test + fun `given repository provided models list, expected valid list value`() { + whenever(stubRepository.fetchAndGetHuggingFaceModels()) + .thenReturn(Single.just(mockHuggingFaceModels)) + + useCase() + .test() + .assertNoErrors() + .assertValue(mockHuggingFaceModels) + .await() + .assertComplete() + } + + @Test + fun `given repository provided empty models list, expected empty list value`() { + whenever(stubRepository.fetchAndGetHuggingFaceModels()) + .thenReturn(Single.just(emptyList())) + + useCase() + .test() + .assertNoErrors() + .assertValue(emptyList()) + .await() + .assertComplete() + } + + @Test + fun `given repository thrown exception, expected error value`() { + val stubException = Throwable("Unknown error occurred.") + + whenever(stubRepository.fetchAndGetHuggingFaceModels()) + .thenReturn(Single.error(stubException)) + + useCase() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } +} diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/sdembedding/FetchAndGetEmbeddingsUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/sdembedding/FetchAndGetEmbeddingsUseCaseImplTest.kt new file mode 100644 index 00000000..28952d80 --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/sdembedding/FetchAndGetEmbeddingsUseCaseImplTest.kt @@ -0,0 +1,56 @@ +package com.shifthackz.aisdv1.domain.usecase.sdembedding + +import com.nhaarman.mockitokotlin2.doReturn +import com.nhaarman.mockitokotlin2.mock +import com.nhaarman.mockitokotlin2.whenever +import com.shifthackz.aisdv1.domain.mocks.mockStableDiffusionEmbeddings +import com.shifthackz.aisdv1.domain.repository.StableDiffusionEmbeddingsRepository +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class FetchAndGetEmbeddingsUseCaseImplTest { + + private val stubRepository = mock() + + private val useCase = FetchAndGetEmbeddingsUseCaseImpl(stubRepository) + + @Test + fun `given repository provided embeddings list, expected valid list value`() { + whenever(stubRepository.fetchAndGetEmbeddings()) + .doReturn(Single.just(mockStableDiffusionEmbeddings)) + + useCase() + .test() + .assertNoErrors() + .assertValue(mockStableDiffusionEmbeddings) + .await() + .assertComplete() + } + + @Test + fun `given repository provided empty embeddings list, expected empty list value`() { + whenever(stubRepository.fetchAndGetEmbeddings()) + .doReturn(Single.just(emptyList())) + + useCase() + .test() + .assertNoErrors() + .assertValue(emptyList()) + .await() + .assertComplete() + } + + @Test + fun `given repository thrown exception, expected error value`() { + val stubException = Throwable("Unknown error occurred.") + + whenever(stubRepository.fetchAndGetEmbeddings()) + .doReturn(Single.error(stubException)) + + useCase() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } +} diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/sdhypernet/FetchAndGetHyperNetworksUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/sdhypernet/FetchAndGetHyperNetworksUseCaseImplTest.kt new file mode 100644 index 00000000..a8122d41 --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/sdhypernet/FetchAndGetHyperNetworksUseCaseImplTest.kt @@ -0,0 +1,55 @@ +package com.shifthackz.aisdv1.domain.usecase.sdhypernet + +import com.nhaarman.mockitokotlin2.mock +import com.nhaarman.mockitokotlin2.whenever +import com.shifthackz.aisdv1.domain.mocks.mockStableDiffusionHyperNetworks +import com.shifthackz.aisdv1.domain.repository.StableDiffusionHyperNetworksRepository +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class FetchAndGetHyperNetworksUseCaseImplTest { + + private val stubRepository = mock() + + private val useCase = FetchAndGetHyperNetworksUseCaseImpl(stubRepository) + + @Test + fun `given repository provided list of hypernetworks, expected valid list value`() { + whenever(stubRepository.fetchAndGetHyperNetworks()) + .thenReturn(Single.just(mockStableDiffusionHyperNetworks)) + + useCase() + .test() + .assertNoErrors() + .assertValue(mockStableDiffusionHyperNetworks) + .await() + .assertComplete() + } + + @Test + fun `given repository provided empty list of hypernetworks, expected empty list value`() { + whenever(stubRepository.fetchAndGetHyperNetworks()) + .thenReturn(Single.just(emptyList())) + + useCase() + .test() + .assertNoErrors() + .assertValue(emptyList()) + .await() + .assertComplete() + } + + @Test + fun `given repository thrown exception, expected error value`() { + val stubException = Throwable("Unknown error occurred.") + + whenever(stubRepository.fetchAndGetHyperNetworks()) + .thenReturn(Single.error(stubException)) + + useCase() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } +} diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/sdmodel/GetStableDiffusionModelsUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/sdmodel/GetStableDiffusionModelsUseCaseImplTest.kt new file mode 100644 index 00000000..6d6450b0 --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/sdmodel/GetStableDiffusionModelsUseCaseImplTest.kt @@ -0,0 +1,108 @@ +package com.shifthackz.aisdv1.domain.usecase.sdmodel + +import com.nhaarman.mockitokotlin2.mock +import com.nhaarman.mockitokotlin2.whenever +import com.shifthackz.aisdv1.domain.entity.ServerConfiguration +import com.shifthackz.aisdv1.domain.mocks.mockServerConfiguration +import com.shifthackz.aisdv1.domain.mocks.mockStableDiffusionModels +import com.shifthackz.aisdv1.domain.repository.ServerConfigurationRepository +import com.shifthackz.aisdv1.domain.repository.StableDiffusionModelsRepository +import io.reactivex.rxjava3.core.Single +import org.junit.Assert +import org.junit.Test + +class GetStableDiffusionModelsUseCaseImplTest { + + private val stubServerConfigurationRepository = mock() + private val stubSdModelsRepository = mock() + + private val useCase = GetStableDiffusionModelsUseCaseImpl( + serverConfigurationRepository = stubServerConfigurationRepository, + sdModelsRepository = stubSdModelsRepository, + ) + + @Test + fun `given repository returns list with value present in configuration, expected list with selected value`() { + whenever(stubServerConfigurationRepository.fetchAndGetConfiguration()) + .thenReturn(Single.just(mockServerConfiguration)) + + whenever(stubSdModelsRepository.fetchAndGetModels()) + .thenReturn(Single.just(mockStableDiffusionModels)) + + val expectedValue = mockStableDiffusionModels.map { + it to (it.title == mockServerConfiguration.sdModelCheckpoint) + } + + useCase() + .test() + .assertNoErrors() + .assertValue(expectedValue) + .also { + Assert.assertEquals( + true, + expectedValue.any { (_, selected) -> selected }, + ) + } + .await() + .assertComplete() + } + + @Test + fun `given repository returns list with no value present in configuration, expected list without selected value`() { + val stubServerConfiguration = ServerConfiguration("nonsense") + + whenever(stubServerConfigurationRepository.fetchAndGetConfiguration()) + .thenReturn(Single.just(stubServerConfiguration)) + + whenever(stubSdModelsRepository.fetchAndGetModels()) + .thenReturn(Single.just(mockStableDiffusionModels)) + + val expectedValue = mockStableDiffusionModels.map { + it to (it.title == stubServerConfiguration.sdModelCheckpoint) + } + + useCase() + .test() + .assertNoErrors() + .assertValue(expectedValue) + .also { + Assert.assertEquals( + true, + !expectedValue.any { (_, selected) -> selected }, + ) + } + .await() + .assertComplete() + } + + @Test + fun `given exception while fetching configuration, expected error value`() { + val stubException = Throwable("Network error.") + + whenever(stubServerConfigurationRepository.fetchAndGetConfiguration()) + .thenReturn(Single.error(stubException)) + + useCase() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `given exception while fetching models, expected error value`() { + val stubException = Throwable("Network error.") + + whenever(stubServerConfigurationRepository.fetchAndGetConfiguration()) + .thenReturn(Single.just(mockServerConfiguration)) + + whenever(stubSdModelsRepository.fetchAndGetModels()) + .thenReturn(Single.error(stubException)) + + useCase() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } +} diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/sdmodel/SelectStableDiffusionModelUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/sdmodel/SelectStableDiffusionModelUseCaseImplTest.kt new file mode 100644 index 00000000..8ac4b20b --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/sdmodel/SelectStableDiffusionModelUseCaseImplTest.kt @@ -0,0 +1,96 @@ +package com.shifthackz.aisdv1.domain.usecase.sdmodel + +import com.nhaarman.mockitokotlin2.any +import com.nhaarman.mockitokotlin2.mock +import com.nhaarman.mockitokotlin2.whenever +import com.shifthackz.aisdv1.domain.mocks.mockServerConfiguration +import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import com.shifthackz.aisdv1.domain.repository.ServerConfigurationRepository +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class SelectStableDiffusionModelUseCaseImplTest { + + private val stubException = Throwable("Unknown error occurred.") + private val stubServerConfigurationRepository = mock() + private val stubPreferenceManager = mock() + + private val useCase = SelectStableDiffusionModelUseCaseImpl( + serverConfigurationRepository = stubServerConfigurationRepository, + preferenceManager = stubPreferenceManager, + ) + + @Test + fun `expected get, update, fetch completed, expected complete without errors`() { + whenever(stubServerConfigurationRepository.getConfiguration()) + .thenReturn(Single.just(mockServerConfiguration)) + + whenever(stubServerConfigurationRepository.updateConfiguration(any())) + .thenReturn(Completable.complete()) + + whenever(stubServerConfigurationRepository.fetchConfiguration()) + .thenReturn(Completable.complete()) + + useCase("model") + .test() + .assertNoErrors() + .await() + .assertComplete() + } + + @Test + fun `expected get failed, update, fetch completed, expected complete with error`() { + whenever(stubServerConfigurationRepository.getConfiguration()) + .thenReturn(Single.error(stubException)) + + whenever(stubServerConfigurationRepository.updateConfiguration(any())) + .thenReturn(Completable.complete()) + + whenever(stubServerConfigurationRepository.fetchConfiguration()) + .thenReturn(Completable.complete()) + + useCase("model") + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + + @Test + fun `expected update failed, get, fetch completed, expected complete with error`() { + whenever(stubServerConfigurationRepository.getConfiguration()) + .thenReturn(Single.just(mockServerConfiguration)) + + whenever(stubServerConfigurationRepository.updateConfiguration(any())) + .thenReturn(Completable.error(stubException)) + + whenever(stubServerConfigurationRepository.fetchConfiguration()) + .thenReturn(Completable.complete()) + + useCase("model") + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } + + @Test + fun `expected get, update completed, fetch failed, expected complete with error`() { + whenever(stubServerConfigurationRepository.getConfiguration()) + .thenReturn(Single.just(mockServerConfiguration)) + + whenever(stubServerConfigurationRepository.updateConfiguration(any())) + .thenReturn(Completable.complete()) + + whenever(stubServerConfigurationRepository.fetchConfiguration()) + .thenReturn(Completable.error(stubException)) + + useCase("model") + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } +} diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/sdsampler/GetStableDiffusionSamplersUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/sdsampler/GetStableDiffusionSamplersUseCaseImplTest.kt new file mode 100644 index 00000000..5da15b29 --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/sdsampler/GetStableDiffusionSamplersUseCaseImplTest.kt @@ -0,0 +1,57 @@ +package com.shifthackz.aisdv1.domain.usecase.sdsampler + +import com.nhaarman.mockitokotlin2.mock +import com.nhaarman.mockitokotlin2.whenever +import com.shifthackz.aisdv1.domain.mocks.mockStableDiffusionSamplers +import com.shifthackz.aisdv1.domain.repository.StableDiffusionSamplersRepository +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class GetStableDiffusionSamplersUseCaseImplTest { + + private val stubStableDiffusionSamplersRepository = mock() + + private val useCase = GetStableDiffusionSamplersUseCaseImpl( + repository = stubStableDiffusionSamplersRepository, + ) + + @Test + fun `given got samplers from repository, expected valid samplers value`() { + whenever(stubStableDiffusionSamplersRepository.getSamplers()) + .thenReturn(Single.just(mockStableDiffusionSamplers)) + + useCase() + .test() + .assertNoErrors() + .assertValue(mockStableDiffusionSamplers) + .await() + .assertComplete() + } + + @Test + fun `given got empty list from repository, expected empty value`() { + whenever(stubStableDiffusionSamplersRepository.getSamplers()) + .thenReturn(Single.just(emptyList())) + + useCase() + .test() + .assertNoErrors() + .assertValue(emptyList()) + .await() + .assertComplete() + } + + @Test + fun `given got error from repository, expected the same error`() { + val stubException = Throwable("Error query database.") + + whenever(stubStableDiffusionSamplersRepository.getSamplers()) + .thenReturn(Single.error(stubException)) + + useCase() + .test() + .assertError(stubException) + .await() + .assertNotComplete() + } +} diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToA1111UseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToA1111UseCaseImplTest.kt new file mode 100644 index 00000000..f0ee7799 --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToA1111UseCaseImplTest.kt @@ -0,0 +1,79 @@ +package com.shifthackz.aisdv1.domain.usecase.settings + +import com.shifthackz.aisdv1.domain.feature.auth.AuthorizationCredentials +import com.shifthackz.aisdv1.domain.mocks.mockConfiguration +import com.shifthackz.aisdv1.domain.usecase.caching.DataPreLoaderUseCase +import com.shifthackz.aisdv1.domain.usecase.connectivity.TestConnectivityUseCase +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class ConnectToA1111UseCaseImplTest { + + private val stubThrowable = Throwable("Something went wrong.") + private val stubGetConfigurationUseCase = mockk() + private val stubSetServerConfigurationUseCase = mockk() + private val stubTestConnectivityUseCase = mockk() + private val stubDataPreLoaderUseCase = mockk() + + private val useCase = ConnectToA1111UseCaseImpl( + getConfigurationUseCase = stubGetConfigurationUseCase, + setServerConfigurationUseCase = stubSetServerConfigurationUseCase, + testConnectivityUseCase = stubTestConnectivityUseCase, + dataPreLoaderUseCase = stubDataPreLoaderUseCase, + ) + + @Test + fun `given connection process successful, expected success result value`() { + every { + stubGetConfigurationUseCase() + } returns Single.just(mockConfiguration) + + every { + stubSetServerConfigurationUseCase(any()) + } returns Completable.complete() + + every { + stubTestConnectivityUseCase(any()) + } returns Completable.complete() + + every { + stubDataPreLoaderUseCase() + } returns Completable.complete() + + useCase("5598", false, AuthorizationCredentials.None) + .test() + .assertNoErrors() + .await() + .assertValue(Result.success(Unit)) + .assertComplete() + } + + @Test + fun `given connection process failed, expected error result value`() { + every { + stubGetConfigurationUseCase() + } returns Single.error(stubThrowable) + + every { + stubSetServerConfigurationUseCase(any()) + } returns Completable.error(stubThrowable) + + every { + stubTestConnectivityUseCase(any()) + } returns Completable.error(stubThrowable) + + every { + stubDataPreLoaderUseCase() + } returns Completable.error(stubThrowable) + + useCase("5598", false, AuthorizationCredentials.None) + .test() + .assertNoErrors() + .await() + .assertValue(Result.failure(stubThrowable)) + .assertComplete() + } +} diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToHordeUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToHordeUseCaseImplTest.kt new file mode 100644 index 00000000..aa12885a --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToHordeUseCaseImplTest.kt @@ -0,0 +1,93 @@ +package com.shifthackz.aisdv1.domain.usecase.settings + +import com.shifthackz.aisdv1.domain.mocks.mockConfiguration +import com.shifthackz.aisdv1.domain.usecase.connectivity.TestHordeApiKeyUseCase +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class ConnectToHordeUseCaseImplTest { + + private val stubThrowable = Throwable("Something went wrong.") + private val stubGetConfigurationUseCase = mockk() + private val stubSetServerConfigurationUseCase = mockk() + private val stubTestHordeApiKeyUseCase = mockk() + + private val useCase = ConnectToHordeUseCaseImpl( + getConfigurationUseCase = stubGetConfigurationUseCase, + setServerConfigurationUseCase = stubSetServerConfigurationUseCase, + testHordeApiKeyUseCase = stubTestHordeApiKeyUseCase, + ) + + @Test + fun `given connection process successful, API key is valid, expected success result value`() { + every { + stubGetConfigurationUseCase() + } returns Single.just(mockConfiguration) + + every { + stubSetServerConfigurationUseCase(any()) + } returns Completable.complete() + + every { + stubTestHordeApiKeyUseCase() + } returns Single.just(true) + + useCase("5598") + .test() + .assertNoErrors() + .await() + .assertValue(Result.success(Unit)) + .assertComplete() + } + + @Test + fun `given connection process successful, API key is NOT valid, expected success result value`() { + every { + stubGetConfigurationUseCase() + } returns Single.just(mockConfiguration) + + every { + stubSetServerConfigurationUseCase(any()) + } returns Completable.complete() + + every { + stubTestHordeApiKeyUseCase() + } returns Single.just(false) + + useCase("5598") + .test() + .assertNoErrors() + .await() + .assertValue { actual -> + actual.isFailure + && actual.exceptionOrNull() is IllegalStateException + && actual.exceptionOrNull()?.message == "Bad key" + } + .assertComplete() + } + + @Test + fun `given connection process failed, expected error result value`() { + every { + stubGetConfigurationUseCase() + } returns Single.error(stubThrowable) + + every { + stubSetServerConfigurationUseCase(any()) + } returns Completable.error(stubThrowable) + + every { + stubTestHordeApiKeyUseCase() + } returns Single.error(stubThrowable) + + useCase("5598") + .test() + .assertNoErrors() + .await() + .assertValue(Result.failure(stubThrowable)) + .assertComplete() + } +} diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToHuggingFaceUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToHuggingFaceUseCaseImplTest.kt new file mode 100644 index 00000000..c66c336d --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToHuggingFaceUseCaseImplTest.kt @@ -0,0 +1,93 @@ +package com.shifthackz.aisdv1.domain.usecase.settings + +import com.shifthackz.aisdv1.domain.mocks.mockConfiguration +import com.shifthackz.aisdv1.domain.usecase.connectivity.TestHuggingFaceApiKeyUseCase +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class ConnectToHuggingFaceUseCaseImplTest { + + private val stubThrowable = Throwable("Something went wrong.") + private val stubGetConfigurationUseCase = mockk() + private val stubSetServerConfigurationUseCase = mockk() + private val stubTestHuggingFaceApiKeyUseCase = mockk() + + private val useCase = ConnectToHuggingFaceUseCaseImpl( + getConfigurationUseCase = stubGetConfigurationUseCase, + setServerConfigurationUseCase = stubSetServerConfigurationUseCase, + testHuggingFaceApiKeyUseCase = stubTestHuggingFaceApiKeyUseCase, + ) + + @Test + fun `given connection process successful, API key is valid, expected success result value`() { + every { + stubGetConfigurationUseCase() + } returns Single.just(mockConfiguration) + + every { + stubSetServerConfigurationUseCase(any()) + } returns Completable.complete() + + every { + stubTestHuggingFaceApiKeyUseCase() + } returns Single.just(true) + + useCase("5598", "5598") + .test() + .assertNoErrors() + .await() + .assertValue(Result.success(Unit)) + .assertComplete() + } + + @Test + fun `given connection process successful, API key is NOT valid, expected success result value`() { + every { + stubGetConfigurationUseCase() + } returns Single.just(mockConfiguration) + + every { + stubSetServerConfigurationUseCase(any()) + } returns Completable.complete() + + every { + stubTestHuggingFaceApiKeyUseCase() + } returns Single.just(false) + + useCase("5598", "5598") + .test() + .assertNoErrors() + .await() + .assertValue { actual -> + actual.isFailure + && actual.exceptionOrNull() is IllegalStateException + && actual.exceptionOrNull()?.message == "Bad key" + } + .assertComplete() + } + + @Test + fun `given connection process failed, expected error result value`() { + every { + stubGetConfigurationUseCase() + } returns Single.error(stubThrowable) + + every { + stubSetServerConfigurationUseCase(any()) + } returns Completable.error(stubThrowable) + + every { + stubTestHuggingFaceApiKeyUseCase() + } returns Single.error(stubThrowable) + + useCase("5598", "5598") + .test() + .assertNoErrors() + .await() + .assertValue(Result.failure(stubThrowable)) + .assertComplete() + } +} diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToLocalDiffusionUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToLocalDiffusionUseCaseImplTest.kt new file mode 100644 index 00000000..33a09c57 --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToLocalDiffusionUseCaseImplTest.kt @@ -0,0 +1,56 @@ +package com.shifthackz.aisdv1.domain.usecase.settings + +import com.shifthackz.aisdv1.domain.mocks.mockConfiguration +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class ConnectToLocalDiffusionUseCaseImplTest { + + private val stubThrowable = Throwable("Something went wrong.") + private val stubGetConfigurationUseCase = mockk() + private val stubSetServerConfigurationUseCase = mockk() + + private val useCase = ConnectToLocalDiffusionUseCaseImpl( + getConfigurationUseCase = stubGetConfigurationUseCase, + setServerConfigurationUseCase = stubSetServerConfigurationUseCase, + ) + + @Test + fun `given connection process successful, expected success result value`() { + every { + stubGetConfigurationUseCase() + } returns Single.just(mockConfiguration) + + every { + stubSetServerConfigurationUseCase(any()) + } returns Completable.complete() + + useCase("5598") + .test() + .assertNoErrors() + .assertValue(Result.success(Unit)) + .await() + .assertComplete() + } + + @Test + fun `given connection process failed, expected error result value`() { + every { + stubGetConfigurationUseCase() + } returns Single.error(stubThrowable) + + every { + stubSetServerConfigurationUseCase(any()) + } returns Completable.error(stubThrowable) + + useCase("5598") + .test() + .assertNoErrors() + .assertValue(Result.failure(stubThrowable)) + .await() + .assertComplete() + } +} diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToOpenAiUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToOpenAiUseCaseImplTest.kt new file mode 100644 index 00000000..470b2657 --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToOpenAiUseCaseImplTest.kt @@ -0,0 +1,93 @@ +package com.shifthackz.aisdv1.domain.usecase.settings + +import com.shifthackz.aisdv1.domain.mocks.mockConfiguration +import com.shifthackz.aisdv1.domain.usecase.connectivity.TestOpenAiApiKeyUseCase +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class ConnectToOpenAiUseCaseImplTest { + + private val stubThrowable = Throwable("Something went wrong.") + private val stubGetConfigurationUseCase = mockk() + private val stubSetServerConfigurationUseCase = mockk() + private val stubTestOpenAiApiKeyUseCase = mockk() + + private val useCase = ConnectToOpenAiUseCaseImpl( + getConfigurationUseCase = stubGetConfigurationUseCase, + setServerConfigurationUseCase = stubSetServerConfigurationUseCase, + testOpenAiApiKeyUseCase = stubTestOpenAiApiKeyUseCase, + ) + + @Test + fun `given connection process successful, API key is valid, expected success result value`() { + every { + stubGetConfigurationUseCase() + } returns Single.just(mockConfiguration) + + every { + stubSetServerConfigurationUseCase(any()) + } returns Completable.complete() + + every { + stubTestOpenAiApiKeyUseCase() + } returns Single.just(true) + + useCase("5598") + .test() + .assertNoErrors() + .await() + .assertValue(Result.success(Unit)) + .assertComplete() + } + + @Test + fun `given connection process successful, API key is NOT valid, expected success result value`() { + every { + stubGetConfigurationUseCase() + } returns Single.just(mockConfiguration) + + every { + stubSetServerConfigurationUseCase(any()) + } returns Completable.complete() + + every { + stubTestOpenAiApiKeyUseCase() + } returns Single.just(false) + + useCase("5598") + .test() + .assertNoErrors() + .await() + .assertValue { actual -> + actual.isFailure + && actual.exceptionOrNull() is IllegalStateException + && actual.exceptionOrNull()?.message == "Bad key" + } + .assertComplete() + } + + @Test + fun `given connection process failed, expected error result value`() { + every { + stubGetConfigurationUseCase() + } returns Single.error(stubThrowable) + + every { + stubSetServerConfigurationUseCase(any()) + } returns Completable.error(stubThrowable) + + every { + stubTestOpenAiApiKeyUseCase() + } returns Single.error(stubThrowable) + + useCase("5598") + .test() + .assertNoErrors() + .await() + .assertValue(Result.failure(stubThrowable)) + .assertComplete() + } +} diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToStabilityAiUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToStabilityAiUseCaseImplTest.kt new file mode 100644 index 00000000..a3d2a4e7 --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToStabilityAiUseCaseImplTest.kt @@ -0,0 +1,93 @@ +package com.shifthackz.aisdv1.domain.usecase.settings + +import com.shifthackz.aisdv1.domain.mocks.mockConfiguration +import com.shifthackz.aisdv1.domain.usecase.connectivity.TestStabilityAiApiKeyUseCase +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class ConnectToStabilityAiUseCaseImplTest { + + private val stubThrowable = Throwable("Something went wrong.") + private val stubGetConfigurationUseCase = mockk() + private val stubSetServerConfigurationUseCase = mockk() + private val stubTestStabilityAiApiKeyUseCase = mockk() + + private val useCase = ConnectToStabilityAiUseCaseImpl( + getConfigurationUseCase = stubGetConfigurationUseCase, + setServerConfigurationUseCase = stubSetServerConfigurationUseCase, + testStabilityAiApiKeyUseCase = stubTestStabilityAiApiKeyUseCase, + ) + + @Test + fun `given connection process successful, API key is valid, expected success result value`() { + every { + stubGetConfigurationUseCase() + } returns Single.just(mockConfiguration) + + every { + stubSetServerConfigurationUseCase(any()) + } returns Completable.complete() + + every { + stubTestStabilityAiApiKeyUseCase() + } returns Single.just(true) + + useCase("5598") + .test() + .assertNoErrors() + .await() + .assertValue(Result.success(Unit)) + .assertComplete() + } + + @Test + fun `given connection process successful, API key is NOT valid, expected success result value`() { + every { + stubGetConfigurationUseCase() + } returns Single.just(mockConfiguration) + + every { + stubSetServerConfigurationUseCase(any()) + } returns Completable.complete() + + every { + stubTestStabilityAiApiKeyUseCase() + } returns Single.just(false) + + useCase("5598") + .test() + .assertNoErrors() + .await() + .assertValue { actual -> + actual.isFailure + && actual.exceptionOrNull() is IllegalStateException + && actual.exceptionOrNull()?.message == "Bad key" + } + .assertComplete() + } + + @Test + fun `given connection process failed, expected error result value`() { + every { + stubGetConfigurationUseCase() + } returns Single.error(stubThrowable) + + every { + stubSetServerConfigurationUseCase(any()) + } returns Completable.error(stubThrowable) + + every { + stubTestStabilityAiApiKeyUseCase() + } returns Single.error(stubThrowable) + + useCase("5598") + .test() + .assertNoErrors() + .await() + .assertValue(Result.failure(stubThrowable)) + .assertComplete() + } +} diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImplTest.kt new file mode 100644 index 00000000..c23282ba --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImplTest.kt @@ -0,0 +1,75 @@ +package com.shifthackz.aisdv1.domain.usecase.settings + +import com.shifthackz.aisdv1.domain.feature.auth.AuthorizationCredentials +import com.shifthackz.aisdv1.domain.feature.auth.AuthorizationStore +import com.shifthackz.aisdv1.domain.mocks.mockConfiguration +import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import io.mockk.every +import io.mockk.mockk +import org.junit.Test + +class GetConfigurationUseCaseImplTest { + + private val stubPreferenceManager = mockk() + private val stubAuthorizationStore = mockk() + + private val useCase = GetConfigurationUseCaseImpl( + preferenceManager = stubPreferenceManager, + authorizationStore = stubAuthorizationStore, + ) + + @Test + fun `given configuration read success, expected valid configuration domain model value`() { + every { + stubAuthorizationStore.getAuthorizationCredentials() + } returns AuthorizationCredentials.None + + every { + stubPreferenceManager::serverUrl.get() + } returns mockConfiguration.serverUrl + + every { + stubPreferenceManager::demoMode.get() + } returns mockConfiguration.demoMode + + every { + stubPreferenceManager::source.get() + } returns mockConfiguration.source + + every { + stubPreferenceManager::hordeApiKey.get() + } returns mockConfiguration.hordeApiKey + + every { + stubPreferenceManager::openAiApiKey.get() + } returns mockConfiguration.openAiApiKey + + every { + stubPreferenceManager::huggingFaceApiKey.get() + } returns mockConfiguration.huggingFaceApiKey + + every { + stubPreferenceManager::huggingFaceModel.get() + } returns mockConfiguration.huggingFaceModel + + every { + stubPreferenceManager::stabilityAiApiKey.get() + } returns mockConfiguration.stabilityAiApiKey + + every { + stubPreferenceManager::stabilityAiEngineId.get() + } returns mockConfiguration.stabilityAiEngineId + + every { + stubPreferenceManager::localModelId.get() + } returns mockConfiguration.localModelId + + useCase + .invoke() + .test() + .assertNoErrors() + .assertValue(mockConfiguration) + .await() + .assertComplete() + } +} diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImplTest.kt new file mode 100644 index 00000000..fabc2fe4 --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImplTest.kt @@ -0,0 +1,73 @@ +package com.shifthackz.aisdv1.domain.usecase.settings + +import com.shifthackz.aisdv1.domain.feature.auth.AuthorizationStore +import com.shifthackz.aisdv1.domain.mocks.mockConfiguration +import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import io.mockk.every +import io.mockk.mockk +import org.junit.Test + +class SetServerConfigurationUseCaseImplTest { + + private val stubPreferenceManager = mockk() + private val stubAuthorizationStore = mockk() + + private val useCase = SetServerConfigurationUseCaseImpl( + preferenceManager = stubPreferenceManager, + authorizationStore = stubAuthorizationStore, + ) + + @Test + fun `given configuration apply success, expected complete value`() { + every { + stubAuthorizationStore.storeAuthorizationCredentials(any()) + } returns Unit + + every { + stubPreferenceManager::source.set(any()) + } returns Unit + + every { + stubPreferenceManager::serverUrl.set(any()) + } returns Unit + + every { + stubPreferenceManager::demoMode.set(any()) + } returns Unit + + every { + stubPreferenceManager::hordeApiKey.set(any()) + } returns Unit + + every { + stubPreferenceManager::openAiApiKey.set(any()) + } returns Unit + + every { + stubPreferenceManager::huggingFaceApiKey.set(any()) + } returns Unit + + every { + stubPreferenceManager::huggingFaceModel.set(any()) + } returns Unit + + every { + stubPreferenceManager::stabilityAiApiKey.set(any()) + } returns Unit + + every { + stubPreferenceManager::stabilityAiEngineId.set(any()) + } returns Unit + + every { + stubPreferenceManager::localModelId.set(any()) + } returns Unit + + useCase + .invoke(mockConfiguration) + .test() + .assertNoErrors() + .await() + .assertComplete() + } +} diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/splash/SplashNavigationUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/splash/SplashNavigationUseCaseImplTest.kt new file mode 100644 index 00000000..c2787eeb --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/splash/SplashNavigationUseCaseImplTest.kt @@ -0,0 +1,93 @@ +package com.shifthackz.aisdv1.domain.usecase.splash + +import com.nhaarman.mockitokotlin2.mock +import com.nhaarman.mockitokotlin2.whenever +import com.shifthackz.aisdv1.domain.entity.ServerSource +import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import org.junit.Test + +class SplashNavigationUseCaseImplTest { + + private val stubPreferenceManager = mock() + + private val useCase = SplashNavigationUseCaseImpl(stubPreferenceManager) + + @Test + fun `given forceSetupAfterUpdate is true, expected LAUNCH_SERVER_SETUP`() { + whenever(stubPreferenceManager.forceSetupAfterUpdate) + .thenReturn(true) + + useCase() + .test() + .assertNoErrors() + .assertValue(SplashNavigationUseCase.Action.LAUNCH_SERVER_SETUP) + } + + @Test + fun `given source is AUTOMATIC1111 and server url empty, expected LAUNCH_SERVER_SETUP`() { + whenever(stubPreferenceManager.forceSetupAfterUpdate) + .thenReturn(false) + + whenever(stubPreferenceManager.serverUrl) + .thenReturn("") + + whenever(stubPreferenceManager.source) + .thenReturn(ServerSource.AUTOMATIC1111) + + useCase() + .test() + .assertNoErrors() + .assertValue(SplashNavigationUseCase.Action.LAUNCH_SERVER_SETUP) + } + + @Test + fun `given source is AUTOMATIC1111 and server url not empty, expected LAUNCH_HOME`() { + whenever(stubPreferenceManager.forceSetupAfterUpdate) + .thenReturn(false) + + whenever(stubPreferenceManager.serverUrl) + .thenReturn("http://192.168.0.1:7860") + + whenever(stubPreferenceManager.source) + .thenReturn(ServerSource.AUTOMATIC1111) + + useCase() + .test() + .assertNoErrors() + .assertValue(SplashNavigationUseCase.Action.LAUNCH_HOME) + } + + @Test + fun `given source is LOCAL, and server url is empty, expected LAUNCH_HOME`() { + whenever(stubPreferenceManager.forceSetupAfterUpdate) + .thenReturn(false) + + whenever(stubPreferenceManager.serverUrl) + .thenReturn("") + + whenever(stubPreferenceManager.source) + .thenReturn(ServerSource.LOCAL) + + useCase() + .test() + .assertNoErrors() + .assertValue(SplashNavigationUseCase.Action.LAUNCH_HOME) + } + + @Test + fun `given source is LOCAL, and server url is not empty, expected LAUNCH_HOME`() { + whenever(stubPreferenceManager.forceSetupAfterUpdate) + .thenReturn(false) + + whenever(stubPreferenceManager.serverUrl) + .thenReturn("http://192.168.0.1:7860") + + whenever(stubPreferenceManager.source) + .thenReturn(ServerSource.LOCAL) + + useCase() + .test() + .assertNoErrors() + .assertValue(SplashNavigationUseCase.Action.LAUNCH_HOME) + } +} diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/stabilityai/FetchAndGetStabilityAiEnginesUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/stabilityai/FetchAndGetStabilityAiEnginesUseCaseImplTest.kt new file mode 100644 index 00000000..68636239 --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/stabilityai/FetchAndGetStabilityAiEnginesUseCaseImplTest.kt @@ -0,0 +1,49 @@ +package com.shifthackz.aisdv1.domain.usecase.stabilityai + +import com.nhaarman.mockitokotlin2.mock +import com.nhaarman.mockitokotlin2.whenever +import com.shifthackz.aisdv1.domain.mocks.mockStabilityAiEngines +import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import com.shifthackz.aisdv1.domain.repository.StabilityAiEnginesRepository +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class FetchAndGetStabilityAiEnginesUseCaseImplTest { + + private val stubRepository = mock() + + private val stubPreferenceManager = mock() + + private val useCase = FetchAndGetStabilityAiEnginesUseCaseImpl( + repository = stubRepository, + preferenceManager = stubPreferenceManager, + ) + + @Test + fun `given repository returned engines list, id present in preference, expected the same engines list, id not changed`() { + whenever(stubRepository.fetchAndGet()) + .thenReturn(Single.just(mockStabilityAiEngines)) + + whenever(stubPreferenceManager::stabilityAiEngineId.get()) + .thenReturn("engine_1") + + useCase() + .test() + .assertNoErrors() + .assertValue(mockStabilityAiEngines) + .assertComplete() + } + + @Test + fun `given repository thrown exception, expected the same exception`() { + val stubException = Throwable("Network exception") + + whenever(stubRepository.fetchAndGet()) + .thenReturn(Single.error(stubException)) + + useCase() + .test() + .assertError(stubException) + .assertNotComplete() + } +} diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/stabilityai/ObserveStabilityAiCreditsUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/stabilityai/ObserveStabilityAiCreditsUseCaseImplTest.kt new file mode 100644 index 00000000..a1b189bc --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/stabilityai/ObserveStabilityAiCreditsUseCaseImplTest.kt @@ -0,0 +1,109 @@ +package com.shifthackz.aisdv1.domain.usecase.stabilityai + +import com.nhaarman.mockitokotlin2.mock +import com.nhaarman.mockitokotlin2.whenever +import com.shifthackz.aisdv1.domain.entity.Settings +import com.shifthackz.aisdv1.domain.mocks.mockSettings +import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import com.shifthackz.aisdv1.domain.repository.StabilityAiCreditsRepository +import io.reactivex.rxjava3.core.BackpressureStrategy +import io.reactivex.rxjava3.core.Flowable +import io.reactivex.rxjava3.subjects.BehaviorSubject +import org.junit.Before +import org.junit.Test + +class ObserveStabilityAiCreditsUseCaseImplTest { + + private val stubSettingsObserver = BehaviorSubject.create() + private val stubCreditsObserver = BehaviorSubject.create>() + + private val stubStabilityAiCreditsRepository = mock() + + private val stubPreferenceManager = mock() + + private val useCase = ObserveStabilityAiCreditsUseCaseImpl( + repository = stubStabilityAiCreditsRepository, + preferenceManager = stubPreferenceManager, + ) + + @Before + fun initialize() { + whenever(stubPreferenceManager.observe()) + .thenReturn(stubSettingsObserver.toFlowable(BackpressureStrategy.LATEST)) + + whenever(stubStabilityAiCreditsRepository.fetchAndObserve()) + .thenReturn( + stubCreditsObserver + .toFlowable(BackpressureStrategy.LATEST) + .flatMap { result -> + result.fold( + onSuccess = { credits -> Flowable.just(credits) }, + onFailure = { t -> Flowable.error(t) }, + ) + } + ) + } + + @Test + fun `given successfully got settings and credits, expected valid credits value`() { + val stubObserver = useCase().test() + + stubCreditsObserver.onNext(Result.success(5598f)) + stubSettingsObserver.onNext(mockSettings) + + stubObserver + .assertValueAt(0, 5598f) + .assertNoErrors() + } + + @Test + fun `given successfully got settings and credits, then change settings, expected credits value not changed`() { + val stubObserver = useCase().test() + + stubCreditsObserver.onNext(Result.success(5598f)) + stubSettingsObserver.onNext(mockSettings) + + stubObserver + .assertValueAt(0, 5598f) + .assertNoErrors() + + stubSettingsObserver.onNext(mockSettings.copy(formPromptTaggedInput = false)) + + stubObserver + .assertValueAt(1, 5598f) + .assertNoErrors() + } + + @Test + fun `given successfully got settings and credits, then credits changed, expected credits value changed`() { + val stubObserver = useCase().test() + + stubCreditsObserver.onNext(Result.success(5598f)) + stubSettingsObserver.onNext(mockSettings) + + stubObserver + .assertValueAt(0, 5598f) + .assertNoErrors() + + stubCreditsObserver.onNext(Result.success(2211f)) + + stubObserver + .assertValueAt(1, 2211f) + .assertNoErrors() + } + + @Test + fun `given exception from credits repository, expected zero credits value`() { + val stubObserver = useCase().test() + val stubException = Throwable("Wrong server source selected.") + + stubCreditsObserver.onNext(Result.failure(stubException)) + stubSettingsObserver.onNext(mockSettings) + + stubObserver + .assertNoErrors() + .assertValueAt(0, 0f) + .await() + .assertComplete() + } +} diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/wakelock/AcquireWakelockUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/wakelock/AcquireWakelockUseCaseImplTest.kt new file mode 100644 index 00000000..997c0263 --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/wakelock/AcquireWakelockUseCaseImplTest.kt @@ -0,0 +1,48 @@ +package com.shifthackz.aisdv1.domain.usecase.wakelock + +import android.os.PowerManager +import com.nhaarman.mockitokotlin2.any +import com.nhaarman.mockitokotlin2.doNothing +import com.nhaarman.mockitokotlin2.given +import com.nhaarman.mockitokotlin2.mock +import com.nhaarman.mockitokotlin2.whenever +import com.shifthackz.aisdv1.domain.repository.WakeLockRepository +import org.junit.Assert +import org.junit.Before +import org.junit.Test + +class AcquireWakelockUseCaseImplTest { + + private val stubException = Throwable("Can not acquire wakelock.") + private val stubWakeLock = mock() + private val stubRepository = mock() + + private val useCase = AcquireWakelockUseCaseImpl(stubRepository) + + @Before + fun initialize() { + whenever(stubRepository.wakeLock) + .thenReturn(stubWakeLock) + } + + @Test + fun `given wakelock was acquired successfully, expected result success`() { + doNothing() + .whenever(stubWakeLock) + .acquire(any()) + + val expected = Result.success(Unit) + val actual = useCase() + Assert.assertEquals(expected, actual) + } + + @Test + fun `given wakelock acquire failed, expected result failure`() { + given(stubWakeLock.acquire(any())) + .willAnswer { throw stubException } + + val expected = Result.failure(stubException) + val actual = useCase() + Assert.assertEquals(expected, actual) + } +} diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/wakelock/ReleaseWakeLockUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/wakelock/ReleaseWakeLockUseCaseImplTest.kt new file mode 100644 index 00000000..1aae0795 --- /dev/null +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/wakelock/ReleaseWakeLockUseCaseImplTest.kt @@ -0,0 +1,47 @@ +package com.shifthackz.aisdv1.domain.usecase.wakelock + +import android.os.PowerManager +import com.nhaarman.mockitokotlin2.doNothing +import com.nhaarman.mockitokotlin2.given +import com.nhaarman.mockitokotlin2.mock +import com.nhaarman.mockitokotlin2.whenever +import com.shifthackz.aisdv1.domain.repository.WakeLockRepository +import org.junit.Assert +import org.junit.Before +import org.junit.Test + +class ReleaseWakeLockUseCaseImplTest { + + private val stubException = Throwable("Can not release wakelock.") + private val stubWakeLock = mock() + private val stubRepository = mock() + + private val useCase = ReleaseWakeLockUseCaseImpl(stubRepository) + + @Before + fun initialize() { + whenever(stubRepository.wakeLock) + .thenReturn(stubWakeLock) + } + + @Test + fun `given wakelock was released successfully, expected result success`() { + doNothing() + .whenever(stubWakeLock) + .release() + + val expected = Result.success(Unit) + val actual = useCase() + Assert.assertEquals(expected, actual) + } + + @Test + fun `given wakelock release failed, expected result failure`() { + given(stubWakeLock.release()) + .willAnswer { throw stubException } + + val expected = Result.failure(stubException) + val actual = useCase() + Assert.assertEquals(expected, actual) + } +} diff --git a/network/build.gradle b/network/build.gradle index 58b953ed..1c273e78 100755 --- a/network/build.gradle +++ b/network/build.gradle @@ -22,4 +22,6 @@ dependencies { implementation di.koinCore implementation reactive.rxkotlin implementation reactive.rxnetwork + testImplementation test.junit + testImplementation test.mockk } diff --git a/network/src/main/java/com/shifthackz/aisdv1/network/authenticator/RestAuthenticator.kt b/network/src/main/java/com/shifthackz/aisdv1/network/authenticator/RestAuthenticator.kt index 51a503c9..1406e011 100644 --- a/network/src/main/java/com/shifthackz/aisdv1/network/authenticator/RestAuthenticator.kt +++ b/network/src/main/java/com/shifthackz/aisdv1/network/authenticator/RestAuthenticator.kt @@ -1,6 +1,5 @@ package com.shifthackz.aisdv1.network.authenticator -import com.shifthackz.aisdv1.core.common.log.debugLog import com.shifthackz.aisdv1.network.qualifiers.CredentialsProvider import com.shifthackz.aisdv1.network.qualifiers.NetworkHeaders import okhttp3.Authenticator diff --git a/network/src/main/java/com/shifthackz/aisdv1/network/di/NetworkModule.kt b/network/src/main/java/com/shifthackz/aisdv1/network/di/NetworkModule.kt index 1599eaca..0995039c 100755 --- a/network/src/main/java/com/shifthackz/aisdv1/network/di/NetworkModule.kt +++ b/network/src/main/java/com/shifthackz/aisdv1/network/di/NetworkModule.kt @@ -157,7 +157,7 @@ val networkModule = module { singleOf(::DownloadableModelsApiImpl) bind DownloadableModelsApi::class singleOf(::HuggingFaceInferenceApiImpl) bind HuggingFaceInferenceApi::class - factory {params -> + factory { params -> ConnectivityMonitor(params.get()) } diff --git a/network/src/main/java/com/shifthackz/aisdv1/network/qualifiers/CredentialsProvider.kt b/network/src/main/java/com/shifthackz/aisdv1/network/qualifiers/CredentialsProvider.kt index 37541dde..14179f28 100644 --- a/network/src/main/java/com/shifthackz/aisdv1/network/qualifiers/CredentialsProvider.kt +++ b/network/src/main/java/com/shifthackz/aisdv1/network/qualifiers/CredentialsProvider.kt @@ -4,7 +4,7 @@ interface CredentialsProvider { operator fun invoke(): Data sealed interface Data { - object None : Data - data class HttpBasic(val login: String, val password: String): Data + data object None : Data + data class HttpBasic(val login: String, val password: String) : Data } } diff --git a/network/src/test/java/com/shifthackz/aisdv1/network/authenticator/RestAuthenticatorTest.kt b/network/src/test/java/com/shifthackz/aisdv1/network/authenticator/RestAuthenticatorTest.kt new file mode 100644 index 00000000..cf500c13 --- /dev/null +++ b/network/src/test/java/com/shifthackz/aisdv1/network/authenticator/RestAuthenticatorTest.kt @@ -0,0 +1,87 @@ +package com.shifthackz.aisdv1.network.authenticator + +import com.shifthackz.aisdv1.network.qualifiers.CredentialsProvider +import com.shifthackz.aisdv1.network.qualifiers.NetworkHeaders +import io.mockk.every +import io.mockk.mockk +import okhttp3.Address +import okhttp3.Authenticator +import okhttp3.Dns +import okhttp3.Protocol +import okhttp3.Request +import okhttp3.Response +import okhttp3.Route +import org.junit.Assert +import org.junit.Test +import java.net.InetSocketAddress +import java.net.Proxy +import java.net.ProxySelector +import javax.net.SocketFactory + +class RestAuthenticatorTest { + + private val stubResponse + get() = Response.Builder() + .request( + Request.Builder() + .url("http://192.168.0.1:8080") + .build() + ) + .protocol(Protocol.HTTP_1_0) + .message("msg") + .code(333) + .build() + + private val stubRoute + get() = Route( + address = Address( + uriHost = "192.168.0.1", + uriPort = 8080, + dns = Dns.SYSTEM, + socketFactory = SocketFactory.getDefault(), + sslSocketFactory = null, + hostnameVerifier = null, + certificatePinner = null, + proxyAuthenticator = Authenticator.NONE, + proxy = null, + protocols = emptyList(), + connectionSpecs = emptyList(), + proxySelector = ProxySelector.getDefault(), + ), + proxy = Proxy.NO_PROXY, + socketAddress = InetSocketAddress.createUnresolved("192.168.0.1", 8080) + ) + + private val stubCredentialsProvider = mockk() + + private val authenticator = RestAuthenticator(stubCredentialsProvider) + + @Test + fun `given provider has no credentials, expected authenticator returned null response value`() { + every { + stubCredentialsProvider() + } returns CredentialsProvider.Data.None + + val expected: Response? = null + val actual = authenticator.authenticate(stubRoute, stubResponse) + Assert.assertEquals(expected, actual) + } + + @Test + fun `given provider has HTTP credentials, expected authenticator returned response with AUTHORIZATION header`() { + val stubLogin = "5598" + val stubPassword = "is_my_favorite" + + every { + stubCredentialsProvider() + } returns CredentialsProvider.Data.HttpBasic( + login = stubLogin, + password = stubPassword, + ) + + val expected = "Basic NTU5ODppc19teV9mYXZvcml0ZQ==" + val actual = authenticator.authenticate(stubRoute, stubResponse) + Assert.assertEquals(true, actual != null) + Assert.assertEquals(expected, actual?.headers?.get(NetworkHeaders.AUTHORIZATION)) + } +} diff --git a/presentation/build.gradle b/presentation/build.gradle index d815875f..d71c2e4e 100755 --- a/presentation/build.gradle +++ b/presentation/build.gradle @@ -14,6 +14,15 @@ android { composeOptions { kotlinCompilerExtensionVersion = "1.5.7" } + testOptions { + unitTests.returnDefaultValues = true + unitTests.all { + jvmArgs( + "--add-opens", "java.base/java.lang=ALL-UNNAMED", + "--add-opens", "java.base/java.lang.reflect=ALL-UNNAMED" + ) + } + } } dependencies { @@ -31,6 +40,7 @@ dependencies { implementation androidx.appcompat implementation androidx.activity implementation androidx.pagingRx3 + implementation androidx.exifinterface implementation google.material implementation apache.stringutils @@ -45,5 +55,9 @@ dependencies { implementation ui.catppuccinSplashscreen implementation ui.composeGestures implementation ui.composeEasyCrop - implementation "androidx.exifinterface:exifinterface:1.3.6" + + testImplementation test.junit + testImplementation test.mockk + testImplementation test.coroutines + testImplementation test.turbine } diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/core/GenerationMviViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/core/GenerationMviViewModel.kt index 15cc7e91..e6cdf698 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/core/GenerationMviViewModel.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/core/GenerationMviViewModel.kt @@ -32,25 +32,21 @@ import com.shifthackz.android.core.mvi.MviEffect import io.reactivex.rxjava3.core.Observable import io.reactivex.rxjava3.disposables.Disposable import io.reactivex.rxjava3.kotlin.subscribeBy -import org.koin.core.component.KoinComponent -import org.koin.core.component.inject import java.util.concurrent.TimeUnit -abstract class GenerationMviViewModel : - MviRxViewModel(), KoinComponent { - - private val preferenceManager: PreferenceManager by inject() - private val schedulersProvider: SchedulersProvider by inject() - private val saveLastResultToCacheUseCase: SaveLastResultToCacheUseCase by inject() - private val saveGenerationResultUseCase: SaveGenerationResultUseCase by inject() - private val getStableDiffusionSamplersUseCase: GetStableDiffusionSamplersUseCase by inject() - private val observeHordeProcessStatusUseCase: ObserveHordeProcessStatusUseCase by inject() - private val observeLocalDiffusionProcessStatusUseCase: ObserveLocalDiffusionProcessStatusUseCase by inject() - private val interruptGenerationUseCase: InterruptGenerationUseCase by inject() - - private val mainRouter: MainRouter by inject() - private val drawerRouter: DrawerRouter by inject() - private val dimensionValidator: DimensionValidator by inject() +abstract class GenerationMviViewModel( + preferenceManager: PreferenceManager, + getStableDiffusionSamplersUseCase: GetStableDiffusionSamplersUseCase, + observeHordeProcessStatusUseCase: ObserveHordeProcessStatusUseCase, + observeLocalDiffusionProcessStatusUseCase: ObserveLocalDiffusionProcessStatusUseCase, + private val saveLastResultToCacheUseCase: SaveLastResultToCacheUseCase, + private val saveGenerationResultUseCase: SaveGenerationResultUseCase, + private val interruptGenerationUseCase: InterruptGenerationUseCase, + private val mainRouter: MainRouter, + private val drawerRouter: DrawerRouter, + private val dimensionValidator: DimensionValidator, + private val schedulersProvider: SchedulersProvider, +) : MviRxViewModel() { private var generationDisposable: Disposable? = null private var randomImageDisposable: Disposable? = null @@ -121,6 +117,14 @@ abstract class GenerationMviViewModel updateGenerationState { diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/ModalRenderer.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/ModalRenderer.kt index 0d584946..f878949b 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/ModalRenderer.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/ModalRenderer.kt @@ -171,7 +171,7 @@ fun ModalRenderer( items = screenModal.models, onItemSelected = { selectedItem = it }, ) - } + }, ) } @@ -240,7 +240,7 @@ fun ModalRenderer( is Modal.Image.Crop -> CropImageModal( bitmap = screenModal.bitmap, onDismissRequest = dismiss, - onResult = { processIntent(ImageToImageIntent.UpdateImage(it)) } + onResult = { processIntent(ImageToImageIntent.UpdateImage(it)) }, ) Modal.ConnectLocalHost -> DecisionInteractiveDialog( diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/extras/ExtrasViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/extras/ExtrasViewModel.kt index 5b27706f..d5bcd1d8 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/extras/ExtrasViewModel.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/modal/extras/ExtrasViewModel.kt @@ -3,6 +3,7 @@ package com.shifthackz.aisdv1.presentation.modal.extras import com.shifthackz.aisdv1.core.common.log.errorLog import com.shifthackz.aisdv1.core.common.schedulers.SchedulersProvider import com.shifthackz.aisdv1.core.common.schedulers.subscribeOnMainThread +import com.shifthackz.aisdv1.core.common.time.TimeProvider import com.shifthackz.aisdv1.core.viewmodel.MviRxViewModel import com.shifthackz.aisdv1.domain.entity.StableDiffusionHyperNetwork import com.shifthackz.aisdv1.domain.entity.StableDiffusionLora @@ -17,6 +18,7 @@ class ExtrasViewModel( private val fetchAndGetLorasUseCase: FetchAndGetLorasUseCase, private val fetchAndGetHyperNetworksUseCase: FetchAndGetHyperNetworksUseCase, private val schedulersProvider: SchedulersProvider, + private val timeProvider: TimeProvider, ) : MviRxViewModel() { override val initialState = ExtrasState() @@ -79,7 +81,7 @@ class ExtrasViewModel( when (it) { is StableDiffusionLora -> ExtraItemUi( type = type, - key = "${it.name}_${type}_${System.nanoTime()}", + key = "${it.name}_${type}_${timeProvider.nanoTime()}", name = it.name, alias = it.alias, isApplied = isApplied, @@ -88,7 +90,7 @@ class ExtrasViewModel( is StableDiffusionHyperNetwork -> ExtraItemUi( type = type, - key = "${it.name}_${type}_${System.nanoTime()}", + key = "${it.name}_${type}_${timeProvider.nanoTime()}", name = it.name, alias = null, isApplied = isApplied, diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/navigation/router/drawer/DrawerRouterImpl.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/navigation/router/drawer/DrawerRouterImpl.kt index 6efa94aa..d0bba577 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/navigation/router/drawer/DrawerRouterImpl.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/navigation/router/drawer/DrawerRouterImpl.kt @@ -1,16 +1,13 @@ package com.shifthackz.aisdv1.presentation.navigation.router.drawer import com.shifthackz.aisdv1.presentation.navigation.NavigationEffect -import io.reactivex.rxjava3.core.Observable import io.reactivex.rxjava3.subjects.PublishSubject internal class DrawerRouterImpl : DrawerRouter { private val effectSubject: PublishSubject = PublishSubject.create() - override fun observe(): Observable { - return effectSubject - } + override fun observe() = effectSubject.distinctUntilChanged() override fun openDrawer() { effectSubject.onNext(NavigationEffect.Drawer.Open) diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/navigation/router/main/MainRouterImpl.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/navigation/router/main/MainRouterImpl.kt index ff2a4170..23399f95 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/navigation/router/main/MainRouterImpl.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/navigation/router/main/MainRouterImpl.kt @@ -52,7 +52,7 @@ internal class MainRouterImpl( } override fun navigateToDebugMenu() { - if (debugMenuAccessor.invoke()) { + if (debugMenuAccessor()) { effectSubject.onNext(NavigationEffect.Navigate.Route(Constants.ROUTE_DEBUG)) } } diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/gallery/detail/GalleryDetailIntent.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/gallery/detail/GalleryDetailIntent.kt index 43017236..ba3bc591 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/gallery/detail/GalleryDetailIntent.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/gallery/detail/GalleryDetailIntent.kt @@ -15,7 +15,6 @@ sealed interface GalleryDetailIntent : MviIntent { } enum class Export : GalleryDetailIntent { - Image, Params; } diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/gallery/detail/GalleryDetailViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/gallery/detail/GalleryDetailViewModel.kt index 783dd884..f4879243 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/gallery/detail/GalleryDetailViewModel.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/gallery/detail/GalleryDetailViewModel.kt @@ -76,7 +76,6 @@ class GalleryDetailViewModel( GalleryDetailIntent.DismissDialog -> setActiveModal(Modal.None) } - } private fun share() { @@ -122,7 +121,7 @@ class GalleryDetailViewModel( } private fun getGenerationResult(id: Long): Single { - if (id <= 0) return getLastResultFromCacheUseCase.invoke() + if (id <= 0) return getLastResultFromCacheUseCase() return getGenerationResultUseCase(id) } } diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/gallery/list/GalleryViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/gallery/list/GalleryViewModel.kt index 5a553932..1ed1e7f6 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/gallery/list/GalleryViewModel.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/gallery/list/GalleryViewModel.kt @@ -65,7 +65,6 @@ class GalleryViewModel( } } - private fun launchGalleryExport() = galleryExporter() .doOnSubscribe { setActiveModal(Modal.ExportInProgress) } .subscribeOnMainThread(schedulersProvider) diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageViewModel.kt index 6ae37485..05c238d1 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageViewModel.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageViewModel.kt @@ -7,18 +7,26 @@ import com.shifthackz.aisdv1.core.imageprocessing.Base64ToBitmapConverter import com.shifthackz.aisdv1.core.imageprocessing.BitmapToBase64Converter import com.shifthackz.aisdv1.core.model.UiText import com.shifthackz.aisdv1.core.model.asUiText +import com.shifthackz.aisdv1.core.validation.dimension.DimensionValidator import com.shifthackz.aisdv1.domain.entity.AiGenerationResult import com.shifthackz.aisdv1.domain.entity.HordeProcessStatus import com.shifthackz.aisdv1.domain.interactor.wakelock.WakeLockInterActor import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import com.shifthackz.aisdv1.domain.usecase.caching.SaveLastResultToCacheUseCase import com.shifthackz.aisdv1.domain.usecase.generation.GetRandomImageUseCase import com.shifthackz.aisdv1.domain.usecase.generation.ImageToImageUseCase +import com.shifthackz.aisdv1.domain.usecase.generation.InterruptGenerationUseCase +import com.shifthackz.aisdv1.domain.usecase.generation.ObserveHordeProcessStatusUseCase +import com.shifthackz.aisdv1.domain.usecase.generation.ObserveLocalDiffusionProcessStatusUseCase +import com.shifthackz.aisdv1.domain.usecase.generation.SaveGenerationResultUseCase +import com.shifthackz.aisdv1.domain.usecase.sdsampler.GetStableDiffusionSamplersUseCase import com.shifthackz.aisdv1.presentation.R import com.shifthackz.aisdv1.presentation.core.GenerationFormUpdateEvent import com.shifthackz.aisdv1.presentation.core.GenerationMviIntent import com.shifthackz.aisdv1.presentation.core.GenerationMviViewModel import com.shifthackz.aisdv1.presentation.core.ImageToImageIntent import com.shifthackz.aisdv1.presentation.model.Modal +import com.shifthackz.aisdv1.presentation.navigation.router.drawer.DrawerRouter import com.shifthackz.aisdv1.presentation.navigation.router.main.MainRouter import com.shifthackz.aisdv1.presentation.notification.SdaiPushNotificationManager import com.shifthackz.aisdv1.presentation.screen.inpaint.InPaintStateProducer @@ -29,6 +37,14 @@ import io.reactivex.rxjava3.kotlin.subscribeBy class ImageToImageViewModel( generationFormUpdateEvent: GenerationFormUpdateEvent, + getStableDiffusionSamplersUseCase: GetStableDiffusionSamplersUseCase, + observeHordeProcessStatusUseCase: ObserveHordeProcessStatusUseCase, + observeLocalDiffusionProcessStatusUseCase: ObserveLocalDiffusionProcessStatusUseCase, + saveLastResultToCacheUseCase: SaveLastResultToCacheUseCase, + saveGenerationResultUseCase: SaveGenerationResultUseCase, + interruptGenerationUseCase: InterruptGenerationUseCase, + drawerRouter: DrawerRouter, + dimensionValidator: DimensionValidator, private val imageToImageUseCase: ImageToImageUseCase, private val getRandomImageUseCase: GetRandomImageUseCase, private val bitmapToBase64Converter: BitmapToBase64Converter, @@ -39,7 +55,19 @@ class ImageToImageViewModel( private val wakeLockInterActor: WakeLockInterActor, private val inPaintStateProducer: InPaintStateProducer, private val mainRouter: MainRouter, -) : GenerationMviViewModel() { +) : GenerationMviViewModel( + preferenceManager = preferenceManager, + getStableDiffusionSamplersUseCase = getStableDiffusionSamplersUseCase, + observeHordeProcessStatusUseCase = observeHordeProcessStatusUseCase, + observeLocalDiffusionProcessStatusUseCase = observeLocalDiffusionProcessStatusUseCase, + saveLastResultToCacheUseCase = saveLastResultToCacheUseCase, + saveGenerationResultUseCase = saveGenerationResultUseCase, + interruptGenerationUseCase = interruptGenerationUseCase, + mainRouter = mainRouter, + drawerRouter = drawerRouter, + dimensionValidator = dimensionValidator, + schedulersProvider = schedulersProvider, +) { override val initialState = ImageToImageState() @@ -56,7 +84,7 @@ class ImageToImageViewModel( .observeInPaint() .subscribeOnMainThread(schedulersProvider) .subscribeBy(::errorLog) { inPaint -> - updateState { it.copy(inPaintModel = inPaint) } + updateState { it.copy(inPaintModel = inPaint) } } } diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/inpaint/InPaintViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/inpaint/InPaintViewModel.kt index 57b014c5..dc8ddde1 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/inpaint/InPaintViewModel.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/inpaint/InPaintViewModel.kt @@ -1,6 +1,5 @@ package com.shifthackz.aisdv1.presentation.screen.inpaint -import com.shifthackz.aisdv1.core.common.log.debugLog import com.shifthackz.aisdv1.core.common.log.errorLog import com.shifthackz.aisdv1.core.common.schedulers.SchedulersProvider import com.shifthackz.aisdv1.core.common.schedulers.subscribeOnMainThread @@ -32,7 +31,6 @@ class InPaintViewModel( } override fun processIntent(intent: InPaintIntent) { - debugLog("INTENT : $intent") when (intent) { is InPaintIntent.DrawPath -> updateState { state -> state.copy( diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupState.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupState.kt index 589a0976..58c46cc4 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupState.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupState.kt @@ -83,6 +83,7 @@ data class ServerSetupState( ) } +//ToDo refactor key to enum ordinal enum class ServerSetupLaunchSource(val key: Int) { SPLASH(0), SETTINGS(1); diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt index 18a72b5e..9d74b8d5 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt @@ -4,7 +4,7 @@ import com.shifthackz.aisdv1.core.common.log.errorLog import com.shifthackz.aisdv1.core.common.schedulers.SchedulersProvider import com.shifthackz.aisdv1.core.common.schedulers.subscribeOnMainThread import com.shifthackz.aisdv1.core.model.asUiText -import com.shifthackz.aisdv1.core.validation.horde.CommonStringValidator +import com.shifthackz.aisdv1.core.validation.common.CommonStringValidator import com.shifthackz.aisdv1.core.validation.url.UrlValidator import com.shifthackz.aisdv1.core.viewmodel.MviRxViewModel import com.shifthackz.aisdv1.domain.entity.DownloadState @@ -58,7 +58,8 @@ class ServerSetupViewModel( getLocalAiModelsUseCase(), fetchAndGetHuggingFaceModelsUseCase(), ::Triple, - ).subscribeOnMainThread(schedulersProvider) + ) + .subscribeOnMainThread(schedulersProvider) .subscribeBy(::errorLog) { (configuration, localModels, hfModels) -> updateState { state -> state.copy( @@ -381,7 +382,8 @@ class ServerSetupViewModel( } downloadDisposable?.dispose() downloadDisposable = null - downloadDisposable = downloadModelUseCase(localModel.id).distinctUntilChanged() + downloadDisposable = downloadModelUseCase(localModel.id) + .distinctUntilChanged() .doOnSubscribe { wakeLockInterActor.acquireWakelockUseCase() } .doFinally { wakeLockInterActor.releaseWakeLockUseCase() } .subscribeOnMainThread(schedulersProvider).subscribeBy( diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/mappers/ServerSetupValidationStringErrorMapper.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/mappers/ServerSetupValidationStringErrorMapper.kt index db0d9484..fa330c61 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/mappers/ServerSetupValidationStringErrorMapper.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/mappers/ServerSetupValidationStringErrorMapper.kt @@ -3,7 +3,7 @@ package com.shifthackz.aisdv1.presentation.screen.setup.mappers import com.shifthackz.aisdv1.core.model.UiText import com.shifthackz.aisdv1.core.model.asUiText import com.shifthackz.aisdv1.core.validation.ValidationResult -import com.shifthackz.aisdv1.core.validation.horde.CommonStringValidator +import com.shifthackz.aisdv1.core.validation.common.CommonStringValidator import com.shifthackz.aisdv1.presentation.R fun ValidationResult.mapToUi(): UiText? { diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageViewModel.kt index 1623e4b6..3d788e90 100755 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageViewModel.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageViewModel.kt @@ -4,29 +4,59 @@ import com.shifthackz.aisdv1.core.common.log.errorLog import com.shifthackz.aisdv1.core.common.schedulers.SchedulersProvider import com.shifthackz.aisdv1.core.common.schedulers.subscribeOnMainThread import com.shifthackz.aisdv1.core.model.asUiText +import com.shifthackz.aisdv1.core.validation.dimension.DimensionValidator import com.shifthackz.aisdv1.domain.entity.HordeProcessStatus import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.aisdv1.domain.feature.diffusion.LocalDiffusion import com.shifthackz.aisdv1.domain.interactor.wakelock.WakeLockInterActor import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import com.shifthackz.aisdv1.domain.usecase.caching.SaveLastResultToCacheUseCase +import com.shifthackz.aisdv1.domain.usecase.generation.InterruptGenerationUseCase +import com.shifthackz.aisdv1.domain.usecase.generation.ObserveHordeProcessStatusUseCase +import com.shifthackz.aisdv1.domain.usecase.generation.ObserveLocalDiffusionProcessStatusUseCase +import com.shifthackz.aisdv1.domain.usecase.generation.SaveGenerationResultUseCase import com.shifthackz.aisdv1.domain.usecase.generation.TextToImageUseCase +import com.shifthackz.aisdv1.domain.usecase.sdsampler.GetStableDiffusionSamplersUseCase import com.shifthackz.aisdv1.presentation.R import com.shifthackz.aisdv1.presentation.core.GenerationFormUpdateEvent import com.shifthackz.aisdv1.presentation.core.GenerationMviIntent import com.shifthackz.aisdv1.presentation.core.GenerationMviViewModel import com.shifthackz.aisdv1.presentation.model.Modal +import com.shifthackz.aisdv1.presentation.navigation.router.drawer.DrawerRouter +import com.shifthackz.aisdv1.presentation.navigation.router.main.MainRouter import com.shifthackz.aisdv1.presentation.notification.SdaiPushNotificationManager import com.shifthackz.android.core.mvi.EmptyEffect import io.reactivex.rxjava3.kotlin.subscribeBy class TextToImageViewModel( generationFormUpdateEvent: GenerationFormUpdateEvent, + getStableDiffusionSamplersUseCase: GetStableDiffusionSamplersUseCase, + observeHordeProcessStatusUseCase: ObserveHordeProcessStatusUseCase, + observeLocalDiffusionProcessStatusUseCase: ObserveLocalDiffusionProcessStatusUseCase, + saveLastResultToCacheUseCase: SaveLastResultToCacheUseCase, + saveGenerationResultUseCase: SaveGenerationResultUseCase, + interruptGenerationUseCase: InterruptGenerationUseCase, + mainRouter: MainRouter, + drawerRouter: DrawerRouter, + dimensionValidator: DimensionValidator, private val textToImageUseCase: TextToImageUseCase, private val schedulersProvider: SchedulersProvider, private val preferenceManager: PreferenceManager, private val notificationManager: SdaiPushNotificationManager, private val wakeLockInterActor: WakeLockInterActor, -) : GenerationMviViewModel() { +) : GenerationMviViewModel( + preferenceManager = preferenceManager, + getStableDiffusionSamplersUseCase = getStableDiffusionSamplersUseCase, + observeHordeProcessStatusUseCase = observeHordeProcessStatusUseCase, + observeLocalDiffusionProcessStatusUseCase = observeLocalDiffusionProcessStatusUseCase, + saveLastResultToCacheUseCase = saveLastResultToCacheUseCase, + saveGenerationResultUseCase = saveGenerationResultUseCase, + interruptGenerationUseCase = interruptGenerationUseCase, + mainRouter = mainRouter, + drawerRouter = drawerRouter, + dimensionValidator = dimensionValidator, + schedulersProvider = schedulersProvider, +) { private val progressModal: Modal get() { diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModel.kt index af32aeee..33582b3b 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModel.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModel.kt @@ -6,6 +6,7 @@ import com.shifthackz.aisdv1.core.common.model.Quintuple import com.shifthackz.aisdv1.core.common.schedulers.SchedulersProvider import com.shifthackz.aisdv1.core.common.schedulers.subscribeOnMainThread import com.shifthackz.aisdv1.core.viewmodel.MviRxViewModel +import com.shifthackz.aisdv1.domain.entity.Configuration import com.shifthackz.aisdv1.domain.entity.LocalAiModel import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.aisdv1.domain.preference.PreferenceManager @@ -36,6 +37,7 @@ class EngineSelectionViewModel( val configuration = preferenceManager .observe() .flatMap { getConfigurationUseCase().toFlowable() } + .onErrorReturn { Configuration() } val a1111Models = getStableDiffusionModelsUseCase() .onErrorReturn { emptyList() } diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/activity/AiStableDiffusionViewModelTest.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/activity/AiStableDiffusionViewModelTest.kt new file mode 100644 index 00000000..4e70d5b8 --- /dev/null +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/activity/AiStableDiffusionViewModelTest.kt @@ -0,0 +1,167 @@ +package com.shifthackz.aisdv1.presentation.activity + +import androidx.navigation.NavOptionsBuilder +import app.cash.turbine.test +import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import com.shifthackz.aisdv1.presentation.core.CoreViewModelInitializeStrategy +import com.shifthackz.aisdv1.presentation.core.CoreViewModelTest +import com.shifthackz.aisdv1.presentation.navigation.NavigationEffect +import com.shifthackz.aisdv1.presentation.navigation.router.drawer.DrawerRouter +import com.shifthackz.aisdv1.presentation.navigation.router.main.MainRouter +import com.shifthackz.aisdv1.presentation.stub.stubSchedulersProvider +import io.mockk.every +import io.mockk.mockk +import io.mockk.verify +import io.reactivex.rxjava3.subjects.BehaviorSubject +import kotlinx.coroutines.flow.firstOrNull +import kotlinx.coroutines.test.runTest +import org.junit.Assert +import org.junit.Before +import org.junit.Test + +class AiStableDiffusionViewModelTest : CoreViewModelTest() { + + private val stubNavigationEffect = BehaviorSubject.create() + private val stubDrawerNavigationEffect = BehaviorSubject.create() + private val stubNavBuilder: NavOptionsBuilder.() -> Unit = { + popUpTo("splash") { inclusive = true } + } + + private val stubMainRouter = mockk() + private val stubDrawerRouter = mockk() + private val stubPreferenceManager = mockk() + + override val testViewModelStrategy = CoreViewModelInitializeStrategy.InitializeEveryTime + + override fun initializeViewModel() = AiStableDiffusionViewModel( + schedulersProvider = stubSchedulersProvider, + mainRouter = stubMainRouter, + drawerRouter = stubDrawerRouter, + preferenceManager = stubPreferenceManager, + ) + + @Before + override fun initialize() { + super.initialize() + + every { + stubMainRouter.observe() + } returns stubNavigationEffect + + every { + stubDrawerRouter.observe() + } returns stubDrawerNavigationEffect + } + + @Test + fun `given onStoragePermissionsGranted was called, expected VM sets field saveToMediaStore with true in preference manager`() { + every { + stubPreferenceManager::saveToMediaStore.set(any()) + } returns Unit + + viewModel.onStoragePermissionsGranted() + + verify { + stubPreferenceManager::saveToMediaStore.set(true) + } + } + + @Test + fun `given route event from main router, expected domain model delivered to effect collector`() { + stubNavigationEffect.onNext(NavigationEffect.Navigate.Route("route")) + runTest { + val expected = NavigationEffect.Navigate.Route("route") + val actual = viewModel.effect.firstOrNull() + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given route pop up event from main router, expected domain model delivered to effect collector`() { + stubNavigationEffect.onNext(NavigationEffect.Navigate.RoutePopUp("route")) + runTest { + val expected = NavigationEffect.Navigate.RoutePopUp("route") + val actual = viewModel.effect.firstOrNull() + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given route builder event from main router, expected domain model delivered to effect collector`() { + stubNavigationEffect.onNext( + NavigationEffect.Navigate.RouteBuilder("route", stubNavBuilder) + ) + runTest { + val expected = NavigationEffect.Navigate.RouteBuilder("route", stubNavBuilder) + val actual = viewModel.effect.firstOrNull() + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given back event from main router, expected domain model delivered to effect collector`() { + stubNavigationEffect.onNext(NavigationEffect.Back) + runTest { + val expected = NavigationEffect.Back + val actual = viewModel.effect.firstOrNull() + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given route then back events from main router, expected two domain models delivered to effect collector in same order`() { + runTest { + viewModel.effect.test { + stubNavigationEffect.onNext(NavigationEffect.Navigate.Route("route2")) + Assert.assertEquals(NavigationEffect.Navigate.Route("route2"), awaitItem()) + + stubNavigationEffect.onNext(NavigationEffect.Back) + Assert.assertEquals(NavigationEffect.Back, awaitItem()) + + cancelAndIgnoreRemainingEvents() + } + } + } + + @Test + fun `given mixed six events from main router, expected six domain models delivered to effect collector in same order`() { + runTest { + viewModel.effect.test { + stubNavigationEffect.onNext(NavigationEffect.Navigate.Route("route2")) + Assert.assertEquals(NavigationEffect.Navigate.Route("route2"), awaitItem()) + + stubNavigationEffect.onNext(NavigationEffect.Navigate.Route("route4")) + Assert.assertEquals(NavigationEffect.Navigate.Route("route4"), awaitItem()) + + stubNavigationEffect.onNext(NavigationEffect.Back) + Assert.assertEquals(NavigationEffect.Back, awaitItem()) + + stubNavigationEffect.onNext(NavigationEffect.Navigate.Route("route3")) + Assert.assertEquals(NavigationEffect.Navigate.Route("route3"), awaitItem()) + + stubNavigationEffect.onNext(NavigationEffect.Back) + Assert.assertEquals(NavigationEffect.Back, awaitItem()) + + stubNavigationEffect.onNext(NavigationEffect.Back) + Assert.assertEquals(NavigationEffect.Back, awaitItem()) + + cancelAndIgnoreRemainingEvents() + } + } + } + + @Test + fun `given open then close events from drawer router, expected two domain models delivered to effect collector in same order`() { + runTest { + viewModel.effect.test { + stubDrawerNavigationEffect.onNext(NavigationEffect.Drawer.Open) + Assert.assertEquals(NavigationEffect.Drawer.Open, awaitItem()) + + stubDrawerNavigationEffect.onNext(NavigationEffect.Drawer.Close) + Assert.assertEquals(NavigationEffect.Drawer.Close, awaitItem()) + + cancelAndIgnoreRemainingEvents() + } + } + } +} diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/core/CoreGenerationMviViewModelTest.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/core/CoreGenerationMviViewModelTest.kt new file mode 100644 index 00000000..32fe1567 --- /dev/null +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/core/CoreGenerationMviViewModelTest.kt @@ -0,0 +1,106 @@ +package com.shifthackz.aisdv1.presentation.core + +import com.shifthackz.aisdv1.core.common.schedulers.SchedulersProvider +import com.shifthackz.aisdv1.core.validation.dimension.DimensionValidator +import com.shifthackz.aisdv1.domain.entity.AiGenerationResult +import com.shifthackz.aisdv1.domain.entity.HordeProcessStatus +import com.shifthackz.aisdv1.domain.entity.Settings +import com.shifthackz.aisdv1.domain.feature.diffusion.LocalDiffusion +import com.shifthackz.aisdv1.domain.interactor.wakelock.WakeLockInterActor +import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import com.shifthackz.aisdv1.domain.usecase.caching.SaveLastResultToCacheUseCase +import com.shifthackz.aisdv1.domain.usecase.generation.InterruptGenerationUseCase +import com.shifthackz.aisdv1.domain.usecase.generation.ObserveHordeProcessStatusUseCase +import com.shifthackz.aisdv1.domain.usecase.generation.ObserveLocalDiffusionProcessStatusUseCase +import com.shifthackz.aisdv1.domain.usecase.generation.SaveGenerationResultUseCase +import com.shifthackz.aisdv1.domain.usecase.sdsampler.GetStableDiffusionSamplersUseCase +import com.shifthackz.aisdv1.domain.usecase.wakelock.AcquireWakelockUseCase +import com.shifthackz.aisdv1.domain.usecase.wakelock.ReleaseWakeLockUseCase +import com.shifthackz.aisdv1.presentation.navigation.router.drawer.DrawerRouter +import com.shifthackz.aisdv1.presentation.navigation.router.main.MainRouter +import com.shifthackz.aisdv1.presentation.notification.SdaiPushNotificationManager +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.android.schedulers.AndroidSchedulers +import io.reactivex.rxjava3.core.BackpressureStrategy +import io.reactivex.rxjava3.core.Scheduler +import io.reactivex.rxjava3.core.Single +import io.reactivex.rxjava3.schedulers.Schedulers +import io.reactivex.rxjava3.subjects.BehaviorSubject +import org.junit.Before +import java.util.concurrent.Executor +import java.util.concurrent.Executors + +abstract class CoreGenerationMviViewModelTest> : + CoreViewModelTest() { + + protected val stubSettings = BehaviorSubject.createDefault(Settings()) + protected val stubAiForm = BehaviorSubject.create() + + protected val stubPreferenceManager = mockk() + protected val stubSaveLastResultToCacheUseCase = mockk() + protected val stubSaveGenerationResultUseCase = mockk() + protected val stubGetStableDiffusionSamplersUseCase = mockk() + protected val stubObserveHordeProcessStatusUseCase = mockk() + protected val stubObserveLocalDiffusionProcessStatusUseCase = mockk() + protected val stubInterruptGenerationUseCase = mockk() + protected val stubMainRouter = mockk() + protected val stubDrawerRouter = mockk() + protected val stubDimensionValidator = mockk() + protected val stubSdaiPushNotificationManager = mockk() + + protected val stubAcquireWakelockUseCase = mockk() + protected val stubReleaseWakelockUseCase = mockk() + protected val stubWakeLockInterActor = mockk() + + private val stubHordeProcessStatus = BehaviorSubject.create() + private val stubLdStatus = BehaviorSubject.create() + + protected val stubCustomSchedulers = object : SchedulersProvider { + override val io: Scheduler = Schedulers.io() + override val ui: Scheduler = AndroidSchedulers.mainThread() + override val computation: Scheduler = Schedulers.trampoline() + override val singleThread: Executor = Executors.newSingleThreadExecutor() + } + + @Before + override fun initialize() { + super.initialize() + + every { + stubPreferenceManager.observe() + } returns stubSettings.toFlowable(BackpressureStrategy.LATEST) + + every { + stubObserveHordeProcessStatusUseCase() + } returns stubHordeProcessStatus.toFlowable(BackpressureStrategy.LATEST) + + every { + stubObserveLocalDiffusionProcessStatusUseCase() + } returns stubLdStatus + + every { + stubGetStableDiffusionSamplersUseCase() + } returns Single.just(emptyList()) + + every { + stubAcquireWakelockUseCase.invoke(any()) + } returns Result.success(Unit) + + every { + stubAcquireWakelockUseCase.invoke() + } returns Result.success(Unit) + + every { + stubReleaseWakelockUseCase.invoke() + } returns Result.success(Unit) + + every { + stubWakeLockInterActor::acquireWakelockUseCase.get() + } returns stubAcquireWakelockUseCase + + every { + stubWakeLockInterActor::releaseWakeLockUseCase.get() + } returns stubReleaseWakelockUseCase + } +} diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/core/CoreViewModelInitializeStrategy.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/core/CoreViewModelInitializeStrategy.kt new file mode 100644 index 00000000..5a247d66 --- /dev/null +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/core/CoreViewModelInitializeStrategy.kt @@ -0,0 +1,6 @@ +package com.shifthackz.aisdv1.presentation.core + +enum class CoreViewModelInitializeStrategy { + InitializeOnce, + InitializeEveryTime; +} diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/core/CoreViewModelTest.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/core/CoreViewModelTest.kt new file mode 100644 index 00000000..5b8cde70 --- /dev/null +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/core/CoreViewModelTest.kt @@ -0,0 +1,50 @@ +@file:OptIn(ExperimentalCoroutinesApi::class) + +package com.shifthackz.aisdv1.presentation.core + +import androidx.lifecycle.ViewModel +import kotlinx.coroutines.CoroutineDispatcher +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.test.UnconfinedTestDispatcher +import kotlinx.coroutines.test.resetMain +import kotlinx.coroutines.test.setMain +import org.junit.After +import org.junit.Before + +abstract class CoreViewModelTest { + + private var _viewModel: V? = null + + protected val viewModel: V + get() = when (testViewModelStrategy) { + CoreViewModelInitializeStrategy.InitializeOnce -> _viewModel ?: run { + val vm = initializeViewModel() + _viewModel = vm + vm + } + CoreViewModelInitializeStrategy.InitializeEveryTime -> { + val vm = initializeViewModel() + _viewModel = vm + vm + } + } + + open val testViewModelStrategy: CoreViewModelInitializeStrategy + get() = CoreViewModelInitializeStrategy.InitializeOnce + + open val testDispatcher: CoroutineDispatcher + get() = UnconfinedTestDispatcher() + + @Before + open fun initialize() { + Dispatchers.setMain(testDispatcher) + } + + @After + open fun finalize() { + Dispatchers.resetMain() + } + + abstract fun initializeViewModel(): V +} diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/AiGenerationResultMocks.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/AiGenerationResultMocks.kt new file mode 100644 index 00000000..96cf8a88 --- /dev/null +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/AiGenerationResultMocks.kt @@ -0,0 +1,24 @@ +package com.shifthackz.aisdv1.presentation.mocks + +import com.shifthackz.aisdv1.domain.entity.AiGenerationResult +import java.util.Date + +val mockAiGenerationResult = AiGenerationResult( + id = 5598L, + image = "img", + inputImage = "inp", + createdAt = Date(0), + type = AiGenerationResult.Type.IMAGE_TO_IMAGE, + prompt = "prompt", + negativePrompt = "negative", + width = 512, + height = 512, + samplingSteps = 7, + cfgScale = 0.7f, + restoreFaces = true, + sampler = "sampler", + seed = "5598", + subSeed = "1504", + subSeedStrength = 5598f, + denoisingStrength = 1504f, +) diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/HuggingFaceModelMocks.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/HuggingFaceModelMocks.kt new file mode 100644 index 00000000..a18196bb --- /dev/null +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/HuggingFaceModelMocks.kt @@ -0,0 +1,13 @@ +package com.shifthackz.aisdv1.presentation.mocks + +import com.shifthackz.aisdv1.domain.entity.HuggingFaceModel + +val mockHuggingFaceModels = listOf( + HuggingFaceModel.default, + HuggingFaceModel( + "80974f2d-7ee0-48e5-97bc-448de3c1d634", + "Analog Diffusion", + "wavymulder/Analog-Diffusion", + "https://huggingface.co/wavymulder/Analog-Diffusion", + ), +) diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/LocalAiModelMocks.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/LocalAiModelMocks.kt new file mode 100644 index 00000000..8b9eb04a --- /dev/null +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/LocalAiModelMocks.kt @@ -0,0 +1,26 @@ +package com.shifthackz.aisdv1.presentation.mocks + +import com.shifthackz.aisdv1.domain.entity.DownloadState +import com.shifthackz.aisdv1.domain.entity.LocalAiModel +import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupState + +val mockLocalAiModels = listOf( + LocalAiModel.CUSTOM, + LocalAiModel( + id = "1", + name = "Model 1", + size = "5 Gb", + sources = listOf("https://example.com/1.html"), + downloaded = false, + selected = false, + ), +) + +val mockServerSetupStateLocalModel = ServerSetupState.LocalModel( + id = "1", + name = "Model 1", + size = "5 Gb", + downloaded = false, + downloadState = DownloadState.Unknown, + selected = false, +) diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/StabilityAiEngineMocks.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/StabilityAiEngineMocks.kt new file mode 100644 index 00000000..e4f514ea --- /dev/null +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/StabilityAiEngineMocks.kt @@ -0,0 +1,10 @@ +package com.shifthackz.aisdv1.presentation.mocks + +import com.shifthackz.aisdv1.domain.entity.StabilityAiEngine + +val mockStabilityAiEngines = listOf( + StabilityAiEngine( + id = "5598", + name = "engine_5598", + ), +) diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/StableDiffusionEmbeddingMocks.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/StableDiffusionEmbeddingMocks.kt new file mode 100644 index 00000000..5c83e46e --- /dev/null +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/StableDiffusionEmbeddingMocks.kt @@ -0,0 +1,8 @@ +package com.shifthackz.aisdv1.presentation.mocks + +import com.shifthackz.aisdv1.domain.entity.StableDiffusionEmbedding + +val mockStableDiffusionEmbeddings = listOf( + StableDiffusionEmbedding("5598"), + StableDiffusionEmbedding("151297"), +) diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/StableDiffusionLoraMocks.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/StableDiffusionLoraMocks.kt new file mode 100644 index 00000000..dd20509a --- /dev/null +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/StableDiffusionLoraMocks.kt @@ -0,0 +1,16 @@ +package com.shifthackz.aisdv1.presentation.mocks + +import com.shifthackz.aisdv1.domain.entity.StableDiffusionLora + +val mockStableDiffusionLoras = listOf( + StableDiffusionLora( + name = "name_5598", + alias = "alias_5598", + path = "/unknown", + ), + StableDiffusionLora( + name = "name_151297", + alias = "alias_151297", + path = "/unknown", + ), +) diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/StableDiffusionModelMocks.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/StableDiffusionModelMocks.kt new file mode 100644 index 00000000..29149f1d --- /dev/null +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/StableDiffusionModelMocks.kt @@ -0,0 +1,22 @@ +package com.shifthackz.aisdv1.presentation.mocks + +import com.shifthackz.aisdv1.domain.entity.StableDiffusionModel + +val mockStableDiffusionModels = listOf( + StableDiffusionModel( + title = "title_5598", + modelName = "name_5598", + hash = "hash_5598", + sha256 = "sha_5598", + filename = "file_5598", + config = "config_5598", + ) to true, + StableDiffusionModel( + title = "title_151297", + modelName = "name_151297", + hash = "hash_151297", + sha256 = "sha_151297", + filename = "file_151297", + config = "config_151297", + ) to false, +) diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/StableDiffusionSamplerMocks.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/StableDiffusionSamplerMocks.kt new file mode 100644 index 00000000..53d1f943 --- /dev/null +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/StableDiffusionSamplerMocks.kt @@ -0,0 +1,16 @@ +package com.shifthackz.aisdv1.presentation.mocks + +import com.shifthackz.aisdv1.domain.entity.StableDiffusionSampler + +val mockStableDiffusionSamplers = listOf( + StableDiffusionSampler( + name = "sampler_1", + aliases = listOf("alias_1"), + options = mapOf("option" to "value"), + ), + StableDiffusionSampler( + name = "sampler_2", + aliases = listOf("alias_2"), + options = mapOf("option" to "value"), + ), +) diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/modal/embedding/EmbeddingViewModelTest.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/modal/embedding/EmbeddingViewModelTest.kt new file mode 100644 index 00000000..0d5535c8 --- /dev/null +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/modal/embedding/EmbeddingViewModelTest.kt @@ -0,0 +1,138 @@ +package com.shifthackz.aisdv1.presentation.modal.embedding + +import com.shifthackz.aisdv1.domain.usecase.sdembedding.FetchAndGetEmbeddingsUseCase +import com.shifthackz.aisdv1.presentation.core.CoreViewModelTest +import com.shifthackz.aisdv1.presentation.mocks.mockStableDiffusionEmbeddings +import com.shifthackz.aisdv1.presentation.modal.extras.ExtrasEffect +import com.shifthackz.aisdv1.presentation.model.ErrorState +import com.shifthackz.aisdv1.presentation.stub.stubSchedulersProvider +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Single +import kotlinx.coroutines.flow.firstOrNull +import kotlinx.coroutines.test.runTest +import org.junit.Assert +import org.junit.Test + +class EmbeddingViewModelTest : CoreViewModelTest() { + + private val stubException = Throwable("Something went wrong.") + private val stubFetchAndGetEmbeddingsUseCase = mockk() + + override fun initializeViewModel() = EmbeddingViewModel( + fetchAndGetEmbeddingsUseCase = stubFetchAndGetEmbeddingsUseCase, + schedulersProvider = stubSchedulersProvider, + ) + + @Test + fun `given update data, fetch embeddings successful, expected UI state with embeddings list`() { + every { + stubFetchAndGetEmbeddingsUseCase() + } returns Single.just(mockStableDiffusionEmbeddings) + + viewModel.updateData("prompt", "negative") + + runTest { + val expected = EmbeddingState( + loading = false, + error = ErrorState.None, + prompt = "prompt", + negativePrompt = "negative", + embeddings = listOf( + EmbeddingItemUi( + keyword = "5598", + isInPrompt = false, + isInNegativePrompt = false, + ), + EmbeddingItemUi( + keyword = "151297", + isInPrompt = false, + isInNegativePrompt = false, + ), + ), + selector = false, + ) + val actual = viewModel.state.value + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given update data, fetch embeddings failed, expected UI state with Generic error`() { + every { + stubFetchAndGetEmbeddingsUseCase() + } returns Single.error(stubException) + + viewModel.updateData("prompt", "negative") + + runTest { + val expected = ErrorState.Generic + val actual = viewModel.state.value.error + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received ApplyNewPrompts intent, expected ApplyPrompts effect delivered to effect collector`() { + viewModel.processIntent(EmbeddingIntent.ApplyNewPrompts) + runTest { + val expected = ExtrasEffect.ApplyPrompts("", "") + val actual = viewModel.effect.firstOrNull() + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received ChangeSelector intent, expected selector field updated in UI state from intent`() { + viewModel.processIntent(EmbeddingIntent.ChangeSelector(true)) + runTest { + val expected = true + val actual = viewModel.state.value.selector + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received Close intent, expected Close effect delivered to effect collector`() { + viewModel.processIntent(EmbeddingIntent.Close) + runTest { + val expected = ExtrasEffect.Close + val actual = viewModel.effect.firstOrNull() + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received ToggleItem intent, expected item from intent isInNegativePrompt changed in UI state`() { + every { + stubFetchAndGetEmbeddingsUseCase() + } returns Single.just(mockStableDiffusionEmbeddings) + + viewModel.updateData("prompt", "negative") + + val embedding = EmbeddingItemUi( + keyword = "5598", + isInPrompt = false, + isInNegativePrompt = false, + ) + val intent = EmbeddingIntent.ToggleItem(embedding) + viewModel.processIntent(intent) + + runTest { + val expected = listOf( + EmbeddingItemUi( + keyword = "5598", + isInPrompt = false, + isInNegativePrompt = true, + ), + EmbeddingItemUi( + keyword = "151297", + isInPrompt = false, + isInNegativePrompt = false, + ), + ) + val actual = viewModel.state.value.embeddings + Assert.assertEquals(expected, actual) + } + } +} diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/modal/extras/ExtrasViewModelTest.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/modal/extras/ExtrasViewModelTest.kt new file mode 100644 index 00000000..93b872c0 --- /dev/null +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/modal/extras/ExtrasViewModelTest.kt @@ -0,0 +1,159 @@ +package com.shifthackz.aisdv1.presentation.modal.extras + +import com.shifthackz.aisdv1.core.common.time.TimeProvider +import com.shifthackz.aisdv1.domain.usecase.sdhypernet.FetchAndGetHyperNetworksUseCase +import com.shifthackz.aisdv1.domain.usecase.sdlora.FetchAndGetLorasUseCase +import com.shifthackz.aisdv1.presentation.core.CoreViewModelTest +import com.shifthackz.aisdv1.presentation.mocks.mockStableDiffusionLoras +import com.shifthackz.aisdv1.presentation.model.ErrorState +import com.shifthackz.aisdv1.presentation.model.ExtraType +import com.shifthackz.aisdv1.presentation.stub.stubSchedulersProvider +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.Single +import kotlinx.coroutines.flow.firstOrNull +import kotlinx.coroutines.test.runTest +import org.junit.Assert +import org.junit.Before +import org.junit.Test + +class ExtrasViewModelTest : CoreViewModelTest() { + + private val stubException = Throwable("Something went wrong.") + private val stubFetchAndGetLorasUseCase = mockk() + private val stubFetchAndGetHyperNetworksUseCase = mockk() + private val stubTimeProvider = mockk() + + override fun initializeViewModel() = ExtrasViewModel( + stubFetchAndGetLorasUseCase, + stubFetchAndGetHyperNetworksUseCase, + stubSchedulersProvider, + stubTimeProvider, + ) + + @Before + override fun initialize() { + super.initialize() + + every { + stubTimeProvider.nanoTime() + } returns MOCK_SYS_TIME + } + + @Test + fun `given update data, fetch loras successful, expected UI state with loras list`() { + mockInitialData() + runTest { + val expected = ExtrasState( + loading = false, + error = ErrorState.None, + prompt = "prompt ", + negativePrompt = "negative", + type = ExtraType.Lora, + loras = listOf( + ExtraItemUi( + type = ExtraType.Lora, + key = "name_5598_lora_$MOCK_SYS_TIME", + name = "name_5598", + alias = "alias_5598", + isApplied = true, + value = "1", + ), + ExtraItemUi( + type = ExtraType.Lora, + key = "name_151297_lora_$MOCK_SYS_TIME", + name = "name_151297", + alias = "alias_151297", + isApplied = false, + value = null, + ), + ), + ) + val actual = viewModel.state.value + Assert.assertEquals(expected.type, actual.type) + Assert.assertEquals(expected.error, actual.error) + Assert.assertEquals(expected.prompt, actual.prompt) + Assert.assertEquals(expected.negativePrompt, actual.negativePrompt) + Assert.assertEquals(expected.type, actual.type) + Assert.assertEquals(true, actual.loras.any { it.name == "name_5598" && it.isApplied }) + } + } + + @Test + fun `given update data, fetch loras failed, expected UI state with Generic error`() { + every { + stubFetchAndGetLorasUseCase() + } returns Single.error(stubException) + + viewModel.updateData( + prompt = "prompt ", + negativePrompt = "negative", + type = ExtraType.Lora, + ) + + runTest { + val state = viewModel.state.value + Assert.assertEquals(false, state.loading) + Assert.assertEquals(ErrorState.Generic, state.error) + } + } + + @Test + fun `given received ApplyPrompts intent, expected ApplyPrompts effect delivered to effect collector`() { + mockInitialData() + viewModel.processIntent(ExtrasIntent.ApplyPrompts) + runTest { + val expected = ExtrasEffect.ApplyPrompts( + prompt = "prompt ", + negativePrompt = "negative", + ) + val actual = viewModel.effect.firstOrNull() + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received Close intent, expected Close effect delivered to effect collector`() { + viewModel.processIntent(ExtrasIntent.Close) + runTest { + val expected = ExtrasEffect.Close + val actual = viewModel.effect.firstOrNull() + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received ToggleItem intent, expected prompt updated in UI state`() { + mockInitialData() + Thread.sleep(1000L) + val item = ExtraItemUi( + type = ExtraType.Lora, + key = "name_5598_lora_$MOCK_SYS_TIME", + name = "name_5598", + alias = "alias_5598", + isApplied = true, + value = "1", + ) + viewModel.processIntent(ExtrasIntent.ToggleItem(item)) + runTest { + val state = viewModel.state.value + Assert.assertEquals("prompt", state.prompt) + } + } + + private fun mockInitialData() { + every { + stubFetchAndGetLorasUseCase() + } returns Single.just(mockStableDiffusionLoras) + + viewModel.updateData( + prompt = "prompt ", + negativePrompt = "negative", + type = ExtraType.Lora, + ) + } + + companion object { + private const val MOCK_SYS_TIME = 5598L + } +} diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/modal/tag/EditTagViewModelTest.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/modal/tag/EditTagViewModelTest.kt new file mode 100644 index 00000000..d5a06a44 --- /dev/null +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/modal/tag/EditTagViewModelTest.kt @@ -0,0 +1,99 @@ +package com.shifthackz.aisdv1.presentation.modal.tag + +import com.shifthackz.aisdv1.presentation.core.CoreViewModelTest +import com.shifthackz.aisdv1.presentation.model.ExtraType +import kotlinx.coroutines.flow.firstOrNull +import kotlinx.coroutines.test.runTest +import org.junit.Assert +import org.junit.Test + +class EditTagViewModelTest : CoreViewModelTest() { + + override fun initializeViewModel() = EditTagViewModel() + + @Test + fun `given received InitialData intent, expected UI state updated witch correct stub values`() { + mockInitialData() + runTest { + val expected = EditTagState( + prompt = "prompt ", + negativePrompt = "negative", + originalTag = "", + currentTag = "", + extraType = ExtraType.Lora, + isNegative = false, + ) + val actual = viewModel.state.value + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received UpdateTag intent, expected field currentTag changed, field originalTag not changed in UI state`() { + mockInitialData() + viewModel.processIntent(EditTagIntent.UpdateTag("")) + runTest { + val state = viewModel.state.value + Assert.assertEquals("", state.currentTag) + Assert.assertEquals("", state.originalTag) + Assert.assertNotEquals(state.originalTag, state.currentTag) + } + } + + @Test + fun `given received Value Increment intent, expected field currentTag changed, field originalTag not changed in UI state`() { + mockInitialData() + viewModel.processIntent(EditTagIntent.Value.Increment) + runTest { + val state = viewModel.state.value + Assert.assertEquals("", state.currentTag) + Assert.assertEquals("", state.originalTag) + Assert.assertNotEquals(state.originalTag, state.currentTag) + } + } + + @Test + fun `given received Value Decrement intent, expected field currentTag changed, field originalTag not changed in UI state`() { + mockInitialData() + viewModel.processIntent(EditTagIntent.Value.Decrement) + runTest { + val state = viewModel.state.value + Assert.assertEquals("", state.currentTag) + Assert.assertEquals("", state.originalTag) + Assert.assertNotEquals(state.originalTag, state.currentTag) + } + } + + @Test + fun `given received Action Apply intent, expected ApplyPrompts effect with valid prompt delivered to effect collector`() { + mockInitialData() + viewModel.processIntent(EditTagIntent.Value.Increment) + viewModel.processIntent(EditTagIntent.Action.Apply) + runTest { + val expected = "prompt " + val actual = (viewModel.effect.firstOrNull() as? EditTagEffect.ApplyPrompts)?.prompt + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received Action Delete intent, expected ApplyPrompts effect with prompt that does not contain tag delivered to effect collector`() { + mockInitialData() + viewModel.processIntent(EditTagIntent.Action.Delete) + runTest { + val expected = "prompt" + val actual = (viewModel.effect.firstOrNull() as? EditTagEffect.ApplyPrompts)?.prompt + Assert.assertEquals(expected, actual) + } + } + + private fun mockInitialData() { + val intent = EditTagIntent.InitialData( + prompt = "prompt ", + negativePrompt = "negative", + tag = "", + isNegative = false, + ) + viewModel.processIntent(intent) + } +} diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/navigation/router/drawer/DrawerRouterImplTest.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/navigation/router/drawer/DrawerRouterImplTest.kt new file mode 100644 index 00000000..a54941a3 --- /dev/null +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/navigation/router/drawer/DrawerRouterImplTest.kt @@ -0,0 +1,68 @@ +package com.shifthackz.aisdv1.presentation.navigation.router.drawer + +import com.shifthackz.aisdv1.presentation.navigation.NavigationEffect +import org.junit.Test + +class DrawerRouterImplTest { + + private val router = DrawerRouterImpl() + + @Test + fun `given user opens drawer, expected router emits Open event`() { + router + .observe() + .test() + .also { router.openDrawer() } + .assertNoErrors() + .assertValueAt(0, NavigationEffect.Drawer.Open) + } + + @Test + fun `given user closes drawer, expected router emits Close event`() { + router + .observe() + .test() + .also { router.closeDrawer() } + .assertNoErrors() + .assertValueAt(0, NavigationEffect.Drawer.Close) + } + + @Test + fun `given user opens than closes drawer, expected router emits Open than Close events`() { + router + .observe() + .test() + .also { router.openDrawer() } + .assertNoErrors() + .assertValueAt(0, NavigationEffect.Drawer.Open) + .also { router.closeDrawer() } + .assertNoErrors() + .assertValueAt(1, NavigationEffect.Drawer.Close) + } + + @Test + fun `given user opens drawer twice, expected router emits one Open event`() { + router + .observe() + .test() + .also { router.openDrawer() } + .assertNoErrors() + .assertValueAt(0, NavigationEffect.Drawer.Open) + .also { router.openDrawer() } + .assertNoErrors() + .assertValueCount(1) + } + + @Test + fun `given user closes drawer twice, expected router emits one Close event`() { + router + .observe() + .test() + .also { router.closeDrawer() } + .assertNoErrors() + .assertValueAt(0, NavigationEffect.Drawer.Close) + .also { router.closeDrawer() } + .assertNoErrors() + .assertValueCount(1) + } +} diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/navigation/router/main/MainRouterImplTest.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/navigation/router/main/MainRouterImplTest.kt new file mode 100644 index 00000000..891ff322 --- /dev/null +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/navigation/router/main/MainRouterImplTest.kt @@ -0,0 +1,167 @@ +package com.shifthackz.aisdv1.presentation.navigation.router.main + +import com.shifthackz.aisdv1.core.common.appbuild.BuildInfoProvider +import com.shifthackz.aisdv1.presentation.navigation.NavigationEffect +import com.shifthackz.aisdv1.presentation.screen.debug.DebugMenuAccessor +import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupLaunchSource +import com.shifthackz.aisdv1.presentation.utils.Constants +import io.mockk.every +import io.mockk.mockk +import org.junit.Test + +class MainRouterImplTest { + + private val stubBuildInfoProvider = mockk() + private val stubDebugMenuAccessor = DebugMenuAccessor(stubBuildInfoProvider) + + private val router = MainRouterImpl(stubDebugMenuAccessor) + + @Test + fun `given user navigates back, expected router emits Back event`() { + router + .observe() + .test() + .also { router.navigateBack() } + .assertNoErrors() + .assertValueAt(0, NavigationEffect.Back) + } + + @Test + fun `given user navigates to splash config loader, expected router emits RouteBuilder event with ROUTE_CONFIG_LOADER route`() { + router + .observe() + .test() + .also { router.navigateToPostSplashConfigLoader() } + .assertNoErrors() + .assertValueAt(0) { actual -> + actual is NavigationEffect.Navigate.RouteBuilder + && actual.route == Constants.ROUTE_CONFIG_LOADER + } + } + + @Test + fun `given user navigates to home screen, expected router emits RoutePopUp event with ROUTE_HOME route`() { + router + .observe() + .test() + .also { router.navigateToHomeScreen() } + .assertNoErrors() + .assertValueAt(0, NavigationEffect.Navigate.RoutePopUp(Constants.ROUTE_HOME)) + } + + @Test + fun `given user navigates to server setup from splash, expected router emits RouteBuilder event with ROUTE_SERVER_SETUP route and SPLASH source`() { + router + .observe() + .test() + .also { router.navigateToServerSetup(ServerSetupLaunchSource.SPLASH) } + .assertNoErrors() + .assertValueAt(0) { actual -> + val expectedRoute = + "${Constants.ROUTE_SERVER_SETUP}/${ServerSetupLaunchSource.SPLASH.key}" + actual is NavigationEffect.Navigate.RouteBuilder + && actual.route == expectedRoute + } + } + + @Test + fun `given user navigates to server setup from settings, expected router emits RouteBuilder event with ROUTE_SERVER_SETUP route and SETTINGS source`() { + router + .observe() + .test() + .also { router.navigateToServerSetup(ServerSetupLaunchSource.SETTINGS) } + .assertNoErrors() + .assertValueAt(0) { actual -> + val expectedRoute = + "${Constants.ROUTE_SERVER_SETUP}/${ServerSetupLaunchSource.SETTINGS.key}" + actual is NavigationEffect.Navigate.RouteBuilder + && actual.route == expectedRoute + } + } + + @Test + fun `given user navigates to gallery details for item 5598, expected router emits Route event with ROUTE_GALLERY_DETAIL route and id 5598`() { + router + .observe() + .test() + .also { router.navigateToGalleryDetails(5598L) } + .assertNoErrors() + .assertValueAt( + 0, + NavigationEffect.Navigate.Route("${Constants.ROUTE_GALLERY_DETAIL}/5598"), + ) + } + + @Test + fun `given user navigates to in paint, expected router emits Route event with ROUTE_IN_PAINT route`() { + router + .observe() + .test() + .also { router.navigateToInPaint() } + .assertNoErrors() + .assertValueAt( + 0, + NavigationEffect.Navigate.Route(Constants.ROUTE_IN_PAINT), + ) + } + + @Test + fun `given user tapped hidden menu 6 times, build is debuggable, expected router emits no events`() { + every { + stubBuildInfoProvider.isDebug + } returns true + + val stubObserver = router.observe().test() + + repeat(6) { router.navigateToDebugMenu() } + + stubObserver + .assertNoErrors() + .assertNoValues() + } + + @Test + fun `given user tapped hidden menu 7 times, build is debuggable, expected router emits Route event with ROUTE_DEBUG route`() { + every { + stubBuildInfoProvider.isDebug + } returns true + + val stubObserver = router.observe().test() + + repeat(7) { router.navigateToDebugMenu() } + + stubObserver + .assertNoErrors() + .assertValueAt(0, NavigationEffect.Navigate.Route(Constants.ROUTE_DEBUG)) + } + + @Test + fun `given user tapped hidden menu 6 times, build is NOT debuggable, expected router emits no events`() { + every { + stubBuildInfoProvider.isDebug + } returns false + + val stubObserver = router.observe().test() + + repeat(6) { router.navigateToDebugMenu() } + + stubObserver + .assertNoErrors() + .assertNoValues() + } + + @Test + fun `given user tapped hidden menu 7 times, build is NOT debuggable, expected router emits no events`() { + every { + stubBuildInfoProvider.isDebug + } returns false + + val stubObserver = router.observe().test() + + repeat(7) { router.navigateToDebugMenu() } + + stubObserver + .assertNoErrors() + .assertNoValues() + } +} diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/debug/DebugMenuViewModelTest.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/debug/DebugMenuViewModelTest.kt new file mode 100644 index 00000000..e501b0f3 --- /dev/null +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/debug/DebugMenuViewModelTest.kt @@ -0,0 +1,49 @@ +package com.shifthackz.aisdv1.presentation.screen.debug + +import com.shifthackz.aisdv1.domain.usecase.debug.DebugInsertBadBase64UseCase +import com.shifthackz.aisdv1.presentation.core.CoreViewModelTest +import com.shifthackz.aisdv1.presentation.navigation.router.main.MainRouter +import com.shifthackz.aisdv1.presentation.stub.stubSchedulersProvider +import io.mockk.every +import io.mockk.mockk +import io.mockk.verify +import io.reactivex.rxjava3.core.Completable +import org.junit.Test + +class DebugMenuViewModelTest : CoreViewModelTest() { + + private val stubDebugInsertBadBase64UseCase = mockk() + private val stubMainRouter = mockk() + + override fun initializeViewModel() = DebugMenuViewModel( + debugInsertBadBase64UseCase = stubDebugInsertBadBase64UseCase, + schedulersProvider = stubSchedulersProvider, + mainRouter = stubMainRouter, + ) + + @Test + fun `given received NavigateBack intent, expected router navigateBack() method called`() { + every { + stubMainRouter.navigateBack() + } returns Unit + + viewModel.processIntent(DebugMenuIntent.NavigateBack) + + verify { + stubMainRouter.navigateBack() + } + } + + @Test + fun `given received InsertBadBase64 intent, expected debugInsertBadBase64UseCase() method called`() { + every { + stubDebugInsertBadBase64UseCase() + } returns Completable.complete() + + viewModel.processIntent(DebugMenuIntent.InsertBadBase64) + + verify { + stubDebugInsertBadBase64UseCase() + } + } +} diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/drawer/DrawerViewModelTest.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/drawer/DrawerViewModelTest.kt new file mode 100644 index 00000000..0b926db7 --- /dev/null +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/drawer/DrawerViewModelTest.kt @@ -0,0 +1,41 @@ +package com.shifthackz.aisdv1.presentation.screen.drawer + +import com.shifthackz.aisdv1.presentation.core.CoreViewModelTest +import com.shifthackz.aisdv1.presentation.navigation.router.drawer.DrawerRouter +import io.mockk.every +import io.mockk.mockk +import io.mockk.verify +import org.junit.Test + +class DrawerViewModelTest : CoreViewModelTest() { + + private val stubDrawerRouter = mockk() + + override fun initializeViewModel() = DrawerViewModel(stubDrawerRouter) + + @Test + fun `given received Close intent, expected router closeDrawer() method called`() { + every { + stubDrawerRouter.closeDrawer() + } returns Unit + + viewModel.processIntent(DrawerIntent.Close) + + verify { + stubDrawerRouter.closeDrawer() + } + } + + @Test + fun `given received Open intent, expected router openDrawer() method called`() { + every { + stubDrawerRouter.openDrawer() + } returns Unit + + viewModel.processIntent(DrawerIntent.Open) + + verify { + stubDrawerRouter.openDrawer() + } + } +} diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/gallery/detail/GalleryDetailViewModelTest.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/gallery/detail/GalleryDetailViewModelTest.kt new file mode 100644 index 00000000..b60f2579 --- /dev/null +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/gallery/detail/GalleryDetailViewModelTest.kt @@ -0,0 +1,273 @@ +@file:OptIn(ExperimentalCoroutinesApi::class) + +package com.shifthackz.aisdv1.presentation.screen.gallery.detail + +import android.graphics.Bitmap +import com.shifthackz.aisdv1.core.imageprocessing.Base64ToBitmapConverter +import com.shifthackz.aisdv1.core.model.asUiText +import com.shifthackz.aisdv1.domain.entity.AiGenerationResult +import com.shifthackz.aisdv1.domain.usecase.caching.GetLastResultFromCacheUseCase +import com.shifthackz.aisdv1.domain.usecase.gallery.DeleteGalleryItemUseCase +import com.shifthackz.aisdv1.domain.usecase.generation.GetGenerationResultUseCase +import com.shifthackz.aisdv1.presentation.core.CoreViewModelTest +import com.shifthackz.aisdv1.presentation.core.GenerationFormUpdateEvent +import com.shifthackz.aisdv1.presentation.extensions.mapToUi +import com.shifthackz.aisdv1.presentation.mocks.mockAiGenerationResult +import com.shifthackz.aisdv1.presentation.model.Modal +import com.shifthackz.aisdv1.presentation.navigation.router.main.MainRouter +import com.shifthackz.aisdv1.presentation.stub.stubSchedulersProvider +import io.mockk.every +import io.mockk.mockk +import io.mockk.unmockkAll +import io.mockk.verify +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.flow.firstOrNull +import kotlinx.coroutines.test.runTest +import org.junit.After +import org.junit.Assert +import org.junit.Before +import org.junit.Test +import java.io.File + +class GalleryDetailViewModelTest : CoreViewModelTest() { + + private val stubBitmap = mockk() + private val stubFile = mockk() + private val stubGetGenerationResultUseCase = mockk() + private val stubGetLastResultFromCacheUseCase = mockk() + private val stubDeleteGalleryItemUseCase = mockk() + private val stubGalleryDetailBitmapExporter = mockk() + private val stubBase64ToBitmapConverter = mockk() + private val stubGenerationFormUpdateEvent = mockk() + private val stubMainRouter = mockk() + + override fun initializeViewModel() = GalleryDetailViewModel( + itemId = 5598L, + getGenerationResultUseCase = stubGetGenerationResultUseCase, + getLastResultFromCacheUseCase = stubGetLastResultFromCacheUseCase, + deleteGalleryItemUseCase = stubDeleteGalleryItemUseCase, + galleryDetailBitmapExporter = stubGalleryDetailBitmapExporter, + base64ToBitmapConverter = stubBase64ToBitmapConverter, + schedulersProvider = stubSchedulersProvider, + generationFormUpdateEvent = stubGenerationFormUpdateEvent, + mainRouter = stubMainRouter, + ) + + @Before + override fun initialize() { + super.initialize() + + every { + stubGetLastResultFromCacheUseCase() + } returns Single.just(mockAiGenerationResult) + + every { + stubGetGenerationResultUseCase(any()) + } returns Single.just(mockAiGenerationResult) + + every { + stubBase64ToBitmapConverter(any()) + } returns Single.just(Base64ToBitmapConverter.Output(stubBitmap)) + } + + @After + override fun finalize() { + super.finalize() + unmockkAll() + } + + @Test + fun `initialized, loaded ai generation result, expected UI state is Content`() { + runTest { + val expected = GalleryDetailState.Content( + tabs = GalleryDetailState.Tab.consume(mockAiGenerationResult.type), + generationType = mockAiGenerationResult.type, + id = mockAiGenerationResult.id, + bitmap = stubBitmap, + inputBitmap = stubBitmap, + createdAt = mockAiGenerationResult.createdAt.toString().asUiText(), + type = mockAiGenerationResult.type.key.asUiText(), + prompt = mockAiGenerationResult.prompt.asUiText(), + negativePrompt = mockAiGenerationResult.negativePrompt.asUiText(), + size = "512 X 512".asUiText(), + samplingSteps = mockAiGenerationResult.samplingSteps.toString().asUiText(), + cfgScale = mockAiGenerationResult.cfgScale.toString().asUiText(), + restoreFaces = mockAiGenerationResult.restoreFaces.mapToUi(), + sampler = mockAiGenerationResult.sampler.asUiText(), + seed = mockAiGenerationResult.seed.asUiText(), + subSeed = mockAiGenerationResult.subSeed.asUiText(), + subSeedStrength = mockAiGenerationResult.subSeedStrength.toString().asUiText(), + denoisingStrength = mockAiGenerationResult.denoisingStrength.toString().asUiText(), + ) + val actual = viewModel.state.value + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received CopyToClipboard intent, expected ShareClipBoard effect delivered to effect collector`() { + viewModel.processIntent(GalleryDetailIntent.CopyToClipboard("text")) + runTest { + val expected = GalleryDetailEffect.ShareClipBoard("text") + val actual = viewModel.effect.firstOrNull() + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received Delete Request intent, expected modal field in UI state is DeleteImageConfirm`() { + viewModel.processIntent(GalleryDetailIntent.Delete.Request) + runTest { + val expected = Modal.DeleteImageConfirm + val actual = (viewModel.state.value as? GalleryDetailState.Content)?.screenModal + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received Delete Confirm intent, expected screenModal field in UI state is None, deleteGalleryItemUseCase() method is called`() { + every { + stubDeleteGalleryItemUseCase(any()) + } returns Completable.complete() + + every { + stubMainRouter.navigateBack() + } returns Unit + + viewModel.processIntent(GalleryDetailIntent.Delete.Confirm) + + runTest { + val expected = Modal.None + val actual = (viewModel.state.value as? GalleryDetailState.Content)?.screenModal + Assert.assertEquals(expected, actual) + } + verify { + stubDeleteGalleryItemUseCase(5598L) + } + } + + @Test + fun `given received Export Image intent, expected galleryDetailBitmapExporter() method is called, ShareImageFile effect delivered to effect collector`() { + every { + stubGalleryDetailBitmapExporter(any()) + } returns Single.just(stubFile) + viewModel.processIntent(GalleryDetailIntent.Export.Image) + verify { + stubGalleryDetailBitmapExporter(stubBitmap) + } + runTest { + val expected = GalleryDetailEffect.ShareImageFile(stubFile) + val actual = viewModel.effect.firstOrNull() + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received Export Params intent, expected ShareGenerationParams effect delivered to effect collector`() { + viewModel.processIntent(GalleryDetailIntent.Export.Params) + runTest { + val expected = GalleryDetailEffect.ShareGenerationParams(viewModel.state.value) + val actual = viewModel.effect.firstOrNull() + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received NavigateBack intent, expected router navigateBack() method called`() { + every { + stubMainRouter.navigateBack() + } returns Unit + viewModel.processIntent(GalleryDetailIntent.NavigateBack) + verify { + stubMainRouter.navigateBack() + } + } + + @Test + fun `given received SelectTab intent with IMAGE tab, expected expected selectedTab field in UI state is IMAGE`() { + viewModel.processIntent(GalleryDetailIntent.SelectTab(GalleryDetailState.Tab.IMAGE)) + runTest { + val expected = GalleryDetailState.Tab.IMAGE + val actual = viewModel.state.value.selectedTab + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received SelectTab intent with INFO tab, expected expected selectedTab field in UI state is INFO`() { + viewModel.processIntent(GalleryDetailIntent.SelectTab(GalleryDetailState.Tab.INFO)) + runTest { + val expected = GalleryDetailState.Tab.INFO + val actual = viewModel.state.value.selectedTab + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received SelectTab intent with ORIGINAL tab, expected expected selectedTab field in UI state is ORIGINAL`() { + viewModel.processIntent(GalleryDetailIntent.SelectTab(GalleryDetailState.Tab.ORIGINAL)) + runTest { + val expected = GalleryDetailState.Tab.ORIGINAL + val actual = viewModel.state.value.selectedTab + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received SendTo Txt2Img intent, expected router navigateBack() and form event update() methods called`() { + every { + stubGenerationFormUpdateEvent.update(any(), any()) + } returns Unit + + every { + stubMainRouter.navigateBack() + } returns Unit + + viewModel.processIntent(GalleryDetailIntent.SendTo.Txt2Img) + + verify { + stubMainRouter.navigateBack() + } + verify { + stubGenerationFormUpdateEvent.update( + mockAiGenerationResult, + AiGenerationResult.Type.TEXT_TO_IMAGE, + ) + } + } + + @Test + fun `given received SendTo Img2Img intent, expected router navigateBack() and form event update() methods called`() { + every { + stubGenerationFormUpdateEvent.update(any(), any()) + } returns Unit + + every { + stubMainRouter.navigateBack() + } returns Unit + + viewModel.processIntent(GalleryDetailIntent.SendTo.Img2Img) + + verify { + stubMainRouter.navigateBack() + } + verify { + stubGenerationFormUpdateEvent.update( + mockAiGenerationResult, + AiGenerationResult.Type.IMAGE_TO_IMAGE, + ) + } + } + + @Test + fun `given received DismissDialog intent, expected screenModal field in UI state is None`() { + viewModel.processIntent(GalleryDetailIntent.DismissDialog) + runTest { + val expected = Modal.None + val actual = viewModel.state.value.screenModal + Assert.assertEquals(expected, actual) + } + } +} diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/gallery/list/GalleryViewModelTest.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/gallery/list/GalleryViewModelTest.kt new file mode 100644 index 00000000..dd811ef2 --- /dev/null +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/gallery/list/GalleryViewModelTest.kt @@ -0,0 +1,137 @@ +@file:OptIn(ExperimentalCoroutinesApi::class) + +package com.shifthackz.aisdv1.presentation.screen.gallery.list + +import android.graphics.Bitmap +import android.net.Uri +import com.shifthackz.aisdv1.core.imageprocessing.Base64ToBitmapConverter +import com.shifthackz.aisdv1.domain.entity.MediaStoreInfo +import com.shifthackz.aisdv1.domain.usecase.gallery.GetMediaStoreInfoUseCase +import com.shifthackz.aisdv1.domain.usecase.generation.GetGenerationResultPagedUseCase +import com.shifthackz.aisdv1.presentation.core.CoreViewModelTest +import com.shifthackz.aisdv1.presentation.model.Modal +import com.shifthackz.aisdv1.presentation.navigation.router.main.MainRouter +import com.shifthackz.aisdv1.presentation.stub.stubSchedulersProvider +import io.mockk.every +import io.mockk.mockk +import io.mockk.verify +import io.reactivex.rxjava3.core.Single +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.flow.firstOrNull +import kotlinx.coroutines.test.UnconfinedTestDispatcher +import kotlinx.coroutines.test.runTest +import kotlinx.coroutines.test.setMain +import org.junit.Assert +import org.junit.Before +import org.junit.Test +import java.io.File + +class GalleryViewModelTest : CoreViewModelTest() { + + private val stubMediaStoreInfo = MediaStoreInfo(5598) + private val stubFile = mockk() + private val stubBitmap = mockk() + private val stubUri = mockk() + private val stubGetMediaStoreInfoUseCase = mockk() + private val stubGetGenerationResultPagedUseCase = mockk() + private val stubBase64ToBitmapConverter = mockk() + private val stubGalleryExporter = mockk() + private val stubMainRouter = mockk() + + override fun initializeViewModel() = GalleryViewModel( + getMediaStoreInfoUseCase = stubGetMediaStoreInfoUseCase, + getGenerationResultPagedUseCase = stubGetGenerationResultPagedUseCase, + base64ToBitmapConverter = stubBase64ToBitmapConverter, + galleryExporter = stubGalleryExporter, + schedulersProvider = stubSchedulersProvider, + mainRouter = stubMainRouter, + ) + + @Before + override fun initialize() { + super.initialize() + + every { + stubGetMediaStoreInfoUseCase() + } returns Single.just(stubMediaStoreInfo) + } + + @Test + fun `initialized, expected mediaStoreInfo field in UI state equals stubMediaStoreInfo`() { + runTest { + val expected = stubMediaStoreInfo + val actual = viewModel.state.value.mediaStoreInfo + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received DismissDialog intent, expected screenModal field in UI state is None`() { + viewModel.processIntent(GalleryIntent.DismissDialog) + runTest { + val expected = Modal.None + val actual = viewModel.state.value.screenModal + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received Export Request intent, expected screenModal field in UI state is ConfirmExport`() { + viewModel.processIntent(GalleryIntent.Export.Request) + runTest { + val expected = Modal.ConfirmExport + val actual = viewModel.state.value.screenModal + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received Export Confirm intent, expected screenModal field in UI state is None, Share effect delivered to effect collector`() { + every { + stubGalleryExporter() + } returns Single.just(stubFile) + + Dispatchers.setMain(UnconfinedTestDispatcher()) + + viewModel.processIntent(GalleryIntent.Export.Confirm) + + runTest { + val expectedUiState = Modal.None + val actualUiState = viewModel.state.value.screenModal + Assert.assertEquals(expectedUiState, actualUiState) + + val expectedEffect = GalleryEffect.Share(stubFile) + val actualEffect = viewModel.effect.firstOrNull() + Assert.assertEquals(expectedEffect, actualEffect) + } + verify { + stubGalleryExporter() + } + } + + @Test + fun `given received OpenItem intent, expected router navigateToGalleryDetails() method called`() { + every { + stubMainRouter.navigateToGalleryDetails(any()) + } returns Unit + + val item = GalleryGridItemUi(5598L, stubBitmap) + viewModel.processIntent(GalleryIntent.OpenItem(item)) + + verify { + stubMainRouter.navigateToGalleryDetails(5598L) + } + } + + @Test + fun `given received OpenMediaStoreFolder intent, expected OpenUri effect delivered to effect collector`() { + Dispatchers.setMain(UnconfinedTestDispatcher()) + viewModel.processIntent(GalleryIntent.OpenMediaStoreFolder(stubUri)) + runTest { + val expected = GalleryEffect.OpenUri(stubUri) + val actual = viewModel.effect.firstOrNull() + Assert.assertEquals(expected, actual) + } + } +} diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageViewModelTest.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageViewModelTest.kt new file mode 100644 index 00000000..209ea8d4 --- /dev/null +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageViewModelTest.kt @@ -0,0 +1,503 @@ +package com.shifthackz.aisdv1.presentation.screen.img2img + +import android.graphics.Bitmap +import com.shifthackz.aisdv1.core.imageprocessing.Base64ToBitmapConverter +import com.shifthackz.aisdv1.core.imageprocessing.BitmapToBase64Converter +import com.shifthackz.aisdv1.core.validation.ValidationResult +import com.shifthackz.aisdv1.core.validation.dimension.DimensionValidator.Error +import com.shifthackz.aisdv1.domain.entity.OpenAiModel +import com.shifthackz.aisdv1.domain.entity.OpenAiQuality +import com.shifthackz.aisdv1.domain.entity.OpenAiSize +import com.shifthackz.aisdv1.domain.entity.OpenAiStyle +import com.shifthackz.aisdv1.domain.entity.ServerSource +import com.shifthackz.aisdv1.domain.entity.Settings +import com.shifthackz.aisdv1.domain.entity.StableDiffusionSampler +import com.shifthackz.aisdv1.domain.usecase.generation.GetRandomImageUseCase +import com.shifthackz.aisdv1.domain.usecase.generation.ImageToImageUseCase +import com.shifthackz.aisdv1.presentation.core.CoreGenerationMviViewModelTest +import com.shifthackz.aisdv1.presentation.core.GenerationFormUpdateEvent +import com.shifthackz.aisdv1.presentation.core.GenerationMviIntent +import com.shifthackz.aisdv1.presentation.mocks.mockAiGenerationResult +import com.shifthackz.aisdv1.presentation.model.InPaintModel +import com.shifthackz.aisdv1.presentation.model.Modal +import com.shifthackz.aisdv1.presentation.screen.drawer.DrawerIntent +import com.shifthackz.aisdv1.presentation.screen.inpaint.InPaintStateProducer +import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupLaunchSource +import io.mockk.every +import io.mockk.mockk +import io.mockk.unmockkAll +import io.mockk.verify +import io.reactivex.rxjava3.core.BackpressureStrategy +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single +import io.reactivex.rxjava3.subjects.BehaviorSubject +import kotlinx.coroutines.test.runTest +import org.junit.After +import org.junit.Assert +import org.junit.Before +import org.junit.Test + +class ImageToImageViewModelTest : CoreGenerationMviViewModelTest() { + + private val stubBitmap = mockk() + private val stubInPainModel = BehaviorSubject.create() + + private val stubGenerationFormUpdateEvent = mockk() + private val stubImageToImageUseCase = mockk() + private val stubGetRandomImageUseCase = mockk() + private val stubBitmapToBase64Converter = mockk() + private val stubBase64ToBitmapConverter = mockk() + private val stubInPaintStateProducer = mockk() + + override fun initializeViewModel() = ImageToImageViewModel( + generationFormUpdateEvent = stubGenerationFormUpdateEvent, + getStableDiffusionSamplersUseCase = stubGetStableDiffusionSamplersUseCase, + observeHordeProcessStatusUseCase = stubObserveHordeProcessStatusUseCase, + observeLocalDiffusionProcessStatusUseCase = stubObserveLocalDiffusionProcessStatusUseCase, + saveLastResultToCacheUseCase = stubSaveLastResultToCacheUseCase, + saveGenerationResultUseCase = stubSaveGenerationResultUseCase, + interruptGenerationUseCase = stubInterruptGenerationUseCase, + drawerRouter = stubDrawerRouter, + dimensionValidator = stubDimensionValidator, + imageToImageUseCase = stubImageToImageUseCase, + getRandomImageUseCase = stubGetRandomImageUseCase, + bitmapToBase64Converter = stubBitmapToBase64Converter, + base64ToBitmapConverter = stubBase64ToBitmapConverter, + preferenceManager = stubPreferenceManager, + schedulersProvider = stubCustomSchedulers, + notificationManager = stubSdaiPushNotificationManager, + wakeLockInterActor = stubWakeLockInterActor, + inPaintStateProducer = stubInPaintStateProducer, + mainRouter = stubMainRouter, + ) + + @Before + override fun initialize() { + super.initialize() + + every { + stubGenerationFormUpdateEvent.observeImg2ImgForm() + } returns stubAiForm.toFlowable(BackpressureStrategy.LATEST) + + every { + stubInPaintStateProducer.observeInPaint() + } returns stubInPainModel.toFlowable(BackpressureStrategy.LATEST) + + stubSettings.onNext(Settings(source = ServerSource.AUTOMATIC1111)) + } + + @After + override fun finalize() { + super.finalize() + unmockkAll() + } + + @Test + fun `initialized, expected UI state update with correct stub values`() { + runTest { + val state = viewModel.state.value + Assert.assertNotNull(viewModel) + Assert.assertNotNull(viewModel.initialState) + Assert.assertNotNull(viewModel.state.value) + Assert.assertEquals(ServerSource.AUTOMATIC1111, state.mode) + Assert.assertEquals(emptyList(), state.availableSamplers) + } + verify { + stubGetStableDiffusionSamplersUseCase() + } + verify { + stubPreferenceManager.observe() + } + } + + @Test + fun `given received NewPrompts intent, expected prompt, negativePrompt updated in UI state`() { + val intent = GenerationMviIntent.NewPrompts( + positive = "prompt", + negative = "negative", + ) + viewModel.processIntent(intent) + runTest { + val state = viewModel.state.value + Assert.assertEquals("prompt", state.prompt) + Assert.assertEquals("negative", state.negativePrompt) + } + } + + @Test + fun `given received SetAdvancedOptionsVisibility intent, expected advancedOptionsVisible updated in UI state`() { + val intent = GenerationMviIntent.SetAdvancedOptionsVisibility(true) + viewModel.processIntent(intent) + runTest { + val expected = true + val actual = viewModel.state.value.advancedOptionsVisible + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received Update Prompt intent, expected prompt updated in UI state`() { + val intent = GenerationMviIntent.Update.Prompt("5598") + viewModel.processIntent(intent) + runTest { + val expected = "5598" + val actual = viewModel.state.value.prompt + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received Update NegativePrompt intent, expected negativePrompt updated in UI state`() { + val intent = GenerationMviIntent.Update.NegativePrompt("5598") + viewModel.processIntent(intent) + runTest { + val expected = "5598" + val actual = viewModel.state.value.negativePrompt + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received Update Size Width intent with valid value, expected width updated, widthValidationError is null in UI state`() { + every { + stubDimensionValidator(any()) + } returns ValidationResult(true) + + val intent = GenerationMviIntent.Update.Size.Width("512") + viewModel.processIntent(intent) + + runTest { + val state = viewModel.state.value + Assert.assertEquals("512", state.width) + Assert.assertNull(state.widthValidationError) + } + } + + @Test + fun `given received Update Size Width intent with invalid value, expected width updated, widthValidationError is NOT null in UI state`() { + every { + stubDimensionValidator(any()) + } returns ValidationResult(false, Error.Unexpected) + + val intent = GenerationMviIntent.Update.Size.Width("512d") + viewModel.processIntent(intent) + + runTest { + val state = viewModel.state.value + Assert.assertEquals("512d", state.width) + Assert.assertNotNull(state.widthValidationError) + } + } + + @Test + fun `given received Update Size Height intent with valid value, expected height updated, heightValidationError is null in UI state`() { + every { + stubDimensionValidator(any()) + } returns ValidationResult(true) + + val intent = GenerationMviIntent.Update.Size.Height("512") + viewModel.processIntent(intent) + + runTest { + val state = viewModel.state.value + Assert.assertEquals("512", state.height) + Assert.assertNull(state.heightValidationError) + } + } + + @Test + fun `given received Update Size Height intent with invalid value, expected height updated, heightValidationError is NOT null in UI state`() { + every { + stubDimensionValidator(any()) + } returns ValidationResult(false, Error.Unexpected) + + val intent = GenerationMviIntent.Update.Size.Height("512d") + viewModel.processIntent(intent) + + runTest { + val state = viewModel.state.value + Assert.assertEquals("512d", state.height) + Assert.assertNotNull(state.heightValidationError) + } + } + + @Test + fun `given received Update SamplingSteps intent, expected samplingSteps updated in UI state`() { + val intent = GenerationMviIntent.Update.SamplingSteps(12) + viewModel.processIntent(intent) + runTest { + val expected = 12 + val actual = viewModel.state.value.samplingSteps + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received Update CfgScale intent, expected cfgScale updated in UI state`() { + val intent = GenerationMviIntent.Update.CfgScale(12f) + viewModel.processIntent(intent) + runTest { + val expected = 12f + val actual = viewModel.state.value.cfgScale + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received Update RestoreFaces intent, expected restoreFaces updated in UI state`() { + val intent = GenerationMviIntent.Update.RestoreFaces(true) + viewModel.processIntent(intent) + runTest { + val expected = true + val actual = viewModel.state.value.restoreFaces + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received Update Seed intent, expected seed updated in UI state`() { + val intent = GenerationMviIntent.Update.Seed("5598") + viewModel.processIntent(intent) + runTest { + val expected = "5598" + val actual = viewModel.state.value.seed + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received Update SubSeed intent, expected subSeed updated in UI state`() { + val intent = GenerationMviIntent.Update.SubSeed("5598") + viewModel.processIntent(intent) + runTest { + val expected = "5598" + val actual = viewModel.state.value.subSeed + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received Update SubSeedStrength intent, expected subSeed updated in UI state`() { + val intent = GenerationMviIntent.Update.SubSeedStrength(7f) + viewModel.processIntent(intent) + runTest { + val expected = 7f + val actual = viewModel.state.value.subSeedStrength + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received Update Sampler intent, expected selectedSampler updated in UI state`() { + val intent = GenerationMviIntent.Update.Sampler("5598") + viewModel.processIntent(intent) + runTest { + val expected = "5598" + val actual = viewModel.state.value.selectedSampler + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received Update Nsfw intent, expected nsfw updated in UI state`() { + val intent = GenerationMviIntent.Update.Nsfw(true) + viewModel.processIntent(intent) + runTest { + val expected = true + val actual = viewModel.state.value.nsfw + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received Update Batch intent, expected batchCount updated in UI state`() { + val intent = GenerationMviIntent.Update.Batch(26) + viewModel.processIntent(intent) + runTest { + val expected = 26 + val actual = viewModel.state.value.batchCount + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received Update OpenAi Model intent, expected openAiModel updated in UI state`() { + val intent = GenerationMviIntent.Update.OpenAi.Model(OpenAiModel.DALL_E_2) + viewModel.processIntent(intent) + runTest { + val expected = OpenAiModel.DALL_E_2 + val actual = viewModel.state.value.openAiModel + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received Update OpenAi Size intent, expected openAiSize updated in UI state`() { + val intent = GenerationMviIntent.Update.OpenAi.Size(OpenAiSize.W256_H256) + viewModel.processIntent(intent) + runTest { + val expected = OpenAiSize.W256_H256 + val actual = viewModel.state.value.openAiSize + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received Update OpenAi Quality intent, expected openAiQuality updated in UI state`() { + val intent = GenerationMviIntent.Update.OpenAi.Quality(OpenAiQuality.HD) + viewModel.processIntent(intent) + runTest { + val expected = OpenAiQuality.HD + val actual = viewModel.state.value.openAiQuality + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received Update OpenAi Style intent, expected openAiStyle updated in UI state`() { + val intent = GenerationMviIntent.Update.OpenAi.Style(OpenAiStyle.NATURAL) + viewModel.processIntent(intent) + runTest { + val expected = OpenAiStyle.NATURAL + val actual = viewModel.state.value.openAiStyle + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received Result Save intent, expected screenModal is None in UI state`() { + every { + stubSaveGenerationResultUseCase(any()) + } returns Completable.complete() + + val intent = GenerationMviIntent.Result.Save(listOf(mockAiGenerationResult)) + viewModel.processIntent(intent) + + runTest { + val state = viewModel.state.value + Assert.assertEquals(Modal.None, state.screenModal) + } + } + + @Test + fun `given received Result View intent, expected saveGenerationResultUseCase() called`() { + every { + stubSaveLastResultToCacheUseCase(any()) + } returns Single.just(mockAiGenerationResult) + + every { + stubMainRouter.navigateToGalleryDetails(any()) + } returns Unit + + val intent = GenerationMviIntent.Result.View(mockAiGenerationResult) + viewModel.processIntent(intent) + + verify { + stubSaveLastResultToCacheUseCase(mockAiGenerationResult) + } + } + + @Test + fun `given received SetModal intent, expected screenModal updated in UI state`() { + val intent = GenerationMviIntent.SetModal(Modal.Communicating()) + viewModel.processIntent(intent) + runTest { + val expected = Modal.Communicating() + val actual = viewModel.state.value.screenModal + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received Cancel Generation intent, expected interruptGenerationUseCase() called`() { + every { + stubInterruptGenerationUseCase() + } returns Completable.complete() + + val intent = GenerationMviIntent.Cancel.Generation + viewModel.processIntent(intent) + + verify { + stubInterruptGenerationUseCase() + } + } + + @Test + fun `given received Cancel FetchRandomImage intent, expected screenModal is None in UI state`() { + val intent = GenerationMviIntent.Cancel.FetchRandomImage + viewModel.processIntent(intent) + runTest { + Assert.assertEquals( + Modal.None, + viewModel.state.value.screenModal, + ) + } + } + + @Test + fun `given received Configuration intent, expected router navigateToServerSetup() called`() { + every { + stubMainRouter.navigateToServerSetup(any()) + } returns Unit + + val intent = GenerationMviIntent.Configuration + viewModel.processIntent(intent) + + verify { + stubMainRouter.navigateToServerSetup(ServerSetupLaunchSource.SETTINGS) + } + } + + @Test + fun `given received UpdateFromGeneration intent, expected UI state fields are same as intent model`() { + every { + stubBase64ToBitmapConverter.invoke(any()) + } returns Single.just(Base64ToBitmapConverter.Output(stubBitmap)) + + val intent = GenerationMviIntent.UpdateFromGeneration(mockAiGenerationResult) + viewModel.processIntent(intent) + + runTest { + val state = viewModel.state.value + Assert.assertEquals(true, state.advancedOptionsVisible) + Assert.assertEquals(mockAiGenerationResult.prompt, state.prompt) + Assert.assertEquals(mockAiGenerationResult.negativePrompt, state.negativePrompt) + Assert.assertEquals(mockAiGenerationResult.width.toString(), state.width) + Assert.assertEquals(mockAiGenerationResult.height.toString(), state.height) + Assert.assertEquals(mockAiGenerationResult.seed, state.seed) + Assert.assertEquals(mockAiGenerationResult.subSeed, state.subSeed) + Assert.assertEquals(mockAiGenerationResult.subSeedStrength, state.subSeedStrength) + Assert.assertEquals(mockAiGenerationResult.samplingSteps, state.samplingSteps) + Assert.assertEquals(mockAiGenerationResult.cfgScale, state.cfgScale) + Assert.assertEquals(mockAiGenerationResult.restoreFaces, state.restoreFaces) + } + } + + @Test + fun `given received Drawer Open intent, expected router openDrawer() called`() { + every { + stubDrawerRouter.openDrawer() + } returns Unit + + val intent = GenerationMviIntent.Drawer(DrawerIntent.Open) + viewModel.processIntent(intent) + + verify { + stubDrawerRouter.openDrawer() + } + } + + @Test + fun `given received Drawer Close intent, expected router closeDrawer() called`() { + every { + stubDrawerRouter.closeDrawer() + } returns Unit + + val intent = GenerationMviIntent.Drawer(DrawerIntent.Close) + viewModel.processIntent(intent) + + verify { + stubDrawerRouter.closeDrawer() + } + } +} diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/inpaint/InPaintViewModelTest.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/inpaint/InPaintViewModelTest.kt new file mode 100644 index 00000000..acbf687b --- /dev/null +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/inpaint/InPaintViewModelTest.kt @@ -0,0 +1,208 @@ +package com.shifthackz.aisdv1.presentation.screen.inpaint + +import android.graphics.Bitmap +import androidx.compose.ui.graphics.Path +import com.shifthackz.aisdv1.presentation.core.CoreViewModelTest +import com.shifthackz.aisdv1.presentation.model.InPaintModel +import com.shifthackz.aisdv1.presentation.model.Modal +import com.shifthackz.aisdv1.presentation.navigation.router.main.MainRouter +import com.shifthackz.aisdv1.presentation.stub.stubSchedulersProvider +import io.mockk.every +import io.mockk.mockk +import io.mockk.verify +import io.reactivex.rxjava3.core.BackpressureStrategy +import io.reactivex.rxjava3.subjects.BehaviorSubject +import kotlinx.coroutines.test.runTest +import org.junit.Assert +import org.junit.Before +import org.junit.Test + +class InPaintViewModelTest : CoreViewModelTest() { + + private val stubBitmap = mockk() + private val stubPath = mockk() + private val stubInPainSubject = BehaviorSubject.create() + private val stubBitmapSubject = BehaviorSubject.create() + private val stubInPaintStateProducer = mockk() + private val stubMainRouter = mockk() + + override fun initializeViewModel() = InPaintViewModel( + schedulersProvider = stubSchedulersProvider, + stateProducer = stubInPaintStateProducer, + mainRouter = stubMainRouter, + ) + + @Before + override fun initialize() { + super.initialize() + + every { + stubInPaintStateProducer.observeInPaint() + } returns stubInPainSubject.toFlowable(BackpressureStrategy.LATEST) + + every { + stubInPaintStateProducer.observeBitmap() + } returns stubBitmapSubject.toFlowable(BackpressureStrategy.LATEST) + } + + @Test + fun `initialized, expected UI state updated with correct stub values`() { + stubBitmapSubject.onNext(stubBitmap) + stubInPainSubject.onNext(InPaintModel()) + runTest { + val state = viewModel.state.value + Assert.assertEquals(InPaintModel(), state.model) + Assert.assertEquals(stubBitmap, state.bitmap) + } + } + + @Test + fun `given received DrawPath intent, expected last path in UI state added from intent`() { + viewModel.processIntent(InPaintIntent.DrawPath(stubPath)) + runTest { + val expected = stubPath to 16 + val actual = viewModel.state.value.model.paths.last() + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received DrawPathBmp intent, expected model bitmap updated in UI state from intent`() { + viewModel.processIntent(InPaintIntent.DrawPathBmp(stubBitmap)) + runTest { + val expected = stubBitmap + val actual = viewModel.state.value.model.bitmap + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received NavigateBack intent, expected state producer updateInPaint(), router navigateBack() methods called`() { + every { + stubInPaintStateProducer.updateInPaint(any()) + } returns Unit + + every { + stubMainRouter.navigateBack() + } returns Unit + + viewModel.processIntent(InPaintIntent.NavigateBack) + + verify { + stubInPaintStateProducer.updateInPaint(viewModel.state.value.model) + } + verify { + stubMainRouter.navigateBack() + } + } + + @Test + fun `given received SelectTab intent, expected selectedTab field updated in UI state`() { + viewModel.processIntent(InPaintIntent.SelectTab(InPaintState.Tab.FORM)) + runTest { + val expected = InPaintState.Tab.FORM + val actual = viewModel.state.value.selectedTab + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received ChangeCapSize intent, expected size field updated in UI state`() { + viewModel.processIntent(InPaintIntent.ChangeCapSize(5598)) + runTest { + val expected = 5598 + val actual = viewModel.state.value.size + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received Action Undo intent, expected last path in UI state removed`() { + viewModel.processIntent(InPaintIntent.Action.Undo) + runTest { + val expected = emptyList>() + val actual = viewModel.state.value.model.paths + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received ScreenModal Dismiss intent, expected screenModal in UI state is None`() { + viewModel.processIntent(InPaintIntent.ScreenModal.Dismiss) + runTest { + val expected = Modal.None + val actual = viewModel.state.value.screenModal + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received ScreenModal Show intent, expected screenModal in UI state is updated`() { + viewModel.processIntent(InPaintIntent.ScreenModal.Show(Modal.Language)) + runTest { + val expected = Modal.Language + val actual = viewModel.state.value.screenModal + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received Action Clear intent, expected screenModal is None, paths is empty in UI state`() { + viewModel.processIntent(InPaintIntent.Action.Clear) + runTest { + val state = viewModel.state.value + Assert.assertEquals(Modal.None, state.screenModal) + Assert.assertEquals(emptyList>(), state.model.paths) + } + } + + @Test + fun `given received Update MaskBlur intent, expected maskBlur is updated in UI state`() { + viewModel.processIntent(InPaintIntent.Update.MaskBlur(5598)) + runTest { + val expected = 5598 + val actual = viewModel.state.value.model.maskBlur + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received Update OnlyMaskedPadding intent, expected onlyMaskedPaddingPx is updated in UI state`() { + viewModel.processIntent(InPaintIntent.Update.OnlyMaskedPadding(5598)) + runTest { + val expected = 5598 + val actual = viewModel.state.value.model.onlyMaskedPaddingPx + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received Update Area intent, expected inPaintArea is updated in UI state`() { + viewModel.processIntent(InPaintIntent.Update.Area(InPaintModel.Area.WholePicture)) + runTest { + val expected = InPaintModel.Area.WholePicture + val actual = viewModel.state.value.model.inPaintArea + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received Update MaskContent intent, expected maskContent is updated in UI state`() { + viewModel.processIntent(InPaintIntent.Update.MaskContent(InPaintModel.MaskContent.Fill)) + runTest { + val expected = InPaintModel.MaskContent.Fill + val actual = viewModel.state.value.model.maskContent + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received Update MaskMode intent, expected maskMode is updated in UI state`() { + viewModel.processIntent(InPaintIntent.Update.MaskMode(InPaintModel.MaskMode.InPaintNotMasked)) + runTest { + val expected = InPaintModel.MaskMode.InPaintNotMasked + val actual = viewModel.state.value.model.maskMode + Assert.assertEquals(expected, actual) + } + } +} diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/loader/ConfigurationLoaderViewModelTest.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/loader/ConfigurationLoaderViewModelTest.kt new file mode 100644 index 00000000..ac2180a5 --- /dev/null +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/loader/ConfigurationLoaderViewModelTest.kt @@ -0,0 +1,68 @@ +package com.shifthackz.aisdv1.presentation.screen.loader + +import com.shifthackz.aisdv1.domain.usecase.caching.DataPreLoaderUseCase +import com.shifthackz.aisdv1.presentation.core.CoreViewModelTest +import com.shifthackz.aisdv1.presentation.navigation.router.main.MainRouter +import com.shifthackz.aisdv1.presentation.stub.stubSchedulersProvider +import io.mockk.every +import io.mockk.mockk +import io.mockk.verify +import io.reactivex.rxjava3.core.Completable +import kotlinx.coroutines.test.runTest +import org.junit.Assert +import org.junit.Test + +class ConfigurationLoaderViewModelTest : CoreViewModelTest() { + + private val stubException = Throwable("Something went wrong.") + private val stubDataPreLoaderUseCase = mockk() + private val stubMainRouter = mockk() + + override fun initializeViewModel() = ConfigurationLoaderViewModel( + dataPreLoaderUseCase = stubDataPreLoaderUseCase, + schedulersProvider = stubSchedulersProvider, + mainRouter = stubMainRouter, + ) + + @Test + fun `given initialized, data loaded successfully, expected UI state is StatusNotification, router navigateToHomeScreen() method called`() { + every { + stubMainRouter.navigateToHomeScreen() + } returns Unit + + every { + stubDataPreLoaderUseCase() + } returns Completable.complete() + + runTest { + val expected = true + val actual = viewModel.state.value is ConfigurationLoaderState.StatusNotification + Assert.assertEquals(expected, actual) + } + + verify { + stubMainRouter.navigateToHomeScreen() + } + } + + @Test + fun `given initialized, data not loaded, expected UI state is StatusNotification, router navigateToHomeScreen() method called`() { + every { + stubMainRouter.navigateToHomeScreen() + } returns Unit + + every { + stubDataPreLoaderUseCase() + } returns Completable.error(stubException) + + runTest { + val expected = true + val actual = viewModel.state.value is ConfigurationLoaderState.StatusNotification + Assert.assertEquals(expected, actual) + } + + verify { + stubMainRouter.navigateToHomeScreen() + } + } +} diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsViewModelTest.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsViewModelTest.kt new file mode 100644 index 00000000..1439e58e --- /dev/null +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsViewModelTest.kt @@ -0,0 +1,460 @@ +package com.shifthackz.aisdv1.presentation.screen.settings + +import android.os.Build +import app.cash.turbine.test +import com.shifthackz.aisdv1.core.common.appbuild.BuildInfoProvider +import com.shifthackz.aisdv1.domain.entity.ColorToken +import com.shifthackz.aisdv1.domain.entity.DarkThemeToken +import com.shifthackz.aisdv1.domain.entity.Settings +import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import com.shifthackz.aisdv1.domain.usecase.caching.ClearAppCacheUseCase +import com.shifthackz.aisdv1.domain.usecase.sdmodel.GetStableDiffusionModelsUseCase +import com.shifthackz.aisdv1.domain.usecase.sdmodel.SelectStableDiffusionModelUseCase +import com.shifthackz.aisdv1.domain.usecase.stabilityai.ObserveStabilityAiCreditsUseCase +import com.shifthackz.aisdv1.presentation.core.CoreViewModelTest +import com.shifthackz.aisdv1.presentation.mocks.mockStableDiffusionModels +import com.shifthackz.aisdv1.presentation.model.Modal +import com.shifthackz.aisdv1.presentation.navigation.router.main.MainRouter +import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupLaunchSource +import com.shifthackz.aisdv1.presentation.stub.stubSchedulersProvider +import io.mockk.every +import io.mockk.mockk +import io.mockk.verify +import io.reactivex.rxjava3.core.BackpressureStrategy +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single +import io.reactivex.rxjava3.subjects.BehaviorSubject +import kotlinx.coroutines.flow.firstOrNull +import kotlinx.coroutines.test.runTest +import org.junit.Assert +import org.junit.Before +import org.junit.Test +import java.lang.reflect.Field +import java.lang.reflect.Method +import java.lang.reflect.Modifier + +class SettingsViewModelTest : CoreViewModelTest() { + + private val stubSettings = BehaviorSubject.createDefault(Settings()) + private val stubStabilityCredits = BehaviorSubject.createDefault(5598f) + private val stubGetStableDiffusionModelsUseCase = mockk() + private val stubObserveStabilityAiCreditsUseCase = mockk() + private val stubSelectStableDiffusionModelUseCase = mockk() + private val stubClearAppCacheUseCase = mockk() + private val stubPreferenceManager = mockk() + private val stubBuildInfoProvider = mockk() + private val stubMainRouter = mockk() + + override fun initializeViewModel() = SettingsViewModel( + getStableDiffusionModelsUseCase = stubGetStableDiffusionModelsUseCase, + observeStabilityAiCreditsUseCase = stubObserveStabilityAiCreditsUseCase, + selectStableDiffusionModelUseCase = stubSelectStableDiffusionModelUseCase, + clearAppCacheUseCase = stubClearAppCacheUseCase, + schedulersProvider = stubSchedulersProvider, + preferenceManager = stubPreferenceManager, + buildInfoProvider = stubBuildInfoProvider, + mainRouter = stubMainRouter, + ) + + @Before + override fun initialize() { + super.initialize() + + every { + stubBuildInfoProvider.toString() + } returns "5.5.98" + + every { + stubGetStableDiffusionModelsUseCase() + } returns Single.just(mockStableDiffusionModels) + + every { + stubPreferenceManager.observe() + } returns stubSettings.toFlowable(BackpressureStrategy.LATEST) + + every { + stubObserveStabilityAiCreditsUseCase() + } returns stubStabilityCredits.toFlowable(BackpressureStrategy.LATEST) + } + + @Test + fun `initialized, expected UI state updated with correct stub values`() { + runTest { + val expected = "5.5.98" + val actual = viewModel.state.value.appVersion + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received Action AppVersion intent, expected router navigateToDebugMenu() method called`() { + every { + stubMainRouter.navigateToDebugMenu() + } returns Unit + + viewModel.processIntent(SettingsIntent.Action.AppVersion) + + verify { + stubMainRouter.navigateToDebugMenu() + } + } + + @Test + fun `given received Action ClearAppCache Request intent, expected screenModal field in UI state is ClearAppCache`() { + viewModel.processIntent(SettingsIntent.Action.ClearAppCache.Request) + runTest { + val expected = Modal.ClearAppCache + val actual = viewModel.state.value.screenModal + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received Action ClearAppCache Confirm intent, expected screenModal field in UI state is None`() { + every { + stubClearAppCacheUseCase() + } returns Completable.complete() + + viewModel.processIntent(SettingsIntent.Action.ClearAppCache.Confirm) + + runTest { + val expected = Modal.None + val actual = viewModel.state.value.screenModal + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received Action ReportProblem intent, expected ShareLogFile effect delivered to effect collector`() { + viewModel.processIntent(SettingsIntent.Action.ReportProblem) + runTest { + viewModel.effect.test { + Assert.assertEquals(SettingsEffect.ShareLogFile, awaitItem()) + } + } + } + + @Test + fun `given received DismissDialog intent, expected screenModal field in UI state is None`() { + viewModel.processIntent(SettingsIntent.DismissDialog) + runTest { + val expected = Modal.None + val actual = viewModel.state.value.screenModal + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received NavigateConfiguration intent, expected router navigateToServerSetup() method called`() { + every { + stubMainRouter.navigateToServerSetup(any()) + } returns Unit + + viewModel.processIntent(SettingsIntent.NavigateConfiguration) + + verify { + stubMainRouter.navigateToServerSetup(ServerSetupLaunchSource.SETTINGS) + } + } + + @Test + fun `given received SdModel OpenChooser intent, expected screenModal field in UI state is SelectSdModel`() { + viewModel.processIntent(SettingsIntent.SdModel.OpenChooser) + runTest { + val expected = Modal.SelectSdModel( + models = listOf("title_5598", "title_151297"), + selected = "title_5598", + ) + val actual = viewModel.state.value.screenModal + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received SdModel OpenChooser intent, expected screenModal field in UI state is None`() { + every { + stubSelectStableDiffusionModelUseCase(any()) + } returns Completable.complete() + + viewModel.processIntent(SettingsIntent.SdModel.Select("title_151297")) + + runTest { + val expected = Modal.None + val actual = viewModel.state.value.screenModal + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received UpdateFlag AdvancedFormVisibility intent, expected formAdvancedOptionsAlwaysShow preference updated`() { + every { + stubPreferenceManager::formAdvancedOptionsAlwaysShow.set(any()) + } returns Unit + + viewModel.processIntent(SettingsIntent.UpdateFlag.AdvancedFormVisibility(true)) + + verify { + stubPreferenceManager::formAdvancedOptionsAlwaysShow.set(true) + } + } + + @Test + fun `given received UpdateFlag AutoSaveResult intent, expected autoSaveAiResults preference updated`() { + every { + stubPreferenceManager::autoSaveAiResults.set(any()) + } returns Unit + + viewModel.processIntent(SettingsIntent.UpdateFlag.AutoSaveResult(true)) + + verify { + stubPreferenceManager::autoSaveAiResults.set(true) + } + } + + @Test + fun `given received UpdateFlag MonitorConnection intent, expected monitorConnectivity preference updated`() { + every { + stubPreferenceManager::monitorConnectivity.set(any()) + } returns Unit + + viewModel.processIntent(SettingsIntent.UpdateFlag.MonitorConnection(true)) + + verify { + stubPreferenceManager::monitorConnectivity.set(true) + } + } + + @Test + fun `given received UpdateFlag NNAPI intent, expected localUseNNAPI preference updated`() { + every { + stubPreferenceManager::localUseNNAPI.set(any()) + } returns Unit + + viewModel.processIntent(SettingsIntent.UpdateFlag.NNAPI(true)) + + verify { + stubPreferenceManager::localUseNNAPI.set(true) + } + } + + @Test + fun `given received UpdateFlag TaggedInput intent, expected formPromptTaggedInput preference updated`() { + every { + stubPreferenceManager::formPromptTaggedInput.set(any()) + } returns Unit + + viewModel.processIntent(SettingsIntent.UpdateFlag.TaggedInput(true)) + + verify { + stubPreferenceManager::formPromptTaggedInput.set(true) + } + } + + @Test + fun `given received UpdateFlag SaveToMediaStore intent with true value, app running on Android SDK 34, expected newImpl() called, saveToMediaStore preference updated, saveToMediaStore field in UI state is false`() { + every { + stubPreferenceManager::saveToMediaStore.set(any()) + } returns Unit + + mockSdkInt(Build.VERSION_CODES.UPSIDE_DOWN_CAKE) + + viewModel.processIntent(SettingsIntent.UpdateFlag.SaveToMediaStore(true)) + + runTest { + val expected = false + val actual = viewModel.state.value.saveToMediaStore + Assert.assertEquals(expected, actual) + } + verify { + stubPreferenceManager::saveToMediaStore.set(true) + } + } + + @Test + fun `given received UpdateFlag SaveToMediaStore intent with false value, app running on Android SDK 34, expected newImpl() called, saveToMediaStore preference updated, saveToMediaStore field in UI state is false`() { + every { + stubPreferenceManager::saveToMediaStore.set(any()) + } returns Unit + + mockSdkInt(Build.VERSION_CODES.UPSIDE_DOWN_CAKE) + + viewModel.processIntent(SettingsIntent.UpdateFlag.SaveToMediaStore(false)) + + runTest { + val expected = false + val actual = viewModel.state.value.saveToMediaStore + Assert.assertEquals(expected, actual) + } + verify { + stubPreferenceManager::saveToMediaStore.set(false) + } + } + + @Test + fun `given received UpdateFlag SaveToMediaStore intent with true value, app running on Android SDK 26, expected oldImpl() called, RequestStoragePermission effect delivered to effect collector`() { + mockSdkInt(Build.VERSION_CODES.O) + + viewModel.processIntent(SettingsIntent.UpdateFlag.SaveToMediaStore(true)) + + runTest { + val expected = SettingsEffect.RequestStoragePermission + val actual = viewModel.effect.firstOrNull() + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received UpdateFlag SaveToMediaStore intent with false value, app running on Android SDK 26, expected oldImpl() called, saveToMediaStore preference updated, saveToMediaStore field in UI state is false`() { + every { + stubPreferenceManager::saveToMediaStore.set(any()) + } returns Unit + + mockSdkInt(Build.VERSION_CODES.O) + + viewModel.processIntent(SettingsIntent.UpdateFlag.SaveToMediaStore(false)) + + runTest { + val expected = false + val actual = viewModel.state.value.saveToMediaStore + Assert.assertEquals(expected, actual) + } + verify { + stubPreferenceManager::saveToMediaStore.set(false) + } + } + + @Test + fun `given received LaunchUrl intent, expected OpenUrl effect delivered to effect collector`() { + val intent = mockk() + every { + intent::url.get() + } returns "https://5598.is.my.favorite.com" + + viewModel.processIntent(intent) + + runTest { + val expected = SettingsEffect.OpenUrl("https://5598.is.my.favorite.com") + val actual = viewModel.effect.firstOrNull() + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received StoragePermissionGranted intent, expected saveToMediaStore preference set to true`() { + every { + stubPreferenceManager::saveToMediaStore.set(any()) + } returns Unit + + viewModel.processIntent(SettingsIntent.StoragePermissionGranted) + + verify { + stubPreferenceManager::saveToMediaStore.set(true) + } + } + + @Test + fun `given received UpdateFlag DynamicColors intent, expected designUseSystemColorPalette preference updated`() { + every { + stubPreferenceManager::designUseSystemColorPalette.set(any()) + } returns Unit + + viewModel.processIntent(SettingsIntent.UpdateFlag.DynamicColors(true)) + + verify { + stubPreferenceManager::designUseSystemColorPalette.set(true) + } + } + + @Test + fun `given received UpdateFlag SystemDarkTheme intent, expected designUseSystemDarkTheme preference updated`() { + every { + stubPreferenceManager::designUseSystemDarkTheme.set(any()) + } returns Unit + + viewModel.processIntent(SettingsIntent.UpdateFlag.SystemDarkTheme(true)) + + verify { + stubPreferenceManager::designUseSystemDarkTheme.set(true) + } + } + + @Test + fun `given received UpdateFlag DarkTheme intent, expected designDarkTheme preference updated`() { + every { + stubPreferenceManager::designDarkTheme.set(any()) + } returns Unit + + viewModel.processIntent(SettingsIntent.UpdateFlag.DarkTheme(true)) + + verify { + stubPreferenceManager::designDarkTheme.set(true) + } + } + + @Test + fun `given received NewColorToken intent, expected designColorToken preference updated`() { + every { + stubPreferenceManager::designColorToken.set(any()) + } returns Unit + + viewModel.processIntent(SettingsIntent.NewColorToken(ColorToken.MAUVE)) + + verify { + stubPreferenceManager::designColorToken.set("${ColorToken.MAUVE}") + } + } + + @Test + fun `given received NewDarkThemeToken intent, expected designDarkThemeToken preference updated`() { + every { + stubPreferenceManager::designDarkThemeToken.set(any()) + } returns Unit + + viewModel.processIntent(SettingsIntent.NewDarkThemeToken(DarkThemeToken.MACCHIATO)) + + verify { + stubPreferenceManager::designDarkThemeToken.set("${DarkThemeToken.MACCHIATO}") + } + } + + @Test + fun `given received Action PickLanguage intent, expected screenModal field in UI state is Language`() { + viewModel.processIntent(SettingsIntent.Action.PickLanguage) + runTest { + val expected = Modal.Language + val actual = viewModel.state.value.screenModal + Assert.assertEquals(expected, actual) + } + } + + private fun mockSdkInt(sdkInt: Int) { + val sdkIntField = Build.VERSION::class.java.getField("SDK_INT") + sdkIntField.isAccessible = true + getModifiersField().also { + it.isAccessible = true + it.set(sdkIntField, sdkIntField.modifiers and Modifier.FINAL.inv()) + } + sdkIntField.set(null, sdkInt) + } + + private fun getModifiersField(): Field { + return try { + Field::class.java.getDeclaredField("modifiers") + } catch (e: NoSuchFieldException) { + try { + val getDeclaredFields0: Method = + Class::class.java.getDeclaredMethod("getDeclaredFields0", Boolean::class.javaPrimitiveType) + getDeclaredFields0.isAccessible = true + val fields = getDeclaredFields0.invoke(Field::class.java, false) as Array + for (field in fields) { + if ("modifiers" == field.name) { + return field + } + } + } catch (ex: ReflectiveOperationException) { + e.addSuppressed(ex) + } + throw e + } + } +} diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModelTest.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModelTest.kt new file mode 100644 index 00000000..e6bb063c --- /dev/null +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModelTest.kt @@ -0,0 +1,398 @@ +package com.shifthackz.aisdv1.presentation.screen.setup + +import com.shifthackz.aisdv1.core.validation.common.CommonStringValidator +import com.shifthackz.aisdv1.core.validation.url.UrlValidator +import com.shifthackz.aisdv1.domain.entity.Configuration +import com.shifthackz.aisdv1.domain.entity.DownloadState +import com.shifthackz.aisdv1.domain.entity.ServerSource +import com.shifthackz.aisdv1.domain.interactor.settings.SetupConnectionInterActor +import com.shifthackz.aisdv1.domain.interactor.wakelock.WakeLockInterActor +import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import com.shifthackz.aisdv1.domain.usecase.downloadable.DeleteModelUseCase +import com.shifthackz.aisdv1.domain.usecase.downloadable.DownloadModelUseCase +import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalAiModelsUseCase +import com.shifthackz.aisdv1.domain.usecase.huggingface.FetchAndGetHuggingFaceModelsUseCase +import com.shifthackz.aisdv1.domain.usecase.settings.GetConfigurationUseCase +import com.shifthackz.aisdv1.presentation.core.CoreViewModelTest +import com.shifthackz.aisdv1.presentation.mocks.mockHuggingFaceModels +import com.shifthackz.aisdv1.presentation.mocks.mockLocalAiModels +import com.shifthackz.aisdv1.presentation.mocks.mockServerSetupStateLocalModel +import com.shifthackz.aisdv1.presentation.model.Modal +import com.shifthackz.aisdv1.presentation.navigation.router.main.MainRouter +import com.shifthackz.aisdv1.presentation.stub.stubSchedulersProvider +import io.mockk.every +import io.mockk.mockk +import io.mockk.verify +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Observable +import io.reactivex.rxjava3.core.Single +import kotlinx.coroutines.flow.firstOrNull +import kotlinx.coroutines.test.runTest +import org.junit.Assert +import org.junit.Before +import org.junit.Test + +class ServerSetupViewModelTest : CoreViewModelTest() { + + private val stubGetConfigurationUseCase = mockk() + private val stubGetLocalAiModelsUseCase = mockk() + private val stubFetchAndGetHuggingFaceModelsUseCase = + mockk() + private val stubUrlValidator = mockk() + private val stubCommonStringValidator = mockk() + private val stubSetupConnectionInterActor = mockk() + private val stubDownloadModelUseCase = mockk() + private val stubDeleteModelUseCase = mockk() + private val stubPreferenceManager = mockk() + private val stubWakeLockInterActor = mockk() + private val stubMainRouter = mockk() + + override fun initializeViewModel() = ServerSetupViewModel( + launchSource = ServerSetupLaunchSource.SETTINGS, + getConfigurationUseCase = stubGetConfigurationUseCase, + getLocalAiModelsUseCase = stubGetLocalAiModelsUseCase, + fetchAndGetHuggingFaceModelsUseCase = stubFetchAndGetHuggingFaceModelsUseCase, + urlValidator = stubUrlValidator, + stringValidator = stubCommonStringValidator, + setupConnectionInterActor = stubSetupConnectionInterActor, + downloadModelUseCase = stubDownloadModelUseCase, + deleteModelUseCase = stubDeleteModelUseCase, + schedulersProvider = stubSchedulersProvider, + preferenceManager = stubPreferenceManager, + wakeLockInterActor = stubWakeLockInterActor, + mainRouter = stubMainRouter, + ) + + @Before + override fun initialize() { + super.initialize() + + every { + stubGetConfigurationUseCase() + } returns Single.just(Configuration(serverUrl = "https://5598.is.my.favorite.com")) + + every { + stubGetLocalAiModelsUseCase() + } returns Single.just(mockLocalAiModels) + + every { + stubFetchAndGetHuggingFaceModelsUseCase() + } returns Single.just(mockHuggingFaceModels) + } + + @Test + fun `initialized, expected UI state updated with correct stub values`() { + runTest { + val state = viewModel.state.value + Assert.assertEquals(true, state.huggingFaceModels.isNotEmpty()) + Assert.assertEquals(true, state.localModels.isNotEmpty()) + Assert.assertEquals("https://5598.is.my.favorite.com", state.serverUrl) + Assert.assertEquals(ServerSetupState.AuthType.ANONYMOUS, state.authType) + } + } + + @Test + fun `given received AllowLocalCustomModel intent, expected Custom local model selected in UI state`() { + viewModel.processIntent(ServerSetupIntent.AllowLocalCustomModel(true)) + runTest { + val state = viewModel.state.value + val expectedLocalModels = listOf( + ServerSetupState.LocalModel( + id = "CUSTOM", + name = "Custom", + size = "NaN", + downloaded = false, + selected = true, + ), + ServerSetupState.LocalModel( + id = "1", + name = "Model 1", + size = "5 Gb", + downloaded = false, + selected = false, + ) + ) + Assert.assertEquals(true, state.localCustomModel) + Assert.assertEquals(expectedLocalModels, state.localModels) + } + } + + @Test + fun `given received DismissDialog intent, expected screenModal field in UI state is None`() { + viewModel.processIntent(ServerSetupIntent.DismissDialog) + runTest { + val expected = Modal.None + val actual = viewModel.state.value.screenModal + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received LocalModel ClickReduce intent, model not downloaded, expected UI state is Downloading, wakeLocks called`() { + every { + stubDownloadModelUseCase(any()) + } returns Observable.just(DownloadState.Downloading(22)) + + every { + stubWakeLockInterActor.acquireWakelockUseCase() + } returns Result.success(Unit) + + every { + stubWakeLockInterActor.releaseWakeLockUseCase() + } returns Result.success(Unit) + + val localModel = mockServerSetupStateLocalModel.copy( + downloadState = DownloadState.Unknown, + ) + val intent = ServerSetupIntent.LocalModel.ClickReduce(localModel) + viewModel.processIntent(intent) + + runTest { + val state = viewModel.state.value + val expected = true + val actual = state.localModels.any { + it.downloadState == DownloadState.Downloading(22) + } + Assert.assertEquals(expected, actual) + } + verify { + stubWakeLockInterActor.acquireWakelockUseCase() + } + verify { + stubWakeLockInterActor.releaseWakeLockUseCase() + } + verify { + stubDownloadModelUseCase("1") + } + } + + @Test + fun `given received LocalModel ClickReduce intent, model downloaded, expected screenModal field in UI state is DeleteLocalModelConfirm`() { + val localModel = mockServerSetupStateLocalModel.copy( + downloaded = true, + downloadState = DownloadState.Unknown, + ) + val intent = ServerSetupIntent.LocalModel.ClickReduce(localModel) + viewModel.processIntent(intent) + + runTest { + val expected = Modal.DeleteLocalModelConfirm(localModel) + val actual = viewModel.state.value.screenModal + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received LocalModel ClickReduce intent, model is downloading, expected UI state is Unknown`() { + val localModel = mockServerSetupStateLocalModel.copy( + downloadState = DownloadState.Downloading(22), + ) + val intent = ServerSetupIntent.LocalModel.ClickReduce(localModel) + viewModel.processIntent(intent) + + runTest { + val state = viewModel.state.value + val expected = false + val actual = state.localModels.any { + it.downloadState == DownloadState.Downloading(22) + } + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received LocalModel DeleteConfirm intent, expected downloaded field is false for LocalModel UI state`() { + every { + stubDeleteModelUseCase(any()) + } returns Completable.complete() + + val localModel = mockServerSetupStateLocalModel.copy( + downloaded = true, + downloadState = DownloadState.Unknown, + ) + val intent = ServerSetupIntent.LocalModel.DeleteConfirm(localModel) + viewModel.processIntent(intent) + + runTest { + val state = viewModel.state.value + Assert.assertEquals(Modal.None, state.screenModal) + Assert.assertEquals(false, state.localModels.find { it.id == "1" }!!.downloaded) + } + verify { + stubDeleteModelUseCase("1") + } + } + + @Test + fun `given received SelectLocalModel intent, expected passed LocalModel is selected in UI state`() { + viewModel.processIntent(ServerSetupIntent.SelectLocalModel(mockServerSetupStateLocalModel)) + runTest { + val state = viewModel.state.value + Assert.assertEquals(true, state.localModels.find { it.id == "1" }!!.selected) + } + } + + @Test + fun `given received MainButtonClick intent, expected step field in UI state is CONFIGURE`() { + viewModel.processIntent(ServerSetupIntent.MainButtonClick) + runTest { + val expected = ServerSetupState.Step.CONFIGURE + val actual = viewModel.state.value.step + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received UpdateDemoMode intent, expected demoMode field in UI state is true`() { + viewModel.processIntent(ServerSetupIntent.UpdateDemoMode(true)) + runTest { + val expected = true + val actual = viewModel.state.value.demoMode + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received UpdateHordeApiKey intent, expected hordeApiKey field in UI state is 5598`() { + viewModel.processIntent(ServerSetupIntent.UpdateHordeApiKey("5598")) + runTest { + val expected = "5598" + val actual = viewModel.state.value.hordeApiKey + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received UpdateHordeDefaultApiKey intent, expected hordeDefaultApiKey field in UI state is true`() { + viewModel.processIntent(ServerSetupIntent.UpdateHordeDefaultApiKey(true)) + runTest { + val expected = true + val actual = viewModel.state.value.hordeDefaultApiKey + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received UpdateHuggingFaceApiKey intent, expected huggingFaceApiKey field in UI state is 5598`() { + viewModel.processIntent(ServerSetupIntent.UpdateHuggingFaceApiKey("5598")) + runTest { + val expected = "5598" + val actual = viewModel.state.value.huggingFaceApiKey + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received UpdateHuggingFaceModel intent, expected huggingFaceModel field in UI state is 5598`() { + viewModel.processIntent(ServerSetupIntent.UpdateHuggingFaceModel("5598")) + runTest { + val expected = "5598" + val actual = viewModel.state.value.huggingFaceModel + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received UpdateLogin intent, expected login field is 5598, loginValidationError is null in UI state`() { + viewModel.processIntent(ServerSetupIntent.UpdateLogin("5598")) + runTest { + val state = viewModel.state.value + Assert.assertEquals("5598", state.login) + Assert.assertEquals(null, state.loginValidationError) + } + } + + @Test + fun `given received UpdatePassword intent, expected password field is 5598, passwordValidationError is null in UI state`() { + viewModel.processIntent(ServerSetupIntent.UpdatePassword("5598")) + runTest { + val state = viewModel.state.value + Assert.assertEquals("5598", state.password) + Assert.assertEquals(null, state.passwordValidationError) + } + } + + @Test + fun `given received UpdateOpenAiApiKey intent, expected openAiApiKey field in UI state is 5598`() { + viewModel.processIntent(ServerSetupIntent.UpdateOpenAiApiKey("5598")) + runTest { + val expected = "5598" + val actual = viewModel.state.value.openAiApiKey + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received UpdatePasswordVisibility intent, expected passwordVisible field in UI state is false`() { + viewModel.processIntent(ServerSetupIntent.UpdatePasswordVisibility(true)) + runTest { + val expected = false + val actual = viewModel.state.value.passwordVisible + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received UpdateServerMode intent, expected mode field in UI state is LOCAL`() { + viewModel.processIntent(ServerSetupIntent.UpdateServerMode(ServerSource.LOCAL)) + runTest { + val expected = ServerSource.LOCAL + val actual = viewModel.state.value.mode + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received LaunchManageStoragePermission intent, expected LaunchManageStoragePermission effect delivered to effect collector`() { + viewModel.processIntent(ServerSetupIntent.LaunchManageStoragePermission) + runTest { + val expected = ServerSetupEffect.LaunchManageStoragePermission + val actual = viewModel.effect.firstOrNull() + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received NavigateBack intent, expected router navigateBack() method called`() { + every { + stubMainRouter.navigateBack() + } returns Unit + + viewModel.processIntent(ServerSetupIntent.NavigateBack) + + verify { + stubMainRouter.navigateBack() + } + } + + @Test + fun `given received UpdateStabilityAiApiKey intent, expected stabilityAiApiKey field in UI state is 5598`() { + viewModel.processIntent(ServerSetupIntent.UpdateStabilityAiApiKey("5598")) + runTest { + val expected = "5598" + val actual = viewModel.state.value.stabilityAiApiKey + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received ConnectToLocalHost intent, expected success, router navigateToHomeScreen() method called, preference forceSetupAfterUpdate is false, dialog is None`() { + every { + stubSetupConnectionInterActor.connectToA1111(any(), any(), any()) + } returns Single.just(Result.success(Unit)) + + every { + stubMainRouter.navigateToHomeScreen() + } returns Unit + + every { + stubPreferenceManager::forceSetupAfterUpdate.set(any()) + } returns Unit + + viewModel.processIntent(ServerSetupIntent.ConnectToLocalHost) + + verify { + stubMainRouter.navigateToHomeScreen() + } + } +} diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/splash/SplashViewModelTest.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/splash/SplashViewModelTest.kt new file mode 100644 index 00000000..a115c47b --- /dev/null +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/splash/SplashViewModelTest.kt @@ -0,0 +1,79 @@ +package com.shifthackz.aisdv1.presentation.screen.splash + +import com.shifthackz.aisdv1.domain.usecase.splash.SplashNavigationUseCase +import com.shifthackz.aisdv1.presentation.core.CoreViewModelInitializeStrategy +import com.shifthackz.aisdv1.presentation.core.CoreViewModelTest +import com.shifthackz.aisdv1.presentation.navigation.router.main.MainRouter +import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupLaunchSource +import com.shifthackz.aisdv1.presentation.stub.stubSchedulersProvider +import io.mockk.every +import io.mockk.mockk +import io.mockk.verify +import io.reactivex.rxjava3.core.Single +import org.junit.Test + +class SplashViewModelTest : CoreViewModelTest() { + + private val stubMainRouter = mockk() + private val stubSplashNavigationUseCase = mockk() + + override val testViewModelStrategy = CoreViewModelInitializeStrategy.InitializeEveryTime + + override fun initializeViewModel() = SplashViewModel( + mainRouter = stubMainRouter, + splashNavigationUseCase = stubSplashNavigationUseCase, + schedulersProvider = stubSchedulersProvider, + ) + + @Test + fun `given initialized, use case emits LAUNCH_ONBOARDING action, expected nothing happens`() { + every { + stubSplashNavigationUseCase() + } returns Single.just(SplashNavigationUseCase.Action.LAUNCH_ONBOARDING) + + viewModel.hashCode() + + verify(inverse = true) { + stubMainRouter.navigateToServerSetup(ServerSetupLaunchSource.SPLASH) + } + verify(inverse = true) { + stubMainRouter.navigateToPostSplashConfigLoader() + } + } + + @Test + fun `given initialized, use case emits LAUNCH_SERVER_SETUP action, expected router navigateToServerSetup() method called`() { + every { + stubMainRouter.navigateToServerSetup(any()) + } returns Unit + + every { + stubSplashNavigationUseCase() + } returns Single.just(SplashNavigationUseCase.Action.LAUNCH_SERVER_SETUP) + + viewModel.hashCode() + + verify { + stubMainRouter.navigateToServerSetup( + ServerSetupLaunchSource.SPLASH + ) + } + } + + @Test + fun `given initialized, use case emits LAUNCH_HOME action, expected router navigateToServerSetup() method called`() { + every { + stubMainRouter.navigateToPostSplashConfigLoader() + } returns Unit + + every { + stubSplashNavigationUseCase() + } returns Single.just(SplashNavigationUseCase.Action.LAUNCH_HOME) + + viewModel.hashCode() + + verify { + stubMainRouter.navigateToPostSplashConfigLoader() + } + } +} diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageViewModelTest.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageViewModelTest.kt new file mode 100644 index 00000000..49667fae --- /dev/null +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageViewModelTest.kt @@ -0,0 +1,499 @@ +package com.shifthackz.aisdv1.presentation.screen.txt2img + +import com.shifthackz.aisdv1.core.validation.ValidationResult +import com.shifthackz.aisdv1.core.validation.dimension.DimensionValidator.Error +import com.shifthackz.aisdv1.domain.entity.OpenAiModel +import com.shifthackz.aisdv1.domain.entity.OpenAiQuality +import com.shifthackz.aisdv1.domain.entity.OpenAiSize +import com.shifthackz.aisdv1.domain.entity.OpenAiStyle +import com.shifthackz.aisdv1.domain.entity.ServerSource +import com.shifthackz.aisdv1.domain.entity.Settings +import com.shifthackz.aisdv1.domain.entity.StableDiffusionSampler +import com.shifthackz.aisdv1.domain.usecase.generation.TextToImageUseCase +import com.shifthackz.aisdv1.presentation.core.CoreGenerationMviViewModelTest +import com.shifthackz.aisdv1.presentation.core.GenerationFormUpdateEvent +import com.shifthackz.aisdv1.presentation.core.GenerationMviIntent +import com.shifthackz.aisdv1.presentation.mocks.mockAiGenerationResult +import com.shifthackz.aisdv1.presentation.model.Modal +import com.shifthackz.aisdv1.presentation.screen.drawer.DrawerIntent +import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupLaunchSource +import io.mockk.every +import io.mockk.mockk +import io.mockk.unmockkAll +import io.mockk.verify +import io.reactivex.rxjava3.core.BackpressureStrategy +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single +import kotlinx.coroutines.test.runTest +import org.junit.After +import org.junit.Assert +import org.junit.Before +import org.junit.Test + +class TextToImageViewModelTest : CoreGenerationMviViewModelTest() { + + private val stubGenerationFormUpdateEvent = mockk() + private val stubTextToImageUseCase = mockk() + + override fun initializeViewModel() = TextToImageViewModel( + generationFormUpdateEvent = stubGenerationFormUpdateEvent, + getStableDiffusionSamplersUseCase = stubGetStableDiffusionSamplersUseCase, + observeHordeProcessStatusUseCase = stubObserveHordeProcessStatusUseCase, + observeLocalDiffusionProcessStatusUseCase = stubObserveLocalDiffusionProcessStatusUseCase, + saveLastResultToCacheUseCase = stubSaveLastResultToCacheUseCase, + saveGenerationResultUseCase = stubSaveGenerationResultUseCase, + interruptGenerationUseCase = stubInterruptGenerationUseCase, + mainRouter = stubMainRouter, + drawerRouter = stubDrawerRouter, + dimensionValidator = stubDimensionValidator, + textToImageUseCase = stubTextToImageUseCase, + schedulersProvider = stubCustomSchedulers, + preferenceManager = stubPreferenceManager, + notificationManager = stubSdaiPushNotificationManager, + wakeLockInterActor = stubWakeLockInterActor, + ) + + @Before + override fun initialize() { + super.initialize() + + every { + stubGenerationFormUpdateEvent.observeTxt2ImgForm() + } returns stubAiForm.toFlowable(BackpressureStrategy.LATEST) + + stubSettings.onNext(Settings(source = ServerSource.AUTOMATIC1111)) + } + + @After + override fun finalize() { + super.finalize() + unmockkAll() + } + + @Test + fun `initialized, expected UI state update with correct stub values`() { + runTest { + val state = viewModel.state.value + Assert.assertNotNull(viewModel) + Assert.assertNotNull(viewModel.initialState) + Assert.assertNotNull(viewModel.state.value) + Assert.assertEquals(ServerSource.AUTOMATIC1111, state.mode) + Assert.assertEquals(emptyList(), state.availableSamplers) + } + verify { + stubGetStableDiffusionSamplersUseCase() + } + verify { + stubPreferenceManager.observe() + } + } + + @Test + fun `given received NewPrompts intent, expected prompt, negativePrompt updated in UI state`() { + val intent = GenerationMviIntent.NewPrompts( + positive = "prompt", + negative = "negative", + ) + viewModel.processIntent(intent) + runTest { + val state = viewModel.state.value + Assert.assertEquals("prompt", state.prompt) + Assert.assertEquals("negative", state.negativePrompt) + } + } + + @Test + fun `given received SetAdvancedOptionsVisibility intent, expected advancedOptionsVisible updated in UI state`() { + val intent = GenerationMviIntent.SetAdvancedOptionsVisibility(true) + viewModel.processIntent(intent) + runTest { + val expected = true + val actual = viewModel.state.value.advancedOptionsVisible + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received Update Prompt intent, expected prompt updated in UI state`() { + val intent = GenerationMviIntent.Update.Prompt("5598") + viewModel.processIntent(intent) + runTest { + val expected = "5598" + val actual = viewModel.state.value.prompt + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received Update NegativePrompt intent, expected negativePrompt updated in UI state`() { + val intent = GenerationMviIntent.Update.NegativePrompt("5598") + viewModel.processIntent(intent) + runTest { + val expected = "5598" + val actual = viewModel.state.value.negativePrompt + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received Update Size Width intent with valid value, expected width updated, widthValidationError is null in UI state`() { + every { + stubDimensionValidator(any()) + } returns ValidationResult(true) + + val intent = GenerationMviIntent.Update.Size.Width("512") + viewModel.processIntent(intent) + + runTest { + val state = viewModel.state.value + Assert.assertEquals("512", state.width) + Assert.assertNull(state.widthValidationError) + } + } + + @Test + fun `given received Update Size Width intent with invalid value, expected width updated, widthValidationError is NOT null in UI state`() { + every { + stubDimensionValidator(any()) + } returns ValidationResult(false, Error.Unexpected) + + val intent = GenerationMviIntent.Update.Size.Width("512d") + viewModel.processIntent(intent) + + runTest { + val state = viewModel.state.value + Assert.assertEquals("512d", state.width) + Assert.assertNotNull(state.widthValidationError) + } + } + + @Test + fun `given received Update Size Height intent with valid value, expected height updated, heightValidationError is null in UI state`() { + every { + stubDimensionValidator(any()) + } returns ValidationResult(true) + + val intent = GenerationMviIntent.Update.Size.Height("512") + viewModel.processIntent(intent) + + runTest { + val state = viewModel.state.value + Assert.assertEquals("512", state.height) + Assert.assertNull(state.heightValidationError) + } + } + + @Test + fun `given received Update Size Height intent with invalid value, expected height updated, heightValidationError is NOT null in UI state`() { + every { + stubDimensionValidator(any()) + } returns ValidationResult(false, Error.Unexpected) + + val intent = GenerationMviIntent.Update.Size.Height("512d") + viewModel.processIntent(intent) + + runTest { + val state = viewModel.state.value + Assert.assertEquals("512d", state.height) + Assert.assertNotNull(state.heightValidationError) + } + } + + @Test + fun `given received Update SamplingSteps intent, expected samplingSteps updated in UI state`() { + val intent = GenerationMviIntent.Update.SamplingSteps(12) + viewModel.processIntent(intent) + runTest { + val expected = 12 + val actual = viewModel.state.value.samplingSteps + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received Update CfgScale intent, expected cfgScale updated in UI state`() { + val intent = GenerationMviIntent.Update.CfgScale(12f) + viewModel.processIntent(intent) + runTest { + val expected = 12f + val actual = viewModel.state.value.cfgScale + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received Update RestoreFaces intent, expected restoreFaces updated in UI state`() { + val intent = GenerationMviIntent.Update.RestoreFaces(true) + viewModel.processIntent(intent) + runTest { + val expected = true + val actual = viewModel.state.value.restoreFaces + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received Update Seed intent, expected seed updated in UI state`() { + val intent = GenerationMviIntent.Update.Seed("5598") + viewModel.processIntent(intent) + runTest { + val expected = "5598" + val actual = viewModel.state.value.seed + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received Update SubSeed intent, expected subSeed updated in UI state`() { + val intent = GenerationMviIntent.Update.SubSeed("5598") + viewModel.processIntent(intent) + runTest { + val expected = "5598" + val actual = viewModel.state.value.subSeed + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received Update SubSeedStrength intent, expected subSeed updated in UI state`() { + val intent = GenerationMviIntent.Update.SubSeedStrength(7f) + viewModel.processIntent(intent) + runTest { + val expected = 7f + val actual = viewModel.state.value.subSeedStrength + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received Update Sampler intent, expected selectedSampler updated in UI state`() { + val intent = GenerationMviIntent.Update.Sampler("5598") + viewModel.processIntent(intent) + runTest { + val expected = "5598" + val actual = viewModel.state.value.selectedSampler + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received Update Nsfw intent, expected nsfw updated in UI state`() { + val intent = GenerationMviIntent.Update.Nsfw(true) + viewModel.processIntent(intent) + runTest { + val expected = true + val actual = viewModel.state.value.nsfw + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received Update Batch intent, expected batchCount updated in UI state`() { + val intent = GenerationMviIntent.Update.Batch(26) + viewModel.processIntent(intent) + runTest { + val expected = 26 + val actual = viewModel.state.value.batchCount + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received Update OpenAi Model intent, expected openAiModel updated in UI state`() { + val intent = GenerationMviIntent.Update.OpenAi.Model(OpenAiModel.DALL_E_2) + viewModel.processIntent(intent) + runTest { + val expected = OpenAiModel.DALL_E_2 + val actual = viewModel.state.value.openAiModel + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received Update OpenAi Size intent, expected openAiSize updated in UI state`() { + val intent = GenerationMviIntent.Update.OpenAi.Size(OpenAiSize.W256_H256) + viewModel.processIntent(intent) + runTest { + val expected = OpenAiSize.W256_H256 + val actual = viewModel.state.value.openAiSize + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received Update OpenAi Quality intent, expected openAiQuality updated in UI state`() { + val intent = GenerationMviIntent.Update.OpenAi.Quality(OpenAiQuality.HD) + viewModel.processIntent(intent) + runTest { + val expected = OpenAiQuality.HD + val actual = viewModel.state.value.openAiQuality + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received Update OpenAi Style intent, expected openAiStyle updated in UI state`() { + val intent = GenerationMviIntent.Update.OpenAi.Style(OpenAiStyle.NATURAL) + viewModel.processIntent(intent) + runTest { + val expected = OpenAiStyle.NATURAL + val actual = viewModel.state.value.openAiStyle + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received Result Save intent, expected screenModal is None in UI state`() { + every { + stubSaveGenerationResultUseCase(any()) + } returns Completable.complete() + + val intent = GenerationMviIntent.Result.Save(listOf(mockAiGenerationResult)) + viewModel.processIntent(intent) + + runTest { + val state = viewModel.state.value + Assert.assertEquals(Modal.None, state.screenModal) + } + } + + @Test + fun `given received Result View intent, expected saveGenerationResultUseCase() called`() { + every { + stubSaveLastResultToCacheUseCase(any()) + } returns Single.just(mockAiGenerationResult) + + every { + stubMainRouter.navigateToGalleryDetails(any()) + } returns Unit + + val intent = GenerationMviIntent.Result.View(mockAiGenerationResult) + viewModel.processIntent(intent) + + verify { + stubSaveLastResultToCacheUseCase(mockAiGenerationResult) + } + } + + @Test + fun `given received SetModal intent, expected screenModal updated in UI state`() { + val intent = GenerationMviIntent.SetModal(Modal.Communicating()) + viewModel.processIntent(intent) + runTest { + val expected = Modal.Communicating() + val actual = viewModel.state.value.screenModal + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received Cancel Generation intent, expected interruptGenerationUseCase() called`() { + every { + stubInterruptGenerationUseCase() + } returns Completable.complete() + + val intent = GenerationMviIntent.Cancel.Generation + viewModel.processIntent(intent) + + verify { + stubInterruptGenerationUseCase() + } + } + + @Test + fun `given received Cancel FetchRandomImage intent, expected screenModal is None in UI state`() { + val intent = GenerationMviIntent.Cancel.FetchRandomImage + viewModel.processIntent(intent) + runTest { + Assert.assertEquals( + Modal.None, + viewModel.state.value.screenModal, + ) + } + } + + @Test + fun `given received Generate intent, expected textToImageUseCase() called`() { + every { + stubPreferenceManager::autoSaveAiResults.get() + } returns true + + every { + stubSdaiPushNotificationManager.show(any(), any()) + } returns Unit + + every { + stubTextToImageUseCase.invoke(any()) + } returns Single.just(listOf(mockAiGenerationResult)) + + val payload = viewModel.state.value.mapToPayload() + val intent = GenerationMviIntent.Generate + viewModel.processIntent(intent) + + verify { + stubTextToImageUseCase(payload) + } + } + + @Test + fun `given received Configuration intent, expected router navigateToServerSetup() called`() { + every { + stubMainRouter.navigateToServerSetup(any()) + } returns Unit + + val intent = GenerationMviIntent.Configuration + viewModel.processIntent(intent) + + verify { + stubMainRouter.navigateToServerSetup(ServerSetupLaunchSource.SETTINGS) + } + } + + @Test + fun `given received UpdateFromGeneration intent, expected UI state fields are same as intent model`() { + val intent = GenerationMviIntent.UpdateFromGeneration(mockAiGenerationResult) + viewModel.processIntent(intent) + runTest { + val state = viewModel.state.value + Assert.assertEquals(true, state.advancedOptionsVisible) + Assert.assertEquals(mockAiGenerationResult.prompt, state.prompt) + Assert.assertEquals(mockAiGenerationResult.negativePrompt, state.negativePrompt) + Assert.assertEquals(mockAiGenerationResult.width.toString(), state.width) + Assert.assertEquals(mockAiGenerationResult.height.toString(), state.height) + Assert.assertEquals(mockAiGenerationResult.seed, state.seed) + Assert.assertEquals(mockAiGenerationResult.subSeed, state.subSeed) + Assert.assertEquals(mockAiGenerationResult.subSeedStrength, state.subSeedStrength) + Assert.assertEquals(mockAiGenerationResult.samplingSteps, state.samplingSteps) + Assert.assertEquals(mockAiGenerationResult.cfgScale, state.cfgScale) + Assert.assertEquals(mockAiGenerationResult.restoreFaces, state.restoreFaces) + } + } + + @Test + fun `given received Drawer Open intent, expected router openDrawer() called`() { + every { + stubDrawerRouter.openDrawer() + } returns Unit + + val intent = GenerationMviIntent.Drawer(DrawerIntent.Open) + viewModel.processIntent(intent) + + verify { + stubDrawerRouter.openDrawer() + } + } + + @Test + fun `given received Drawer Close intent, expected router closeDrawer() called`() { + every { + stubDrawerRouter.closeDrawer() + } returns Unit + + val intent = GenerationMviIntent.Drawer(DrawerIntent.Close) + viewModel.processIntent(intent) + + verify { + stubDrawerRouter.closeDrawer() + } + } +} diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/stub/SchedulersProviderStub.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/stub/SchedulersProviderStub.kt new file mode 100644 index 00000000..bc25478d --- /dev/null +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/stub/SchedulersProviderStub.kt @@ -0,0 +1,14 @@ +package com.shifthackz.aisdv1.presentation.stub + +import com.shifthackz.aisdv1.core.common.schedulers.SchedulersProvider +import io.reactivex.rxjava3.core.Scheduler +import io.reactivex.rxjava3.schedulers.Schedulers +import java.util.concurrent.Executor +import java.util.concurrent.Executors + +val stubSchedulersProvider = object : SchedulersProvider { + override val io: Scheduler = Schedulers.trampoline() + override val ui: Scheduler = Schedulers.trampoline() + override val computation: Scheduler = Schedulers.trampoline() + override val singleThread: Executor = Executors.newSingleThreadExecutor() +} diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/widget/connectivity/ConnectivityViewModelTest.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/widget/connectivity/ConnectivityViewModelTest.kt new file mode 100644 index 00000000..3ea39ebf --- /dev/null +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/widget/connectivity/ConnectivityViewModelTest.kt @@ -0,0 +1,112 @@ +package com.shifthackz.aisdv1.presentation.widget.connectivity + +import com.shifthackz.aisdv1.domain.entity.Settings +import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import com.shifthackz.aisdv1.domain.usecase.connectivity.ObserveSeverConnectivityUseCase +import com.shifthackz.aisdv1.presentation.core.CoreViewModelTest +import com.shifthackz.aisdv1.presentation.stub.stubSchedulersProvider +import io.mockk.every +import io.mockk.mockk +import io.reactivex.rxjava3.core.BackpressureStrategy +import io.reactivex.rxjava3.subjects.BehaviorSubject +import kotlinx.coroutines.test.runTest +import org.junit.Assert +import org.junit.Before +import org.junit.Test + +class ConnectivityViewModelTest : CoreViewModelTest() { + + private val stubSettings = BehaviorSubject.create() + private val stubConnected = BehaviorSubject.create() + private val stubPreferenceManager = mockk() + private val stubObserveSeverConnectivityUseCase = mockk() + + override fun initializeViewModel() = ConnectivityViewModel( + preferenceManager = stubPreferenceManager, + observeServerConnectivityUseCase = stubObserveSeverConnectivityUseCase, + schedulersProvider = stubSchedulersProvider, + ) + + @Before + override fun initialize() { + super.initialize() + + every { + stubPreferenceManager.observe() + } returns stubSettings.toFlowable(BackpressureStrategy.LATEST) + + every { + stubPreferenceManager::monitorConnectivity.get() + } returns true + + every { + stubObserveSeverConnectivityUseCase() + } returns stubConnected.toFlowable(BackpressureStrategy.LATEST) + } + + @Test + fun `initialized, monitorConnectivity true, expected UI state is Uninitialized, enabled is true`() { + runTest { + val expected = ConnectivityState.Uninitialized(true) + val actual = viewModel.state.value + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `initialized, monitorConnectivity false, expected UI state is Uninitialized, enabled is false`() { + every { + stubPreferenceManager::monitorConnectivity.get() + } returns false + + runTest { + val expected = ConnectivityState.Uninitialized(false) + val actual = viewModel.state.value + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given monitorConnectivity true, connected, expected UI state is Connected, enabled is true`() { + stubSettings.onNext(Settings(monitorConnectivity = true)) + stubConnected.onNext(true) + runTest { + val expected = ConnectivityState.Connected(true) + val actual = viewModel.state.value + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given monitorConnectivity false, connected, expected UI state is Connected, enabled is false`() { + stubSettings.onNext(Settings(monitorConnectivity = false)) + stubConnected.onNext(true) + runTest { + val expected = ConnectivityState.Connected(false) + val actual = viewModel.state.value + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given monitorConnectivity true, disconnected, expected UI state is Disconnected, enabled is true`() { + stubSettings.onNext(Settings(monitorConnectivity = true)) + stubConnected.onNext(false) + runTest { + val expected = ConnectivityState.Disconnected(true) + val actual = viewModel.state.value + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given monitorConnectivity false, disconnected, expected UI state is Disconnected, enabled is false`() { + stubSettings.onNext(Settings(monitorConnectivity = false)) + stubConnected.onNext(false) + runTest { + val expected = ConnectivityState.Disconnected(false) + val actual = viewModel.state.value + Assert.assertEquals(expected, actual) + } + } +} diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModelTest.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModelTest.kt new file mode 100644 index 00000000..cf2d83cb --- /dev/null +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModelTest.kt @@ -0,0 +1,290 @@ +package com.shifthackz.aisdv1.presentation.widget.engine + +import com.shifthackz.aisdv1.domain.entity.Configuration +import com.shifthackz.aisdv1.domain.entity.LocalAiModel +import com.shifthackz.aisdv1.domain.entity.ServerSource +import com.shifthackz.aisdv1.domain.entity.Settings +import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import com.shifthackz.aisdv1.domain.usecase.downloadable.ObserveLocalAiModelsUseCase +import com.shifthackz.aisdv1.domain.usecase.huggingface.FetchAndGetHuggingFaceModelsUseCase +import com.shifthackz.aisdv1.domain.usecase.sdmodel.GetStableDiffusionModelsUseCase +import com.shifthackz.aisdv1.domain.usecase.sdmodel.SelectStableDiffusionModelUseCase +import com.shifthackz.aisdv1.domain.usecase.settings.GetConfigurationUseCase +import com.shifthackz.aisdv1.domain.usecase.stabilityai.FetchAndGetStabilityAiEnginesUseCase +import com.shifthackz.aisdv1.presentation.core.CoreViewModelTest +import com.shifthackz.aisdv1.presentation.mocks.mockHuggingFaceModels +import com.shifthackz.aisdv1.presentation.mocks.mockLocalAiModels +import com.shifthackz.aisdv1.presentation.mocks.mockStabilityAiEngines +import com.shifthackz.aisdv1.presentation.mocks.mockStableDiffusionModels +import com.shifthackz.aisdv1.presentation.stub.stubSchedulersProvider +import io.mockk.every +import io.mockk.mockk +import io.mockk.unmockkAll +import io.mockk.verify +import io.reactivex.rxjava3.core.BackpressureStrategy +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Flowable +import io.reactivex.rxjava3.core.Single +import io.reactivex.rxjava3.subjects.BehaviorSubject +import kotlinx.coroutines.test.runTest +import org.junit.After +import org.junit.Assert +import org.junit.Before +import org.junit.Test + +class EngineSelectionViewModelTest : CoreViewModelTest() { + + private val stubSettings = BehaviorSubject.create>() + private val stubLocalAiModels = BehaviorSubject.create>>() + private val stubException = Throwable("Something went wrong.") + private val stubPreferenceManager = mockk() + private val stubGetConfigurationUseCase = mockk() + private val stubSelectStableDiffusionModelUseCase = mockk() + private val stubGetStableDiffusionModelsUseCase = mockk() + private val stubObserveLocalAiModelsUseCase = mockk() + private val stubFetchAndGetStabilityAiEnginesUseCase = + mockk() + private val stubFetchAndGetHuggingFaceModelsUseCase = + mockk() + + override fun initializeViewModel() = EngineSelectionViewModel( + preferenceManager = stubPreferenceManager, + schedulersProvider = stubSchedulersProvider, + getConfigurationUseCase = stubGetConfigurationUseCase, + selectStableDiffusionModelUseCase = stubSelectStableDiffusionModelUseCase, + getStableDiffusionModelsUseCase = stubGetStableDiffusionModelsUseCase, + observeLocalAiModelsUseCase = stubObserveLocalAiModelsUseCase, + fetchAndGetStabilityAiEnginesUseCase = stubFetchAndGetStabilityAiEnginesUseCase, + getHuggingFaceModelsUseCase = stubFetchAndGetHuggingFaceModelsUseCase, + ) + + @Before + override fun initialize() { + super.initialize() + + every { + stubPreferenceManager.observe() + } returns stubSettings + .toFlowable(BackpressureStrategy.LATEST) + .flatMap { result -> + result.fold( + onSuccess = { settings -> Flowable.just(settings) }, + onFailure = { t -> Flowable.error(t) }, + ) + } + + every { + stubObserveLocalAiModelsUseCase() + } returns stubLocalAiModels + .toFlowable(BackpressureStrategy.LATEST) + .flatMap { result -> + result.fold( + onSuccess = { list -> Flowable.just(list) }, + onFailure = { t -> Flowable.error(t) }, + ) + } + } + + @After + override fun finalize() { + super.finalize() + unmockkAll() + } + + @Test + fun `initialized, use cases returned data, expected UI state with correct valid stub data`() { + mockInitialData(DataTestCase.Mock) + runTest { + val expected = EngineSelectionState( + loading = false, + mode = ServerSource.AUTOMATIC1111, + sdModels = listOf("title_5598", "title_151297"), + selectedSdModel = "title_5598", + hfModels = listOf("prompthero/openjourney-v4", "wavymulder/Analog-Diffusion"), + selectedHfModel = "prompthero/openjourney-v4", + stEngines = listOf("5598"), + selectedStEngine = "5598", + localAiModels = listOf(LocalAiModel.CUSTOM), + selectedLocalAiModelId = "CUSTOM", + ) + val actual = viewModel.state.value + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `initialized, use cases returned empty data, expected UI state with empty stub data`() { + mockInitialData(DataTestCase.Empty) + runTest { + val expected = EngineSelectionState( + loading = false, + mode = ServerSource.AUTOMATIC1111, + sdModels = emptyList(), + selectedSdModel = "", + hfModels = emptyList(), + selectedHfModel = "", + stEngines = emptyList(), + selectedStEngine = "", + localAiModels = emptyList(), + selectedLocalAiModelId = "", + ) + val actual = viewModel.state.value + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `initialized, use cases thrown exceptions, expected UI state with empty stub data`() { + mockInitialData(DataTestCase.Exception) + runTest { + val expected = EngineSelectionState( + loading = false, + mode = ServerSource.AUTOMATIC1111, + sdModels = emptyList(), + selectedSdModel = "", + hfModels = emptyList(), + selectedHfModel = "", + stEngines = emptyList(), + selectedStEngine = "", + localAiModels = emptyList(), + selectedLocalAiModelId = "", + ) + val actual = viewModel.state.value + Assert.assertEquals(expected, actual) + } + } + + @Test + fun `given received EngineSelectionIntent, source is AUTOMATIC1111, expected selectedSdModel changed in UI state`() { + mockInitialData(DataTestCase.Mock, ServerSource.AUTOMATIC1111) + + every { + stubSelectStableDiffusionModelUseCase(any()) + } returns Completable.complete() + + every { + stubGetStableDiffusionModelsUseCase() + } returns Single.just(mockStableDiffusionModels.map { (f, s) -> f to !s }) + + viewModel.processIntent(EngineSelectionIntent("title_151297")) + + runTest { + val state = viewModel.state.value + Assert.assertEquals(false, state.loading) + Assert.assertEquals(listOf("title_5598", "title_151297"), state.sdModels) + Assert.assertEquals("title_151297", state.selectedSdModel) + } + } + + @Test + fun `given received EngineSelectionIntent, source is HUGGING_FACE, expected huggingFaceModel changed in preference`() { + mockInitialData(DataTestCase.Mock, ServerSource.HUGGING_FACE) + + every { + stubPreferenceManager::huggingFaceModel.set(any()) + } returns Unit + + viewModel.processIntent(EngineSelectionIntent("hf_5598")) + + verify { + stubPreferenceManager::huggingFaceModel.set("hf_5598") + } + } + + @Test + fun `given received EngineSelectionIntent, source is STABILITY_AI, expected stabilityAiEngineId changed in preference`() { + mockInitialData(DataTestCase.Mock, ServerSource.STABILITY_AI) + + every { + stubPreferenceManager::stabilityAiEngineId.set(any()) + } returns Unit + + viewModel.processIntent(EngineSelectionIntent("st_5598")) + + verify { + stubPreferenceManager::stabilityAiEngineId.set("st_5598") + } + } + + @Test + fun `given received EngineSelectionIntent, source is LOCAL, expected localModelId changed in preference`() { + mockInitialData(DataTestCase.Mock, ServerSource.LOCAL) + + every { + stubPreferenceManager::localModelId.set(any()) + } returns Unit + + viewModel.processIntent(EngineSelectionIntent("llm_5598")) + + verify { + stubPreferenceManager::localModelId.set("llm_5598") + } + } + + private fun mockInitialData( + testCase: DataTestCase, + source: ServerSource = ServerSource.AUTOMATIC1111, + ) { + stubSettings.onNext( + when (testCase) { + DataTestCase.Mock -> Result.success(Settings()) + DataTestCase.Empty -> Result.success(Settings()) + DataTestCase.Exception -> Result.failure(stubException) + } + ) + + every { + stubGetConfigurationUseCase() + } returns when (testCase) { + DataTestCase.Mock -> Single.just( + Configuration( + huggingFaceModel = "prompthero/openjourney-v4", + stabilityAiEngineId = "5598", + localModelId = "CUSTOM", + source = source, + ), + ) + + DataTestCase.Empty -> Single.just(Configuration()) + DataTestCase.Exception -> Single.error(stubException) + } + + every { + stubGetStableDiffusionModelsUseCase() + } returns when (testCase) { + DataTestCase.Mock -> Single.just(mockStableDiffusionModels) + DataTestCase.Empty -> Single.just(emptyList()) + DataTestCase.Exception -> Single.error(stubException) + } + + every { + stubFetchAndGetHuggingFaceModelsUseCase() + } returns when (testCase) { + DataTestCase.Mock -> Single.just(mockHuggingFaceModels) + DataTestCase.Empty -> Single.just(emptyList()) + DataTestCase.Exception -> Single.error(stubException) + } + + every { + stubFetchAndGetStabilityAiEnginesUseCase() + } returns when (testCase) { + DataTestCase.Mock -> Single.just(mockStabilityAiEngines) + DataTestCase.Empty -> Single.just(emptyList()) + DataTestCase.Exception -> Single.error(stubException) + } + + stubLocalAiModels.onNext( + when (testCase) { + DataTestCase.Mock -> Result.success(mockLocalAiModels) + DataTestCase.Empty -> Result.success(emptyList()) + DataTestCase.Exception -> Result.failure(stubException) + }, + ) + } + + private enum class DataTestCase { + Mock, + Empty, + Exception, + } +}