Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,18 @@ class SafeBoxTest {
assertEquals("Secured", value)
}

@Test
fun getString_withRealIoDispatcher_shouldReturnCorrectValue() {
safeBox = createSafeBox(ioDispatcher = Dispatchers.IO)
safeBox.edit()
.putString("SafeBox", "Secured")
.apply()

val value = safeBox.getString("SafeBox", null)

assertEquals("Secured", value)
}

@Test
fun getString_afterRemove_shouldReturnDefaultValue() {
safeBox = createSafeBox()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,11 @@ import android.content.Context
import androidx.test.core.app.ApplicationProvider
import androidx.test.ext.junit.runners.AndroidJUnit4
import com.harrytmthy.safebox.extensions.toBytes
import com.harrytmthy.safebox.state.SafeBoxState
import com.harrytmthy.safebox.state.SafeBoxStateListener
import com.harrytmthy.safebox.state.SafeBoxStateManager
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.test.UnconfinedTestDispatcher
import kotlinx.coroutines.test.runTest
import org.junit.After
import org.junit.runner.RunWith
import java.io.File
import java.util.concurrent.CopyOnWriteArrayList
import kotlin.test.Test
import kotlin.test.assertContentEquals
import kotlin.test.assertEquals
Expand All @@ -44,21 +39,12 @@ class SafeBoxBlobStoreTest {

private val fileName: String = "safebox_blob_test"

private val observedStates = CopyOnWriteArrayList<SafeBoxState>()

private val stateListener = SafeBoxStateListener(observedStates::add)

private val blobStore = SafeBoxBlobStore.create(
context,
fileName,
SafeBoxStateManager(fileName, stateListener, UnconfinedTestDispatcher()),
)
private val blobStore = SafeBoxBlobStore.create(context, fileName)

@After
fun teardown() {
blobStore.close()
File(context.noBackupFilesDir, "$fileName.bin").delete()
observedStates.clear()
}

@Test
Expand All @@ -70,7 +56,7 @@ class SafeBoxBlobStoreTest {
blobStore.write(firstKey, firstValue)
blobStore.write(secondKey, secondValue)

val result = blobStore.getAll()
val result = blobStore.loadPersistedEntries()

assertEquals(2, result.size)
assertContentEquals(firstValue, result[firstKey])
Expand All @@ -91,7 +77,7 @@ class SafeBoxBlobStoreTest {

blobStore.delete(firstKey)

val result = blobStore.getAll()
val result = blobStore.loadPersistedEntries()
assertEquals(2, result.size)
assertFalse(result.any { it == firstKey })
assertContentEquals(secondValue, result[secondKey])
Expand All @@ -113,7 +99,7 @@ class SafeBoxBlobStoreTest {

blobStore.delete(secondKey)

val result = blobStore.getAll()
val result = blobStore.loadPersistedEntries()
assertEquals(2, result.size)
assertFalse(result.any { it == secondKey })
assertContentEquals(firstValue, result[firstKey])
Expand All @@ -135,7 +121,7 @@ class SafeBoxBlobStoreTest {

blobStore.delete(thirdKey)

val result = blobStore.getAll()
val result = blobStore.loadPersistedEntries()
assertEquals(2, result.size)
assertFalse(result.any { it == thirdKey })
assertContentEquals(firstValue, result[firstKey])
Expand All @@ -157,7 +143,7 @@ class SafeBoxBlobStoreTest {
val thirdValue = "789".toByteArray()
blobStore.write(thirdKey, thirdValue)

val result = blobStore.getAll()
val result = blobStore.loadPersistedEntries()
assertEquals(2, result.size)
assertFalse(result.any { it == firstKey })
assertContentEquals(secondValue, result[secondKey])
Expand All @@ -176,7 +162,7 @@ class SafeBoxBlobStoreTest {
blobStore.write(key, firstValue)
blobStore.write(key, secondValue)

val result = blobStore.getAll()
val result = blobStore.loadPersistedEntries()
assertEquals(1, result.size)
assertContentEquals(secondValue, result[key])
assertTrue(blobStore.entryMetas.containsKey(key))
Expand All @@ -191,7 +177,7 @@ class SafeBoxBlobStoreTest {
blobStore.write(key, firstValue)
blobStore.write(key, secondValue)

val result = blobStore.getAll()
val result = blobStore.loadPersistedEntries()
assertEquals(1, result.size)
assertContentEquals(secondValue, result[key])
assertTrue(blobStore.entryMetas.containsKey(key))
Expand All @@ -206,7 +192,7 @@ class SafeBoxBlobStoreTest {
blobStore.write(key, firstValue)
blobStore.write(key, secondValue)

val result = blobStore.getAll()
val result = blobStore.loadPersistedEntries()
assertEquals(1, result.size)
assertContentEquals(secondValue, result[key])
assertTrue(blobStore.entryMetas.containsKey(key))
Expand All @@ -216,13 +202,4 @@ class SafeBoxBlobStoreTest {
fun getFileName_shouldReturnFileName() {
assertEquals(fileName, blobStore.getFileName())
}

@Test
fun init_shouldEmitStartingState() {
val expected = listOf(
SafeBoxState.STARTING,
SafeBoxState.IDLE,
)
assertEquals(expected, observedStates)
}
}
78 changes: 60 additions & 18 deletions safebox/src/main/java/com/harrytmthy/safebox/SafeBox.kt
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ import kotlinx.coroutines.CoroutineExceptionHandler
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.CopyOnWriteArrayList
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.AtomicReference
Expand Down Expand Up @@ -77,7 +78,9 @@ public class SafeBox private constructor(
private val stateManager: SafeBoxStateManager,
) : SharedPreferences {

private val castFailureStrategy = AtomicReference<ValueFallbackStrategy>(WARN)
private val entries: MutableMap<Bytes, ByteArray> = ConcurrentHashMap()

private val castFailureStrategy = AtomicReference(WARN)

private val byteDecoder = ByteDecoder(castFailureStrategy::get)

Expand All @@ -92,34 +95,79 @@ public class SafeBox private constructor(

private val delegate = object : Delegate {

private val updateLock = Any()

private val exceptionHandler = CoroutineExceptionHandler { _, throwable ->
Log.e("SafeBox", "Failed to apply changes.", throwable)
applyCompleted.complete(Unit)
}

override fun commit(entries: LinkedHashMap<String, Action>, cleared: Boolean): Boolean =
stateManager.launchCommitWithWritingState {
override fun commit(entries: LinkedHashMap<String, Action>, cleared: Boolean): Boolean {
val entriesToWrite = LinkedHashMap(entries)
synchronized(updateLock) {
entries.clear() // Prevents stale mutations on reused editor instance
updateEntries(entriesToWrite, cleared)
}
return stateManager.launchCommitWithWritingState {
try {
applyCompleted.await()
commitMutex.withLock {
applyChanges(entries, cleared)
applyChanges(entriesToWrite, cleared)
}
true
} catch (e: Exception) {
Log.e("SafeBox", "Failed to commit changes.", e)
false
}
}
}

override fun apply(entries: LinkedHashMap<String, Action>, cleared: Boolean) {
val entriesToWrite = LinkedHashMap(entries)
synchronized(updateLock) {
entries.clear() // Prevents stale mutations on reused editor instance
updateEntries(entriesToWrite, cleared)
}
stateManager.launchApplyWithWritingState(exceptionHandler) {
applyCompleted = CompletableDeferred()
applyMutex.withLock {
applyChanges(entries, cleared)
applyChanges(entriesToWrite, cleared)
}
applyCompleted.complete(Unit)
}
}

private fun updateEntries(entries: LinkedHashMap<String, Action>, cleared: Boolean) {
if (cleared) {
val keys = this@SafeBox.entries.keys.toHashSet()
this@SafeBox.entries.clear()
for (encryptedKey in keys) {
val key = keyCipherProvider.decrypt(encryptedKey.value).toString(Charsets.UTF_8)
listeners.forEach { it.onSharedPreferenceChanged(this@SafeBox, key) }
}
}
for ((key, action) in entries) {
when (action) {
is Put -> {
val encryptedKey = key.toEncryptedKey()
val encryptedValue = action.encodedValue.value
.let(valueCipherProvider::encrypt)
this@SafeBox.entries[encryptedKey] = encryptedValue
}
is Remove -> {
val encryptedKey = key.toEncryptedKey()
this@SafeBox.entries.remove(encryptedKey)
}
}
listeners.forEach { it.onSharedPreferenceChanged(this@SafeBox, key) }
}
}
}

init {
stateManager.launchWithStartingState {
entries += blobStore.loadPersistedEntries()
}
}

/**
Expand Down Expand Up @@ -185,9 +233,8 @@ public class SafeBox private constructor(
}

override fun getAll(): Map<String, Any?> {
val encryptedEntries = blobStore.getAll()
val decryptedEntries = HashMap<String, Any?>(encryptedEntries.size, 1f)
for (entry in encryptedEntries) {
val decryptedEntries = HashMap<String, Any?>(entries.size, 1f)
for (entry in entries) {
val key = keyCipherProvider.decrypt(entry.key.value).toString(Charsets.UTF_8)
val value = valueCipherProvider.decrypt(entry.value)
decryptedEntries[key] = byteDecoder.decodeAny(value)
Expand Down Expand Up @@ -226,7 +273,7 @@ public class SafeBox private constructor(
?: defValue

override fun contains(key: String): Boolean =
blobStore.contains(key.toEncryptedKey())
entries.containsKey(key.toEncryptedKey())

override fun edit(): SharedPreferences.Editor = Editor(delegate)

Expand All @@ -243,15 +290,12 @@ public class SafeBox private constructor(
}

private fun getDecryptedValue(key: String): ByteArray? =
blobStore.get(key.toEncryptedKey())
entries[key.toEncryptedKey()]
?.let(valueCipherProvider::decrypt)

private suspend fun applyChanges(entries: LinkedHashMap<String, Action>, cleared: Boolean) {
if (cleared) {
blobStore.deleteAll().forEach { encryptedKey ->
val key = keyCipherProvider.decrypt(encryptedKey.value).toString(Charsets.UTF_8)
listeners.forEach { it.onSharedPreferenceChanged(this, key) }
}
blobStore.deleteAll()
}
for ((key, action) in entries) {
when (action) {
Expand All @@ -267,9 +311,7 @@ public class SafeBox private constructor(
}
}
}
listeners.forEach { it.onSharedPreferenceChanged(this, key) }
}
entries.clear()
}

private fun String.toEncryptedKey(): Bytes =
Expand Down Expand Up @@ -384,7 +426,7 @@ public class SafeBox private constructor(
val keyCipherProvider = ChaCha20CipherProvider(keyProvider, deterministic = true)
val valueCipherProvider = ChaCha20CipherProvider(keyProvider, deterministic = false)
val stateManager = SafeBoxStateManager(fileName, stateListener, ioDispatcher)
val blobStore = SafeBoxBlobStore.create(context, fileName, stateManager)
val blobStore = SafeBoxBlobStore.create(context, fileName)
return SafeBox(blobStore, keyCipherProvider, valueCipherProvider, stateManager)
}

Expand Down Expand Up @@ -421,7 +463,7 @@ public class SafeBox private constructor(
): SafeBox {
SafeBoxBlobFileRegistry.register(fileName)
val stateManager = SafeBoxStateManager(fileName, stateListener, ioDispatcher)
val blobStore = SafeBoxBlobStore.create(context, fileName, stateManager)
val blobStore = SafeBoxBlobStore.create(context, fileName)
return SafeBox(blobStore, keyCipherProvider, valueCipherProvider, stateManager)
}
}
Expand Down
Loading