Skip to content

Commit 5ac6810

Browse files
authored
feat: add bigquery.ml.generate_text function (#2403)
Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [ ] Make sure to open an issue as a [bug/issue](https://github.com/googleapis/python-bigquery-dataframes/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [ ] Ensure the tests and linter pass - [ ] Code coverage does not decrease (if any source code was changed) - [ ] Appropriate docs were updated (if necessary) Fixes #<issue_number_goes_here> 🦕
1 parent 6fef9be commit 5ac6810

File tree

11 files changed

+237
-11
lines changed

11 files changed

+237
-11
lines changed

bigframes/bigquery/_operations/ml.py

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from __future__ import annotations
1616

17-
from typing import cast, Mapping, Optional, Union
17+
from typing import cast, List, Mapping, Optional, Union
1818

1919
import bigframes_vendored.constants
2020
import google.cloud.bigquery
@@ -431,3 +431,92 @@ def transform(
431431
return bpd.read_gbq_query(sql)
432432
else:
433433
return session.read_gbq_query(sql)
434+
435+
436+
@log_adapter.method_logger(custom_base_name="bigquery_ml")
437+
def generate_text(
438+
model: Union[bigframes.ml.base.BaseEstimator, str, pd.Series],
439+
input_: Union[pd.DataFrame, dataframe.DataFrame, str],
440+
*,
441+
temperature: Optional[float] = None,
442+
max_output_tokens: Optional[int] = None,
443+
top_k: Optional[int] = None,
444+
top_p: Optional[float] = None,
445+
flatten_json_output: Optional[bool] = None,
446+
stop_sequences: Optional[List[str]] = None,
447+
ground_with_google_search: Optional[bool] = None,
448+
request_type: Optional[str] = None,
449+
) -> dataframe.DataFrame:
450+
"""
451+
Generates text using a BigQuery ML model.
452+
453+
See the `BigQuery ML GENERATE_TEXT function syntax
454+
<https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-generate-text>`_
455+
for additional reference.
456+
457+
Args:
458+
model (bigframes.ml.base.BaseEstimator or str):
459+
The model to use for text generation.
460+
input_ (Union[bigframes.pandas.DataFrame, str]):
461+
The DataFrame or query to use for text generation.
462+
temperature (float, optional):
463+
A FLOAT64 value that is used for sampling promiscuity. The value
464+
must be in the range ``[0.0, 1.0]``. A lower temperature works well
465+
for prompts that expect a more deterministic and less open-ended
466+
or creative response, while a higher temperature can lead to more
467+
diverse or creative results. A temperature of ``0`` is
468+
deterministic, meaning that the highest probability response is
469+
always selected.
470+
max_output_tokens (int, optional):
471+
An INT64 value that sets the maximum number of tokens in the
472+
generated text.
473+
top_k (int, optional):
474+
An INT64 value that changes how the model selects tokens for
475+
output. A ``top_k`` of ``1`` means the next selected token is the
476+
most probable among all tokens in the model's vocabulary. A
477+
``top_k`` of ``3`` means that the next token is selected from
478+
among the three most probable tokens by using temperature. The
479+
default value is ``40``.
480+
top_p (float, optional):
481+
A FLOAT64 value that changes how the model selects tokens for
482+
output. Tokens are selected from most probable to least probable
483+
until the sum of their probabilities equals the ``top_p`` value.
484+
For example, if tokens A, B, and C have a probability of 0.3, 0.2,
485+
and 0.1 and the ``top_p`` value is ``0.5``, then the model will
486+
select either A or B as the next token by using temperature. The
487+
default value is ``0.95``.
488+
flatten_json_output (bool, optional):
489+
A BOOL value that determines the content of the generated JSON column.
490+
stop_sequences (List[str], optional):
491+
An ARRAY<STRING> value that contains the stop sequences for the model.
492+
ground_with_google_search (bool, optional):
493+
A BOOL value that determines whether to ground the model with Google Search.
494+
request_type (str, optional):
495+
A STRING value that contains the request type for the model.
496+
497+
Returns:
498+
bigframes.pandas.DataFrame:
499+
The generated text.
500+
"""
501+
import bigframes.pandas as bpd
502+
503+
model_name, session = _get_model_name_and_session(model, input_)
504+
table_sql = _to_sql(input_)
505+
506+
sql = bigframes.core.sql.ml.generate_text(
507+
model_name=model_name,
508+
table=table_sql,
509+
temperature=temperature,
510+
max_output_tokens=max_output_tokens,
511+
top_k=top_k,
512+
top_p=top_p,
513+
flatten_json_output=flatten_json_output,
514+
stop_sequences=stop_sequences,
515+
ground_with_google_search=ground_with_google_search,
516+
request_type=request_type,
517+
)
518+
519+
if session is None:
520+
return bpd.read_gbq_query(sql)
521+
else:
522+
return session.read_gbq_query(sql)

bigframes/bigquery/ml.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
create_model,
2424
evaluate,
2525
explain_predict,
26+
generate_text,
2627
global_explain,
2728
predict,
2829
transform,
@@ -35,4 +36,5 @@
3536
"explain_predict",
3637
"global_explain",
3738
"transform",
39+
"generate_text",
3840
]

bigframes/core/sql/ml.py

Lines changed: 77 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414

1515
from __future__ import annotations
1616

17-
from typing import Dict, Mapping, Optional, Union
17+
import collections.abc
18+
import json
19+
from typing import Any, Dict, List, Mapping, Optional, Union
1820

1921
import bigframes.core.compile.googlesql as googlesql
2022
import bigframes.core.sql
@@ -100,14 +102,41 @@ def create_model_ddl(
100102

101103

102104
def _build_struct_sql(
103-
struct_options: Mapping[str, Union[str, int, float, bool]]
105+
struct_options: Mapping[
106+
str,
107+
Union[str, int, float, bool, Mapping[str, str], List[str], Mapping[str, Any]],
108+
]
104109
) -> str:
105110
if not struct_options:
106111
return ""
107112

108113
rendered_options = []
109114
for option_name, option_value in struct_options.items():
110-
rendered_val = bigframes.core.sql.simple_literal(option_value)
115+
if option_name == "model_params":
116+
json_str = json.dumps(option_value)
117+
# Escape single quotes for SQL string literal
118+
sql_json_str = json_str.replace("'", "''")
119+
rendered_val = f"JSON'{sql_json_str}'"
120+
elif isinstance(option_value, collections.abc.Mapping):
121+
struct_body = ", ".join(
122+
[
123+
f"{bigframes.core.sql.simple_literal(v)} AS {k}"
124+
for k, v in option_value.items()
125+
]
126+
)
127+
rendered_val = f"STRUCT({struct_body})"
128+
elif isinstance(option_value, list):
129+
rendered_val = (
130+
"["
131+
+ ", ".join(
132+
[bigframes.core.sql.simple_literal(v) for v in option_value]
133+
)
134+
+ "]"
135+
)
136+
elif isinstance(option_value, bool):
137+
rendered_val = str(option_value).lower()
138+
else:
139+
rendered_val = bigframes.core.sql.simple_literal(option_value)
111140
rendered_options.append(f"{rendered_val} AS {option_name}")
112141
return f", STRUCT({', '.join(rendered_options)})"
113142

@@ -151,7 +180,7 @@ def predict(
151180
"""Encode the ML.PREDICT statement.
152181
See https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-predict for reference.
153182
"""
154-
struct_options = {}
183+
struct_options: Dict[str, Union[str, int, float, bool]] = {}
155184
if threshold is not None:
156185
struct_options["threshold"] = threshold
157186
if keep_original_columns is not None:
@@ -205,7 +234,7 @@ def global_explain(
205234
"""Encode the ML.GLOBAL_EXPLAIN statement.
206235
See https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-global-explain for reference.
207236
"""
208-
struct_options = {}
237+
struct_options: Dict[str, Union[str, int, float, bool]] = {}
209238
if class_level_explain is not None:
210239
struct_options["class_level_explain"] = class_level_explain
211240

@@ -224,3 +253,46 @@ def transform(
224253
"""
225254
sql = f"SELECT * FROM ML.TRANSFORM(MODEL {googlesql.identifier(model_name)}, ({table}))\n"
226255
return sql
256+
257+
258+
def generate_text(
259+
model_name: str,
260+
table: str,
261+
*,
262+
temperature: Optional[float] = None,
263+
max_output_tokens: Optional[int] = None,
264+
top_k: Optional[int] = None,
265+
top_p: Optional[float] = None,
266+
flatten_json_output: Optional[bool] = None,
267+
stop_sequences: Optional[List[str]] = None,
268+
ground_with_google_search: Optional[bool] = None,
269+
request_type: Optional[str] = None,
270+
) -> str:
271+
"""Encode the ML.GENERATE_TEXT statement.
272+
See https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-generate-text for reference.
273+
"""
274+
struct_options: Dict[
275+
str,
276+
Union[str, int, float, bool, Mapping[str, str], List[str], Mapping[str, Any]],
277+
] = {}
278+
if temperature is not None:
279+
struct_options["temperature"] = temperature
280+
if max_output_tokens is not None:
281+
struct_options["max_output_tokens"] = max_output_tokens
282+
if top_k is not None:
283+
struct_options["top_k"] = top_k
284+
if top_p is not None:
285+
struct_options["top_p"] = top_p
286+
if flatten_json_output is not None:
287+
struct_options["flatten_json_output"] = flatten_json_output
288+
if stop_sequences is not None:
289+
struct_options["stop_sequences"] = stop_sequences
290+
if ground_with_google_search is not None:
291+
struct_options["ground_with_google_search"] = ground_with_google_search
292+
if request_type is not None:
293+
struct_options["request_type"] = request_type
294+
295+
sql = f"SELECT * FROM ML.GENERATE_TEXT(MODEL {googlesql.identifier(model_name)}, ({table})"
296+
sql += _build_struct_sql(struct_options)
297+
sql += ")\n"
298+
return sql

notebooks/ml/bq_dataframes_ml_cross_validation.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -991,7 +991,7 @@
991991
],
992992
"metadata": {
993993
"kernelspec": {
994-
"display_name": "venv",
994+
"display_name": "venv (3.10.14)",
995995
"language": "python",
996996
"name": "python3"
997997
},
@@ -1005,7 +1005,7 @@
10051005
"name": "python",
10061006
"nbconvert_exporter": "python",
10071007
"pygments_lexer": "ipython3",
1008-
"version": "3.10.15"
1008+
"version": "3.10.14"
10091009
}
10101010
},
10111011
"nbformat": 4,

tests/unit/bigquery/test_ml.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,3 +163,40 @@ def test_transform_with_pandas_dataframe(read_pandas_mock, read_gbq_query_mock):
163163
assert "ML.TRANSFORM" in generated_sql
164164
assert f"MODEL `{MODEL_NAME}`" in generated_sql
165165
assert "(SELECT * FROM `pandas_df`)" in generated_sql
166+
167+
168+
@mock.patch("bigframes.pandas.read_gbq_query")
169+
@mock.patch("bigframes.pandas.read_pandas")
170+
def test_generate_text_with_pandas_dataframe(read_pandas_mock, read_gbq_query_mock):
171+
df = pd.DataFrame({"col1": [1, 2, 3]})
172+
read_pandas_mock.return_value._to_sql_query.return_value = (
173+
"SELECT * FROM `pandas_df`",
174+
[],
175+
[],
176+
)
177+
ml_ops.generate_text(
178+
MODEL_SERIES,
179+
input_=df,
180+
temperature=0.5,
181+
max_output_tokens=128,
182+
top_k=20,
183+
top_p=0.9,
184+
flatten_json_output=True,
185+
stop_sequences=["a", "b"],
186+
ground_with_google_search=True,
187+
request_type="TYPE",
188+
)
189+
read_pandas_mock.assert_called_once()
190+
read_gbq_query_mock.assert_called_once()
191+
generated_sql = read_gbq_query_mock.call_args[0][0]
192+
assert "ML.GENERATE_TEXT" in generated_sql
193+
assert f"MODEL `{MODEL_NAME}`" in generated_sql
194+
assert "(SELECT * FROM `pandas_df`)" in generated_sql
195+
assert "STRUCT(0.5 AS temperature" in generated_sql
196+
assert "128 AS max_output_tokens" in generated_sql
197+
assert "20 AS top_k" in generated_sql
198+
assert "0.9 AS top_p" in generated_sql
199+
assert "true AS flatten_json_output" in generated_sql
200+
assert "['a', 'b'] AS stop_sequences" in generated_sql
201+
assert "true AS ground_with_google_search" in generated_sql
202+
assert "'TYPE' AS request_type" in generated_sql
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
SELECT * FROM ML.EVALUATE(MODEL `my_model`, STRUCT(False AS perform_aggregation, 10 AS horizon, 0.95 AS confidence_level))
1+
SELECT * FROM ML.EVALUATE(MODEL `my_model`, STRUCT(false AS perform_aggregation, 10 AS horizon, 0.95 AS confidence_level))
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
SELECT * FROM ML.GENERATE_TEXT(MODEL `my_project.my_dataset.my_model`, (SELECT * FROM new_data))
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
SELECT * FROM ML.GENERATE_TEXT(MODEL `my_project.my_dataset.my_model`, (SELECT * FROM new_data), STRUCT(0.5 AS temperature, 128 AS max_output_tokens, 20 AS top_k, 0.9 AS top_p, true AS flatten_json_output, ['a', 'b'] AS stop_sequences, true AS ground_with_google_search, 'TYPE' AS request_type))
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
SELECT * FROM ML.GLOBAL_EXPLAIN(MODEL `my_model`, STRUCT(True AS class_level_explain))
1+
SELECT * FROM ML.GLOBAL_EXPLAIN(MODEL `my_model`, STRUCT(true AS class_level_explain))
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
SELECT * FROM ML.PREDICT(MODEL `my_model`, (SELECT * FROM new_data), STRUCT(True AS keep_original_columns))
1+
SELECT * FROM ML.PREDICT(MODEL `my_model`, (SELECT * FROM new_data), STRUCT(true AS keep_original_columns))

0 commit comments

Comments
 (0)