@attrs.define
class UDF(Callable[[pa.RecordBatch], pa.Array]): # type: ignore
"""User-defined function (UDF) to be applied to a Lance Table."""
# The reference to the callable
func: Callable = attrs.field()
name: str = attrs.field(default="")
cuda: Optional[bool] = attrs.field(default=False)
num_cpus: Optional[float] = attrs.field(
default=1.0,
converter=lambda v: None if v is None else float(v),
validator=valid.optional(valid.ge(0.0)),
on_setattr=[attrs.setters.convert, attrs.setters.validate],
)
num_gpus: Optional[float] = attrs.field(
default=None,
converter=lambda v: None if v is None else float(v),
validator=valid.optional(valid.ge(0.0)),
on_setattr=[attrs.setters.convert, attrs.setters.validate],
)
memory: int | None = attrs.field(default=None)
batch_size: int | None = attrs.field(default=None)
# Error handling configuration
error_handling: Optional["ErrorHandlingConfig"] = attrs.field(default=None)
def _record_batch_input(self) -> bool:
sig = inspect.signature(self.func)
if len(sig.parameters) == 1:
param = list(sig.parameters.values())[0]
return param.annotation == pa.RecordBatch
return False
@property
def arg_type(self) -> UDFArgType:
if self._record_batch_input():
return UDFArgType.RECORD_BATCH
if _is_batched_func(self.func):
return UDFArgType.ARRAY
return UDFArgType.SCALAR
input_columns: list[str] | None = attrs.field(default=None)
data_type: pa.DataType = attrs.field(default=None)
version: str = attrs.field(default="")
checkpoint_key: str = attrs.field(default="")
field_metadata: dict[str, str] = attrs.field(default={})
def __attrs_post_init__(self) -> None:
"""
Initialize UDF fields and normalize num_gpus after all fields are set:
1) if cuda=True and num_gpus is None or 0.0 -> set to 1.0
2) otherwise ignore cuda and just use num_gpus setting
"""
# Set default name
if not self.name:
if inspect.isfunction(self.func):
self.name = self.func.__name__
elif isinstance(self.func, Callable):
self.name = self.func.__class__.__name__
else:
raise ValueError(
f"func must be a function or a callable, got {self.func}"
)
# Set default input_columns
if self.input_columns is None:
sig = inspect.signature(self.func)
params = list(sig.parameters.keys())
if self._record_batch_input():
self.input_columns = None
else:
self.input_columns = params
# Validate input_columns
if self.arg_type == UDFArgType.RECORD_BATCH:
if self.input_columns is not None:
raise ValueError(
"RecordBatch input UDF must not declare any input columns. "
"RecordBatch UDFs receive the entire batch and should not "
"specify input_columns. Consider using a stateful RecordBatch "
"UDF and parameterize it or use UDF with Array inputs."
)
else:
if self.input_columns is None:
raise ValueError("Array and Scalar input UDF must declare input column")
# Set default data_type
if self.data_type is None:
if self.arg_type != UDFArgType.SCALAR:
raise ValueError(
"batched UDFs do not support data_type inference yet,"
" please specify data_type",
)
self.data_type = _infer_func_arrow_type(self.func, None) # type: ignore[arg-type]
# Validate data_type
if self.data_type is None:
raise ValueError("data_type must be set")
if not isinstance(self.data_type, pa.DataType):
raise ValueError(
f"data_type must be a pyarrow.DataType, got {self.data_type}"
)
# Set default version
if not self.version:
hasher = hashlib.md5()
hasher.update(pickle.dumps(self.func))
self.version = hasher.hexdigest()
# Set default checkpoint_key
if not self.checkpoint_key:
self.checkpoint_key = f"{self.name}:{self.version}"
# Handle cuda/num_gpus normalization
if self.cuda:
warnings.warn(
"The 'cuda' flag is deprecated. Please set 'num_gpus' explicitly "
"(0.0 for CPU, >=1.0 for GPU).",
DeprecationWarning,
stacklevel=2,
)
if self.num_gpus is None:
self.num_gpus = 1.0 if self.cuda is True else 0.0
# otherwise fall back to user specified num_gpus
def _scalar_func_record_batch_call(self, record_batch: pa.RecordBatch) -> pa.Array:
"""
We use this when the UDF uses single call like
`func(x_int, y_string, ...) -> type`
this function automatically dispatches rows to the func and returns `pa.Array`
"""
# this let's us avoid having to allocate a list in python
# to hold the results. PA will allocate once for us
def _iter(): # noqa: ANN202
batches = (
record_batch.to_pylist()
if isinstance(record_batch, pa.RecordBatch)
else record_batch
)
for item in batches:
# we know inputs_columns is not none here
if BACKFILL_SELECTED not in item or item.get(BACKFILL_SELECTED):
# we know input_columns is not none here
args = [item[col] for col in self.input_columns] # type: ignore
yield self.func(*args)
else:
# item was not selected, so do not compute
yield None
arr = pa.array(
_iter(),
type=self.data_type,
)
# this should always by an Array, never should we get a ChunkedArray back here
assert isinstance(arr, pa.Array)
return arr
def _input_columns_validator(self, attribute, value) -> None:
"""Validate input_columns attribute for attrs compatibility."""
if self.arg_type == UDFArgType.RECORD_BATCH:
if value is not None:
raise ValueError(
"RecordBatch input UDF must not declare any input columns. "
"RecordBatch UDFs receive the entire batch and should not "
"specify input_columns."
)
else:
if value is None:
raise ValueError("Array and Scalar input UDF must declare input column")
def validate_against_schema(
self, table_schema: pa.Schema, input_columns: list[str] | None = None
) -> None:
"""
Validate UDF against table schema.
This is the primary validation method that should be called before executing
a UDF. It performs comprehensive validation including:
1. **Column Existence**: Verifies all input columns exist in the table schema
2. **Type Compatibility**: Checks that column types match UDF type annotations
(if present)
3. **RecordBatch Constraints**: Ensures RecordBatch UDFs don't have
input_columns defined
The validation happens at two points in the UDF lifecycle:
- At `add_columns()` time when defining the column
- At `backfill()` time when executing (if input_columns are overridden)
Parameters
----------
table_schema: pa.Schema
The schema of the table being processed
input_columns: list[str] | None
The input column names to validate. If None, uses self.input_columns.
Raises
------
ValueError: If validation fails for any of the following reasons:
- Input columns don't exist in table schema
- Type mismatch between table and UDF expectations
- RecordBatch UDF has input_columns defined
- Array/Scalar UDF has no input_columns defined
Warns
-----
UserWarning: If type validation is skipped due to:
- UDF has no type annotations
- Type annotation can't be mapped to PyArrow types
Examples
--------
>>> @udf(data_type=pa.int32())
... def my_udf(a: int) -> int:
... return a * 2
>>> my_udf.validate_against_schema(table.schema) # Validates column 'a' exists
"""
# Determine which columns to validate
cols_to_validate = (
input_columns if input_columns is not None else self.input_columns
)
# Check RecordBatch UDFs
if self.arg_type == UDFArgType.RECORD_BATCH:
# Error if input_columns are specified for RecordBatch UDFs
if cols_to_validate is not None:
raise ValueError(
f"UDF '{self.name}' is a RecordBatch UDF but has input_columns "
f"{cols_to_validate} specified. RecordBatch UDFs receive the "
f"entire batch and should not declare input_columns. "
f"Remove the input_columns parameter."
)
# RecordBatch UDFs don't need column validation
return
# For Array and Scalar UDFs, input_columns must be defined
if cols_to_validate is None:
arg_type_name = self.arg_type.name if self.arg_type else "UNKNOWN"
raise ValueError(
f"UDF '{self.name}' (type: {arg_type_name}) has no input_columns "
f"defined. Array and Scalar UDFs must specify input columns either "
f"through function parameter names or the input_columns parameter."
)
# Validate all input columns exist in table schema
missing_columns = [
col for col in cols_to_validate if col not in table_schema.names
]
if missing_columns:
raise ValueError(
f"UDF '{self.name}' expects input columns {missing_columns} which are "
f"not found in table schema. Available columns: {table_schema.names}. "
f"Check your UDF's function parameter names or input_columns parameter."
)
# Validate type compatibility for each input column
self._validate_column_types(table_schema, cols_to_validate)
def _validate_column_types(
self, table_schema: pa.Schema, input_columns: list[str]
) -> None:
"""
Validate type compatibility between table schema and UDF expectations.
This method checks if the table column types match the UDF's type annotations.
If no type annotations are present or types can't be mapped, validation is
skipped with a warning.
Parameters
----------
table_schema: pa.Schema
The schema of the table being processed
input_columns: list[str]
The input column names to validate types for
Raises
------
ValueError: If there's a type mismatch between table schema and UDF expectations
Warns
-----
UserWarning: If type validation is skipped due to missing annotations or
unmappable types
"""
import warnings
# Get type annotations from the UDF function
annotations = _get_annotations(self.func)
if not annotations:
# No type annotations found - warn user
warnings.warn(
f"UDF '{self.name}' has no type annotations. Type validation will be "
f"skipped. Consider adding type hints to your UDF function parameters "
f"for better error detection.",
UserWarning,
stacklevel=4,
)
return
# For each input column, validate type if annotation exists
for col_name in input_columns:
# Get the actual type from table schema
table_field = table_schema.field(col_name)
table_type = table_field.type
# Get expected type from UDF signature if available
if col_name in annotations:
expected_type = annotations[col_name]
# Try to map expected type to PyArrow type for comparison
try:
expected_pa_type = self._python_type_to_arrow_type(expected_type)
# Check if types are compatible
if not self._types_compatible(table_type, expected_pa_type):
raise ValueError(
f"Type mismatch for column '{col_name}' in UDF "
f"'{self.name}': table has type {table_type}, but UDF "
f"expects {expected_pa_type} (from annotation "
f"{expected_type}). This will likely cause serialization "
f"or conversion errors during execution."
)
except (ValueError, KeyError):
# If we can't map the type, skip validation with warning
warnings.warn(
f"Could not validate type for column '{col_name}' in UDF "
f"'{self.name}' with annotation {expected_type}. Type "
f"validation skipped for this column.",
UserWarning,
stacklevel=4,
)
def _python_type_to_arrow_type(self, python_type) -> pa.DataType:
"""
Convert Python type annotation to PyArrow type.
Raises ValueError if type cannot be mapped.
"""
# Handle PyArrow types directly
if isinstance(python_type, pa.DataType):
return python_type
# Handle pa.Array annotation (for batched UDFs)
if python_type == pa.Array:
# Can't determine specific array type, so return None to skip validation
raise ValueError("Cannot validate generic pa.Array type")
# Map Python/numpy types to PyArrow types
type_map = {
bool: pa.bool_(),
bytes: pa.binary(),
float: pa.float32(),
int: pa.int64(),
str: pa.string(),
numpy.bool_: pa.bool_(),
numpy.uint8: pa.uint8(),
numpy.uint16: pa.uint16(),
numpy.uint32: pa.uint32(),
numpy.uint64: pa.uint64(),
numpy.int8: pa.int8(),
numpy.int16: pa.int16(),
numpy.int32: pa.int32(),
numpy.int64: pa.int64(),
numpy.float16: pa.float16(),
numpy.float32: pa.float32(),
numpy.float64: pa.float64(),
numpy.str_: pa.string(),
}
if python_type in type_map:
return type_map[python_type]
raise ValueError(f"Cannot map Python type {python_type} to PyArrow type")
def _types_compatible(self, actual: pa.DataType, expected: pa.DataType) -> bool:
"""
Check if actual type is compatible with expected type.
This is more permissive than exact equality, allowing for:
- Exact matches
- Nullable vs non-nullable variants
"""
# Exact match
if actual == expected:
return True
# Check base types match (ignoring nullability, precision differences)
# For numeric types, check if they're in the same family
if pa.types.is_integer(actual) and pa.types.is_integer(expected):
# Allow integer types if bit width and signedness match
return actual.bit_width == expected.bit_width and (
(
pa.types.is_signed_integer(actual)
and pa.types.is_signed_integer(expected)
)
or (
pa.types.is_unsigned_integer(actual)
and pa.types.is_unsigned_integer(expected)
)
)
if pa.types.is_floating(actual) and pa.types.is_floating(expected):
# Require exact match for floating point types (float32 vs float64 matters!)
return actual.bit_width == expected.bit_width
# For other types, require exact match
return False
def __call__(self, *args, use_applier: bool = False, **kwargs) -> pa.Array:
# dispatch coming from Applier or user calling with a `RecordBatch`
if use_applier or (len(args) == 1 and isinstance(args[0], pa.RecordBatch)):
record_batch = args[0]
match self.arg_type:
case UDFArgType.SCALAR:
return self._scalar_func_record_batch_call(record_batch)
case UDFArgType.ARRAY:
# Validate columns exist before accessing them
try:
arrs = [record_batch[col] for col in self.input_columns] # type:ignore
except KeyError as e:
raise KeyError(
f"UDF '{self.name}' failed: column {e} not found in "
f"RecordBatch. Available columns: "
f"{record_batch.schema.names}. UDF expects "
f"input_columns: {self.input_columns}."
) from e
return self.func(*arrs)
case UDFArgType.RECORD_BATCH:
if isinstance(record_batch, pa.RecordBatch):
return self.func(record_batch)
# a list of dicts with BlobFiles that need to de-ref'ed
assert isinstance(record_batch, list)
rb_list = []
for row in record_batch:
new_row = {}
for k, v in row.items():
if isinstance(v, BlobFile):
# read the blob file into memory
new_row[k] = v.readall()
continue
new_row[k] = v
rb_list.append(new_row)
rb = pa.RecordBatch.from_pylist(rb_list)
return self.func(rb)
# dispatch is trying to access the function's original pattern
return self.func(*args, **kwargs)