Skip to content

mit-wu-lab/SD-MBTL

Repository files navigation

Structure Detection for Contextual Reinforcement Learning

Codebase for "Structure Detection for Contextual Reinforcement Learning (AAAI 2026)"

Webpage arXiv

Overall SD-MBTL Process

Directory

SD-MBTL/
├── data/                   # Data
├── figures/                # Figures
├── tables/                 # Tables
├── algs.py                 # Algorithms
├── utils.py                # Utility functions
├── main.py                 # Main function
├── environment.yml         # Environment file
├── README.md               # This file
├── plot_figure.ipynb       # Jupyter notebook for plotting figures
├── s-run-main-synt.sh      # Shell script for running main.py with 3d synthetic data
├── s-run-main-synt5d.sh    # Shell script for running main.py with 5d synthetic data
├── s-run-main-synt7d.sh    # Shell script for running main.py with 7d synthetic data
├── s-run-main.sh           # Shell script for running main.py with real-world data
└── LICENSE                 # License

Installation

conda env create -f environment.yml

Training RL policies and evaluating on CMDP

We consider two types of tasks: synthetic and real-world tasks. The synthetic tasks are generated using the gen_synthetic_data.ipynb, while the real-world tasks are generated using four environments under envs/.

Synthetic tasks

The synthetic tasks are generated using the gen_synthetic_data.ipynb notebook. The notebook generates the data and saves it in the data/ directory. You can run the notebook to generate the data.

Real-world tasks

The real-world tasks are generated using the following environments:

Download the data

You can download the synthetic and real-world datasets from the following links:

After downloading the data, unzip it and place the contents in the data/ directory.

Run SD-MBTL and baselines

Run SD-MBTL and baselines in synthetic tasks (CartPole, BipedalWalker, IntersectionZoo, and CyclesGym)

3 dimensional tasks:

bash s-run-main-synt.sh

5 dimensional tasks:

bash s-run-main-synt5d.sh

7 dimensional tasks:

bash s-run-main-synt7d.sh

Run SD-MBTL and baselines in real-world tasks (CartPole, BipedalWalker, IntersectionZoo, and CyclesGym)

bash s-run-main.sh

Results

You can find the results in the tables and figures directories. The results are generated by running the plot_figure.ipynb notebook.

Synthetic tasks

3 dimensional Synthetic Tasks Results 3D Synthetic Tasks Results Any K 5 dimensional Synthetic Tasks Results 5D 7 dimensional Synthetic Tasks Results 7D

Real-world tasks

The aggregated performance scales each MBTL-based algorithm’s performance between 0 and 1—reflecting how much it outperforms the Random baseline and how closely it approaches the Myopic Oracle—averaged across four benchmarks. Real-World Tasks Results Real-World Tasks Figure Cartpole Walker Any K IntersectionZoo Crop Any K

License

This project is licensed under the MIT License. See the LICENSE file for details. Each benchmark environment is licensed under its own license. Please refer to the respective repositories for more information.

Acknowledgements

This codebase is built upon the following repositories:

Cite

Coming soon

Contact

This is a codebase for the paper "Structure Detection for Contextual Reinforcement Learning (AAAI 2026)". If you are interested in the code or have any questions, please contact the authors.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors