Skip to content

Commit

Permalink
feat: add example for copy to
Browse files Browse the repository at this point in the history
  • Loading branch information
tshauck committed Jun 29, 2024
1 parent 7a7797c commit d983b60
Showing 1 changed file with 203 additions and 0 deletions.
203 changes: 203 additions & 0 deletions datafusion-examples/examples/custom_file_format.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use std::{any::Any, sync::Arc};

use arrow::array::{RecordBatch, StringArray, UInt8Array};
use arrow_schema::{DataType, Field, Schema, SchemaRef};
use datafusion::{
datasource::{
file_format::{
csv::CsvFormatFactory, file_compression_type::FileCompressionType,
FileFormat, FileFormatFactory,
},
physical_plan::{FileScanConfig, FileSinkConfig},
MemTable,
},
error::Result,
execution::{context::SessionState, runtime_env::RuntimeEnv},
physical_plan::ExecutionPlan,
prelude::{SessionConfig, SessionContext},
};
use datafusion_common::{GetExt, Statistics};
use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement};
use object_store::{ObjectMeta, ObjectStore};

#[derive(Debug)]
struct TSVFileFormat {
csv_file_format: Arc<dyn FileFormat>,
}

impl TSVFileFormat {
pub fn new(csv_file_format: Arc<dyn FileFormat>) -> Self {
Self { csv_file_format }
}
}

#[async_trait::async_trait]
impl FileFormat for TSVFileFormat {
fn as_any(&self) -> &dyn Any {
self
}

fn get_ext(&self) -> String {
"tsv".to_string()
}

fn get_ext_with_compression(
&self,
c: &FileCompressionType,
) -> datafusion::error::Result<String> {
if c == &FileCompressionType::UNCOMPRESSED {
Ok("tsv".to_string())
} else {
todo!("Compression not supported")
}
}

async fn infer_schema(
&self,
state: &SessionState,
store: &Arc<dyn ObjectStore>,
objects: &[ObjectMeta],
) -> Result<SchemaRef> {
self.csv_file_format
.infer_schema(state, store, objects)
.await
}

async fn infer_stats(
&self,
state: &SessionState,
store: &Arc<dyn ObjectStore>,
table_schema: SchemaRef,
object: &ObjectMeta,
) -> Result<Statistics> {
self.csv_file_format
.infer_stats(state, store, table_schema, object)
.await
}

async fn create_physical_plan(
&self,
state: &SessionState,
conf: FileScanConfig,
filters: Option<&Arc<dyn PhysicalExpr>>,
) -> Result<Arc<dyn ExecutionPlan>> {
self.csv_file_format
.create_physical_plan(state, conf, filters)
.await
}

async fn create_writer_physical_plan(
&self,
input: Arc<dyn ExecutionPlan>,
state: &SessionState,
conf: FileSinkConfig,
order_requirements: Option<Vec<PhysicalSortRequirement>>,
) -> Result<Arc<dyn ExecutionPlan>> {
self.csv_file_format
.create_writer_physical_plan(input, state, conf, order_requirements)
.await
}
}

#[derive(Default)]
pub struct TSVFileFactory {
csv_file_factory: CsvFormatFactory,
}

impl TSVFileFactory {
pub fn new() -> Self {
Self {
csv_file_factory: CsvFormatFactory::new(),
}
}
}

impl FileFormatFactory for TSVFileFactory {
fn create(
&self,
state: &SessionState,
format_options: &std::collections::HashMap<String, String>,
) -> Result<std::sync::Arc<dyn FileFormat>> {
let mut new_options = format_options.clone();
new_options.insert("format.delimiter".to_string(), "\t".to_string());

let csv_file_format = self.csv_file_factory.create(state, &new_options)?;
let tsv_file_format = Arc::new(TSVFileFormat::new(csv_file_format));

Ok(tsv_file_format)
}

fn default(&self) -> std::sync::Arc<dyn FileFormat> {
todo!()
}
}

impl GetExt for TSVFileFactory {
fn get_ext(&self) -> String {
"tsv".to_string()
}
}

#[tokio::main]
async fn main() -> Result<()> {
// Create a new context with the default configuration
let config = SessionConfig::new();
let runtime = RuntimeEnv::default();
let mut state = SessionState::new_with_config_rt(config, Arc::new(runtime));

// Register the custom file format
let file_format = Arc::new(TSVFileFactory::new());
state.register_file_format(file_format, true).unwrap();

// Create a new context with the custom file format
let ctx = SessionContext::new_with_state(state);

let mem_table = create_mem_table();
ctx.register_table("mem_table", mem_table).unwrap();

let d = ctx
.sql("COPY mem_table TO 'mem_table.tsv' STORED AS TSV;")
.await?;

let results = d.collect().await?;
println!("Number of inserted rows: {:?}", results[0]);

Ok(())
}

// create a simple mem table
fn create_mem_table() -> Arc<MemTable> {
let fields = vec![
Field::new("id", DataType::UInt8, false),
Field::new("data", DataType::Utf8, false),
];
let schema = Arc::new(Schema::new(fields));

let partitions = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(UInt8Array::from(vec![1, 2])),
Arc::new(StringArray::from(vec!["foo", "bar"])),
],
)
.unwrap();

Arc::new(MemTable::try_new(schema, vec![vec![partitions]]).unwrap())
}

0 comments on commit d983b60

Please sign in to comment.