Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add NCF_PyTorch models #536

Merged
merged 6 commits into from
Oct 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
tests/vocab.pkl
.idea/
.vscode/

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
1 change: 0 additions & 1 deletion cornac/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,3 @@
"FM model is only supported on Linux.\n"
+ "Windows executable can be found at http://www.libfm.org."
)

176 changes: 176 additions & 0 deletions cornac/models/ncf/backend_pt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
import torch
import torch.nn as nn


optimizer_dict = {
"sgd": torch.optim.SGD,
"adam": torch.optim.Adam,
"rmsprop": torch.optim.RMSprop,
"adagrad": torch.optim.Adagrad,
}

activation_functions = {
"sigmoid": nn.Sigmoid(),
"tanh": nn.Tanh(),
"elu": nn.ELU(),
"selu": nn.SELU(),
"relu": nn.ReLU(),
"relu6": nn.ReLU6(),
"leakyrelu": nn.LeakyReLU(),
}


class GMF(nn.Module):
def __init__(
self,
num_users: int,
num_items: int,
num_factors: int = 8,
):
super(GMF, self).__init__()

self.num_users = num_users
self.num_items = num_items
self.user_embedding = nn.Embedding(num_users, num_factors)
self.item_embedding = nn.Embedding(num_items, num_factors)

self.logit = nn.Linear(num_factors, 1)
self.Sigmoid = nn.Sigmoid()

self._init_weight()

def _init_weight(self):
nn.init.normal_(self.user_embedding.weight, std=1e-2)
nn.init.normal_(self.item_embedding.weight, std=1e-2)
nn.init.normal_(self.logit.weight, std=1e-2)

def from_pretrained(self, pretrained_gmf):
self.user_embedding.weight.data.copy_(pretrained_gmf.user_embedding.weight)
self.item_embedding.weight.data.copy_(pretrained_gmf.item_embedding.weight)
self.logit.weight.data.copy_(pretrained_gmf.logit.weight)
self.logit.bias.data.copy_(pretrained_gmf.logit.bias)

def h(self, users, items):
return self.user_embedding(users) * self.item_embedding(items)

def forward(self, users, items):
h = self.h(users, items)
output = self.Sigmoid(self.logit(h)).view(-1)
return output


class MLP(nn.Module):
def __init__(
self,
num_users: int,
num_items: int,
layers=(64, 32, 16, 8),
act_fn="relu",
):
super(MLP, self).__init__()

