How to Build Your First Recommendation System (Easy)
A step-by-step guide to training and serving a collaborative filtering model to serve users content
While generative AI has caused discussions about the impact of AI to skyrocket, Iâd argue recommendation systems are the AI most people should be concerned about. Theyâve been around for over a decade and choose what content people consume, what ideas they see, and even influence how people think.
Software engineers should understand recommendation systems because any company serving content to users is using a system similar to this. Collaborative filtering is simple, intuitive, and very effective.
This article is a follow-on to a previous case study detailing how collaborative filtering is used by Spotify and how a system like this has impacted the music industry. If you havenât read that article, do that first. It puts everything below into context.
Spotify ML Case Study: AI Has Fundamentally Changed the Music Industry
This is part one of a series. In this part, I detail how Spotify's recommendation system works and the real-world impact it has (both advertently and inadvertently). In the next part, I will go over how to build a simple recommendation system similar to Spotify's.
When youâre finished with this article, you will have trained your own collaborative filtering model using matrix factorization and be able to visualize it in a UI that shows how interacting with content and retraining a model changes recommendations over time.
Housekeeping & Things You Should Know
The complete code for building this collaborative filtering system can be found in the ML for SWEs GitHub repo. Please star it to support ML for SWEs and stay updated when new tutorials are added. This is the first of many ML system tutorials Iâll be putting out.
Donât forget about our Machine Learning Roadmap. Itâs a guide to ML fundamentals that can be completed entirely for free. I spent some time in 2024 curating it and confidently say itâs the best free ML roadmap available.
Iâm going to start including this Housekeeping and What You Should Know sections in each article and make each article about something. I felt the roundups were too shallow and I wasnât having fun or learning enough spending my time on them. Instead, each article will have a little roundup section included.
ML for SWEs is looking for sponsors! If you have a job opportunity, developer tool, or want to share anything else that would be beneficial for software engineers working in AI, reach out to me to get it in front of over 10,000 developers. I reserve the right to deny anything I donât think is helpful. Thereâs a high bar for what I share with my audience to ensure itâs a good fit for both readers and sponsors.
Iâll be posting more frequent jobs updates/whoâs hiring/the skills you should acquire for paid subscribers in the ML for SWEs Substack chat. Upgrade to paid if you want those.
Part of Machine Learning for Software Engineers is keeping you abreast of the happenings in AI that are actually important. Here are the most important items since our last article:
OpenAI and Amazon announced a multi-year, $38B partnership for AWS to provide large-scale compute infrastructure, including NVIDIA GB200s and GB300s.
Apple is reportedly nearing a deal to pay Google ~$1B annually to use a custom 1.2T parameter Gemini model to power a major Siri update.
OpenAI announced it now has over 1 million paying business customers and 7 million ChatGPT for Work seats.
Moonshot AIâs Kimi K2 Thinking is a new 1T parameter Mixture-of-Experts model (32B active) that uses native INT4 inference for a ~2x speedup.
NVIDIA reports achieving 4x faster inference for math problem solving using FP8 quantization and kernel optimizations.
Researchers propose Continuous Autoregressive Language Models (CALM), which compress tokens into continuous vectors to cut training FLOPs by 44%.
Terminal-Bench 2.0 was released alongside Harbor, a new framework for testing AI agents in containerized developer environments.
OpenAI published a post on understanding prompt injections, which it calls a frontier security challenge requiring multi-layered defenses.
Wikipedia is urging AI companies to stop scraping and use its paid Enterprise API to support the nonprofitâs servers and mission.
A new report details cybersecurity in the era of AI and quantum, highlighting threats from AI-automated attacks and quantum decryption.
A Stanford study found that 22-25 year-olds in AI-exposed roles, like software development, experienced a 13% employment decline since ChatGPTâs launch. [Credit:
]Platforms like Anthropicâs Claude Code are pushing a shift toward agentic coding, where developers orchestrate agent fleets rather than coding line-by-line. [Credit:
]An article explains how models like Qwen3-Next and Kimi Linear are using hybrid attention mechanisms to achieve O(n) scaling for long contexts. [Credit:
]OpenAI is offering a free year of ChatGPT Plus to transitioning U.S. servicemembers and veterans.
If youâre particularly interested in one of these things and would like a deep dive, leave a comment and Iâll see what I can do.
Now onto our collaborative filtering system!
Step 1: Retrieve the Data
For this specific project, weâll use the HetRec 2011 Last.fm 2k dataset to initially train our model and enable retraining based on simulated interactions between users and artists. HetRec 2011 Last.fm 2k is a great example of implicit feedback which is perfect for a recommendation system. It contains a mapping of listening counts between user and artist IDs, each with a given weight to infer preference (i.e. a higher listen count means a user likes an artist).
First, create a file named last_fm_loader.py. This will be used to load our dataset and prepare it for training. Include our imports at the top of the file. Weâll get into how we use each in a later section.
import requests
import zipfile
import io
import pandas as pd
import osDefine a class LastFmLoader to encapsulate all the data-loading logic. In that class, create two variables: One with the url for downloading our dataset and one naming the directory of the folder weâll store the dataset in. The __init__ function initializes placeholders for our dataframes and defines the file paths we expect to find inside the extracted zip.
class LastFmLoader:
_ZIP_URL = â[https://files.grouplens.org/datasets/hetrec2011/hetrec2011-lastfm-2k.zip](https://files.grouplens.org/datasets/hetrec2011/hetrec2011-lastfm-2k.zip)â
_DATA_DIR = âlastfm-2kâ
def __init__(self):
self.interactions = None
self.artists = None
self._interactions_file = os.path.join(self._DATA_DIR, âuser_artists.datâ)
self._artists_file = os.path.join(self._DATA_DIR, âartists.datâ)Add a private method _download_data to this class. This method downloads the zip file from the URL, and extracts its contents into the _DATA_DIR skipping this process if the folder for the data already exists. The print statements are niceties for debugging.
def _download_data(self):
if os.path.exists(self._DATA_DIR):
print(fâDirectory {self._DATA_DIR} already exists. Skipping download.â)
return
os.makedirs(self._DATA_DIR, exist_ok=True)
print(fâDownloading data from {self._ZIP_URL}...â)
try:
response = requests.get(self._ZIP_URL)
response.raise_for_status()
print(âExtracting data...â)
with zipfile.ZipFile(io.BytesIO(response.content)) as z:
z.extractall(self._DATA_DIR)
print(âDownload and extraction complete.â)
except requests.exceptions.RequestException as e:
print(fâError downloading file: {e}â)
raise
except zipfile.BadZipFile as e:
print(fâError extracting file: {e}â)
raise
except Exception as e:
print(fâAn error occurred during download/extraction: {e}â)Add a public load_data method. This is the function weâll call from our training script. It runs _download_data to ensure the data is present. Then, it uses pandas.read_csv to load the two files we care about to train our model: user_artists.dat (which contains userID, artistID, weight) and artists.dat (which contains id, name).
def load_data(self):
self._download_data()
try:
print(âLoading interactions data...â)
self.interactions = pd.read_csv(
self._interactions_file,
sep=â â,
header=0,
encoding=âutf-8â
)
print(âLoading artists data...â)
self.artists = pd.read_csv(
self._artists_file,
sep=â â,
header=0,
encoding=âutf-8â,
usecols=[âidâ, ânameâ]
)
print(âData loading complete.â)
except FileNotFoundError as e:
print(fâError loading data: {e}â)
raise
except Exception as e:
print(fâAn error occurred during data loading: {e}â)Lastly, add a test block at the end of last_fm_loader.py. This block runs a simple test showing the columns present in our data if you execute python last_fm_loader.py directly. We wonât run our training or serving system from this file, but this is great for testing its functionality.
if __name__ == â__main__â:
loader = LastFmLoader()
loader.load_data()
if loader.interactions is not None:
print(loader.interactions.head())
if loader.artists is not None:
print(loader.artists.head())Step 2: Define the Model
Create model.py. This will define our MatrixFactorization class. Start with the imports from torch.
import torch
import torch.nn as nnDefine the MatrixFactorization class, inheriting from torch.nn.Module. The __init__ method sets up our learnable parameters. These are the two embedding matrices our model will learn. nn.Embedding is a PyTorch layer that acts as a lookup table. self.user_embedding will store learned user representations and self.artist_embedding will do the same for artists.
embedding_dim is the size we choose for those representations. In __init__, we also define values for our embedding matrices.
class MatrixFactorization(nn.Module):
def __init__(self, num_users, num_artists, embedding_dim=500):
super(MatrixFactorization, self).__init__()
self.user_embedding = nn.Embedding(num_users, embedding_dim)
self.artist_embedding = nn.Embedding(num_artists, embedding_dim)
self.user_embedding.weight.data.uniform_(0, 0.05)
self.artist_embedding.weight.data.uniform_(0, 0.05)Now, we define the forward method for the class. This is what PyTorch runs when the model is called. It takes a batch of user indices and artist indices, looks up their corresponding embedding vectors, and then computes the dot product between them as described in our overview of collaborative filtering systems. The .sum(dim=1) is how we perform a batched dot product by computing the dot product over a specified dimension. This resulting âscoreâ is our modelâs prediction of how much the user likes the artist.
def forward(self, user, artist):
user_vector = self.user_embedding(user)
artist_vector = self.artist_embedding(artist)
score = (user_vector * artist_vector).sum(dim=1)
return scoreAgain, we add a test block at the end of model.py. This is a great way to perform a quick test via python model.py to make sure our modelâs input and output shapes are correct.
if __name__ == â__main__â:
print(âTesting model.pyâ)
test_num_users = 100
test_num_artists = 50
test_emb_size = 10
model = MatrixFactorization(test_num_users, test_num_artists, test_emb_size)
print(âModel created.â)
test_user_ids = torch.LongTensor([1, 5, 20, 99])
test_artist_ids = torch.LongTensor([4, 10, 30, 45])
predictions = model(test_user_ids, test_artist_ids)
print(fâ\nInput user tensor shape: {test_user_ids.shape}â)
print(fâInput artist tensor shape: {test_artist_ids.shape}â)
print(fâOutput predictions shape: {predictions.shape}â)
assert predictions.shape == (4,)
print(â\nModel test passed!â)
print(âExample predictions (randomly initialized):â)
print(predictions)Step 3: Train the Model
Create your third file, train.py. This script will use the LastFmLoader and MatrixFactorization classes to train and save our model.
Start with all the necessary imports. Notice that weâre importing our custom classes here.
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from sklearn.model_selection import train_test_split
import numpy as np
import os
from last_fm_loader import LastFmLoader
from model import MatrixFactorizationWe define a custom LastFmDataset class. DataLoader will use this class to retrieve our training data. The __init__ takes our data (as numpy arrays) and stores them as Tensors. The __len__ method returns the total number of samples. The __getitem__ method returns a single sample (one user, artist, and weight) at a given index. These will be used further down.
class LastFmDataset(Dataset):
def __init__(self, users, artists, weights):
self.users = torch.LongTensor(users)
self.artists = torch.LongTensor(artists)
self.weights = torch.FloatTensor(weights)
def __len__(self):
return len(self.weights)
def __getitem__(self, idx):
return self.users[idx], self.artists[idx], self.weights[idx]When training a model, itâs possible to overfit. This is when the model âmemorizesâ the training data but gets worse at handling new, unseen data. You know youâre overfitting when your training loss goes down but your validation loss stays higher.
Early Stopping is a technique to prevent this. We monitor the validation loss at each epoch. If the loss stops improving for a set number of epochs (our patience), we stop the training, since continuing would only make the model worse.
Weâll build this logic into a class. The __init__ method sets up our tracking parameters:
patience: How many epochs to wait for improvement before stopping.delta: A small amount the loss must improve by to be considered an âimprovementâ.The other variables (
counter,best_score, etc.) are for internal tracking.
class EarlyStopping:
def __init__(self, patience=5, verbose=False, delta=0, path=âcheckpoint.ptâ):
self.patience = patience
self.verbose = verbose
self.counter = 0
self.best_score = None
self.early_stop = False
self.val_loss_min = np.inf
self.delta = delta
self.path = pathThe __call__ method makes the class instance callable (like a function). Weâll call it at the end of each epoch, passing in the current val_loss.
It checks if this is the best score it has seen.
If not, it increments a
counter.If the
counterexceeds ourpatience, it sets theearly_stopflag toTrue.If the score is better, it resets the counter and calls
save_checkpoint.
def __call__(self, val_loss, model):
score = -val_loss
if self.best_score is None:
self.best_score = score
self.save_checkpoint(val_loss, model)
elif score < self.best_score + self.delta:
self.counter += 1
if self.verbose:
print(fâEarlyStopping counter: {self.counter} out of {self.patience}â)
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(val_loss, model)
self.counter = 0The save_checkpoint method is a helper called by __call__. Itâs only triggered when a new best validation loss is found. It saves the modelâs current weights to the specified path. This ensures that when training stops, the file at path contains the weights from the best performing epoch.
def save_checkpoint(self, val_loss, model):
if self.verbose:
print(fâValidation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...â)
torch.save(model.state_dict(), self.path)
self.val_loss_min = val_lossPyTorch nn.Embedding layers need sequential integer indices. The IDs in our data arenât sequential. Thus, we write a helper function to create two dictionaries: one to map from the original ID to sequential indices, and an inverse mapping to go back.
def create_id_mapping(df):
user_id_mapping = {original_id: i for i, original_id in enumerate(df[âuserIDâ].unique())}
artist_id_mapping = {original_id: i for i, original_id in enumerate(df[âartistIDâ].unique())}
user_inv_map = {i: original_id for original_id, i in user_id_mapping.items()}
artist_inv_map = {i: original_id for original_id, i in artist_id_mapping.items()}
return user_id_mapping, artist_id_mapping, user_inv_map, artist_inv_mapNow we define the main train_model function. This first part sets up hyperparameters, creates the model_store directory, and loads our data using the LastFmLoader. If we were writing a production system, we would run experiments to optimize the hyperparameters. This can be a lengthy process so weâre sticking with guesses and pushing forward.
def train_model(epochs=20, batch_size=1024, emb_size=50, learning_rate=0.001, model_save_path=âmodel_store/model.ptâ):
os.makedirs(os.path.dirname(model_save_path), exist_ok=True)
loader = LastFmLoader()
loader.load_data()
df = loader.interactions
if df is None:
print(âFailed to load data.â)
returnStill inside train_model, we preprocess our data. First, we call create_id_mapping to get our dictionaries. Then, we use the .map() method to replace the original userID and artistID columns with their new sequential indices.
print(âCreate ID mappings...â)
user_id_mapping, artist_id_mapping, user_inv_map, artist_inv_map = create_id_mapping(df)
df[âuserIDâ] = df[âuserIDâ].map(user_id_mapping)
df[âartistIDâ] = df[âartistIDâ].map(artist_id_mapping)Next, we apply np.log1p to the weight column. This log(1 + x) transform is useful because it scales down massive listen counts, so a user who listened 100,000 times doesnât dominate the loss function. We also get the total count of unique users and artists for our model.
df[âweight_logâ] = np.log1p(df[âweightâ])
num_users = len(user_id_mapping)
num_artists = len(artist_id_mapping)
print(fâNumber of users: {num_users}â)
print(fâNumber of artists: {num_artists}â)We split our data into an 80% training set and a 20% validation set using train_test_split.
train_df, valid_df = train_test_split(df, test_size=0.2, random_state=42)We create LastFmDataset instances for both the training and validation dataframes.
train_dataset = LastFmDataset(train_df[âuserIDâ].values, train_df[âartistIDâ].values, train_df[âweight_logâ].values)
valid_dataset = LastFmDataset(valid_df[âuserIDâ].values, valid_df[âartistIDâ].values, valid_df[âweight_logâ].values)Then we wrap our Dataset instances in DataLoader. The DataLoader is a PyTorch utility that handles batching, shuffling, and multi-process data loading for us.
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=4)We initialize our MatrixFactorization model with the number of users and artists.
print(âInitializing model...â)
model = MatrixFactorization(num_users, num_artists, embedding_dim=emb_size)We check for a GPU (CUDA for NVIDIA, MPS for Apple) and move the model to that device for faster training if available. The .to(device) call moves all of the modelâs parameters (the embedding matrices) onto the GPUâs memory. The model and data must be on the same device.
if torch.cuda.is_available():
device = torch.device(âcudaâ)
elif torch.backends.mps.is_available():
device = torch.device(âmpsâ)
else:
device = torch.device(âcpuâ)
print(fâUsing device: {device}â)
model.to(device)Then, we define our loss function. MSE is a standard loss function for regression that works by calculating the average squared difference between the modelâs prediction and the actual weight_log. It heavily penalizes large errors, which is good for this kind of system.
For the optimizer, we choose optim.Adam (Adaptive Moment Estimation). Adam is a highly effective and popular optimizer that works well âout of the boxâ for most problems. It combines the benefits of other optimizers by adapting the learning rate for each model parameter individually, which often leads to faster convergence than standard optimizers like SGD.
We also initialize our EarlyStopping class, telling it to save the best model to model_save_path.
loss_fn = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)
early_stopper = EarlyStopping(patience=3, verbose=True, path=model_save_path)This is the core of the training. We loop for epochs times. First, we set the model to model.train() mode.
print(âTraining model...â)
for epoch in range(epochs):
model.train()
total_train_loss = 0.0Inside the epoch loop, we loop over every batch in our train_loader. For each batch, we move the data to our device. This is the second half of the device equation: the model lives on the GPU, so every batch of data we feed it must also be moved to the GPU. This user.to(device), artist.to(device), etc. call does that.
for user, artist, weight in train_loader:
user, artist, weight = user.to(device), artist.to(device), weight.to(device)We then get into a standard 5-step PyTorch training process for a batch:
optimizer.zero_grad(): Clear old gradients.prediction = model(...): Get the modelâs prediction.loss = loss_fn(...): Calculate the loss.loss.backward(): Compute new gradients.optimizer.step(): Update the modelâs weights.
We also compute the modelâs total training loss as we go along.
optimizer.zero_grad()
prediction = model(user, artist)
loss = loss_fn(prediction, weight)
loss.backward()
optimizer.step()
total_train_loss += loss.item()After training on all batches, we switch to model.eval() mode and use with torch.no_grad() to turn off gradient calculations for validation.
model.eval()
total_val_loss = 0.0
with torch.no_grad():We loop over the valid_loader to get the predictions and calculate the total validation loss.
for users, artists, weights in valid_loader:
users, artists, weights = users.to(device), artists.to(device), weights.to(device)
predictions = model(users, artists)
val_loss = loss_fn(predictions, weights)
total_val_loss += val_loss.item()At the end of each epoch, we calculate and print the average training and validation losses.
avg_train_loss = total_train_loss / len(train_loader)
avg_val_loss = total_val_loss / len(valid_loader)
print(fâEpoch {epoch+1}/{epochs} - Train Loss: {avg_train_loss:.4f} - Val Loss: {avg_val_loss:.4f}â)Finally, we call our early_stopper with the validation loss. It will run its internal logic and if the early_stop flag has been set to True, we break the training loop.
early_stopper(avg_val_loss, model)
if early_stopper.early_stop:
print(âEarly stopping triggered.â)
breakAfter the loop, the model_save_path will hold the best version of our model, thanks to our EarlyStopping class. We also must save our ID mappings. Without them, we have no way to connect userID 1002 to user_index 5.
print(fâ\nTraining complete. Best model saved to {model_save_path}â)
mapping_path = âmodel_store/mappings.pthâ
torch.save({
âuser_id_mappingâ: user_id_mapping,
âartist_id_mappingâ: artist_id_mapping,
âuser_inv_mapâ: user_inv_map,
âartist_inv_mapâ: artist_inv_map
}, mapping_path)
print(fâMappings saved to {mapping_path}â)Finally, add the if __name__ == â__main__â: block to train.py so we can run it as a script using python train.py.
if __name__ == â__main__â:
train_model()You should now be able to run the full training loop. It will download the data, train the model, and save model.pt and mappings.pth in the model_store directory.
Step 4: Serve the Recommendations
Create the final file, app.py. Weâll use Streamlit to build a simple web UI.
Import Streamlit, PyTorch, pandas, numpy, and our custom classes. We also import LastFmDataset because weâll need it for retraining. We also define constants for our saved paths.
import streamlit as st
import torch
import pandas as pd
import os
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from model import MatrixFactorization
from last_fm_loader import LastFmLoader
from train import LastFmDataset
MODEL_PATH = os.path.join(âmodel_storeâ, âmodel.ptâ)
MAPPINGS_PATH = os.path.join(âmodel_storeâ, âmappings.pthâ)
SIMULATIONS = 5000Create a function load_assets to load our model, mappings, and artist data. We use Streamlitâs @st.cache_resource decorator. This tells Streamlit to run this function once and cache the result, so our app is fast and doesnât reload the model on every interaction.
Inside load_assets, we load the mappings. Note: torch.load(MAPPINGS_PATH, weights_only=False) is important. PyTorchâs security features default to weights_only=True, but our mappings file is a dictionary, not model weights.
@st.cache_resource
def load_assets():
try:
mappings = torch.load(MAPPINGS_PATH, weights_only=False)
user_map = mappings[âuser_id_mappingâ]
artist_map = mappings[âartist_id_mappingâ]
num_users = len(user_map)
num_artists = len(artist_map)Now, we initialize a new MatrixFactorization model instance and load the saved weights from our model.pt file.
model = MatrixFactorization(num_users, num_artists, embedding_dim=50)
model.load_state_dict(torch.load(MODEL_PATH))
model.eval()
Finally, we load the artist names using our LastFmLoader so we can display them later, and we return all the loaded assets.
loader = LastFmLoader()
loader.load_data()
artists_df = loader.artists.set_index(âidâ)
return model, mappings, artists_df
except FileNotFoundError as e:
print(fâError loading assets: {e}â)
st.stop()
except Exception as e:
print(fâAn error occurred during asset loading: {e}â)
st.stop()Now we create the get_recommendations function. This is the core of our appâs logic. We use @st.cache_data to cache the results for a given user.
It maps the selected_user_id to its user_idx, then gets the user_vector from the modelâs embedding layer for the current user and all artist vectors from the artist embedding layer.
@st.cache_data(show_spinner=âGenerating recommendations...â)
def get_recommendations(selected_user_id, _model, _mappings, _artists_df, num_recs=10):
user_idx = _mappings[âuser_id_mappingâ][selected_user_id]
if user_idx is None:
st.error(fâUser ID {selected_user_id} not found in the mapping.â)
return pd.DataFrame(columns=[âArtistâ, âPredicted Scoreâ])
user_tensor = torch.LongTensor([user_idx])
user_vector = _model.user_embedding(user_tensor)
all_artist_vectors = _model.artist_embedding.weight
We perform a single matrix multiplication between the one user vector and the entire matrix of artist vectors. This gets us all our predictions for a given user at once.
with torch.no_grad():
scores = torch.matmul(user_vector, all_artist_vectors.T). squeeze()We sort the scores using torch.argsort to get the top N, then loop over them, mapping the artist_model_idx back to the original artist ID and then to the artistâs name.
top_indices = torch.argsort(scores, descending=True)[:num_recs]
rec_data = []
for idx in top_indices:
artist_model_idx = idx.item()
original_artist_id = _mappings[âartist_inv_mapâ].get(artist_model_idx)
if original_artist_id:
artist_name = _artists_df.loc[original_artist_id, ânameâ]
rec_data.append((artist_name, scores[idx].item()))
return pd.DataFrame(rec_data, columns=[âArtistâ, âScoreâ])To make the visual we want to really understand collaborative filtering, weâll add functions to simulate new data and retrain the model live. simulate_new_listen creates a new dataframe of random user-artist interactions.
def simulate_new_listen(_mappings, num_simulations=100):
st.write(fâSimulating {num_simulations} new listens...â)
all_user_indices = list(_mappings[âuser_inv_mapâ].keys())
all_artist_indices = list(_mappings[âartist_inv_mapâ].keys())
sim_users = np.random.choice(all_user_indices, num_simulations)
sim_artists = np.random.choice(all_artist_indices, num_simulations)
sim_weights = np.random.randint(50, 500, num_simulations)
sim_df = pd.DataFrame({
âuser_idxâ: sim_users,
âartist_idxâ: sim_artists,
âweightâ: sim_weights
})
return sim_dfNow we create the retrain_model function. It first checks Streamlitâs st.session_state to see if a âretrained_modelâ already exists. If it does, we use that one. If not, then this is the first retraining so we start from the original load_assets() model. This ensures that clicking the button multiple times keeps improving the same âliveâ model.
def retrain_model(new_data):
st.sidebar.write(âRetraining model...â)
if âretrained_modelâ in st.session_state:
model_to_retrain = st.session_state.retrained_model
st.sidebar.write(âStarting from *previously* retrained model.â)
else:
model, _, _ = load_assets()
model_to_retrain = model
st.sidebar.write(âStarting from *original* loaded model.â)Next, we prepare the new data. Just like in train.py, we apply the log1p transform and load the data into a LastFmDataset and a DataLoader.
new_data[âweight_logâ] = np.log1p(new_data[âweightâ])
new_dataset = LastFmDataset(new_data[âuser_idxâ].values, new_data[âartist_idxâ].values, new_data[âweight_logâ].values)
new_loader = DataLoader(new_dataset, batch_size=32, shuffle=True)We also need to define our optimizer and loss function again, pointing them at the model_to_retrainâs parameters.
optimizer = optim.Adam(model_to_retrain.parameters(), lr=0.001)
loss_fn = nn.MSELoss()We run a smaller training loop, just on the new data. We set the model to train() mode and loop over our new_loader, applying the same 5-step PyTorch training process as before.
model_to_retrain.train()
for users, artists, weights, in new_loader:
optimizer.zero_grad()
predictions = model_to_retrain(users, artists)
loss = loss_fn(predictions, weights)
loss.backward()
optimizer.step()
Finally, we set the model back to eval() mode and save the updated model back into st.session_state[âretrained_modelâ]. This replaces the old âliveâ model with the new, retrained one with the more up to date weights.
model_to_retrain.eval()
st.session_state.retrained_model = model_to_retrain
st.sidebar.success(âRetraining complete!â)Now we create a simple app to visualize all the calculations that are happen. Weâre using Streamlit to keep things simple and build it entirely in Python.
First, the UI loads our assets. Then it checks st.session_state to see if a retrained model exists. If so, we use it; otherwise, we use the original model we loaded.
st.set_page_config(page_title=âMusic Recommenderâ, layout=âwideâ)
st.title(âInteractive Music Recommenderâ)
model, mappings, artists_df = load_assets()
if âretrained_modelâ in st.session_state:
model_to_use = st.session_state.retrained_model
else:
model_to_use = modelWe create a st.selectbox dropdown for the user to pick a user ID.
original_user_ids = list(mappings[âuser_inv_mapâ].values())
st.subheader(âSelect a user to see their recommendations:â)
selected_user_id = st.selectbox(âSelect a userâ, original_user_ids)If a user is selected, we call get_recommendations and display the results in a st.table.
if selected_user_id:
st.write(fâTop Recommendations for user: **{selected_user_id}**â)
recs_df = get_recommendations(selected_user_id, model_to_use, mappings, artists_df)
st.table(recs_df.set_index(âArtistâ))Finally, we add a sidebar with a button that, when clicked, runs the simulation and retraining. It then clears the recommendation cache and calls st.rerun() to refresh the app and show the new recommendations.
st.sidebar.title(âRetraining Simulationâ)
st.sidebar.write(âSimulate new user activity and retrain.â)
if st.sidebar.button(fâSimulate {SIMULATIONS} listens and retrainâ):
new_data = simulate_new_listen(mappings, num_simulations=SIMULATIONS)
retrain_model(new_data)
get_recommendations.clear()
st.rerun()And thatâs it! Youâve built a complete, end-to-end recommendation system with four files.
To see it in action, run the following command in your terminal:
streamlit run app.pyYouâll now be able to select any user, see their initial recommendations, and use the sidebar to simulate new data and retrain the model live to watch how its predictions change over time.
Always be (machine) learning,
Logan


