top | item 36684801

(no title)

kerasteam2 | 2 years ago

At this time, there are no backend-agnostic APIs to implement training steps/training loops, because each backend handles training very differently so no shared abstraction can exist (expecially for JAX). So when customizing fit() you have to use backend-native APIs.

If you want to make a model with a custom train_step that is cross-backend, you can do something like:

  def train_step(self, *args, *kwargs):
    if keras.config.backend() == "tensorflow":
      return self._tf_train_step(*args, *kwargs)
    elif ...
BTW it looks the previous account is being rate-limited to less than 1 post / hour (maybe even locked for the day) so I will be very slow to answer questions.

discuss

order

No comments yet.