![]() |
![]() |
![]() |
![]() |
import tensorflow as tf
2025-06-26 11:07:04.576431: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered WARNING: All log messages before absl::InitializeLog() is called are written to STDERR E0000 00:00:1750936024.597332 7254 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered E0000 00:00:1750936024.603809 7254 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered W0000 00:00:1750936024.620507 7254 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once. W0000 00:00:1750936024.620529 7254 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once. W0000 00:00:1750936024.620531 7254 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once. W0000 00:00:1750936024.620534 7254 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
TensorFlow 2.x includes significant changes to the tf.summary
API used to write summary data for visualization in TensorBoard.
What's changed
It's useful to think of the tf.summary
API as two sub-APIs:
- A set of ops for recording individual summaries -
summary.scalar()
,summary.histogram()
,summary.image()
,summary.audio()
, andsummary.text()
- which are called inline from your model code. - Writing logic that collects these individual summaries and writes them to a specially formatted log file (which TensorBoard then reads to generate visualizations).
In TF 1.x
The two halves had to be manually wired together - by fetching the summary op outputs via Session.run()
and calling FileWriter.add_summary(output, step)
. The v1.summary.merge_all()
op made this easier by using a graph collection to aggregate all summary op outputs, but this approach still worked poorly for eager execution and control flow, making it especially ill-suited for TF 2.x.
In TF 2.X
The two halves are tightly integrated, and now individual tf.summary
ops write their data immediately when executed. Using the API from your model code should still look familiar, but it's now friendly to eager execution while remaining graph-mode compatible. Integrating both halves of the API means the summary.FileWriter
is now part of the TensorFlow execution context and gets accessed directly by tf.summary
ops, so configuring writers is the main part that looks different.
Example usage with eager execution, the default in TF 2.x:
writer = tf.summary.create_file_writer("/tmp/mylogs/eager")
with writer.as_default():
for step in range(100):
# other model code would go here
tf.summary.scalar("my_metric", 0.5, step=step)
writer.flush()
I0000 00:00:1750936029.169573 7254 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 13680 MB memory: -> device: 0, name: Tesla T4, pci bus id: 0000:00:05.0, compute capability: 7.5 I0000 00:00:1750936029.171874 7254 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 13756 MB memory: -> device: 1, name: Tesla T4, pci bus id: 0000:00:06.0, compute capability: 7.5 I0000 00:00:1750936029.174126 7254 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:2 with 13756 MB memory: -> device: 2, name: Tesla T4, pci bus id: 0000:00:07.0, compute capability: 7.5 I0000 00:00:1750936029.176311 7254 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:3 with 13756 MB memory: -> device: 3, name: Tesla T4, pci bus id: 0000:00:08.0, compute capability: 7.5
ls /tmp/mylogs/eager
events.out.tfevents.1750936029.kokoro-gcp-ubuntu-prod-1050903991.7254.0.v2
Example usage with tf.function graph execution:
writer = tf.summary.create_file_writer("/tmp/mylogs/tf_function")
@tf.function
def my_func(step):
with writer.as_default():
# other model code would go here
tf.summary.scalar("my_metric", 0.5, step=step)
for step in tf.range(100, dtype=tf.int64):
my_func(step)
writer.flush()
ls /tmp/mylogs/tf_function
events.out.tfevents.1750936030.kokoro-gcp-ubuntu-prod-1050903991.7254.1.v2
Example usage with legacy TF 1.x graph execution:
g = tf.compat.v1.Graph()
with g.as_default():
step = tf.Variable(0, dtype=tf.int64)
step_update = step.assign_add(1)
writer = tf.summary.create_file_writer("/tmp/mylogs/session")
with writer.as_default():
tf.summary.scalar("my_metric", 0.5, step=step)
all_summary_ops = tf.compat.v1.summary.all_v2_summary_ops()
writer_flush = writer.flush()
with tf.compat.v1.Session(graph=g) as sess:
sess.run([writer.init(), step.initializer])
for i in range(100):
sess.run(all_summary_ops)
sess.run(step_update)
sess.run(writer_flush)
I0000 00:00:1750936030.700600 7254 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 13680 MB memory: -> device: 0, name: Tesla T4, pci bus id: 0000:00:05.0, compute capability: 7.5 I0000 00:00:1750936030.702411 7254 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 13756 MB memory: -> device: 1, name: Tesla T4, pci bus id: 0000:00:06.0, compute capability: 7.5 I0000 00:00:1750936030.704267 7254 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:2 with 13756 MB memory: -> device: 2, name: Tesla T4, pci bus id: 0000:00:07.0, compute capability: 7.5 I0000 00:00:1750936030.706041 7254 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:3 with 13756 MB memory: -> device: 3, name: Tesla T4, pci bus id: 0000:00:08.0, compute capability: 7.5 WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1750936030.709912 7254 mlir_graph_optimization_pass.cc:425] MLIR V1 optimization pass is not enabled
ls /tmp/mylogs/session
events.out.tfevents.1750936030.kokoro-gcp-ubuntu-prod-1050903991.7254.2.v2
Converting your code
Converting existing tf.summary
usage to the TF 2.x API cannot be reliably automated, so the tf_upgrade_v2
script just rewrites it all to tf.compat.v1.summary
and will not enable the TF 2.x behaviors automatically.
Partial Migration
To make migration to TF 2.x easier for users of model code that still depends heavily on the TF 1.x summary API logging ops like tf.compat.v1.summary.scalar()
, it is possible to migrate only the writer APIs first, allowing for individual TF 1.x summary ops inside your model code to be fully migrated at a later point.
To support this style of migration, tf.compat.v1.summary
will automatically forward to their TF 2.x equivalents under the following conditions:
- The outermost context is eager mode
- A default TF 2.x summary writer has been set
- A non-empty value for step has been set for the writer (using
tf.summary.SummaryWriter.as_default
,tf.summary.experimental.set_step
, or alternativelytf.compat.v1.train.create_global_step
)
Note that when TF 2.x summary implementation is invoked, the return value will be an empty bytestring tensor, to avoid duplicate summary writing. Additionally, the input argument forwarding is best-effort and not all arguments will be preserved (for instance family
argument will be supported whereas collections
will be removed).
Example to invoke tf.summary.scalar
behaviors in tf.compat.v1.summary.scalar
:
# Enable eager execution.
tf.compat.v1.enable_v2_behavior()
# A default TF 2.x summary writer is available.
writer = tf.summary.create_file_writer("/tmp/mylogs/enable_v2_in_v1")
# A step is set for the writer.
with writer.as_default(step=0):
# Below invokes `tf.summary.scalar`, and the return value is an empty bytestring.
tf.compat.v1.summary.scalar('float', tf.constant(1.0), family="family")
Full Migration
To fully migrate to TF 2.x, you'll need to adapt your code as follows:
A default writer set via
.as_default()
must be present to use summary ops- This means executing ops eagerly or using ops in graph construction
- Without a default writer, summary ops become silent no-ops
- Default writers do not (yet) propagate across the
@tf.function
execution boundary - they are only detected when the function is traced - so best practice is to callwriter.as_default()
within the function body, and to ensure that the writer object continues to exist as long as the@tf.function
is being used
The "step" value must be passed into each op via a the
step
argument- TensorBoard requires a step value to render the data as a time series
- Explicit passing is necessary because the global step from TF 1.x has been removed, so each op must know the desired step variable to read
- To reduce boilerplate, experimental support for registering a default step value is available as
tf.summary.experimental.set_step()
, but this is provisional functionality that may be changed without notice
Function signatures of individual summary ops have changed
- Return value is now a boolean (indicating if a summary was actually written)
- The second parameter name (if used) has changed from
tensor
todata
- The
collections
parameter has been removed; collections are TF 1.x only - The
family
parameter has been removed; just usetf.name_scope()
[Only for legacy graph mode / session execution users]
First initialize the writer with
v1.Session.run(writer.init())
Use
v1.summary.all_v2_summary_ops()
to get all TF 2.x summary ops for the current graph, e.g. to execute them viaSession.run()
Flush the writer with
v1.Session.run(writer.flush())
and likewise forclose()
If your TF 1.x code was instead using tf.contrib.summary
API, it's much more similar to the TF 2.x API, so tf_upgrade_v2
script will automate most of the migration steps (and emit warnings or errors for any usage that cannot be fully migrated). For the most part it just rewrites the API calls to tf.compat.v2.summary
; if you only need compatibility with TF 2.x you can drop the compat.v2
and just reference it as tf.summary
.
Additional tips
In addition to the critical areas above, some auxiliary aspects have also changed:
Conditional recording (like "log every 100 steps") has a new look
- To control ops and associated code, wrap them in a regular if statement (which works in eager mode and in
@tf.function
via autograph) or atf.cond
- To control just summaries, use the new
tf.summary.record_if()
context manager, and pass it the boolean condition of your choosing These replace the TF 1.x pattern:
if condition: writer.add_summary()
- To control ops and associated code, wrap them in a regular if statement (which works in eager mode and in
No direct writing of
tf.compat.v1.Graph
- instead use trace functions- Graph execution in TF 2.x uses
@tf.function
instead of the explicit Graph - In TF 2.x, use the new tracing-style APIs
tf.summary.trace_on()
andtf.summary.trace_export()
to record executed function graphs
- Graph execution in TF 2.x uses
No more global writer caching per logdir with
tf.summary.FileWriterCache
- Users should either implement their own caching/sharing of writer objects, or just use separate writers (TensorBoard support for the latter is in progress)
The event file binary representation has changed
- TensorBoard 1.x already supports the new format; this difference only affects users who are manually parsing summary data from event files
- Summary data is now stored as tensor bytes; you can use
tf.make_ndarray(event.summary.value[0].tensor)
to convert it to numpy