1use std::path::{Path, PathBuf};
7
8#[cfg(feature = "parquet")]
9use parquet::basic::{BrotliLevel, Compression as ParquetCompressionCodec, GzipLevel, ZstdLevel};
10
11use log::warn;
12
13use crate::err::ReadStatError;
14
15#[non_exhaustive]
24#[derive(Debug, Clone, Copy)]
25pub enum OutFormat {
26 Csv,
28 Feather,
30 Ndjson,
32 Parquet,
34}
35
36impl std::fmt::Display for OutFormat {
37 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38 match self {
39 Self::Csv => f.write_str("csv"),
40 Self::Feather => f.write_str("feather"),
41 Self::Ndjson => f.write_str("ndjson"),
42 Self::Parquet => f.write_str("parquet"),
43 }
44 }
45}
46
47impl std::str::FromStr for OutFormat {
48 type Err = ReadStatError;
49
50 fn from_str(s: &str) -> Result<Self, Self::Err> {
58 match s.to_lowercase().as_str() {
59 "csv" => Ok(Self::Csv),
60 "feather" => Ok(Self::Feather),
61 "ndjson" => Ok(Self::Ndjson),
62 "parquet" => Ok(Self::Parquet),
63 _ => Err(ReadStatError::Other(format!(
64 "Unknown format: {s:?}. Expected one of: csv, feather, ndjson, parquet"
65 ))),
66 }
67 }
68}
69
70#[non_exhaustive]
75#[derive(Debug, Clone, Copy)]
76pub enum ParquetCompression {
77 Uncompressed,
79 Snappy,
81 Gzip,
83 Lz4Raw,
85 Brotli,
87 Zstd,
89}
90
91impl std::fmt::Display for ParquetCompression {
92 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93 match self {
94 Self::Uncompressed => f.write_str("uncompressed"),
95 Self::Snappy => f.write_str("snappy"),
96 Self::Gzip => f.write_str("gzip"),
97 Self::Lz4Raw => f.write_str("lz4-raw"),
98 Self::Brotli => f.write_str("brotli"),
99 Self::Zstd => f.write_str("zstd"),
100 }
101 }
102}
103
104#[derive(Debug, Clone)]
110pub struct WriteConfig {
111 pub out_path: Option<PathBuf>,
113 pub format: OutFormat,
115 pub overwrite: bool,
117 pub compression: Option<ParquetCompression>,
119 pub compression_level: Option<u32>,
121}
122
123impl WriteConfig {
124 pub fn new(
132 out_path: Option<PathBuf>,
133 format: Option<OutFormat>,
134 overwrite: bool,
135 compression: Option<ParquetCompression>,
136 compression_level: Option<u32>,
137 ) -> Result<Self, ReadStatError> {
138 let f = Self::validate_format(format);
139 let op = Self::validate_out_path(out_path, overwrite)?;
140 let op = if let Some(op) = op {
141 Self::validate_out_extension(&op, f)?
142 } else {
143 None
144 };
145 let cl = match compression {
146 None => {
147 if compression_level.is_some() {
148 warn!("Ignoring value of --compression-level as --compression was not set");
149 }
150 None
151 }
152 Some(pc) => Self::validate_compression_level(pc, compression_level)?,
153 };
154
155 Ok(Self {
156 out_path: op,
157 format: f,
158 overwrite,
159 compression,
160 compression_level: cl,
161 })
162 }
163
164 fn validate_format(format: Option<OutFormat>) -> OutFormat {
165 format.unwrap_or(OutFormat::Csv)
166 }
167
168 fn validate_out_extension(
170 path: &Path,
171 format: OutFormat,
172 ) -> Result<Option<PathBuf>, ReadStatError> {
173 path.extension()
174 .and_then(|e| e.to_str())
175 .map(std::borrow::ToOwned::to_owned)
176 .map_or(
177 Err(ReadStatError::Other(format!(
178 "File {} does not have an extension! Expecting extension {}.",
179 path.to_string_lossy(),
180 format
181 ))),
182 |e| {
183 if e == format.to_string() {
184 Ok(Some(path.to_owned()))
185 } else {
186 Err(ReadStatError::Other(format!(
187 "Expecting extension {}. Instead, file {} has extension {}.",
188 format,
189 path.to_string_lossy(),
190 e
191 )))
192 }
193 },
194 )
195 }
196
197 fn validate_out_path(
199 path: Option<PathBuf>,
200 overwrite: bool,
201 ) -> Result<Option<PathBuf>, ReadStatError> {
202 match path {
203 None => Ok(None),
204 Some(p) => {
205 let abs_path = std::path::absolute(&p)
206 .map_err(|e| ReadStatError::Other(format!("Failed to resolve path: {e}")))?;
207
208 match abs_path.parent() {
209 None => Err(ReadStatError::Other(format!(
210 "The parent directory of the value of the parameter --output ({}) does not exist",
211 abs_path.to_string_lossy()
212 ))),
213 Some(parent) => {
214 if parent.exists() {
215 if abs_path.exists() {
216 if overwrite {
217 warn!(
218 "The file {} will be overwritten!",
219 abs_path.to_string_lossy()
220 );
221 Ok(Some(abs_path))
222 } else {
223 Err(ReadStatError::Other(format!(
224 "The output file - {} - already exists! To overwrite the file, utilize the --overwrite parameter",
225 abs_path.to_string_lossy()
226 )))
227 }
228 } else {
229 Ok(Some(abs_path))
230 }
231 } else {
232 Err(ReadStatError::Other(format!(
233 "The parent directory of the value of the parameter --output ({}) does not exist",
234 parent.to_string_lossy()
235 )))
236 }
237 }
238 }
239 }
240 }
241 }
242
243 fn validate_compression_level(
245 compression: ParquetCompression,
246 compression_level: Option<u32>,
247 ) -> Result<Option<u32>, ReadStatError> {
248 let (name, max_level): (&str, Option<u32>) = match compression {
249 ParquetCompression::Uncompressed => ("uncompressed", None),
250 ParquetCompression::Snappy => ("snappy", None),
251 ParquetCompression::Lz4Raw => ("lz4-raw", None),
252 ParquetCompression::Gzip => ("gzip", Some(9)),
253 ParquetCompression::Brotli => ("brotli", Some(11)),
254 ParquetCompression::Zstd => ("zstd", Some(22)),
255 };
256
257 match (max_level, compression_level) {
258 (None | Some(_), None) => Ok(None),
259 (None, Some(_)) => {
260 warn!(
261 "Compression level is not required for compression={name}, ignoring value of --compression-level"
262 );
263 Ok(None)
264 }
265 (Some(max), Some(c)) => {
266 if c <= max {
267 Ok(Some(c))
268 } else {
269 Err(ReadStatError::Other(format!(
270 "The compression level of {c} is not a valid level for {name} compression. \
271 Instead, please use values between 0-{max}."
272 )))
273 }
274 }
275 }
276 }
277}
278
279#[cfg(feature = "parquet")]
283#[allow(clippy::cast_possible_wrap)]
284pub fn resolve_parquet_compression(
285 compression: Option<ParquetCompression>,
286 compression_level: Option<u32>,
287) -> Result<ParquetCompressionCodec, ReadStatError> {
288 let codec = match compression {
289 Some(ParquetCompression::Uncompressed) => ParquetCompressionCodec::UNCOMPRESSED,
290 Some(ParquetCompression::Snappy) | None => ParquetCompressionCodec::SNAPPY,
291 Some(ParquetCompression::Gzip) => {
292 if let Some(level) = compression_level {
293 let gzip_level = GzipLevel::try_new(level).map_err(|e| {
294 ReadStatError::Other(format!("Invalid Gzip compression level: {e}"))
295 })?;
296 ParquetCompressionCodec::GZIP(gzip_level)
297 } else {
298 ParquetCompressionCodec::GZIP(GzipLevel::default())
299 }
300 }
301 Some(ParquetCompression::Lz4Raw) => ParquetCompressionCodec::LZ4_RAW,
302 Some(ParquetCompression::Brotli) => {
303 if let Some(level) = compression_level {
304 let brotli_level = BrotliLevel::try_new(level).map_err(|e| {
305 ReadStatError::Other(format!("Invalid Brotli compression level: {e}"))
306 })?;
307 ParquetCompressionCodec::BROTLI(brotli_level)
308 } else {
309 ParquetCompressionCodec::BROTLI(BrotliLevel::default())
310 }
311 }
312 Some(ParquetCompression::Zstd) => {
313 if let Some(level) = compression_level {
314 let zstd_level = ZstdLevel::try_new(level as i32).map_err(|e| {
315 ReadStatError::Other(format!("Invalid Zstd compression level: {e}"))
316 })?;
317 ParquetCompressionCodec::ZSTD(zstd_level)
318 } else {
319 ParquetCompressionCodec::ZSTD(ZstdLevel::default())
320 }
321 }
322 };
323 Ok(codec)
324}
325
326#[cfg(test)]
327mod tests {
328 use super::*;
329
330 #[test]
333 fn format_none_defaults_to_csv() {
334 let f = WriteConfig::validate_format(None);
335 assert!(matches!(f, OutFormat::Csv));
336 }
337
338 #[test]
339 fn format_some_passes_through() {
340 let f = WriteConfig::validate_format(Some(OutFormat::Parquet));
341 assert!(matches!(f, OutFormat::Parquet));
342 }
343
344 #[test]
347 fn valid_csv_out_extension() {
348 let path = Path::new("/some/output.csv");
349 let result = WriteConfig::validate_out_extension(path, OutFormat::Csv).unwrap();
350 assert!(result.is_some());
351 }
352
353 #[test]
354 fn valid_parquet_out_extension() {
355 let path = Path::new("/some/output.parquet");
356 let result = WriteConfig::validate_out_extension(path, OutFormat::Parquet).unwrap();
357 assert!(result.is_some());
358 }
359
360 #[test]
361 fn valid_feather_out_extension() {
362 let path = Path::new("/some/output.feather");
363 let result = WriteConfig::validate_out_extension(path, OutFormat::Feather).unwrap();
364 assert!(result.is_some());
365 }
366
367 #[test]
368 fn valid_ndjson_out_extension() {
369 let path = Path::new("/some/output.ndjson");
370 let result = WriteConfig::validate_out_extension(path, OutFormat::Ndjson).unwrap();
371 assert!(result.is_some());
372 }
373
374 #[test]
375 fn mismatched_out_extension() {
376 let path = Path::new("/some/output.csv");
377 assert!(WriteConfig::validate_out_extension(path, OutFormat::Parquet).is_err());
378 }
379
380 #[test]
381 fn no_out_extension() {
382 let path = Path::new("/some/output");
383 assert!(WriteConfig::validate_out_extension(path, OutFormat::Csv).is_err());
384 }
385
386 #[test]
389 fn uncompressed_ignores_level() {
390 let result =
391 WriteConfig::validate_compression_level(ParquetCompression::Uncompressed, Some(5))
392 .unwrap();
393 assert_eq!(result, None);
394 }
395
396 #[test]
397 fn snappy_ignores_level() {
398 let result =
399 WriteConfig::validate_compression_level(ParquetCompression::Snappy, Some(5)).unwrap();
400 assert_eq!(result, None);
401 }
402
403 #[test]
404 fn lz4raw_ignores_level() {
405 let result =
406 WriteConfig::validate_compression_level(ParquetCompression::Lz4Raw, Some(5)).unwrap();
407 assert_eq!(result, None);
408 }
409
410 #[test]
411 fn gzip_valid_level() {
412 let result =
413 WriteConfig::validate_compression_level(ParquetCompression::Gzip, Some(5)).unwrap();
414 assert_eq!(result, Some(5));
415 }
416
417 #[test]
418 fn gzip_max_valid_level() {
419 let result =
420 WriteConfig::validate_compression_level(ParquetCompression::Gzip, Some(9)).unwrap();
421 assert_eq!(result, Some(9));
422 }
423
424 #[test]
425 fn gzip_invalid_level() {
426 assert!(
427 WriteConfig::validate_compression_level(ParquetCompression::Gzip, Some(10),).is_err()
428 );
429 }
430
431 #[test]
432 fn brotli_valid_level() {
433 let result =
434 WriteConfig::validate_compression_level(ParquetCompression::Brotli, Some(11)).unwrap();
435 assert_eq!(result, Some(11));
436 }
437
438 #[test]
439 fn brotli_invalid_level() {
440 assert!(
441 WriteConfig::validate_compression_level(ParquetCompression::Brotli, Some(12),).is_err()
442 );
443 }
444
445 #[test]
446 fn zstd_valid_level() {
447 let result =
448 WriteConfig::validate_compression_level(ParquetCompression::Zstd, Some(22)).unwrap();
449 assert_eq!(result, Some(22));
450 }
451
452 #[test]
453 fn zstd_invalid_level() {
454 assert!(
455 WriteConfig::validate_compression_level(ParquetCompression::Zstd, Some(23),).is_err()
456 );
457 }
458
459 #[test]
460 fn no_level_passes_through() {
461 let result =
462 WriteConfig::validate_compression_level(ParquetCompression::Gzip, None).unwrap();
463 assert_eq!(result, None);
464 }
465
466 #[test]
469 fn validate_out_path_none() {
470 assert!(
471 WriteConfig::validate_out_path(None, false)
472 .unwrap()
473 .is_none()
474 );
475 }
476}