We are going to be extending the code we wrote in the last tutorial to render an insightful visualization of the environment using Matplotlib. If you haven’t read my first article on Creating custom Gym environments from scratch, you should stop here and read that first.

If you are unfamiliar with the matplotlib library, don’t worry. We will be going over every line so you can create your own custom visualizations of your gym environments. As always, the code for this tutorial will be available on my Github.

Here is a sneak preview of what we will be creating in this article:

If it looks complicated, it’s actually not that bad. Just a couple of graphs updating on each step , annotated with some key information. Let’s get started!

Stock Trading Visualization

In our last tutorial, we wrote a simple render method using print statements to display the agent’s net worth and other important metrics. Let’s move that logic to a new method called _render_to_file , so we can save a session’s trading metrics to a file, if necessary.

def _render_to_file(self, filename='render.txt'):

profit = self.net_worth - INITIAL_ACCOUNT_BALANCE



file = open(filename, 'a+') file.write(f'Step: {self.current_step}

')

file.write(f'Balance: {self.balance}

')

file.write(f'Shares held: {self.shares_held} (Total sold:

{self.total_shares_sold})

')

file.write(f'Avg cost for held shares: {self.cost_basis} (Total

sales value: {self.total_sales_value})

')

file.write(f'Net worth: {self.net_worth} (Max net worth:

{self.max_net_worth})

')

file.write(f'Profit: {profit}



') file.close()

Now, let’s move onto creating our new render method. It’s going to utilize our new StockTradingGraph class, that we haven’t written yet. We’ll get to that next.

def render(self, mode='live', title=None, **kwargs):

# Render the environment to the screen if mode == 'file':

self._render_to_file(kwargs.get('filename', 'render.txt'))

elif mode == 'live':

if self.visualization == None:

self.visualization = StockTradingGraph(self.df, title)



if self.current_step > LOOKBACK_WINDOW_SIZE:

self.visualization.render(self.current_step, self.net_worth,

self.trades, window_size=LOOKBACK_WINDOW_SIZE)

We are using kwargs here to pass the optional filename and title to the StockTradingGraph . If you are unfamiliar with kwargs , it is basically a dictionary for passing optional keyword arguments to functions.

We also pass self.trades for the visualization to render, but have not defined it yet, so let’s do that. Back in our _take_action method, whenever we buy or sell shares, we are now going to add the details of that transaction to the self.trades object, which we’ve initialized to [] in our reset method.

def _take_action(self, action):

...



if action_type < 1:

...



if shares_bought > 0:

self.trades.append({'step': self.current_step,

'shares': shares_bought, 'total': additional_cost,

'type': "buy"}) elif action_type < 2:

... if shares_sold > 0:

self.trades.append({'step': self.current_step,

'shares': shares_sold, 'total': shares_sold * current_price,

'type': "sell"})

Now our StockTradingGraph has all of the information it needs to render the stock’s price history and trade volume, along with our agent’s net worth and any trades its made. Let’s get started rendering our visualization.

First, we’ll define our StockTradingGraph and its __init__ method. Here is where we will create our pyplot figure, and set up each of the subplots to be rendered to. The date2num function is used to reformat dates into timestamps, necessary later in the rendering process.

import numpy as np

import matplotlib

import matplotlib.pyplot as plt

import matplotlib.dates as mdates def date2num(date):

converter = mdates.strpdate2num('%Y-%m-%d')

return converter(date) class StockTradingGraph:

"""A stock trading visualization using matplotlib made to render

OpenAI gym environments""" def __init__(self, df, title=None):

self.df = df

self.net_worths = np.zeros(len(df['Date'])) # Create a figure on screen and set the title

fig = plt.figure()

fig.suptitle(title) # Create top subplot for net worth axis

self.net_worth_ax = plt.subplot2grid((6, 1), (0, 0), rowspan=2,

colspan=1)



# Create bottom subplot for shared price/volume axis

self.price_ax = plt.subplot2grid((6, 1), (2, 0), rowspan=8,

colspan=1, sharex=self.net_worth_ax) # Create a new axis for volume which shares its x-axis with

price

self.volume_ax = self.price_ax.twinx() # Add padding to make graph easier to view

plt.subplots_adjust(left=0.11, bottom=0.24, right=0.90,

top=0.90, wspace=0.2, hspace=0) # Show the graph without blocking the rest of the program

plt.show(block=False)

We use the plt.subplot2grid(...) method to first create a subplot at the top of our figure to render our net worth grid, and then create another subplot below it for our price grid. The first argument of subplot2grid is the size of the subplot and the second is the location within the figure.

To render our trade volume bars, we call the twinx() method on self.price_ax , which allows us to overlay another grid on top that will share the same x-axis. Finally, and most importantly, we will render our figure to the screen using plt.show(block=False) . If you forget to pass block=False , you will only ever see the first step rendered, after which the agent will be blocked from continuing.

Next, let’s write our render method. This will take all of the information from the current time step and render a live representation to the screen.

def render(self, current_step, net_worth, trades, window_size=40):

self.net_worths[current_step] = net_worth window_start = max(current_step - window_size, 0)

step_range = range(window_start, current_step + 1) # Format dates as timestamps, necessary for candlestick graph

dates = np.array([date2num(x)

for x in self.df['Date'].values[step_range]])



self._render_net_worth(current_step, net_worth, window_size,

dates)

self._render_price(current_step, net_worth, dates, step_range)

self._render_volume(current_step, net_worth, dates, step_range)

self._render_trades(current_step, trades, step_range) # Format the date ticks to be more easily read

self.price_ax.set_xticklabels(self.df['Date'].values[step_range],

rotation=45, horizontalalignment='right') # Hide duplicate net worth date labels

plt.setp(self.net_worth_ax.get_xticklabels(), visible=False) # Necessary to view frames before they are unrendered

plt.pause(0.001)

Here, we save the net_worth , and then render each graph from top to bottom. We’re also going to annotate the price graph with the trades the agent has taken in the self.render_trades method. It’s important to call plt.pause() here, otherwise each frame will be cleared by the next call to render , before the last frame was actually shown on screen.

Now, let’s look at each of the graph’s render methods, starting with net worth.

def _render_net_worth(self, current_step, net_worth, step_range,

dates):

# Clear the frame rendered last step

self.net_worth_ax.clear() # Plot net worths

self.net_worth_ax.plot_date(dates, self.net_worths[step_range], '-

', label='Net Worth') # Show legend, which uses the label we defined for the plot above

self.net_worth_ax.legend()

legend = self.net_worth_ax.legend(loc=2, ncol=2, prop={'size': 8})

legend.get_frame().set_alpha(0.4) last_date = date2num(self.df['Date'].values[current_step])

last_net_worth = self.net_worths[current_step] # Annotate the current net worth on the net worth graph

self.net_worth_ax.annotate('{0:.2f}'.format(net_worth),

(last_date, last_net_worth),

xytext=(last_date, last_net_worth),

bbox=dict(boxstyle='round', fc='w', ec='k', lw=1),

color="black",

fontsize="small") # Add space above and below min/max net worth

self.net_worth_ax.set_ylim(

min(self.net_worths[np.nonzero(self.net_worths)]) / 1.25,

max(self.net_worths) * 1.25)

We just call plot_date(...) on our net worth subplot to plot a simple line graph, then annotate it with the agent’s current net_worth , and add a legend.

Rendering the price graph is a bit more complicated. To keep things simple, we are going to render the OHCL bars in a separate method from the volume bars. First, you need to pip install mpl_finance if you don’t already have it, as this package is needed for the candlestick graphs we’ll be using. Then add this line to the top of your file.

from mpl_finance import candlestick_ochl as candlestick

Great, let’s clear the previous frame, zip up the OHCL data, and render a candlestick graph to the self.price_ax subplot.

def _render_price(self, current_step, net_worth, dates, step_range):

self.price_ax.clear() # Format data for OHCL candlestick graph

candlesticks = zip(dates,

self.df['Open'].values[step_range],

self.df['Close'].values[step_range],

self.df['High'].values[step_range],

self.df['Low'].values[step_range]) # Plot price using candlestick graph from mpl_finance

candlestick(self.price_ax, candlesticks, width=1,

colorup=UP_COLOR, colordown=DOWN_COLOR) last_date = date2num(self.df['Date'].values[current_step])

last_close = self.df['Close'].values[current_step]

last_high = self.df['High'].values[current_step] # Print the current price to the price axis

self.price_ax.annotate('{0:.2f}'.format(last_close),

(last_date, last_close),

xytext=(last_date, last_high),

bbox=dict(boxstyle='round', fc='w', ec='k', lw=1),

color="black",

fontsize="small") # Shift price axis up to give volume chart space

ylim = self.price_ax.get_ylim()

self.price_ax.set_ylim(ylim[0] - (ylim[1] - ylim[0])

* VOLUME_CHART_HEIGHT, ylim[1])

We’ve annotated the graph with the stock’s current price, and shifted the chart up to prevent it from overlapping with the volume bars. Next let’s look at the volume render method, which is much simpler as there are no annotations.

def _render_volume(self, current_step, net_worth, dates,

step_range):

self.volume_ax.clear() volume = np.array(self.df['Volume'].values[step_range])



pos = self.df['Open'].values[step_range] - \

self.df['Close'].values[step_range] < 0

neg = self.df['Open'].values[step_range] - \

self.df['Close'].values[step_range] > 0 # Color volume bars based on price direction on that date

self.volume_ax.bar(dates[pos], volume[pos], color=UP_COLOR,

alpha=0.4, width=1, align='center')

self.volume_ax.bar(dates[neg], volume[neg], color=DOWN_COLOR,

alpha=0.4, width=1, align='center') # Cap volume axis height below price chart and hide ticks

self.volume_ax.set_ylim(0, max(volume) / VOLUME_CHART_HEIGHT)

self.volume_ax.yaxis.set_ticks([])

Just a simple bar graph, with each bar colored either green or red, depending on whether the price moved up or down in that time step.

Finally, let’s get to the fun part: _render_trades . In this method, we’ll render an arrow on the price graph where the agent has made a trade, annotated with the total amount transacted.

def _render_trades(self, current_step, trades, step_range):

for trade in trades:

if trade['step'] in step_range:

date = date2num(self.df['Date'].values[trade['step']])

high = self.df['High'].values[trade['step']]

low = self.df['Low'].values[trade['step']] if trade['type'] == 'buy':

high_low = low

color = UP_TEXT_COLOR

else:

high_low = high

color = DOWN_TEXT_COLOR total = '{0:.2f}'.format(trade['total']) # Print the current price to the price axis

self.price_ax.annotate(f'${total}', (date, high_low),

xytext=(date, high_low),

color=color,

fontsize=8,

arrowprops=(dict(color=color)))

And that’s it! We now have a beautiful, live rendering visualization of our stock trading environment we created in the last article! It’s too bad we still haven’t put much time into teaching the agent how to make money… We’ll leave that for next time!

Not too shabby! Next week we are going to build upon the code from this tutorial to create Bitcoin trading bots that don’t lose money.