Skip to content

Commit 945fb49

Browse files
committed
Merge branch 'master' into statistics-fixes
2 parents 63ee929 + 79bd076 commit 945fb49

File tree

8 files changed

+302
-122
lines changed

8 files changed

+302
-122
lines changed
File renamed without changes.

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/api/parse.kt

+13-16
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,13 @@ import org.jetbrains.kotlinx.dataframe.api.ParserOptions
2020
import org.jetbrains.kotlinx.dataframe.api.asColumnGroup
2121
import org.jetbrains.kotlinx.dataframe.api.asDataColumn
2222
import org.jetbrains.kotlinx.dataframe.api.cast
23-
import org.jetbrains.kotlinx.dataframe.api.emptyDataFrame
24-
import org.jetbrains.kotlinx.dataframe.api.getColumnsWithPaths
23+
import org.jetbrains.kotlinx.dataframe.api.convert
2524
import org.jetbrains.kotlinx.dataframe.api.isColumnGroup
2625
import org.jetbrains.kotlinx.dataframe.api.isFrameColumn
2726
import org.jetbrains.kotlinx.dataframe.api.isSubtypeOf
28-
import org.jetbrains.kotlinx.dataframe.api.toColumn
27+
import org.jetbrains.kotlinx.dataframe.api.map
28+
import org.jetbrains.kotlinx.dataframe.api.parse
29+
import org.jetbrains.kotlinx.dataframe.api.to
2930
import org.jetbrains.kotlinx.dataframe.api.tryParse
3031
import org.jetbrains.kotlinx.dataframe.columns.TypeSuggestion
3132
import org.jetbrains.kotlinx.dataframe.columns.size
@@ -531,17 +532,16 @@ internal fun <T> DataColumn<String?>.parse(parser: StringParser<T>, options: Par
531532
)
532533
}
533534

