Skip to content

Query

geneva.query.GenevaQuery

Bases: BaseModel

Source code in geneva/query.py
class GenevaQuery(BaseModel):
    base: Query
    shuffle: bool | None = None
    shuffle_seed: int | None = None
    fragment_ids: list[int] | None = None
    with_row_address: bool | None = None
    column_udfs: list[ColumnUDF] | None = None

    def extract_column_udfs(self, packager: UDFPackager) -> list[ExtractedTransform]:
        """
        Loads a set of transforms that reflect the column_udfs and map_batches_udfs
        of the query.
        """
        transforms = []
        if self.column_udfs is not None:
            for column_udf in self.column_udfs:
                udf = packager.unmarshal(column_udf.udf.to_attrs())
                transforms.append(
                    ExtractedTransform(
                        output_index=column_udf.output_index,
                        output_name=column_udf.output_name,
                        udf=udf,
                    )
                )
        return transforms

base

base: Query

shuffle

shuffle: bool | None = None

shuffle_seed

shuffle_seed: int | None = None

fragment_ids

fragment_ids: list[int] | None = None

with_row_address

with_row_address: bool | None = None

column_udfs

column_udfs: list[ColumnUDF] | None = None

extract_column_udfs

extract_column_udfs(
    packager: UDFPackager,
) -> list[ExtractedTransform]

Loads a set of transforms that reflect the column_udfs and map_batches_udfs of the query.

Source code in geneva/query.py
def extract_column_udfs(self, packager: UDFPackager) -> list[ExtractedTransform]:
    """
    Loads a set of transforms that reflect the column_udfs and map_batches_udfs
    of the query.
    """
    transforms = []
    if self.column_udfs is not None:
        for column_udf in self.column_udfs:
            udf = packager.unmarshal(column_udf.udf.to_attrs())
            transforms.append(
                ExtractedTransform(
                    output_index=column_udf.output_index,
                    output_name=column_udf.output_name,
                    udf=udf,
                )
            )
    return transforms

geneva.query.GenevaQueryBuilder

Bases: LanceEmptyQueryBuilder

A proxy that wraps LanceQueryBuilder and adds geneva-specific functionality.

