Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions .github/workflows/deploy_docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ concurrency:

jobs:
deploy_docs:
environment:
name: github-pages
url: ${{ steps.deployment.outputs.page_url }}
runs-on: ubuntu-latest
steps:
- name: Checkout
Expand All @@ -41,9 +44,9 @@ jobs:
- name: Setup Pages
uses: actions/configure-pages@v4
- name: Upload artifact
uses: actions/upload-pages-artifact@v3
uses: actions/upload-pages-artifact@v2
with:
path: 'docs/_build/html/'
- name: Deploy to GitHub Pages
id: deployment
uses: actions/deploy-pages@v4
uses: actions/deploy-pages@v3
6 changes: 3 additions & 3 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ jobs:
toml-sort --check pyproject.toml
./poetry_wrapper.sh --experimental --generate
toml-sort --check pyproject.toml

pytest_core:
needs: static_tests
runs-on: ubuntu-20.04
Expand Down Expand Up @@ -98,7 +98,7 @@ jobs:
python -m venv venv
. ./venv/bin/activate
pip install --upgrade pip wheel poetry==1.5.1
./poetry_wrapper.sh install -E torch-openvino
./poetry_wrapper.sh install -E torch
- name: pytest
run: |
. ./venv/bin/activate
Expand Down Expand Up @@ -299,4 +299,4 @@ jobs:
run: |
. ./venv/bin/activate
export PACKAGE_SUFFIX=.preview
./poetry_wrapper.sh --experimental build
./poetry_wrapper.sh --experimental build
5 changes: 2 additions & 3 deletions docs/pages/modules/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ Neural Networks recommenders are Lightning-compatible. They can be trained using
Bert4Rec
````````
.. autoclass:: replay.models.nn.Bert4Rec
:members: __init__, predict
:members: __init__, predict_step

SasRec
``````
Expand All @@ -372,8 +372,7 @@ SasRecCompiled

Bert4RecCompiled
~~~~~~~~~~~~~~~~
.. autoclass:: replay.models.nn.sequential.compiled.Bert4RecCompiled
:members: compile, predict
TODO

Features for easy training and validation with Lightning
________________________________________________________
Expand Down
827 changes: 209 additions & 618 deletions examples/10_bert4rec_example.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions replay/metrics/offline_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,13 +156,13 @@ def __init__(
):
"""
:param metrics: (list of metrics): List of metrics to be calculated.
:param query_column:: (str): The name of the query column.
:param user_column: (str): The name of the user column.
Note that you do not need to specify the value of this parameter for each metric separately.
It is enough to specify the value of this parameter here once.
:param item_column: (str): The name of the item column.
Note that you do not need to specify the value of this parameter for each metric separately.
It is enough to specify the value of this parameter here once.
:param rating_column: (str): The name of the rating column.
:param score_column: (str): The name of the score column.
Note that you do not need to specify the value of this parameter for each metric separately.
It is enough to specify the value of this parameter here once.
:param category_column: (str): The name of the category column.
Expand Down
3 changes: 1 addition & 2 deletions replay/models/base_rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -1162,11 +1162,10 @@ def get_features(
) -> Optional[Tuple[SparkDataFrame, int]]:
"""
Returns query or item feature vectors as a Column with type ArrayType
If a model does not have a vector for some ids they are not present in the final result.

:param ids: Spark DataFrame with unique ids
:param features: Spark DataFrame with features for provided ids
:return: feature vectors
If a model does not have a vector for some ids they are not present in the final result.
"""
return self._get_features_wrap(ids, features)

Expand Down
1 change: 1 addition & 0 deletions replay/models/nn/optimizer_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@

if TORCH_AVAILABLE:
from .optimizer_factory import FatLRSchedulerFactory, FatOptimizerFactory, LRSchedulerFactory, OptimizerFactory
from .fused_linear_ce_loss import LigerFusedLinearCrossEntropyFunction
Loading
Loading