readstat/
run.rs

1//! CLI dispatch logic for the readstat binary.
2
3use colored::Colorize;
4use crossbeam::channel::bounded;
5use indicatif::{ProgressBar, ProgressStyle};
6use log::debug;
7use path_abs::{PathAbs, PathInfo};
8use rayon::prelude::*;
9use std::path::PathBuf;
10use std::sync::Arc;
11use std::thread;
12
13use readstat::{
14    OutFormat, ProgressCallback, ReadStatData, ReadStatError, ReadStatMetadata, ReadStatPath,
15    ReadStatWriter, WriteConfig, build_offsets,
16};
17
18use crate::cli::{ReadStatCli, ReadStatCliCommands, Reader};
19
20/// Default number of rows to read per streaming chunk.
21const STREAM_ROWS: u32 = 10000;
22
23/// Capacity of the bounded channel between reader and writer threads.
24/// Also used as the batch size for bounded-batch parallel writes.
25const CHANNEL_CAPACITY: usize = 10;
26
27/// Determine stream row count based on reader type.
28fn resolve_stream_rows(reader: Option<Reader>, stream_rows: Option<u32>, total_rows: u32) -> u32 {
29    match reader {
30        Some(Reader::Stream) | None => stream_rows.unwrap_or(STREAM_ROWS),
31        Some(Reader::Mem) => total_rows,
32    }
33}
34
35/// [`ProgressCallback`] implementation backed by an `indicatif::ProgressBar`.
36struct IndicatifProgress {
37    pb: ProgressBar,
38}
39
40impl ProgressCallback for IndicatifProgress {
41    fn inc(&self, n: u64) {
42        self.pb.inc(n);
43    }
44
45    fn parsing_started(&self, path: &str) {
46        if let Ok(style) =
47            ProgressStyle::default_spinner().template("[{spinner:.green} {elapsed_precise}] {msg}")
48        {
49            self.pb.set_style(style);
50        }
51        self.pb
52            .set_message(format!("Parsing sas7bdat data from file {path}"));
53        self.pb
54            .enable_steady_tick(std::time::Duration::from_millis(120));
55    }
56}
57
58/// Create a progress bar if progress is enabled.
59fn create_progress(
60    no_progress: bool,
61    total_rows: u32,
62) -> Result<Option<Arc<IndicatifProgress>>, ReadStatError> {
63    if no_progress {
64        return Ok(None);
65    }
66    let pb = ProgressBar::new(u64::from(total_rows));
67    pb.set_style(
68        ProgressStyle::default_bar()
69            .template("[{elapsed_precise}] {bar:40.cyan/blue} {pos:>7}/{len:7} rows {msg}")
70            .map_err(|e| ReadStatError::Other(format!("Progress bar template error: {e}")))?
71            .progress_chars("##-"),
72    );
73    Ok(Some(Arc::new(IndicatifProgress { pb })))
74}
75
76/// Resolve column names from `--columns` or `--columns-file` CLI options.
77fn resolve_columns(
78    columns: Option<Vec<String>>,
79    columns_file: Option<PathBuf>,
80) -> Result<Option<Vec<String>>, ReadStatError> {
81    if let Some(path) = columns_file {
82        let names = ReadStatMetadata::parse_columns_file(&path)?;
83        if names.is_empty() {
84            Ok(None)
85        } else {
86            Ok(Some(names))
87        }
88    } else {
89        Ok(columns)
90    }
91}
92
93/// Resolve the SQL query from `--sql` or `--sql-file` CLI options.
94#[cfg(feature = "sql")]
95fn resolve_sql(
96    sql: Option<String>,
97    sql_file: Option<PathBuf>,
98) -> Result<Option<String>, ReadStatError> {
99    if let Some(path) = sql_file {
100        Ok(Some(readstat::read_sql_file(&path)?))
101    } else {
102        Ok(sql)
103    }
104}
105
106/// Extract a table name from the input file stem (e.g. "cars" from "cars.sas7bdat").
107#[cfg(feature = "sql")]
108fn table_name_from_path(path: &std::path::Path) -> String {
109    path.file_stem()
110        .and_then(|s| s.to_str())
111        .unwrap_or("data")
112        .to_string()
113}
114
115/// Executes the CLI command specified by the parsed [`ReadStatCli`] arguments.
116///
117/// This is the main entry point for the CLI binary, dispatching to the
118/// `metadata`, `preview`, or `data` subcommand.
119pub fn run(rs: ReadStatCli) -> Result<(), ReadStatError> {
120    env_logger::init();
121
122    match rs.command {
123        cmd @ ReadStatCliCommands::Metadata { .. } => run_metadata(cmd),
124        cmd @ ReadStatCliCommands::Preview { .. } => run_preview(cmd),
125        cmd @ ReadStatCliCommands::Data { .. } => run_data(cmd),
126    }
127}
128
129/// Handle the `metadata` subcommand: read and display SAS file metadata.
130fn run_metadata(cmd: ReadStatCliCommands) -> Result<(), ReadStatError> {
131    let ReadStatCliCommands::Metadata {
132        input: in_path,
133        as_json,
134        no_progress: _,
135        skip_row_count,
136    } = cmd
137    else {
138        unreachable!()
139    };
140    let sas_path = PathAbs::new(in_path)?.as_path().to_path_buf();
141    debug!(
142        "Retrieving metadata from the file {}",
143        &sas_path.to_string_lossy()
144    );
145
146    let rsp = ReadStatPath::new(sas_path)?;
147    let mut md = ReadStatMetadata::new();
148    md.read_metadata(&rsp, skip_row_count)?;
149    ReadStatWriter::new().write_metadata(&md, &rsp, as_json)?;
150    Ok(())
151}
152
153/// Handle the `preview` subcommand: read a limited number of rows and write to stdout as CSV.
154#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
155fn run_preview(cmd: ReadStatCliCommands) -> Result<(), ReadStatError> {
156    let ReadStatCliCommands::Preview {
157        input,
158        rows,
159        reader,
160        stream_rows,
161        no_progress,
162        columns,
163        columns_file,
164        #[cfg(feature = "sql")]
165        sql,
166        #[cfg(feature = "sql")]
167        sql_file,
168    } = cmd
169    else {
170        unreachable!()
171    };
172
173    #[cfg(feature = "sql")]
174    let sql_query = resolve_sql(sql, sql_file)?;
175
176    let sas_path = PathAbs::new(input)?.as_path().to_path_buf();
177    debug!(
178        "Generating data preview from the file {}",
179        &sas_path.to_string_lossy()
180    );
181
182    let rsp = ReadStatPath::new(sas_path)?;
183    let mut md = ReadStatMetadata::new();
184    md.read_metadata(&rsp, false)?;
185
186    // Resolve column selection
187    let col_names = resolve_columns(columns, columns_file)?;
188    let column_filter = md.resolve_selected_columns(col_names)?;
189    let original_var_count = md.var_count;
190    if let Some(ref mapping) = column_filter {
191        md = md.filter_to_selected_columns(mapping);
192    }
193
194    let column_filter = column_filter.map(Arc::new);
195    let total_rows_to_process = std::cmp::min(rows, md.row_count as u32);
196    let total_rows_to_stream = resolve_stream_rows(reader, stream_rows, total_rows_to_process);
197    let total_rows_processed = Arc::new(std::sync::atomic::AtomicUsize::new(0));
198    let progress = create_progress(no_progress, total_rows_to_process)?;
199
200    let offsets = build_offsets(total_rows_to_process, total_rows_to_stream)?;
201    let offsets_pairs = offsets.windows(2);
202
203    let var_count = md.var_count;
204    let vars_shared = Arc::new(md.vars);
205    let schema_shared = Arc::new(md.schema);
206
207    // Read all chunks into batches
208    let mut all_batches: Vec<arrow_array::RecordBatch> = Vec::new();
209    for w in offsets_pairs {
210        let row_start = w[0];
211        let row_end = w[1];
212
213        let mut d = ReadStatData::new()
214            .set_column_filter(column_filter.clone(), original_var_count)
215            .set_no_progress(no_progress)
216            .set_total_rows_to_process(total_rows_to_process as usize)
217            .set_total_rows_processed(total_rows_processed.clone())
218            .init_shared(
219                var_count,
220                vars_shared.clone(),
221                schema_shared.clone(),
222                row_start,
223                row_end,
224            );
225
226        if let Some(ref p) = progress {
227            d = d.set_progress(p.clone() as Arc<dyn ProgressCallback>);
228        }
229
230        d.read_data(&rsp)?;
231
232        if let Some(batch) = d.batch {
233            all_batches.push(batch);
234        }
235    }
236
237    if let Some(p) = progress {
238        p.pb.finish_with_message("Done");
239    }
240
241    // Apply SQL query if provided
242    #[cfg(feature = "sql")]
243    let all_batches = if let Some(ref query) = sql_query {
244        let table_name = table_name_from_path(&rsp.path);
245        readstat::execute_sql(all_batches, schema_shared.clone(), &table_name, query)?
246    } else {
247        all_batches
248    };
249
250    // Write all batches to stdout as CSV
251    #[cfg(feature = "csv")]
252    {
253        let stdout = std::io::stdout();
254        let mut csv_writer = arrow_csv::WriterBuilder::new()
255            .with_header(true)
256            .build(stdout);
257        for batch in &all_batches {
258            csv_writer.write(batch)?;
259        }
260    }
261    #[cfg(not(feature = "csv"))]
262    {
263        let _ = all_batches;
264        return Err(ReadStatError::Other(
265            "CSV feature is required for preview output".to_string(),
266        ));
267    }
268    #[cfg(feature = "csv")]
269    Ok(())
270}
271
272/// Handle the `data` subcommand: read SAS data and write to an output file.
273#[allow(
274    clippy::too_many_lines,
275    clippy::cast_sign_loss,
276    clippy::cast_possible_truncation
277)]
278fn run_data(cmd: ReadStatCliCommands) -> Result<(), ReadStatError> {
279    let ReadStatCliCommands::Data {
280        input,
281        output,
282        format,
283        rows,
284        reader,
285        stream_rows,
286        no_progress,
287        overwrite,
288        parallel,
289        parallel_write,
290        #[cfg(feature = "parquet")]
291        parallel_write_buffer_mb,
292        #[cfg(not(feature = "parquet"))]
293            parallel_write_buffer_mb: _,
294        compression,
295        compression_level,
296        columns,
297        columns_file,
298        #[cfg(feature = "sql")]
299        sql,
300        #[cfg(feature = "sql")]
301        sql_file,
302    } = cmd
303    else {
304        unreachable!()
305    };
306
307    #[cfg(feature = "sql")]
308    let sql_query = resolve_sql(sql, sql_file)?;
309
310    let sas_path = PathAbs::new(input)?.as_path().to_path_buf();
311    debug!(
312        "Generating data from the file {}",
313        &sas_path.to_string_lossy()
314    );
315
316    let rsp = ReadStatPath::new(sas_path)?;
317    let wc = WriteConfig::new(
318        output,
319        format.map(Into::into),
320        overwrite,
321        compression.map(Into::into),
322        compression_level,
323    )?;
324
325    let mut md = ReadStatMetadata::new();
326    md.read_metadata(&rsp, false)?;
327
328    // Resolve column selection
329    let col_names = resolve_columns(columns, columns_file)?;
330    let column_filter = md.resolve_selected_columns(col_names)?;
331    let original_var_count = md.var_count;
332    if let Some(ref mapping) = column_filter {
333        md = md.filter_to_selected_columns(mapping);
334    }
335
336    let column_filter = column_filter.map(Arc::new);
337
338    // If no output path then only read metadata; otherwise read data
339    match &wc.out_path {
340        None => {
341            println!(
342                "{}: a value was not provided for the parameter {}, thus displaying metadata only\n",
343                "Warning".bright_yellow(),
344                "--output".bright_cyan()
345            );
346
347            let mut md = ReadStatMetadata::new();
348            md.read_metadata(&rsp, false)?;
349            ReadStatWriter::new().write_metadata(&md, &rsp, false)?;
350            Ok(())
351        }
352        Some(p) => {
353            println!(
354                "Writing parsed data to file {}",
355                p.to_string_lossy().bright_yellow()
356            );
357
358            // Determine row count
359            let total_rows_to_process = if let Some(r) = rows {
360                std::cmp::min(r, md.row_count as u32)
361            } else {
362                md.row_count as u32
363            };
364
365            let total_rows_to_stream =
366                resolve_stream_rows(reader, stream_rows, total_rows_to_process);
367            let total_rows_processed = Arc::new(std::sync::atomic::AtomicUsize::new(0));
368            let progress = create_progress(no_progress, total_rows_to_process)?;
369
370            let offsets = build_offsets(total_rows_to_process, total_rows_to_stream)?;
371
372            let use_parallel_writes =
373                parallel && parallel_write && matches!(wc.format, OutFormat::Parquet);
374
375            let input_path = rsp.path.clone();
376
377            #[cfg(feature = "parquet")]
378            let out_path_clone = wc.out_path.clone();
379            #[cfg(feature = "parquet")]
380            let compression_clone = wc.compression;
381            #[cfg(feature = "parquet")]
382            let compression_level_clone = wc.compression_level;
383            #[cfg(feature = "parquet")]
384            let buffer_size_bytes = parallel_write_buffer_mb * 1024 * 1024;
385
386            let var_count = md.var_count;
387            let vars_shared = Arc::new(md.vars);
388            let schema_shared = Arc::new(md.schema);
389
390            #[cfg(feature = "sql")]
391            let sql_schema = schema_shared.clone();
392            #[cfg(feature = "sql")]
393            let sql_table_name = table_name_from_path(&rsp.path);
394            #[cfg(feature = "sql")]
395            let sql_format = wc.format;
396
397            let (s, r) = bounded(CHANNEL_CAPACITY);
398            let progress_thread = progress.clone();
399            let wc_thread = wc.clone();
400
401            // Process data in batches (i.e. stream chunks of rows)
402            let reader_handle = thread::spawn(move || -> Result<(), ReadStatError> {
403                let offsets_pairs: Vec<_> = offsets.windows(2).collect();
404                let pairs_cnt = offsets_pairs.len();
405
406                let num_threads = usize::from(!parallel);
407                let pool = rayon::ThreadPoolBuilder::new()
408                    .num_threads(num_threads)
409                    .build()
410                    .map_err(|e| {
411                        ReadStatError::Other(format!("Failed to build thread pool: {e}"))
412                    })?;
413
414                let results: Vec<Result<(ReadStatData, WriteConfig, usize), ReadStatError>> = pool
415                    .install(|| {
416                        offsets_pairs
417                            .par_iter()
418                            .map(
419                                |w| -> Result<(ReadStatData, WriteConfig, usize), ReadStatError> {
420                                    let row_start = w[0];
421                                    let row_end = w[1];
422
423                                    let mut d = ReadStatData::new()
424                                        .set_column_filter(
425                                            column_filter.clone(),
426                                            original_var_count,
427                                        )
428                                        .set_no_progress(no_progress)
429                                        .set_total_rows_to_process(total_rows_to_process as usize)
430                                        .set_total_rows_processed(total_rows_processed.clone())
431                                        .init_shared(
432                                            var_count,
433                                            vars_shared.clone(),
434                                            schema_shared.clone(),
435                                            row_start,
436                                            row_end,
437                                        );
438
439                                    if let Some(ref p) = progress_thread {
440                                        d = d.set_progress(p.clone() as Arc<dyn ProgressCallback>);
441                                    }
442
443                                    d.read_data(&rsp)?;
444
445                                    Ok((d, wc_thread.clone(), pairs_cnt))
446                                },
447                            )
448                            .collect()
449                    });
450
451                let mut errors = Vec::new();
452                for result in results {
453                    match result {
454                        Ok(data) => {
455                            if s.send(data).is_err() {
456                                errors.push(ReadStatError::Other(
457                                    "Error when attempting to send read data for writing"
458                                        .to_string(),
459                                ));
460                            }
461                        }
462                        Err(e) => errors.push(e),
463                    }
464                }
465
466                drop(s);
467
468                if !errors.is_empty() {
469                    eprintln!("The following errors occurred when processing data:");
470                    for e in &errors {
471                        eprintln!("    Error: {e:#?}");
472                    }
473                }
474
475                Ok(())
476            });
477
478            // Write
479
480            #[cfg(feature = "sql")]
481            let has_sql = sql_query.is_some();
482            #[cfg(not(feature = "sql"))]
483            let has_sql = false;
484
485            if has_sql {
486                #[cfg(feature = "sql")]
487                {
488                    let query = sql_query
489                        .as_ref()
490                        .expect("sql_query must be set when has_sql is true");
491                    if let Some(out_path) = &out_path_clone {
492                        let mut all_batches = Vec::new();
493                        for (d, _wc, _) in r.iter() {
494                            if let Some(batch) = d.batch {
495                                all_batches.push(batch);
496                            }
497                        }
498                        let results =
499                            readstat::execute_sql(all_batches, sql_schema, &sql_table_name, query)?;
500                        readstat::write_sql_results(
501                            &results,
502                            out_path,
503                            sql_format,
504                            compression_clone,
505                            compression_level_clone,
506                        )?;
507                    } else {
508                        let mut all_batches = Vec::new();
509                        for (d, _wc, _) in r.iter() {
510                            if let Some(batch) = d.batch {
511                                all_batches.push(batch);
512                            }
513                        }
514                        let _results =
515                            readstat::execute_sql(all_batches, sql_schema, &sql_table_name, query)?;
516                    }
517                }
518            } else if use_parallel_writes {
519                #[cfg(feature = "parquet")]
520                {
521                    let temp_dir = if let Some(out_path) = &out_path_clone {
522                        match out_path.parent() {
523                            Ok(parent) => parent.to_path_buf(),
524                            Err(_) => std::env::current_dir()?,
525                        }
526                    } else {
527                        return Err(ReadStatError::Other(
528                            "No output path specified for parallel write".to_string(),
529                        ));
530                    };
531
532                    let mut all_temp_files: Vec<PathBuf> = Vec::new();
533                    let mut schema: Option<Arc<arrow_schema::Schema>> = None;
534                    let mut batch_idx: usize = 0;
535
536                    loop {
537                        let mut batch_group: Vec<(ReadStatData, WriteConfig, usize)> =
538                            Vec::with_capacity(CHANNEL_CAPACITY);
539                        for item in &r {
540                            batch_group.push(item);
541                            if batch_group.len() >= CHANNEL_CAPACITY {
542                                break;
543                            }
544                        }
545
546                        if batch_group.is_empty() {
547                            break;
548                        }
549
550                        if schema.is_none() {
551                            schema = Some(batch_group[0].0.schema.clone());
552                        }
553                        let schema_ref = schema
554                            .as_ref()
555                            .expect("schema must be set after first batch group");
556
557                        let temp_files: Vec<PathBuf> = batch_group
558                            .par_iter()
559                            .enumerate()
560                            .map(|(i, (d, _wc, _))| -> Result<PathBuf, ReadStatError> {
561                                let temp_file = temp_dir
562                                    .join(format!(".readstat_temp_{}.parquet", batch_idx + i));
563
564                                if let Some(batch) = &d.batch {
565                                    ReadStatWriter::write_batch_to_parquet(
566                                        batch,
567                                        schema_ref,
568                                        &temp_file,
569                                        compression_clone,
570                                        compression_level_clone,
571                                        buffer_size_bytes as usize,
572                                    )?;
573                                }
574
575                                Ok(temp_file)
576                            })
577                            .collect::<Result<Vec<_>, _>>()?;
578
579                        batch_idx += batch_group.len();
580                        // batch_group is implicitly dropped here at the end of the loop body,
581                        // freeing ReadStatData/RecordBatch memory before the next iteration
582                        all_temp_files.extend(temp_files);
583                    }
584
585                    // Merge all temp files into final output
586                    if !all_temp_files.is_empty()
587                        && let Some(out_path) = &out_path_clone
588                    {
589                        ReadStatWriter::merge_parquet_files(
590                            &all_temp_files,
591                            out_path,
592                            schema
593                                .as_ref()
594                                .expect("schema must be set when temp files exist"),
595                            compression_clone,
596                            compression_level_clone,
597                        )?;
598                    }
599                }
600                #[cfg(not(feature = "parquet"))]
601                {
602                    return Err(ReadStatError::Other(
603                        "Parallel writes require the parquet feature".to_string(),
604                    ));
605                }
606            } else {
607                // Sequential write mode (default) with BufWriter optimizations
608                let mut wtr = ReadStatWriter::new();
609
610                // d (ReadStatData) is implicitly dropped at each iteration boundary,
611                // preventing accumulation of RecordBatch memory across chunks
612                for (i, (d, wc, pairs_cnt)) in r.iter().enumerate() {
613                    wtr.write(&d, &wc)?;
614
615                    if i == (pairs_cnt - 1) {
616                        wtr.finish(&d, &wc, &input_path)?;
617                    }
618                }
619            }
620
621            if let Some(p) = progress {
622                p.pb.finish_with_message("Done");
623            }
624
625            match reader_handle.join() {
626                Ok(Ok(())) => {}
627                Ok(Err(e)) => return Err(e),
628                Err(_) => {
629                    return Err(ReadStatError::Other("Reader thread panicked".to_string()));
630                }
631            }
632
633            Ok(())
634        }
635    }
636}