Skip to content

Commit 471eb13

Browse files
refactor(bigframes): Simplify @udf wrapper object (#16556)
1 parent 636af26 commit 471eb13

9 files changed

Lines changed: 806 additions & 1069 deletions

File tree

packages/bigframes/bigframes/dataframe.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2851,7 +2851,7 @@ def _apply_callable(self, condition):
28512851
"""Executes the possible callable condition as needed."""
28522852
if callable(condition):
28532853
# When it's a bigframes function.
2854-
if hasattr(condition, "bigframes_bigquery_function"):
2854+
if isinstance(condition, bigframes.functions.Udf):
28552855
return self.apply(condition, axis=1)
28562856

28572857
# When it's a plain Python function.
@@ -4685,7 +4685,7 @@ def _prepare_export(
46854685
return array_value, id_overrides
46864686

46874687
def map(self, func, na_action: Optional[str] = None) -> DataFrame:
4688-
if not isinstance(func, bigframes.functions.BigqueryCallableRoutine):
4688+
if not isinstance(func, bigframes.functions.Udf):
46894689
raise TypeError("the first argument must be callable")
46904690

46914691
if na_action not in {None, "ignore"}:
@@ -4709,18 +4709,12 @@ def apply(self, func, *, axis=0, args: typing.Tuple = (), **kwargs):
47094709
)
47104710
warnings.warn(msg, category=bfe.FunctionAxisOnePreviewWarning)
47114711

4712-
if not isinstance(
4713-
func,
4714-
(
4715-
bigframes.functions.BigqueryCallableRoutine,
4716-
bigframes.functions.BigqueryCallableRowRoutine,
4717-
),
4718-
):
4712+
if not isinstance(func, bigframes.functions.Udf):
47194713
raise ValueError(
47204714
"For axis=1 a BigFrames BigQuery function must be used."
47214715
)
47224716

4723-
if func.is_row_processor:
4717+
if func.udf_def.signature.is_row_processor:
47244718
# Early check whether the dataframe dtypes are currently supported
47254719
# in the bigquery function
47264720
# NOTE: Keep in sync with the value converters used in the gcf code
@@ -4849,7 +4843,7 @@ def apply(self, func, *, axis=0, args: typing.Tuple = (), **kwargs):
48494843

48504844
# At this point column-wise or element-wise bigquery function operation will
48514845
# be performed (not supported).
4852-
if hasattr(func, "bigframes_bigquery_function"):
4846+
if isinstance(func, bigframes.functions.Udf):
48534847
raise formatter.create_exception_with_feedback_link(
48544848
NotImplementedError,
48554849
"BigFrames DataFrame '.apply()' does not support BigFrames "

packages/bigframes/bigframes/functions/__init__.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from bigframes.functions.function import (
15-
BigqueryCallableRoutine,
16-
BigqueryCallableRowRoutine,
17-
)
14+
from bigframes.functions.function import Udf
1815

1916
__all__ = [
20-
"BigqueryCallableRoutine",
21-
"BigqueryCallableRowRoutine",
17+
"Udf",
2218
]

packages/bigframes/bigframes/functions/_function_session.py

Lines changed: 22 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,18 @@
2020
import inspect
2121
import sys
2222
import threading
23-
import warnings
2423
from typing import (
25-
TYPE_CHECKING,
2624
Any,
25+
cast,
2726
Dict,
2827
Literal,
2928
Mapping,
3029
Optional,
3130
Sequence,
31+
TYPE_CHECKING,
3232
Union,
33-
cast,
3433
)
34+
import warnings
3535

3636
import google.api_core.exceptions
3737
from google.cloud import (
@@ -41,9 +41,9 @@
4141
resourcemanager_v3,
4242
)
4343

44+
from bigframes import clients
4445
import bigframes.exceptions as bfe
4546
import bigframes.formatting_helpers as bf_formatting
46-
from bigframes import clients
4747
from bigframes.functions import function as bq_functions
4848
from bigframes.functions import udf_def
4949

@@ -630,25 +630,15 @@ def wrapper(func):
630630
if udf_sig.is_row_processor:
631631
msg = bfe.format_message("input_types=Series is in preview.")
632632
warnings.warn(msg, stacklevel=1, category=bfe.PreviewWarning)
633-
return decorator(
634-
bq_functions.BigqueryCallableRowRoutine(
635-
udf_definition,
636-
session,
637-
cloud_function_ref=bigframes_cloud_function,
638-
local_func=func,
639-
is_managed=False,
640-
)
641-
)
642-
else:
643-
return decorator(
644-
bq_functions.BigqueryCallableRoutine(
645-
udf_definition,
646-
session,
647-
cloud_function_ref=bigframes_cloud_function,
648-
local_func=func,
649-
is_managed=False,
650-
)
633+
return decorator(
634+
bq_functions.BigqueryCallableRoutine(
635+
udf_definition,
636+
session,
637+
cloud_function_ref=bigframes_cloud_function,
638+
local_func=func,
639+
is_managed=False,
651640
)
641+
)
652642

653643
return wrapper
654644

@@ -834,8 +824,9 @@ def wrapper(func):
834824
bq_connection_manager,
835825
session=session, # type: ignore
836826
)
827+
code_def = udf_def.CodeDef.from_func(func)
837828
config = udf_def.ManagedFunctionConfig(
838-
code=udf_def.CodeDef.from_func(func),
829+
code=code_def,
839830
signature=udf_sig,
840831
max_batching_rows=max_batching_rows,
841832
container_cpu=container_cpu,
@@ -859,28 +850,18 @@ def wrapper(func):
859850
signature=udf_sig,
860851
)
861852

862-
if not name:
863-
self._update_temp_artifacts(full_rf_name, "")
864-
865-
decorator = functools.wraps(func)
866853
if udf_sig.is_row_processor:
867854
msg = bfe.format_message("input_types=Series is in preview.")
868855
warnings.warn(msg, stacklevel=1, category=bfe.PreviewWarning)
869-
assert session is not None # appease mypy
870-
return decorator(
871-
bq_functions.BigqueryCallableRowRoutine(
872-
udf_definition, session, local_func=func, is_managed=True
873-
)
874-
)
856+
857+
if not name: # session-owned resource - will be cleaned up automatically
858+
self._update_temp_artifacts(full_rf_name, "")
859+
return bq_functions.UdfRoutine(func=func, _udf_def=udf_definition)
860+
861+
# user-managed permanent resource - will not be cleaned up automatically
875862
else:
876-
assert session is not None # appease mypy
877-
return decorator(
878-
bq_functions.BigqueryCallableRoutine(
879-
udf_definition,
880-
session,
881-
local_func=func,
882-
is_managed=True,
883-
)
863+
return bq_functions.BigqueryCallableRoutine(
864+
udf_definition, session, local_func=func, is_managed=True
884865
)
885866

886867
return wrapper

packages/bigframes/bigframes/functions/function.py

Lines changed: 30 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,9 @@
1515
from __future__ import annotations
1616

1717
import logging
18-
from typing import TYPE_CHECKING, Callable, Optional
18+
from typing import Callable, Optional, Protocol, runtime_checkable, TYPE_CHECKING
1919

20-
if TYPE_CHECKING:
21-
import bigframes.series
22-
from bigframes.session import Session
20+
import dataclasses
2321

2422
import google.api_core.exceptions
2523
from google.cloud import bigquery
@@ -28,6 +26,11 @@
2826
from bigframes.functions import _function_session as bff_session
2927
from bigframes.functions import function_typing, udf_def
3028

29+
if TYPE_CHECKING:
30+
import bigframes.core.col
31+
from bigframes.session import Session
32+
import bigframes.series
33+
3134
logger = logging.getLogger(__name__)
3235

3336

@@ -90,13 +93,13 @@ def _try_import_routine(
9093

9194
def _try_import_row_routine(
9295
routine: bigquery.Routine, session: bigframes.Session
93-
) -> BigqueryCallableRowRoutine:
96+
) -> BigqueryCallableRoutine:
9497
udf_def = _routine_as_udf_def(routine, is_row_processor=True)
9598

9699
is_remote = (
97100
hasattr(routine, "remote_function_options") and routine.remote_function_options
98101
)
99-
return BigqueryCallableRowRoutine(udf_def, session, is_managed=not is_remote)
102+
return BigqueryCallableRoutine(udf_def, session, is_managed=not is_remote)
100103

101104

102105
def _routine_as_udf_def(
@@ -117,7 +120,6 @@ def _routine_as_udf_def(
117120
)
118121

119122

120-
# TODO(b/399894805): Support managed function.
121123
def read_gbq_function(
122124
function_name: str,
123125
*,
@@ -152,6 +154,18 @@ def read_gbq_function(
152154
return _try_import_routine(routine, session)
153155

154156

157+
@runtime_checkable
158+
class Udf(Protocol):
159+
"""
160+
Protocol for all BigFrames user-defined functions.
161+
162+
Has @runtime_checkable so functions like df.apply() can dispatch UDFs with isinstance() checks.
163+
"""
164+
165+
@property
166+
def udf_def(self) -> udf_def.BigqueryUdf: ...
167+
168+
155169
class BigqueryCallableRoutine:
156170
"""
157171
A reference to a routine in the context of a session.
@@ -178,8 +192,8 @@ def __call__(self, *args, **kwargs):
178192
if self._local_fun:
179193
return self._local_fun(*args, **kwargs)
180194
# avoid circular imports
181-
import bigframes.session._io.bigquery as bf_io_bigquery
182195
from bigframes.core.compile.sqlglot import sql as sg_sql
196+
import bigframes.session._io.bigquery as bf_io_bigquery
183197

184198
args_string = ", ".join([sg_sql.to_sql(sg_sql.literal(v)) for v in args])
185199
sql = f"SELECT `{str(self._udf_def.routine_ref)}`({args_string})"
@@ -202,7 +216,7 @@ def bigframes_remote_function(self):
202216

203217
@property
204218
def is_row_processor(self) -> bool:
205-
return False
219+
return self.udf_def.signature.is_row_processor
206220

207221
@property
208222
def udf_def(self) -> udf_def.BigqueryUdf:
@@ -225,75 +239,16 @@ def bigframes_bigquery_function_output_dtype(self):
225239
return self.udf_def.signature.output.emulating_type.bf_type
226240

227241

228-
class BigqueryCallableRowRoutine:
229-
"""
230-
A reference to a routine in the context of a session.
231-
232-
Can be used both directly as a callable, or as an input to dataframe ops that take a callable.
233-
"""
234-
235-
def __init__(
236-
self,
237-
udf_def: udf_def.BigqueryUdf,
238-
session: bigframes.Session,
239-
*,
240-
local_func: Optional[Callable] = None,
241-
cloud_function_ref: Optional[str] = None,
242-
is_managed: bool = False,
243-
):
244-
assert udf_def.signature.is_row_processor
245-
self._udf_def = udf_def
246-
self._session = session
247-
self._local_fun = local_func
248-
self._cloud_function = cloud_function_ref
249-
self._is_managed = is_managed
242+
@dataclasses.dataclass(frozen=True)
243+
class UdfRoutine:
244+
func: Callable
245+
# Try not to depend on this, bq managed function creation will be deferred later
246+
# And this ref will be replaced with requirements rather to support lazy creation
247+
_udf_def: udf_def.BigqueryUdf
250248

251249
def __call__(self, *args, **kwargs):
252-
if self._local_fun:
253-
return self._local_fun(*args, **kwargs)
254-
# avoid circular imports
255-
import bigframes.session._io.bigquery as bf_io_bigquery
256-
from bigframes.core.compile.sqlglot import sql as sg_sql
257-
258-
args_string = ", ".join([sg_sql.to_sql(sg_sql.literal(v)) for v in args])
259-
sql = f"SELECT `{str(self._udf_def.routine_ref)}`({args_string})"
260-
iter, job = bf_io_bigquery.start_query_with_client(
261-
self._session.bqclient,
262-
sql=sql,
263-
query_with_job=True,
264-
job_config=bigquery.QueryJobConfig(),
265-
publisher=self._session._publisher,
266-
) # type: ignore
267-
return list(iter.to_arrow().to_pydict().values())[0][0]
268-
269-
@property
270-
def bigframes_bigquery_function(self) -> str:
271-
return str(self._udf_def.routine_ref)
272-
273-
@property
274-
def bigframes_remote_function(self):
275-
return None if self._is_managed else str(self._udf_def.routine_ref)
276-
277-
@property
278-
def is_row_processor(self) -> bool:
279-
return True
250+
return self.func(*args, **kwargs)
280251

281252
@property
282253
def udf_def(self) -> udf_def.BigqueryUdf:
283254
return self._udf_def
284-
285-
@property
286-
def bigframes_cloud_function(self) -> Optional[str]:
287-
return self._cloud_function
288-
289-
@property
290-
def input_dtypes(self):
291-
return tuple(arg.bf_type for arg in self.udf_def.signature.inputs)
292-
293-
@property
294-
def output_dtype(self):
295-
return self.udf_def.signature.output.bf_type
296-
297-
@property
298-
def bigframes_bigquery_function_output_dtype(self):
299-
return self.udf_def.signature.output.emulating_type.bf_type

packages/bigframes/bigframes/functions/function_template.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def convert_from_bq_json(type_, arg):
4343
import base64
4444
import collections
4545

46-
converters = collections.defaultdict(lambda: (lambda value: value)) # type: ignore
46+
converters = collections.defaultdict(lambda: lambda value: value) # type: ignore
4747
converters["BYTES"] = base64.b64decode
4848
converter = converters[type_]
4949
return converter(arg) if arg is not None else None
@@ -53,7 +53,7 @@ def convert_to_bq_json(type_, arg):
5353
import base64
5454
import collections
5555

56-
converters = collections.defaultdict(lambda: (lambda value: value)) # type: ignore
56+
converters = collections.defaultdict(lambda: lambda value: value) # type: ignore
5757
converters["BYTES"] = lambda value: base64.b64encode(value).decode("utf-8")
5858
converter = converters[type_]
5959
return converter(arg) if arg is not None else None

packages/bigframes/bigframes/functions/udf_def.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,14 @@ def stable_hash(self) -> bytes:
457457

458458
return hash_val.digest()
459459

460+
def to_callable(self):
461+
"""
462+
Reconstructs the python callable from the pickled code.
463+
464+
Assumption: package_requirements match local environment
465+
"""
466+
return cloudpickle.loads(self.pickled_code)
467+
460468

461469
@dataclasses.dataclass(frozen=True)
462470
class ManagedFunctionConfig:

0 commit comments

Comments
 (0)