Skip to main content

readstat/
rs_write_config.rs

1//! Output configuration for writing Arrow data to various formats.
2//!
3//! [`WriteConfig`] captures the output file path, format, compression settings,
4//! and overwrite behavior, decoupled from input path validation ([`ReadStatPath`]).
5
6use std::path::{Path, PathBuf};
7
8#[cfg(feature = "parquet")]
9use parquet::basic::{BrotliLevel, Compression as ParquetCompressionCodec, GzipLevel, ZstdLevel};
10
11use log::warn;
12
13use crate::err::ReadStatError;
14
15/// Output file format for data conversion.
16///
17/// All variants are always present regardless of enabled features.
18/// Attempting to use a format whose feature is not enabled will
19/// result in a compile-time error in the writer code.
20///
21/// This enum is `#[non_exhaustive]`: new format variants may be added in
22/// minor releases. Match with a wildcard arm to remain forward-compatible.
23#[non_exhaustive]
24#[derive(Debug, Clone, Copy)]
25pub enum OutFormat {
26    /// Comma-separated values.
27    Csv,
28    /// Feather (Arrow IPC) format.
29    Feather,
30    /// Newline-delimited JSON.
31    Ndjson,
32    /// Apache Parquet columnar format.
33    Parquet,
34}
35
36impl std::fmt::Display for OutFormat {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        match self {
39            Self::Csv => f.write_str("csv"),
40            Self::Feather => f.write_str("feather"),
41            Self::Ndjson => f.write_str("ndjson"),
42            Self::Parquet => f.write_str("parquet"),
43        }
44    }
45}
46
47impl std::str::FromStr for OutFormat {
48    type Err = ReadStatError;
49
50    /// Parses a format name (case-insensitive) into an [`OutFormat`].
51    ///
52    /// Accepted values: `"csv"`, `"feather"`, `"ndjson"`, `"parquet"`.
53    ///
54    /// # Errors
55    ///
56    /// Returns [`ReadStatError::Other`] for unrecognized format strings.
57    fn from_str(s: &str) -> Result<Self, Self::Err> {
58        match s.to_lowercase().as_str() {
59            "csv" => Ok(Self::Csv),
60            "feather" => Ok(Self::Feather),
61            "ndjson" => Ok(Self::Ndjson),
62            "parquet" => Ok(Self::Parquet),
63            _ => Err(ReadStatError::Other(format!(
64                "Unknown format: {s:?}. Expected one of: csv, feather, ndjson, parquet"
65            ))),
66        }
67    }
68}
69
70/// Parquet compression algorithm.
71///
72/// This enum is `#[non_exhaustive]`: new codec variants may be added in
73/// minor releases. Match with a wildcard arm to remain forward-compatible.
74#[non_exhaustive]
75#[derive(Debug, Clone, Copy)]
76pub enum ParquetCompression {
77    /// No compression.
78    Uncompressed,
79    /// Snappy compression (fast, moderate ratio).
80    Snappy,
81    /// Gzip compression (levels 0-9).
82    Gzip,
83    /// LZ4 raw compression.
84    Lz4Raw,
85    /// Brotli compression (levels 0-11).
86    Brotli,
87    /// Zstandard compression (levels 0-22).
88    Zstd,
89}
90
91impl std::fmt::Display for ParquetCompression {
92    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93        match self {
94            Self::Uncompressed => f.write_str("uncompressed"),
95            Self::Snappy => f.write_str("snappy"),
96            Self::Gzip => f.write_str("gzip"),
97            Self::Lz4Raw => f.write_str("lz4-raw"),
98            Self::Brotli => f.write_str("brotli"),
99            Self::Zstd => f.write_str("zstd"),
100        }
101    }
102}
103
104/// Output configuration for writing Arrow data.
105///
106/// Captures the output file path, format, compression settings, and overwrite
107/// behavior. Created separately from [`ReadStatPath`](crate::ReadStatPath),
108/// which handles only input path validation.
109#[derive(Debug, Clone)]
110pub struct WriteConfig {
111    /// Optional output file path.
112    pub out_path: Option<PathBuf>,
113    /// Output format (defaults to CSV).
114    pub format: OutFormat,
115    /// Whether to overwrite an existing output file.
116    pub overwrite: bool,
117    /// Optional Parquet compression algorithm.
118    pub compression: Option<ParquetCompression>,
119    /// Optional Parquet compression level.
120    pub compression_level: Option<u32>,
121}
122
123impl WriteConfig {
124    /// Creates a new `WriteConfig` after validating the output path, format,
125    /// and compression settings.
126    ///
127    /// # Errors
128    ///
129    /// Returns [`ReadStatError`] if the output path, format, or compression settings
130    /// are invalid.
131    pub fn new(
132        out_path: Option<PathBuf>,
133        format: Option<OutFormat>,
134        overwrite: bool,
135        compression: Option<ParquetCompression>,
136        compression_level: Option<u32>,
137    ) -> Result<Self, ReadStatError> {
138        let f = Self::validate_format(format);
139        let op = Self::validate_out_path(out_path, overwrite)?;
140        let op = if let Some(op) = op {
141            Self::validate_out_extension(&op, f)?
142        } else {
143            None
144        };
145        let cl = match compression {
146            None => {
147                if compression_level.is_some() {
148                    warn!("Ignoring value of --compression-level as --compression was not set");
149                }
150                None
151            }
152            Some(pc) => Self::validate_compression_level(pc, compression_level)?,
153        };
154
155        Ok(Self {
156            out_path: op,
157            format: f,
158            overwrite,
159            compression,
160            compression_level: cl,
161        })
162    }
163
164    fn validate_format(format: Option<OutFormat>) -> OutFormat {
165        format.unwrap_or(OutFormat::Csv)
166    }
167
168    /// Validates the output file extension matches the format.
169    fn validate_out_extension(
170        path: &Path,
171        format: OutFormat,
172    ) -> Result<Option<PathBuf>, ReadStatError> {
173        path.extension()
174            .and_then(|e| e.to_str())
175            .map(std::borrow::ToOwned::to_owned)
176            .map_or(
177                Err(ReadStatError::Other(format!(
178                    "File {} does not have an extension! Expecting extension {}.",
179                    path.to_string_lossy(),
180                    format
181                ))),
182                |e| {
183                    if e == format.to_string() {
184                        Ok(Some(path.to_owned()))
185                    } else {
186                        Err(ReadStatError::Other(format!(
187                            "Expecting extension {}. Instead, file {} has extension {}.",
188                            format,
189                            path.to_string_lossy(),
190                            e
191                        )))
192                    }
193                },
194            )
195    }
196
197    /// Validates the output path exists and handles overwrite logic.
198    fn validate_out_path(
199        path: Option<PathBuf>,
200        overwrite: bool,
201    ) -> Result<Option<PathBuf>, ReadStatError> {
202        match path {
203            None => Ok(None),
204            Some(p) => {
205                let abs_path = std::path::absolute(&p)
206                    .map_err(|e| ReadStatError::Other(format!("Failed to resolve path: {e}")))?;
207
208                match abs_path.parent() {
209                    None => Err(ReadStatError::Other(format!(
210                        "The parent directory of the value of the parameter --output ({}) does not exist",
211                        abs_path.to_string_lossy()
212                    ))),
213                    Some(parent) => {
214                        if parent.exists() {
215                            if abs_path.exists() {
216                                if overwrite {
217                                    warn!(
218                                        "The file {} will be overwritten!",
219                                        abs_path.to_string_lossy()
220                                    );
221                                    Ok(Some(abs_path))
222                                } else {
223                                    Err(ReadStatError::Other(format!(
224                                        "The output file - {} - already exists! To overwrite the file, utilize the --overwrite parameter",
225                                        abs_path.to_string_lossy()
226                                    )))
227                                }
228                            } else {
229                                Ok(Some(abs_path))
230                            }
231                        } else {
232                            Err(ReadStatError::Other(format!(
233                                "The parent directory of the value of the parameter --output ({}) does not exist",
234                                parent.to_string_lossy()
235                            )))
236                        }
237                    }
238                }
239            }
240        }
241    }
242
243    /// Validates compression level is valid for the given compression algorithm.
244    fn validate_compression_level(
245        compression: ParquetCompression,
246        compression_level: Option<u32>,
247    ) -> Result<Option<u32>, ReadStatError> {
248        let (name, max_level): (&str, Option<u32>) = match compression {
249            ParquetCompression::Uncompressed => ("uncompressed", None),
250            ParquetCompression::Snappy => ("snappy", None),
251            ParquetCompression::Lz4Raw => ("lz4-raw", None),
252            ParquetCompression::Gzip => ("gzip", Some(9)),
253            ParquetCompression::Brotli => ("brotli", Some(11)),
254            ParquetCompression::Zstd => ("zstd", Some(22)),
255        };
256
257        match (max_level, compression_level) {
258            (None | Some(_), None) => Ok(None),
259            (None, Some(_)) => {
260                warn!(
261                    "Compression level is not required for compression={name}, ignoring value of --compression-level"
262                );
263                Ok(None)
264            }
265            (Some(max), Some(c)) => {
266                if c <= max {
267                    Ok(Some(c))
268                } else {
269                    Err(ReadStatError::Other(format!(
270                        "The compression level of {c} is not a valid level for {name} compression. \
271                         Instead, please use values between 0-{max}."
272                    )))
273                }
274            }
275        }
276    }
277}
278
279/// Resolves [`ParquetCompression`] and an optional level into a Parquet compression codec.
280///
281/// Defaults to Snappy when no compression is specified.
282#[cfg(feature = "parquet")]
283#[allow(clippy::cast_possible_wrap)]
284pub fn resolve_parquet_compression(
285    compression: Option<ParquetCompression>,
286    compression_level: Option<u32>,
287) -> Result<ParquetCompressionCodec, ReadStatError> {
288    let codec = match compression {
289        Some(ParquetCompression::Uncompressed) => ParquetCompressionCodec::UNCOMPRESSED,
290        Some(ParquetCompression::Snappy) | None => ParquetCompressionCodec::SNAPPY,
291        Some(ParquetCompression::Gzip) => {
292            if let Some(level) = compression_level {
293                let gzip_level = GzipLevel::try_new(level).map_err(|e| {
294                    ReadStatError::Other(format!("Invalid Gzip compression level: {e}"))
295                })?;
296                ParquetCompressionCodec::GZIP(gzip_level)
297            } else {
298                ParquetCompressionCodec::GZIP(GzipLevel::default())
299            }
300        }
301        Some(ParquetCompression::Lz4Raw) => ParquetCompressionCodec::LZ4_RAW,
302        Some(ParquetCompression::Brotli) => {
303            if let Some(level) = compression_level {
304                let brotli_level = BrotliLevel::try_new(level).map_err(|e| {
305                    ReadStatError::Other(format!("Invalid Brotli compression level: {e}"))
306                })?;
307                ParquetCompressionCodec::BROTLI(brotli_level)
308            } else {
309                ParquetCompressionCodec::BROTLI(BrotliLevel::default())
310            }
311        }
312        Some(ParquetCompression::Zstd) => {
313            if let Some(level) = compression_level {
314                let zstd_level = ZstdLevel::try_new(level as i32).map_err(|e| {
315                    ReadStatError::Other(format!("Invalid Zstd compression level: {e}"))
316                })?;
317                ParquetCompressionCodec::ZSTD(zstd_level)
318            } else {
319                ParquetCompressionCodec::ZSTD(ZstdLevel::default())
320            }
321        }
322    };
323    Ok(codec)
324}
325
326#[cfg(test)]
327mod tests {
328    use super::*;
329
330    // --- validate_format ---
331
332    #[test]
333    fn format_none_defaults_to_csv() {
334        let f = WriteConfig::validate_format(None);
335        assert!(matches!(f, OutFormat::Csv));
336    }
337
338    #[test]
339    fn format_some_passes_through() {
340        let f = WriteConfig::validate_format(Some(OutFormat::Parquet));
341        assert!(matches!(f, OutFormat::Parquet));
342    }
343
344    // --- validate_out_extension ---
345
346    #[test]
347    fn valid_csv_out_extension() {
348        let path = Path::new("/some/output.csv");
349        let result = WriteConfig::validate_out_extension(path, OutFormat::Csv).unwrap();
350        assert!(result.is_some());
351    }
352
353    #[test]
354    fn valid_parquet_out_extension() {
355        let path = Path::new("/some/output.parquet");
356        let result = WriteConfig::validate_out_extension(path, OutFormat::Parquet).unwrap();
357        assert!(result.is_some());
358    }
359
360    #[test]
361    fn valid_feather_out_extension() {
362        let path = Path::new("/some/output.feather");
363        let result = WriteConfig::validate_out_extension(path, OutFormat::Feather).unwrap();
364        assert!(result.is_some());
365    }
366
367    #[test]
368    fn valid_ndjson_out_extension() {
369        let path = Path::new("/some/output.ndjson");
370        let result = WriteConfig::validate_out_extension(path, OutFormat::Ndjson).unwrap();
371        assert!(result.is_some());
372    }
373
374    #[test]
375    fn mismatched_out_extension() {
376        let path = Path::new("/some/output.csv");
377        assert!(WriteConfig::validate_out_extension(path, OutFormat::Parquet).is_err());
378    }
379
380    #[test]
381    fn no_out_extension() {
382        let path = Path::new("/some/output");
383        assert!(WriteConfig::validate_out_extension(path, OutFormat::Csv).is_err());
384    }
385
386    // --- validate_compression_level ---
387
388    #[test]
389    fn uncompressed_ignores_level() {
390        let result =
391            WriteConfig::validate_compression_level(ParquetCompression::Uncompressed, Some(5))
392                .unwrap();
393        assert_eq!(result, None);
394    }
395
396    #[test]
397    fn snappy_ignores_level() {
398        let result =
399            WriteConfig::validate_compression_level(ParquetCompression::Snappy, Some(5)).unwrap();
400        assert_eq!(result, None);
401    }
402
403    #[test]
404    fn lz4raw_ignores_level() {
405        let result =
406            WriteConfig::validate_compression_level(ParquetCompression::Lz4Raw, Some(5)).unwrap();
407        assert_eq!(result, None);
408    }
409
410    #[test]
411    fn gzip_valid_level() {
412        let result =
413            WriteConfig::validate_compression_level(ParquetCompression::Gzip, Some(5)).unwrap();
414        assert_eq!(result, Some(5));
415    }
416
417    #[test]
418    fn gzip_max_valid_level() {
419        let result =
420            WriteConfig::validate_compression_level(ParquetCompression::Gzip, Some(9)).unwrap();
421        assert_eq!(result, Some(9));
422    }
423
424    #[test]
425    fn gzip_invalid_level() {
426        assert!(
427            WriteConfig::validate_compression_level(ParquetCompression::Gzip, Some(10),).is_err()
428        );
429    }
430
431    #[test]
432    fn brotli_valid_level() {
433        let result =
434            WriteConfig::validate_compression_level(ParquetCompression::Brotli, Some(11)).unwrap();
435        assert_eq!(result, Some(11));
436    }
437
438    #[test]
439    fn brotli_invalid_level() {
440        assert!(
441            WriteConfig::validate_compression_level(ParquetCompression::Brotli, Some(12),).is_err()
442        );
443    }
444
445    #[test]
446    fn zstd_valid_level() {
447        let result =
448            WriteConfig::validate_compression_level(ParquetCompression::Zstd, Some(22)).unwrap();
449        assert_eq!(result, Some(22));
450    }
451
452    #[test]
453    fn zstd_invalid_level() {
454        assert!(
455            WriteConfig::validate_compression_level(ParquetCompression::Zstd, Some(23),).is_err()
456        );
457    }
458
459    #[test]
460    fn no_level_passes_through() {
461        let result =
462            WriteConfig::validate_compression_level(ParquetCompression::Gzip, None).unwrap();
463        assert_eq!(result, None);
464    }
465
466    // --- validate_out_path ---
467
468    #[test]
469    fn validate_out_path_none() {
470        assert!(
471            WriteConfig::validate_out_path(None, false)
472                .unwrap()
473                .is_none()
474        );
475    }
476}