diff --git a/datafusion/functions/benches/initcap.rs b/datafusion/functions/benches/initcap.rs index ba055d58f5664..a92c2cc2d5758 100644 --- a/datafusion/functions/benches/initcap.rs +++ b/datafusion/functions/benches/initcap.rs @@ -22,12 +22,14 @@ use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::{ create_string_array_with_len, create_string_view_array_with_len, }; -use criterion::{Criterion, criterion_group, criterion_main}; +use criterion::{Criterion, SamplingMode, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::unicode; use std::hint::black_box; use std::sync::Arc; +use std::time::Duration; fn create_args( size: usize, @@ -49,60 +51,87 @@ fn create_args( fn criterion_benchmark(c: &mut Criterion) { let initcap = unicode::initcap(); - for size in [1024, 4096] { - let args = create_args::(size, 8, true); - let arg_fields = args - .iter() - .enumerate() - .map(|(idx, arg)| { - Field::new(format!("arg_{idx}"), arg.data_type(), true).into() - }) - .collect::>(); - let config_options = Arc::new(ConfigOptions::default()); - - c.bench_function( - format!("initcap string view shorter than 12 [size={size}]").as_str(), - |b| { - b.iter(|| { - black_box(initcap.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - arg_fields: arg_fields.clone(), - number_rows: size, - return_field: Field::new("f", DataType::Utf8View, true).into(), - config_options: Arc::clone(&config_options), - })) - }) - }, - ); - - let args = create_args::(size, 16, true); - c.bench_function( - format!("initcap string view longer than 12 [size={size}]").as_str(), - |b| { - b.iter(|| { - black_box(initcap.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - arg_fields: arg_fields.clone(), - number_rows: size, - return_field: Field::new("f", DataType::Utf8View, true).into(), - config_options: Arc::clone(&config_options), - })) - }) - }, - ); - - let args = create_args::(size, 16, false); - c.bench_function(format!("initcap string [size={size}]").as_str(), |b| { + let config_options = Arc::new(ConfigOptions::default()); + + // Grouped benchmarks for array sizes - to compare with scalar performance + for size in [1024, 4096, 8192] { + let mut group = c.benchmark_group(format!("initcap size={size}")); + group.sampling_mode(SamplingMode::Flat); + group.sample_size(10); + group.measurement_time(Duration::from_secs(10)); + + // Array benchmark - Utf8 + let array_args = create_args::(size, 16, false); + let array_arg_fields = vec![Field::new("arg_0", DataType::Utf8, true).into()]; + let batch_len = size; + + group.bench_function("array_utf8", |b| { b.iter(|| { black_box(initcap.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - arg_fields: arg_fields.clone(), - number_rows: size, + args: array_args.clone(), + arg_fields: array_arg_fields.clone(), + number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), })) }) }); + + // Array benchmark - Utf8View + let array_view_args = create_args::(size, 16, true); + let array_view_arg_fields = + vec![Field::new("arg_0", DataType::Utf8View, true).into()]; + + group.bench_function("array_utf8view", |b| { + b.iter(|| { + black_box(initcap.invoke_with_args(ScalarFunctionArgs { + args: array_view_args.clone(), + arg_fields: array_view_arg_fields.clone(), + number_rows: batch_len, + return_field: Field::new("f", DataType::Utf8View, true).into(), + config_options: Arc::clone(&config_options), + })) + }) + }); + + // Scalar benchmark - Utf8 (the optimization we added) + let scalar_args = vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some( + "hello world test string".to_string(), + )))]; + let scalar_arg_fields = vec![Field::new("arg_0", DataType::Utf8, false).into()]; + + group.bench_function("scalar_utf8", |b| { + b.iter(|| { + black_box(initcap.invoke_with_args(ScalarFunctionArgs { + args: scalar_args.clone(), + arg_fields: scalar_arg_fields.clone(), + number_rows: 1, + return_field: Field::new("f", DataType::Utf8, false).into(), + config_options: Arc::clone(&config_options), + })) + }) + }); + + // Scalar benchmark - Utf8View + let scalar_view_args = vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + "hello world test string".to_string(), + )))]; + let scalar_view_arg_fields = + vec![Field::new("arg_0", DataType::Utf8View, false).into()]; + + group.bench_function("scalar_utf8view", |b| { + b.iter(|| { + black_box(initcap.invoke_with_args(ScalarFunctionArgs { + args: scalar_view_args.clone(), + arg_fields: scalar_view_arg_fields.clone(), + number_rows: 1, + return_field: Field::new("f", DataType::Utf8View, false).into(), + config_options: Arc::clone(&config_options), + })) + }) + }); + + group.finish(); } } diff --git a/datafusion/functions/src/unicode/initcap.rs b/datafusion/functions/src/unicode/initcap.rs index 929b0c316951b..e2fc9130992db 100644 --- a/datafusion/functions/src/unicode/initcap.rs +++ b/datafusion/functions/src/unicode/initcap.rs @@ -26,7 +26,7 @@ use arrow::datatypes::DataType; use crate::utils::{make_scalar_function, utf8_to_str_type}; use datafusion_common::cast::{as_generic_string_array, as_string_view_array}; use datafusion_common::types::logical_string; -use datafusion_common::{Result, exec_err}; +use datafusion_common::{Result, ScalarValue, exec_err}; use datafusion_expr::{ Coercion, ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignatureClass, Volatility, @@ -99,6 +99,39 @@ impl ScalarUDFImpl for InitcapFunc { &self, args: datafusion_expr::ScalarFunctionArgs, ) -> Result { + let arg = &args.args[0]; + + // Scalar fast path - handle directly without array conversion + if let ColumnarValue::Scalar(scalar) = arg { + return match scalar { + ScalarValue::Utf8(None) + | ScalarValue::LargeUtf8(None) + | ScalarValue::Utf8View(None) => Ok(arg.clone()), + ScalarValue::Utf8(Some(s)) => { + let mut result = String::new(); + initcap_string(s, &mut result); + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result)))) + } + ScalarValue::LargeUtf8(Some(s)) => { + let mut result = String::new(); + initcap_string(s, &mut result); + Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(result)))) + } + ScalarValue::Utf8View(Some(s)) => { + let mut result = String::new(); + initcap_string(s, &mut result); + Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some(result)))) + } + other => { + exec_err!( + "Unsupported data type {:?} for function `initcap`", + other.data_type() + ) + } + }; + } + + // Array path let args = &args.args; match args[0].data_type() { DataType::Utf8 => make_scalar_function(initcap::, vec![])(args),