534-
internal fun <T> DataFrame<T>.parseImpl(options: ParserOptions?, columns: ColumnsSelector<T, Any?>): DataFrame<T> {
535-
val convertedCols = getColumnsWithPaths(columns).map { col ->
535+
internal fun <T> DataFrame<T>.parseImpl(options: ParserOptions?, columns: ColumnsSelector<T, Any?>): DataFrame<T> =
536+
convert(columns).to { col ->
536537
when {
537538
// when a frame column is requested to be parsed,
538539
// parse each value/frame column at any depth inside each DataFrame in the frame column
539-
col.isFrameColumn() ->
540-
col.values.map {
541-
it.parseImpl(options) {
542-
colsAtAnyDepth { !it.isColumnGroup() }
543-
}
544-
}.toColumn(col.name)
540+
col.isFrameColumn() -> col.map {
541+
it.parseImpl(options) {
542+
colsAtAnyDepth { !it.isColumnGroup() }
543+
}
544+
}
545545

546546
// when a column group is requested to be parsed,
547547
// parse each column in the group
@@ -552,11 +552,8 @@ internal fun <T> DataFrame<T>.parseImpl(options: ParserOptions?, columns: Column
552552

553553
// Base case, parse the column if it's a `String?` column
554554
col.isSubtypeOf<String?>() ->
555-
col.cast<String?>().tryParse(options)
555+
col.cast<String?>().tryParseImpl(options)
556556

557557
else -> col
558-
}.let { ColumnToInsert(col.path, it) }
558+
}
559559
}
560-
561-
return emptyDataFrame<T>().insertImpl(convertedCols)
562-
}

dataframe-jdbc/api/dataframe-jdbc.api

+37-31
Large diffs are not rendered by default.

dataframe-jdbc/build.gradle.kts

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ dependencies {
2828
testImplementation(libs.h2db)
2929
testImplementation(libs.mssql)
3030
testImplementation(libs.junit)
31-
testImplementation(libs.sl4j)
31+
testImplementation(libs.sl4jsimple)
3232
testImplementation(libs.jts)
3333
testImplementation(libs.kotestAssertions) {
3434
exclude("org.jetbrains.kotlin", "kotlin-stdlib-jdk8")

dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/H2.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ import kotlin.reflect.KType
1515
*
1616
* NOTE: All date and timestamp-related types are converted to String to avoid java.sql.* types.
1717
*/
18-
public class H2(public val dialect: DbType = MySql) : DbType("h2") {
18+
public open class H2(public val dialect: DbType = MySql) : DbType("h2") {
1919
init {
2020
require(dialect::class != H2::class) { "H2 database could not be specified with H2 dialect!" }
2121
}

dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt

+118-70
Large diffs are not rendered by default.

dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/h2Test.kt

+130-2
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@ class JdbcTest {
166166
val dataSchema = DataFrame.getSchemaForSqlTable(connection, tableName)
167167
dataSchema.columns.size shouldBe 2
168168
dataSchema.columns["characterCol"]!!.type shouldBe typeOf<String?>()
169+
170+
connection.createStatement().execute("DROP TABLE EmptyTestTable")
169171
}
170172

171173
@Test
@@ -299,6 +301,8 @@ class JdbcTest {
299301
schema.columns["realCol"]!!.type shouldBe typeOf<Float?>()
300302
schema.columns["doublePrecisionCol"]!!.type shouldBe typeOf<Double?>()
301303
schema.columns["decFloatCol"]!!.type shouldBe typeOf<BigDecimal?>()
304+
305+
connection.createStatement().execute("DROP TABLE $tableName")
302306
}
303307

304308
@Test
@@ -441,7 +445,7 @@ class JdbcTest {
441445

442446
rs.beforeFirst()
443447

444-
val dataSchema1 = DataFrame.getSchemaForResultSet(rs, connection)
448+
val dataSchema1 = DataFrame.getSchemaForResultSet(rs, H2(MySql))
445449
dataSchema1.columns.size shouldBe 3
446450
dataSchema1.columns["name"]!!.type shouldBe typeOf<String?>()
447451
}
@@ -493,7 +497,7 @@ class JdbcTest {
493497

494498
rs.beforeFirst()
495499

496-
val dataSchema1 = rs.getDataFrameSchema(connection)
500+
val dataSchema1 = rs.getDataFrameSchema(H2(MySql))
497501
dataSchema1.columns.size shouldBe 3
498502
dataSchema1.columns["name"]!!.type shouldBe typeOf<String?>()
499503
}
@@ -613,6 +617,7 @@ class JdbcTest {
613617
"""
614618

615619
DataFrame.readSqlQuery(connection, selectFromWeirdTableSQL).rowsCount() shouldBe 0
620+
connection.createStatement().execute("DROP TABLE \"ALTER\"")
616621
}
617622

618623
@Test
@@ -967,4 +972,127 @@ class JdbcTest {
967972
}
968973
exception.message shouldBe "H2 database could not be specified with H2 dialect!"
969974
}
975+
976+
// helper object created for API testing purposes
977+
object CustomDB : H2(MySql)
978+
979+
@Test
980+
fun `read from table from custom database`() {
981+
val tableName = "Customer"
982+
val df = DataFrame.readSqlTable(connection, tableName, dbType = CustomDB).cast<Customer>()
983+
984+
df.rowsCount() shouldBe 4
985+
df.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 2
986+
df[0][1] shouldBe "John"
987+
988+
val dataSchema = DataFrame.getSchemaForSqlTable(connection, tableName, dbType = CustomDB)
989+
dataSchema.columns.size shouldBe 3
990+
dataSchema.columns["name"]!!.type shouldBe typeOf<String?>()
991+
992+
val dbConfig = DbConnectionConfig(url = URL)
993+
val df2 = DataFrame.readSqlTable(dbConfig, tableName, dbType = CustomDB).cast<Customer>()
994+
995+
df2.rowsCount() shouldBe 4
996+
df2.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 2
997+
df2[0][1] shouldBe "John"
998+
999+
val dataSchema1 = DataFrame.getSchemaForSqlTable(dbConfig, tableName, dbType = CustomDB)
1000+
dataSchema1.columns.size shouldBe 3
1001+
dataSchema1.columns["name"]!!.type shouldBe typeOf<String?>()
1002+
}
1003+
1004+
@Test
1005+
fun `read from query from custom database`() {
1006+
@Language("SQL")
1007+
val sqlQuery =
1008+
"""
1009+
SELECT c.name as customerName, SUM(s.amount) as totalSalesAmount
1010+
FROM Sale s
1011+
INNER JOIN Customer c ON s.customerId = c.id
1012+
WHERE c.age > 35
1013+
GROUP BY s.customerId, c.name
1014+
""".trimIndent()
1015+
1016+
val df = DataFrame.readSqlQuery(connection, sqlQuery, dbType = CustomDB).cast<CustomerSales>()
1017+
1018+
df.rowsCount() shouldBe 2
1019+
df.filter { it[CustomerSales::totalSalesAmount]!! > 100 }.rowsCount() shouldBe 1
1020+
df[0][0] shouldBe "John"
1021+
1022+
val dataSchema = DataFrame.getSchemaForSqlQuery(connection, sqlQuery, dbType = CustomDB)
1023+
dataSchema.columns.size shouldBe 2
1024+
dataSchema.columns["name"]!!.type shouldBe typeOf<String?>()
1025+
1026+
val dbConfig = DbConnectionConfig(url = URL)
1027+
val df2 = DataFrame.readSqlQuery(dbConfig, sqlQuery, dbType = CustomDB).cast<CustomerSales>()
1028+
1029+
df2.rowsCount() shouldBe 2
1030+
df2.filter { it[CustomerSales::totalSalesAmount]!! > 100 }.rowsCount() shouldBe 1
1031+
df2[0][0] shouldBe "John"
1032+
1033+
val dataSchema1 = DataFrame.getSchemaForSqlQuery(dbConfig, sqlQuery, dbType = CustomDB)
1034+
dataSchema1.columns.size shouldBe 2
1035+
dataSchema1.columns["name"]!!.type shouldBe typeOf<String?>()
1036+
}
1037+
1038+
@Test
1039+
fun `read from all tables from custom database`() {
1040+
val dataFrameMap = DataFrame.readAllSqlTables(connection, dbType = CustomDB)
1041+
dataFrameMap.containsKey("Customer") shouldBe true
1042+
dataFrameMap.containsKey("Sale") shouldBe true
1043+
1044+
val dataframes = dataFrameMap.values.toList()
1045+
1046+
val customerDf = dataframes[0].cast<Customer>()
1047+
1048+
customerDf.rowsCount() shouldBe 4
1049+
customerDf.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 2
1050+
customerDf[0][1] shouldBe "John"
1051+
1052+
val saleDf = dataframes[1].cast<Sale>()
1053+
1054+
saleDf.rowsCount() shouldBe 4
1055+
saleDf.filter { it[Sale::amount] > 40 }.rowsCount() shouldBe 3
1056+
(saleDf[0][2] as BigDecimal).compareTo(BigDecimal(100.50)) shouldBe 0
1057+
1058+
val dataFrameSchemaMap = DataFrame.getSchemaForAllSqlTables(connection, dbType = CustomDB)
1059+
dataFrameSchemaMap.containsKey("Customer") shouldBe true
1060+
dataFrameSchemaMap.containsKey("Sale") shouldBe true
1061+
1062+
val dataSchemas = dataFrameSchemaMap.values.toList()
1063+
1064+
val customerDataSchema = dataSchemas[0]
1065+
customerDataSchema.columns.size shouldBe 3
1066+
customerDataSchema.columns["name"]!!.type shouldBe typeOf<String?>()
1067+
1068+
val saleDataSchema = dataSchemas[1]
1069+
saleDataSchema.columns.size shouldBe 3
1070+
// TODO: fix nullability
1071+
saleDataSchema.columns["amount"]!!.type shouldBe typeOf<BigDecimal>()
1072+
1073+
val dbConfig = DbConnectionConfig(url = URL)
1074+
val dataframes2 = DataFrame.readAllSqlTables(dbConfig, dbType = CustomDB).values.toList()
1075+
1076+
val customerDf2 = dataframes2[0].cast<Customer>()
1077+
1078+
customerDf2.rowsCount() shouldBe 4
1079+
customerDf2.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 2
1080+
customerDf2[0][1] shouldBe "John"
1081+
1082+
val saleDf2 = dataframes2[1].cast<Sale>()
1083+
1084+
saleDf2.rowsCount() shouldBe 4
1085+
saleDf2.filter { it[Sale::amount] > 40 }.rowsCount() shouldBe 3
1086+
(saleDf[0][2] as BigDecimal).compareTo(BigDecimal(100.50)) shouldBe 0
1087+
1088+
val dataSchemas1 = DataFrame.getSchemaForAllSqlTables(dbConfig, dbType = CustomDB).values.toList()
1089+
1090+
val customerDataSchema1 = dataSchemas1[0]
1091+
customerDataSchema1.columns.size shouldBe 3
1092+
customerDataSchema1.columns["name"]!!.type shouldBe typeOf<String?>()
1093+
1094+
val saleDataSchema1 = dataSchemas1[1]
1095+
saleDataSchema1.columns.size shouldBe 3
1096+
saleDataSchema1.columns["amount"]!!.type shouldBe typeOf<BigDecimal>()
1097+
}
9701098
}

gradle/libs.versions.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ kotlinpoet = { group = "com.squareup", name = "kotlinpoet", version.ref = "kotli
111111
swagger = { group = "io.swagger.parser.v3", name = "swagger-parser", version.ref = "openapi" }
112112

113113
kotlinLogging = { group = "io.github.oshai", name = "kotlin-logging", version.ref = "kotlinLogging" }
114-
sl4j = { group = "org.slf4j", name = "slf4j-simple", version.ref = "sl4j" }
114+
sl4j = { group = "org.slf4j", name = "slf4j-api", version.ref = "sl4j" }
115+
sl4jsimple = { group = "org.slf4j", name = "slf4j-simple", version.ref = "sl4j" }
115116
android-gradle-api = { group = "com.android.tools.build", name = "gradle-api", version.ref = "android-gradle-api" }
116117
android-gradle = { group = "com.android.tools.build", name = "gradle", version.ref = "android-gradle-api" }
117118
kotlin-gradle-plugin = { group = "org.jetbrains.kotlin", name = "kotlin-gradle-plugin" }

0 commit comments

Comments
 (0)