Flux.jl is a machine learning framework built in Julia. It has some similarities to PyTorch, and like most modern frameworks includes autodifferentiation. It’s definitely still a work in progress, but it is being actively developed (including several GSoC projects this summer).

I was curious about how easy/difficult it might be to convert a PyTorch model into Flux.jl. I found a fairly simple PyTorch tutorial on RNNs to translate. My goal here isn’t to explain RNNs (see the linked article for that) - my intent is to see what is required to go from the PyTorch/Python ecosystem to the Flux.jl/Julia ecosystem.

Let’s start with the PyTorch code:

1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 import torch from torch.autograd import Variable import numpy as np import pylab as pl import torch.nn.init as init torch . manual_seed( 1 ) dtype = torch . FloatTensor input_size, hidden_size, output_size = 7 , 6 , 1 epochs = 200 seq_length = 20 lr = 0.1 data_time_steps = np . linspace( 2 , 10 , seq_length + 1 ) data = np . sin(data_time_steps) data . resize((seq_length + 1 , 1 )) x = Variable(torch . Tensor(data[: - 1 ]) . type(dtype), requires_grad = False) y = Variable(torch . Tensor(data[ 1 :]) . type(dtype), requires_grad = False) w1 = torch . FloatTensor(input_size, hidden_size) . type(dtype) init . normal(w1, 0.0 , 0.4 ) w1 = Variable(w1, requires_grad = True) w2 = torch . FloatTensor(hidden_size, output_size) . type(dtype) init . normal(w2, 0.0 , 0.3 ) w2 = Variable(w2, requires_grad = True) def forward (input, context_state, w1, w2): xh = torch . cat((input, context_state), 1 ) context_state = torch . tanh(xh . mm(w1)) out = context_state . mm(w2) return (out, context_state) for i in range(epochs): total_loss = 0 context_state = Variable(torch . zeros(( 1 , hidden_size)) . type(dtype), requires_grad = True) for j in range(x . size( 0 )): input = x[j:(j + 1 )] target = y[j:(j + 1 )] (pred, context_state) = forward(input, context_state, w1, w2) loss = (pred - target) . pow( 2 ) . sum() / 2 total_loss += loss loss . backward() w1 . data -= lr * w1 . grad . data w2 . data -= lr * w2 . grad . data w1 . grad . data . zero_() w2 . grad . data . zero_() context_state = Variable(context_state . data) if i % 10 == 0 : print ( "Epoch: {} loss {}" . format(i, total_loss . data[ 0 ])) context_state = Variable(torch . zeros(( 1 , hidden_size)) . type(dtype), requires_grad = False) predictions = [] for i in range(x . size( 0 )): input = x[i:i + 1 ] (pred, context_state) = forward(input, context_state, w1, w2) context_state = context_state predictions . append(pred . data . numpy() . ravel()[ 0 ]) pl . scatter(data_time_steps[: - 1 ], x . data . numpy(), s = 90 , label = "Actual" ) pl . scatter(data_time_steps[ 1 :], predictions, label = "Predicted" ) pl . legend() pl . show()

And now the Flux.jl version in Julia:

1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 using Flux using Flux . Tracker using Plots srand( 1 ) input_size, hidden_size, output_size = 7 , 6 , 1 epochs = 200 seq_length = 20 lr = 0.1 data_time_steps = linspace( 2 , 10 , seq_length + 1 ) data = sin . (data_time_steps) x = data[ 1 : end - 1 ] y = data[ 2 : end ] w1 = param(randn(input_size, hidden_size)) w2 = param(randn(hidden_size, output_size)) function forward(input, context_state, W1, W2) #xh = cat(2,input, context_state) # Due to a Flux bug you have to do: xh = cat( 2 , Tracker . collect(input), context_state) context_state = tanh . (xh * W1) out = context_state * W2 return out, context_state end function train() for i in 1 : epochs total_loss = 0 context_state = param(zeros( 1 ,hidden_size)) for j in 1 : length(x) input = x[j] target = y[j] pred, context_state = forward(input, context_state, w1, w2) loss = sum((pred .- target) .^ 2 ) / 2 total_loss += loss back!(loss) w1 . data .-= lr .* w1 . grad w2 . data .-= lr .* w2 . grad w1 . grad .= 0.0 w2 . grad .= 0.0 context_state = param(context_state . data) end if (i % 10 == 0 ) println( "Epoch: $i loss: $total_loss " ) end end end train() context_state = param(zeros( 1 ,hidden_size)) predictions = [] for i in 1 : length(x) input = x[i] pred, context_state = forward(input, context_state, w1, w2) append!(predictions, pred . data) end scatter(data_time_steps[ 1 : end - 1 ], x, label = "actual" ) scatter!(data_time_steps[ 2 : end ],predictions, label = "predicted" )

The Flux.jl/Julia version is very similar to the PyTorch/Python version. A few notable differences:

Numpy functionality is builtin to Julia. No need to import numpy.

torch.Variable maps to Flux.param

x and y are type torch.Variable in the PyTorch version, while they’re just regular builtin matrices on the Julia side.

Flux.param(var) indicates that the variable var will be tracked for the purposes of determining gradients (just as torch.Variable).

I did run into a bug in Flux.jl; you’ll notice the workaround on line 24. Ultimately, when the bug is fixed you’ll be able to uncomment line 22 and eliminate line 24. The bug had to do with how certain tracked collections are translated to scalar types. The tracking is required for back propagation and the problem was that the input being passed into the foward function would get another level of unnecessary tracking each time forward was called.

’.’ prior to an operator (such as at line 41 in the Julia code) indicates a broadcasting operation in Julia. Note also line ‘.’ after the tanh at line 25, it indicates that the tanh is broadcast to the matrix that results from xh*W1. (From what I can tell, numpy sort of automatically determines whether an operation should broadcast or not based on the dimensions of the operands - Julia is more explicit about this.)

Even the plotting at the end is very similar between the two versions.

In the next post I’ll modify the Julia version to use the GPU.