tf_agents.bandits.multi_objective.multi_objective_scalarizer.LinearScalarizer
Stay organized with collections
Save and categorize content based on your preferences.
Scalarizes multple objectives by a linear combination.
Inherits From: Scalarizer
tf_agents.bandits.multi_objective.multi_objective_scalarizer.LinearScalarizer(
weights: Sequence[tf_agents.bandits.multi_objective.multi_objective_scalarizer.ScalarFloat
],
multi_objective_transform: Optional[Callable[[tf.Tensor], tf.Tensor]] = None
)
Args |
weights
|
A Sequence of weights for linearly combining the objectives.
|
multi_objective_transform
|
A Optional Callable that takes in a
tf.Tensor of multiple objective values and applies an arbitrary
transform that returns a tf.Tensor of transformed multiple objectives.
This transform is applied before the linear scalarization. The transform
should apply to each objective so that the shape of the multiobjectives
and the transformed multiobjectives are equal. This is verified in
_validate_scalarization_parameter_shape via call.
|
Methods
set_parameters
View source
set_parameters(
weights: tf.Tensor
)
Set the scalarization parameter of the LinearScalarizer.
Args |
weights
|
A a rank-2 tf.Tensor of weights shaped as [batch_size,
self._num_of_objectives], where batch_size should match the batch size
of the multi_objectives passed to the scalarizer call.
|
Raises |
ValueError
|
if the weights tensor is not rank-2, or has a last dimension
size that does not match self._num_of_objectives .
|
__call__
View source
__call__(
multi_objectives: tf.Tensor
) -> tf.Tensor
Returns a single reward by scalarizing multiple objectives.
Args |
multi_objectives
|
A Tensor of shape [batch_size, number_of_objectives],
where each column represents an objective.
|
Returns: A Tensor
of shape [batch_size] representing scalarized rewards.
Raises |
ValueError
|
if multi_objectives.shape.rank != 2 .
|
ValueError
|
if
multi_objectives.shape.dims[1] != self._num_of_objectives .
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2024-04-26 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2024-04-26 UTC."],[],[]]