# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
This codelab demonstrates how to leverage the power of Keras 3, KerasNLP and TFX pipelines to fine-tune a pre-trained GPT-2 model on the IMDb movie reviews dataset. The dataset that is used in this demo is IMDB Reviews dataset.
Why is this pipeline useful?
TFX pipelines provide a powerful and structured approach to building and managing machine learning workflows, particularly those involving large language models. They offer significant advantages over traditional Python code, including:
Enhanced Reproducibility: TFX pipelines ensure consistent results by capturing all steps and dependencies, eliminating the inconsistencies often associated with manual workflows.
Scalability and Modularity: TFX allows for breaking down complex workflows into manageable, reusable components, promoting code organization.
Streamlined Fine-Tuning and Conversion: The pipeline structure streamlines the fine-tuning and conversion processes of large language models, significantly reducing manual effort and time.
Comprehensive Lineage Tracking: Through metadata tracking, TFX pipelines provide a clear understanding of data and model provenance, making debugging, auditing, and performance analysis much easier and more efficient.
By leveraging the benefits of TFX pipelines, organizations can effectively manage the complexity of large language model development and deployment, achieving greater efficiency and control over their machine learning processes.
Note
GPT-2 is used here only to demonstrate the end-to-end process; the techniques and tooling introduced in this codelab are potentially transferrable to other generative language models such as Google T5.
Before You Begin
Colab offers different kinds of runtimes. Make sure to go to Runtime -> Change runtime type and choose the GPU Hardware Accelerator runtime since you will finetune the GPT-2 model.
This tutorial's interactive pipeline is designed to function seamlessly with free Colab GPUs. However, for users opting to run the pipeline using the LocalDagRunner orchestrator (code provided at the end of this tutorial), a more substantial amount of GPU memory is required. Therefore, Colab Pro or a local machine equipped with a higher-capacity GPU is recommended for this approach.
Set Up
We first install required python packages.
Upgrade Pip
To avoid upgrading Pip in a system when running locally, check to make sure that we are running in Colab. Local systems can of course be upgraded separately.
try:
import colab
!pip install --upgrade pip
except:
pass
Install TFX, Keras 3, KerasNLP and required Libraries
pip install -q tfx tensorflow-text more_itertools tensorflow_datasetspip install -q --upgrade keras-nlppip install -q --upgrade keras
Did you restart the runtime?
If you are using Google Colab, the first time that you run the cell above, you must restart the runtime by clicking above "RESTART SESSION" button or using "Runtime > Restart session" menu. This is because of the way that Colab loads packages.
Let's check the TensorFlow, Keras, Keras-nlp and TFX library versions.
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import tensorflow as tf
print('TensorFlow version: {}'.format(tf.__version__))
from tfx import v1 as tfx
print('TFX version: {}'.format(tfx.__version__))
import keras
print('Keras version: {}'.format(keras.__version__))
import keras_nlp
print('Keras NLP version: {}'.format(keras_nlp.__version__))
keras.mixed_precision.set_global_policy("mixed_float16")
2024-06-19 10:24:56.971153: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2024-06-19 10:24:56.971203: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2024-06-19 10:24:56.972902: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered TensorFlow version: 2.15.1 TFX version: 1.15.1 Keras version: 3.3.3 Keras NLP version: 0.12.1
Using TFX Interactive Context
An interactive context is used to provide global context when running a TFX pipeline in a notebook without using a runner or orchestrator such as Apache Airflow or Kubeflow. This style of development is only useful when developing the code for a pipeline, and cannot currently be used to deploy a working pipeline to production.
from tfx.orchestration.experimental.interactive.interactive_context import InteractiveContext
context = InteractiveContext()
WARNING:absl:InteractiveContext pipeline_root argument not provided: using temporary directory /tmpfs/tmp/tfx-interactive-2024-06-19T10_25_02.997883-vluhbwr6 as root for pipeline outputs. WARNING:absl:InteractiveContext metadata_connection_config not provided: using SQLite ML Metadata database at /tmpfs/tmp/tfx-interactive-2024-06-19T10_25_02.997883-vluhbwr6/metadata.sqlite.
Pipeline Overview
Below are the components that this pipeline follows.
Custom Artifacts are artifacts that we have created for this pipeline. Artifacts are data that is produced by a component or consumed by a component. Artifacts are stored in a system for managing the storage and versioning of artifacts called MLMD.
Components are defined as the implementation of an ML task that you can use as a step in your pipeline
Aside from artifacts, Parameters are passed into the components to specify an argument.
ExampleGen
We create a custom ExampleGen component which we use to load a TensorFlow Datasets (TFDS) dataset. This uses a custom executor in a FileBasedExampleGen.
from typing import Any, Dict, List, Text
import tensorflow_datasets as tfds
import apache_beam as beam
import json
from tfx.components.example_gen.base_example_gen_executor import BaseExampleGenExecutor
from tfx.components.example_gen.component import FileBasedExampleGen
from tfx.components.example_gen import utils
from tfx.dsl.components.base import executor_spec
import os
import pprint
pp = pprint.PrettyPrinter()
@beam.ptransform_fn
@beam.typehints.with_input_types(beam.Pipeline)
@beam.typehints.with_output_types(tf.train.Example)
def _TFDatasetToExample(
pipeline: beam.Pipeline,
exec_properties: Dict[str, Any],
split_pattern: str
) -> beam.pvalue.PCollection:
"""Read a TensorFlow Dataset and create tf.Examples"""
custom_config = json.loads(exec_properties['custom_config'])
dataset_name = custom_config['dataset']
split_name = custom_config['split']
builder = tfds.builder(dataset_name)
builder.download_and_prepare()
return (pipeline
| 'MakeExamples' >> tfds.beam.ReadFromTFDS(builder, split=split_name)
| 'AsNumpy' >> beam.Map(tfds.as_numpy)
| 'ToDict' >> beam.Map(dict)
| 'ToTFExample' >> beam.Map(utils.dict_to_example)
)
class TFDSExecutor(BaseExampleGenExecutor):
def GetInputSourceToExamplePTransform(self) -> beam.PTransform:
"""Returns PTransform for TF Dataset to TF examples."""
return _TFDatasetToExample
For this demonstration, we're using a subset of the IMDb reviews dataset, representing 20% of the total data. This allows for a more manageable training process. You can modify the "custom_config" settings to experiment with larger amounts of data, up to the full dataset, depending on your computational resources.
example_gen = FileBasedExampleGen(
input_base='dummy',
custom_config={'dataset':'imdb_reviews', 'split':'train[:20%]'},
custom_executor_spec=executor_spec.BeamExecutorSpec(TFDSExecutor))
context.run(example_gen, enable_cache=False)
WARNING:apache_beam.runners.interactive.interactive_environment:Dependencies required for Interactive Beam PCollection visualization are not available, please use: `pip install apache-beam[interactive]` to install necessary dependencies to enable all data visualization features. WARNING:apache_beam.io.tfrecordio:Couldn't find python-snappy so the implementation of _TFRecordUtil._masked_crc32c is not as fast as it could be.
We've developed a handy utility for examining datasets composed of TFExamples. When used with the reviews dataset, this tool returns a clear dictionary containing both the text and the corresponding label.
def inspect_examples(component,
channel_name='examples',
split_name='train',
num_examples=1):
# Get the URI of the output artifact, which is a directory
full_split_name = 'Split-{}'.format(split_name)
print('channel_name: {}, split_name: {} (\"{}\"), num_examples: {}\n'.format(
channel_name, split_name, full_split_name, num_examples))
train_uri = os.path.join(
component.outputs[channel_name].get()[0].uri, full_split_name)
print('train_uri: {}'.format(train_uri))
# Get the list of files in this directory (all compressed TFRecord files)
tfrecord_filenames = [os.path.join(train_uri, name)
for name in os.listdir(train_uri)]
# Create a `TFRecordDataset` to read these files
dataset = tf.data.TFRecordDataset(tfrecord_filenames, compression_type="GZIP")
# Iterate over the records and print them
print()
for tfrecord in dataset.take(num_examples):
serialized_example = tfrecord.numpy()
example = tf.train.Example()
example.ParseFromString(serialized_example)
pp.pprint(example)
inspect_examples(example_gen, num_examples=1, split_name='eval')
channel_name: examples, split_name: eval ("Split-eval"), num_examples: 1
train_uri: /tmpfs/tmp/tfx-interactive-2024-06-19T10_25_02.997883-vluhbwr6/FileBasedExampleGen/examples/1/Split-eval
features {
feature {
key: "label"
value {
}
}
feature {
key: "text"
value {
bytes_list {
value: "This was an absolutely terrible movie. Don\'t be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie\'s ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor\'s like Christopher Walken\'s good name. I could barely sit through it."
}
}
}
}
StatisticsGen
StatisticsGen component computes statistics over your dataset for data analysis, such as the number of examples, the number of features, and the data types of the features. It uses the TensorFlow Data Validation library. StatisticsGen takes as input the dataset we just ingested using ExampleGen.
Note that the statistics generator is appropriate for tabular data, and therefore, text dataset for this LLM tutorial may not be the optimal dataset for the analysis with statistics generator.
from tfx.components import StatisticsGen
statistics_gen = tfx.components.StatisticsGen(
examples=example_gen.outputs['examples'], exclude_splits=['eval']
)
context.run(statistics_gen, enable_cache=False)
context.show(statistics_gen.outputs['statistics'])
SchemaGen
The SchemaGen component generates a schema based on your data statistics. (A schema defines the expected bounds, types, and properties of the features in your dataset.) It also uses the TensorFlow Data Validation library.
SchemaGen will take as input the statistics that we generated with StatisticsGen, looking at the training split by default.
schema_gen = tfx.components.SchemaGen(
statistics=statistics_gen.outputs['statistics'],
infer_feature_shape=False,
exclude_splits=['eval'],
)
context.run(schema_gen, enable_cache=False)
context.show(schema_gen.outputs['schema'])
ExampleValidator
The ExampleValidator component detects anomalies in your data, based on the expectations defined by the schema. It also uses the TensorFlow Data Validation library.
ExampleValidator will take as input the statistics from StatisticsGen, and the schema from SchemaGen.
example_validator = tfx.components.ExampleValidator(
statistics=statistics_gen.outputs['statistics'],
schema=schema_gen.outputs['schema'],
exclude_splits=['eval'],
)
context.run(example_validator, enable_cache=False)
After ExampleValidator finishes running, we can visualize the anomalies as a table.
context.show(example_validator.outputs['anomalies'])
Transform
For a structured and repeatable design of a TFX pipeline we will need a scalable approach to feature engineering. The Transform component performs feature engineering for both training and serving. It uses the TensorFlow Transform library.
The Transform component uses a module file to supply user code for the feature engineering what we want to do, so our first step is to create that module file. We will only be working with the summary field.
import os
if not os.path.exists("modules"):
os.mkdir("modules")
_transform_module_file = 'modules/_transform_module.py'
%%writefile {_transform_module_file}
import tensorflow as tf
def _fill_in_missing(x, default_value):
"""Replace missing values in a SparseTensor.
Fills in missing values of `x` with the default_value.
Args:
x: A `SparseTensor` of rank 2. Its dense shape should have size at most 1
in the second dimension.
default_value: the value with which to replace the missing values.
Returns:
A rank 1 tensor where missing values of `x` have been filled in.
"""
if not isinstance(x, tf.sparse.SparseTensor):
return x
return tf.squeeze(
tf.sparse.to_dense(
tf.SparseTensor(x.indices, x.values, [x.dense_shape[0], 1]),
default_value),
axis=1)
def preprocessing_fn(inputs):
outputs = {}
# outputs["summary"] = _fill_in_missing(inputs["summary"],"")
outputs["summary"] = _fill_in_missing(inputs["text"],"")
return outputs
Writing modules/_transform_module.py
preprocessor = tfx.components.Transform(
examples=example_gen.outputs['examples'],
schema=schema_gen.outputs['schema'],
module_file=os.path.abspath(_transform_module_file))
context.run(preprocessor, enable_cache=False)
running bdist_wheel
running build
running build_py
creating build
creating build/lib
copying _transform_module.py -> build/lib
installing to /tmpfs/tmp/tmpng5n_dum
running install
running install_lib
copying build/lib/_transform_module.py -> /tmpfs/tmp/tmpng5n_dum
running install_egg_info
running egg_info
creating tfx_user_code_Transform.egg-info
writing tfx_user_code_Transform.egg-info/PKG-INFO
writing dependency_links to tfx_user_code_Transform.egg-info/dependency_links.txt
writing top-level names to tfx_user_code_Transform.egg-info/top_level.txt
writing manifest file 'tfx_user_code_Transform.egg-info/SOURCES.txt'
reading manifest file 'tfx_user_code_Transform.egg-info/SOURCES.txt'
writing manifest file 'tfx_user_code_Transform.egg-info/SOURCES.txt'
Copying tfx_user_code_Transform.egg-info to /tmpfs/tmp/tmpng5n_dum/tfx_user_code_Transform-0.0+1c2de159578a848ee3ef5ef0cff9839b74e136b49dc22db018288c0f9abdee29-py3.9.egg-info
running install_scripts
creating /tmpfs/tmp/tmpng5n_dum/tfx_user_code_Transform-0.0+1c2de159578a848ee3ef5ef0cff9839b74e136b49dc22db018288c0f9abdee29.dist-info/WHEEL
creating '/tmpfs/tmp/tmpyps6sws4/tfx_user_code_Transform-0.0+1c2de159578a848ee3ef5ef0cff9839b74e136b49dc22db018288c0f9abdee29-py3-none-any.whl' and adding '/tmpfs/tmp/tmpng5n_dum' to it
adding '_transform_module.py'
adding 'tfx_user_code_Transform-0.0+1c2de159578a848ee3ef5ef0cff9839b74e136b49dc22db018288c0f9abdee29.dist-info/METADATA'
adding 'tfx_user_code_Transform-0.0+1c2de159578a848ee3ef5ef0cff9839b74e136b49dc22db018288c0f9abdee29.dist-info/WHEEL'
adding 'tfx_user_code_Transform-0.0+1c2de159578a848ee3ef5ef0cff9839b74e136b49dc22db018288c0f9abdee29.dist-info/top_level.txt'
adding 'tfx_user_code_Transform-0.0+1c2de159578a848ee3ef5ef0cff9839b74e136b49dc22db018288c0f9abdee29.dist-info/RECORD'
removing /tmpfs/tmp/tmpng5n_dum
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/setuptools/_distutils/cmd.py:66: SetuptoolsDeprecationWarning: setup.py install is deprecated.
!!
********************************************************************************
Please avoid running ``setup.py`` directly.
Instead, use pypa/build, pypa/installer or other
standards-based tools.
See https://blog.ganssle.io/articles/2021/10/setup-py-deprecated.html for details.
********************************************************************************
!!
self.initialize_options()
Processing /tmpfs/tmp/tfx-interactive-2024-06-19T10_25_02.997883-vluhbwr6/_wheels/tfx_user_code_Transform-0.0+1c2de159578a848ee3ef5ef0cff9839b74e136b49dc22db018288c0f9abdee29-py3-none-any.whl
Installing collected packages: tfx-user-code-Transform
Successfully installed tfx-user-code-Transform-0.0+1c2de159578a848ee3ef5ef0cff9839b74e136b49dc22db018288c0f9abdee29
Processing /tmpfs/tmp/tfx-interactive-2024-06-19T10_25_02.997883-vluhbwr6/_wheels/tfx_user_code_Transform-0.0+1c2de159578a848ee3ef5ef0cff9839b74e136b49dc22db018288c0f9abdee29-py3-none-any.whl
Installing collected packages: tfx-user-code-Transform
Successfully installed tfx-user-code-Transform-0.0+1c2de159578a848ee3ef5ef0cff9839b74e136b49dc22db018288c0f9abdee29
Processing /tmpfs/tmp/tfx-interactive-2024-06-19T10_25_02.997883-vluhbwr6/_wheels/tfx_user_code_Transform-0.0+1c2de159578a848ee3ef5ef0cff9839b74e136b49dc22db018288c0f9abdee29-py3-none-any.whl
Installing collected packages: tfx-user-code-Transform
Successfully installed tfx-user-code-Transform-0.0+1c2de159578a848ee3ef5ef0cff9839b74e136b49dc22db018288c0f9abdee29
INFO:tensorflow:Assets written to: /tmpfs/tmp/tfx-interactive-2024-06-19T10_25_02.997883-vluhbwr6/Transform/transform_graph/5/.temp_path/tftransform_tmp/16aec2c799b44aacabe0e367f06d0a6e/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tfx-interactive-2024-06-19T10_25_02.997883-vluhbwr6/Transform/transform_graph/5/.temp_path/tftransform_tmp/16aec2c799b44aacabe0e367f06d0a6e/assets
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
Let's take a look at some of the transformed examples and check that they are indeed processed as intended.
def pprint_examples(artifact, n_examples=2):
print("artifact:", artifact, "\n")
uri = os.path.join(artifact.uri, "Split-eval")
print("uri:", uri, "\n")
tfrecord_filenames = [os.path.join(uri, name) for name in os.listdir(uri)]
print("tfrecord_filenames:", tfrecord_filenames, "\n")
dataset = tf.data.TFRecordDataset(tfrecord_filenames, compression_type="GZIP")
for tfrecord in dataset.take(n_examples):
serialized_example = tfrecord.numpy()
example = tf.train.Example.FromString(serialized_example)
pp.pprint(example)
pprint_examples(preprocessor.outputs['transformed_examples'].get()[0])
artifact: Artifact(artifact: id: 6
type_id: 14
uri: "/tmpfs/tmp/tfx-interactive-2024-06-19T10_25_02.997883-vluhbwr6/Transform/transformed_examples/5"
properties {
key: "split_names"
value {
string_value: "[\"train\", \"eval\"]"
}
}
custom_properties {
key: "name"
value {
string_value: "transformed_examples:2024-06-19T10:25:14.991872"
}
}
custom_properties {
key: "producer_component"
value {
string_value: "Transform"
}
}
custom_properties {
key: "tfx_version"
value {
string_value: "1.15.1"
}
}
state: LIVE
name: "transformed_examples:2024-06-19T10:25:14.991872"
, artifact_type: id: 14
name: "Examples"
properties {
key: "span"
value: INT
}
properties {
key: "split_names"
value: STRING
}
properties {
key: "version"
value: INT
}
base_type: DATASET
)
uri: /tmpfs/tmp/tfx-interactive-2024-06-19T10_25_02.997883-vluhbwr6/Transform/transformed_examples/5/Split-eval
tfrecord_filenames: ['/tmpfs/tmp/tfx-interactive-2024-06-19T10_25_02.997883-vluhbwr6/Transform/transformed_examples/5/Split-eval/transformed_examples-00000-of-00001.gz']
features {
feature {
key: "summary"
value {
bytes_list {
value: "This was an absolutely terrible movie. Don\'t be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie\'s ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor\'s like Christopher Walken\'s good name. I could barely sit through it."
}
}
}
}
features {
feature {
key: "summary"
value {
bytes_list {
value: "This is the kind of film for a snowy Sunday afternoon when the rest of the world can go ahead with its own business as you descend into a big arm-chair and mellow for a couple of hours. Wonderful performances from Cher and Nicolas Cage (as always) gently row the plot along. There are no rapids to cross, no dangerous waters, just a warm and witty paddle through New York life at its best. A family film in every sense and one that deserves the praise it received."
}
}
}
}
Trainer
Trainer component trains an ML model, and it requires a model definition code from users.
The run_fn function in TFX's Trainer component is the entry point for training a machine learning model. It is a user-supplied function that takes in a set of arguments and returns a model artifact.
The run_fn function is responsible for:
- Building the machine learning model.
- Training the model on the training data.
- Saving the trained model to the serving model directory.
Write model training code
We will create a very simple fine-tuned model, with the preprocessing GPT-2 model. First, we need to create a module that contains the run_fn function for TFX Trainer because TFX Trainer expects the run_fn function to be defined in a module.
model_file = "modules/model.py"
model_fn = "modules.model.run_fn"
Now, we write the run_fn function:
This run_fn function first gets the training data from the fn_args.examples argument. It then gets the schema of the training data from the fn_args.schema argument. Next, it loads finetuned GPT-2 model along with its preprocessor. The model is then trained on the training data using the model.train() method.
Finally, the trained model weights are saved to the fn_args.serving_model_dir argument.
Now, we are going to work with Keras NLP's GPT-2 Model! You can learn about the full GPT-2 model implementation in KerasNLP on GitHub or can read and interactively test the model on Google IO2023 colab notebook.
import keras_nlp
import keras
import tensorflow as tf
%%writefile {model_file}
import os
import time
from absl import logging
import keras_nlp
import more_itertools
import pandas as pd
import tensorflow as tf
import keras
import tfx
import tfx.components.trainer.fn_args_utils
import gc
_EPOCH = 1
_BATCH_SIZE = 20
_INITIAL_LEARNING_RATE = 5e-5
_END_LEARNING_RATE = 0.0
_SEQUENCE_LENGTH = 128 # default value is 256
def _input_fn(file_pattern: str) -> list:
"""Retrieves training data and returns a list of articles for training.
For each row in the TFRecordDataset, generated in the previous ExampleGen
component, create a new tf.train.Example object and parse the TFRecord into
the example object. Articles, which are initially in bytes objects, are
decoded into a string.
Args:
file_pattern: Path to the TFRecord file of the training dataset.
Returns:
A list of training articles.
Raises:
FileNotFoundError: If TFRecord dataset is not found in the file_pattern
directory.
"""
if os.path.basename(file_pattern) == '*':
file_loc = os.path.dirname(file_pattern)
else:
raise FileNotFoundError(
f"There is no file in the current directory: '{file_pattern}."
)
file_paths = [os.path.join(file_loc, name) for name in os.listdir(file_loc)]
train_articles = []
parsed_dataset = tf.data.TFRecordDataset(file_paths, compression_type="GZIP")
for raw_record in parsed_dataset:
example = tf.train.Example()
example.ParseFromString(raw_record.numpy())
train_articles.append(
example.features.feature["summary"].bytes_list.value[0].decode('utf-8')
)
return train_articles
def run_fn(fn_args: tfx.components.trainer.fn_args_utils.FnArgs) -> None:
"""Trains the model and outputs the trained model to a the desired location given by FnArgs.
Args:
FnArgs : Args to pass to user defined training/tuning function(s)
"""
train_articles = pd.Series(_input_fn(
fn_args.train_files[0],
))
tf_train_ds = tf.data.Dataset.from_tensor_slices(train_articles)
gpt2_preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset(
'gpt2_base_en',
sequence_length=_SEQUENCE_LENGTH,
add_end_token=True,
)
gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset(
'gpt2_base_en', preprocessor=gpt2_preprocessor
)
processed_ds = (
tf_train_ds
.batch(_BATCH_SIZE)
.cache()
.prefetch(tf.data.AUTOTUNE)
)
gpt2_lm.include_preprocessing = False
lr = tf.keras.optimizers.schedules.PolynomialDecay(
5e-5,
decay_steps=processed_ds.cardinality() * _EPOCH,
end_learning_rate=0.0,
)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
gpt2_lm.compile(
optimizer=keras.optimizers.Adam(lr),
loss=loss,
weighted_metrics=['accuracy'],
)
gpt2_lm.fit(processed_ds, epochs=_EPOCH)
if os.path.exists(fn_args.serving_model_dir):
os.rmdir(fn_args.serving_model_dir)
os.mkdir(fn_args.serving_model_dir)
gpt2_lm.save_weights(
filepath=os.path.join(fn_args.serving_model_dir, "model_weights.weights.h5")
)
del gpt2_lm, gpt2_preprocessor, processed_ds, tf_train_ds
gc.collect()
Writing modules/model.py
trainer = tfx.components.Trainer(
run_fn=model_fn,
examples=preprocessor.outputs['transformed_examples'],
train_args=tfx.proto.TrainArgs(splits=['train']),
eval_args=tfx.proto.EvalArgs(splits=['train']),
schema=schema_gen.outputs['schema'],
)
context.run(trainer, enable_cache=False)
WARNING:absl:Examples artifact does not have payload_format custom property. Falling back to FORMAT_TF_EXAMPLE WARNING:absl:Examples artifact does not have payload_format custom property. Falling back to FORMAT_TF_EXAMPLE WARNING:absl:Examples artifact does not have payload_format custom property. Falling back to FORMAT_TF_EXAMPLE Dow
View on TensorFlow.org
Run in Google Colab
View source on GitHub
Download notebook