1use std::path::{Path, PathBuf};
7
8#[cfg(feature = "parquet")]
9use parquet::basic::{BrotliLevel, Compression as ParquetCompressionCodec, GzipLevel, ZstdLevel};
10
11use log::warn;
12
13use crate::err::ReadStatError;
14
15#[derive(Debug, Clone, Copy)]
21pub enum OutFormat {
22 Csv,
24 Feather,
26 Ndjson,
28 Parquet,
30}
31
32impl std::fmt::Display for OutFormat {
33 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34 match self {
35 Self::Csv => f.write_str("csv"),
36 Self::Feather => f.write_str("feather"),
37 Self::Ndjson => f.write_str("ndjson"),
38 Self::Parquet => f.write_str("parquet"),
39 }
40 }
41}
42
43#[derive(Debug, Clone, Copy)]
45pub enum ParquetCompression {
46 Uncompressed,
48 Snappy,
50 Gzip,
52 Lz4Raw,
54 Brotli,
56 Zstd,
58}
59
60impl std::fmt::Display for ParquetCompression {
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 match self {
63 Self::Uncompressed => f.write_str("uncompressed"),
64 Self::Snappy => f.write_str("snappy"),
65 Self::Gzip => f.write_str("gzip"),
66 Self::Lz4Raw => f.write_str("lz4-raw"),
67 Self::Brotli => f.write_str("brotli"),
68 Self::Zstd => f.write_str("zstd"),
69 }
70 }
71}
72
73#[derive(Debug, Clone)]
79pub struct WriteConfig {
80 pub out_path: Option<PathBuf>,
82 pub format: OutFormat,
84 pub overwrite: bool,
86 pub compression: Option<ParquetCompression>,
88 pub compression_level: Option<u32>,
90}
91
92impl WriteConfig {
93 pub fn new(
101 out_path: Option<PathBuf>,
102 format: Option<OutFormat>,
103 overwrite: bool,
104 compression: Option<ParquetCompression>,
105 compression_level: Option<u32>,
106 ) -> Result<Self, ReadStatError> {
107 let f = Self::validate_format(format);
108 let op = Self::validate_out_path(out_path, overwrite)?;
109 let op = if let Some(op) = op {
110 Self::validate_out_extension(&op, f)?
111 } else {
112 None
113 };
114 let cl = match compression {
115 None => {
116 if compression_level.is_some() {
117 warn!("Ignoring value of --compression-level as --compression was not set");
118 }
119 None
120 }
121 Some(pc) => Self::validate_compression_level(pc, compression_level)?,
122 };
123
124 Ok(Self {
125 out_path: op,
126 format: f,
127 overwrite,
128 compression,
129 compression_level: cl,
130 })
131 }
132
133 fn validate_format(format: Option<OutFormat>) -> OutFormat {
134 format.unwrap_or(OutFormat::Csv)
135 }
136
137 fn validate_out_extension(
139 path: &Path,
140 format: OutFormat,
141 ) -> Result<Option<PathBuf>, ReadStatError> {
142 path.extension()
143 .and_then(|e| e.to_str())
144 .map(std::borrow::ToOwned::to_owned)
145 .map_or(
146 Err(ReadStatError::Other(format!(
147 "File {} does not have an extension! Expecting extension {}.",
148 path.to_string_lossy(),
149 format
150 ))),
151 |e| {
152 if e == format.to_string() {
153 Ok(Some(path.to_owned()))
154 } else {
155 Err(ReadStatError::Other(format!(
156 "Expecting extension {}. Instead, file {} has extension {}.",
157 format,
158 path.to_string_lossy(),
159 e
160 )))
161 }
162 },
163 )
164 }
165
166 fn validate_out_path(
168 path: Option<PathBuf>,
169 overwrite: bool,
170 ) -> Result<Option<PathBuf>, ReadStatError> {
171 match path {
172 None => Ok(None),
173 Some(p) => {
174 let abs_path = std::path::absolute(&p)
175 .map_err(|e| ReadStatError::Other(format!("Failed to resolve path: {e}")))?;
176
177 match abs_path.parent() {
178 None => Err(ReadStatError::Other(format!(
179 "The parent directory of the value of the parameter --output ({}) does not exist",
180 abs_path.to_string_lossy()
181 ))),
182 Some(parent) => {
183 if parent.exists() {
184 if abs_path.exists() {
185 if overwrite {
186 warn!(
187 "The file {} will be overwritten!",
188 abs_path.to_string_lossy()
189 );
190 Ok(Some(abs_path))
191 } else {
192 Err(ReadStatError::Other(format!(
193 "The output file - {} - already exists! To overwrite the file, utilize the --overwrite parameter",
194 abs_path.to_string_lossy()
195 )))
196 }
197 } else {
198 Ok(Some(abs_path))
199 }
200 } else {
201 Err(ReadStatError::Other(format!(
202 "The parent directory of the value of the parameter --output ({}) does not exist",
203 parent.to_string_lossy()
204 )))
205 }
206 }
207 }
208 }
209 }
210 }
211
212 fn validate_compression_level(
214 compression: ParquetCompression,
215 compression_level: Option<u32>,
216 ) -> Result<Option<u32>, ReadStatError> {
217 let (name, max_level): (&str, Option<u32>) = match compression {
218 ParquetCompression::Uncompressed => ("uncompressed", None),
219 ParquetCompression::Snappy => ("snappy", None),
220 ParquetCompression::Lz4Raw => ("lz4-raw", None),
221 ParquetCompression::Gzip => ("gzip", Some(9)),
222 ParquetCompression::Brotli => ("brotli", Some(11)),
223 ParquetCompression::Zstd => ("zstd", Some(22)),
224 };
225
226 match (max_level, compression_level) {
227 (None | Some(_), None) => Ok(None),
228 (None, Some(_)) => {
229 warn!(
230 "Compression level is not required for compression={name}, ignoring value of --compression-level"
231 );
232 Ok(None)
233 }
234 (Some(max), Some(c)) => {
235 if c <= max {
236 Ok(Some(c))
237 } else {
238 Err(ReadStatError::Other(format!(
239 "The compression level of {c} is not a valid level for {name} compression. \
240 Instead, please use values between 0-{max}."
241 )))
242 }
243 }
244 }
245 }
246}
247
248#[cfg(feature = "parquet")]
252#[allow(clippy::cast_possible_wrap)]
253pub fn resolve_parquet_compression(
254 compression: Option<ParquetCompression>,
255 compression_level: Option<u32>,
256) -> Result<ParquetCompressionCodec, ReadStatError> {
257 let codec = match compression {
258 Some(ParquetCompression::Uncompressed) => ParquetCompressionCodec::UNCOMPRESSED,
259 Some(ParquetCompression::Snappy) | None => ParquetCompressionCodec::SNAPPY,
260 Some(ParquetCompression::Gzip) => {
261 if let Some(level) = compression_level {
262 let gzip_level = GzipLevel::try_new(level).map_err(|e| {
263 ReadStatError::Other(format!("Invalid Gzip compression level: {e}"))
264 })?;
265 ParquetCompressionCodec::GZIP(gzip_level)
266 } else {
267 ParquetCompressionCodec::GZIP(GzipLevel::default())
268 }
269 }
270 Some(ParquetCompression::Lz4Raw) => ParquetCompressionCodec::LZ4_RAW,
271 Some(ParquetCompression::Brotli) => {
272 if let Some(level) = compression_level {
273 let brotli_level = BrotliLevel::try_new(level).map_err(|e| {
274 ReadStatError::Other(format!("Invalid Brotli compression level: {e}"))
275 })?;
276 ParquetCompressionCodec::BROTLI(brotli_level)
277 } else {
278 ParquetCompressionCodec::BROTLI(BrotliLevel::default())
279 }
280 }
281 Some(ParquetCompression::Zstd) => {
282 if let Some(level) = compression_level {
283 let zstd_level = ZstdLevel::try_new(level as i32).map_err(|e| {
284 ReadStatError::Other(format!("Invalid Zstd compression level: {e}"))
285 })?;
286 ParquetCompressionCodec::ZSTD(zstd_level)
287 } else {
288 ParquetCompressionCodec::ZSTD(ZstdLevel::default())
289 }
290 }
291 };
292 Ok(codec)
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298
299 #[test]
302 fn format_none_defaults_to_csv() {
303 let f = WriteConfig::validate_format(None);
304 assert!(matches!(f, OutFormat::Csv));
305 }
306
307 #[test]
308 fn format_some_passes_through() {
309 let f = WriteConfig::validate_format(Some(OutFormat::Parquet));
310 assert!(matches!(f, OutFormat::Parquet));
311 }
312
313 #[test]
316 fn valid_csv_out_extension() {
317 let path = Path::new("/some/output.csv");
318 let result = WriteConfig::validate_out_extension(path, OutFormat::Csv).unwrap();
319 assert!(result.is_some());
320 }
321
322 #[test]
323 fn valid_parquet_out_extension() {
324 let path = Path::new("/some/output.parquet");
325 let result = WriteConfig::validate_out_extension(path, OutFormat::Parquet).unwrap();
326 assert!(result.is_some());
327 }
328
329 #[test]
330 fn valid_feather_out_extension() {
331 let path = Path::new("/some/output.feather");
332 let result = WriteConfig::validate_out_extension(path, OutFormat::Feather).unwrap();
333 assert!(result.is_some());
334 }
335
336 #[test]
337 fn valid_ndjson_out_extension() {
338 let path = Path::new("/some/output.ndjson");
339 let result = WriteConfig::validate_out_extension(path, OutFormat::Ndjson).unwrap();
340 assert!(result.is_some());
341 }
342
343 #[test]
344 fn mismatched_out_extension() {
345 let path = Path::new("/some/output.csv");
346 assert!(WriteConfig::validate_out_extension(path, OutFormat::Parquet).is_err());
347 }
348
349 #[test]
350 fn no_out_extension() {
351 let path = Path::new("/some/output");
352 assert!(WriteConfig::validate_out_extension(path, OutFormat::Csv).is_err());
353 }
354
355 #[test]
358 fn uncompressed_ignores_level() {
359 let result =
360 WriteConfig::validate_compression_level(ParquetCompression::Uncompressed, Some(5))
361 .unwrap();
362 assert_eq!(result, None);
363 }
364
365 #[test]
366 fn snappy_ignores_level() {
367 let result =
368 WriteConfig::validate_compression_level(ParquetCompression::Snappy, Some(5)).unwrap();
369 assert_eq!(result, None);
370 }
371
372 #[test]
373 fn lz4raw_ignores_level() {
374 let result =
375 WriteConfig::validate_compression_level(ParquetCompression::Lz4Raw, Some(5)).unwrap();
376 assert_eq!(result, None);
377 }
378
379 #[test]
380 fn gzip_valid_level() {
381 let result =
382 WriteConfig::validate_compression_level(ParquetCompression::Gzip, Some(5)).unwrap();
383 assert_eq!(result, Some(5));
384 }
385
386 #[test]
387 fn gzip_max_valid_level() {
388 let result =
389 WriteConfig::validate_compression_level(ParquetCompression::Gzip, Some(9)).unwrap();
390 assert_eq!(result, Some(9));
391 }
392
393 #[test]
394 fn gzip_invalid_level() {
395 assert!(
396 WriteConfig::validate_compression_level(ParquetCompression::Gzip, Some(10),).is_err()
397 );
398 }
399
400 #[test]
401 fn brotli_valid_level() {
402 let result =
403 WriteConfig::validate_compression_level(ParquetCompression::Brotli, Some(11)).unwrap();
404 assert_eq!(result, Some(11));
405 }
406
407 #[test]
408 fn brotli_invalid_level() {
409 assert!(
410 WriteConfig::validate_compression_level(ParquetCompression::Brotli, Some(12),).is_err()
411 );
412 }
413
414 #[test]
415 fn zstd_valid_level() {
416 let result =
417 WriteConfig::validate_compression_level(ParquetCompression::Zstd, Some(22)).unwrap();
418 assert_eq!(result, Some(22));
419 }
420
421 #[test]
422 fn zstd_invalid_level() {
423 assert!(
424 WriteConfig::validate_compression_level(ParquetCompression::Zstd, Some(23),).is_err()
425 );
426 }
427
428 #[test]
429 fn no_level_passes_through() {
430 let result =
431 WriteConfig::validate_compression_level(ParquetCompression::Gzip, None).unwrap();
432 assert_eq!(result, None);
433 }
434
435 #[test]
438 fn validate_out_path_none() {
439 assert!(
440 WriteConfig::validate_out_path(None, false)
441 .unwrap()
442 .is_none()
443 );
444 }
445}