Source code in geneva/query.py
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
class GenevaQueryBuilder(LanceEmptyQueryBuilder):
    """A proxy that wraps LanceQueryBuilder and adds geneva-specific functionality."""

    def __init__(self, table: "Table") -> None:
        super().__init__(table)
        self._table = table
        self._shuffle = None
        self._shuffle_seed = None
        self._fragment_ids = None
        self._with_row_address = None
        self._internal_api_enabled = False
        self._column_udfs = None
        self._with_where_as_bool_column = False

    def _internal_api_only(self) -> None:
        if not self._internal_api_enabled:
            raise ValueError(
                "This method is for internal use only and subject to change. "
                "Call enable_internal_api() first to enable."
            )

    @override
    def select(self, columns: list[str] | Mapping[str, str | UDF]) -> Self:
        """
        Select the output columns of the query.

        Parameters
        ----------
        columns: list[str] | dict[str, str] | dict[str, UDF]
            The columns to select.

            If a list of strings, each string is the name of a column to select.

            If a dictionary of strings then the key is the output name of the column
            and the value is either an SQL expression (str) or a UDF.
        """
        if isinstance(columns, dict):
            self._column_udfs = {
                key: (value, index)
                for (index, (key, value)) in enumerate(columns.items())
                if isinstance(value, UDF)
            }
            # Filter out UDFs and create a proper dict for super().select()
            filtered_columns: dict[str, str] = {
                key: str(value)  # Convert to string if needed
                for key, value in columns.items()
                if not isinstance(value, UDF)
            }
            super().select(filtered_columns)
        else:
            super().select(columns)  # type: ignore[arg-type]
        return self

    def shuffle(self, seed: int | None = None) -> Self:
        """Shuffle the rows of the table"""
        self._shuffle = True
        self._shuffle_seed = seed
        return self

    def enable_internal_api(self) -> Self:
        """
        Enable internal APIs
        WARNING: Internal APIs are subject to change
        """
        self._internal_api_enabled = True
        return self

    def with_fragments(self, fragments: list[int] | int) -> Self:
        """
        Filter the rows of the table to only include the specified fragments.
        """
        self._internal_api_only()
        self._fragment_ids = [fragments] if isinstance(fragments, int) else fragments
        return self

    def with_row_address(self) -> Self:
        """
        Include the physical row address in the result
        WARNING: INTERNAL API DETAIL
        """
        self._internal_api_only()
        self._with_row_address = True
        return self

    def with_where_as_bool_column(self) -> Self:
        """
        Include the filter selected column in the result instead of just selected rows
        """
        self._internal_api_only()
        self._with_where_as_bool_column = True
        return self

    @override
    def to_query_object(self) -> GenevaQuery:  # type: ignore
        query = super().to_query_object()
        result = GenevaQuery(
            base=query,
            shuffle=self._shuffle,
            shuffle_seed=self._shuffle_seed,
            fragment_ids=self._fragment_ids,
            with_row_address=self._with_row_address,
        )
        if self._column_udfs:
            result.column_udfs = [
                ColumnUDF(
                    output_index=index,
                    output_name=name,
                    udf=PydanticUDFSpec.from_attrs(
                        self._table._conn._packager.marshal(udf)
                    ),
                )
                for (name, (udf, index)) in self._column_udfs.items()
            ]
        return result

    @classmethod
    def from_query_object(
        cls, table: "Table", query: GenevaQuery
    ) -> "GenevaQueryBuilder":
        result = GenevaQueryBuilder(table)

        # TODO: Add from_query_object to lancedb.  For now, this will work
        # for simple (non-vector, non-fts) queries.
        if query.base.columns is not None:
            result.select(query.base.columns)
        if query.base.filter:
            result.where(query.base.filter)
        if query.base.limit:
            result.limit(query.base.limit)
        if query.base.offset:
            result.offset(query.base.offset)
        if query.base.with_row_id:
            result.with_row_id(True)

        result._shuffle = query.shuffle
        result._shuffle_seed = query.shuffle_seed
        if query.column_udfs:
            result._column_udfs = {}
            for column_udf in query.column_udfs:
                udf = table._conn._packager.unmarshal(column_udf.udf.to_attrs())
                result._column_udfs[column_udf.output_name] = (
                    udf,
                    column_udf.output_index,
                )
        result._fragment_ids = query.fragment_ids
        result._with_row_address = query.with_row_address
        result._internal_api_enabled = True
        return result

    def take_rows(self, rows: list[int]) -> pa.Table:
        query = self.to_query_object()
        return self._table.to_lance()._take_rows(rows, query.base.columns)

    def _infer_sql_expression_type(self, expr: str, dest_name: str) -> pa.DataType:
        """Infer output type of SQL expression by evaluating on source table."""
        dataset = self._table.to_lance()
        scanner = dataset.scanner(columns={dest_name: expr}, limit=1)
        return scanner.projected_schema.field(dest_name).type

    def _schema_for_query(self, include_metacols: bool = True) -> pa.Schema:
        schema = self._table.schema

        base_query = super().to_query_object()

        if base_query.columns is not None:
            if isinstance(base_query.columns, list):
                fields = [_resolve_field(schema, col) for col in base_query.columns]
            else:
                fields = []
                for dest_name, expr in base_query.columns.items():
                    try:
                        field = _resolve_field(schema, expr)
                    except KeyError:
                        if dest_name == BACKFILL_SELECTED:
                            # HACK special case for BACKFILL_SELECTED
                            field = pa.field(dest_name, pa.bool_(), True)
                        else:
                            # SQL expression - infer type by executing on sample data
                            inferred_type = self._infer_sql_expression_type(
                                expr, dest_name
                            )
                            field = pa.field(dest_name, inferred_type, nullable=True)

                    fields.append(pa.field(dest_name, field.type, field.nullable))

        else:
            fields = list(schema)

        if self._column_udfs is not None:
            for output_name, (udf, output_index) in self._column_udfs.items():
                fields.insert(
                    output_index,
                    pa.field(output_name, udf.data_type, metadata=udf.field_metadata),
                )

        if include_metacols and base_query.with_row_id:
            fields += [pa.field("_rowid", pa.uint64())]

        if include_metacols and self._with_row_address:
            fields += [pa.field("_rowaddr", pa.int64())]

        return pa.schema(fields)

    @property
    def schema(self) -> pa.Schema:
        return self._schema_for_query()

    @override
    def to_batches(
        self, /, batch_size: int | None = None, *, timeout: timedelta | None = None
    ) -> pa.RecordBatchReader:
        schema_no_meta = self._schema_for_query(include_metacols=False)

        # Collect blob columns.
        blob_columns: dict[str, int] = {
            f.name: idx
            for idx, f in enumerate(schema_no_meta)
            if f.metadata and f.metadata.get(b"lance-encoding:blob") == b"true"
        }

        base_query = super().to_query_object()
        orig_filter = base_query.filter

        # Enforce row_id if we need blobs or where-as-column
        if blob_columns or (self._with_where_as_bool_column and orig_filter):
            base_query.with_row_id = True

        # UDF extra-column bookkeeping
        extra_columns: list[str] = []
        if self._column_udfs and base_query.columns is not None:
            # collect all needed inputs
            current_cols = (
                set(base_query.columns)
                if isinstance(base_query.columns, list)
                else set(base_query.columns.keys())
            )
            for udf, _ in self._column_udfs.values():
                for inp in udf.input_columns or []:
                    if inp not in current_cols:
                        extra_columns.append(inp)
                        current_cols.add(inp)

        # append extra_columns into the query, track their positions
        added_columns: list[int] = []
        if base_query.columns is not None and extra_columns:
            if isinstance(base_query.columns, list):
                pos = len(base_query.columns)
                for col in extra_columns:
                    added_columns.append(pos)
                    base_query.columns.append(col)
                    pos += 1
            else:
                pos = len(base_query.columns)
                for col in extra_columns:
                    added_columns.append(pos)
                    base_query.columns[col] = col
                    pos += 1

        # sanity‐check unsupported features
        if self._shuffle:
            raise NotImplementedError("Shuffle is not yet implemented")
        if base_query.vector:
            raise NotImplementedError("Vector search not yet implemented")
        if base_query.full_text_query:
            raise NotImplementedError("FTS search not yet implemented")

        dataset: LanceDataset = self._table.to_lance()
        fragments = (
            [dataset.get_fragment(fid) for fid in self._fragment_ids]
            if self._fragment_ids
            else list(dataset.get_fragments())
        )

        schema_with_meta = self._schema_for_query(include_metacols=True)

        # Fragment‐by‐fragment generator
        def gen() -> Iterator[pa.RecordBatch]:
            for frag in fragments:
                # build per‐fragment matching_ids if we’re doing where-as-column
                frag_ids: set[int] | None = None
                if self._with_where_as_bool_column and orig_filter:
                    frag_ids = set()
                    id_scan = dataset.scanner(
                        columns=["_rowid"],
                        with_row_id=True,
                        filter=orig_filter,
                        fragments=[frag],
                    )
                    for id_batch in id_scan.to_batches():
                        rowid_list = id_batch["_rowid"].to_pylist()
                        # Filter out None values and convert to ints
                        valid_ids = [int(rid) for rid in rowid_list if rid is not None]
                        frag_ids.update(valid_ids)

                # choose filter for main scan
                scan_filter = None if frag_ids is not None else orig_filter

                # run the main scan over this fragment
                main_scan = dataset.scanner(
                    columns=base_query.columns,
                    with_row_id=base_query.with_row_id,
                    with_row_address=self._with_row_address,
                    filter=scan_filter,
                    batch_size=batch_size,
                    offset=base_query.offset,
                    limit=base_query.limit,
                    fragments=[frag],
                )
                for batch in main_scan.to_batches():
                    # blob injection
                    if blob_columns:
                        rowid_list = batch["_rowid"].to_pylist()  # type: ignore[index]
                        ids = [int(rid) for rid in rowid_list if rid is not None]
                        for col_name in blob_columns:
                            if hasattr(batch, "to_pylist"):
                                batch = batch.to_pylist()  # type: ignore[attr-defined]
                            else:
                                # batch is already a list
                                pass
                            try:
                                blob_files = dataset.take_blobs(col_name, ids=ids)
                                for elem, blob in zip(batch, blob_files, strict=True):  # type: ignore[arg-type]
                                    elem[col_name] = blob  # type: ignore[index]
                            except ValueError:
                                # not blobfile? (maybe because null?) return Null.
                                for elem in batch:
                                    elem[col_name] = None  # type: ignore[index]
                    # UDFs and drop UDF-only columns
                    if self._column_udfs:
                        for col_name, (udf, insert_idx) in self._column_udfs.items():
                            arr = udf(batch)
                            if hasattr(batch, "add_column"):
                                batch = batch.add_column(  # type: ignore[attr-defined]
                                    insert_idx, pa.field(col_name, arr.type), arr
                                )
                            else:
                                # Handle case where batch is a list
                                pass
                        # remove the extra_columns we only pulled for UDF inputs
                        for drop_idx in reversed(added_columns):
                            if hasattr(batch, "remove_column"):
                                batch = batch.remove_column(  # type: ignore[attr-defined]
                                    drop_idx + len(self._column_udfs)
                                )
                            else:
                                # Handle case where batch is a list
                                pass

                    # where-as-column mask
                    if frag_ids is not None:
                        if isinstance(batch, list):
                            # blob case -- a list of dicts
                            ids = [row["_rowid"] for row in batch]
                            mask = pa.array(
                                [rid in frag_ids for rid in ids], pa.bool_()
                            )
                            for i, _row in enumerate(batch):
                                batch[i][BACKFILL_SELECTED] = mask[i]

                        else:
                            # normal case - pa.RecordBatch
                            ids = batch["_rowid"].to_pylist()
                            mask = pa.array(
                                [rid in frag_ids for rid in ids], pa.bool_()
                            )
                            batch = batch.add_column(
                                batch.num_columns,
                                pa.field(BACKFILL_SELECTED, pa.bool_()),
                                mask,
                            )

                    yield batch  # type: ignore[misc]

        if blob_columns:
            return list(gen())  # type: ignore[return-value]
        return pa.RecordBatchReader.from_batches(schema_with_meta, gen())  # type: ignore[arg-type]

    @override
    def to_arrow(self, *args, timeout: timedelta | None = None) -> pa.Table:
        return pa.Table.from_batches(
            self.to_batches(*args, timeout=timeout), schema=self.schema
        )

    @override
    def rerank(self, reranker: Reranker) -> Self:
        raise NotImplementedError("rerank is not yet implemented")

    def create_materialized_view(self, conn: Connection, view_name: str) -> "Table":
        """
        Creates a materialized view of the table.

        The materialized view will be a table that contains the result of the query.
        The view will be populated via a pipeline job.

        Parameters
        ----------
        conn: Connection
            A connection to the database to create the view in.
        view_name: str
            The name of the view to create.

        Warnings
        --------
        UserWarning
            If the source table does not have stable row IDs enabled. Without stable
            row IDs, incremental refresh is only supported when refreshing to the
            same source version. Attempting to refresh to a different version will fail.
        """
        import warnings

        # Check if source table has stable row IDs enabled
        source_lance_ds = self._table.to_lance()
        fragments = list(source_lance_ds.get_fragments())

        # Validate that source table is not empty
        if not fragments:
            raise ValueError(
                f"Cannot create materialized view from empty table "
                f"'{self._table._ltbl.name}'.\n\n"
                "Materialized views require at least one row in the source "
                "table to determine storage characteristics (such as whether "
                "stable row IDs are enabled).\n\n"
                "Please add data to the source table before creating a "
                "materialized view."
            )

        source_has_stable_row_ids = has_stable_row_ids(fragments)

        if not source_has_stable_row_ids:
            warnings.warn(
                f"Creating materialized view from table '{self._table._ltbl.name}' "
                "without stable row IDs enabled.\n\n"
                "Without stable row IDs, you can only refresh the materialized view "
                "to the SAME source version it was created from. Attempting to refresh "
                "to a different version will fail because compaction operations may "
                "have changed row IDs.\n\n"
                "For full incremental refresh support across all versions, create the "
                "source table with stable row IDs enabled:\n"
                "  db.create_table(\n"
                "      name='table_name',\n"
                "      data=data,\n"
                "      storage_options={'new_table_enable_stable_row_ids': 'true'}\n"
                "  )\n\n"
                "Note: Both 'true' (string) and True (boolean) are accepted.\n\n"
                "Stable row IDs is a Lance feature (added in 0.21.0) exposed via "
                "lancedb's new_table_enable_stable_row_ids option (added in 0.25.4b3).",
                UserWarning,
                stacklevel=2,
            )

        view_schema = self._schema_for_query(include_metacols=True)
        view_schema = view_schema.insert(0, pa.field("__is_set", pa.bool_()))
        view_schema = view_schema.insert(0, pa.field("__source_row_id", pa.int64()))

        query = self.to_query_object()

        # Capture source columns at creation time if no explicit select was used.
        # This ensures that when new columns are added to the source table,
        # the MV refresh only reads the columns that existed at creation time.
        if query.base.columns is None:
            query.base.columns = [
                n
                for n in self._table.schema.names
                if n not in ["__is_set", "__source_row_id"]
            ]

        view_schema = view_schema.with_metadata(
            {
                MATVIEW_META_QUERY: query.model_dump_json(),
                MATVIEW_META_BASE_TABLE: self._table._ltbl.name,
                MATVIEW_META_BASE_DBURI: self._table._conn_uri,
                MATVIEW_META_BASE_VERSION: str(self._table._ltbl.version),
                # Store materialized view format version.
                # Version 1: fragment+offset encoding (fragment_id << 32 | offset)
                #   - Used for v0.7.x and earlier (always)
                #   - Used for v0.8.x+ without stable row IDs
                #   - Source table does NOT have stable row IDs
                #   - Refresh only supported to same source version
                # Version 2: stable row IDs (v0.8.x+ with stable row IDs enabled)
                #   - Source table HAS stable row IDs
                #   - Refresh supported across source versions
                MATVIEW_META_VERSION: "2" if source_has_stable_row_ids else "1",
                # TODO: Add the base DB URI (should be possible
                # to get from lancedb table in future)
            }
        )

        row_ids_query = GenevaQuery(
            fragment_ids=query.fragment_ids,
            base=query.base,
        )
        row_ids_query.base.with_row_id = True
        row_ids_query.base.columns = []
        row_ids_query.column_udfs = None
        row_ids_query.with_row_address = None

        row_ids_query_builder = GenevaQueryBuilder.from_query_object(
            self._table, row_ids_query
        )

        row_ids_table = row_ids_query_builder.to_arrow()
        row_ids_table = row_ids_table.combine_chunks()
        # Copy is needed so that the array is not read-only
        row_ids = row_ids_table["_rowid"].to_numpy().copy()

        if query.shuffle:
            rng = default_rng(query.shuffle_seed)
            rng.shuffle(row_ids)

        initial_view_table_data = pa.table(
            [
                pa.array(row_ids, type=pa.int64()),
                pa.array([False] * len(row_ids), type=pa.bool_()),
            ],
            names=["__source_row_id", "__is_set"],
        )

        # Create the MV table with stable row IDs enabled to support chaining MVs
        # We use lance directly when not using namespaces, since lancedb doesn't
        # expose new_table_enable_stable_row_ids yet. For namespace connections,
        # fall back to lancedb (without stable row IDs on the MV table).
        from .table import Table

        # Check if using namespace connection
        is_namespace = conn.namespace_impl is not None

        if not is_namespace:
            # Use lance directly to enable stable row IDs
            import lance

            table_uri = f"{conn._uri}/{view_name}.lance"

            # Create empty table with full schema first
            # This ensures all columns exist (including UDF columns)
            empty_data = view_schema.empty_table()
            lance.write_dataset(
                empty_data,
                table_uri,
                mode="create",
                enable_stable_row_ids=True,
            )

            # Open the table through lancedb
            view_table = Table(conn, view_name)

            # Add the initial row ID data
            view_table.add(initial_view_table_data)
        else:
            # For namespaces, use lancedb's API (no stable row IDs on MV table yet)
            # Need to create table in two steps because partial schema is not allowed
            # on initial create_table call.
            view_table = conn.create_table(
                view_name, data=None, schema=view_schema, mode="create"
            )
            view_table.add(initial_view_table_data)

        for udf_col, (udf, _output_index) in (self._column_udfs or {}).items():
            input_cols = udf.input_columns
            view_table._configure_computed_column(udf_col, udf, input_cols)

        return view_table

