tensorflow-ruby TensorBoard

Now that tensorflow-ruby supports linear regression, its time to implement support for TensorBoard which helps you visualize how your model is working.

TensorFlow 2 supports TensorBoard via the Summary API. A high-level AP is provided via Python in tf.summary. A lower level API is also provided by a set of C++ operations, which is what the Ruby implementation uses.

So let’s add TensorBoard support to the linear regression example. Note the example uses Graph mode (ala TensorFlow 1.x) so it requires the use of Session#run.

First we create a writer:

# Setup a variable to keep track of the epoch and get an op to increment it
epoch_var = Tf::Variable.new(1, dtype: :int64)
epoch_var_add_op = epoch_var.assign_add(1)

# Enable logging to TensorBoard - create a file writer and initialize it
path = File.join(Dir.tmpdir, 'tensorflow-ruby')
writer = Tf::Summary.create_file_writer(path)
writer.step = epoch_var
writer_flush_op = writer.flush

Notice the use of a variable to keep track of the current epoch, or step. Each time you log data for TensorBoard, you need to specify the current step.

Now let’s write out the graph (once again, if this was eager node you could skip the Session#run calls):

# Log the graph

Let’s go open TensorBoard:

 tensorboard --logdir <temp_dir>\tensorflow-ruby

And here is the graph – complicated!

Part of the reason its complicated is that the Ruby bindings don’t do as good as job as namespacing operations as Python does. So some future work there. But otherwise, this is the graph along with the operations used to calculate gradients and thus back propogration.

Now let’s train our model and log the loss function.

  # Log the cost
  session.run(write_cost_op, {x_value => train_x, y_value => train_y})

Note that logging is slow! So the example logs once per epoch. Our finished graph looks like this:

Nice! The Ruby bindings provide full support for TensorBoard – logging scalars, audio, video, graphs, etc.

Leave a Reply

Your email address will not be published. Required fields are marked *