Using AI to predict prices and manage investments is the key to gaining a competitive edge in the crypto space. Doing so isn't as difficult as you might think.
This is the fifth of a six-part blog series on real-time crypto price predictions with AI. In this blog, I'll deploy the AI models I've built and tested previously on a real-time feed of crypto prices. Keep up with the rest of the blog series:
- Acquire up-to-date crypto data with Apache Airflow
- Implement an LSTM model with TensorFlow
- Implement a linear regression model with Nvidia RAPIDS
- Test the models on simulated real-time data
- Implement the models on real-time crypto data from Coinbase
- Share AI predictions with URIs
Now that the model building, training, and testing is all complete, it's time for the last step in the machine learning workflow: deployment. For this application, the deployment will be applying the models to real-time crypto prices obtained from the Coinbase Websocket API. Previously, Jake Mulford wrote a blog that details how to use the Coinbase websocket API that details how to use this API. I'll be building on his work for this application.
Pulling from Coinbase
I've made small changes to Jake's code so that it only listens to a subset of the available data. In order to match the time steps in the training model, we update the crypto price once per minute. I've built on his work by only ingesting the data I care about for my AI models. I will use this workflow to predict Bitcoin prices with both models.
Click to see the code!
# Required imports
from websocket import create_connection, WebSocketConnectionClosedException
from deephaven.time import to_datetime, lower_bin
from deephaven import DynamicTableWriter
import deephaven.dtypes as dht
from threading import Thread
import json
# Connect to the Coinbase API and send a get request for BTC price
ws = create_connection("wss://ws-feed.exchange.coinbase.com")
ws.send(
json.dumps(
{
"type": "subscribe",
"product_ids": ["BTC-USD"],
"channels": ["matches"],
}
)
)
# Convert a Coinbase time string to a Deephaven time
def coinbase_time_to_datetime(strn):
return to_datetime(strn[0:-1] + " UTC")
# A dict that defines the schema for the real-time BTC price table
if "coinbase_websocket_table" not in globals():
dtw_columns = {
'time': dht.DateTime,
'price': dht.float_
}
# Create the DynamicTableWriter and name the table it will write to
dtw = DynamicTableWriter(dtw_columns)
coinbase_websocket_table = dtw.table
# An empty dict of times, and a flag indicitating connection status
time_dict = {}
connection_open = True
# A function to pull data from Coinbase and populate the DynamicTableWriter
def pull_from_coinbase():
global connection_open
# Write data while the connection stays open
while connection_open:
try:
data = json.loads(ws.recv())
time = coinbase_time_to_datetime(data["time"])
price = float(data["price"])
time_mins = lower_bin(time, 60_000_000_000)
if time_mins in time_dict:
old_time = time_mins
time_dict[time_mins][0] += 1
time_dict[time_mins][1] += price
else:
time_dict[time_mins] = [1,price]
if len(time_dict) > 1:
row_to_write = []
row_to_write.append(old_time)
row_to_write.append(time_dict[old_time][1] / time_dict[old_time][0])
dtw.write_row(*row_to_write)
# Handle a KeyError - just indicates the key doesn't exist in our time dict yet
except KeyError as key_error:
print(f"Warning: The key {key_error.args} has yet to be set.")
# Handle a connection closed error - the connection is no longer open
except WebSocketConnectionClosedException as connection_error:
print("The connection to the host has been closed.")
ws.close()
connection_open = False
# A thread to load Coinbase data in the background
thread = Thread(target=pull_from_coinbase)
thread.start()
# A thread to load Coinbase data in the background
thread = Thread(target=pull_from_coinbase)
thread.start()
Real-time price prediction
In the previous blog of the series, we tested both the TensorFlow and Nvidia RAPIDS models on simulated real-time feeds. If you're keeping up with the series, you'll see that the code to implement them here is remarkably similar.
The only notable difference in the code is the use of a flag that indicates if the first price has been received from Coinbase.
TensorFlow LSTM model
The model
in code below is the TensorFlow LSTM model we previously implemented. To run the code below, you'll need to run the code from that blog.
Upon receipt of the first price, the first_time
flag gets set to false
. Any time a new price comes in after that, the prices shift, and the newest price gets rolled to the end of our model input. The price is predicted in turn.
Click to see the code!
# Helper function - gather table data into a 2d numpy array of doubles
def table_to_numpy(rows, cols):
return gather.table_to_numpy_2d(rows, cols, dtype = np.double)
# Helper function - scatter model prediction back into a table
def get_predicted_price(data, idx):
return data
# Model parameters from a previous blog in the series
n_input = 4
n_features = 1
# Globals to keep track of price and bookkeeping (first_time)
first_time = True
prices = np.array([], dtype = np.double)
last_four = np.array([0, 0, 0, 0], dtype = np.double).reshape((1, n_input, n_features))
# Use the trained model to predict prices
def predict_with_model(data):
global last_four, first_time
# Get the prediction for the first batch
current_pred = model.predict(last_four)
current_pred = scaler.inverse_transform(current_pred)
current_pred = current_pred.reshape(1,-1)[0]
add_data = data[0]
scaled_live_prices = scaler.fit_transform(dhnp.to_numpy(coinbase_websocket_table.view(["price"])).reshape(-1, 1))
value = scaled_live_prices[-1].item()
# If this is the first price, the inputs are just the price four times
if first_time:
last_four = np.array([value, value, value, value]).reshape((1, 4, 1))
first_time = False
# After the first price, shift prices circularly and put the newest at the end
last_four = np.roll(last_four, -1, axis=1)
last_four[0][-1][0] = value
return current_pred
# Put it all together
real_time_prediction=learn.learn(
table = coinbase_websocket_table,
model_func = predict_with_model,
inputs = [learn.Input("price", table_to_numpy_double)],
outputs = [learn.Output("Predicted_price", get_predicted_price, "double")],
batch_size = 1
)
Nvidia RAPIDS linear regression model
The code below uses the fitted Nvidia RAPIDS linear regression model previously implemented. To run the code below, first run the code from that blog.
This code is also remarkably similar to that of the previous blog. Just like with the LSTM above, we roll the newest value into the input for our model each time a new, current price is obtained from Coinbase.
Click to see the code!
# Helper function - gather table data into a 2d numpy array of doubles
def table_to_numpy(rows, cols):
return gather.table_to_numpy_2d(rows, cols, dtype = np.double)
# Helper function - scatter model prediction back into a table
def get_predicted_price(data, idx):
return data
# Globals for most recent prices and a bookkeeping variable
last_three = np.array([[0, 0, 0]], dtype = np.double)
first_time = True
# Use the fitted model to predict prices
def use_fitted_model(data):
global last_three, first_time
value = data[0][0] # Upon receipt of the first price, initialize the last_three variable
if first_time == True:
first_time = False
last_three = np.array([[value,value,value]], dtype = np.double)
# Circularly shift new values in, replace the oldest with the current
last_three = np.roll(last_three, 1, axis=1)
last_three[0,0] = value
predictions=linear_regression_gpu.predict(last_three)
return predictions
# Put it all together
Predict_table = learn.learn(
table = coinbase_websocket_table,
model_func = use_fitted_model,
inputs = [learn.Input(["price"], table_to_numpy)],
outputs = [learn.Output("Predicted_Price", get_predicted_price, "double")],
batch_size = 1
)
Plot the results
I like to watch my model work in real time. I can see how it performs, and make informed decisions based off its behavior.
from deephaven.plot.figure import Figure
rt_plot = Figure()\
.plot_xy(series_name="price", t=real_time_prediction, x="time", y="price")\
.plot_xy(series_name="Predicted_price", t=real_time_prediction, x="time", y="Predicted_price")\
.show()
Here's a screenshot of what the plot looks like after the first seven predictions.
Try this out for yourself! Both of the models presented in this series can easily be modified to suit your needs, and are flexible in their applications. In the next and final blog of the series, we'll share the results of our efforts in a new and exciting way.
Reach out on Slack if you have any questions or feedback for us.