schema

schema: Schema

select

select(
    columns: list[str] | Mapping[str, str | UDF],
) -> Self

Select the output columns of the query.

Parameters:

  • columns (list[str] | Mapping[str, str | UDF]) –

    The columns to select.

    If a list of strings, each string is the name of a column to select.

    If a dictionary of strings then the key is the output name of the column and the value is either an SQL expression (str) or a UDF.

Source code in geneva/query.py
@override
def select(self, columns: list[str] | Mapping[str, str | UDF]) -> Self:
    """
    Select the output columns of the query.

    Parameters
    ----------
    columns: list[str] | dict[str, str] | dict[str, UDF]
        The columns to select.

        If a list of strings, each string is the name of a column to select.

        If a dictionary of strings then the key is the output name of the column
        and the value is either an SQL expression (str) or a UDF.
    """
    if isinstance(columns, dict):
        self._column_udfs = {
            key: (value, index)
            for (index, (key, value)) in enumerate(columns.items())
            if isinstance(value, UDF)
        }
        # Filter out UDFs and create a proper dict for super().select()
        filtered_columns: dict[str, str] = {
            key: str(value)  # Convert to string if needed
            for key, value in columns.items()
            if not isinstance(value, UDF)
        }
        super().select(filtered_columns)
    else:
        super().select(columns)  # type: ignore[arg-type]
    return self

