readstat/
rs_write_config.rs

1//! Output configuration for writing Arrow data to various formats.
2//!
3//! [`WriteConfig`] captures the output file path, format, compression settings,
4//! and overwrite behavior, decoupled from input path validation ([`ReadStatPath`]).
5
6use std::path::{Path, PathBuf};
7
8#[cfg(feature = "parquet")]
9use parquet::basic::{BrotliLevel, Compression as ParquetCompressionCodec, GzipLevel, ZstdLevel};
10
11use log::warn;
12
13use crate::err::ReadStatError;
14
15/// Output file format for data conversion.
16///
17/// All variants are always present regardless of enabled features.
18/// Attempting to use a format whose feature is not enabled will
19/// result in a compile-time error in the writer code.
20#[derive(Debug, Clone, Copy)]
21pub enum OutFormat {
22    /// Comma-separated values.
23    Csv,
24    /// Feather (Arrow IPC) format.
25    Feather,
26    /// Newline-delimited JSON.
27    Ndjson,
28    /// Apache Parquet columnar format.
29    Parquet,
30}
31
32impl std::fmt::Display for OutFormat {
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        match self {
35            Self::Csv => f.write_str("csv"),
36            Self::Feather => f.write_str("feather"),
37            Self::Ndjson => f.write_str("ndjson"),
38            Self::Parquet => f.write_str("parquet"),
39        }
40    }
41}
42
43/// Parquet compression algorithm.
44#[derive(Debug, Clone, Copy)]
45pub enum ParquetCompression {
46    /// No compression.
47    Uncompressed,
48    /// Snappy compression (fast, moderate ratio).
49    Snappy,
50    /// Gzip compression (levels 0-9).
51    Gzip,
52    /// LZ4 raw compression.
53    Lz4Raw,
54    /// Brotli compression (levels 0-11).
55    Brotli,
56    /// Zstandard compression (levels 0-22).
57    Zstd,
58}
59
60impl std::fmt::Display for ParquetCompression {
61    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62        match self {
63            Self::Uncompressed => f.write_str("uncompressed"),
64            Self::Snappy => f.write_str("snappy"),
65            Self::Gzip => f.write_str("gzip"),
66            Self::Lz4Raw => f.write_str("lz4-raw"),
67            Self::Brotli => f.write_str("brotli"),
68            Self::Zstd => f.write_str("zstd"),
69        }
70    }
71}
72
73/// Output configuration for writing Arrow data.
74///
75/// Captures the output file path, format, compression settings, and overwrite
76/// behavior. Created separately from [`ReadStatPath`](crate::ReadStatPath),
77/// which handles only input path validation.
78#[derive(Debug, Clone)]
79pub struct WriteConfig {
80    /// Optional output file path.
81    pub out_path: Option<PathBuf>,
82    /// Output format (defaults to CSV).
83    pub format: OutFormat,
84    /// Whether to overwrite an existing output file.
85    pub overwrite: bool,
86    /// Optional Parquet compression algorithm.
87    pub compression: Option<ParquetCompression>,
88    /// Optional Parquet compression level.
89    pub compression_level: Option<u32>,
90}
91
92impl WriteConfig {
93    /// Creates a new `WriteConfig` after validating the output path, format,
94    /// and compression settings.
95    ///
96    /// # Errors
97    ///
98    /// Returns [`ReadStatError`] if the output path, format, or compression settings
99    /// are invalid.
100    pub fn new(
101        out_path: Option<PathBuf>,
102        format: Option<OutFormat>,
103        overwrite: bool,
104        compression: Option<ParquetCompression>,
105        compression_level: Option<u32>,
106    ) -> Result<Self, ReadStatError> {
107        let f = Self::validate_format(format);
108        let op = Self::validate_out_path(out_path, overwrite)?;
109        let op = if let Some(op) = op {
110            Self::validate_out_extension(&op, f)?
111        } else {
112            None
113        };
114        let cl = match compression {
115            None => {
116                if compression_level.is_some() {
117                    warn!("Ignoring value of --compression-level as --compression was not set");
118                }
119                None
120            }
121            Some(pc) => Self::validate_compression_level(pc, compression_level)?,
122        };
123
124        Ok(Self {
125            out_path: op,
126            format: f,
127            overwrite,
128            compression,
129            compression_level: cl,
130        })
131    }
132
133    fn validate_format(format: Option<OutFormat>) -> OutFormat {
134        format.unwrap_or(OutFormat::Csv)
135    }
136
137    /// Validates the output file extension matches the format.
138    fn validate_out_extension(
139        path: &Path,
140        format: OutFormat,
141    ) -> Result<Option<PathBuf>, ReadStatError> {
142        path.extension()
143            .and_then(|e| e.to_str())
144            .map(std::borrow::ToOwned::to_owned)
145            .map_or(
146                Err(ReadStatError::Other(format!(
147                    "File {} does not have an extension! Expecting extension {}.",
148                    path.to_string_lossy(),
149                    format
150                ))),
151                |e| {
152                    if e == format.to_string() {
153                        Ok(Some(path.to_owned()))
154                    } else {
155                        Err(ReadStatError::Other(format!(
156                            "Expecting extension {}. Instead, file {} has extension {}.",
157                            format,
158                            path.to_string_lossy(),
159                            e
160                        )))
161                    }
162                },
163            )
164    }
165
166    /// Validates the output path exists and handles overwrite logic.
167    fn validate_out_path(
168        path: Option<PathBuf>,
169        overwrite: bool,
170    ) -> Result<Option<PathBuf>, ReadStatError> {
171        match path {
172            None => Ok(None),
173            Some(p) => {
174                let abs_path = std::path::absolute(&p)
175                    .map_err(|e| ReadStatError::Other(format!("Failed to resolve path: {e}")))?;
176
177                match abs_path.parent() {
178                    None => Err(ReadStatError::Other(format!(
179                        "The parent directory of the value of the parameter --output ({}) does not exist",
180                        abs_path.to_string_lossy()
181                    ))),
182                    Some(parent) => {
183                        if parent.exists() {
184                            if abs_path.exists() {
185                                if overwrite {
186                                    warn!(
187                                        "The file {} will be overwritten!",
188                                        abs_path.to_string_lossy()
189                                    );
190                                    Ok(Some(abs_path))
191                                } else {
192                                    Err(ReadStatError::Other(format!(
193                                        "The output file - {} - already exists! To overwrite the file, utilize the --overwrite parameter",
194                                        abs_path.to_string_lossy()
195                                    )))
196                                }
197                            } else {
198                                Ok(Some(abs_path))
199                            }
200                        } else {
201                            Err(ReadStatError::Other(format!(
202                                "The parent directory of the value of the parameter --output ({}) does not exist",
203                                parent.to_string_lossy()
204                            )))
205                        }
206                    }
207                }
208            }
209        }
210    }
211
212    /// Validates compression level is valid for the given compression algorithm.
213    fn validate_compression_level(
214        compression: ParquetCompression,
215        compression_level: Option<u32>,
216    ) -> Result<Option<u32>, ReadStatError> {
217        let (name, max_level): (&str, Option<u32>) = match compression {
218            ParquetCompression::Uncompressed => ("uncompressed", None),
219            ParquetCompression::Snappy => ("snappy", None),
220            ParquetCompression::Lz4Raw => ("lz4-raw", None),
221            ParquetCompression::Gzip => ("gzip", Some(9)),
222            ParquetCompression::Brotli => ("brotli", Some(11)),
223            ParquetCompression::Zstd => ("zstd", Some(22)),
224        };
225
226        match (max_level, compression_level) {
227            (None | Some(_), None) => Ok(None),
228            (None, Some(_)) => {
229                warn!(
230                    "Compression level is not required for compression={name}, ignoring value of --compression-level"
231                );
232                Ok(None)
233            }
234            (Some(max), Some(c)) => {
235                if c <= max {
236                    Ok(Some(c))
237                } else {
238                    Err(ReadStatError::Other(format!(
239                        "The compression level of {c} is not a valid level for {name} compression. \
240                         Instead, please use values between 0-{max}."
241                    )))
242                }
243            }
244        }
245    }
246}
247
248/// Resolves [`ParquetCompression`] and an optional level into a Parquet compression codec.
249///
250/// Defaults to Snappy when no compression is specified.
251#[cfg(feature = "parquet")]
252#[allow(clippy::cast_possible_wrap)]
253pub fn resolve_parquet_compression(
254    compression: Option<ParquetCompression>,
255    compression_level: Option<u32>,
256) -> Result<ParquetCompressionCodec, ReadStatError> {
257    let codec = match compression {
258        Some(ParquetCompression::Uncompressed) => ParquetCompressionCodec::UNCOMPRESSED,
259        Some(ParquetCompression::Snappy) | None => ParquetCompressionCodec::SNAPPY,
260        Some(ParquetCompression::Gzip) => {
261            if let Some(level) = compression_level {
262                let gzip_level = GzipLevel::try_new(level).map_err(|e| {
263                    ReadStatError::Other(format!("Invalid Gzip compression level: {e}"))
264                })?;
265                ParquetCompressionCodec::GZIP(gzip_level)
266            } else {
267                ParquetCompressionCodec::GZIP(GzipLevel::default())
268            }
269        }
270        Some(ParquetCompression::Lz4Raw) => ParquetCompressionCodec::LZ4_RAW,
271        Some(ParquetCompression::Brotli) => {
272            if let Some(level) = compression_level {
273                let brotli_level = BrotliLevel::try_new(level).map_err(|e| {
274                    ReadStatError::Other(format!("Invalid Brotli compression level: {e}"))
275                })?;
276                ParquetCompressionCodec::BROTLI(brotli_level)
277            } else {
278                ParquetCompressionCodec::BROTLI(BrotliLevel::default())
279            }
280        }
281        Some(ParquetCompression::Zstd) => {
282            if let Some(level) = compression_level {
283                let zstd_level = ZstdLevel::try_new(level as i32).map_err(|e| {
284                    ReadStatError::Other(format!("Invalid Zstd compression level: {e}"))
285                })?;
286                ParquetCompressionCodec::ZSTD(zstd_level)
287            } else {
288                ParquetCompressionCodec::ZSTD(ZstdLevel::default())
289            }
290        }
291    };
292    Ok(codec)
293}
294
295#[cfg(test)]
296mod tests {
297    use super::*;
298
299    // --- validate_format ---
300
301    #[test]
302    fn format_none_defaults_to_csv() {
303        let f = WriteConfig::validate_format(None);
304        assert!(matches!(f, OutFormat::Csv));
305    }
306
307    #[test]
308    fn format_some_passes_through() {
309        let f = WriteConfig::validate_format(Some(OutFormat::Parquet));
310        assert!(matches!(f, OutFormat::Parquet));
311    }
312
313    // --- validate_out_extension ---
314
315    #[test]
316    fn valid_csv_out_extension() {
317        let path = Path::new("/some/output.csv");
318        let result = WriteConfig::validate_out_extension(path, OutFormat::Csv).unwrap();
319        assert!(result.is_some());
320    }
321
322    #[test]
323    fn valid_parquet_out_extension() {
324        let path = Path::new("/some/output.parquet");
325        let result = WriteConfig::validate_out_extension(path, OutFormat::Parquet).unwrap();
326        assert!(result.is_some());
327    }
328
329    #[test]
330    fn valid_feather_out_extension() {
331        let path = Path::new("/some/output.feather");
332        let result = WriteConfig::validate_out_extension(path, OutFormat::Feather).unwrap();
333        assert!(result.is_some());
334    }
335
336    #[test]
337    fn valid_ndjson_out_extension() {
338        let path = Path::new("/some/output.ndjson");
339        let result = WriteConfig::validate_out_extension(path, OutFormat::Ndjson).unwrap();
340        assert!(result.is_some());
341    }
342
343    #[test]
344    fn mismatched_out_extension() {
345        let path = Path::new("/some/output.csv");
346        assert!(WriteConfig::validate_out_extension(path, OutFormat::Parquet).is_err());
347    }
348
349    #[test]
350    fn no_out_extension() {
351        let path = Path::new("/some/output");
352        assert!(WriteConfig::validate_out_extension(path, OutFormat::Csv).is_err());
353    }
354
355    // --- validate_compression_level ---
356
357    #[test]
358    fn uncompressed_ignores_level() {
359        let result =
360            WriteConfig::validate_compression_level(ParquetCompression::Uncompressed, Some(5))
361                .unwrap();
362        assert_eq!(result, None);
363    }
364
365    #[test]
366    fn snappy_ignores_level() {
367        let result =
368            WriteConfig::validate_compression_level(ParquetCompression::Snappy, Some(5)).unwrap();
369        assert_eq!(result, None);
370    }
371
372    #[test]
373    fn lz4raw_ignores_level() {
374        let result =
375            WriteConfig::validate_compression_level(ParquetCompression::Lz4Raw, Some(5)).unwrap();
376        assert_eq!(result, None);
377    }
378
379    #[test]
380    fn gzip_valid_level() {
381        let result =
382            WriteConfig::validate_compression_level(ParquetCompression::Gzip, Some(5)).unwrap();
383        assert_eq!(result, Some(5));
384    }
385
386    #[test]
387    fn gzip_max_valid_level() {
388        let result =
389            WriteConfig::validate_compression_level(ParquetCompression::Gzip, Some(9)).unwrap();
390        assert_eq!(result, Some(9));
391    }
392
393    #[test]
394    fn gzip_invalid_level() {
395        assert!(
396            WriteConfig::validate_compression_level(ParquetCompression::Gzip, Some(10),).is_err()
397        );
398    }
399
400    #[test]
401    fn brotli_valid_level() {
402        let result =
403            WriteConfig::validate_compression_level(ParquetCompression::Brotli, Some(11)).unwrap();
404        assert_eq!(result, Some(11));
405    }
406
407    #[test]
408    fn brotli_invalid_level() {
409        assert!(
410            WriteConfig::validate_compression_level(ParquetCompression::Brotli, Some(12),).is_err()
411        );
412    }
413
414    #[test]
415    fn zstd_valid_level() {
416        let result =
417            WriteConfig::validate_compression_level(ParquetCompression::Zstd, Some(22)).unwrap();
418        assert_eq!(result, Some(22));
419    }
420
421    #[test]
422    fn zstd_invalid_level() {
423        assert!(
424            WriteConfig::validate_compression_level(ParquetCompression::Zstd, Some(23),).is_err()
425        );
426    }
427
428    #[test]
429    fn no_level_passes_through() {
430        let result =
431            WriteConfig::validate_compression_level(ParquetCompression::Gzip, None).unwrap();
432        assert_eq!(result, None);
433    }
434
435    // --- validate_out_path ---
436
437    #[test]
438    fn validate_out_path_none() {
439        assert!(
440            WriteConfig::validate_out_path(None, false)
441                .unwrap()
442                .is_none()
443        );
444    }
445}