@@ -166,6 +166,8 @@ class JdbcTest {
166
166
val dataSchema = DataFrame .getSchemaForSqlTable(connection, tableName)
167
167
dataSchema.columns.size shouldBe 2
168
168
dataSchema.columns[" characterCol" ]!! .type shouldBe typeOf<String ?>()
169
+
170
+ connection.createStatement().execute(" DROP TABLE EmptyTestTable" )
169
171
}
170
172
171
173
@Test
@@ -299,6 +301,8 @@ class JdbcTest {
299
301
schema.columns[" realCol" ]!! .type shouldBe typeOf<Float ?>()
300
302
schema.columns[" doublePrecisionCol" ]!! .type shouldBe typeOf<Double ?>()
301
303
schema.columns[" decFloatCol" ]!! .type shouldBe typeOf<BigDecimal ?>()
304
+
305
+ connection.createStatement().execute(" DROP TABLE $tableName " )
302
306
}
303
307
304
308
@Test
@@ -441,7 +445,7 @@ class JdbcTest {
441
445
442
446
rs.beforeFirst()
443
447
444
- val dataSchema1 = DataFrame .getSchemaForResultSet(rs, connection )
448
+ val dataSchema1 = DataFrame .getSchemaForResultSet(rs, H2 ( MySql ) )
445
449
dataSchema1.columns.size shouldBe 3
446
450
dataSchema1.columns[" name" ]!! .type shouldBe typeOf<String ?>()
447
451
}
@@ -493,7 +497,7 @@ class JdbcTest {
493
497
494
498
rs.beforeFirst()
495
499
496
- val dataSchema1 = rs.getDataFrameSchema(connection )
500
+ val dataSchema1 = rs.getDataFrameSchema(H2 ( MySql ) )
497
501
dataSchema1.columns.size shouldBe 3
498
502
dataSchema1.columns[" name" ]!! .type shouldBe typeOf<String ?>()
499
503
}
@@ -613,6 +617,7 @@ class JdbcTest {
613
617
"""
614
618
615
619
DataFrame .readSqlQuery(connection, selectFromWeirdTableSQL).rowsCount() shouldBe 0
620
+ connection.createStatement().execute(" DROP TABLE \" ALTER\" " )
616
621
}
617
622
618
623
@Test
@@ -967,4 +972,127 @@ class JdbcTest {
967
972
}
968
973
exception.message shouldBe " H2 database could not be specified with H2 dialect!"
969
974
}
975
+
976
+ // helper object created for API testing purposes
977
+ object CustomDB : H2(MySql )
978
+
979
+ @Test
980
+ fun `read from table from custom database` () {
981
+ val tableName = " Customer"
982
+ val df = DataFrame .readSqlTable(connection, tableName, dbType = CustomDB ).cast<Customer >()
983
+
984
+ df.rowsCount() shouldBe 4
985
+ df.filter { it[Customer ::age] != null && it[Customer ::age]!! > 30 }.rowsCount() shouldBe 2
986
+ df[0 ][1 ] shouldBe " John"
987
+
988
+ val dataSchema = DataFrame .getSchemaForSqlTable(connection, tableName, dbType = CustomDB )
989
+ dataSchema.columns.size shouldBe 3
990
+ dataSchema.columns[" name" ]!! .type shouldBe typeOf<String ?>()
991
+
992
+ val dbConfig = DbConnectionConfig (url = URL )
993
+ val df2 = DataFrame .readSqlTable(dbConfig, tableName, dbType = CustomDB ).cast<Customer >()
994
+
995
+ df2.rowsCount() shouldBe 4
996
+ df2.filter { it[Customer ::age] != null && it[Customer ::age]!! > 30 }.rowsCount() shouldBe 2
997
+ df2[0 ][1 ] shouldBe " John"
998
+
999
+ val dataSchema1 = DataFrame .getSchemaForSqlTable(dbConfig, tableName, dbType = CustomDB )
1000
+ dataSchema1.columns.size shouldBe 3
1001
+ dataSchema1.columns[" name" ]!! .type shouldBe typeOf<String ?>()
1002
+ }
1003
+
1004
+ @Test
1005
+ fun `read from query from custom database` () {
1006
+ @Language(" SQL" )
1007
+ val sqlQuery =
1008
+ """
1009
+ SELECT c.name as customerName, SUM(s.amount) as totalSalesAmount
1010
+ FROM Sale s
1011
+ INNER JOIN Customer c ON s.customerId = c.id
1012
+ WHERE c.age > 35
1013
+ GROUP BY s.customerId, c.name
1014
+ """ .trimIndent()
1015
+
1016
+ val df = DataFrame .readSqlQuery(connection, sqlQuery, dbType = CustomDB ).cast<CustomerSales >()
1017
+
1018
+ df.rowsCount() shouldBe 2
1019
+ df.filter { it[CustomerSales ::totalSalesAmount]!! > 100 }.rowsCount() shouldBe 1
1020
+ df[0 ][0 ] shouldBe " John"
1021
+
1022
+ val dataSchema = DataFrame .getSchemaForSqlQuery(connection, sqlQuery, dbType = CustomDB )
1023
+ dataSchema.columns.size shouldBe 2
1024
+ dataSchema.columns[" name" ]!! .type shouldBe typeOf<String ?>()
1025
+
1026
+ val dbConfig = DbConnectionConfig (url = URL )
1027
+ val df2 = DataFrame .readSqlQuery(dbConfig, sqlQuery, dbType = CustomDB ).cast<CustomerSales >()
1028
+
1029
+ df2.rowsCount() shouldBe 2
1030
+ df2.filter { it[CustomerSales ::totalSalesAmount]!! > 100 }.rowsCount() shouldBe 1
1031
+ df2[0 ][0 ] shouldBe " John"
1032
+
1033
+ val dataSchema1 = DataFrame .getSchemaForSqlQuery(dbConfig, sqlQuery, dbType = CustomDB )
1034
+ dataSchema1.columns.size shouldBe 2
1035
+ dataSchema1.columns[" name" ]!! .type shouldBe typeOf<String ?>()
1036
+ }
1037
+
1038
+ @Test
1039
+ fun `read from all tables from custom database` () {
1040
+ val dataFrameMap = DataFrame .readAllSqlTables(connection, dbType = CustomDB )
1041
+ dataFrameMap.containsKey(" Customer" ) shouldBe true
1042
+ dataFrameMap.containsKey(" Sale" ) shouldBe true
1043
+
1044
+ val dataframes = dataFrameMap.values.toList()
1045
+
1046
+ val customerDf = dataframes[0 ].cast<Customer >()
1047
+
1048
+ customerDf.rowsCount() shouldBe 4
1049
+ customerDf.filter { it[Customer ::age] != null && it[Customer ::age]!! > 30 }.rowsCount() shouldBe 2
1050
+ customerDf[0 ][1 ] shouldBe " John"
1051
+
1052
+ val saleDf = dataframes[1 ].cast<Sale >()
1053
+
1054
+ saleDf.rowsCount() shouldBe 4
1055
+ saleDf.filter { it[Sale ::amount] > 40 }.rowsCount() shouldBe 3
1056
+ (saleDf[0 ][2 ] as BigDecimal ).compareTo(BigDecimal (100.50 )) shouldBe 0
1057
+
1058
+ val dataFrameSchemaMap = DataFrame .getSchemaForAllSqlTables(connection, dbType = CustomDB )
1059
+ dataFrameSchemaMap.containsKey(" Customer" ) shouldBe true
1060
+ dataFrameSchemaMap.containsKey(" Sale" ) shouldBe true
1061
+
1062
+ val dataSchemas = dataFrameSchemaMap.values.toList()
1063
+
1064
+ val customerDataSchema = dataSchemas[0 ]
1065
+ customerDataSchema.columns.size shouldBe 3
1066
+ customerDataSchema.columns[" name" ]!! .type shouldBe typeOf<String ?>()
1067
+
1068
+ val saleDataSchema = dataSchemas[1 ]
1069
+ saleDataSchema.columns.size shouldBe 3
1070
+ // TODO: fix nullability
1071
+ saleDataSchema.columns[" amount" ]!! .type shouldBe typeOf<BigDecimal >()
1072
+
1073
+ val dbConfig = DbConnectionConfig (url = URL )
1074
+ val dataframes2 = DataFrame .readAllSqlTables(dbConfig, dbType = CustomDB ).values.toList()
1075
+
1076
+ val customerDf2 = dataframes2[0 ].cast<Customer >()
1077
+
1078
+ customerDf2.rowsCount() shouldBe 4
1079
+ customerDf2.filter { it[Customer ::age] != null && it[Customer ::age]!! > 30 }.rowsCount() shouldBe 2
1080
+ customerDf2[0 ][1 ] shouldBe " John"
1081
+
1082
+ val saleDf2 = dataframes2[1 ].cast<Sale >()
1083
+
1084
+ saleDf2.rowsCount() shouldBe 4
1085
+ saleDf2.filter { it[Sale ::amount] > 40 }.rowsCount() shouldBe 3
1086
+ (saleDf[0 ][2 ] as BigDecimal ).compareTo(BigDecimal (100.50 )) shouldBe 0
1087
+
1088
+ val dataSchemas1 = DataFrame .getSchemaForAllSqlTables(dbConfig, dbType = CustomDB ).values.toList()
1089
+
1090
+ val customerDataSchema1 = dataSchemas1[0 ]
1091
+ customerDataSchema1.columns.size shouldBe 3
1092
+ customerDataSchema1.columns[" name" ]!! .type shouldBe typeOf<String ?>()
1093
+
1094
+ val saleDataSchema1 = dataSchemas1[1 ]
1095
+ saleDataSchema1.columns.size shouldBe 3
1096
+ saleDataSchema1.columns[" amount" ]!! .type shouldBe typeOf<BigDecimal >()
1097
+ }
970
1098
}
0 commit comments