shuffle

shuffle(seed: int | None = None) -> Self

Shuffle the rows of the table

Source code in geneva/query.py
def shuffle(self, seed: int | None = None) -> Self:
    """Shuffle the rows of the table"""
    self._shuffle = True
    self._shuffle_seed = seed
    return self

enable_internal_api

enable_internal_api() -> Self

Enable internal APIs WARNING: Internal APIs are subject to change

Source code in geneva/query.py
def enable_internal_api(self) -> Self:
    """
    Enable internal APIs
    WARNING: Internal APIs are subject to change
    """
    self._internal_api_enabled = True
    return self

with_fragments

with_fragments(fragments: list[int] | int) -> Self

Filter the rows of the table to only include the specified fragments.

Source code in geneva/query.py
def with_fragments(self, fragments: list[int] | int) -> Self:
    """
    Filter the rows of the table to only include the specified fragments.
    """
    self._internal_api_only()
    self._fragment_ids = [fragments] if isinstance(fragments, int) else fragments
    return self

with_row_address

with_row_address() -> Self

Include the physical row address in the result WARNING: INTERNAL API DETAIL

Source code in geneva/query.py
def with_row_address(self) -> Self:
    """
    Include the physical row address in the result
    WARNING: INTERNAL API DETAIL
    """
    self._internal_api_only()
    self._with_row_address = True
    return self