self.num_users = num_users
self.num_items = num_items
self.user_embedding = nn.Embedding(num_users, layers[0] // 2)
self.item_embedding = nn.Embedding(num_items, layers[0] // 2)

mlp_layers = []
for idx, factor in enumerate(layers[:-1]):
mlp_layers.append(nn.Linear(factor, layers[idx + 1]))
mlp_layers.append(activation_functions[act_fn.lower()])

# unpacking layers in to torch.nn.Sequential
self.mlp_model = nn.Sequential(*mlp_layers)

self.logit = nn.Linear(layers[-1], 1)
self.Sigmoid = nn.Sigmoid()

self._init_weight()

def _init_weight(self):
nn.init.normal_(self.user_embedding.weight, std=1e-2)
nn.init.normal_(self.item_embedding.weight, std=1e-2)
for layer in self.mlp_model:
if isinstance(layer, nn.Linear):
nn.init.xavier_uniform_(layer.weight)
nn.init.normal_(self.logit.weight, std=1e-2)

def from_pretrained(self, pretrained_mlp):
self.user_embedding.weight.data.copy_(pretrained_mlp.user_embedding.weight)
self.item_embedding.weight.data.copy_(pretrained_mlp.item_embedding.weight)
for layer, pretrained_layer in zip(self.mlp_model, pretrained_mlp.mlp_model):
if isinstance(layer, nn.Linear) and isinstance(pretrained_layer, nn.Linear):
layer.weight.data.copy_(pretrained_layer.weight)
layer.bias.data.copy_(pretrained_layer.bias)
self.logit.weight.data.copy_(pretrained_mlp.logit.weight)
self.logit.bias.data.copy_(pretrained_mlp.logit.bias)

def h(self, users, items):
embed_user = self.user_embedding(users)
embed_item = self.item_embedding(items)
embed_input = torch.cat((embed_user, embed_item), dim=-1)
return self.mlp_model(embed_input)

def forward(self, users, items):
h = self.h(users, items)
output = self.Sigmoid(self.logit(h)).view(-1)
return output

def __call__(self, *args):
return self.forward(*args)


class NeuMF(nn.Module):
def __init__(
self,
num_users: int,
num_items: int,
num_factors: int = 8,
layers=(64, 32, 16, 8),
act_fn="relu",
):
super(NeuMF, self).__init__()

# layer for MLP
if layers is None:
layers = [64, 32, 16, 8]
if num_factors is None:
num_factors = layers[-1]

assert layers[-1] == num_factors

self.logit = nn.Linear(num_factors + layers[-1], 1)
self.Sigmoid = nn.Sigmoid()

self.gmf = GMF(num_users, num_items, num_factors)
self.mlp = MLP(
num_users=num_users, num_items=num_items, layers=layers, act_fn=act_fn
)

nn.init.normal_(self.logit.weight, std=1e-2)

def from_pretrained(self, pretrained_gmf, pretrained_mlp, alpha):
self.gmf.from_pretrained(pretrained_gmf)
self.mlp.from_pretrained(pretrained_mlp)
logit_weight = torch.cat(
[
alpha * self.gmf.logit.weight,
(1.0 - alpha) * self.mlp.logit.weight,
],
dim=1,
)
logit_bias = alpha * self.gmf.logit.bias + (1.0 - alpha) * self.mlp.logit.bias
self.logit.weight.data.copy_(logit_weight)
self.logit.bias.data.copy_(logit_bias)

def forward(self, users, items, gmf_users=None):
# gmf_users is there to take advantage of broadcasting
h_gmf = (
self.gmf.h(users, items)
if gmf_users is None
else self.gmf.h(gmf_users, items)
)
h_mlp = self.mlp.h(users, items)
h = torch.cat([h_gmf, h_mlp], dim=-1)
output = self.Sigmoid(self.logit(h)).view(-1)
return output
File renamed without changes.
106 changes: 48 additions & 58 deletions cornac/models/ncf/recom_gmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ class GMF(NCFBase):
----------
num_factors: int, optional, default: 8
Embedding size of MF model.

regs: float, optional, default: 0.
Regularization for user and item embeddings.
reg: float, optional, default: 0.
Regularization (weight_decay).

num_epochs: int, optional, default: 20
Number of epochs.
Expand All @@ -45,7 +45,10 @@ class GMF(NCFBase):

learner: str, optional, default: 'adam'
Specify an optimizer: adagrad, adam, rmsprop, sgd


backend: str, optional, default: 'tensorflow'
Backend used for model training: tensorflow, pytorch

early_stopping: {min_delta: float, patience: int}, optional, default: None
If `None`, no early stopping. Meaning of the arguments:

Expand Down Expand Up @@ -77,12 +80,13 @@ def __init__(
self,
name="GMF",
num_factors=8,
regs=(0.0, 0.0),
reg=0.0,
num_epochs=20,
batch_size=256,
num_neg=4,
lr=0.001,
learner="adam",
backend="tensorflow",
early_stopping=None,
trainable=True,
verbose=True,
Expand All @@ -97,17 +101,21 @@ def __init__(
num_neg=num_neg,
lr=lr,
learner=learner,
backend=backend,
early_stopping=early_stopping,
seed=seed,
)
self.num_factors = num_factors
self.regs = regs
self.reg = reg

def _build_graph(self):
########################
## TensorFlow backend ##
########################
def _build_graph_tf(self):
import tensorflow.compat.v1 as tf
from .ops import gmf, loss_fn, train_fn
from .backend_tf import gmf, loss_fn, train_fn

super()._build_graph()
self.graph = tf.Graph()
with self.graph.as_default():
tf.set_random_seed(self.seed)

Expand All @@ -123,8 +131,8 @@ def _build_graph(self):
num_users=self.num_users,
num_items=self.num_items,
emb_size=self.num_factors,
reg_user=self.regs[0],
reg_item=self.regs[1],
reg_user=self.reg,
reg_item=self.reg,
seed=self.seed,
)

Expand All @@ -144,50 +152,32 @@ def _build_graph(self):
self.initializer = tf.global_variables_initializer()
self.saver = tf.train.Saver()

self._sess_init()

def score(self, user_idx, item_idx=None):
"""Predict the scores/ratings of a user for an item.

Parameters
----------
user_idx: int, required
The index of the user for whom to perform score prediction.

item_idx: int, optional, default: None
The index of the item for which to perform score prediction.
If None, scores for all known items will be returned.

Returns
-------
res : A scalar or a Numpy array
Relative scores that the user gives to the item or to all known items
"""
if item_idx is None:
if self.train_set.is_unk_user(user_idx):
raise ScoreException(
"Can't make score prediction for (user_id=%d)" % user_idx
)

known_item_scores = self.sess.run(
self.prediction,
feed_dict={
self.user_id: [user_idx],
self.item_id: np.arange(self.train_set.num_items),
},
)
return known_item_scores.ravel()
else:
if self.train_set.is_unk_user(user_idx) or self.train_set.is_unk_item(
item_idx
):
raise ScoreException(
"Can't make score prediction for (user_id=%d, item_id=%d)"
% (user_idx, item_idx)
)

user_pred = self.sess.run(
self.prediction,
feed_dict={self.user_id: [user_idx], self.item_id: [item_idx]},
)
return user_pred.ravel()
self._sess_init_tf()

def _score_tf(self, user_idx, item_idx):
feed_dict = {
self.user_id: [user_idx],
self.item_id: np.arange(self.num_items) if item_idx is None else [item_idx],
}
return self.sess.run(self.prediction, feed_dict=feed_dict)

#####################
## PyTorch backend ##
#####################
def _build_model_pt(self):
from .backend_pt import GMF

return GMF(self.num_users, self.num_items, self.num_factors)

def _score_pt(self, user_idx, item_idx):
import torch

with torch.no_grad():
users = torch.tensor(user_idx).unsqueeze(0).to(self.device)
items = (
torch.from_numpy(np.arange(self.num_items))
if item_idx is None
else torch.tensor(item_idx).unsqueeze(0)
).to(self.device)
output = self.model(users, items)
return output.squeeze().cpu().numpy()
Loading
Loading