Using Other ML Frameworks in TFX
Stay organized with collections
Save and categorize content based on your preferences.
TFX as a platform is framework neutral, and can be used with other ML
frameworks, e.g., JAX, scikit-learn.
For model developers, this means they do not need to rewrite their model
code implemented in another ML framework, but can instead reuse the bulk of the
training code as-is in TFX, and benefit from other capabilities TFX and the
rest of the TensorFlow Ecosystem offers.
The TFX pipeline SDK and most modules in TFX, e.g., pipeline orchestrator,
don't have any direct dependency on TensorFlow, but there are some aspects
which are oriented towards TensorFlow, such as data formats. With some
consideration of the needs of a particular modeling framework, a TFX pipeline
can be used to train models in any other Python-based ML framework. This includes
Scikit-learn, XGBoost, and PyTorch, among others. Some of the considerations for
using the standard TFX components with other frameworks include:
- ExampleGen outputs
tf.train.Example
in TFRecord files. It's a generic representation for training data, and
downstream components use
TFXIO
to read it as Arrow/RecordBatch in memory, which can be further converted to
tf.dataset
, Tensors
or other formats. Payload/File formats other than
tf.train.Example/TFRecord are being considered, but for TFXIO users it
should be a blackbox.
- Transform can be used to generate transformed training examples no
matter what framework is used for training, but if the model format is not
saved_model
, users won't be able to embed the transform graph into the
model. In that case, model prediction needs to take transformed features
instead of raw features, and users can run transform as a preprocessing
step before calling the model prediction when serving.
- Trainer supports
GenericTraining
so users can train their models using any ML framework.
- Evaluator by default only supports
saved_model
, but users can provide
a UDF that generates predictions for model evaluation.
Training a model in a non-Python-based framework will require isolating a
custom training component in a Docker container, as part of a pipeline which is
running in a containerized environment such as Kubernetes.
JAX
JAX is Autograd and XLA, brought together for
high-performance machine learning research.
Flax
is a neural network library and ecosystem for JAX, designed for flexibility.
With jax2tf,
we are able to convert trained JAX/Flax models into saved_model
format,
which can be used seamlessly in TFX with generic training and model evaluation.
For details, check this example.
scikit-learn
Scikit-learn is a machine learning library
for the Python programming language. We have an e2e
example
with customized training and evaluation in TFX-Addons.
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2024-12-09 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2024-12-09 UTC."],[],[],null,["# Using Other ML Frameworks in TFX\n\n\u003cbr /\u003e\n\nTFX as a platform is framework neutral, and can be used with other ML\nframeworks, e.g., JAX, scikit-learn.\n\nFor model developers, this means they do not need to rewrite their model\ncode implemented in another ML framework, but can instead reuse the bulk of the\ntraining code as-is in TFX, and benefit from other capabilities TFX and the\nrest of the TensorFlow Ecosystem offers.\n\nThe TFX pipeline SDK and most modules in TFX, e.g., pipeline orchestrator,\ndon't have any direct dependency on TensorFlow, but there are some aspects\nwhich are oriented towards TensorFlow, such as data formats. With some\nconsideration of the needs of a particular modeling framework, a TFX pipeline\ncan be used to train models in any other Python-based ML framework. This includes\nScikit-learn, XGBoost, and PyTorch, among others. Some of the considerations for\nusing the standard TFX components with other frameworks include:\n\n- **ExampleGen** outputs [tf.train.Example](https://www.tensorflow.org/tutorials/load_data/tfrecord) in TFRecord files. It's a generic representation for training data, and downstream components use [TFXIO](https://github.com/tensorflow/community/blob/master/rfcs/20191017-tfx-standardized-inputs.md) to read it as Arrow/RecordBatch in memory, which can be further converted to `tf.dataset`, `Tensors` or other formats. Payload/File formats other than tf.train.Example/TFRecord are being considered, but for TFXIO users it should be a blackbox.\n- **Transform** can be used to generate transformed training examples no matter what framework is used for training, but if the model format is not `saved_model`, users won't be able to embed the transform graph into the model. In that case, model prediction needs to take transformed features instead of raw features, and users can run transform as a preprocessing step before calling the model prediction when serving.\n- **Trainer** supports [GenericTraining](https://www.tensorflow.org/tfx/guide/trainer#generic_trainer) so users can train their models using any ML framework.\n- **Evaluator** by default only supports `saved_model`, but users can provide a UDF that generates predictions for model evaluation.\n\nTraining a model in a non-Python-based framework will require isolating a\ncustom training component in a Docker container, as part of a pipeline which is\nrunning in a containerized environment such as Kubernetes.\n\nJAX\n---\n\n[JAX](https://github.com/jax-ml/jax) is Autograd and XLA, brought together for\nhigh-performance machine learning research.\n[Flax](https://github.com/google/flax)\nis a neural network library and ecosystem for JAX, designed for flexibility.\n\nWith [jax2tf](https://github.com/jax-ml/jax/tree/main/jax/experimental/jax2tf),\nwe are able to convert trained JAX/Flax models into `saved_model` format,\nwhich can be used seamlessly in TFX with generic training and model evaluation.\nFor details, check this [example](https://github.com/tensorflow/tfx/blob/master/tfx/examples/penguin/penguin_utils_flax_experimental.py).\n\nscikit-learn\n------------\n\n[Scikit-learn](https://scikit-learn.org/stable/) is a machine learning library\nfor the Python programming language. We have an e2e\n[example](https://github.com/tensorflow/tfx-addons/tree/main/examples/sklearn_penguins)\nwith customized training and evaluation in TFX-Addons."]]