with_where_as_bool_column

with_where_as_bool_column() -> Self

Include the filter selected column in the result instead of just selected rows

Source code in geneva/query.py
def with_where_as_bool_column(self) -> Self:
    """
    Include the filter selected column in the result instead of just selected rows
    """
    self._internal_api_only()
    self._with_where_as_bool_column = True
    return self

to_query_object

to_query_object() -> GenevaQuery
Source code in geneva/query.py
@override
def to_query_object(self) -> GenevaQuery:  # type: ignore
    query = super().to_query_object()
    result = GenevaQuery(
        base=query,
        shuffle=self._shuffle,
        shuffle_seed=self._shuffle_seed,
        fragment_ids=self._fragment_ids,
        with_row_address=self._with_row_address,
    )
    if self._column_udfs:
        result.column_udfs = [
            ColumnUDF(
                output_index=index,
                output_name=name,
                udf=PydanticUDFSpec.from_attrs(
                    self._table._conn._packager.marshal(udf)
                ),
            )
            for (name, (udf, index)) in self._column_udfs.items()
        ]
    return result

from_query_object

from_query_object(
    table: Table, query: GenevaQuery
) -> GenevaQueryBuilder
Source code in geneva/query.py
@classmethod
def from_query_object(
    cls, table: "Table", query: GenevaQuery
) -> "GenevaQueryBuilder":
    result = GenevaQueryBuilder(table)

    # TODO: Add from_query_object to lancedb.  For now, this will work
    # for simple (non-vector, non-fts) queries.
    if query.base.columns is not None:
        result.select(query.base.columns)
    if query.base.filter:
        result.where(query.base.filter)
    if query.base.limit:
        result.limit(query.base.limit)
    if query.base.offset:
        result.offset(query.base.offset)
    if query.base.with_row_id:
        result.with_row_id(True)

    result._shuffle = query.shuffle
    result._shuffle_seed = query.shuffle_seed
    if query.column_udfs:
        result._column_udfs = {}
        for column_udf in query.column_udfs:
            udf = table._conn._packager.unmarshal(column_udf.udf.to_attrs())
            result._column_udfs[column_udf.output_name] = (
                udf,
                column_udf.output_index,
            )
    result._fragment_ids = query.fragment_ids
    result._with_row_address = query.with_row_address
    result._internal_api_enabled = True
    return result

