99 "sort"
1010 "strconv"
1111 "strings"
12+ "sync"
1213 "time"
1314
1415 "github.com/sirupsen/logrus"
@@ -23,6 +24,8 @@ type CSVSource struct {
2324 currentFile string // Current file being processed
2425 columns []string // Column names from CSV header
2526 totalRows int // Total rows across all files
27+ totalRowsOnce sync.Once // Ensures totalRows is computed only once
28+ totalRowsErr error // Error from computing totalRows
2629}
2730
2831func NewCSVSource (cfg * config.Config ) (* CSVSource , error ) {
@@ -92,33 +95,47 @@ func (s *CSVSource) AdjustBatchSizeAccordingToSourceDbTable() uint64 {
9295}
9396
9497// GetSourceReadRowsCount returns the total number of rows in all CSV files
98+ // This method is thread-safe and caches the result after first call
9599func (s * CSVSource ) GetSourceReadRowsCount () (int , error ) {
96- if s . totalRows > 0 {
97- return s . totalRows , nil
98- }
99-
100- totalRows := 0
101- for _ , file := range s . files {
102- count , err := countCSVRows ( file )
103- if err != nil {
104- return 0 , fmt . Errorf ( "failed to count rows in %s: %w" , file , err )
100+ s . totalRowsOnce . Do ( func () {
101+ totalRows := 0
102+ for _ , file := range s . files {
103+ count , err := s . countCSVRows ( file )
104+ if err != nil {
105+ s . totalRowsErr = fmt . Errorf ( "failed to count rows in %s: %w" , file , err )
106+ return
107+ }
108+ totalRows += count
105109 }
106- totalRows += count
107- }
110+ s . totalRows = totalRows
111+ })
108112
109- s .totalRows = totalRows
110- return totalRows , nil
113+ if s .totalRowsErr != nil {
114+ return 0 , s .totalRowsErr
115+ }
116+ return s .totalRows , nil
111117}
112118
113119// countCSVRows counts the number of data rows in a CSV file (excluding header)
114- func countCSVRows (filename string ) (int , error ) {
120+ func ( s * CSVSource ) countCSVRows (filename string ) (int , error ) {
115121 file , err := os .Open (filename )
116122 if err != nil {
117123 return 0 , err
118124 }
119125 defer file .Close ()
120126
121127 reader := csv .NewReader (file )
128+
129+ // Configure CSV delimiter if specified
130+ if s .cfg .SourceCSVDelimiter != "" {
131+ delimiter := s .cfg .SourceCSVDelimiter
132+ if delimiter == "tab" || delimiter == "\\ t" {
133+ reader .Comma = '\t'
134+ } else if len (delimiter ) > 0 {
135+ reader .Comma = rune (delimiter [0 ])
136+ }
137+ }
138+
122139 count := 0
123140
124141 // Skip header
@@ -209,7 +226,18 @@ func (s *CSVSource) QueryTableData(threadNum int, conditionSql string) ([][]inte
209226 }
210227
211228 if len (columns ) == 0 {
229+ // Use the first non-empty file's columns as the reference schema
212230 columns = cols
231+ } else if len (cols ) > 0 {
232+ // Validate that subsequent files have the same columns
233+ if len (cols ) != len (columns ) {
234+ return nil , nil , fmt .Errorf ("header mismatch in file %s: expected %d columns, got %d columns" , file , len (columns ), len (cols ))
235+ }
236+ for i := range columns {
237+ if columns [i ] != cols [i ] {
238+ return nil , nil , fmt .Errorf ("header mismatch in file %s at column %d: expected %q, got %q" , file , i , columns [i ], cols [i ])
239+ }
240+ }
213241 }
214242
215243 allData = append (allData , data ... )
@@ -229,6 +257,7 @@ func (s *CSVSource) QueryTableData(threadNum int, conditionSql string) ([][]inte
229257}
230258
231259// readCSVFile reads a specific range of rows from a CSV file
260+ // Optimized to skip rows before startRow to improve performance for parallel processing
232261func (s * CSVSource ) readCSVFile (filename string , startRow , endRow , currentRow uint64 ) ([][]interface {}, []string , uint64 , error ) {
233262 file , err := os .Open (filename )
234263 if err != nil {
@@ -238,6 +267,16 @@ func (s *CSVSource) readCSVFile(filename string, startRow, endRow, currentRow ui
238267
239268 reader := csv .NewReader (file )
240269
270+ // Configure CSV delimiter if specified
271+ if s .cfg .SourceCSVDelimiter != "" {
272+ delimiter := s .cfg .SourceCSVDelimiter
273+ if delimiter == "tab" || delimiter == "\\ t" {
274+ reader .Comma = '\t'
275+ } else if len (delimiter ) > 0 {
276+ reader .Comma = rune (delimiter [0 ])
277+ }
278+ }
279+
241280 // Read header
242281 header , err := reader .Read ()
243282 if err != nil {
@@ -247,8 +286,21 @@ func (s *CSVSource) readCSVFile(filename string, startRow, endRow, currentRow ui
247286 var data [][]interface {}
248287 rowNum := currentRow
249288
250- // Read all rows
251- for {
289+ // Skip rows before startRow for better performance
290+ for rowNum < startRow {
291+ _ , err := reader .Read ()
292+ if err == io .EOF {
293+ // Reached end of file before startRow
294+ return data , header , rowNum - 1 , nil
295+ }
296+ if err != nil {
297+ return nil , nil , 0 , fmt .Errorf ("failed to skip row: %w" , err )
298+ }
299+ rowNum ++
300+ }
301+
302+ // Read rows in the desired range
303+ for rowNum < endRow {
252304 record , err := reader .Read ()
253305 if err == io .EOF {
254306 break
@@ -257,29 +309,27 @@ func (s *CSVSource) readCSVFile(filename string, startRow, endRow, currentRow ui
257309 return nil , nil , 0 , fmt .Errorf ("failed to read row: %w" , err )
258310 }
259311
260- // Check if this row is in the desired range
261- if rowNum >= startRow && rowNum < endRow {
262- // Convert string values to interface{}
263- row := make ([]interface {}, len (record ))
264- for i , val := range record {
265- row [i ] = convertCSVValue (val )
266- }
267- data = append (data , row )
312+ // Convert string values to interface{}
313+ row := make ([]interface {}, len (record ))
314+ for i , val := range record {
315+ row [i ] = convertCSVValue (val )
268316 }
269-
317+ data = append ( data , row )
270318 rowNum ++
271-
272- // If we've passed the end row, stop reading this file
273- if rowNum >= endRow {
274- break
275- }
276319 }
277320
278321 return data , header , rowNum - 1 , nil
279322}
280323
281324// convertCSVValue attempts to convert CSV string values to appropriate types
325+ // Empty strings are returned as empty strings (not nil) to maintain consistency
282326func convertCSVValue (val string ) interface {} {
327+ // Handle empty strings - return as-is
328+ // Note: Empty CSV cells will be imported as empty strings in Databend
329+ if val == "" {
330+ return val
331+ }
332+
283333 // Try to parse as integer
284334 if intVal , err := strconv .ParseInt (val , 10 , 64 ); err == nil {
285335 return intVal
@@ -303,7 +353,8 @@ func convertCSVValue(val string) interface{} {
303353}
304354
305355// parseRowCondition parses a condition like "(row_num >= 1 and row_num < 1001)"
306- // and returns the start and end row numbers
356+ // Supports both >= and > for start condition, and both < and <= for end condition
357+ // Returns the start and end row numbers
307358func parseRowCondition (condition string ) (uint64 , uint64 , error ) {
308359 // Remove parentheses and split by "and"
309360 condition = strings .Trim (condition , "()" )
@@ -316,7 +367,7 @@ func parseRowCondition(condition string) (uint64, uint64, error) {
316367 var startRow , endRow uint64
317368 var err error
318369
319- // Parse first part (e.g., "row_num >= 1")
370+ // Parse first part (e.g., "row_num >= 1" or "row_num > 0" )
320371 if strings .Contains (parts [0 ], ">=" ) {
321372 fields := strings .Split (parts [0 ], ">=" )
322373 if len (fields ) != 2 {
@@ -326,9 +377,21 @@ func parseRowCondition(condition string) (uint64, uint64, error) {
326377 if err != nil {
327378 return 0 , 0 , fmt .Errorf ("failed to parse start row: %w" , err )
328379 }
380+ } else if strings .Contains (parts [0 ], ">" ) {
381+ fields := strings .Split (parts [0 ], ">" )
382+ if len (fields ) != 2 {
383+ return 0 , 0 , fmt .Errorf ("invalid start condition: %s" , parts [0 ])
384+ }
385+ startRow , err = strconv .ParseUint (strings .TrimSpace (fields [1 ]), 10 , 64 )
386+ if err != nil {
387+ return 0 , 0 , fmt .Errorf ("failed to parse start row: %w" , err )
388+ }
389+ startRow ++ // Convert > to >=
390+ } else {
391+ return 0 , 0 , fmt .Errorf ("invalid start condition (missing >= or >): %s" , parts [0 ])
329392 }
330393
331- // Parse second part (e.g., "row_num < 1001")
394+ // Parse second part (e.g., "row_num < 1001" or "row_num <= 1000" )
332395 if strings .Contains (parts [1 ], "<=" ) {
333396 fields := strings .Split (parts [1 ], "<=" )
334397 if len (fields ) != 2 {
@@ -348,6 +411,8 @@ func parseRowCondition(condition string) (uint64, uint64, error) {
348411 if err != nil {
349412 return 0 , 0 , fmt .Errorf ("failed to parse end row: %w" , err )
350413 }
414+ } else {
415+ return 0 , 0 , fmt .Errorf ("invalid end condition (missing < or <=): %s" , parts [1 ])
351416 }
352417
353418 return startRow , endRow , nil
0 commit comments