1use std::path::{Path, PathBuf};
7
8#[cfg(feature = "parquet")]
9use parquet::basic::{BrotliLevel, Compression as ParquetCompressionCodec, GzipLevel, ZstdLevel};
10
11use log::warn;
12
13use crate::err::ReadStatError;
14
15#[non_exhaustive]
25#[derive(Debug, Clone, Copy)]
26pub enum OutFormat {
27 Csv,
29 Feather,
31 Ndjson,
33 Parquet,
35}
36
37impl std::fmt::Display for OutFormat {
38 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39 match self {
40 Self::Csv => f.write_str("csv"),
41 Self::Feather => f.write_str("feather"),
42 Self::Ndjson => f.write_str("ndjson"),
43 Self::Parquet => f.write_str("parquet"),
44 }
45 }
46}
47
48impl std::str::FromStr for OutFormat {
49 type Err = ReadStatError;
50
51 fn from_str(s: &str) -> Result<Self, Self::Err> {
59 match s.to_lowercase().as_str() {
60 "csv" => Ok(Self::Csv),
61 "feather" => Ok(Self::Feather),
62 "ndjson" => Ok(Self::Ndjson),
63 "parquet" => Ok(Self::Parquet),
64 _ => Err(ReadStatError::UnknownFormat(s.to_string())),
65 }
66 }
67}
68
69#[non_exhaustive]
74#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
75pub enum ParquetCompression {
76 Uncompressed,
78 Snappy,
80 Gzip,
82 Lz4Raw,
84 Brotli,
86 Zstd,
88}
89
90impl std::fmt::Display for ParquetCompression {
91 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92 match self {
93 Self::Uncompressed => f.write_str("uncompressed"),
94 Self::Snappy => f.write_str("snappy"),
95 Self::Gzip => f.write_str("gzip"),
96 Self::Lz4Raw => f.write_str("lz4-raw"),
97 Self::Brotli => f.write_str("brotli"),
98 Self::Zstd => f.write_str("zstd"),
99 }
100 }
101}
102
103impl std::str::FromStr for ParquetCompression {
104 type Err = ReadStatError;
105
106 fn from_str(s: &str) -> Result<Self, Self::Err> {
115 match s.to_lowercase().as_str() {
116 "uncompressed" => Ok(Self::Uncompressed),
117 "snappy" => Ok(Self::Snappy),
118 "gzip" => Ok(Self::Gzip),
119 "lz4-raw" | "lz4raw" => Ok(Self::Lz4Raw),
120 "brotli" => Ok(Self::Brotli),
121 "zstd" => Ok(Self::Zstd),
122 _ => Err(ReadStatError::UnknownFormat(s.to_string())),
123 }
124 }
125}
126
127#[derive(Debug, Clone)]
137pub struct WriteConfig {
138 pub(crate) out_path: Option<PathBuf>,
140 pub(crate) format: OutFormat,
142 pub(crate) overwrite: bool,
144 pub(crate) compression: Option<ParquetCompression>,
146 pub(crate) compression_level: Option<u32>,
148}
149
150impl WriteConfig {
151 pub fn new(
159 out_path: Option<PathBuf>,
160 format: Option<OutFormat>,
161 overwrite: bool,
162 compression: Option<ParquetCompression>,
163 compression_level: Option<u32>,
164 ) -> Result<Self, ReadStatError> {
165 let f = Self::validate_format(format);
166 let op = Self::validate_out_path(out_path, overwrite)?;
167 let op = if let Some(op) = op {
168 Self::validate_out_extension(&op, f)?
169 } else {
170 None
171 };
172 let cl = match compression {
173 None => {
174 if compression_level.is_some() {
175 warn!("Ignoring value of --compression-level as --compression was not set");
176 }
177 None
178 }
179 Some(pc) => Self::validate_compression_level(pc, compression_level)?,
180 };
181
182 Ok(Self {
183 out_path: op,
184 format: f,
185 overwrite,
186 compression,
187 compression_level: cl,
188 })
189 }
190
191 #[must_use]
193 pub fn out_path(&self) -> Option<&Path> {
194 self.out_path.as_deref()
195 }
196
197 #[must_use]
199 pub const fn format(&self) -> OutFormat {
200 self.format
201 }
202
203 #[must_use]
205 pub const fn overwrite(&self) -> bool {
206 self.overwrite
207 }
208
209 #[must_use]
211 pub const fn compression(&self) -> Option<ParquetCompression> {
212 self.compression
213 }
214
215 #[must_use]
217 pub const fn compression_level(&self) -> Option<u32> {
218 self.compression_level
219 }
220
221 fn validate_format(format: Option<OutFormat>) -> OutFormat {
222 format.unwrap_or(OutFormat::Csv)
223 }
224
225 fn validate_out_extension(
227 path: &Path,
228 format: OutFormat,
229 ) -> Result<Option<PathBuf>, ReadStatError> {
230 match path.extension().and_then(|e| e.to_str()) {
231 Some(e) if e.eq_ignore_ascii_case(&format.to_string()) => Ok(Some(path.to_owned())),
232 _ => Err(ReadStatError::OutputExtensionMismatch {
233 path: path.to_owned(),
234 expected: format.to_string(),
235 }),
236 }
237 }
238
239 fn validate_out_path(
241 path: Option<PathBuf>,
242 overwrite: bool,
243 ) -> Result<Option<PathBuf>, ReadStatError> {
244 match path {
245 None => Ok(None),
246 Some(p) => {
247 let abs_path = std::path::absolute(&p)
248 .map_err(|e| ReadStatError::Other(format!("Failed to resolve path: {e}")))?;
249
250 match abs_path.parent() {
251 None => Err(ReadStatError::OutputParentMissing(abs_path.clone())),
252 Some(parent) => {
253 if parent.exists() {
254 if abs_path.exists() {
255 if overwrite {
256 warn!(
257 "The file {} will be overwritten!",
258 abs_path.to_string_lossy()
259 );
260 Ok(Some(abs_path))
261 } else {
262 Err(ReadStatError::OutputFileExists(abs_path))
263 }
264 } else {
265 Ok(Some(abs_path))
266 }
267 } else {
268 Err(ReadStatError::OutputParentMissing(parent.to_path_buf()))
269 }
270 }
271 }
272 }
273 }
274 }
275
276 fn validate_compression_level(
278 compression: ParquetCompression,
279 compression_level: Option<u32>,
280 ) -> Result<Option<u32>, ReadStatError> {
281 let (name, max_level): (&str, Option<u32>) = match compression {
282 ParquetCompression::Uncompressed => ("uncompressed", None),
283 ParquetCompression::Snappy => ("snappy", None),
284 ParquetCompression::Lz4Raw => ("lz4-raw", None),
285 ParquetCompression::Gzip => ("gzip", Some(9)),
286 ParquetCompression::Brotli => ("brotli", Some(11)),
287 ParquetCompression::Zstd => ("zstd", Some(22)),
288 };
289
290 match (max_level, compression_level) {
291 (None | Some(_), None) => Ok(None),
292 (None, Some(_)) => {
293 warn!(
294 "Compression level is not required for compression={name}, ignoring value of --compression-level"
295 );
296 Ok(None)
297 }
298 (Some(max), Some(c)) => {
299 if c <= max {
300 Ok(Some(c))
301 } else {
302 Err(ReadStatError::Other(format!(
303 "The compression level of {c} is not a valid level for {name} compression. \
304 Instead, please use values between 0-{max}."
305 )))
306 }
307 }
308 }
309 }
310}
311
312#[cfg(feature = "parquet")]
316#[allow(clippy::cast_possible_wrap)]
317pub fn resolve_parquet_compression(
318 compression: Option<ParquetCompression>,
319 compression_level: Option<u32>,
320) -> Result<ParquetCompressionCodec, ReadStatError> {
321 let codec = match compression {
322 Some(ParquetCompression::Uncompressed) => ParquetCompressionCodec::UNCOMPRESSED,
323 Some(ParquetCompression::Snappy) | None => ParquetCompressionCodec::SNAPPY,
324 Some(ParquetCompression::Gzip) => {
325 if let Some(level) = compression_level {
326 let gzip_level = GzipLevel::try_new(level).map_err(|e| {
327 ReadStatError::Other(format!("Invalid Gzip compression level: {e}"))
328 })?;
329 ParquetCompressionCodec::GZIP(gzip_level)
330 } else {
331 ParquetCompressionCodec::GZIP(GzipLevel::default())
332 }
333 }
334 Some(ParquetCompression::Lz4Raw) => ParquetCompressionCodec::LZ4_RAW,
335 Some(ParquetCompression::Brotli) => {
336 if let Some(level) = compression_level {
337 let brotli_level = BrotliLevel::try_new(level).map_err(|e| {
338 ReadStatError::Other(format!("Invalid Brotli compression level: {e}"))
339 })?;
340 ParquetCompressionCodec::BROTLI(brotli_level)
341 } else {
342 ParquetCompressionCodec::BROTLI(BrotliLevel::default())
343 }
344 }
345 Some(ParquetCompression::Zstd) => {
346 if let Some(level) = compression_level {
347 let zstd_level = ZstdLevel::try_new(level as i32).map_err(|e| {
348 ReadStatError::Other(format!("Invalid Zstd compression level: {e}"))
349 })?;
350 ParquetCompressionCodec::ZSTD(zstd_level)
351 } else {
352 ParquetCompressionCodec::ZSTD(ZstdLevel::default())
353 }
354 }
355 };
356 Ok(codec)
357}
358
359#[cfg(test)]
360mod tests {
361 use super::*;
362
363 #[test]
366 fn format_none_defaults_to_csv() {
367 let f = WriteConfig::validate_format(None);
368 assert!(matches!(f, OutFormat::Csv));
369 }
370
371 #[test]
372 fn format_some_passes_through() {
373 let f = WriteConfig::validate_format(Some(OutFormat::Parquet));
374 assert!(matches!(f, OutFormat::Parquet));
375 }
376
377 #[test]
380 fn valid_csv_out_extension() {
381 let path = Path::new("/some/output.csv");
382 let result = WriteConfig::validate_out_extension(path, OutFormat::Csv).unwrap();
383 assert!(result.is_some());
384 }
385
386 #[test]
387 fn valid_parquet_out_extension() {
388 let path = Path::new("/some/output.parquet");
389 let result = WriteConfig::validate_out_extension(path, OutFormat::Parquet).unwrap();
390 assert!(result.is_some());
391 }
392
393 #[test]
394 fn valid_feather_out_extension() {
395 let path = Path::new("/some/output.feather");
396 let result = WriteConfig::validate_out_extension(path, OutFormat::Feather).unwrap();
397 assert!(result.is_some());
398 }
399
400 #[test]
401 fn valid_ndjson_out_extension() {
402 let path = Path::new("/some/output.ndjson");
403 let result = WriteConfig::validate_out_extension(path, OutFormat::Ndjson).unwrap();
404 assert!(result.is_some());
405 }
406
407 #[test]
408 fn mismatched_out_extension() {
409 let path = Path::new("/some/output.csv");
410 assert!(WriteConfig::validate_out_extension(path, OutFormat::Parquet).is_err());
411 }
412
413 #[test]
414 fn no_out_extension() {
415 let path = Path::new("/some/output");
416 assert!(WriteConfig::validate_out_extension(path, OutFormat::Csv).is_err());
417 }
418
419 #[test]
422 fn uncompressed_ignores_level() {
423 let result =
424 WriteConfig::validate_compression_level(ParquetCompression::Uncompressed, Some(5))
425 .unwrap();
426 assert_eq!(result, None);
427 }
428
429 #[test]
430 fn snappy_ignores_level() {
431 let result =
432 WriteConfig::validate_compression_level(ParquetCompression::Snappy, Some(5)).unwrap();
433 assert_eq!(result, None);
434 }
435
436 #[test]
437 fn lz4raw_ignores_level() {
438 let result =
439 WriteConfig::validate_compression_level(ParquetCompression::Lz4Raw, Some(5)).unwrap();
440 assert_eq!(result, None);
441 }
442
443 #[test]
444 fn gzip_valid_level() {
445 let result =
446 WriteConfig::validate_compression_level(ParquetCompression::Gzip, Some(5)).unwrap();
447 assert_eq!(result, Some(5));
448 }
449
450 #[test]
451 fn gzip_max_valid_level() {
452 let result =
453 WriteConfig::validate_compression_level(ParquetCompression::Gzip, Some(9)).unwrap();
454 assert_eq!(result, Some(9));
455 }
456
457 #[test]
458 fn gzip_invalid_level() {
459 assert!(
460 WriteConfig::validate_compression_level(ParquetCompression::Gzip, Some(10),).is_err()
461 );
462 }
463
464 #[test]
465 fn brotli_valid_level() {
466 let result =
467 WriteConfig::validate_compression_level(ParquetCompression::Brotli, Some(11)).unwrap();
468 assert_eq!(result, Some(11));
469 }
470
471 #[test]
472 fn brotli_invalid_level() {
473 assert!(
474 WriteConfig::validate_compression_level(ParquetCompression::Brotli, Some(12),).is_err()
475 );
476 }
477
478 #[test]
479 fn zstd_valid_level() {
480 let result =
481 WriteConfig::validate_compression_level(ParquetCompression::Zstd, Some(22)).unwrap();
482 assert_eq!(result, Some(22));
483 }
484
485 #[test]
486 fn zstd_invalid_level() {
487 assert!(
488 WriteConfig::validate_compression_level(ParquetCompression::Zstd, Some(23),).is_err()
489 );
490 }
491
492 #[test]
493 fn no_level_passes_through() {
494 let result =
495 WriteConfig::validate_compression_level(ParquetCompression::Gzip, None).unwrap();
496 assert_eq!(result, None);
497 }
498
499 #[test]
502 fn validate_out_path_none() {
503 assert!(
504 WriteConfig::validate_out_path(None, false)
505 .unwrap()
506 .is_none()
507 );
508 }
509}