take_rows

take_rows(rows: list[int]) -> Table
Source code in geneva/query.py
def take_rows(self, rows: list[int]) -> pa.Table:
    query = self.to_query_object()
    return self._table.to_lance()._take_rows(rows, query.base.columns)

to_batches

to_batches(
    batch_size: int | None = None,
    *,
    timeout: timedelta | None = None,
) -> RecordBatchReader
Source code in geneva/query.py
@override
def to_batches(
    self, /, batch_size: int | None = None, *, timeout: timedelta | None = None
) -> pa.RecordBatchReader:
    schema_no_meta = self._schema_for_query(include_metacols=False)

    # Collect blob columns.
    blob_columns: dict[str, int] = {
        f.name: idx
        for idx, f in enumerate(schema_no_meta)
        if f.metadata and f.metadata.get(b"lance-encoding:blob") == b"true"
    }

    base_query = super().to_query_object()
    orig_filter = base_query.filter

    # Enforce row_id if we need blobs or where-as-column
    if blob_columns or (self._with_where_as_bool_column and orig_filter):
        base_query.with_row_id = True

    # UDF extra-column bookkeeping
    extra_columns: list[str] = []
    if self._column_udfs and base_query.columns is not None:
        # collect all needed inputs
        current_cols = (
            set(base_query.columns)
            if isinstance(base_query.columns, list)
            else set(base_query.columns.keys())
        )
        for udf, _ in self._column_udfs.values():
            for inp in udf.input_columns or []:
                if inp not in current_cols:
                    extra_columns.append(inp)
                    current_cols.add(inp)

    # append extra_columns into the query, track their positions
    added_columns: list[int] = []
    if base_query.columns is not None and extra_columns:
        if isinstance(base_query.columns, list):
            pos = len(base_query.columns)
            for col in extra_columns:
                added_columns.append(pos)
                base_query.columns.append(col)
                pos += 1
        else:
            pos = len(base_query.columns)
            for col in extra_columns:
                added_columns.append(pos)
                base_query.columns[col] = col
                pos += 1

    # sanity‐check unsupported features
    if self._shuffle:
        raise NotImplementedError("Shuffle is not yet implemented")
    if base_query.vector:
        raise NotImplementedError("Vector search not yet implemented")
    if base_query.full_text_query:
        raise NotImplementedError("FTS search not yet implemented")

    dataset: LanceDataset = self._table.to_lance()
    fragments = (
        [dataset.get_fragment(fid) for fid in self._fragment_ids]
        if self._fragment_ids
        else list(dataset.get_fragments())
    )

    schema_with_meta = self._schema_for_query(include_metacols=True)

    # Fragment‐by‐fragment generator
    def gen() -> Iterator[pa.RecordBatch]:
        for frag in fragments:
            # build per‐fragment matching_ids if we’re doing where-as-column
            frag_ids: set[int] | None = None
            if self._with_where_as_bool_column and orig_filter:
                frag_ids = set()
                id_scan = dataset.scanner(
                    columns=["_rowid"],
                    with_row_id=True,
                    filter=orig_filter,
                    fragments=[frag],
                )
                for id_batch in id_scan.to_batches():
                    rowid_list = id_batch["_rowid"].to_pylist()
                    # Filter out None values and convert to ints
                    valid_ids = [int(rid) for rid in rowid_list if rid is not None]
                    frag_ids.update(valid_ids)

            # choose filter for main scan
            scan_filter = None if frag_ids is not None else orig_filter

            # run the main scan over this fragment
            main_scan = dataset.scanner(
                columns=base_query.columns,
                with_row_id=base_query.with_row_id,
                with_row_address=self._with_row_address,
                filter=scan_filter,
                batch_size=batch_size,
                offset=base_query.offset,
                limit=base_query.limit,
                fragments=[frag],
            )
            for batch in main_scan.to_batches():
                # blob injection
                if blob_columns:
                    rowid_list = batch["_rowid"].to_pylist()  # type: ignore[index]
                    ids = [int(rid) for rid in rowid_list if rid is not None]
                    for col_name in blob_columns:
                        if hasattr(batch, "to_pylist"):
                            batch = batch.to_pylist()  # type: ignore[attr-defined]
                        else:
                            # batch is already a list
                            pass
                        try:
                            blob_files = dataset.take_blobs(col_name, ids=ids)
                            for elem, blob in zip(batch, blob_files, strict=True):  # type: ignore[arg-type]
                                elem[col_name] = blob  # type: ignore[index]
                        except ValueError:
                            # not blobfile? (maybe because null?) return Null.
                            for elem in batch:
                                elem[col_name] = None  # type: ignore[index]
                # UDFs and drop UDF-only columns
                if self._column_udfs:
                    for col_name, (udf, insert_idx) in self._column_udfs.items():
                        arr = udf(batch)
                        if hasattr(batch, "add_column"):
                            batch = batch.add_column(  # type: ignore[attr-defined]
                                insert_idx, pa.field(col_name, arr.type), arr
                            )
                        else:
                            # Handle case where batch is a list
                            pass
                    # remove the extra_columns we only pulled for UDF inputs
                    for drop_idx in reversed(added_columns):
                        if hasattr(batch, "remove_column"):
                            batch = batch.remove_column(  # type: ignore[attr-defined]
                                drop_idx + len(self._column_udfs)
                            )
                        else:
                            # Handle case where batch is a list
                            pass

                # where-as-column mask
                if frag_ids is not None:
                    if isinstance(batch, list):
                        # blob case -- a list of dicts
                        ids = [row["_rowid"] for row in batch]
                        mask = pa.array(
                            [rid in frag_ids for rid in ids], pa.bool_()
                        )
                        for i, _row in enumerate(batch):
                            batch[i][BACKFILL_SELECTED] = mask[i]

                    else:
                        # normal case - pa.RecordBatch
                        ids = batch["_rowid"].to_pylist()
                        mask = pa.array(
                            [rid in frag_ids for rid in ids], pa.bool_()
                        )
                        batch = batch.add_column(
                            batch.num_columns,
                            pa.field(BACKFILL_SELECTED, pa.bool_()),
                            mask,
                        )

                yield batch  # type: ignore[misc]

    if blob_columns:
        return list(gen())  # type: ignore[return-value]
    return pa.RecordBatchReader.from_batches(schema_with_meta, gen())  # type: ignore[arg-type]

