Skip to main content

readstat/
rs_write_config.rs

1//! Output configuration for writing Arrow data to various formats.
2//!
3//! [`WriteConfig`] captures the output file path, format, compression settings,
4//! and overwrite behavior, decoupled from input path validation ([`ReadStatPath`]).
5
6use std::path::{Path, PathBuf};
7
8#[cfg(feature = "parquet")]
9use parquet::basic::{BrotliLevel, Compression as ParquetCompressionCodec, GzipLevel, ZstdLevel};
10
11use log::warn;
12
13use crate::err::ReadStatError;
14
15/// Output file format for data conversion.
16///
17/// All variants are always present regardless of which writer features are
18/// enabled. Attempting to *write* a format whose feature is disabled returns a
19/// runtime [`ReadStatError`](crate::ReadStatError) from the writer rather than
20/// failing to compile.
21///
22/// This enum is `#[non_exhaustive]`: new format variants may be added in
23/// minor releases. Match with a wildcard arm to remain forward-compatible.
24#[non_exhaustive]
25#[derive(Debug, Clone, Copy)]
26pub enum OutFormat {
27    /// Comma-separated values.
28    Csv,
29    /// Feather (Arrow IPC) format.
30    Feather,
31    /// Newline-delimited JSON.
32    Ndjson,
33    /// Apache Parquet columnar format.
34    Parquet,
35}
36
37impl std::fmt::Display for OutFormat {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        match self {
40            Self::Csv => f.write_str("csv"),
41            Self::Feather => f.write_str("feather"),
42            Self::Ndjson => f.write_str("ndjson"),
43            Self::Parquet => f.write_str("parquet"),
44        }
45    }
46}
47
48impl std::str::FromStr for OutFormat {
49    type Err = ReadStatError;
50
51    /// Parses a format name (case-insensitive) into an [`OutFormat`].
52    ///
53    /// Accepted values: `"csv"`, `"feather"`, `"ndjson"`, `"parquet"`.
54    ///
55    /// # Errors
56    ///
57    /// Returns [`ReadStatError::UnknownFormat`] for unrecognized format strings.
58    fn from_str(s: &str) -> Result<Self, Self::Err> {
59        match s.to_lowercase().as_str() {
60            "csv" => Ok(Self::Csv),
61            "feather" => Ok(Self::Feather),
62            "ndjson" => Ok(Self::Ndjson),
63            "parquet" => Ok(Self::Parquet),
64            _ => Err(ReadStatError::UnknownFormat(s.to_string())),
65        }
66    }
67}
68
69/// Parquet compression algorithm.
70///
71/// This enum is `#[non_exhaustive]`: new codec variants may be added in
72/// minor releases. Match with a wildcard arm to remain forward-compatible.
73#[non_exhaustive]
74#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
75pub enum ParquetCompression {
76    /// No compression.
77    Uncompressed,
78    /// Snappy compression (fast, moderate ratio).
79    Snappy,
80    /// Gzip compression (levels 0-9).
81    Gzip,
82    /// LZ4 raw compression.
83    Lz4Raw,
84    /// Brotli compression (levels 0-11).
85    Brotli,
86    /// Zstandard compression (levels 0-22).
87    Zstd,
88}
89
90impl std::fmt::Display for ParquetCompression {
91    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92        match self {
93            Self::Uncompressed => f.write_str("uncompressed"),
94            Self::Snappy => f.write_str("snappy"),
95            Self::Gzip => f.write_str("gzip"),
96            Self::Lz4Raw => f.write_str("lz4-raw"),
97            Self::Brotli => f.write_str("brotli"),
98            Self::Zstd => f.write_str("zstd"),
99        }
100    }
101}
102
103impl std::str::FromStr for ParquetCompression {
104    type Err = ReadStatError;
105
106    /// Parses a codec name (case-insensitive) into a [`ParquetCompression`].
107    ///
108    /// Accepted values: `"uncompressed"`, `"snappy"`, `"gzip"`, `"lz4-raw"`
109    /// (or `"lz4raw"`), `"brotli"`, `"zstd"`.
110    ///
111    /// # Errors
112    ///
113    /// Returns [`ReadStatError::UnknownFormat`] for unrecognized codec names.
114    fn from_str(s: &str) -> Result<Self, Self::Err> {
115        match s.to_lowercase().as_str() {
116            "uncompressed" => Ok(Self::Uncompressed),
117            "snappy" => Ok(Self::Snappy),
118            "gzip" => Ok(Self::Gzip),
119            "lz4-raw" | "lz4raw" => Ok(Self::Lz4Raw),
120            "brotli" => Ok(Self::Brotli),
121            "zstd" => Ok(Self::Zstd),
122            _ => Err(ReadStatError::UnknownFormat(s.to_string())),
123        }
124    }
125}
126
127/// Output configuration for writing Arrow data.
128///
129/// Captures the output file path, format, compression settings, and overwrite
130/// behavior. Created separately from [`ReadStatPath`](crate::ReadStatPath),
131/// which handles only input path validation.
132///
133/// Fields are private and validated by [`new`](WriteConfig::new); read them via
134/// the accessor methods. This prevents constructing a config that bypasses path,
135/// extension, and compression-level validation.
136#[derive(Debug, Clone)]
137pub struct WriteConfig {
138    /// Optional output file path.
139    pub(crate) out_path: Option<PathBuf>,
140    /// Output format (defaults to CSV).
141    pub(crate) format: OutFormat,
142    /// Whether to overwrite an existing output file.
143    pub(crate) overwrite: bool,
144    /// Optional Parquet compression algorithm.
145    pub(crate) compression: Option<ParquetCompression>,
146    /// Optional Parquet compression level.
147    pub(crate) compression_level: Option<u32>,
148}
149
150impl WriteConfig {
151    /// Creates a new `WriteConfig` after validating the output path, format,
152    /// and compression settings.
153    ///
154    /// # Errors
155    ///
156    /// Returns [`ReadStatError`] if the output path, format, or compression settings
157    /// are invalid.
158    pub fn new(
159        out_path: Option<PathBuf>,
160        format: Option<OutFormat>,
161        overwrite: bool,
162        compression: Option<ParquetCompression>,
163        compression_level: Option<u32>,
164    ) -> Result<Self, ReadStatError> {
165        let f = Self::validate_format(format);
166        let op = Self::validate_out_path(out_path, overwrite)?;
167        let op = if let Some(op) = op {
168            Self::validate_out_extension(&op, f)?
169        } else {
170            None
171        };
172        let cl = match compression {
173            None => {
174                if compression_level.is_some() {
175                    warn!("Ignoring value of --compression-level as --compression was not set");
176                }
177                None
178            }
179            Some(pc) => Self::validate_compression_level(pc, compression_level)?,
180        };
181
182        Ok(Self {
183            out_path: op,
184            format: f,
185            overwrite,
186            compression,
187            compression_level: cl,
188        })
189    }
190
191    /// The validated output path, or `None` to write CSV to stdout.
192    #[must_use]
193    pub fn out_path(&self) -> Option<&Path> {
194        self.out_path.as_deref()
195    }
196
197    /// The output format.
198    #[must_use]
199    pub const fn format(&self) -> OutFormat {
200        self.format
201    }
202
203    /// Whether an existing output file may be overwritten.
204    #[must_use]
205    pub const fn overwrite(&self) -> bool {
206        self.overwrite
207    }
208
209    /// The configured Parquet compression codec, if any.
210    #[must_use]
211    pub const fn compression(&self) -> Option<ParquetCompression> {
212        self.compression
213    }
214
215    /// The configured Parquet compression level, if any.
216    #[must_use]
217    pub const fn compression_level(&self) -> Option<u32> {
218        self.compression_level
219    }
220
221    fn validate_format(format: Option<OutFormat>) -> OutFormat {
222        format.unwrap_or(OutFormat::Csv)
223    }
224
225    /// Validates the output file extension matches the format.
226    fn validate_out_extension(
227        path: &Path,
228        format: OutFormat,
229    ) -> Result<Option<PathBuf>, ReadStatError> {
230        match path.extension().and_then(|e| e.to_str()) {
231            Some(e) if e.eq_ignore_ascii_case(&format.to_string()) => Ok(Some(path.to_owned())),
232            _ => Err(ReadStatError::OutputExtensionMismatch {
233                path: path.to_owned(),
234                expected: format.to_string(),
235            }),
236        }
237    }
238
239    /// Validates the output path exists and handles overwrite logic.
240    fn validate_out_path(
241        path: Option<PathBuf>,
242        overwrite: bool,
243    ) -> Result<Option<PathBuf>, ReadStatError> {
244        match path {
245            None => Ok(None),
246            Some(p) => {
247                let abs_path = std::path::absolute(&p)
248                    .map_err(|e| ReadStatError::Other(format!("Failed to resolve path: {e}")))?;
249
250                match abs_path.parent() {
251                    None => Err(ReadStatError::OutputParentMissing(abs_path.clone())),
252                    Some(parent) => {
253                        if parent.exists() {
254                            if abs_path.exists() {
255                                if overwrite {
256                                    warn!(
257                                        "The file {} will be overwritten!",
258                                        abs_path.to_string_lossy()
259                                    );
260                                    Ok(Some(abs_path))
261                                } else {
262                                    Err(ReadStatError::OutputFileExists(abs_path))
263                                }
264                            } else {
265                                Ok(Some(abs_path))
266                            }
267                        } else {
268                            Err(ReadStatError::OutputParentMissing(parent.to_path_buf()))
269                        }
270                    }
271                }
272            }
273        }
274    }
275
276    /// Validates compression level is valid for the given compression algorithm.
277    fn validate_compression_level(
278        compression: ParquetCompression,
279        compression_level: Option<u32>,
280    ) -> Result<Option<u32>, ReadStatError> {
281        let (name, max_level): (&str, Option<u32>) = match compression {
282            ParquetCompression::Uncompressed => ("uncompressed", None),
283            ParquetCompression::Snappy => ("snappy", None),
284            ParquetCompression::Lz4Raw => ("lz4-raw", None),
285            ParquetCompression::Gzip => ("gzip", Some(9)),
286            ParquetCompression::Brotli => ("brotli", Some(11)),
287            ParquetCompression::Zstd => ("zstd", Some(22)),
288        };
289
290        match (max_level, compression_level) {
291            (None | Some(_), None) => Ok(None),
292            (None, Some(_)) => {
293                warn!(
294                    "Compression level is not required for compression={name}, ignoring value of --compression-level"
295                );
296                Ok(None)
297            }
298            (Some(max), Some(c)) => {
299                if c <= max {
300                    Ok(Some(c))
301                } else {
302                    Err(ReadStatError::Other(format!(
303                        "The compression level of {c} is not a valid level for {name} compression. \
304                         Instead, please use values between 0-{max}."
305                    )))
306                }
307            }
308        }
309    }
310}
311
312/// Resolves [`ParquetCompression`] and an optional level into a Parquet compression codec.
313///
314/// Defaults to Snappy when no compression is specified.
315#[cfg(feature = "parquet")]
316#[allow(clippy::cast_possible_wrap)]
317pub fn resolve_parquet_compression(
318    compression: Option<ParquetCompression>,
319    compression_level: Option<u32>,
320) -> Result<ParquetCompressionCodec, ReadStatError> {
321    let codec = match compression {
322        Some(ParquetCompression::Uncompressed) => ParquetCompressionCodec::UNCOMPRESSED,
323        Some(ParquetCompression::Snappy) | None => ParquetCompressionCodec::SNAPPY,
324        Some(ParquetCompression::Gzip) => {
325            if let Some(level) = compression_level {
326                let gzip_level = GzipLevel::try_new(level).map_err(|e| {
327                    ReadStatError::Other(format!("Invalid Gzip compression level: {e}"))
328                })?;
329                ParquetCompressionCodec::GZIP(gzip_level)
330            } else {
331                ParquetCompressionCodec::GZIP(GzipLevel::default())
332            }
333        }
334        Some(ParquetCompression::Lz4Raw) => ParquetCompressionCodec::LZ4_RAW,
335        Some(ParquetCompression::Brotli) => {
336            if let Some(level) = compression_level {
337                let brotli_level = BrotliLevel::try_new(level).map_err(|e| {
338                    ReadStatError::Other(format!("Invalid Brotli compression level: {e}"))
339                })?;
340                ParquetCompressionCodec::BROTLI(brotli_level)
341            } else {
342                ParquetCompressionCodec::BROTLI(BrotliLevel::default())
343            }
344        }
345        Some(ParquetCompression::Zstd) => {
346            if let Some(level) = compression_level {
347                let zstd_level = ZstdLevel::try_new(level as i32).map_err(|e| {
348                    ReadStatError::Other(format!("Invalid Zstd compression level: {e}"))
349                })?;
350                ParquetCompressionCodec::ZSTD(zstd_level)
351            } else {
352                ParquetCompressionCodec::ZSTD(ZstdLevel::default())
353            }
354        }
355    };
356    Ok(codec)
357}
358
359#[cfg(test)]
360mod tests {
361    use super::*;
362
363    // --- validate_format ---
364
365    #[test]
366    fn format_none_defaults_to_csv() {
367        let f = WriteConfig::validate_format(None);
368        assert!(matches!(f, OutFormat::Csv));
369    }
370
371    #[test]
372    fn format_some_passes_through() {
373        let f = WriteConfig::validate_format(Some(OutFormat::Parquet));
374        assert!(matches!(f, OutFormat::Parquet));
375    }
376
377    // --- validate_out_extension ---
378
379    #[test]
380    fn valid_csv_out_extension() {
381        let path = Path::new("/some/output.csv");
382        let result = WriteConfig::validate_out_extension(path, OutFormat::Csv).unwrap();
383        assert!(result.is_some());
384    }
385
386    #[test]
387    fn valid_parquet_out_extension() {
388        let path = Path::new("/some/output.parquet");
389        let result = WriteConfig::validate_out_extension(path, OutFormat::Parquet).unwrap();
390        assert!(result.is_some());
391    }
392
393    #[test]
394    fn valid_feather_out_extension() {
395        let path = Path::new("/some/output.feather");
396        let result = WriteConfig::validate_out_extension(path, OutFormat::Feather).unwrap();
397        assert!(result.is_some());
398    }
399
400    #[test]
401    fn valid_ndjson_out_extension() {
402        let path = Path::new("/some/output.ndjson");
403        let result = WriteConfig::validate_out_extension(path, OutFormat::Ndjson).unwrap();
404        assert!(result.is_some());
405    }
406
407    #[test]
408    fn mismatched_out_extension() {
409        let path = Path::new("/some/output.csv");
410        assert!(WriteConfig::validate_out_extension(path, OutFormat::Parquet).is_err());
411    }
412
413    #[test]
414    fn no_out_extension() {
415        let path = Path::new("/some/output");
416        assert!(WriteConfig::validate_out_extension(path, OutFormat::Csv).is_err());
417    }
418
419    // --- validate_compression_level ---
420
421    #[test]
422    fn uncompressed_ignores_level() {
423        let result =
424            WriteConfig::validate_compression_level(ParquetCompression::Uncompressed, Some(5))
425                .unwrap();
426        assert_eq!(result, None);
427    }
428
429    #[test]
430    fn snappy_ignores_level() {
431        let result =
432            WriteConfig::validate_compression_level(ParquetCompression::Snappy, Some(5)).unwrap();
433        assert_eq!(result, None);
434    }
435
436    #[test]
437    fn lz4raw_ignores_level() {
438        let result =
439            WriteConfig::validate_compression_level(ParquetCompression::Lz4Raw, Some(5)).unwrap();
440        assert_eq!(result, None);
441    }
442
443    #[test]
444    fn gzip_valid_level() {
445        let result =
446            WriteConfig::validate_compression_level(ParquetCompression::Gzip, Some(5)).unwrap();
447        assert_eq!(result, Some(5));
448    }
449
450    #[test]
451    fn gzip_max_valid_level() {
452        let result =
453            WriteConfig::validate_compression_level(ParquetCompression::Gzip, Some(9)).unwrap();
454        assert_eq!(result, Some(9));
455    }
456
457    #[test]
458    fn gzip_invalid_level() {
459        assert!(
460            WriteConfig::validate_compression_level(ParquetCompression::Gzip, Some(10),).is_err()
461        );
462    }
463
464    #[test]
465    fn brotli_valid_level() {
466        let result =
467            WriteConfig::validate_compression_level(ParquetCompression::Brotli, Some(11)).unwrap();
468        assert_eq!(result, Some(11));
469    }
470
471    #[test]
472    fn brotli_invalid_level() {
473        assert!(
474            WriteConfig::validate_compression_level(ParquetCompression::Brotli, Some(12),).is_err()
475        );
476    }
477
478    #[test]
479    fn zstd_valid_level() {
480        let result =
481            WriteConfig::validate_compression_level(ParquetCompression::Zstd, Some(22)).unwrap();
482        assert_eq!(result, Some(22));
483    }
484
485    #[test]
486    fn zstd_invalid_level() {
487        assert!(
488            WriteConfig::validate_compression_level(ParquetCompression::Zstd, Some(23),).is_err()
489        );
490    }
491
492    #[test]
493    fn no_level_passes_through() {
494        let result =
495            WriteConfig::validate_compression_level(ParquetCompression::Gzip, None).unwrap();
496        assert_eq!(result, None);
497    }
498
499    // --- validate_out_path ---
500
501    #[test]
502    fn validate_out_path_none() {
503        assert!(
504            WriteConfig::validate_out_path(None, false)
505                .unwrap()
506                .is_none()
507        );
508    }
509}