@@ -1581,6 +1581,7 @@ def _task_to_record_batches(
15811581 partition_spec : PartitionSpec | None = None ,
15821582 format_version : TableVersion = TableProperties .DEFAULT_FORMAT_VERSION ,
15831583 downcast_ns_timestamp_to_us : bool | None = None ,
1584+ batch_size : int | None = None ,
15841585) -> Iterator [pa .RecordBatch ]:
15851586 arrow_format = _get_file_format (task .file .file_format , pre_buffer = True , buffer_size = (ONE_MEGABYTE * 8 ))
15861587 with io .new_input (task .file .file_path ).open () as fin :
@@ -1612,14 +1613,18 @@ def _task_to_record_batches(
16121613
16131614 file_project_schema = prune_columns (file_schema , projected_field_ids , select_full_types = False )
16141615
1615- fragment_scanner = ds . Scanner . from_fragment (
1616- fragment = fragment ,
1617- schema = physical_schema ,
1616+ scanner_kwargs : dict [ str , Any ] = {
1617+ " fragment" : fragment ,
1618+ " schema" : physical_schema ,
16181619 # This will push down the query to Arrow.
16191620 # But in case there are positional deletes, we have to apply them first
1620- filter = pyarrow_filter if not positional_deletes else None ,
1621- columns = [col .name for col in file_project_schema .columns ],
1622- )
1621+ "filter" : pyarrow_filter if not positional_deletes else None ,
1622+ "columns" : [col .name for col in file_project_schema .columns ],
1623+ }
1624+ if batch_size is not None :
1625+ scanner_kwargs ["batch_size" ] = batch_size
1626+
1627+ fragment_scanner = ds .Scanner .from_fragment (** scanner_kwargs )
16231628
16241629 next_index = 0
16251630 batches = fragment_scanner .to_batches ()
@@ -1756,7 +1761,7 @@ def to_table(self, tasks: Iterable[FileScanTask]) -> pa.Table:
17561761
17571762 return result
17581763
1759- def to_record_batches (self , tasks : Iterable [FileScanTask ]) -> Iterator [pa .RecordBatch ]:
1764+ def to_record_batches (self , tasks : Iterable [FileScanTask ], batch_size : int | None = None ) -> Iterator [pa .RecordBatch ]:
17601765 """Scan the Iceberg table and return an Iterator[pa.RecordBatch].
17611766
17621767 Returns an Iterator of pa.RecordBatch with data from the Iceberg table
@@ -1783,7 +1788,7 @@ def batches_for_task(task: FileScanTask) -> list[pa.RecordBatch]:
17831788 # Materialize the iterator here to ensure execution happens within the executor.
17841789 # Otherwise, the iterator would be lazily consumed later (in the main thread),
17851790 # defeating the purpose of using executor.map.
1786- return list (self ._record_batches_from_scan_tasks_and_deletes ([task ], deletes_per_file ))
1791+ return list (self ._record_batches_from_scan_tasks_and_deletes ([task ], deletes_per_file , batch_size ))
17871792
17881793 limit_reached = False
17891794 for batches in executor .map (batches_for_task , tasks ):
@@ -1803,7 +1808,7 @@ def batches_for_task(task: FileScanTask) -> list[pa.RecordBatch]:
18031808 break
18041809
18051810 def _record_batches_from_scan_tasks_and_deletes (
1806- self , tasks : Iterable [FileScanTask ], deletes_per_file : dict [str , list [ChunkedArray ]]
1811+ self , tasks : Iterable [FileScanTask ], deletes_per_file : dict [str , list [ChunkedArray ]], batch_size : int | None = None
18071812 ) -> Iterator [pa .RecordBatch ]:
18081813 total_row_count = 0
18091814 for task in tasks :
@@ -1822,6 +1827,7 @@ def _record_batches_from_scan_tasks_and_deletes(
18221827 self ._table_metadata .specs ().get (task .file .spec_id ),
18231828 self ._table_metadata .format_version ,
18241829 self ._downcast_ns_timestamp_to_us ,
1830+ batch_size ,
18251831 )
18261832 for batch in batches :
18271833 if self ._limit is not None :
0 commit comments