to_arrow

to_arrow(*args, timeout: timedelta | None = None) -> Table
Source code in geneva/query.py
@override
def to_arrow(self, *args, timeout: timedelta | None = None) -> pa.Table:
    return pa.Table.from_batches(
        self.to_batches(*args, timeout=timeout), schema=self.schema
    )

rerank

rerank(reranker: Reranker) -> Self
Source code in geneva/query.py
@override
def rerank(self, reranker: Reranker) -> Self:
    raise NotImplementedError("rerank is not yet implemented")

create_materialized_view

create_materialized_view(
    conn: Connection, view_name: str
) -> Table

Creates a materialized view of the table.

The materialized view will be a table that contains the result of the query. The view will be populated via a pipeline job.

Parameters:

  • conn (Connection) –

    A connection to the database to create the view in.

  • view_name (str) –

    The name of the view to create.

Warnings

UserWarning If the source table does not have stable row IDs enabled. Without stable row IDs, incremental refresh is only supported when refreshing to the same source version. Attempting to refresh to a different version will fail.

Source code in geneva/query.py
def create_materialized_view(self, conn: Connection, view_name: str) -> "Table":
    """
    Creates a materialized view of the table.

    The materialized view will be a table that contains the result of the query.
    The view will be populated via a pipeline job.

    Parameters
    ----------
    conn: Connection
        A connection to the database to create the view in.
    view_name: str
        The name of the view to create.

    Warnings
    --------
    UserWarning
        If the source table does not have stable row IDs enabled. Without stable
        row IDs, incremental refresh is only supported when refreshing to the
        same source version. Attempting to refresh to a different version will fail.
    """
    import warnings

    # Check if source table has stable row IDs enabled
    source_lance_ds = self._table.to_lance()
    fragments = list(source_lance_ds.get_fragments())

    # Validate that source table is not empty
    if not fragments:
        raise ValueError(
            f"Cannot create materialized view from empty table "
            f"'{self._table._ltbl.name}'.\n\n"
            "Materialized views require at least one row in the source "
            "table to determine storage characteristics (such as whether "
            "stable row IDs are enabled).\n\n"
            "Please add data to the source table before creating a "
            "materialized view."
        )

    source_has_stable_row_ids = has_stable_row_ids(fragments)

    if not source_has_stable_row_ids:
        warnings.warn(
            f"Creating materialized view from table '{self._table._ltbl.name}' "
            "without stable row IDs enabled.\n\n"
            "Without stable row IDs, you can only refresh the materialized view "
            "to the SAME source version it was created from. Attempting to refresh "
            "to a different version will fail because compaction operations may "
            "have changed row IDs.\n\n"
            "For full incremental refresh support across all versions, create the "
            "source table with stable row IDs enabled:\n"
            "  db.create_table(\n"
            "      name='table_name',\n"
            "      data=data,\n"
            "      storage_options={'new_table_enable_stable_row_ids': 'true'}\n"
            "  )\n\n"
            "Note: Both 'true' (string) and True (boolean) are accepted.\n\n"
            "Stable row IDs is a Lance feature (added in 0.21.0) exposed via "
            "lancedb's new_table_enable_stable_row_ids option (added in 0.25.4b3).",
            UserWarning,
            stacklevel=2,
        )

    view_schema = self._schema_for_query(include_metacols=True)
    view_schema = view_schema.insert(0, pa.field("__is_set", pa.bool_()))
    view_schema = view_schema.insert(0, pa.field("__source_row_id", pa.int64()))

    query = self.to_query_object()

    # Capture source columns at creation time if no explicit select was used.
    # This ensures that when new columns are added to the source table,
    # the MV refresh only reads the columns that existed at creation time.
    if query.base.columns is None:
        query.base.columns = [
            n
            for n in self._table.schema.names
            if n not in ["__is_set", "__source_row_id"]
        ]

    view_schema = view_schema.with_metadata(
        {
            MATVIEW_META_QUERY: query.model_dump_json(),
            MATVIEW_META_BASE_TABLE: self._table._ltbl.name,
            MATVIEW_META_BASE_DBURI: self._table._conn_uri,
            MATVIEW_META_BASE_VERSION: str(self._table._ltbl.version),
            # Store materialized view format version.
            # Version 1: fragment+offset encoding (fragment_id << 32 | offset)
            #   - Used for v0.7.x and earlier (always)
            #   - Used for v0.8.x+ without stable row IDs
            #   - Source table does NOT have stable row IDs
            #   - Refresh only supported to same source version
            # Version 2: stable row IDs (v0.8.x+ with stable row IDs enabled)
            #   - Source table HAS stable row IDs
            #   - Refresh supported across source versions
            MATVIEW_META_VERSION: "2" if source_has_stable_row_ids else "1",
            # TODO: Add the base DB URI (should be possible
            # to get from lancedb table in future)
        }
    )

    row_ids_query = GenevaQuery(
        fragment_ids=query.fragment_ids,
        base=query.base,
    )
    row_ids_query.base.with_row_id = True
    row_ids_query.base.columns = []
    row_ids_query.column_udfs = None
    row_ids_query.with_row_address = None

    row_ids_query_builder = GenevaQueryBuilder.from_query_object(
        self._table, row_ids_query
    )

    row_ids_table = row_ids_query_builder.to_arrow()
    row_ids_table = row_ids_table.combine_chunks()
    # Copy is needed so that the array is not read-only
    row_ids = row_ids_table["_rowid"].to_numpy().copy()

    if query.shuffle:
        rng = default_rng(query.shuffle_seed)
        rng.shuffle(row_ids)

    initial_view_table_data = pa.table(
        [
            pa.array(row_ids, type=pa.int64()),
            pa.array([False] * len(row_ids), type=pa.bool_()),
        ],
        names=["__source_row_id", "__is_set"],
    )

    # Create the MV table with stable row IDs enabled to support chaining MVs
    # We use lance directly when not using namespaces, since lancedb doesn't
    # expose new_table_enable_stable_row_ids yet. For namespace connections,
    # fall back to lancedb (without stable row IDs on the MV table).
    from .table import Table

    # Check if using namespace connection
    is_namespace = conn.namespace_impl is not None

    if not is_namespace:
        # Use lance directly to enable stable row IDs
        import lance

        table_uri = f"{conn._uri}/{view_name}.lance"

        # Create empty table with full schema first
        # This ensures all columns exist (including UDF columns)
        empty_data = view_schema.empty_table()
        lance.write_dataset(
            empty_data,
            table_uri,
            mode="create",
            enable_stable_row_ids=True,
        )

        # Open the table through lancedb
        view_table = Table(conn, view_name)

        # Add the initial row ID data
        view_table.add(initial_view_table_data)
    else:
        # For namespaces, use lancedb's API (no stable row IDs on MV table yet)
        # Need to create table in two steps because partial schema is not allowed
        # on initial create_table call.
        view_table = conn.create_table(
            view_name, data=None, schema=view_schema, mode="create"
        )
        view_table.add(initial_view_table_data)

    for udf_col, (udf, _output_index) in (self._column_udfs or {}).items():
        input_cols = udf.input_columns
        view_table._configure_computed_column(udf_col, udf, input_cols)

    return view_table