diff --git a/.gitignore b/.gitignore
index 4472b89..64d49ae 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,37 +1,216 @@
-# Ignore Python bytecode files
-*.pyc
+# Byte-compiled / optimized / DLL files
__pycache__/
+*.py[codz]
+*$py.class
-# ignore data
-*.pkl
-*.csv
-*.csv.bak
-*.egg-info
+# C extensions
+*.so
-# Ignore Jupyter Notebook checkpoints
-.ipynb_checkpoints/
-# Ignore virtual environment directories
-venv/
-.env
-# Ignore logs
-logs/
-*.log
-
-# Ignore coverage reports
-.coverage
-htmlcov/
-# Ignore build directories
+# Distribution / packaging
+.Python
build/
+develop-eggs/
dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
*.egg
-# Ignore IDE specific files
-.idea/
-.vscode/
-# Ignore system files
-.DS_Store
-Thumbs.db
-# Ignore configuration files
-*.cfg
-*.ini
-
-docs/
\ No newline at end of file
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py.cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+# Pipfile.lock
+
+# UV
+# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+# uv.lock
+
+# poetry
+# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+# poetry.lock
+# poetry.toml
+
+# pdm
+# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
+# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
+# pdm.lock
+# pdm.toml
+.pdm-python
+.pdm-build/
+
+# pixi
+# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
+# pixi.lock
+# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
+# in the .venv directory. It is recommended not to include this directory in version control.
+.pixi
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# Redis
+*.rdb
+*.aof
+*.pid
+
+# RabbitMQ
+mnesia/
+rabbitmq/
+rabbitmq-data/
+
+# ActiveMQ
+activemq-data/
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.envrc
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# PyCharm
+# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
+# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
+# and can be added to the global gitignore or merged into this file. For a more nuclear
+# option (not recommended) you can uncomment the following to ignore the entire idea folder.
+# .idea/
+
+# Abstra
+# Abstra is an AI-powered process automation framework.
+# Ignore directories containing user credentials, local state, and settings.
+# Learn more at https://abstra.io/docs
+.abstra/
+
+# Visual Studio Code
+# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
+# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
+# and can be added to the global gitignore or merged into this file. However, if you prefer,
+# you could uncomment the following to ignore the entire vscode folder
+# .vscode/
+
+# Ruff stuff:
+.ruff_cache/
+
+# PyPI configuration file
+.pypirc
+
+# Marimo
+marimo/_static/
+marimo/_lsp/
+__marimo__/
+
+# Streamlit
+.streamlit/secrets.toml
\ No newline at end of file
diff --git a/Doxyfile b/Doxyfile
deleted file mode 100644
index 540f64c..0000000
--- a/Doxyfile
+++ /dev/null
@@ -1,7 +0,0 @@
-INPUT = ark
-RECURSIVE = YES
-FILE_PATTERNS = *.py
-EXTRACT_ALL = YES
-OUTPUT_DIRECTORY = docs # or docs, build/doc, /absolute/path
-CREATE_SUBDIRS = YES # keeps html, xml, latex in sub-folders
-PROJECT_NAME = "Ark"
\ No newline at end of file
diff --git a/LICENCE b/LICENCE
index 1ae131d..630bf13 100644
--- a/LICENCE
+++ b/LICENCE
@@ -1,82 +1,9 @@
-
-
-THIRD PARTY OPEN SOURCE SOFTWARE NOTICE
-
-Please note we provide an open source software notice for the third party open source software along with this software
- and/or this software component contributed by Huawei (in the following just “this SOFTWARE”).
- The open source software licenses are granted by the respective right holders.
-
-Warranty Disclaimer
-THE OPEN SOURCE SOFTWARE IN THIS SOFTWARE IS DISTRIBUTED IN THE HOPE THAT IT WILL BE USEFUL,
-BUT WITHOUT ANY WARRANTY, WITHOUT EVEN THE IMPLIED WARRANTY OF MERCHANTABILITY OR FITNESS FOR A PARTICULAR PURPOSE.
-SEE THE APPLICABLE LICENSES FOR MORE DETAILS.
-
-------------------------------------------------------------------------------------------------------------------------
-
-Copyright Notice and License Texts
-
-Software: Mockturtle (https://mockturtle.readthedocs.io/)
-Copyright notice: Copyright (c) 2018-2020
-License: MIT License
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-SOFTWARE.
-
----
-
-Software: ABC: System for Sequential Synthesis and Verification (http://www.eecs.berkeley.edu/~alanmi/abc/)
-Copyright notice: Copyright (c) 2018-2020
-License: MIT License
-Permission is hereby granted, without written agreement and without license or
-royalty fees, to use, copy, modify, and distribute this software and its
-documentation for any purpose, provided that the above copyright notice and
-the following two paragraphs appear in all copies of this software.
-
-IN NO EVENT SHALL THE UNIVERSITY OF CALIFORNIA BE LIABLE TO ANY PARTY FOR
-DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
-THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF THE UNIVERSITY OF
-CALIFORNIA HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-THE UNIVERSITY OF CALIFORNIA SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING,
-BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-A PARTICULAR PURPOSE. THE SOFTWARE PROVIDED HEREUNDER IS ON AN "AS IS" BASIS,
-AND THE UNIVERSITY OF CALIFORNIA HAS NO OBLIGATION TO PROVIDE MAINTENANCE,
-SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS.
-
----
MIT License
-Copyright (c) 2020 Eric Han
-
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
+Copyright 2025 TU Darmstadt, Huawei
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
+Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-SOFTWARE.
+The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
+THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
\ No newline at end of file
diff --git a/README.md b/README.md
index 4426d7e..fe13cf0 100644
--- a/README.md
+++ b/README.md
@@ -1,11 +1,11 @@
-
+
A Python framework for robotics research and development.
- Lightweight, flexible, and designed for researchers and developers in robotics.
+ Lightweight, flexible, and designed for researchers and developers in robot learning.
@@ -51,85 +51,34 @@
## What is this about?
-Ark is a Python-first playground for robot learning. Instead of wrestling with C++ and fragmented tools, you can collect data, train policies, and switch between simulation and real robots with just a few lines of code. Think of it as the PyTorch + Gym for robotics — simple, modular, and built for rapid prototyping of intelligent robots.
+Ark is a Python-first playground for robot learning.
+Instead of wrestling with C++ and fragmented tools, you can collect data, train policies, and switch between simulation and real robots with just a few lines of code.
+Think of it as the PyTorch + Gym for robotics — simple, modular, and built for rapid prototyping of intelligent robots.
📚 **Learn more:**
-- [📖 Tutorials](https://arkrobotics.notion.site/ARK-Home-22be053d9c6f8096bcdbefd6276aba61)
-- [⚙️ Documentation](https://robotics-ark.github.io/ark_robotics.github.io/docs/html/index.html)
-- [📄 Research Paper](https://robotics-ark.github.io/ark_robotics.github.io/static/images/2506.21628v2.pdf)
+- [📖 Tutorials]()
+- [⚙️ Documentation]()
+- [📄 Research Paper]()
-## Installation
+# Installation
-The framework depends on [Ark Types](https://github.com/Robotics-Ark/ark_types) and
-requires a Python environment managed with Conda. The steps below describe how
-to set up the repositories on **Ubuntu** and **macOS**.
-
-### Ubuntu
-
-```bash
-# create a workspace and enter it
-mkdir Ark
-cd Ark
-
-# create and activate the environment
-conda create -n ark_env python=3.10
-conda activate ark_env
-
-# clone and install the framework
-git clone https://github.com/Robotics-Ark/ark_framework.git
-cd ark_framework
-pip install -e .
-cd ..
-
-# clone and install ark_types
-git clone https://github.com/Robotics-Ark/ark_types.git
-cd ark_types
-pip install -e .
-```
-
-### macOS
-
-```bash
-# create a workspace and enter it
-mkdir Ark
-cd Ark
-
-# create and activate the environment
-conda create -n ark_env python=3.11
-conda activate ark_env
-
-# clone and install the framework
-git clone https://github.com/Robotics-Ark/ark_framework.git
-cd ark_framework
-pip install -e .
-
-# pybullet must be installed via conda on macOS
-conda install -c conda-forge pybullet
-cd ..
-
-# clone and install ark_types
-git clone https://github.com/Robotics-Ark/ark_types.git
-cd ark_types
-pip install -e .
-```
-
-After installation, verify the command-line tool is available:
-
-```bash
-ark --help
-```
+1. Create and activate a conda environment.
+ - Python 3.12 is recommended.
+ - E.g, `conda create -n ark python=3.12`
+2. Clone this repository and change directory `cd ark_framework`.
+3. Install [zenoh-python](https://github.com/eclipse-zenoh/zenoh-python)
+ - Installation instructions are found [here](https://github.com/eclipse-zenoh/zenoh-python#how-to-install-it).
+ - It is recommended to [enable zenoh features](https://github.com/eclipse-zenoh/zenoh-python#enable-zenoh-features).
+4. Install: `pip install -e .`
## Cite
-If you find Ark useful for your work please cite:
-
```bibtex
- @misc{robotark2025,
- title = {Ark: An Open-source Python-based Framework for Robot Learning},
- author = {Magnus Dierking, Christopher E. Mower, Sarthak Das, Huang Helong, Jiacheng Qiu, Cody Reading,
- Wei Chen, Huidong Liang, Huang Guowei, Jan Peters, Quan Xingyue, Jun Wang, Haitham Bou-Ammar},
- year = {2025},
- howpublished = {\url{https://ark_robotics.github.io/}},
- note = {Technical report}
- }
+@misc{robotark2025,
+ title = {An Open-source Python-based Framework for Embodied AI},
+ author = {Magnus Dierking, Christopher E. Mower, Refinath S N, Abhineet Kumar, Huang Helong, Jiacheng Qiu, Wei Chen, Huidong Liang, Huang Guowei, Jan Peters, Quan Xingyue, Jun Wang, Haitham Bou-Ammar},
+ year = {2025},
+ howpublished = {\url{https://robotics-ark.github.io/ark_robotics.github.io/}},
+ note = {Technical report}
+}
```
diff --git a/ark/__init__.py b/ark/__init__.py
deleted file mode 100644
index 8b13789..0000000
--- a/ark/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-
diff --git a/ark/cli.py b/ark/cli.py
deleted file mode 100644
index edbeb68..0000000
--- a/ark/cli.py
+++ /dev/null
@@ -1,35 +0,0 @@
-import typer
-
-from ark.client.comm_infrastructure import registry
-from ark.decoders import list_decoders
-from ark.tools.ark_graph import ark_graph
-from ark.tools import launcher
-from ark.tools import network
-from ark.tools.visualization import image_viewer
-
-app = typer.Typer()
-decoders_app = typer.Typer(help="Decoder registry utilities.")
-
-
-# Core tooling
-app.add_typer(registry.app, name="registry")
-app.add_typer(ark_graph.app, name="graph")
-app.add_typer(launcher.app, name="launcher")
-
-# Network inspection utilities
-app.add_typer(network.node, name="node")
-app.add_typer(network.channel, name="channel")
-app.add_typer(network.service, name="service")
-app.add_typer(image_viewer.app, name="view")
-
-# Decoder registry utilities
-app.add_typer(list_decoders.app, name="decoders")
-
-
-def main() -> None:
- """Main CLI entry point."""
- app()
-
-
-if __name__ == "__main__":
- main()
diff --git a/ark/client/comm_handler/__init__.py b/ark/client/comm_handler/__init__.py
deleted file mode 100644
index 8b13789..0000000
--- a/ark/client/comm_handler/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-
diff --git a/ark/client/comm_handler/comm_handler.py b/ark/client/comm_handler/comm_handler.py
deleted file mode 100644
index 8f02f42..0000000
--- a/ark/client/comm_handler/comm_handler.py
+++ /dev/null
@@ -1,62 +0,0 @@
-from abc import ABC, abstractmethod
-from lcm import LCM
-from typing import Optional
-
-
-class CommHandler(ABC):
- """!
- Base class for communication handlers, used for managing communication
- between different nodes.
-
- This class holds common attributes like the LCM instance, channel name,
- and channel type for communication handlers, and provides an interface
- for shutting down the communication.
- """
-
- def __init__(self, lcm: LCM, channel_name: str, channel_type: type):
- """!
- Initializes the communication handler with an LCM instance, a channel name,
- and a channel type.
-
- @param lcm: The LCM instance used for communication.
- @param channel_name: The name of the communication channel.
- @param channel_type: The type of the message expected for this communication channel.
- """
- self._lcm: LCM = lcm
- self.channel_name: str = channel_name
- self.channel_type: type = channel_type
- self.active = True
-
- def __repr__(self) -> str:
- """!
- Returns a string representation of the communication handler, showing the
- channel name and the type of message it handles.
-
- @return: A string representation of the handler in the format
- "channel_name[channel_type]".
- """
- return f"{self.channel_name}[{self.channel_type.__name__}]"
-
- @abstractmethod
- def get_info(self) -> dict:
- """!
- Should return a dictionary containing all information about the comms
-
- This method is abstract and should be implemented in subclasses.
- """
-
- @abstractmethod
- def suspend(self) -> None:
- """!
- Suspends the comms handler
-
- TODO
- """
-
- @abstractmethod
- def restart(self) -> None:
- """!
- Reactivates the comms handler
-
- TODO
- """
diff --git a/ark/client/comm_handler/listener.py b/ark/client/comm_handler/listener.py
deleted file mode 100644
index 5502b67..0000000
--- a/ark/client/comm_handler/listener.py
+++ /dev/null
@@ -1,98 +0,0 @@
-import threading
-from copy import deepcopy
-from ark.client.comm_handler.subscriber import Subscriber
-from typing import Any
-from lcm import LCM
-
-
-class Listener(Subscriber):
- """!
- A class for receiving and processing messages from a specific channel using an LCM subscriber.
-
- This class listens for messages on a specified channel, and saves the latest message
- and provides methods to retrieve it in a thread-safe manner. It inherits from the `Subscriber` class,
- which handles subscribing to the LCM channel and receiving messages. The `Listener` class ensures
- thread-safety using a `Lock` to protect access to the message.
-
- @param lcm: The LCM instance used for communication.
- @param channel_name: The name of the channel to subscribe to.
- @param channel_type: The type of the message expected from the channel.
- """
-
- def __init__(self, lcm: LCM, channel_name: str, channel_type: type) -> None:
- """!
- Initializes the Listener instance, subscribing to the specified channel and setting up a mutex
- to protect access to the received message.
-
- @param lcm: The LCM instance used for communication.
- @param channel_name: The name of the channel to subscribe to.
- @param channel_type: The type of the message expected from the channel.
- """
- self.mutex: threading.Lock = threading.Lock()
- self._msg: Any = None
- self.channel_name = channel_name
- self.channel_type = channel_type
- self.comm_type = "Listener"
- super().__init__(lcm, channel_name, channel_type, self.callback)
-
- def received(self) -> bool:
- """!
- Checks whether a message has been received.
-
- @return: True if a message has been received, False otherwise.
- """
- return self._msg is not None
-
- def callback(self, t: int, channel_name: str, msg: Any) -> None:
- """!
- Callback function that is called when a new message is received on the subscribed channel.
-
- This method is invoked by the parent `Subscriber` class when a new message is received.
- It locks the mutex to safely store the received message in the instance.
-
- @param t: The time stamp when the message was received in nanoseconds.
- @param channel_name: The name of the channel to subscribe to.
- @param msg: The received message.
- """
- with self.mutex:
- self._msg = msg
-
- def get(self) -> Any:
- """!
- Retrieves the latest received message in a thread-safe manner.
-
- The method locks the mutex to ensure thread-safe access to the message. It creates and returns
- a deep copy of the message to avoid any unintended modifications to the internal state.
-
- @return: A deep copy of the latest received message.
- """
- with self.mutex:
- msg = deepcopy(self._msg)
- return msg
-
- def suspend(self):
- """!
- Suspend the listener and clear any cached message.
-
- @return: ``None``
- """
- self.empty_data()
- return super().suspend()
-
- def empty_data(self):
- """!
- Clear the stored message.
- """
- self._msg = None
-
- def get_info(self):
- """!
- Return a dictionary describing this listener.
- """
- info = {
- "comms_type": "Listener",
- "channel_name": self.channel_name,
- "channel_type": self.channel_type.__name__,
- "channel_status": self.active,
- }
- return info
diff --git a/ark/client/comm_handler/multi_channel_listener.py b/ark/client/comm_handler/multi_channel_listener.py
deleted file mode 100644
index bcf55a7..0000000
--- a/ark/client/comm_handler/multi_channel_listener.py
+++ /dev/null
@@ -1,56 +0,0 @@
-import lcm
-from lcm import LCM
-import time
-import threading
-from ark.client.frequencies.stepper import Stepper
-from ark.client.comm_handler.listener import Listener
-from ark.client.comm_handler.subscriber import Subscriber
-from ark.client.comm_handler.multi_comm_handler import MultiCommHandler
-from typing import List
-import copy
-
-
-class MultiChannelListener(MultiCommHandler):
- def __init__(self, channels: dict[str, type], lcm_instance: LCM) -> None:
- """!
- Initialize listeners for multiple channels.
-
- @param channels: List of ``(channel_name, channel_type)`` tuples.
- @param lcm_instance: LCM instance used for communication.
- """
-
- super().__init__()
-
- self.data = {}
- self.blank_data = {}
- self.comm_type = "Multi Channel Listener"
-
- for channel_name, channel_type in channels.items():
- listener = Listener(lcm_instance, channel_name, channel_type)
- self.channel_data[channel_name] = None
- self.blank_data[channel_name] = None
- self._comm_handlers.append(listener)
-
- def get(self):
- """!
- Retrieves the current observation from the space.
-
- @return: The current observation.
- @rtype: Any
- """
-
- # get all the data
- for listener in self._comm_handlers:
- listener_message = listener.get()
- self.data[listener.channel_name] = listener_message
-
- # return it
- return self.data
-
- def empty_data(self):
- """!
- Empties the data dictionary.
- """
- self.data = copy.deepcopy(self.blank_data)
- for listener in self._comm_handlers:
- listener.empty_data()
diff --git a/ark/client/comm_handler/multi_channel_publisher.py b/ark/client/comm_handler/multi_channel_publisher.py
deleted file mode 100644
index b5c4a04..0000000
--- a/ark/client/comm_handler/multi_channel_publisher.py
+++ /dev/null
@@ -1,60 +0,0 @@
-from lcm import LCM
-import time
-import threading
-from ark.client.frequencies.stepper import Stepper
-from ark.client.comm_handler.publisher import Publisher
-from ark.client.comm_handler.multi_comm_handler import MultiCommHandler
-from abc import ABC, abstractmethod
-from typing import Any, Optional, List, Dict
-from ark.tools.log import log
-
-
-class MultiChannelPublisher(MultiCommHandler):
- """!
- Publisher that manages multiple communication channels.
-
- @note Internally creates one :class:`Publisher` per channel.
- """
-
- def __init__(self, channels: Dict[str, type], lcm_instance: LCM) -> None:
- """!
- Initialize the publisher with a list of channels.
-
- @param channels: Dictionary mapping channel names to their types.
- @type channels: Dict[str, type]
- @param lcm_instance: LCM instance used for publishing.
- """
-
- super().__init__()
-
- self.comm_type = "Multi Channel Publisher"
- # iterate through the channels dictionary
- for channel_name, channel_type in channels.items():
- publisher = Publisher(lcm_instance, channel_name, channel_type)
- self._comm_handlers.append(publisher)
-
- def publish(self, messages_to_publish: Dict[str, Any]) -> None:
- """!
- Publish messages to their respective channels.
-
- @param messages_to_publish: Mapping of channel names to messages.
- """
- for publisher in self._comm_handlers:
- channel_name = publisher.channel_name
- channel_type = publisher.channel_type
- try:
- if channel_name not in messages_to_publish:
- continue
- message = messages_to_publish[channel_name]
-
- if not isinstance(message, channel_type):
- raise TypeError(
- f"Incorrect message type for channel '{channel_name}'. "
- f"Expected {channel_type}, got {type(message)}."
- )
-
- publisher.publish(message)
- except:
- log.warning(
- f"Error Occured when publishing on channel '{channel_name}'."
- )
diff --git a/ark/client/comm_handler/multi_comm_handler.py b/ark/client/comm_handler/multi_comm_handler.py
deleted file mode 100644
index 7e6b97a..0000000
--- a/ark/client/comm_handler/multi_comm_handler.py
+++ /dev/null
@@ -1,36 +0,0 @@
-from abc import ABC, abstractmethod
-from lcm import LCM
-from typing import Optional
-from ark.client.comm_handler.comm_handler import CommHandler
-
-
-class MultiCommHandler(ABC):
- def __init__(self):
- self.channel_data = {}
- self._comm_handlers: list[CommHandler] = []
-
- def get_info(self) -> dict:
- """!
- Should return a dictionary containing all information about the comms
- """
- info = []
- for ch in self._comm_handlers:
- ch_info = ch.get_info()
- info.append(ch_info)
-
- print(info)
- return info
-
- def suspend(self) -> None:
- """!
- Suspends the comms handler
- """
- for ch in self._comm_handlers:
- ch.suspend()
-
- def restart(self) -> None:
- """!
- Reactivates the comms handler
- """
- for ch in self._comm_handlers:
- ch.restart()
diff --git a/ark/client/comm_handler/publisher.py b/ark/client/comm_handler/publisher.py
deleted file mode 100644
index 1bdbd1b..0000000
--- a/ark/client/comm_handler/publisher.py
+++ /dev/null
@@ -1,70 +0,0 @@
-from ark.tools.log import log
-from ark.client.comm_handler.comm_handler import CommHandler
-from lcm import LCM
-
-
-class Publisher(CommHandler):
- """!
- A Publisher class that extends the CommHandler base class. This class handles
- the publishing of messages to a specified communication channel using LCM.
-
- Attributes:
- _enabled (bool): A flag indicating whether the publisher is enabled.
- """
-
- def __init__(self, lcm: LCM, channel_name: str, channel_type: type) -> None:
- """!
- Initializes the Publisher instance with the LCM instance, channel name,
- and message type. Also sets the publisher as enabled and logs the setup.
-
- @param lcm: The LCM instance used for communication.
- @param channel_name: The name of the channel for publishing messages.
- @param channel_type: The type of message expected for the channel.
- """
- self.channel_name = channel_name
- self.channel_type = channel_type
- self.comm_type = "Publisher"
- super().__init__(lcm, channel_name, channel_type)
- log.ok(f"setup publisher {self}")
-
- def publish(self, msg: object) -> None:
- """!
- Publishes a message to the specified channel if the publisher is enabled.
-
- @param msg: The message object to publish. This should match the expected
- type for the channel.
- """
- assert (
- type(msg) == self.channel_type
- ), f"Wrong Message Type send to Channel {self.channel_name}"
-
- if self.active:
- self._lcm.publish(self.channel_name, self.channel_type.encode(msg))
- else:
- log.warning(f"publisher {self} is not enabled, cannot publish messages")
-
- def restart(self) -> None:
- """!
- Restarts the publisher by enabling it again and logging the action.
- """
- self.active = True
- log.ok(f"enabled {self}")
-
- def suspend(self) -> None:
- """!
- Shuts down the publisher by disabling it and logging the shutdown action.
- """
- self.active = False
- log.ok(f"suspended publisher {self}")
-
- def get_info(self):
- """!
- Return a dictionary describing this publisher.
- """
- info = {
- "comms_type": "Publisher",
- "channel_name": self.channel_name,
- "channel_type": self.channel_type.__name__,
- "channel_status": self.active,
- }
- return info
diff --git a/ark/client/comm_handler/service.py b/ark/client/comm_handler/service.py
deleted file mode 100644
index 919606e..0000000
--- a/ark/client/comm_handler/service.py
+++ /dev/null
@@ -1,436 +0,0 @@
-import socket
-import struct
-import threading
-from typing import Callable, Type
-from ark.client.comm_handler.comm_handler import CommHandler
-import json
-import time
-from ark.tools.log import log
-from typing import Any
-
-
-class Service(CommHandler):
- def __init__(
- self,
- service_name: str,
- req_type: Type,
- resp_type: Type,
- callback: Callable[[str, object], object],
- registry_host: str,
- registry_port: int,
- host: str = None,
- port: int = None,
- is_default=False,
- ):
- """!
- Initialize the service.
-
- :param name: Name of the service.
- :param req_type: Request message class with encode/decode methods.
- :param resp_type: Response message class with encode/decode methods.
- :param callback: Function to handle the request and return a response.
- :param registry_host: Host of the registry server.
- :param registry_port: Port of the registry server.
- :param host: Host to bind the service. If None, binds to the local network interface.
- :param port: Port to bind the service. If None, a random free port is chosen.
- """
- self.service_name = service_name
- self.comm_type = "Service"
- self.req_type = req_type
- self.resp_type = resp_type
- self.callback = callback
- self.host = host if host is not None else self._get_local_ip()
- self.port = port if port is not None else self._find_free_port()
- self.registry_host = registry_host
- self.registry_port = registry_port
- self._stop_event = threading.Event()
- self.thread = threading.Thread(target=self._serve)
- self.is_default_service = is_default
- self.thread.daemon = True
- self.thread.start()
- self.registered = self.register_with_registry()
-
- def _get_local_ip(self) -> str:
- """!
- Get the local IP address of the machine.
-
- @return: Detected local IP address.
- """
- s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
- s.settimeout(0)
- try:
- s.connect(("10.254.254.254", 1)) # Connect to a non-local address
- local_ip = s.getsockname()[0]
- except Exception:
- local_ip = "0.0.0.0" # If it fails, use a fallback IP
- finally:
- s.close()
- return local_ip
-
- def _find_free_port(self) -> int:
- """!
- Find a free port to bind the service.
-
- @return: Available port number.
- """
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
- s.bind((self.host, 0))
- return s.getsockname()[1]
-
- def register_with_registry(self):
- """!
- Register the service with the registry server.
-
- @return: ``True`` on success, ``False`` otherwise.
- """
- registration = {
- "type": "REGISTER",
- "service_name": self.service_name,
- "host": self.host, # Use the local IP address for registration
- "port": self.port,
- }
- try:
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
-
- s.connect((self.registry_host, self.registry_port))
- encoded_req = json.dumps(registration).encode("utf-8")
-
- s.sendall(struct.pack("!I", len(encoded_req)))
-
- s.sendall(encoded_req)
- # Receive response
- raw_resp_len = self._recvall(s, 4)
-
- if not raw_resp_len:
- log.error(
- "Service: Failed to receive registration response length."
- )
- return False
- resp_len = struct.unpack("!I", raw_resp_len)[0]
- data = self._recvall(s, resp_len)
- if not data:
- log.error("Service: Failed to receive registration response data.")
- return False
- response = json.loads(data.decode("utf-8"))
- if response.get("status") == "OK":
- log.info(
- f"Service: Successfully registered '{self.service_name}' with registry."
- )
- else:
- log.error(
- f"Service: Registration failed - {response.get('message')}"
- )
- return False
- except Exception as e:
- # log.error(f"Service: Error registering with registry - {e}")
- return
- return True
-
- def _serve(self):
- """!
- Serve incoming service requests.
- """
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
- s.bind((self.host, self.port))
- s.listen()
- while not self._stop_event.is_set():
- try:
- s.settimeout(1.0)
- conn, addr = s.accept()
-
- except socket.timeout:
- continue
- with conn:
-
- try:
- # Receive message length
- raw_msglen = self._recvall(conn, 4)
- if not raw_msglen:
- print("Service: No message length received.")
- continue
- msglen = struct.unpack("!I", raw_msglen)[0]
-
- # Receive the actual message
- data = self._recvall(conn, msglen)
- if not data:
- print("Service: No data received.")
- continue
- # Decode the request
- request = self.req_type.decode(data)
-
- # Process the request
- response = self.callback(self.service_name, request)
- # Encode the response
- encoded_resp = response.encode()
-
- # Send the length of the response first
- conn.sendall(struct.pack("!I", len(encoded_resp)))
- # Then send the actual response
- conn.sendall(encoded_resp)
- except Exception as e:
- log.error(f"Service: Error handling request: {e}")
-
- def _recvall(self, conn, n):
- """!
- Helper function to receive ``n`` bytes from a socket.
-
- @param conn: Socket connection.
- @param n: Number of bytes to read.
- @return: Received bytes or ``None`` on EOF.
- """
- data = bytearray()
- while len(data) < n:
- packet = conn.recv(n - len(data))
- if not packet:
- return None
- data.extend(packet)
- return bytes(data)
-
- def __repr__(self):
- """
- Returns a string representation of the communication handler, including
- the channel name and the types of messages it handles.
-
- The string is formatted as:
- "channel_name[request_type, response_type]".
-
- @return: A string representation of the handler, formatted as
- "channel_name[request_type,response_type]".
- """
- return f"{self.service_name}[{self.req_type},{self.resp_type}]"
-
- def restart(self):
- """!
- Restart the service communication handlers.
- """
- return super().restart()
-
- def suspend(self):
- """!
- Shut down the service and deregister from the registry.
- """
- if self.deregister_from_registry():
- self._stop_event.set() # Stop the serving thread
- self.thread.join() # Wait for the serving thread to terminate
- print(f"Service '{self.service_name}' stopped.")
- else:
- print("Service shutdown un-gracefully.")
-
- def deregister_from_registry(self) -> bool:
- """!
- Deregister the service from the registry server and validate the response.
-
- @return: ``True`` if deregistration succeeded.
- """
- deregistration = {
- "type": "DEREGISTER",
- "name": self.service_name,
- "host": self.host,
- "port": self.port,
- }
- try:
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
- s.connect((self.registry_host, self.registry_port))
- encoded_req = json.dumps(deregistration).encode("utf-8")
- s.sendall(struct.pack("!I", len(encoded_req)))
- s.sendall(encoded_req)
- log.info(
- f"Service: Sending deregistration request for '{self.service_name}'."
- )
-
- # Receive response length
- raw_resp_len = self._recvall(s, 4)
- if not raw_resp_len:
- log.error(
- "Service: Failed to receive deregistration response length."
- )
- return False
- resp_len = struct.unpack("!I", raw_resp_len)[0]
-
- # Receive the actual response data
- data = self._recvall(s, resp_len)
- if not data:
- log.error(
- "Service: Failed to receive deregistration response data."
- )
- return False
-
- # Parse the response
- response = json.loads(data.decode("utf-8"))
- if response.get("status") == "OK":
- log.info(
- f"Service: Successfully deregistered '{self.service_name}' from registry."
- )
- return True
- else:
- log.error(
- f"Service: Deregistration failed - {response.get('message')}"
- )
- return False
- except Exception as e:
- log.error(f"Service: Error deregistering from registry - {e}")
- return False
-
- def get_info(self):
- """!
- Return a dictionary describing this service instance.
- """
- info = {
- "comms_type": "Service",
- "service_name": self.service_name,
- "service_host": self.host,
- "service_port": self.port,
- "registry_host": self.registry_host,
- "registry_port": self.registry_port,
- "request_type": self.req_type.__name__,
- "response_type": self.resp_type.__name__,
- "default_service": self.is_default_service,
- }
-
- return info
-
-
-def send_service_request(
- registry_host,
- registry_port,
- service_name: str,
- request: object,
- response_type: type,
- timeout: int = 1,
-) -> Any:
- """!
- Send a request to a service discovered from a registry.
-
- @param registry_host: Host address of the service registry.
- @param registry_port: Port of the service registry.
- @param service_name: Name of the service to discover.
- @param request: Request object to send.
- @param response_type: Expected response type.
- @param timeout: Timeout in seconds.
- @return: The response from the service.
- """
- # TODO timeout addition
- try:
- # Discover the host and port of the service from the registry
- host, port = __discover_service(registry_host, registry_port, service_name)
- # Call the discovered service with the provided request
- response = __call_service(host, port, request, response_type)
- return response
- except Exception as e:
- log.error(f"Client Error: {e}")
- pass
-
-
-def __call_service(
- service_host: str, service_port: int, request, response_type: type
-) -> Any:
- """!
- Call a specific service with the given request and return the response.
-
- @param service_host: Host address of the service.
- @param service_port: Port of the service.
- @param request: Request to send to the service.
- @param response_type: Expected response type.
- @return: The decoded response object.
- @raises RuntimeError: If communication fails.
- """
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
- # Connect to the service
- s.connect((service_host, service_port))
-
- # Encode the request into bytes
- encoded_req = request.encode()
-
- # Send the length of the request first
- s.sendall(struct.pack("!I", len(encoded_req)))
-
- # Then send the actual request data
- s.sendall(encoded_req)
-
- # Receive the length of the response (first 4 bytes)
- raw_resp_len = __recvall(s, 4)
- if not raw_resp_len:
- raise RuntimeError("Client: Failed to receive response length.")
- resp_len = struct.unpack("!I", raw_resp_len)[0]
-
- # Receive the actual response data
- data = __recvall(s, resp_len)
- if not data:
- raise RuntimeError("Client: Failed to receive response data.")
-
- # Decode the response into the specified response type
- response = response_type.decode(data)
- return response
-
-
-def __recvall(conn: socket.socket, n: int) -> bytes:
- """!
- Receive ``n`` bytes from a socket connection.
-
- @param conn: Socket connection to read from.
- @param n: Number of bytes to receive.
- @return: Bytes received or ``None`` if EOF is reached.
- """
- data = bytearray()
- while len(data) < n:
- # Receive the remaining bytes
- packet = conn.recv(n - len(data))
- if not packet:
- return None # EOF hit
- data.extend(packet)
- return bytes(data)
-
-
-def __discover_service(registry_host: str, registry_port: int, service_name: str):
- """!
- Discover the host and port of a service by querying the registry.
-
- @param registry_host: Host address of the registry.
- @param registry_port: Port of the registry.
- @param service_name: Name of the service to discover.
- @return: ``(host, port)`` tuple of the discovered service.
- @raises RuntimeError: If discovery fails.
- """
- discovery_request = {"type": "DISCOVER", "service_name": service_name}
- try:
- # Create a socket connection to the registry
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
- s.connect((registry_host, registry_port))
-
- # Encode the discovery request to send it over the socket
- encoded_req = json.dumps(discovery_request).encode("utf-8")
-
- # Send the length of the request first
- s.sendall(struct.pack("!I", len(encoded_req)))
-
- # Send the actual discovery request
- s.sendall(encoded_req)
-
- # Receive the length of the response (first 4 bytes)
- raw_resp_len = __recvall(s, 4)
- if not raw_resp_len:
- raise RuntimeError(
- "Client: Failed to receive discovery response length."
- )
- resp_len = struct.unpack("!I", raw_resp_len)[0]
-
- # Receive the actual response data
- data = __recvall(s, resp_len)
- if not data:
- raise RuntimeError("Client: Failed to receive discovery response data.")
-
- # Decode the response
- response = json.loads(data.decode("utf-8"))
-
- # If the service was successfully discovered, return the host and port
- if response.get("status") == "OK":
- host = response.get("host")
- port = response.get("port")
- return host, port
- else:
- raise RuntimeError(
- f"Client: Service discovery of {service_name} failed - {response.get('message')}"
- )
- except Exception as e:
- log.error(f"Client: Error during service discovery - {e}")
- raise
diff --git a/ark/client/comm_handler/subscriber.py b/ark/client/comm_handler/subscriber.py
deleted file mode 100644
index 1021741..0000000
--- a/ark/client/comm_handler/subscriber.py
+++ /dev/null
@@ -1,100 +0,0 @@
-import time
-from lcm import LCM
-from ark.tools.log import log
-from ark.client.comm_handler.comm_handler import CommHandler
-from typing import Callable
-
-
-class Subscriber(CommHandler):
- """!
- A subscriber for listening to messages on a communication channel.
-
- This class subscribes to a specified communication channel and calls a user-defined
- callback function whenever a new message is received. The message data is passed to the
- callback along with the timestamp and channel name.
-
- @note: This class is a subclass of `CommHandler` and requires an LCM instance,
- a channel name, and a channel type to function correctly.
- """
-
- def __init__(
- self,
- lcm: LCM,
- channel_name: str,
- channel_type: type,
- callback: Callable[[int, str, object], None],
- callback_args: list[object] = [],
- ) -> None:
- """!
- Initializes the subscriber with the necessary parameters for subscribing
- to a communication channel and setting up the callback function.
-
- @param lcm: The LCM instance used for communication.
- @param channel_name: The name of the communication channel.
- @param channel_type: The type of the message expected for this communication channel.
- @param callback: The user-defined callback function to be called with the message data.
- @param callback_args: Additional arguments to be passed to the callback function.
- """
- super().__init__(lcm, channel_name, channel_type)
- self._user_callback: Callable[[int, str, object], None] = callback
- self._callback_args: list[object] = callback_args
- self.comm_type = "Subscriber"
- self.subscribe()
-
- def subscriber_callback(self, channel_name: str, data: bytes) -> None:
- """!
- Callback function to handle incoming messages on the subscribed channel.
-
- This method decodes the message data, records the timestamp, and calls
- the user-defined callback function with the timestamp, channel name, and message.
-
- @param channel_name: The name of the communication channel the message was received from.
- @param data: The raw byte data of the message.
- """
- t: int = time.time_ns()
- try:
- msg: object = self.channel_type.decode(data)
- self._user_callback(t, channel_name, msg, *self._callback_args)
- except ValueError as e:
- log.warning(f"failed to decode message on channel '{channel_name}': {e}")
-
- def subscribe(self):
- """!
- Subscribe to the configured channel.
-
- @return: ``None``
- """
- self._sub = self._lcm.subscribe(self.channel_name, self.subscriber_callback)
- self._sub.set_queue_capacity(1) # TODO: configurable
- log.ok(f"subscribed to {self}")
- self.active = True
-
- def restart(self):
- """!
- Reconnect the subscriber to its channel.
- """
- self.subscribe()
- self.active = True
-
- def suspend(self) -> None:
- """!
- Suspends the subscriber by unsubscribing from the communication channel.
-
- This method releases the subscription and logs that the subscriber has been unsubscribed.
- """
- if self.active == True:
- self._lcm.unsubscribe(self._sub)
- log.ok(f"unsubscribed from {self}")
- self.active = False
-
- def get_info(self):
- """!
- Return a dictionary describing this subscriber.
- """
- info = {
- "comms_type": "Subscriber",
- "channel_name": self.channel_name,
- "channel_type": self.channel_type.__name__,
- "channel_status": self.active,
- }
- return info
diff --git a/ark/client/comm_infrastructure/__init__.py b/ark/client/comm_infrastructure/__init__.py
deleted file mode 100644
index 8b13789..0000000
--- a/ark/client/comm_infrastructure/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-
diff --git a/ark/client/comm_infrastructure/base_node.py b/ark/client/comm_infrastructure/base_node.py
deleted file mode 100644
index 9e25489..0000000
--- a/ark/client/comm_infrastructure/base_node.py
+++ /dev/null
@@ -1,83 +0,0 @@
-import sys
-import traceback
-from typing import Type
-
-from ark.client.comm_infrastructure.comm_endpoint import CommEndpoint
-from ark.tools.log import log
-
-
-class BaseNode(CommEndpoint):
- """!
- Base class for nodes that interact with the LCM system. Handles the subscription,
- publishing, and communication processes for the node.
-
- The `BaseNode` class manages the LCM instance and communication handlers, and provides
- methods for creating publishers, subscribers, listeners, and steppers. It also provides
- functionality for handling command-line arguments and the graceful shutdown of the node.
-
- @param lcm: The LCM instance used for communication.
- @param channel_name: The name of the channel to subscribe to.
- @param channel_type: The type of the message expected for the channel.
- """
-
- def __init__(self, name: str, global_config=None) -> None:
- """!
- Initializes a BaseNode object with the specified node name and registry host and port.
-
- @param node_name: The name of the node.
- @param global_config: Contains IP Address and Port
- """
- super().__init__(name, global_config)
- self.config = self._load_config_section(
- global_config=global_config, name=name, type="other"
- )
- self._done = False
-
- def spin(self) -> None:
- """!
- Runs the node’s main loop, handling LCM messages continuously until the node is finished.
-
- The loop calls `self._lcm.handle()` to process incoming messages. If an OSError is encountered,
- the loop will stop and the node will shut down.
- """
- while not self._done:
- try:
- self._lcm.handle_timeout(0)
- except OSError as e:
- log.warning(f"LCM threw OSError {e}")
- self._done = True
-
-
-def main(node_cls: type[BaseNode], *args, **kwargs) -> None:
- """!
- Initializes and runs a node.
-
- This function creates an instance of the specified `node_cls`, spins the node to handle messages,
- and handles exceptions that occur during the node's execution.
-
- @param node_cls: The class of the node to run.
- @type node_cls: Type[BaseNode]
- """
-
- if "--help" in sys.argv or "-h" in sys.argv:
- print(node_cls.get_cli_doc())
- sys.exit(0)
-
- node = None
- log.ok(f"Initializing {node_cls.__name__} type node")
- try:
- node = node_cls(*args, **kwargs)
- log.ok(f"Initialized {node.name}")
- node.spin()
- except KeyboardInterrupt:
- log.warning(f"User killed node {node_cls.__name__}")
- except Exception:
- tb = traceback.format_exc()
- div = "=" * 30
- log.error(f"Exception thrown during node execution:\n{div}\n{tb}\n{div}")
- finally:
- if node is not None:
- node.kill_node()
- log.ok(f"Finished running node {node_cls.__name__}")
- else:
- log.warning(f"Node {node_cls.__name__} failed during initialization")
diff --git a/ark/client/comm_infrastructure/comm_endpoint.py b/ark/client/comm_infrastructure/comm_endpoint.py
deleted file mode 100644
index 56676de..0000000
--- a/ark/client/comm_infrastructure/comm_endpoint.py
+++ /dev/null
@@ -1,654 +0,0 @@
-import ast
-import os
-import time
-import signal
-import uuid
-import sys
-from abc import ABC, abstractmethod
-from typing import Any, Generator, Dict, Union
-from pathlib import Path
-import yaml
-
-import lcm
-from lcm import LCM
-
-from ark.client.comm_infrastructure.endpoint import EndPoint
-from ark.client.comm_handler.comm_handler import CommHandler
-from ark.client.comm_handler.multi_comm_handler import MultiCommHandler
-from ark.client.comm_handler.publisher import Publisher
-from ark.client.comm_handler.subscriber import Subscriber
-from ark.client.comm_handler.listener import Listener
-from ark.client.comm_handler.service import Service, send_service_request
-from ark.client.comm_handler.multi_channel_publisher import MultiChannelPublisher
-from ark.client.comm_handler.multi_channel_listener import MultiChannelListener
-from ark.client.frequencies.stepper import Stepper
-from ark.tools.log import log
-from ark.utils.utils import ConfigPath
-
-from arktypes import (
- flag_t,
- node_info_t,
- comms_info_t,
- service_info_t,
- listener_info_t,
- subscriber_info_t,
- publisher_info_t,
-)
-
-
-DEFAULT_SERVICE_DECORATOR = "__DEFAULT_SERVICE"
-
-
-class CommEndpoint(EndPoint):
- """!
- Base class for nodes that interact with the LCM system. Handles the subscription,
- publishing, and communication processes for the node.
-
- The `BaseNode` class manages the LCM instance and communication handlers, and provides
- methods for creating publishers, subscribers, listeners, and steppers. It also provides
- functionality for handling command-line arguments and the graceful shutdown of the node.
-
- @param lcm: The LCM instance used for communication.
- @param channel_name: The name of the channel to subscribe to.
- @param channel_type: The type of the message expected for the channel.
- """
-
- # def __init__(self, node_name: str, registry_host: str = "127.0.0.1", registry_port: int = 1234, lcm_network_bounces: int = 1) -> None:
- def __init__(
- self, node_name: str, global_config: Union[str, Dict[str, Any], Path]
- ) -> None:
- """!
- Initialize a communication endpoint for a node.
-
- @param node_name: The name for the node.
- param unique: True when the node is unique, False otherwise. When True the node name is appended with a unique stamp.
-
- @raises SystemExit: If "--help" or "-h" is passed in command-line arguments.
- """
-
- # system_config = self.load_system_config(global_config, node_name)
-
- super().__init__(global_config)
-
- self.name = node_name
- self.node_id = str(uuid.uuid4())
-
- self._done: bool = False
- self._comm_handlers: list[CommHandler] = []
- self._multi_comm_handlers: list[MultiCommHandler] = []
- self._steppers: list[Stepper] = []
-
- # Create default service for get info of the node
- get_info_service_name = (
- f"{DEFAULT_SERVICE_DECORATOR}/GetInfo/{self.name}_{self.node_id}"
- )
- get_info_service = self.create_service(
- get_info_service_name, flag_t, node_info_t, self._callback_get_info, True
- )
-
- suspend_node_service_name = (
- f"{DEFAULT_SERVICE_DECORATOR}/SuspendNode/{self.name}_{self.node_id}"
- )
- suspend_node_service = self.create_service(
- suspend_node_service_name, flag_t, flag_t, self._callback_suspend_node, True
- )
-
- restart_node_service_name = (
- f"{DEFAULT_SERVICE_DECORATOR}/RestartNode/{self.name}_{self.node_id}"
- )
- restart_node_service = self.create_service(
- restart_node_service_name, flag_t, flag_t, self._callback_restart_node, True
- )
-
- self.registered = self.check_registration()
-
- if self.registered == False:
- log.error(
- "Unable to connect to Registry. Please check network configuration setting / start a registry"
- )
- sys.exit(1)
-
- def check_registration(self):
- """!
- Check whether all default services have been registered.
-
- @return: ``True`` if all services are registered, ``False`` otherwise.
- """
- # check if default services are registered
- n_service_channels = 0
- n_service_channels_registered = 0
- for ch in self._comm_handlers:
- # check if type is a service
- if ch.comm_type == "Service":
- n_service_channels += 1
- if ch.registered == True:
- n_service_channels_registered += 1
-
- if (
- n_service_channels != n_service_channels_registered
- and n_service_channels_registered == 0
- ):
- # log.error("ARK Registry has not been started, use 'ark registry start', to start ")
- return False
- elif (
- n_service_channels != n_service_channels_registered
- and n_service_channels_registered > 0
- ):
- log.error(
- "FATAL: Some services are not registered, please check the registry and network settings"
- )
- return False
- else:
- return True
-
- def _load_config_section(
- self, global_config: Union[str, Dict[str, Any], Path], name: str, type: str
- ) -> Dict:
- """!
- Load the configuration section for a component.
-
- @param global_config: Global configuration source.
- @param name: Name of the component.
- @param type: Section type within the configuration file.
- @return: Dictionary containing the configuration for the component.
- """
-
- if isinstance(global_config, str):
- global_config = ConfigPath(global_config)
- if not global_config.exists():
- log.error("Given configuration file path does not exist.")
- if not global_config.is_absolute():
- global_config = global_config.resolve()
- if isinstance(global_config, Path):
- global_config = ConfigPath(str(global_config))
- if isinstance(global_config, ConfigPath):
- cfg = global_config.read_yaml()
- for item in cfg.get(type, []):
- if isinstance(item, dict): # If it's an inline configuration
- config = item["config"]
- return config
- # Make sure the yaml config has the same name with "name"
- elif isinstance(item, str) and item.endswith(
- ".yaml"
- ): # If it's a path to an external file
- if item.split(".")[0] == type + "/" + name:
- if os.path.isabs(item): # Check if the path is absolute
- external_path = ConfigPath(item)
- else: # Relative path, use the directory of the main config file
- external_path = global_config.parent / item
- # Load the YAML file and return its content
- item_config = external_path.read_yaml()
- config = item_config["config"]
- return config
- else:
- log.error(
- f"Invalid entry in '{type}': {self.name}. Please provide either a config or a path to another (.yaml) config."
- )
- return # Skip invalid entries
- if isinstance(global_config, dict):
- config = {}
- for component, component_config in global_config[type].items():
- if component == self.name:
- if not component_config:
- log.error(
- f"Please provide a config for the {type}: {self.name}"
- )
- return component_config
- if not config:
- log.error(f"Couldn't find type '{self.name}' in config.")
- return config
- else:
- log.error(f"Couldn't load config for {type} '{self.name}'")
-
- def get_info(self) -> dict:
- """!
- Gather information about all registered communication handlers.
-
- @return: Dictionary describing listeners, publishers, subscribers and services.
- """
- listener_info = []
- subscriber_info = []
- publisher_info = []
- service_info = []
- for ch in self._comm_handlers:
- ch_info = ch.get_info()
- if ch_info["comms_type"] == "Listener":
- listener_info.append(ch_info)
- elif ch_info["comms_type"] == "Subscriber":
- subscriber_info.append(ch_info)
- elif ch_info["comms_type"] == "Publisher":
- publisher_info.append(ch_info)
- elif ch_info["comms_type"] == "Service":
- service_info.append(ch_info)
- else:
- raise NameError
-
- for m_ch in self._multi_comm_handlers:
- m_ch_info = m_ch.get_info()
- for ch_info in m_ch_info:
- if ch_info["comms_type"] == "Listener":
- listener_info.append(ch_info)
- elif ch_info["comms_type"] == "Subscriber":
- subscriber_info.append(ch_info)
- elif ch_info["comms_type"] == "Publisher":
- publisher_info.append(ch_info)
- elif ch_info["comms_type"] == "Service":
- service_info.append(ch_info)
- else:
- raise NameError
-
- info = {
- "node_name": self.name,
- "node_id": self.node_id,
- "comms": {
- "listeners": listener_info,
- "subscribers": subscriber_info,
- "publishers": publisher_info,
- "services": service_info,
- },
- }
-
- return info
-
- def _callback_get_info(self, channel, msg):
- """!
- Callback for the default GetInfo service.
-
- @param channel: Unused service channel name.
- @param msg: Service request message.
- @return: Node information message.
- """
-
- print("Get info service called")
-
- # Create an instance of node_info_t
- node_info = node_info_t()
-
- data = self.get_info()
- # Populate node_info
- node_info.node_name = data["node_name"]
- node_info.node_id = data["node_id"]
-
- # Create comms_info_t
- comms_info = comms_info_t()
-
- # Populate listeners
- comms_info.n_listeners = len(data["comms"]["listeners"])
- comms_info.listeners = [
- listener_info_t() for _ in range(comms_info.n_listeners)
- ]
- for i, listener in enumerate(data["comms"]["listeners"]):
- comms_info.listeners[i].comms_type = listener["comms_type"]
- comms_info.listeners[i].channel_name = listener["channel_name"]
- comms_info.listeners[i].channel_type = listener["channel_type"]
- comms_info.listeners[i].channel_status = listener["channel_status"]
-
- # Populate subscribers
- comms_info.n_subscribers = len(data["comms"]["subscribers"])
- comms_info.subscribers = [
- subscriber_info_t() for _ in range(comms_info.n_subscribers)
- ]
- for i, subscriber in enumerate(data["comms"]["subscribers"]):
- comms_info.subscribers[i].comms_type = subscriber["comms_type"]
- comms_info.subscribers[i].channel_name = subscriber["channel_name"]
- comms_info.subscribers[i].channel_type = subscriber["channel_type"]
- comms_info.subscribers[i].channel_status = subscriber["channel_status"]
-
- # Populate publishers
- comms_info.n_publishers = len(data["comms"]["publishers"])
- comms_info.publishers = [
- publisher_info_t() for _ in range(comms_info.n_publishers)
- ]
- for i, publisher in enumerate(data["comms"]["publishers"]):
- comms_info.publishers[i].comms_type = publisher["comms_type"]
- comms_info.publishers[i].channel_name = publisher["channel_name"]
- comms_info.publishers[i].channel_type = publisher["channel_type"]
- comms_info.publishers[i].channel_status = publisher["channel_status"]
-
- # Populate services
- comms_info.n_services = len(data["comms"]["services"])
- comms_info.services = [service_info_t() for _ in range(comms_info.n_services)]
- for i, service in enumerate(data["comms"]["services"]):
- comms_info.services[i].comms_type = service["comms_type"]
- comms_info.services[i].service_name = service["service_name"]
- comms_info.services[i].service_host = service["service_host"]
- comms_info.services[i].service_port = service["service_port"]
- comms_info.services[i].registry_host = service["registry_host"]
- comms_info.services[i].registry_port = service["registry_port"]
- comms_info.services[i].request_type = service["request_type"]
- comms_info.services[i].response_type = service["response_type"]
-
- # Assign comms_info to node_info
- node_info.comms = comms_info
-
- return node_info
-
- def create_publisher(self, channel_name: str, channel_type: type) -> Publisher:
- """!
- Creates and returns a publisher for the specified channel.
-
- @param channel_name: The name of the channel to publish to.
- @type channel_name: str
- @param channel_type: The type of the message to publish.
- @type channel_type: type
- @return: The created Publisher instance.
- @rtype: Publisher
- """
- pub = Publisher(self._lcm, channel_name, channel_type)
- self._comm_handlers.append(pub)
- return pub
-
- def create_multi_channel_publisher(self, channels):
- """!
- Create a publisher that manages multiple channels.
-
- @param channels: List of ``(channel_name, channel_type)`` tuples.
- @return: The created :class:`MultiChannelPublisher` instance.
- """
- multi_pub = MultiChannelPublisher(channels, self._lcm)
- self._multi_comm_handlers.append(multi_pub)
- return multi_pub
-
- def create_multi_channel_listener(self, channels):
- """!
- Create listeners for multiple channels.
-
- @param channels: List of ``(channel_name, channel_type)`` tuples.
- @return: The created :class:`MultiChannelListener` instance.
- """
- multi_listeners = MultiChannelListener(channels, lcm_instance=self._lcm)
- self._multi_comm_handlers.append(multi_listeners)
- return multi_listeners
-
- def create_subscriber(
- self,
- channel_name: str,
- channel_type: type,
- callback: callable,
- callback_args: list = [],
- ) -> Subscriber:
- """!
- Creates and returns a subscriber for the specified channel.
-
- @param channel_name: The name of the channel to subscribe to.
- @type channel_name: str
- @param channel_type: The type of the message expected from the channel.
- @type channel_type: type
- @param callback: The callback function to be invoked when a message is received.
- @type callback: callable
- @param callback_args: Additional arguments to pass to the callback.
- @type callback_args: list
- @return: The created Subscriber instance.
- @rtype: Subscriber
- """
- sub = Subscriber(
- self._lcm,
- channel_name,
- channel_type,
- callback,
- callback_args=callback_args,
- )
- self._comm_handlers.append(sub)
- return sub
-
- def create_service(
- self,
- service_name: str,
- request_type: type,
- response_type: type,
- callback: callable,
- is_default_service=False,
- ):
- """!
- Create and register a service.
-
- @param service_name: Name of the service.
- @param request_type: Message type of the request.
- @param response_type: Message type of the response.
- @param callback: Callback invoked to handle the request.
- @param is_default_service: Mark service as an internal default.
- @return: The created :class:`Service` instance.
- """
- service = Service(
- service_name=service_name,
- req_type=request_type,
- resp_type=response_type,
- callback=callback,
- registry_host=self.registry_host,
- registry_port=self.registry_port,
- is_default=is_default_service,
- )
- self._comm_handlers.append(service)
- return service
-
- def create_listener(self, channel_name: str, channel_type: type) -> Listener:
- """!
- Creates and returns a listener for the specified channel.
-
- @param channel_name: The name of the channel to listen to.
- @type channel_name: str
- @param channel_type: The type of the message expected from the channel.
- @type channel_type: type
- @return: The created Listener instance.
- @rtype: Listener
- """
- listener = Listener(self._lcm, channel_name, channel_type)
- self._comm_handlers.append(listener)
- return listener
-
- def wait_for_message(
- self, channel_name: str, channel_type: type, timeout: int = 10
- ) -> Any:
- """!
- Waits for a single message on the specified channel within a timeout period.
-
- @param channel_name: The name of the channel to listen for messages.
- @type channel_name: str
- @param channel_type: The type of the message to expect.
- @type channel_type: type
- @param timeout: The number of seconds to wait before timing out.
- @type timeout: int
- @return: The received message, or None if the timeout was reached.
- @rtype: Any
- @raises TimeoutError: If the timeout is reached before receiving a message.
- """
-
- def timeout_handler(signum, frame):
- raise TimeoutError(
- f"Timeout reached while waiting for a message on channel '{channel_name}'."
- )
-
- signal.signal(signal.SIGALRM, timeout_handler)
- signal.alarm(timeout)
-
- msg = None
- listener = Listener(self._lcm, channel_name, channel_type)
-
- try:
- while not listener.received():
- self._lcm.handle() # Blocking call.
- msg = listener.get()
- except TimeoutError:
- log.warning(
- f"Listener {listener} did not receive a message within the specified timeout."
- )
- finally:
- signal.alarm(0) # Cancel the alarm.
- listener.shutdown()
-
- return msg
-
- def create_stepper(
- self,
- hz: float,
- callback: callable,
- oneshot: bool = False,
- reset: bool = True,
- callback_args: list = [],
- ) -> Stepper:
- """!
- Creates and returns a stepper that calls the specified callback at the specified rate.
-
- @param hz: The frequency (in Hz) at which the callback will be invoked.
- @type hz: float
- @param callback: The callback function to be called.
- @type callback: callable
- @param oneshot: If True, the callback is fired only once. Otherwise, it fires continuously.
- @type oneshot: bool
- @param reset: If True, the timer is reset when time moves backward.
- @type reset: bool
- @param callback_args: Additional arguments to pass to the callback.
- @type callback_args: list
- @return: The created Stepper instance.
- @rtype: Stepper
- """
- stepper = Stepper(
- hz, callback, oneshot=oneshot, reset=reset, callback_args=callback_args
- )
- self._steppers.append(stepper)
- return stepper
-
- def now(self) -> float:
- """!
- Returns the current time in seconds since the epoch.
-
- @return: The current time.
- @rtype: float
- """
- return time.time()
-
- def now_ns(self) -> int:
- """!
- Returns the current time in nanoseconds since the epoch.
-
- @return: The current time in nanoseconds.
- @rtype: int
- """
- return time.time_ns()
-
- def spin(self) -> None:
- """!
- Runs the node’s main loop, handling LCM messages continuously until the node is finished.
-
- The loop calls `self._lcm.handle()` to process incoming messages. If an OSError is encountered,
- the loop will stop and the node will shut down.
- """
- while not self._done:
- try:
- self._lcm.handle_timeout(0)
- except OSError as e:
- log.warning(f"LCM threw OSError {e}")
- self._done = True
-
- def _callback_suspend_node(self, channel, msg):
- """!
- Callback that suspends the node when triggered.
-
- @param channel: Unused service channel.
- @param msg: Service request message.
- @return: Empty :class:`flag_t` response.
- """
- self.suspend_node()
- return flag_t()
-
- def _callback_restart_node(self, channel, msg):
- """!
- Callback that restarts the node when triggered.
-
- @param channel: Unused service channel.
- @param msg: Service request message.
- @return: Empty :class:`flag_t` response.
- """
- self.restart_node()
- return flag_t()
-
- def suspend_communications(self, services=True) -> None:
- """!
- Suspends the node stopping comms handellers
-
- """
- # Unsubscribe all comm handlers
- for ch in self._comm_handlers:
- if ch.comm_type != "Service":
- ch.suspend()
- elif ch.comm_type == "Service" and services == True:
- ch.suspend()
-
- for m_ch in self._multi_comm_handlers:
- m_ch.suspend()
-
- def resume_communications(self, services=True) -> None:
- """!
- Resumes the node's communication handlers.
- """
- for ch in self._comm_handlers:
- if ch.comm_type != "Service":
- ch.restart()
- elif ch.comm_type == "Service" and services == True:
- ch.restart()
-
- for m_ch in self._multi_comm_handlers:
- m_ch.restart()
-
- def kill_node(self) -> None:
- """!
- Terminate the node process immediately.
-
- This method suspends the node and exits the program.
- """
-
- self.suspend_node()
- log.ok(f"Killing {self.name} Node")
- sys.exit(0)
-
- def suspend_node(self) -> None:
- """!
- Shuts down the node by stopping all communication handlers and steppers.
-
- Iterates through all registered communication handlers and steppers, shutting them down.
- """
- for ch in self._comm_handlers:
- if ch.comm_type != "Service":
- ch.suspend()
- elif ch.comm_type == "Service" and ch.register_with_registry == True:
- ch.suspend()
-
- for m_ch in self._multi_comm_handlers:
- m_ch.suspend()
-
- for s in self._steppers:
- s.suspend()
-
- def restart_node(self) -> None:
- """!
- Restart all communication handlers and steppers for the node.
- """
- for ch in self._comm_handlers:
- ch.restart()
-
- for m_ch in self._multi_comm_handlers:
- m_ch.restart()
-
- for s in self._steppers:
- s.restart()
-
- def send_service_request(
- self, service_name: str, request: object, response_type: type, timeout: int = 30
- ) -> Any:
- """!
- Convenience wrapper around :func:`send_service_request`.
-
- @param service_name: Name of the service to call.
- @param request: Request object to send.
- @param response_type: Expected response type.
- @param timeout: Timeout in seconds.
- @return: The decoded response from the service.
- """
- return send_service_request(
- self.registry_host,
- self.registry_port,
- service_name,
- request,
- response_type,
- timeout,
- )
diff --git a/ark/client/comm_infrastructure/endpoint.py b/ark/client/comm_infrastructure/endpoint.py
deleted file mode 100644
index 470f925..0000000
--- a/ark/client/comm_infrastructure/endpoint.py
+++ /dev/null
@@ -1,115 +0,0 @@
-# from typing import Any, Optional, Dict, Tuple, List, Union
-from pathlib import Path
-
-import lcm
-
-# import os
-from ark.tools.log import log
-
-# import socket
-from ark.utils.utils import ConfigPath
-from lcm import LCM
-
-
-class EndPoint:
-
- def __init__(self, global_config) -> None:
- """!
- Initialize an Endpoint object for interacting with the registry and
- setting up LCM communication.
-
- @param global_config: Global configuration containing network settings.
- """
-
- # self.network_config = {
- # "registry_host": "127.0.0.1",#"10.206.165.77",
- # "registry_port": 1234,
- # "lcm_network_bounces": 1 #was 1
- # }
- self._load_network_config(global_config)
- self.registry_host = self.network_config.get("registry_host", "127.0.0.1")
- self.registry_port = self.network_config.get("registry_port", 1234)
- self.lcm_network_bounces = self.network_config.get("lcm_network_bounces", 1)
- udpm = f"udpm://239.255.76.67:7667?ttl={self.lcm_network_bounces}"
- self._lcm: LCM = lcm.LCM(udpm)
-
- def _load_network_config(self, global_config: str | Path | dict | None) -> None:
- """!
- Load and update the network configuration from the given input.
-
- This method accepts a string path, a :class:`Path` object, a dictionary
- or ``None``. The resulting configuration is stored in
- ``self.network_config``.
-
- @param global_config: Path to a YAML file, a dictionary containing the
- network configuration, or ``None`` to use defaults.
- @return: ``None``. ``self.network_config`` is updated in place.
- """
- self.network_config = {}
- # extract network part of the global config
- if isinstance(global_config, str):
- global_config = ConfigPath(global_config) # Convert string to a Path object
-
- # Check if the given path exists
- if not global_config.exists():
- log.error(
- "Given configuration file path does not exist. Using default system configuration."
- )
- return # Exit the function if the file does not exist
-
- # Resolve relative paths to absolute paths
- elif not global_config.is_absolute():
- global_config = global_config.resolve()
-
- # If global_config is now a Path object, treat it as a configuration file
- if isinstance(global_config, Path):
- global_config = ConfigPath(str(global_config))
- if isinstance(global_config, ConfigPath):
- cfg = global_config.read_yaml(raise_fnf_error=False)
- if not cfg:
- log.error(
- f"Error reading config file {global_config.str}. Using default system configuration."
- )
- return cfg # Exit on failure to read file
-
- try:
- # Extract and update the 'system' configuration if it exists in the loaded YAML
- if "network" in cfg:
- self.network_config.update(cfg.get("network", self.network_config))
- else:
- log.warn(
- f"Couldn't find system in config. Using default system configuration."
- )
- return # Successfully updated configuration
- except Exception as e:
- log.error(
- f"Invalid entry in 'system' for. Using default system configuration."
- )
- return # Exit if there's an error updating the config
-
- # If global_config is a dictionary, assume it directly contains configuration values
- elif isinstance(global_config, dict):
- try:
- # check if system exists in the global_config
- if "network" in global_config:
- self.network_config.update(global_config.get("network"))
- else:
- log.warn(
- f"Couldn't find system in config. Using default system configuration."
- )
- except Exception as e:
- log.warn(
- f"Couldn't find system in config. Using default system configuration."
- )
-
- # If no configuration is provided (None), log a warning and use the default config
- elif global_config is None:
- log.warn(
- f"No global configuration provided. Using default system configuration."
- )
-
- # If global_config is of an unsupported type, log an error and use the default config
- else:
- log.error(
- f"Invalid global configuration type: {type(global_config)}. Using default system configuration."
- )
diff --git a/ark/client/comm_infrastructure/hybrid_node.py b/ark/client/comm_infrastructure/hybrid_node.py
deleted file mode 100644
index c686cd3..0000000
--- a/ark/client/comm_infrastructure/hybrid_node.py
+++ /dev/null
@@ -1,99 +0,0 @@
-import sys
-from pathlib import Path
-from abc import ABC, abstractmethod
-from typing import Any, Generator, Dict, Type
-import traceback
-
-import lcm
-from lcm import LCM
-import yaml
-import os
-
-from ark.client.comm_infrastructure.comm_endpoint import CommEndpoint
-from ark.tools.log import log
-
-
-class HybridNode(CommEndpoint):
- """!
- Base class for nodes that interact with the LCM system. Handles the subscription,
- publishing, and communication processes for the node.
-
- The `BaseNode` class manages the LCM instance and communication handlers, and provides
- methods for creating publishers, subscribers, listeners, and steppers. It also provides
- functionality for handling command-line arguments and the graceful shutdown of the node.
-
- @param lcm: The LCM instance used for communication.
- @param channel_name: The name of the channel to subscribe to.
- @param channel_type: The type of the message expected for the channel.
- """
-
- def __init__(self, node_name: str, global_config=None) -> None:
- """!
- Initializes a BaseNode object with the specified node name and registry host and port.
-
- @param node_name: The name of the node.
- @param global_config: Contains IP Address and Port
- """
- super().__init__(node_name, global_config)
-
- def manual_spin(self) -> None:
- """!
- Process pending LCM messages once.
-
- This method calls ``handle_timeout`` a single time and updates the
- done flag if an error occurs.
- """
- try:
- self._lcm.handle_timeout(0)
- except OSError as e:
- log.warning(f"LCM threw OSError {e}")
- self._done = True
-
- def spin(self) -> None:
- """!
- Runs the node’s main loop, handling LCM messages continuously until the node is finished.
-
- The loop calls `self._lcm.handle()` to process incoming messages. If an OSError is encountered,
- the loop will stop and the node will shut down.
- """
- while not self._done:
- try:
- self._lcm.handle_timeout(0)
- except OSError as e:
- log.warning(f"LCM threw OSError {e}")
- self._done = True
-
-
-def main(node_cls: type[HybridNode], *args) -> None:
- """!
- Initializes and runs a node.
-
- This function creates an instance of the specified `node_cls`, spins the node to handle messages,
- and handles exceptions that occur during the node's execution.
-
- @param node_cls: The class of the node to run.
- @type node_cls: Type[BaseNode]
- """
-
- if "--help" in sys.argv or "-h" in sys.argv:
- print(node_cls.get_cli_doc())
- sys.exit(0)
-
- node = None
- log.ok(f"Initializing {node_cls.__name__} type node")
- try:
- node = node_cls(*args)
- log.ok(f"Initialized {node.node_name}")
- node.spin()
- except KeyboardInterrupt:
- log.warning(f"User killed node {node_cls.__name__}")
- except Exception:
- tb = traceback.format_exc()
- div = "=" * 30
- log.error(f"Exception thrown during node execution:\n{div}\n{tb}\n{div}")
- finally:
- if node is not None:
- node.shutdown()
- log.ok(f"Finished running node {node_cls.__name__}")
- else:
- log.warning(f"Node {node_cls.__name__} failed during initialization")
diff --git a/ark/client/comm_infrastructure/instance_node.py b/ark/client/comm_infrastructure/instance_node.py
deleted file mode 100644
index a8bbc16..0000000
--- a/ark/client/comm_infrastructure/instance_node.py
+++ /dev/null
@@ -1,61 +0,0 @@
-import sys
-from abc import ABC, abstractmethod
-from typing import Any, Generator, Dict, Type
-import traceback
-
-import lcm
-from lcm import LCM
-import yaml
-import os
-import threading
-import signal
-
-from ark.client.comm_infrastructure.comm_endpoint import CommEndpoint
-from ark.tools.log import log
-
-
-class InstanceNode(CommEndpoint):
- """!
- Base class for nodes that interact with the LCM system. Handles the subscription,
- publishing, and communication processes for the node.
-
- The `BaseNode` class manages the LCM instance and communication handlers, and provides
- methods for creating publishers, subscribers, listeners, and steppers. It also provides
- functionality for handling command-line arguments and the graceful shutdown of the node.
-
- @param lcm: The LCM instance used for communication.
- @param channel_name: The name of the channel to subscribe to.
- @param channel_type: The type of the message expected for the channel.
- """
-
- def __init__(self, node_name: str, global_config=None) -> None:
- """!
- Initializes a BaseNode object with the specified node name and registry host and port.
-
- @param node_name: The name of the node.
- @param global_config: Contains IP Address and Port
- """
- print(global_config)
- super().__init__(node_name, global_config)
- self.config = self._load_config_section(
- global_config=global_config, name=node_name, type="other"
- )
-
- self._done = False
-
- self.spin_thread = threading.Thread(target=self.spin, daemon=True)
- self.spin_thread.start()
-
- def spin(self) -> None:
- """!
- Runs the node’s main loop, handling LCM messages continuously until the node is finished.
-
- The loop calls `self._lcm.handle()` to process incoming messages. If an OSError is encountered,
- the loop will stop and the node will shut down.
- """
- while not self._done:
- try:
- self._lcm.handle_timeout(0)
- except OSError as e:
- log.warning(f"LCM threw OSError {e}")
- self._done = True
diff --git a/ark/client/comm_infrastructure/registry.py b/ark/client/comm_infrastructure/registry.py
deleted file mode 100644
index a83ae16..0000000
--- a/ark/client/comm_infrastructure/registry.py
+++ /dev/null
@@ -1,284 +0,0 @@
-import argparse
-import socket
-import threading
-import struct
-import json
-import sys
-import typer
-
-from ark.client.comm_handler.service import Service, send_service_request
-from ark.client.comm_infrastructure.endpoint import EndPoint
-from ark.tools.log import log
-from ark.global_constants import *
-from arktypes import flag_t, network_info_t, node_info_t
-
-app = typer.Typer()
-
-
-class Registry(EndPoint):
- def __init__(
- self,
- registry_host: str = "127.0.0.1",
- registry_port: int = 1234,
- lcm_network_bounces: int = 1,
- ):
- """!
- Initialize the Registry server instance.
-
- @param registry_host: Host address for the registry.
- @param registry_port: Port on which the registry listens.
- @param lcm_network_bounces: TTL for LCM multicast messages.
- """
- global_config = {
- "network": {
- "registry_host": registry_host,
- "registry_port": registry_port,
- "lcm_network_bounces": lcm_network_bounces,
- }
- }
-
- super().__init__(global_config)
-
- self.services = {} # Maps service_name to (host, port)
- self.lock = threading.Lock()
- self._stop_event = threading.Event()
- self.error_flag = False
- self.thread = None
-
- self.get_info_service = None # Placeholder for service
-
- def _callback_get_network_info(self, channel, msg):
- """!
- Aggregate information about all nodes in the network.
-
- @param channel: Unused service channel name.
- @param msg: Service request message.
- @return: Populated :class:`network_info_t` message.
- """
- nodes_info = []
- req = flag_t()
- for service in self.services:
- if service.startswith(f"{DEFAULT_SERVICE_DECORATOR}/GetInfo"):
- node_info = send_service_request(
- self.registry_host, self.registry_port, service, req, node_info_t
- )
- if node_info is not None:
- nodes_info.append(node_info)
-
- res = network_info_t()
- res.n_nodes = len(nodes_info)
- for node in nodes_info:
- res.nodes.append(node)
- return res
-
- def _serve(self):
- """!
- Main loop handling incoming registry requests.
-
- This method listens on the configured host and port, processing
- registration, deregistration and discovery requests from clients.
- """
- self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- self.s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
-
- # with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
- try:
- self.s.bind((self.registry_host, self.registry_port))
- log.ok(
- f"Registry Server started on {self.registry_host} : {self.registry_port}"
- )
- except OSError as e:
- log.error(f"Error: {e}")
- self.error_flag = True # Set the error flag to true on error
- self._stop_event.set() # Trigger shutdown
- self.s.close() # Close the socket
- return # Exit the method to stop the server
-
- self.s.listen()
- self.s.settimeout(1.0) # Allow periodic check for stop event
-
- while not self._stop_event.is_set():
- try:
- conn, addr = self.s.accept()
- except socket.timeout:
- continue
- with conn:
- log.info(f"Registry: Connected via client (ip, port): {addr}")
- try:
- # Receive message length
- raw_msglen = self._recvall(conn, 4)
- if not raw_msglen:
- log.error("Registry: No message length received.")
- continue
- msglen = struct.unpack("!I", raw_msglen)[0]
- # Receive the actual message
- data = self._recvall(conn, msglen)
- if not data:
- log.error("Registry: No data received.")
- continue
- # Parse the request
- request = json.loads(data.decode("utf-8"))
- response = self._handle_request(request)
- # Send response
- encoded_resp = json.dumps(response).encode("utf-8")
- conn.sendall(struct.pack("!I", len(encoded_resp)))
- conn.sendall(encoded_resp)
- except Exception as e:
- log.error(f"Registry: Error handling request: {e}")
- continue # Continue with the next request
-
- def _recvall(self, conn, n):
- """!
- Receive ``n`` bytes from a connection.
-
- @param conn: Socket connection.
- @param n: Number of bytes to receive.
- @return: The received bytes or ``None`` if EOF is hit.
- """
- data = bytearray()
- while len(data) < n:
- packet = conn.recv(n - len(data))
- if not packet:
- return None
- data.extend(packet)
- return bytes(data)
-
- def _handle_request(self, request):
- """!
- Handle an incoming registry request.
-
- @param request: Parsed request dictionary.
- @return: Response dictionary to send back to the client.
- """
- req_type = request.get("type")
- if req_type == "REGISTER":
- service_name = request.get("service_name")
- host = request.get("host")
- port = request.get("port")
- if not all([service_name, host, port]):
- return {"status": "ERROR", "message": "Missing fields in REGISTER"}
- with self.lock:
- self.services[service_name] = (host, port)
- log.info(f"Registry: Registered service '{service_name}' at {host}:{port}")
- return {"status": "OK", "message": "Service registered successfully"}
- elif req_type == "DISCOVER":
- service_name = request.get("service_name")
- if not service_name:
- return {
- "status": "ERROR",
- "message": "Missing service_name in DISCOVER",
- }
- with self.lock:
- service = self.services.get(service_name)
- if service:
- host, port = service
- log.info(f"Registry: Service '{service_name}' found at {host}:{port}")
- return {"status": "OK", "host": host, "port": port}
- else:
- log.warning(f"Registry: Service '{service_name}' not found")
- return {"status": "ERROR", "message": "Service not found"}
- elif req_type == "DEREGISTER":
- service_name = request.get("service_name")
- if not service_name:
- return {
- "status": "ERROR",
- "message": "Missing service_name in DEREGISTER",
- }
- with self.lock:
- if service_name in self.services:
- del self.services[service_name] # Remove service from registry
- log.info(f"Registry: Deregistered service '{service_name}'")
- return {
- "status": "OK",
- "message": "Service deregistered successfully",
- }
- else:
- log.warning(f"Registry: Service '{service_name}' not found")
- return {"status": "ERROR", "message": "Service not found"}
- else:
- return {"status": "ERROR", "message": "Unknown request type"}
-
- def _stop(self):
- """!
- Stop the server and wait for the serving thread to finish.
-
- @return: ``None``
- """
- log.info("Shutting down server...")
- # Shutdown the info service
- if self.get_info_service:
- self.get_info_service.suspend()
-
- if self.thread and self.thread.is_alive():
- self._stop_event.set()
- self.thread.join() # Ensure the server thread is stopped
- log.info("Server thread stopped.")
- self.s.close()
-
- log.info("Registry Server stopped.")
-
- def start(self):
- """!
- Start the server and monitor for errors.
-
- This method blocks until the server stops or encounters a fatal error.
- """
-
- try:
- # Initialize thread to serve requests
- self.thread = threading.Thread(target=self._serve, daemon=True)
- self.thread.start()
-
- self.get_info_service = Service(
- f"{DEFAULT_SERVICE_DECORATOR}/GetNetworkInfo",
- flag_t,
- network_info_t,
- self._callback_get_network_info,
- self.registry_host,
- self.registry_port,
- is_default=True,
- )
-
- while not self.error_flag:
- if not self.thread.is_alive():
- log.error("Server thread terminated unexpectedly.")
- self.error_flag = True
- self._stop()
- sys.exit(1)
- self.thread.join(
- 1
- ) # Wait for the thread to finish (or periodically check for errors)
-
- self._stop()
- except KeyboardInterrupt:
- log.error("Program interrupted by user.")
- self._stop() # Gracefully stop the server
- sys.exit(0) # Exit gracefully
-
-
-@app.command()
-def start(
- registry_host: str = typer.Option(
- "127.0.0.1", "--host", help="The host address for the registry server."
- ),
- registry_port: int = typer.Option(
- 1234, "--port", help="The port for the registry server."
- ),
-):
- """!
- Start the Registry server with the specified host and port.
-
- @param registry_host: Host address for the registry server.
- @param registry_port: Port for the registry server.
- """
- server = Registry(registry_host=registry_host, registry_port=registry_port)
- server.start()
-
-
-def main():
- """! Entry point for the CLI."""
- app() # Initializes the Typer CLI
-
-
-if __name__ == "__main__":
- main()
diff --git a/ark/client/comm_infrastructure/script_node.py b/ark/client/comm_infrastructure/script_node.py
deleted file mode 100644
index 6b39b1b..0000000
--- a/ark/client/comm_infrastructure/script_node.py
+++ /dev/null
@@ -1,93 +0,0 @@
-
-import sys
-from abc import ABC, abstractmethod
-from typing import Any, Generator, Dict, Type
-import traceback
-
-import lcm
-from lcm import LCM
-import yaml
-import os
-
-from ark.client.comm_infrastructure.comm_endpoint import CommEndpoint
-from ark.tools.log import log
-
-class ScriptNode(CommEndpoint, ABC):
- """!
- Base class for nodes that interact with the LCM system. Handles the subscription,
- publishing, and communication processes for the node.
-
- The `BaseNode` class manages the LCM instance and communication handlers, and provides
- methods for creating publishers, subscribers, listeners, and steppers. It also provides
- functionality for handling command-line arguments and the graceful shutdown of the node.
-
- @param lcm: The LCM instance used for communication.
- @param channel_name: The name of the channel to subscribe to.
- @param channel_type: The type of the message expected for the channel.
- """
-
- def __init__(self, node_name: str, global_config=None) -> None:
- """!
- Initializes a BaseNode object with the specified node name and registry host and port.
-
- @param node_name: The name of the node.
- @param global_config: Contains IP Address and Port
- """
- super().__init__(node_name, global_config)
- self.config = self._load_config_section(global_config=global_config, name=node_name, "other")
- self.script_stepper = self.create_stepper(1, self.script, oneshot=True)
-
- @abstractmethod
- def script(self) -> None:
- raise NotImplementedError
-
-
- def single_spin(self) -> None:
- """!
- Runs the node’s main loop, handling LCM messages continuously until the node is finished.
-
- The loop calls `self._lcm.handle()` to process incoming messages. If an OSError is encountered,
- the loop will stop and the node will shut down.
- """
- while not self._done and not self.script_stepper._shutdown:
- try:
- self._lcm.handle_timeout(0)
- except OSError as e:
- log.warning(f"LCM threw OSError {e}")
- self._done = True
-
-
-
-def main(node_cls: type[ScriptNode], config_path=None,*args, **kwargs) -> None:
- """!
- Initializes and runs a node.
-
- This function creates an instance of the specified `node_cls`, spins the node to handle messages,
- and handles exceptions that occur during the node's execution.
-
- @param node_cls: The class of the node to run.
- @type node_cls: Type[BaseNode]
- """
-
- if "--help" in sys.argv or "-h" in sys.argv:
- print(node_cls.get_cli_doc())
- sys.exit(0)
-
- node = None
- log.ok(f"Initializing {node_cls.__name__} type node")
- try:
- node = node_cls(config_path, *args, **kwargs)
- log.ok(f"Initialized {node.node_name}")
- node.single_spin()
- except KeyboardInterrupt:
- log.warning(f"User killed node {node_cls.__name__}")
- except Exception:
- tb = traceback.format_exc()
- div = "=" * 30
- log.error(f"Exception thrown during node execution:\n{div}\n{tb}\n{div}")
- finally:
- if node is not None:
- node.kill_node()
- log.ok(f"Finished running node {node_cls.__name__}")
- else:
- log.warning(f"Node {node_cls.__name__} failed during initialization")
diff --git a/ark/client/frequencies/rate.py b/ark/client/frequencies/rate.py
deleted file mode 100644
index c114637..0000000
--- a/ark/client/frequencies/rate.py
+++ /dev/null
@@ -1,83 +0,0 @@
-import time
-from ark.tools.log import log
-
-
-class Rate:
- """!
- A convenience class for sleeping in a loop at a specified rate using `perf_counter_ns` for high precision.
-
- This class calculates the required sleep duration between loop iterations based on a specified rate (in Hz),
- and attempts to maintain that rate by sleeping for the appropriate amount of time. The rate is measured using
- nanoseconds for high precision.
- """
-
- def __init__(self, hz: float, reset: bool = False) -> None:
- """!
- Initializes the Rate object with a specified rate in Hertz (Hz) and an optional reset flag.
-
- @param hz: The target rate in Hertz (loops per second) to determine the sleep duration.
- @param reset: If True, resets the timer if the system time moves backward. Defaults to False.
- """
- self.last_time_ns: int = time.perf_counter_ns()
- self.sleep_dur_ns: int = int(
- 1e9 / hz
- ) # Duration in nanoseconds for the given Hz rate
- self._reset: bool = reset
-
- def _remaining(self, curr_time_ns: int) -> int:
- """!
- Calculates the remaining time (in nanoseconds) before the next sleep interval.
-
- @param curr_time_ns: The current time in nanoseconds.
- @return: The remaining time to sleep in nanoseconds.
- @raises RuntimeError: If time moved backward and `reset` is False.
- """
- if self.last_time_ns > curr_time_ns:
- if self._reset:
- # Reset the last_time_ns if time moved backward and reset is True
- self.last_time_ns = curr_time_ns
- else:
- # Raise an error if time moved backwards and reset is False
- raise RuntimeError("Time moved backwards and reset is not allowed.")
-
- elapsed_ns = curr_time_ns - self.last_time_ns
- return self.sleep_dur_ns - elapsed_ns
-
- def remaining(self) -> float:
- """!
- Returns the time remaining (in seconds) before the next sleep interval.
-
- @return: The remaining sleep time in seconds.
- """
- curr_time_ns = time.perf_counter_ns()
- remaining_ns = self._remaining(curr_time_ns)
- return remaining_ns / 1e9 # Convert nanoseconds to seconds
-
- def sleep(self) -> None:
- """!
- Attempts to sleep at the specified rate.
-
- This method calculates the remaining time for the current cycle and sleeps for that duration to maintain
- the target rate. If the system time moved backward, a warning is printed. If the `reset` flag is set to False,
- a RuntimeError is raised.
-
- @raises RuntimeError: If the system time moved backward and `reset` is False.
- """
- curr_time_ns = time.perf_counter_ns()
- try:
- remaining_ns = self._remaining(curr_time_ns)
- if remaining_ns > 0:
- # Convert nanoseconds to seconds for time.sleep()
- time.sleep(remaining_ns / 1e9)
- except RuntimeError as e:
- # Handle time moving backward if reset is False
- log.warning(str(e))
- if not self._reset:
- raise
-
- # Update last_time_ns after sleeping
- self.last_time_ns = time.perf_counter_ns()
-
- # Check if the loop is too slow or if time jumped forward (greater than 2x sleep duration)
- if curr_time_ns - self.last_time_ns > self.sleep_dur_ns * 2:
- self.last_time_ns = curr_time_ns
diff --git a/ark/client/frequencies/stepper.py b/ark/client/frequencies/stepper.py
deleted file mode 100644
index 236d471..0000000
--- a/ark/client/frequencies/stepper.py
+++ /dev/null
@@ -1,82 +0,0 @@
-import threading
-from ark.client.frequencies.rate import Rate
-from ark.tools.log import log
-
-
-class Stepper(threading.Thread):
- """!
- Convenience class for stepping a callback at a specified rate.
-
- This class runs a callback function at a specified rate (in Hertz) on a separate thread.
- The callback can be executed continuously or only once, depending on the `oneshot` parameter.
- The rate is controlled using the `Rate` class, which ensures the callback is called at a consistent interval.
- """
-
- def __init__(
- self,
- hz: float,
- callback: callable,
- oneshot: bool = False,
- reset: bool = True,
- callback_args: list = [],
- ) -> None:
- """!
- Initializes the Stepper thread.
-
- @param hz: The rate in Hertz (loops per second) at which to call the callback.
- @param callback: The callback function to be called at the specified rate.
- @param oneshot: If True, the callback is called only once; if False, it is called continuously until `shutdown` is called. Defaults to False.
- @param reset: If True, the timer is reset if the system time moves backward. Defaults to True.
- @param callback_args: Arguments to pass to the callback function when called. Defaults to an empty list.
- """
- super().__init__()
- self._hz: float = hz
- self._period_ns: float = 1e9 / float(hz) # Period in nanoseconds
- self._callback: callable = callback
- self._oneshot: bool = oneshot
- self._reset: bool = reset
- self._shutdown: bool = False
- self.daemon: bool = True
- self._callback_args: list = callback_args
- self.start()
- log.ok("started stepper")
-
- def suspend(self) -> None:
- """!
- Signal the stepper thread to stop running.
-
- @return: ``None``
- """
- self._shutdown = True
- log.ok("stepper suspended")
-
- def run(self) -> None:
- """!
- Runs the callback at the specified rate until the thread is shut down.
-
- The `Rate` class is used to ensure the callback is executed at the correct intervals.
- If `oneshot` is set to True, the callback will only be executed once.
- """
- r = Rate(self._hz, reset=self._reset)
- while not self._shutdown:
- r.sleep() # Sleep for the specified rate duration
- if self._shutdown:
- break
-
- # Call the callback with the provided arguments
- self._callback(*self._callback_args)
-
- if self._oneshot:
- self.suspend()
- break
-
- def restart(self):
- """!
- Restart the stepper if it was previously suspended.
- """
- if self._shutdown:
- self._shutdown = False
- self.start()
- log.ok("restarted stepper")
- else:
- log.error("stepper is already running")
diff --git a/ark/configs/franka_panda.yaml b/ark/configs/franka_panda.yaml
deleted file mode 100644
index 3971700..0000000
--- a/ark/configs/franka_panda.yaml
+++ /dev/null
@@ -1,43 +0,0 @@
-env:
- flatten_obs_space: true
- flatten_action_space: false
- normalise_obs_space: True
- num_envs: 5
-
-robot:
- num_joints : 9
-
-observation_space:
- proprio:
- - from: franka/joint_states/sim
- using: joint_state
- - from: franka/ee_state/sim
- using: pose
-
- sensors:
- - sensor_type: VisionSensor
- name: top_camera
- from: IntelRealSense/rgbd/sim
- using: rgbd
- image_height: 480
- image_width: 640
-
- objects:
- - from: cube/ground_truth/sim
- using: rigid_body_state
- name: cube
- - from: target/ground_truth/sim
- using: rigid_body_state
- name: target
-
-
-action_space:
- action:
- - from: franka/cartesian_command/sim
- dim: 8
- using: task_space_command
- select:
- name: "all" # All channels
- position: [ 0, 1, 2 ] # x, y, z
- orientation: [ 3, 4, 5, 6 ] # quaternion indices
- gripper: 7 # single index
diff --git a/ark/decoders/__init__.py b/ark/decoders/__init__.py
deleted file mode 100644
index fcc60ab..0000000
--- a/ark/decoders/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from . import builtin_decoders
diff --git a/ark/decoders/builtin_decoders.py b/ark/decoders/builtin_decoders.py
deleted file mode 100644
index b5e80c4..0000000
--- a/ark/decoders/builtin_decoders.py
+++ /dev/null
@@ -1,100 +0,0 @@
-from typing import Any
-
-import numpy as np
-from ark.decoders.registry import register_decoder
-
-from arktypes.utils.unpack import image, depth
-
-OBS_SCHEMA = {
- "joint_state": ["position", "velocity", "effort"],
- "pose": ["position", "orientation"],
- "rgbd": ["rgb", "depth"],
- "rigid_body_state": [
- "position",
- "orientation",
- "linear_velocity",
- "angular_velocity",
- ],
-}
-
-
-@register_decoder("joint_state")
-def decode_joint_state(msg) -> dict[str, Any]:
- """
- Decode a joint_state message into a dictionary of numpy arrays.
- Args:
- msg: A message object with attributes `position`, `velocity`, and `effort`.
-
- Returns:
- A dictionary containing:
- - "position": numpy array of joint positions (float32)
- - "velocity": numpy array of joint velocities (float32)
- - "effort": numpy array of joint efforts/torques (float32)
-
- """
- return {
- "position": np.asarray(msg.position, dtype=np.float32),
- "velocity": np.asarray(msg.velocity, dtype=np.float32),
- "effort": np.asarray(msg.effort, dtype=np.float32),
- }
-
-
-@register_decoder("pose")
-def decode_pose(msg) -> dict[str, Any]:
- """
- Decode a pose message into a dictionary of numpy arrays.
- Args:
- msg: A message object with attributes `position` and `orientation`.
-
- Returns:
- A dictionary containing:
- - "position": numpy array of position coordinates (float32)
- - "orientation": numpy array of orientation quaternion (float32)
-
- """
- return {
- "position": np.asarray(msg.position, dtype=np.float32),
- "orientation": np.asarray(msg.orientation, dtype=np.float32),
- }
-
-
-@register_decoder("rgbd")
-def decode_rgbd(msg) -> dict[str, Any]:
- """
- Decode a rgbd message into a dictionary of numpy arrays.
- Args:
- msg: A message object with attributes `image` and `depth`.
-
- Returns:
- A dictionary containing:
- - "rgb": image data
- - "depth": depth data
-
- """
- return {
- "rgb": image(msg.image),
- "depth": depth(msg.depth), # TODO check is this optional
- }
-
-
-@register_decoder("rigid_body_state")
-def decode_rgbd(msg) -> dict[str, Any]:
- """
- Decode a rigid_body_state message into a dictionary of numpy arrays.
- Args:
- msg: A message object with attributes `position`, `orientation`, `lin_velocity`, and `ang_velocity`.
-
- Returns:
- A dictionary containing:
- - "position": numpy array of body position (float32)
- - "orientation": numpy array of body orientation quaternion (float32)
- - "linear_velocity": numpy array of linear velocity (float32)
- - "angular_velocity": numpy array of angular velocity (float32)
-
- """
- return {
- "position": np.asarray(msg.position, dtype=np.float32),
- "orientation": np.asarray(msg.orientation, dtype=np.float32),
- "linear_velocity": np.asarray(msg.lin_velocity, dtype=np.float32),
- "angular_velocity": np.asarray(msg.ang_velocity, dtype=np.float32),
- }
diff --git a/ark/decoders/list_decoders.py b/ark/decoders/list_decoders.py
deleted file mode 100644
index a46f38f..0000000
--- a/ark/decoders/list_decoders.py
+++ /dev/null
@@ -1,28 +0,0 @@
-import typer
-from ark.decoders.registry import DECODER_REGISTRY
-
-app = typer.Typer()
-
-
-@app.command("list")
-def list_decoders() -> None:
- """
- Display all registered observation and action decoders.
- Returns:
- None
- """
- if not DECODER_REGISTRY:
- typer.echo("No decoders registered.")
- return
-
- typer.echo("Available decoders:")
- for name in sorted(DECODER_REGISTRY.keys()):
- typer.echo(f"- {name}")
-
-
-def main():
- app()
-
-
-if __name__ == "__main__":
- main()
diff --git a/ark/decoders/registry.py b/ark/decoders/registry.py
deleted file mode 100644
index 2ac42a3..0000000
--- a/ark/decoders/registry.py
+++ /dev/null
@@ -1,38 +0,0 @@
-DECODER_REGISTRY: dict[str, callable] = {}
-
-
-def register_decoder(name: str):
- """
- Register a decoder function under a specified name.
- Args:
- name: The name under which to register the decoder function.
-
- Returns:
- A decorator function that registers the decoder when applied to a function.
- """
-
- def wrapper(fn):
- if name in DECODER_REGISTRY:
- raise ValueError(f"Decoder '{name}' already registered")
- DECODER_REGISTRY[name] = fn
- return fn
-
- return wrapper
-
-
-def get_decoder(name: str):
- """
- Retrieve a registered decoder function by name.
- Args:
- name: The name of the decoder to retrieve.
-
- Returns:
- The decoder function associated with the given name.
- """
- try:
- return DECODER_REGISTRY[name]
- except KeyError:
- raise KeyError(
- f"Decoder '{name}' not found. "
- f"Available decoders: {sorted(DECODER_REGISTRY.keys())}"
- )
diff --git a/ark/env/ark_env.py b/ark/env/ark_env.py
deleted file mode 100644
index 9f36eee..0000000
--- a/ark/env/ark_env.py
+++ /dev/null
@@ -1,461 +0,0 @@
-import os
-from abc import ABC, abstractmethod
-from pathlib import Path
-from typing import Any
-
-from ark.client.comm_infrastructure.instance_node import InstanceNode
-from ark.env.spaces import ActionSpace, ObservationSpace
-from ark.tools.log import log
-from ark.utils.communication_utils import (
- build_action_space,
- build_observation_space,
- get_channel_types,
- _dynamic_observation_unpacker,
- _dynamic_action_packer,
- namespace_channels,
-)
-from ark.utils.data_utils import generate_flat_dict
-from ark.utils.utils import ConfigPath
-from arktypes import robot_init_t, flag_t, rigid_body_state_t
-from gymnasium import Env
-
-
-class ArkEnv(Env, InstanceNode, ABC):
- """!ArkEnv base class.
-
- This environment integrates the Ark system with the :mod:`gymnasium` API. It
- handles action publishing, observation retrieval and exposes helper utilities
- for resetting parts of the system. Sub‑classes are expected to implement the
- packing/unpacking logic for messages as well as the reward and termination
- functions.
-
- @param environment_name Name of the environment (also the node name).
- @type environment_name str
- @param action_channels Channels on which actions will be published.
- @type action_channels list[tuple[str, type]]
- @param observation_channels Channels on which observations will be received.
- @type observation_channels list[tuple[str, type]]
- @param global_config Path or dictionary describing the complete Noah system
- configuration. If ``None`` a warning is emitted and only minimal
- functionality is available.
- @type global_config Union[str, dict[str, Any], Path]
- @param sim Set ``True`` when running in simulation mode.
- @type sim bool
- """
-
- def __init__(
- self,
- environment_name: str,
- channel_schema: str,
- global_config: str,
- sim=True,
- namespace: str = "ark",
- ) -> None:
- """!Construct the environment.
-
- The constructor sets up the internal communication channels and creates
- the action and observation spaces. The configuration can either be
- provided as a path to a YAML file or as a dictionary already loaded in
- memory.
-
- @param environment_name Name of the environment node.
- @param action_channels Dictionary mapping channel names to LCM
- types for actions.
- @type action_channels dict[str, type]
- @param observation_channels Dictionary mapping channel names to LCM
- types for observations.
- @type observation_channels dict[str, type]
- @param global_config Optional path or dictionary describing the system.
- @param sim If ``True`` the environment interacts with the simulator.
- """
- super().__init__(
- environment_name, global_config
- ) # TODO check why global config needed here
-
- schema = ConfigPath(channel_schema).read_yaml()
-
- # Derive observation and action channel types from schema
- obs_chans = get_channel_types(schema=schema, channel_type="observation_space")
- act_chans = get_channel_types(schema=schema, channel_type="action_space")
-
- # Namespace channels with rank
- observation_channels = namespace_channels(
- channels=obs_chans, namespace=namespace
- )
- action_channels = namespace_channels(channels=act_chans, namespace=namespace)
-
- self._flatten_action_space = schema["env"]["flatten_action_space"]
- self._flatten_obs_space = schema["env"]["flatten_obs_space"]
-
- self.sim = sim
- self.namespace = namespace
- self.prev_state = None
-
- # Create the action space
- self.ark_action_space = ActionSpace(
- action_channels, self.action_packing, self._lcm
- )
- self.ark_observation_space = ObservationSpace(
- observation_channels, self.observation_unpacking, self._lcm
- )
-
- self._multi_comm_handlers.append(self.ark_action_space.action_space_publisher)
- self._multi_comm_handlers.append(
- self.ark_observation_space.observation_space_listener
- )
-
- self._load_config(global_config) # creates self.global_config
-
- # Build Gym-style observation / action spaces from schema
- self.observation_space = build_observation_space(
- schema=schema, flatten_obs_space=self._flatten_obs_space
- )
- self.action_space = build_action_space(schema=schema)
-
- self._obs_unpacker = _dynamic_observation_unpacker(
- schema, namespace=self.namespace
- )
- self._action_packer = _dynamic_action_packer(schema, namespace=self.namespace)
-
- # Reward and Termination Conditions
- self._termination_conditions = self._create_termination_conditions()
- self._reward_functions = self._create_reward_functions()
-
- def action_packing(self, action):
- """
- Packs the action into a task_space_command_t format.
-
- Expected layout:
- [EE_X, EE_Y, EE_Z, EE_QX, EE_QY, EE_QZ, EE_QW, Gripper]
- """
- return self._action_packer(action)
-
- def observation_unpacking(self, observation_dict):
- """
- Unpack raw LCM observations into a compact dict used by the agent.
-
- """
- obs = self._obs_unpacker(observation_dict)
- if self._flatten_obs_space:
- obs = generate_flat_dict(obs)
- return obs
-
- @abstractmethod
- def _create_termination_conditions(self): ...
-
- @abstractmethod
- def _create_reward_functions(self): ...
-
- @abstractmethod
- def reset_objects(self):
- """!Reset all objects in the environment."""
- raise NotImplementedError
-
- def reset(self, **kwargs) -> tuple[Any, Any]:
- """!Reset the environment.
-
- This method resets all user defined objects by calling
- :func:`reset_objects` and waits until fresh observations are available.
- The returned information tuple contains the termination and truncation
- flags as produced by :func:`terminated_truncated_info`.
-
- @return Observation after reset and information tuple.
- @rtype tuple[Any, Any]
- """
- if not self.ark_observation_space.is_ready:
- self.ark_observation_space.wait_until_observation_space_is_ready()
- self.reset_objects()
- self.ark_observation_space.is_ready = False
- self.ark_observation_space.wait_until_observation_space_is_ready()
- obs = self.ark_observation_space.get_observation()
- # Reset per-episode state for reward / termination functions
- for termination in self._termination_conditions.values():
- termination.reset()
- for reward_fn in self._reward_functions.values():
- reward_fn.reset()
-
- self.prev_state = obs
-
- return obs, {}
-
- def reset_backend(self):
- """!Reset the simulation backend."""
- raise NotImplementedError("This feature is to be added soon.")
-
- def reset_component(self, name: str, **kwargs):
- """!Reset a single component.
-
- Depending on ``name`` this method sends a reset service request to a
- robot or object defined in the configuration.
-
- @param name Identifier of the component to reset.
- @param kwargs Optional parameters such as ``base_position`` or
- ``initial_configuration`` used to override the configuration.
- """
- if self.global_config is None:
- log.error(
- "No configuration file provided, so no objects can be found. Please provide a valid configuration file."
- )
- return
- # search through config
- # if name in [robot["name"] for robot in self.global_config["robots"]]:
- if name in self.global_config["robots"]:
-
- service_name = f"{self.namespace}/" + name + "/reset/"
- if self.sim:
- service_name = service_name + "sim"
-
- request = robot_init_t()
- request.name = name
- request.position = kwargs.get(
- "base_position", self.global_config["robots"][name]["base_position"]
- )
- request.orientation = kwargs.get(
- "base_orientation",
- self.global_config["robots"][name]["base_orientation"],
- )
- q_init = kwargs.get(
- "initial_configuration",
- self.global_config["robots"][name]["initial_configuration"],
- )
- request.n = len(q_init)
- request.q_init = q_init
-
- elif name in self.global_config["sensors"]:
- log.error(f"Can't reset a sensor (called for {name}).")
-
- # elif name in [obj["name"] for obj in self.global_config["objects"]]:
- elif name in self.global_config["objects"]:
- service_name = f"{self.namespace}/" + name + "/reset/"
- if self.sim:
- service_name = service_name + "sim"
-
- request = rigid_body_state_t()
- request.name = name
- request.position = kwargs.get(
- "base_position", self.global_config["objects"][name]["base_position"]
- )
- request.orientation = kwargs.get(
- "base_orientation",
- self.global_config["objects"][name]["base_orientation"],
- )
-
- # TODO for now we only work with position init, may add velocity in the future
- request.lin_velocity = kwargs.get("base_velocity", [0.0, 0.0, 0.0])
- request.ang_velocity = kwargs.get("base_angular_velocity", [0.0, 0.0, 0.0])
-
- else:
- log.error(f"Component {name} not part of the system.")
-
- _ = self.send_service_request(
- service_name=service_name, request=request, response_type=flag_t
- )
-
- def _step_termination(self, obs, info=None):
- """
- Step and aggregate termination conditions
-
- Args:
- env (Environment): Environment instance
- info (None or dict): Any info to return
-
- Returns:
- 2-tuple:
- - float: aggregated termination at the current timestep
- - dict: any information passed through this function or generated by this function
- """
- # Get all dones and successes from individual termination conditions
- dones = []
- successes = []
- info = dict() if info is None else info
- if "termination_conditions" not in info:
- info["termination_conditions"] = dict()
- for name, termination_condition in self._termination_conditions.items():
- d, s = termination_condition.step(obs=obs)
- dones.append(d)
- successes.append(s)
- info["termination_conditions"][name] = {
- "done": d,
- "success": s,
- }
- # Any True found corresponds to a done / success
- done = sum(dones) > 0
- success = sum(successes) > 0
-
- # Populate info
- info["success"] = success
- return done, info
-
- def _step_reward(self, obs, info=None):
- """
- Step and aggregate reward functions
-
- Args:
- env (Environment): Environment instance
- info (None or dict): Any info to return
-
- Returns:
- 2-tuple:
- - float: aggregated reward at the current timestep
- - dict: any information passed through this function or generated by this function
- """
- # Make sure info is a dict
- total_info = dict() if info is None else info
- # We'll also store individual reward split as well
- breakdown_dict = dict()
- # Aggregate rewards over all reward functions
- total_reward = 0.0
- for reward_name, reward_function in self._reward_functions.items():
- reward, reward_info = reward_function.step(obs=obs)
- total_reward += reward
- breakdown_dict[reward_name] = reward
- total_info[reward_name] = reward_info
-
- # Store breakdown dict
- total_info["reward_breakdown"] = breakdown_dict
-
- return total_reward, total_info
-
- def step(self, action: Any) -> tuple[Any, float, bool, bool, Any]:
- """!Advance the environment by one step.
-
- The provided ``action`` is packed and published. The function then
- waits for a new observation, computes the reward and termination flags
- and returns all gathered information.
-
- @param action Action provided by the agent.
- @return tuple of observation, reward, termination flag, truncation flag
- and an optional info object.
- @rtype tuple[Any, float, bool, bool, Any]
- """
- if self.prev_state == None:
- raise ValueError("Please call reset() before calling step().")
-
- self.ark_action_space.pack_and_publish(action)
-
- # Wait for the observation space to be ready
- self.ark_observation_space.wait_until_observation_space_is_ready()
-
- # Get the observation
- obs = self.ark_observation_space.get_observation()
-
- # Calculate reward
- done, done_info = self._step_termination(obs=obs)
- reward, reward_info = self._step_reward(obs=obs)
- truncated = True if done and not done_info["success"] else False
-
- info = {
- "reward": reward_info,
- "done": done_info,
- }
-
- self.prev_state = obs
-
- return obs, reward, done, truncated, info
-
- def _load_config(self, global_config: str | ConfigPath) -> None:
- """!Load and merge the environment configuration.
-
- The configuration can be provided as a path to a YAML file or as an
- already parsed dictionary. Sections describing robots, sensors and
- objects may themselves reference additional YAML files which are loaded
- and merged.
-
- @param global_config Path or dictionary to parse.
- """
- if global_config is None:
- log.warning("No configuration file provided. Using default configuration.")
- self.global_config = None
- return
- if isinstance(global_config, str):
- global_config = ConfigPath(global_config)
- if isinstance(global_config, Path):
- global_config = ConfigPath(str(global_config))
-
- if isinstance(global_config, ConfigPath) and not global_config.exists():
- log.error(
- f"Given configuration file path does not exist: {global_config.str}"
- )
- return
-
- if isinstance(global_config, ConfigPath) and not global_config.is_absolute():
- global_config = global_config.resolve()
-
- cfg = global_config.read_yaml()
-
- config = {
- "network": cfg.get("network", None) if isinstance(cfg, dict) else None,
- "simulator": cfg.get("simulator", None) if isinstance(cfg, dict) else None,
- "robots": (
- self._load_section(cfg, global_config, "robots")
- if cfg.get("robots")
- else {}
- ),
- "sensors": (
- self._load_section(cfg, global_config, "sensors")
- if cfg.get("sensors")
- else {}
- ),
- "objects": (
- self._load_section(cfg, global_config, "objects")
- if cfg.get("objects")
- else {}
- ),
- }
-
- if not config["simulator"]:
- log.error(
- "Please provide at least name and backend_type under 'simulator' in your config file."
- )
-
- log.info(
- f"Config file under {global_config.str if global_config else 'default configuration'} loaded successfully."
- )
- self.global_config = config
-
- def _load_section(
- self, cfg: [str, Any], config_path: ConfigPath, section_name: str
- ) -> dict[str, Any]:
- """!Load a sub-section from the configuration.
-
- Sections can either be provided inline in ``cfg`` or as a path to an
- additional YAML file. This helper returns a dictionary mapping component
- names to their configuration dictionaries.
-
- @param cfg Parsed configuration dictionary.
- @param config_path Path to the root configuration file, used to resolve
- relative includes.
- @param section_name Section within ``cfg`` to load.
- @return Dictionary with component names as keys and their configurations
- as values.
- """
- section_config: dict[str, Any] = {}
-
- for item in cfg.get(section_name, []):
- if isinstance(item, dict): # If it's an inline configuration
- subconfig = item
- elif isinstance(item, str) and item.endswith(
- ".yaml"
- ): # If it's a path to an external file
- if os.path.isabs(item): # Check if the path is absolute
- external_path = ConfigPath(item)
- else: # Relative path, use the directory of the main config file
- external_path = config_path.parent / item
- # Load the YAML file and return its content
- subconfig = external_path.read_yaml()
- else:
- log.error(
- f"Invalid entry in '{section_name}': {item}. Please provide either a config or a path to another config."
- )
- continue # Skip invalid entries
-
- section_config[subconfig["name"]] = subconfig["config"]
-
- return section_config
-
- def close(self):
- """!Gracefully shut down communications and background threads."""
- self.suspend_communications(services=True)
- spin_thread = getattr(self, "spin_thread", None)
- if spin_thread and spin_thread.is_alive():
- spin_thread.join(timeout=1.0)
diff --git a/ark/env/spaces.py b/ark/env/spaces.py
deleted file mode 100644
index 3177ac6..0000000
--- a/ark/env/spaces.py
+++ /dev/null
@@ -1,162 +0,0 @@
-"""!Utility classes defining action and observation spaces.
-
-These classes encapsulate the LCM based communication used by the environment
-to publish actions and receive observations."""
-
-import os
-import time
-from abc import ABC
-from typing import Any, List, Dict, Callable
-
-from ark.client.comm_handler.multi_channel_listener import MultiChannelListener
-from ark.client.comm_handler.multi_channel_publisher import MultiChannelPublisher
-from ark.tools.log import log
-from lcm import LCM
-
-
-class Space(ABC):
- """!
- An abstract base class for different types of spaces. This is used as a generic
- interface for both action and observation spaces in the system.
-
- The `Space` class provides the general structure for subclasses to define space-specific
- shutdown behavior.
- """
-
- # @abstractmethod
- def shutdown(self) -> None:
- """!
- Abstract method to shut down the space, ensuring any resources are released.
-
- Subclasses must implement this method to cleanly shut down their specific space.
- """
-
-
-class ActionSpace(Space):
- """!
- A class representing a space where actions are taken. This space handles publishing
- of actions to a given LCM channel.
-
- @param action channels: Channel names where actions will be published.
- @type action_channels: List
- @param action_packing: A function that converts any types of action into dictionary
- @type action_packing: Callable
- @param lcm_instance: Communication variable
- @type lcm_instance: LCM
- """
-
- def __init__(
- self,
- action_channels: Dict[str, type],
- action_packing: Callable,
- lcm_instance: LCM,
- ):
- """!Create an action space.
-
- @param action_channels Channels to publish actions on.
- @param action_packing Callback used to serialize actions.
- @param lcm_instance LCM instance used for communication.
- """
-
- self.action_space_publisher = MultiChannelPublisher(
- action_channels, lcm_instance
- )
- self.action_packing = action_packing
- self.messages_to_publish = None
-
- def pack_and_publish(self, action: Any):
- """!Pack an action and publish it."""
-
- messages_to_publish = self.pack_message(action)
- self.action_space_publisher.publish(messages_to_publish)
-
- def pack_message(self, action: Any) -> Dict[str, Any]:
- """!
- Abstract method to pack the action into a message format suitable for LCM.
-
- @param action: The action to be packed into a message.
- @type action: Any
- @return: The packed LCM message.
- @type: Dict
- """
-
- return self.action_packing(action)
-
-
-class ObservationSpace(Space):
- """!
- A class representing an observation space that listens for observations over LCM
- and processes them.
-
- @param observation channels: Channel names where observations will be listened
- @type observation_channels: List
- @param observation_unpacking: A function that converts observation dictionary into any types
- @type observation_unpacking: Callable
- @param lcm_instance: Communication variable
- @type lcm_instance: LCM
- """
-
- def __init__(
- self,
- observation_channels: Dict[str, type],
- observation_unpacking: Callable,
- lcm_instance: LCM,
- ):
- """!Create an observation space.
-
- @param observation_channels Channels to listen for observations.
- @param observation_unpacking Callback used to deserialize messages.
- @param lcm_instance LCM instance used for communication.
- """
-
- self.observation_space_listener = MultiChannelListener(
- observation_channels, lcm_instance
- )
- self.observation_unpacking = observation_unpacking
- self.is_ready = False
- self.debug = os.getenv("ARK_DEBUG", "").lower() in ("1", "true")
-
- def unpack_message(self, observation_dict: Dict) -> Any:
- """!Unpack a raw observation dictionary.
-
- @param observation_dict Dictionary mapping channel names to raw LCM messages.
- @return The processed observation.
- @rtype Any
- """
- obs = self.observation_unpacking(observation_dict)
- return obs
-
- def check_readiness(self):
- """!Check whether fresh observations are available."""
-
- lcm_dictionary = self.observation_space_listener.get()
- self.is_ready = not any(value is None for value in lcm_dictionary.values())
- if self.debug:
- log.info(f"Observation space {lcm_dictionary.values()}.")
-
- def wait_until_observation_space_is_ready(self):
- """!Block until a complete observation has been received."""
-
- while not self.is_ready:
- if self.debug:
- log.warning("Observation space is getting checked")
- self.check_readiness()
- time.sleep(0.05)
- if not self.is_ready and self.debug:
- log.warning("Observation space is still not ready. Retrying...")
-
- def empty_data(self):
- """!Clear cached observation data."""
- self.observation_space_listener.empty_data()
-
- def get_observation(self) -> Any:
- """!Return the latest processed observation."""
- assert (
- self.is_ready
- ), "Observation space is not ready. Call wait_until_observation_space_is_ready() first."
-
- self.data = self.observation_space_listener.get()
-
- processed_observation = self.unpack_message(self.data)
-
- return processed_observation
diff --git a/ark/env/vector_env.py b/ark/env/vector_env.py
deleted file mode 100644
index 08865d0..0000000
--- a/ark/env/vector_env.py
+++ /dev/null
@@ -1,211 +0,0 @@
-from __future__ import annotations
-
-import uuid
-from multiprocessing import Process
-from typing import Callable, Type, Any
-
-from ark.system.simulation.simulator_node import SimulatorNode
-from ark.utils.communication_utils import (
- get_channel_types,
- namespace_channels,
-)
-from ark.utils.utils import ConfigPath
-from gymnasium import Env
-from gymnasium.vector import AsyncVectorEnv, SyncVectorEnv, VectorEnv
-
-
-def run_simulator_proc(
- global_config: str,
- observation_channels: dict[str, Any],
- action_channels: dict[str, Any],
- namespace: str,
-) -> None:
- """
- Launch a simulator node as a blocking process.
- Args:
- global_config: Path to the global configuration file for the simulator.
- observation_channels: Dictionary defining observation channels and their types.
- action_channels: Dictionary defining action channels and their types.
- namespace: Unique namespace for the simulator instance to avoid channel conflicts.
-
- Returns:
- None
-
- """
- node = SimulatorNode(
- global_config=global_config,
- observation_channels=observation_channels,
- action_channels=action_channels,
- namespace=namespace,
- )
- node.spin()
-
-
-def _make_env_thunk(
- env_cls: Type[Env],
- namespace: str,
- channel_schema: str,
- global_config: str,
- sim: bool,
- env_kwargs: dict[str, Any] | None = None,
-) -> Callable[[], Env]:
- """
- Create a thunk (callable) that initializes a new environment instance.
- Args:
- env_cls: The environment class to instantiate.
- namespace: Unique namespace for the environment instance.
- channel_schema: Schema defining observation and action channels.
- global_config: Path to the global configuration file.
- sim: Whether to connect the environment to a simulator process.
-
- Returns:
-
- """
-
- def _init() -> Env:
- kwargs = env_kwargs or {}
- return env_cls(
- namespace=namespace,
- channel_schema=channel_schema,
- global_config=global_config,
- sim=sim,
- **kwargs,
- )
-
- return _init
-
-
-def make_sim(channel_schema: str, global_config: str, namespace: str) -> Process:
- """
- Spawn a simulator process for a specific namespace and channel configuration.
- Args:
- channel_schema: Path to the YAML schema defining observation and action channels.
- global_config: Path to the global configuration for the simulator.
- namespace: Unique namespace for the simulator instance.
-
- Returns:
- A daemon Process object running the simulator node.
-
- """
- schema = ConfigPath(channel_schema).read_yaml()
-
- # Derive observation and action channel types from schema
- obs_chans = get_channel_types(schema=schema, channel_type="observation_space")
- act_chans = get_channel_types(schema=schema, channel_type="action_space")
- # Namespace channels
- observation_channels = namespace_channels(channels=obs_chans, namespace=namespace)
- action_channels = namespace_channels(channels=act_chans, namespace=namespace)
-
- sim_proc = Process(
- target=run_simulator_proc,
- args=(global_config, observation_channels, action_channels, namespace),
- daemon=True,
- )
- sim_proc.start()
- return sim_proc
-
-
-def _cleanup_sim_procs(sim_procs: list[Process]) -> None:
- """
- Terminate and join any running simulator processes.
- Args:
- sim_procs: List of simulator Process objects to clean up.
-
- Returns:
- None.
-
- """
- for proc in sim_procs:
- if proc.is_alive():
- proc.terminate()
- for proc in sim_procs:
- if proc.is_alive():
- proc.join(timeout=1.0)
-
-
-def _attach_cleanup(env: VectorEnv, sim_procs: list[Process]) -> VectorEnv:
- """
- Wrap an environment's close method to also terminate simulator processes.
- Args:
- env: The vectorized environment to wrap.
- sim_procs: List of simulator Process objects associated with this environment.
-
- Returns:
- The same environment instance with a modified close method that ensures
- simulator processes are cleaned up.
-
- """
- original_close = getattr(env, "close", None)
-
- def _close():
- try:
- if callable(original_close):
- original_close()
- finally:
- _cleanup_sim_procs(sim_procs)
-
- env.close = _close
- env._sim_procs = sim_procs
- return env
-
-
-def make_vector_env(
- env_cls: Type[Env],
- num_envs: int,
- channel_schema: str,
- global_config: str,
- sim: bool = True,
- asynchronous: bool = True,
- env_kwargs: dict[str, Any] | None = None,
- namespace:str="ark"
-) -> VectorEnv:
- """
- Create a vectorized environment with optional simulator processes.
- Args:
- env_cls: The environment class to instantiate.
- num_envs: Number of environment instances to create (must be >= 1).
- channel_schema: Path to the YAML schema defining observation and action channels.
- global_config: Path to the global configuration file.
- sim: Whether to connect the environment to a simulator process.
- asynchronous: Whether to use asynchronous (AsyncVectorEnv) or synchronous (SyncVectorEnv).
- namespace: Given name for Env namespace.
-
- Returns:
- A vectorized environment instance.
- """
-
- if num_envs <= 0:
- raise ValueError("num_envs must be >= 1")
-
- thunks = []
- sim_procs = []
- for rank in range(num_envs):
- namespace = namespace if num_envs==1 else uuid.uuid4().hex[:8]
-
- if sim:
- sim_proc = make_sim(
- channel_schema=channel_schema,
- global_config=global_config,
- namespace=namespace,
- )
- sim_procs.append(sim_proc)
-
- thunks.append(
- _make_env_thunk(
- env_cls,
- namespace,
- channel_schema,
- global_config,
- sim,
- env_kwargs=env_kwargs,
- )
- )
-
- env: VectorEnv = (
- AsyncVectorEnv(thunks) if asynchronous else SyncVectorEnv(thunks)
- )
-
- if sim_procs:
- env = _attach_cleanup(env, sim_procs)
-
- return env
diff --git a/ark/global_constants.py b/ark/global_constants.py
deleted file mode 100644
index 5d47ab5..0000000
--- a/ark/global_constants.py
+++ /dev/null
@@ -1 +0,0 @@
-DEFAULT_SERVICE_DECORATOR = "__DEFAULT_SERVICE"
diff --git a/ark/system/__init__.py b/ark/system/__init__.py
deleted file mode 100644
index 8b13789..0000000
--- a/ark/system/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-
diff --git a/ark/system/component/__init__.py b/ark/system/component/__init__.py
deleted file mode 100644
index 26ff8f5..0000000
--- a/ark/system/component/__init__.py
+++ /dev/null
@@ -1,7 +0,0 @@
-"""System component definitions used by ARK.
-
-This package bundles the core component abstractions that are shared
-between the robotics backends. Every component is implemented as a
-node that can send and receive data through the ARK communication
-infrastructure.
-"""
diff --git a/ark/system/component/base_component.py b/ark/system/component/base_component.py
deleted file mode 100644
index e3b45b3..0000000
--- a/ark/system/component/base_component.py
+++ /dev/null
@@ -1,116 +0,0 @@
-"""Base classes for ARK system components.
-
-This module defines the :class:`BaseComponent` and
-:class:`SimToRealComponent` classes which form the foundation of all
-robot, sensor and simulation objects in the framework.
-"""
-
-from abc import ABC, abstractmethod
-from typing import Any, Dict, Union
-from pathlib import Path
-import os
-import yaml
-
-from arktypes import flag_t
-from ark.client.comm_infrastructure.hybrid_node import HybridNode
-from ark.system.driver.component_driver import ComponentDriver
-from ark.tools.log import log
-
-
-class BaseComponent(HybridNode, ABC):
- """Base class for all components in the ARK system.
-
- @brief Provides common functionality for robots, sensors and
- simulated objects.
-
- Concrete implementations must provide methods for packing data,
- stepping the component and resetting it to a well defined state.
- """
-
- def __init__(
- self,
- name: str,
- global_config: Union[str, Dict[str, Any], Path],
- ) -> None:
- """Construct a new component.
-
- @param name Unique name of the component.
- @param global_config Global configuration or path to a YAML
- configuration file.
- @throws ValueError if ``name`` is empty.
- """
- if not name:
- raise ValueError("Name must be a non-empty string (unique in your system).")
- super().__init__(name, global_config)
- self.name = name # node_name and name are the same
- self._is_suspended = False # TODO do we still need this ?
-
- @abstractmethod
- def pack_data(self) -> None:
- """Pack the component data for publishing.
-
- Subclasses use this hook to convert raw state information into
- the LCM types that are sent to the client.
- """
-
- def component_channels_init(self, channels: Dict[str, type]) -> None:
- """Create publishers for the specified channels.
-
- @param channels Iterable of channel names that the component
- should publish on.
- """
- self.component_multi_publisher = self.create_multi_channel_publisher(channels)
-
- @abstractmethod
- def step_component(self) -> None:
- """Perform a single update of the component state.
-
- Called periodically by the communication layer, this method is
- responsible for gathering data from the driver and publishing it
- to the appropriate channels.
- """
- ...
-
- @abstractmethod
- def reset_component(self, channel, msg) -> None:
- """Reset the component to a known state.
-
- @param channel Channel name for the reset request.
- @param msg Optional message containing reset parameters.
- """
- ...
-
-
-class SimToRealComponent(BaseComponent, ABC):
- """Component with a driver that may run in simulation or on real hardware."""
-
- def __init__(
- self,
- name: str,
- global_config: Union[str, Dict[str, Any], Path],
- driver: ComponentDriver = None,
- ) -> None:
- """Create a component connected to a driver.
-
- @param name Name of the component.
- @param global_config Configuration dictionary or path.
- @param driver Driver instance that provides low level access.
- """
-
- super().__init__(name, global_config)
- self._driver = driver
- self.sim = self._driver.is_sim()
- self.namespace = global_config["namespace"]
-
- # initialize service for reset of any component
- self.reset_service_name = f"{self.namespace }/" + self.name + "/reset"
- if self.sim:
- self.reset_service_name = self.reset_service_name + "/sim"
-
- # Override killing the node to also shutdown the driver, freeing up ports etc.
- def kill_node(self) -> None:
- """Shut down the node and associated driver."""
- # kill driver (close ports, ...)
- self._driver.shutdown_driver()
- # kill all communication
- super().kill_node()
diff --git a/ark/system/component/robot.py b/ark/system/component/robot.py
deleted file mode 100644
index 83a1ba9..0000000
--- a/ark/system/component/robot.py
+++ /dev/null
@@ -1,476 +0,0 @@
-"""Robot component abstractions used by the ARK framework."""
-
-from abc import ABC, abstractmethod
-from enum import Enum
-from typing import Any, Optional, Dict, Tuple, List, Union
-from pathlib import Path
-import xml.etree.ElementTree as ET
-import time
-
-import numpy as np
-
-from ark.tools.log import log
-from ark.system.component.base_component import SimToRealComponent
-from ark.system.driver.robot_driver import RobotDriver
-from arktypes import flag_t, robot_init_t
-
-
-def robot_joint_control(func):
- """Decorator applying joint group lookup before calling ``func``."""
-
- def wrapper(self, group_name: str, target: Dict[str, float], **kwargs):
- # Ensure the class instance has 'joint_groups' and the necessary structure
- if not hasattr(self, "joint_groups"):
- raise AttributeError("The class must have 'joint_groups' attribute.")
-
- if group_name not in self.joint_groups:
- raise ValueError(f"Group name '{group_name}' not found in joint_groups.")
-
- control_mode = self.joint_groups[group_name]["control_mode"]
- actuated_joints = self.joint_groups[group_name]["actuated_joints"]
-
- if len(target) != len(actuated_joints):
- log.warning(
- f"Number of targets ({len(target)}) does not equal number of actuated joints ({len(actuated_joints)}) in group '{group_name}'!"
- )
- kwargs["group_name"] = group_name
- # Call the original function with modified arguments
- return func(self, control_mode, list(actuated_joints), target, **kwargs)
-
- return wrapper
-
-
-def robot_control(func):
- """Decorator forwarding commands for a joint group to ``func``."""
-
- def wrapper(self, group_name: str, target: Dict[str, float], **kwargs):
- # Ensure the class instance has 'joint_groups' and the necessary structure
- if not hasattr(self, "joint_groups"):
- raise AttributeError("The class must have 'joint_groups' attribute.")
-
- if group_name not in self.joint_groups:
- raise ValueError(f"Group name '{group_name}' not found in joint_groups.")
-
- control_mode = self.joint_groups[group_name]["control_mode"]
- actuated_joints = self.joint_groups[group_name]["actuated_joints"]
- kwargs["group_name"] = group_name
-
- # Call the original function with modified arguments
- return func(self, control_mode, list(actuated_joints), target, **kwargs)
-
- return wrapper
-
-
-class ControlType(Enum):
- POSITION = "position"
- VELOCITY = "velocity"
- TORQUE = "torque"
- FIXED = "fixed"
-
-
-class Robot(SimToRealComponent):
- """High level representation of a robot in ARK."""
-
- def __init__(
- self,
- name: str,
- global_config: Union[str, Dict[str, Any], Path],
- driver: RobotDriver,
- ) -> None:
- """Create a robot component.
-
- @param name Unique name of the robot.
- @param global_config Configuration dictionary or file path.
- @param driver Concrete :class:`RobotDriver` controlling the hardware or simulator.
- """
- super().__init__(name=name, global_config=global_config, driver=driver)
- self.robot_config = self._load_config_section(
- global_config=global_config, name=name, type="robots"
- )
-
- self.joint_infos: Dict[str, Any] = {} # from urdf
- # {"name" : {"index" : ... ,
- # "type" : ... ,
- # "actuated" : ... ,
- # "parent_link" : ... ,
- # "child_link" : ... ,
- # "lower_limit" : ... ,
- # "upper_limit" : ... ,
- # "effort_limit" : ... ,
- # "velocity_limit" : ... ,
- # }
- # }
-
- self.joint_groups: Dict[str, Any] = {} # from urdf
- # { "name" : { "control_mode" : ... ,
- # "joints" : { "name" : idx,
- # "name" : idx,
- # ...
- # },
- # "actuated_joints" : { "name" : idx
- # },
- # "end_effector" : { "ee_name" : idx_ee},
- # }
- # }
-
- self._all_actuated_joints = []
- # { "name" : idx,
- # ...
- # }
-
- self.initial_configuration = (
- {}
- ) # from self.robot_config["initial_configuration"]
- # { "name" : float,
- # ...
- # }
- if self.robot_config.get("urdf_path", None) and self.robot_config.get(
- "mjcf_path", None
- ):
- log.warning(
- f"Both 'urdf_path' and 'mjcf_path' are provided for robot '{self.name}'. Defaulting to URDF."
- )
-
- if self.robot_config.get("urdf_path", None):
- class_path = self.robot_config.get("class_dir", None)
- urdf_path = self.robot_config["urdf_path"]
- if class_path is None:
- urdf_path = Path(class_path) / urdf_path
- else:
- urdf_path = Path(urdf_path)
-
- # Make the URDF path absolute if it is not already
- if not urdf_path.is_absolute():
- urdf_path = Path(class_path) / urdf_path
-
- # Check if the URDF path exists
- if not urdf_path.exists():
- log.error(f"The URDF path '{urdf_path}' does not exist.")
- return
-
- tree = ET.parse(urdf_path)
- root = tree.getroot()
- elements = root.findall("joint")
- elif self.robot_config.get("mjcf_path", None):
- class_path = self.robot_config.get("class_dir", None)
- mjcf_path = self.robot_config["mjcf_path"]
- if class_path is None:
- mjcf_path = Path(class_path) / mjcf_path
- else:
- mjcf_path = Path(mjcf_path)
-
- # Make the MJCF path absolute if it is not already
- if not mjcf_path.is_absolute():
- mjcf_path = Path(class_path) / mjcf_path
-
- # Check if the MJCF path exists
- if not mjcf_path.exists():
- log.error(f"The MJCF path '{mjcf_path}' does not exist.")
- return
-
- tree = ET.parse(mjcf_path)
- root = tree.getroot()
- elements = root.findall(".//joint")
-
- print(
- f"Robot '{self.name}' has the following elements (Total: {len(elements)}):"
- )
-
- for i, joint in enumerate(elements):
- name = joint.get("name")
- print(name)
- joint_info = {}
- joint_info["index"] = i
- joint_info["type"] = joint.get("type")
- if not joint_info["type"] == "fixed":
- self._all_actuated_joints.append(name)
- joint_info["actuated"] = True
- else:
- joint_info["actuated"] = False
-
- try:
- # Iterate over all joints and collect relevant info
- for i, joint in enumerate(elements):
- name = joint.get("name")
- print(name)
- joint_info = {}
- joint_info["index"] = i
- joint_info["type"] = joint.get("type")
- if not joint_info["type"] == "fixed":
- self._all_actuated_joints.append(name)
- joint_info["actuated"] = True
- else:
- joint_info["actuated"] = False
-
- # Get the parent and child link names (URDF and MJCF have different structures here)
- if self.robot_config.get("urdf_path", None): # URDF case
- joint_info["parent Link"] = joint.find("parent").get("link")
- joint_info["child Link"] = joint.find("child").get("link")
- elif self.robot_config.get("mjcf_path", None): # MJCF case
- # Build a parent map once (ideally outside your joint loop for efficiency)
- parent_map = {
- child: parent for parent in root.iter() for child in parent
- }
-
- # Find the owning
of this joint (walk up until we hit a body)
- body_el = joint
- while body_el is not None and body_el.tag != "body":
- body_el = parent_map.get(body_el)
-
- child_link = body_el.get("name") if body_el is not None else None
-
- # The parent link is the parent of the owning body.
- parent_el = parent_map.get(body_el) if body_el is not None else None
- if parent_el is not None and parent_el.tag == "body":
- parent_link = parent_el.get("name")
- else:
- # owning body is directly under (i.e., world is the parent)
- parent_link = "__WORLD__"
-
- joint_info["parent Link"] = parent_link
- joint_info["child Link"] = child_link
-
- # If joint has limits (revolute or prismatic joints), get the limits
- if joint_info["type"] in ["revolute", "prismatic"]:
- limit = joint.find("limit")
- if limit is not None:
- joint_info["lower_limit"] = limit.get("lower", None)
- joint_info["upper_limit"] = limit.get("upper", None)
- joint_info["effort_limit"] = limit.get("effort", None)
- joint_info["velocity_limit"] = limit.get("velocity", None)
- else:
- joint_info["limits"] = "No limits defined"
- self.joint_infos[joint.get("name")] = joint_info
- # save dict of iniital cofngiruation of joints
- self.initial_configuration[joint.get("name")] = self.robot_config[
- "initial_configuration"
- ][i]
- # print(f"{i:<8} {name:<20} {joint_info['type']:<10}")
- # Print the joint summary
- print(f"{i:<8} {name:<20} {joint_info['type']:<10}")
- print(f" Parent Link: {joint_info['parent Link']}")
- print(f" Child Link: {joint_info['child Link']}")
- if "lower_limit" in joint_info:
- print(
- f" Limits: {joint_info['lower_limit']} to {joint_info['upper_limit']}, "
- f"Effort: {joint_info['effort_limit']}, Velocity: {joint_info['velocity_limit']}"
- )
- else:
- print(f" Limits: {joint_info.get('limits', 'None')}")
- print("-" * 40) # Divider for each joint summary
- except:
- log.error("Error prasing MJCF/URDF file: Using fallback joint_info")
-
- # check if joint group is defined:
- if "joint_groups" in self.robot_config:
- for group_name, group_config in self.robot_config.get(
- "joint_groups", {}
- ).items():
- # add control type from enum to internal config dict
- control_mode = group_config.get("control_mode", {})
- if control_mode == "position":
- group_config["control_type"] = ControlType.POSITION
- elif control_mode == "velocity":
- group_config["control_type"] = ControlType.VELOCITY
- elif control_mode == "torque":
- group_config["control_type"] = ControlType.TORQUE
- elif control_mode == "fixed":
- group_config["control_type"] = ControlType.FIXED
- # TODO
- raise NotImplementedError("TODO - how to manually fix a joint")
- else:
- raise ValueError(f"control mode '{control_mode}' is not supported")
- joints = {}
- actuated_joints = {}
- for joint in group_config["joints"]:
- joint_idx = self.joint_infos[joint]["index"]
- joints[joint] = joint_idx # {"joint name": joint index}
- if self.joint_infos[joint]["actuated"]:
- actuated_joints[joint] = joint_idx
- group_config["joints"] = joints
- group_config["actuated_joints"] = actuated_joints
-
- # # same for end effector
- # ee = group_config.get("end_effector", "None")# if not provided
- # assert isinstance(ee, str), "end_effector must be either None or a single value joint name."
- # if ee == "None": # if explicitly set to None
- # ee = None
-
- # group_config["end_effector"] = None
- # if ee is not None:
- # ee_idx = self.joint_infos[ee]["index"] # idx
- # group_config["end_effector"] = {ee: ee_idx}
- self.joint_groups[group_name] = group_config
- else:
- log.warning(
- f"Using Default Joint Group all in Position Control '{self.name}' !"
- )
- group_config = {}
- group_config["control_mode"] = "position"
- group_config["control_type"] = ControlType.POSITION
- joints = {}
- actuated_joints = {}
- for joint in self.joint_infos.keys():
- joint_idx = self.joint_infos[joint]["index"]
- joints[joint] = joint_idx
- if self.joint_infos[joint]["actuated"]:
- actuated_joints[joint] = joint_idx
- group_config["joints"] = joints
- group_config["actuated_joints"] = actuated_joints
- self.joint_groups["all"] = group_config
-
- self.create_service(
- self.reset_service_name, robot_init_t, flag_t, self.reset_component
- )
-
- if not self.sim:
- # runs if the robot is real
- try:
- self.freq = self.robot_config["frequency"]
- except:
- log.warning(
- f"No frequency provided for robot '{self.name}', using default 240Hz !"
- )
- self.freq = 240
- self.create_stepper(self.freq, self.step_component)
- self.create_stepper(self.freq, self.control_robot)
-
- print(self._all_actuated_joints)
-
- @abstractmethod
- def control_robot(self) -> None:
- """Send the currently stored command to the robot driver."""
- print("No call")
-
- @abstractmethod
- def pack_data(self) -> None:
- """Pack state information for publishing to the client."""
-
- @abstractmethod
- def get_state(self) -> Any:
- """Retrieve the current robot state from the driver."""
-
- #####################
- ## get infos ##
- #####################
-
- def get_joint_limits(self) -> Dict[str, Dict[str, float]]:
- """Return joint limits for all actuated joints."""
- actuated_info = {
- joint: info for joint, info in self.joint_infos.items() if info["actuated"]
- }
- return {
- joint: {
- "lower_limit": float(info.get("lower_limit", "inf")),
- "upper_limit": float(info.get("upper_limit", "inf")),
- "effort_limit": float(info.get("effort_limit", "inf")),
- "velocity_limit": float(info.get("velocity_limit", "inf")),
- }
- for joint, info in actuated_info.items()
- }
-
- def _get_joint_group_indices(
- self, joint_group: str
- ) -> Tuple[List[float], List[float]]:
- """Return joint and actuated joint indices for ``joint_group``."""
- return list(self.joint_groups[joint_group]["joints"].values()), list(
- self.joint_groups[joint_group]["actuated_joints"].values()
- )
-
- def is_torqued(self) -> bool:
- """Check if the driver currently outputs torque."""
- return self._driver.check_torque_status()
-
- def get_joint_positions(self) -> Dict[str, float]:
- """Get positions of all actuated joints."""
- return self._driver.pass_joint_positions(self._all_actuated_joints)
- # return self._driver.pass_joint_positions(self._all_actuated_joints)
-
- def get_joint_velocities(self) -> Dict[str, float]:
- """Get velocities of all actuated joints."""
- return self._driver.pass_joint_velocities(self._all_actuated_joints)
-
- def get_joint_efforts(self) -> Dict[str, float]:
- """Get efforts of all actuated joints."""
- return self._driver.pass_joint_efforts(self._all_actuated_joints)
-
- def get_joint_group_positions(self, joint_group: str) -> Dict[str, float]:
- """Get joint positions for a specific group."""
- actuated_joints = self.joint_groups[joint_group]["actuated_joints"]
- return self._driver.pass_joint_positions(actuated_joints)
-
- def get_joint_group_velocities(self, joint_group: str) -> Dict[str, float]:
- """Get joint velocities for a specific group."""
- actuated_joints = self.joint_groups[joint_group]["actuated_joints"]
- return self._driver.pass_joint_velocities(actuated_joints)
-
- def get_joint_group_efforts(self, joint_group: str) -> Dict[str, float]:
- """Get joint efforts for a specific group."""
- actuated_joints = self.joint_groups[joint_group]["actuated_joints"]
- return self._driver.pass_joint_efforts(actuated_joints)
-
- #####################
- ## control ##
- #####################
- def control_joint_group(
- self, control_mode: str, cmd: Dict[str, float], **kwargs
- ) -> None:
- """Forward a joint group command to the driver."""
- self._driver.pass_joint_group_control_cmd(control_mode, cmd, **kwargs)
-
- #####################
- ## misc. ##
- #####################
-
- def reset_component(self, channel=None, msg=None) -> flag_t:
- """Reset the robot to its initial configuration."""
- print("RESET HAS BEEN CALLED")
- self.suspend_communications(
- services=False
- ) # Suspend communications to avoid conflicts during reset
- self._is_suspended = True
-
- # # TODO
- # IDEA seperate reset iinto sim and real reset, make user implement real_reset for each robot ?
- if not msg:
- new_pos = self.robot_config["base_position"]
- new_orn = self.robot_config["base_orientation"]
- q_init = self.robot_config["initial_configuration"]
- else:
- new_pos = np.array(msg.position)
- new_orn = np.array(msg.orientation)
- q_init = np.array(msg.q_init)
-
- nbr_actuated = len(self._all_actuated_joints)
-
- nbr_init_pos = len(self.robot_config["initial_configuration"])
-
- if q_init.size == len(self._all_actuated_joints):
- idx = np.linspace(0, q_init.size - 1, q_init.size, dtype=np.uint8)
- temp = np.zeros(nbr_init_pos)
- temp[idx] = q_init
- q_init = temp.copy()
-
- if q_init.size != nbr_init_pos:
- log.error(
- f"Number of initial positions ({q_init.size}) does not match number of joints ({nbr_init_pos}) or actuated joints ({nbr_actuated}) for robot {self.name}!"
- )
-
- if not self.sim:
- usr_input = input("Start moving robot back to initial position [yes/no] ?")
- if usr_input == "yes":
- log.panda("if this appears big issue")
- raise NotImplementedError("TODO - how to reset a real robot ?")
- elif self.sim:
- self._driver.sim_reset(
- base_pos=new_pos, base_orn=new_orn, q_init=list(q_init)
- )
- self.resume_communications(services=False)
- self._is_suspended = False
- return flag_t()
-
- def step_component(self):
- """Query state and publish it to the configured channels."""
- data = self.get_state()
- packed = self.pack_data(data)
- self.component_multi_publisher.publish(packed)
diff --git a/ark/system/component/sensor.py b/ark/system/component/sensor.py
deleted file mode 100644
index 9f3b1ad..0000000
--- a/ark/system/component/sensor.py
+++ /dev/null
@@ -1,75 +0,0 @@
-"""Base classes for sensor components."""
-
-from abc import ABC, abstractmethod
-from typing import Any, Dict, Optional
-
-from ark.system.component.base_component import SimToRealComponent
-from ark.system.driver.sensor_driver import SensorDriver
-from ark.tools.log import log
-from typing import Any, Optional, Dict, Tuple, List, Union
-from pathlib import Path
-import os
-import yaml
-
-from arktypes import flag_t
-
-
-class Sensor(SimToRealComponent, ABC):
- """Base class for sensors used in the framework."""
-
- def __init__(
- self,
- name: str,
- global_config: Dict[str, Any] = None,
- driver: Optional[SensorDriver] = None,
- ) -> None:
- """Create a sensor component.
-
- @param name Name of the sensor.
- @param global_config Global configuration dictionary.
- @param driver Optional :class:`SensorDriver` implementation.
- """
-
- super().__init__(name, global_config, driver) # handles self.name, self.sim
- self.sensor_config = self._load_config_section(
- global_config=global_config, name=name, type="sensors"
- )
-
- # if runing a real system
- if not self.sim:
- try:
- self.freq = self.sensor_config["frequency"]
- except:
- log.warning(
- f"No frequency provided for sensor '{self.name}', using default !"
- )
- self.freq = 240
- self.create_stepper(self.freq, self.step_component)
-
- self.create_service(
- self.reset_service_name, flag_t, flag_t, self.reset_component
- )
-
- @abstractmethod
- def get_sensor_data(self) -> Any:
- """Acquire data from the sensor or its simulation."""
-
- @abstractmethod
- def pack_data(self, data: Any):
- """Transform raw sensor data into the message format."""
-
- # # OVERRIDE
- # def shutdown(self) -> None:
- # # kill driver (close ports, ...)
- # self._driver.shutdown_driver()
- # # kill all communication
- # super().shutdown()
-
- def reset_component(self) -> None:
- """Reset the sensor state if necessary."""
-
- def step_component(self):
- """Retrieve, pack and publish sensor data."""
- data = self.get_sensor_data()
- packed = self.pack_data(data)
- self.component_multi_publisher.publish(packed)
diff --git a/ark/system/component/sim_component.py b/ark/system/component/sim_component.py
deleted file mode 100644
index 0c236b6..0000000
--- a/ark/system/component/sim_component.py
+++ /dev/null
@@ -1,48 +0,0 @@
-"""Base classes for simulation objects."""
-
-from abc import ABC, abstractmethod
-from typing import Any, Dict
-
-from ark.tools.log import log
-from ark.system.component.base_component import BaseComponent
-from arktypes import flag_t, rigid_body_state_t
-
-
-class SimComponent(BaseComponent, ABC):
- """Base class for simulated rigid bodies."""
-
- def __init__(self, name: str, global_config: Dict[str, Any] = None) -> None:
- """Create a simulation component.
-
- @param name Name of the object.
- @param global_config Global configuration dictionary.
- """
- super().__init__(name=name, global_config=global_config)
- # extract this components configuration from the global configuration
- self.config = self._load_config_section(
- global_config=global_config, name=name, type="objects"
- )
- self.namespace = global_config["namespace"]
- # whether this should publish state information
- self.publish_ground_truth = self.config["publish_ground_truth"]
- # initialize service for reset of any component
- self.reset_service_name = f"{self.namespace}/" + self.name + "/reset/sim"
-
- self.create_service(
- self.reset_service_name, rigid_body_state_t, flag_t, self.reset_component
- )
-
- def step_component(self):
- """Gather object state and publish it if required."""
- if self.publish_ground_truth:
- data_dict = self.get_object_data()
- packed = self.pack_data(data_dict)
- self.component_multi_publisher.publish(packed)
-
- @abstractmethod
- def pack_data(self, data_dict) -> dict[str, Any]:
- """Pack object data into the message format."""
-
- @abstractmethod
- def get_object_data(self) -> Any:
- """Retrieve the current state of the simulated object."""
diff --git a/ark/system/driver/__init__.py b/ark/system/driver/__init__.py
deleted file mode 100644
index 8723fa6..0000000
--- a/ark/system/driver/__init__.py
+++ /dev/null
@@ -1,6 +0,0 @@
-"""! Ark system driver package.
-
-This package defines abstract driver interfaces which bridge ARK components
-with either simulation or hardware backends. Submodules implement drivers for
-robots, sensors and other components.
-"""
diff --git a/ark/system/driver/component_driver.py b/ark/system/driver/component_driver.py
deleted file mode 100644
index d59840b..0000000
--- a/ark/system/driver/component_driver.py
+++ /dev/null
@@ -1,119 +0,0 @@
-"""! Base component driver definitions.
-
-This module contains the :class:`ComponentDriver` abstract base class used by
-all ARK drivers. It includes helper functionality for loading configuration
-files and common attributes shared by concrete drivers.
-"""
-
-import os
-from abc import ABC, abstractmethod
-from pathlib import Path
-from typing import Any
-
-from ark.tools.log import log
-from ark.utils.utils import ConfigPath
-
-
-class ComponentDriver(ABC):
- """
- Abstract base class for a driver that facilitates communication between
- component classes and a backend (e.g., simulator or hardware). This class
- should handle backend-specific details.
-
- Attributes:
- component_name (str): The name of the component using this driver.
- component_config (Dict[str, Any], optional): Configuration settings
- for the component. Defaults to None.
- """
-
- def __init__(
- self,
- component_name: str,
- component_config: Any = None,
- sim: bool = True,
- ) -> None:
- """! Initialize the driver.
-
- @param component_name Name of the component using this driver.
- @param component_config Path or dictionary with configuration for the
- component using this driver.
- @param sim Set to ``True`` if running in simulation mode.
- """
- self.component_name = component_name
-
- if not isinstance(component_config, dict):
- self.config = self._load_single_section(component_config, component_name)
- else:
- self.config = component_config
- self.sim = sim
-
- def _load_single_section(self, component_config, component_name):
- """! Load the configuration of a single component from a YAML file.
-
- This helper parses a global configuration file and extracts the
- subsection corresponding to ``component_name``.
-
- @param component_config Path to a YAML file or a ``Path`` object
- pointing to the configuration file.
- @param component_name Name of the component whose configuration should
- be loaded.
- @return Dictionary containing the configuration for the component.
- """
-
- # handle path object vs string
- if isinstance(component_config, str):
- component_config = ConfigPath(component_config)
- elif isinstance(component_config, Path):
- component_config = ConfigPath(str(component_config))
-
- if not component_config.exists():
- raise FileNotFoundError("Given configuration file path does not exist.")
-
- if not component_config.is_absolute():
- component_config = component_config.resolve()
-
- cfg = component_config.read_yaml()
- section_config = {}
- for section_name in ["robots", "sensors", "objects"]:
- for item in cfg.get(section_name, []):
- if isinstance(item, dict): # If it's an inline configuration
- subconfig = item
- elif isinstance(item, str) and item.endswith(
- ".yaml"
- ): # If it's a path to an external file
- if os.path.isabs(item): # Check if the path is absolute
- external_path = ConfigPath(item)
- else: # Relative path, use the directory of the main config file
- external_path = component_config.parent / item
- # Load the YAML file and return its content
- subconfig = external_path.read_yaml()
- else:
- log.error(
- f"Invalid entry in '{section_name}': {item}. Please provide either a config or a path to another config."
- )
- continue # Skip invalid entries
-
- if subconfig["name"] == component_name:
- section_config = subconfig["config"]
- if not section_config:
- log.error(
- f"Could not find configuration for {component_name} in {component_config.str}"
- )
- return section_config
-
- def is_sim(self):
- """! Return whether this driver is running in simulation mode.
-
- @return ``True`` if the driver targets a simulator, ``False`` otherwise.
- """
-
- return self.sim
-
- @abstractmethod
- def shutdown_driver(self) -> None:
- """! Shut down the driver and release all resources.
-
- Concrete drivers should override this method to close connections or
- stop any background tasks started by the driver.
- """
- pass
diff --git a/ark/system/driver/robot_driver.py b/ark/system/driver/robot_driver.py
deleted file mode 100644
index 36fe8a1..0000000
--- a/ark/system/driver/robot_driver.py
+++ /dev/null
@@ -1,141 +0,0 @@
-"""! Robot driver base classes.
-
-This module defines abstract interfaces for robot drivers used by the ARK
-framework. Drivers act as the glue between high level robot components and the
-underlying backend (simulation or real hardware).
-"""
-
-from abc import ABC, abstractmethod
-from enum import Enum
-from typing import Any
-
-from ark.tools.log import log
-from ark.system.driver.component_driver import ComponentDriver
-
-
-class ControlType(Enum):
- """! Supported control modes for robot joints."""
-
- POSITION = "position"
- VELOCITY = "velocity"
- TORQUE = "torque"
- FIXED = "fixed"
-
-
-class RobotDriver(ComponentDriver):
- """! Abstract driver interface for robots.
-
- This class defines the common API that concrete robot drivers must
- implement in order to communicate joint states and control commands to a
- backend system.
- """
-
- def __init__(
- self,
- component_name: str,
- component_config: dict[str, Any] = None,
- sim: bool = True,
- ) -> None:
- """! Construct the driver.
-
- @param component_name Name of the robot component.
- @param component_config Configuration dictionary or path.
- @param sim True if the driver interfaces with a simulator.
- """
-
- super().__init__(
- component_name=component_name, component_config=component_config, sim=sim
- )
-
- #####################
- ## get infos ##
- #####################
-
- @abstractmethod
- def check_torque_status(self) -> bool:
- """! Check whether torque control is enabled on the robot.
-
- @return ``True`` if torque control is active, ``False`` otherwise.
- """
-
- pass
-
- @abstractmethod
- def pass_joint_positions(self, joints: list[str]) -> dict[str, float]:
- """! Retrieve the current joint positions.
-
- @param joints Names of the queried joints.
- @return Dictionary mapping each joint name to its position in radians.
- """
-
- pass
-
- @abstractmethod
- def pass_joint_velocities(self, joints: list[str]) -> dict[str, float]:
- """! Retrieve the current joint velocities.
-
- @param joints Names of the queried joints.
- @return Dictionary mapping each joint name to its velocity.
- """
-
- pass
-
- @abstractmethod
- def pass_joint_efforts(self, joints: list[str]) -> dict[str, float]:
- """! Retrieve the current joint efforts (torques or forces).
-
- @param joints Names of the queried joints.
- @return Dictionary mapping each joint name to its effort value.
- """
-
- pass
-
- #####################
- ## control ##
- #####################
-
- @abstractmethod
- def pass_joint_group_control_cmd(
- self, control_mode: str, cmd: dict[str, float], **kwargs
- ) -> None:
- """! Send a control command to a group of joints.
-
- @param control_mode One of :class:`ControlType` specifying the command type.
- @param cmd Dictionary of joint names to command values.
- @param kwargs Additional backend-specific parameters.
- """
-
- pass
-
-
-class SimRobotDriver(RobotDriver, ABC):
- """! Base class for drivers controlling simulated robots."""
-
- def __init__(
- self,
- component_name: str,
- component_config: dict[str, Any] = None,
- sim: bool = True,
- ) -> None:
- """! Initialize the simulation driver.
-
- @param component_name Name of the robot component.
- @param component_config Configuration dictionary or path.
- @param sim Unused for simulated robots (always ``True``).
- """
-
- super().__init__(component_name, component_config, True)
-
- @abstractmethod
- def sim_reset(
- self, base_pos: list[float], base_orn: list[float], init_pos: list[float]
- ) -> None:
- """! Reset the robot's state in the simulator."""
-
- ...
-
- def shutdown_driver(self) -> None:
- """! Shut down the simulation driver."""
-
- # Nothing to handle here
- pass
diff --git a/ark/system/driver/sensor_driver.py b/ark/system/driver/sensor_driver.py
deleted file mode 100644
index 5320cc0..0000000
--- a/ark/system/driver/sensor_driver.py
+++ /dev/null
@@ -1,101 +0,0 @@
-"""! Sensor driver base definitions.
-
-This module contains abstract base classes for sensor drivers used throughout
-the ARK framework. Drivers handle backend-specific details for sensors such as
-cameras or LiDARs.
-"""
-
-from abc import ABC, abstractmethod
-from enum import Enum
-from typing import Any, Optional, Dict, List
-
-from ark.tools.log import log
-from ark.system.driver.component_driver import ComponentDriver
-
-import numpy as np
-
-
-class SensorType(Enum):
- """! Enumeration of supported sensor types."""
-
- CAMERA = "camera"
- FORCE_TORQUE = "force_torque"
-
-
-class SensorDriver(ComponentDriver, ABC):
- """! Abstract driver interface for sensors.
-
- Concrete sensor drivers inherit from this class and implement the required
- methods to acquire data from a simulator or hardware backend.
- """
-
- def __init__(
- self,
- component_name: str,
- component_config: Dict[str, Any] = None,
- sim: bool = True,
- ) -> None:
- """! Initialize the sensor driver.
-
- @param component_name Name of the sensor component.
- @param component_config Configuration dictionary or path.
- @param sim True if running in simulation mode.
- """
-
- super().__init__(component_name, component_config, sim)
-
-
-class CameraDriver(SensorDriver, ABC):
- """! Base class for camera sensor drivers."""
-
- def __init__(
- self,
- component_name: str,
- component_config: Dict[str, Any] = None,
- sim: bool = True,
- ) -> None:
- """! Initialize the camera driver."""
-
- super().__init__(component_name, component_config, sim)
-
- @abstractmethod
- def get_images(self) -> Dict[str, np.ndarray]:
- """! Retrieve images from the camera."""
-
- ...
-
-
-class LiDARDriver(SensorDriver, ABC):
- """!
- Abstract base class for LiDAR sensor drivers.
-
- Defines the required interface for retrieving LiDAR scan data.
- """
-
- def __init__(
- self,
- component_name: str,
- component_config: Dict[str, Any] = None,
- sim: bool = True,
- ) -> None:
- """!
- Initialize the LiDAR driver.
-
- @param component_name Name of the LiDAR component.
- @param component_config Configuration dictionary.
- @param sim True if running in simulation mode.
- """
- super().__init__(component_name, component_config, sim)
-
- @abstractmethod
- def get_scan(self) -> Dict[str, np.ndarray]:
- """!
- Retrieve a LiDAR scan.
-
- @return Dictionary containing:
- - "angles": 1D NumPy array of angles (in radians) in the LiDAR's reference frame.
- - "ranges": 1D NumPy array of range values (in meters).
-
- Angles and ranges must be aligned such that each angle corresponds to the respective range index.
- """
- ...
diff --git a/ark/system/genesis/genesis_backend.py b/ark/system/genesis/genesis_backend.py
deleted file mode 100644
index 91178f7..0000000
--- a/ark/system/genesis/genesis_backend.py
+++ /dev/null
@@ -1,343 +0,0 @@
-from __future__ import annotations
-
-import importlib.util
-import sys
-from pathlib import Path
-from typing import Any
-
-import cv2
-import genesis as gs
-import numpy as np
-
-from ark.tools.log import log
-from ark.system.simulation.simulator_backend import SimulatorBackend
-
-from ark.system.genesis.genesis_multibody import GenesisMultiBody
-
-
-def import_class_from_directory(path: Path) -> tuple[type[Any], Any | None]:
- """Load and return a class (and optional driver) from ``path``.
-
- The helper searches for ``.py`` inside ``path`` and imports the
- class with the same name. When the module exposes a ``Drivers`` class a
- ``GENESIS_DRIVER`` attribute is returned alongside the main class.
- """
-
- class_name = path.name
- file_path = (path / f"{class_name}.py").resolve()
- if not file_path.exists():
- raise FileNotFoundError(f"The file {file_path} does not exist.")
-
- module_dir = str(file_path.parent)
- sys.path.insert(0, module_dir)
-
- try:
- spec = importlib.util.spec_from_file_location(class_name, file_path)
- if spec is None or spec.loader is None:
- raise ImportError(f"Could not load module from {file_path}")
-
- module = importlib.util.module_from_spec(spec)
- sys.modules[class_name] = module
- spec.loader.exec_module(module)
- finally:
- sys.modules.pop(class_name, None)
- sys.path.pop(0)
-
- drivers_attr: Any | None = None
- drivers_cls = getattr(module, "Drivers", None)
- if isinstance(drivers_cls, type):
- drivers_attr = getattr(drivers_cls, "GENESIS_DRIVER", None)
-
- class_candidates = [
- obj
- for obj in vars(module).values()
- if isinstance(obj, type) and obj.__module__ == module.__name__
- ]
-
- target_class = next(
- (cls for cls in class_candidates if cls.__name__ == class_name), None
- )
- if target_class is None:
- non_driver_classes = [
- cls for cls in class_candidates if cls.__name__ != "Drivers"
- ]
- if len(non_driver_classes) != 1:
- raise ValueError(
- f"Expected a single class definition in {file_path}, found {len(non_driver_classes)}."
- )
- target_class = non_driver_classes[0]
-
- return target_class, drivers_attr
-
-
-class GenesisBackend(SimulatorBackend):
- """Backend wrapper around the Genesis client.
-
- This class handles scene creation, stepping the simulation and managing
- simulated components such as robots, objects and sensors.
- """
-
- def initialize(self) -> None:
- """!Initialize the Genesis world.
-
- The method creates the Genesis client, configures gravity and time step
- and loads all robots, objects and sensors defined in
- ``self.global_config``. Optional frame capture settings are applied as
- well.
- """
- self.ready = False
- self._is_initialised = False
- self.scene: gs.Scene | None = None
- self.scene_ready: bool = False
-
- connection_mode = (
- self.global_config["simulator"]["config"]["connection_mode"]
- )
- show_viewer = connection_mode.upper() == "GUI"
-
- gs.init(backend=gs.cpu)
-
- gravity = self.global_config["simulator"]["config"].get(
- "gravity", [0.0, 0.0, -9.81]
- )
- timestep = 1.0 / self.global_config["simulator"]["config"].get(
- "sim_frequency", 100
- )
-
- self.scene = gs.Scene(
- sim_options=gs.options.SimOptions(dt=timestep, gravity=gravity),
- show_viewer=show_viewer,
- )
-
- self.scene.add_entity(gs.morphs.Plane())
-
- # Optional off-screen rendering
- self.save_render_config: dict[str, Any] | None = self.global_config[
- "simulator"
- ].get("save_render")
- if self.save_render_config:
- self.render_cam = self.scene.add_camera(
- res=(640, 480),
- pos=(3.5, 0.0, 2.5),
- lookat=(0, 0, 0.5),
- fov=30,
- GUI=False,
- )
- self.save_path = Path(
- self.save_render_config.get("save_path", "output/save_render")
- )
- self.save_path.mkdir(parents=True, exist_ok=True)
-
- remove_existing = self.save_render_config.get("remove_existing", True)
- if remove_existing:
- for child in self.save_path.iterdir():
- if child.is_file():
- child.unlink()
- self.save_interval = float(
- self.save_render_config.get("save_interval", 1 / 30)
- )
- self.overwrite_file = bool(
- self.save_render_config.get("overwrite_file", False)
- )
- else:
- self.render_cam = None
- self.save_path = None
- self.save_interval = 0.0
- self.overwrite_file = False
-
- # Setup robots
- if self.global_config.get("robots", None):
- for robot_name, robot_config in self.global_config["robots"].items():
- self.add_robot(robot_name, robot_config)
-
- # Setup objects
- if self.global_config.get("objects", None):
- for obj_name, obj_config in self.global_config["objects"].items():
- self.add_sim_component(obj_name, obj_config)
-
- # Sensors have to be set up last, as e.g. cameras might need
- # a parent to attach to
- if self.global_config.get("sensors", None):
- for sensor_name, sensor_config in self.global_config["sensors"].items():
- self.add_sensor(sensor_name, sensor_config)
-
- self.ready = True
-
- def is_ready(self) -> bool:
- """!Check whether the backend has finished initialization.
-
- @return ``True`` once all components were created and the simulator is
- ready for stepping.
- @rtype bool
- """
- return self.ready
-
- def set_gravity(self, gravity: tuple[float, float, float]) -> None:
- """!Set the world gravity.
-
- @param gravity Tuple ``(gx, gy, gz)`` specifying gravity in m/s^2.
- """
- raise NotImplementedError("Not required for Genesis")
-
- def set_time_step(self, time_step: float) -> None:
- """!Set the simulation timestep.
-
- @param time_step Length of a single simulation step in seconds.
- """
- raise NotImplementedError("Not required for Genesis")
-
- ##########################################################
- #### ROBOTS, SENSORS AND OBJECTS ####
- ##########################################################
-
- def add_robot(self, name: str, robot_config: dict[str, Any]) -> None:
- """!Instantiate and register a robot in the simulation.
-
- @param name Identifier for the robot.
- @param robot_config Robot specific configuration dictionary.
- """
- class_path = Path(robot_config["class_dir"])
- if class_path.is_file():
- class_path = class_path.parent
-
- robot_class, driver_entry = import_class_from_directory(class_path)
-
- if driver_entry is None:
- raise ValueError(
- f"Genesis driver not defined for robot '{name}' at {class_path}."
- )
-
- driver_cls = getattr(driver_entry, "value", driver_entry)
- if self.scene is None:
- raise RuntimeError("Genesis scene is not initialized.")
-
- driver = driver_cls(name, robot_config, self.scene)
- robot = robot_class(name=name, global_config=self.global_config, driver=driver)
-
- self.robot_ref[name] = robot
-
- def add_sim_component(
- self,
- name: str,
- obj_config: dict[str, Any],
- ) -> None:
- """!Add a generic simulated object.
-
- @param name Name of the object.
- @param obj_config Object specific configuration dictionary.
- """
- if self.scene is None:
- raise RuntimeError("Genesis scene is not initialized.")
-
- sim_component = GenesisMultiBody(
- name=name, client=self.scene, global_config=self.global_config
- )
- self.object_ref[name] = sim_component
-
- def add_sensor(self, name: str, sensor_config: dict[str, Any]) -> None:
- """!Instantiate and register a sensor.
-
- @param name Name of the sensor component.
- @param sensor_config Sensor configuration dictionary.
- """
- raise NotImplementedError("Sensors are not compatible with Genesis yet.")
- # Cameras are not supported on MacOS, Ubuntu Cameras are not working
- # Genesis-Embodied-AI/Genesis#1739
-
- def remove(self, name: str) -> None:
- """!Remove a component from the simulator.
-
- @param name Name of the robot, object or sensor to remove.
- """
- raise NotImplementedError("Genesis does not support removing components.")
-
- #######################################
- #### SIMULATION ####
- #######################################
-
- def _all_available(self) -> bool:
- """Return ``True`` when all registered components are active."""
-
- robots_ready = all(not robot._is_suspended for robot in self.robot_ref.values())
- objects_ready = all(not obj._is_suspended for obj in self.object_ref.values())
- sensors_ready = all(
- not sensor._is_suspended for sensor in self.sensor_ref.values()
- )
- return robots_ready and objects_ready and sensors_ready
-
- def save_render(self) -> None:
- """Add the latest render to save folder if rendering is configured."""
-
- if self.render_cam is None or self.save_path is None:
- return
-
- rgba = self.render_cam.render()
- time_us = int(1e6 * self._simulation_time)
- if self.overwrite_file:
- save_path = self.save_path / "render.png"
- else:
- save_path = self.save_path / f"{time_us}.png"
- # Convert renderer output to uint8 BGR image for OpenCV
- img = np.asarray(rgba)
- # Drop alpha channel if present
- if img.ndim == 3 and img.shape[-1] == 4:
- img = img[..., :3]
- # Normalize to uint8 if needed (assume float in [0,1])
- if img.dtype != np.uint8:
- img = np.clip(img, 0.0, 1.0)
- img = (img * 255.0).astype(np.uint8)
- # Convert RGB -> BGR for OpenCV
- img_bgr = img[..., ::-1]
- cv2.imwrite(str(save_path), img_bgr)
-
-
- def step(self) -> None:
- """!Advance the simulation by one timestep.
-
- The method updates all registered components, advances the physics
- engine and optionally saves renders when enabled.
- """
- if self.scene is None:
- raise RuntimeError("Genesis scene is not initialized.")
-
- if not self.scene_ready:
- self.scene.build()
- self.scene_ready = True
-
- if not self._all_available():
- log.warn("Skipping simulation step because a component is suspended.")
- return
-
- self._step_sim_components()
- self.scene.step()
- if self.save_render_config:
- self.save_render()
-
- def reset_simulator(self) -> None:
- """!Reset the entire simulator state.
-
- All robots, objects and sensors are destroyed and the backend is
- re-initialized using ``self.global_config``.
- """
- raise NotImplementedError("Reset simulator not implemented yet.")
-
- def get_current_time(self) -> float:
- """!Return the current simulation time.
-
- @return Elapsed simulation time in seconds.
- @rtype float
- """
- return self.scene.t
-
- def shutdown_backend(self) -> None:
- """!Disconnect all components and shut down the backend.
-
- This should be called at program termination to cleanly close the
- simulator and free all resources.
- """
- for robot in self.robot_ref.values():
- robot.kill_node()
- for obj in self.object_ref.values():
- obj.kill_node()
- for sensor in self.sensor_ref.values():
- sensor.kill_node()
diff --git a/ark/system/genesis/genesis_multibody.py b/ark/system/genesis/genesis_multibody.py
deleted file mode 100644
index 5e8aeff..0000000
--- a/ark/system/genesis/genesis_multibody.py
+++ /dev/null
@@ -1,153 +0,0 @@
-from enum import Enum
-from typing import Any
-
-import genesis as gs
-
-from ark.tools.log import log
-from ark.system.component.sim_component import SimComponent
-from arktypes import flag_t, rigid_body_state_t
-
-
-class SourceType(Enum):
- """Supported source types for object creation."""
-
- URDF = "urdf"
- PRIMITIVE = "primitive"
- SDF = "sdf"
- MJCF = "mjcf"
-
-
-class GenesisMultiBody(SimComponent):
- """Utility class for creating Genesis multi-body objects."""
-
- def __init__(
- self,
- name: str,
- client: Any,
- global_config: dict[str, Any] | None = None,
- ) -> None:
- """Instantiate a GenesisMultiBody object.
-
- @param name Name of the object.
- @param client Genesis client used for creation.
- @param global_config Global configuration dictionary.
- @return ``None``
- """
-
- super().__init__(name, global_config)
- self.client = client
- self.body: Any | None = None
- source_str = self.config["source"]
- source_type = getattr(SourceType, source_str.upper())
-
- if source_type == SourceType.PRIMITIVE:
- # Fall back to the original primitive creation if no URDF path is provided
- vis = self.config.get("visual", {})
- vis_shape_type = str(vis.get("shape_type", "GEOM_BOX")).upper()
- vis_opts = vis.get("visual_shape", {})
-
- col = self.config.get("collision", {})
-
- mass = self.config.get("multi_body", {}).get("baseMass", 1.0)
- color = vis_opts.get("rgbaColor", [1, 0, 0, 1]) # Default to red if not provided
-
- if vis_shape_type == "GEOM_SPHERE":
- radius = vis_opts.get("radius", 0.5)
- self.body = self.client.add_entity(
- gs.morphs.Sphere(
- pos=self.config.get("base_position", [0, 0, 0]),
- quat=self.config.get("base_orientation", [0, 0, 0, 1]),
- radius=radius,
- fixed=True if mass == 0 else False,
- ),
- )
- elif vis_shape_type == "GEOM_BOX":
- size = vis_opts.get("halfExtents", [0.5, 0.5, 0.5])
- # Convert half extents to full size
- size = [2 * s for s in size]
- self.body = self.client.add_entity(
- gs.morphs.Box(
- pos=self.config.get("base_position", [0, 0, 0]),
- quat=self.config.get("base_orientation", [0, 0, 0, 1]),
- size=size,
- fixed=True if mass == 0 else False,
- ),
- )
- else:
- log.warn(
- f"Unsupported primitive type '{vis_shape_type}' for Genesis multi-body; no entity created."
- )
-
- # Set mass for dynamic objects (mass > 0)
- if mass != 0 and self.body:
- self.body.set_mass(mass)
-
- elif source_type == SourceType.SDF:
- raise ValueError("Not Supported for Genesis")
- elif source_type == SourceType.MJCF:
- raise ValueError("Please use Robot for MJCF files in Genesis")
- else:
- log.error("Unknown source specification. Check your config file.")
-
- # setup communication
- self.publisher_name = self.name + "/ground_truth/sim"
- if self.publish_ground_truth:
- self.state_publisher = self.component_channels_init(
- {self.publisher_name: rigid_body_state_t}
- )
-
- def get_object_data(self) -> dict[str, Any]:
- """!Return the current state of the simulated object.
-
- @return Dictionary with position, orientation and velocities of the
- object.
- @rtype Dict[str, Any]
- """
- if self.body is None:
- raise RuntimeError("Genesis body has not been created yet.")
-
- position = self.body.get_pos()
- orientation = self.body.get_quat()
- lin_vel = self.body.get_vel()
- ang_vel = self.body.get_ang()
- return {
- "name": self.name,
- "position": position,
- "orientation": orientation,
- "lin_velocity": lin_vel,
- "ang_velocity": ang_vel,
- }
-
- def pack_data(self, data_dict: dict[str, Any]) -> dict[str, rigid_body_state_t]:
- """!Convert a state dictionary to a ``rigid_body_state_t`` message.
-
- @param data_dict Dictionary as returned by :func:`get_object_data`.
- @return Mapping suitable for :class:`MultiChannelPublisher`.
- @rtype Dict[str, rigid_body_state_t]
- """
- msg = rigid_body_state_t()
- msg.name = data_dict["name"]
- msg.position = data_dict["position"]
- msg.orientation = data_dict["orientation"]
- msg.lin_velocity = data_dict["lin_velocity"]
- msg.ang_velocity = data_dict["ang_velocity"]
- return {self.publisher_name: msg}
-
- def reset_component(self, channel: str, msg: rigid_body_state_t) -> flag_t:
- """!Reset the object pose using a message.
-
- @param channel LCM channel on which the reset request was received.
- @param msg ``rigid_body_state_t`` containing the desired pose.
- @return ``flag_t`` acknowledging the reset.
- """
- new_pos = msg.position
- new_orn = msg.orientation
- log.info(f"Resetting object {self.name} to position: {new_pos}")
- if self.body is None:
- raise RuntimeError("Cannot reset object before it has been created.")
-
- self.body.set_pos(new_pos)
- self.body.set_quat(new_orn)
- log.ok(f"Reset object {self.name} completed at: {new_pos}")
-
- return flag_t()
diff --git a/ark/system/genesis/genesis_robot_driver.py b/ark/system/genesis/genesis_robot_driver.py
deleted file mode 100644
index ba34a41..0000000
--- a/ark/system/genesis/genesis_robot_driver.py
+++ /dev/null
@@ -1,166 +0,0 @@
-"""Genesis robot driver implementation."""
-
-from __future__ import annotations
-
-import math
-from collections.abc import Mapping, Sequence
-from pathlib import Path
-from typing import Any
-
-import genesis as gs
-import numpy as np
-
-from ark.tools.log import log
-from ark.system.driver.robot_driver import ControlType, SimRobotDriver
-
-
-class GenesisRobotDriver(SimRobotDriver):
- """Robot driver that interfaces with the Genesis simulation."""
-
- def __init__(
- self,
- component_name: str,
- component_config: dict[str, Any] | None = None,
- client: gs.Scene | None = None,
- ) -> None:
- """Create a Genesis robot driver."""
-
- if component_config is None:
- raise ValueError("GenesisRobotDriver requires a component configuration.")
- if client is None:
- raise ValueError("GenesisRobotDriver requires an initialized Genesis scene.")
-
- super().__init__(component_name, component_config, True)
-
- self.client: gs.Scene = client
- self.ref_body_id: Any | None = None
-
- base_position_cfg = self.config.get("base_position", [0.0, 0.0, 0.0])
- self.base_position: list[float] = list(base_position_cfg)
-
- base_orientation_cfg = self.config.get(
- "base_orientation", [0.0, 0.0, 0.0, 1.0]
- )
- self.base_orientation = list(base_orientation_cfg)
-
- self.joint_names: list[str] = list(self.config.get("joint_names", []))
- if not self.joint_names:
- raise ValueError(
- f"Robot '{component_name}' configuration must define 'joint_names'."
- )
-
- self.load_robot(self.base_position, self.base_orientation, None)
-
- if self.ref_body_id is None:
- raise RuntimeError("Robot entity has not been created in Genesis.")
-
- self.dofs_idx = [
- self.ref_body_id.get_joint(name).dof_idx_local for name in self.joint_names
- ]
-
- def load_robot(
- self,
- base_position: Sequence[float] | None = None,
- base_orientation: Sequence[float] | None = None,
- q_init: Sequence[float] | None = None,
- ) -> None:
- """Load the robot model into the Genesis simulator."""
-
- if base_position is None:
- base_position = self.base_position
- if base_orientation is None:
- base_orientation = self.base_orientation
-
- mjcf_path_cfg = self.config.get("mjcf_path")
- if mjcf_path_cfg is None:
- raise ValueError(
- f"Robot '{self.component_name}' configuration requires 'mjcf_path'."
- )
-
- mjcf_path = Path(mjcf_path_cfg)
- self.ref_body_id = self.client.add_entity(
- gs.morphs.MJCF(file=str(mjcf_path))
- )
-
- log.ok(
- f"Initialized robot specified by MJCF '{mjcf_path}' in Genesis simulator."
- )
-
- #####################
- ## get infos ##
- #####################
-
- def check_torque_status(self) -> bool:
- """!Return ``True`` as simulated robots are always torqued.
-
- @return Always ``True`` in simulation.
- @rtype bool
- """
- return True # simulated robot is always torqued in genesis
-
- def pass_joint_positions(self, joints: list[str]) -> dict[str, float]:
- """Return the current joint positions."""
-
- pos = {}
- joint_pos = self.ref_body_id.get_dofs_position()
- for idx, name in enumerate(self.joint_names):
- pos[name] = joint_pos[idx]
- return pos
-
- def pass_joint_velocities(self, joints: list[str]) -> dict[str, float]:
- """Return the current joint velocities."""
- vel = {}
- joint_vel = self.ref_body_id.get_dofs_velocity()
- for idx, name in enumerate(self.joint_names):
- vel[name] = joint_vel[idx]
- return vel
-
- def pass_joint_efforts(self, joints: list[str]) -> dict[str, float]:
- """Return the current joint efforts."""
- eff = {}
- joint_vel = self.ref_body_id.get_dofs_force()
- for idx, name in enumerate(self.joint_names):
- eff[name] = joint_vel[idx]
- return eff
-
- #####################
- ## control ##
- #####################
-
- def pass_joint_group_control_cmd(
- self,
- control_mode: ControlType | str,
- cmd: Mapping[str, float],
- **kwargs: Any,
- ) -> None:
- """Send a control command to a group of joints."""
-
-
- self.ref_body_id.control_dofs_position(
- np.array(list(cmd.values())),
- self.dofs_idx
- )
-
-
- #####################
- ## misc. ##
- #####################
-
- def sim_reset(
- self,
- base_pos: Sequence[float],
- base_orn: Sequence[float],
- q_init: Sequence[float] | None,
- ) -> None:
- """Reset the robot in the simulator."""
-
- if self.ref_body_id is None:
- raise RuntimeError("Robot entity has not been initialized in Genesis.")
-
- self.ref_body_id.set_pos(list(base_pos))
- self.ref_body_id.set_quat(list(base_orn))
-
- if q_init is not None:
- self.ref_body_id.control_dofs_position(np.array(q_init, dtype=float), self.dofs_idx)
-
- log.ok(f"Reset robot {self.component_name} completed.")
diff --git a/ark/system/isaac/README.md b/ark/system/isaac/README.md
deleted file mode 100644
index 953ab92..0000000
--- a/ark/system/isaac/README.md
+++ /dev/null
@@ -1,18 +0,0 @@
-# Isaac Sim integration (ARK)
-
-Run the ARK Isaac Simulator with the standalone Isaac Sim 5.0.0 build.
-
-## Prerequisites
-- Python environment set up for this repository (install ARK deps first).
-- NVIDIA driver that meets Isaac Sim requirements.
-
-## Setup & run
-1) Download Isaac Sim 5.0.0 standalone
- `https://download.isaacsim.omniverse.nvidia.com/isaac-sim-standalone-5.0.0-linux-x86_64.zip`
-2) Extract the archive anywhere on your machine.
-3) Point ARK to the extracted Isaac Sim root by setting `ARK_ISSAC_PATH` (note: path must point to the root containing `python.sh`/`python.exe` and `kit`):
- - Linux/macOS: `export ARK_ISSAC_PATH=/path/to/isaac-sim-5.0.0-linux-x86_64`
- - Windows (PowerShell): `$env:ARK_ISSAC_PATH="C:\\path\\to\\isaac-sim-5.0.0-windows-x86_64"`
-4) From Ark, launch the simulator node from where you have it:
- `python sim_node.py`
-
diff --git a/ark/system/isaac/isaac_backend.py b/ark/system/isaac/isaac_backend.py
deleted file mode 100644
index eb9fcb2..0000000
--- a/ark/system/isaac/isaac_backend.py
+++ /dev/null
@@ -1,302 +0,0 @@
-from __future__ import annotations
-
-import ast
-import importlib.util
-import os
-import sys
-from pathlib import Path
-from typing import Any, Optional
-
-from ark.system.isaac.isaac_object import IsaacSimObject
-from ark.system.simulation.simulator_backend import SimulatorBackend
-from ark.tools.log import log
-from ark.utils import lazy
-
-from ark.utils.isaac_utils import configure_isaac_setup
-
-configure_isaac_setup()
-try:
- from isaacsim import SimulationApp
-except ImportError as exc:
- raise ImportError(
- "Isaac Sim Python packages are required for the IsaacSim backend. "
- "Install and source Isaac Sim before selecting backend_type=isaacsim."
- ) from exc
-
-
-def import_class_from_directory(path: Path) -> tuple[type, Optional[type]]:
- """Load the component class and its optional Isaac Sim driver declaration.
-
- Components that want to provide a custom Isaac Sim driver should expose a
- ``Drivers`` enum with an ``ISAAC_DRIVER`` entry (mirroring the existing
- ``PYBULLET_DRIVER`` / ``MUJOCO_DRIVER`` pattern). If none is provided the
- backend falls back to :class:`IsaacSimRobotDriver`.
- """
-
- class_name = path.name
- file_path = path / f"{class_name}.py"
- file_path = file_path.resolve()
- if not file_path.exists():
- raise FileNotFoundError(f"The file {file_path} does not exist.")
-
- with open(file_path, "r", encoding="utf-8") as file:
- tree = ast.parse(file.read(), filename=file_path)
-
- module_dir = os.path.dirname(file_path)
- sys.path.insert(0, module_dir)
-
- class_names = [
- node.name for node in ast.walk(tree) if isinstance(node, ast.ClassDef)
- ]
-
- driver_cls = None
- if "Drivers" in class_names:
- spec = importlib.util.spec_from_file_location(class_names[0], file_path)
- module = importlib.util.module_from_spec(spec)
- sys.modules[class_names[0]] = module
- spec.loader.exec_module(module)
-
- class_ = getattr(module, class_names[0])
- sys.path.pop(0)
- driver_cls = class_.ISAAC_DRIVER
- class_names.remove("Drivers")
-
- spec = importlib.util.spec_from_file_location(class_name, file_path)
- module = importlib.util.module_from_spec(spec)
- sys.modules[class_name] = module
- spec.loader.exec_module(module)
-
- class_ = getattr(module, class_names[0])
- sys.path.pop(0)
-
- if driver_cls is not None and hasattr(driver_cls, "value"):
- driver_cls = driver_cls.value
-
- return class_, driver_cls
-
-
-class IsaacSimBackend(SimulatorBackend):
- """Backend wrapper for running ARK simulations inside Isaac Sim."""
-
- def __init__(self, global_config: dict[str, Any]) -> None:
- """Initialize the backend and parse simulator configuration.
-
- Determines connection mode (GUI or headless), sets defaults, and
- registers the backend's custom event loop.
-
- Args:
- global_config (dict[str, Any]): Global ARK simulation configuration
- containing the simulator and component specifications.
- """
- self._app = None
- self.world = None
- self._headless = True
- self.timestep = 0.0
-
- super().__init__(global_config)
-
- sim_cfg = self.global_config["simulator"]["config"]
- connection_mode = sim_cfg.get("connection_mode", "headless").lower()
- self._headless = connection_mode != "gui"
- self.custom_event_loop = self.run
-
- def initialize(self) -> None:
- """Initialize the Isaac Sim application, stage, and scene components."""
-
- # Create a simulator
- self._app = SimulationApp({"headless": self._headless})
-
- sim_cfg = self.global_config["simulator"]["config"]
- physics_dt = 1 / sim_cfg.get("sim_frequency", 120.0)
- render_dt = 1 / sim_cfg.get("render_frequency", 60)
-
- # Creates scene
- self.world = lazy.isaacsim.core.api.World(
- stage_units_in_meters=1.0, physics_dt=physics_dt, rendering_dt=render_dt
- )
- self.world.scene.add_default_ground_plane()
-
- # Set gravity
- gravity = sim_cfg.get("gravity", [0.0, 0.0, -9.81])
- self.set_gravity(gravity)
-
- self.timestep = physics_dt
-
- # Add components to the simulator
- if self.global_config.get("objects", None):
- for object_name, object_config in self.global_config["objects"].items():
- self.add_sim_component(object_name, object_config)
-
- if self.global_config.get("robots", None):
- for robot_name, robot_config in self.global_config["robots"].items():
- self.add_robot(robot_name, robot_config)
-
- if self.global_config.get("sensors", None):
- for sensor_name, sensor_config in self.global_config["sensors"].items():
- self.add_sensor(sensor_name, sensor_config)
-
- # Allow simulator to settle
- self._app.update()
- for _ in range(250):
- self.world.step(render=True)
-
- def set_gravity(self, gravity: tuple[float, float, float]) -> None:
- """Set gravity for the simulation.
-
- Args:
- gravity (tuple[float, float, float]): Gravity vector in XYZ format.
-
- """
- if self.world is None:
- return
-
- # Isaac's physics_context.set_gravity expects a scalar
- gravity_scalar = float(gravity[2])
- self.world._physics_context.set_gravity(gravity_scalar)
-
- def reset_simulator(self) -> None:
- """Reset the Isaac Sim world and all components."""
-
- if self.world is None:
- return
-
- self.world.reset()
- for robot in self.robot_ref:
- robot._driver.sim_reset()
-
- for obj in self.object_ref:
- self.object_ref[obj].reset_component()
-
- # TODO check sensor reset
-
- for _ in range(10):
- self.world.step(render=True)
-
- def add_robot(
- self,
- name: str,
- robot_config: dict[str, Any],
- ) -> None:
- """Dynamically load a robot class and driver into the simulation.
-
- Args:
- name (str): Name of the robot component.
- robot_config (dict[str, Any]): Robot configuration, including:
- - class_dir (str): Directory containing the robot + driver classes.
- - Additional robot-specific fields.
- """
- class_path = Path(robot_config["class_dir"])
- if class_path.is_file():
- class_path = class_path.parent
-
- RobotClass, DriverClass = import_class_from_directory(class_path)
-
- driver = DriverClass(
- component_name=name,
- component_config=robot_config,
- sim_app=self._app,
- world=self.world,
- )
- robot = RobotClass(
- name=name,
- driver=driver,
- global_config=self.global_config,
- )
- self.robot_ref[name] = robot
-
- def add_sensor(
- self,
- name: str,
- sensor_config: dict[str, Any],
- ) -> None:
- """Load and register a sensor class and its driver.
-
- Args:
- name (str): Sensor name.
- sensor_config (dict[str, Any]): Sensor configuration parameters.
- """
- class_path = Path(sensor_config["class_dir"])
- if class_path.is_file():
- class_path = class_path.parent
-
- SensorClass, DriverClass = import_class_from_directory(class_path)
-
- driver = DriverClass(
- component_name=name,
- component_config=sensor_config,
- world=self.world,
- )
- sensor = SensorClass(
- name=name,
- driver=driver,
- global_config=self.global_config,
- )
- self.sensor_ref[name] = sensor
-
- def add_sim_component(
- self,
- name: str,
- obj_config: dict[str, Any],
- ) -> None:
- """Add a static object to the simulation via `IsaacSimObject`.
-
- Args:
- name (str): Name of the object.
- obj_config (dict[str, Any]): Object configuration.
- """
- obj = IsaacSimObject(
- name=name, world=self.world, global_config=self.global_config
- )
- self.object_ref[name] = obj
- log.ok(f"Loaded '{name}' into Isaac Sim stage.")
-
- @staticmethod
- def remove(name: str) -> None:
- log.warn("Dynamic removal is not supported in the IsaacSim backend yet.")
-
- def run(self, sim_node) -> None:
- """Main simulation loop for Isaac Sim integration.
-
- Handles LCM message processing for the simulator node, robots,
- and sensors, then advances physics and rendering each cycle.
-
- Args:
- sim_node: The main ARK simulation node controlling global stepping.
-
- Loop Behavior:
- - Runs while the Isaac Sim app window is open.
- - Processes LCM messages with zero timeout.
- - Calls `sim_node.step()` for ARK updates.
- - Calls backend `_step()` for physics/render stepping.
- """
- lcms = (
- [sim_node._lcm]
- + [r._lcm for r in self.robot_ref.values()]
- + [s._lcm for s in self.sensor_ref.values()]
- )
- while self._app.is_running():
- for lc in lcms:
- lc.handle_timeout(0)
- sim_node.step()
- self._step()
-
- def _step(self) -> None:
- """Execute a single backend simulation step."""
- self._step_sim_components()
- self.world.step(render=True)
- self._simulation_time += self.timestep
-
- def step(self) -> None:
- """Unused ARK interface override."""
- pass
-
- def shutdown_backend(self) -> None:
- """Shutdown Isaac Sim backend and all components."""
- for robot in self.robot_ref:
- self.robot_ref[robot].kill_node()
- for obj in self.object_ref:
- self.object_ref[obj].kill_node()
- for sensor in self.sensor_ref:
- self.sensor_ref[sensor].kill_node()
- if self._app is not None:
- self._app.close()
diff --git a/ark/system/isaac/isaac_camera_driver.py b/ark/system/isaac/isaac_camera_driver.py
deleted file mode 100644
index c7a014c..0000000
--- a/ark/system/isaac/isaac_camera_driver.py
+++ /dev/null
@@ -1,153 +0,0 @@
-"""Isaac Sim camera driver for ARK sensors (e.g., IntelRealSense)."""
-
-from __future__ import annotations
-
-from typing import Any
-
-import numpy as np
-from ark.system.driver.sensor_driver import CameraDriver
-from ark.tools.log import log
-from ark.utils import lazy
-from ark.utils.camera_utils import CameraType
-from isaacsim.sensors.camera import Camera
-from math import cos, sin, radians
-
-
-def _yaw_pitch_roll_to_pose(
- target: tuple[float, float, float],
- distance: float,
- yaw_deg: float,
- pitch_deg: float,
- roll_deg: float,
-) -> tuple[list[float], list[float]]:
- """
- Compute a camera pose (position + quaternion orientation) in the Isaac Sim world frame
- from a target point, a viewing distance, and yaw–pitch–roll Euler angles.
- Args:
- target: The 3D point the camera should look toward, expressed in world coordinates.
- distance: The radius of the spherical shell around the target on which the camera is placed.
- yaw_deg: Yaw angle (in degrees) around the Z-axis.
- pitch_deg: Pitch angle (in degrees) around the Y-axis.
- roll_deg: Roll angle (in degrees) around the X-axis.
-
- Returns:
- position: camera position in world coordinates.
- orientation: Quaternion compatible with Isaac Sim conventions.
-
- """
-
- yaw, pitch, roll = map(radians, (yaw_deg, pitch_deg, roll_deg))
- # Camera looks toward target; position offset in spherical coords
- pos = [
- target[0] + distance * cos(pitch) * cos(yaw),
- target[1] + distance * cos(pitch) * sin(yaw),
- target[2] + distance * sin(pitch),
- ]
-
- quat = lazy.omni.isaac.core.utils.rotations.euler_angles_to_quat(
- (roll, pitch, yaw), degrees=False
- )
- return pos, [quat[1], quat[2], quat[3], quat[0]]
-
-
-class IsaacCameraDriver(CameraDriver):
- """Camera driver that creates and manages a simulated RGB-D camera in Isaac Sim.."""
-
- def __init__(
- self,
- component_name: str,
- component_config: dict[str, Any],
- world: Any,
- ) -> None:
- """
- Initialize the camera driver and create the underlying Isaac Sim camera.
- Args:
- component_name: ARK component name for this camera.
- component_config: User-provided configuration.
- world: The active Isaac Sim world used for simulation stepping and asset creation.
- """
- self.world = world
- self._camera = None
- self._resolution = (640, 480)
- super().__init__(
- component_name=component_name, component_config=component_config
- )
- self._create_camera_prim()
-
- def _create_camera_prim(self) -> None:
- """
- Create and configure the camera prim in the Isaac Sim stage.
-
- Returns:
- None
- """
- sim_cfg = self.config["sim_config"]
- self._resolution = (
- int(self.config.get("width", 640)),
- int(self.config.get("height", 480)),
- )
- prim_path = sim_cfg.get("prim_path", f"/World/{self.component_name}")
- camera_type = self.config.get("camera_type", "fixed").lower()
-
- if camera_type == CameraType.FIXED:
- fix_cfg = sim_cfg.get("fix", {})
- target = fix_cfg.get("camera_target_position", [0.0, 0.0, 0.0])
- distance = fix_cfg.get("distance", 1.0)
- yaw = fix_cfg.get("yaw", 0.0)
- pitch = fix_cfg.get("pitch", 0.0)
- roll = fix_cfg.get("roll", 0.0)
- position, orientation = _yaw_pitch_roll_to_pose(
- target, distance, yaw, pitch, roll
- )
- elif camera_type == CameraType.ATTACHED:
- attach_cfg = sim_cfg.get("attach", {})
- position = attach_cfg.get("offset_translation", [0.0, 0.0, 0.0])
- orientation = attach_cfg.get("offset_rotation", [0.0, 0.0, 0.0, 1.0])
- parent_prim = attach_cfg.get("parent_prim")
- if parent_prim:
- prim_path = f"{parent_prim}/{self.component_name}"
- else:
- log.warn(
- f"Unsupported camera_type '{camera_type}' for Isaac; falling back to fixed."
- )
- position = sim_cfg.get("position", [0.0, 0.0, 1.0])
- orientation = sim_cfg.get("orientation", [0.0, 0.0, 0.0, 1.0])
-
- self._camera = Camera(
- prim_path=prim_path,
- position=position,
- frequency=self.config.get("frequency", 20),
- resolution=self._resolution,
- orientation=orientation,
- )
-
- self._camera.initialize()
- self._camera.add_motion_vectors_to_frame()
-
- # TODO add depth image as well
-
- def get_images(self) -> dict[str, np.ndarray]:
- """
- Retrieve the latest RGB and depth frames from the simulated camera.
-
- Returns:
- dict containing RGB and depth frames.
-
- """
- # Trigger a render; Isaac sensor API returns RGBA + depth
-
- rgb = self._camera.get_rgba()[:, :, :3] # drop alpha
- depth = self._camera.get_depth()
- image_out = dict(color=np.asarray(rgb), depth=np.asarray(depth))
-
- if rgb is None:
- log.warn(f"Camera {self.component_name} has no rgb frames yet.")
- image_out["color"] = np.zeros((*self._resolution[::-1], 3))
- if depth is None:
- log.warn(f"Camera {self.component_name} has no depth frames yet.")
- image_out["depth"] = np.zeros(self._resolution[::-1])
-
- return image_out
-
- def shutdown_driver(self) -> None:
- pass
diff --git a/ark/system/isaac/isaac_object.py b/ark/system/isaac/isaac_object.py
deleted file mode 100644
index a4e45fe..0000000
--- a/ark/system/isaac/isaac_object.py
+++ /dev/null
@@ -1,173 +0,0 @@
-from __future__ import annotations
-
-from typing import Any
-
-import numpy as np
-from ark.system.component.sim_component import SimComponent
-from ark.tools.log import log
-from ark.utils import lazy
-from ark.utils.source_type_utils import SourceType
-from arktypes import flag_t, rigid_body_state_t
-
-
-class IsaacSimObject(SimComponent):
- """Generic Isaac Sim object loader and pose publisher.
-
- This component abstracts loading and managing objects in Isaac Sim using USD, URDF, or simple primitives.
- It automatically loads the asset, creates the corresponding prim hierarchy, attaches a transform handle
- (`XFormPrim`) for pose access, and optionally publishes ground-truth simulation state.
-
- Supported types are:
- - `USD`: load a USD asset via reference.
- - `URDF`: import a URDF model into the stage.
- - `PRIMITIVE`: create a DynamicCuboid representing a simple shape.
-
- Attributes:
- world (Any): Reference to the simulation world.
- _prim_path (str): Path to the Isaac Sim prim representing this object.
- _xform (XFormPrim): Transform handle used for reading and setting poses.
- publisher_name (str): LCM topic name for ground-truth publishing.
- state_publisher (dict[str, Publisher]): Optional publisher for state output.
- """
-
- def __init__(self, name: str, world: Any, global_config: dict[str, Any]) -> None:
- """Initialize and load the object into the Isaac Sim stage.
-
- Depending on the configured source type, the object is created by:
- - Adding a USD reference.
- - Importing a URDF.
- - Constructing a primitive (DynamicCuboid) and optionally disabling physics if the mass is zero.
- TODO - Other shapes needs to be added based on config
-
- Args:
- name (str): Unique component name.
- world (Any): Simulation world or scene container.
- global_config (dict[str, Any]): Global configuration dict.
-
- """
- self.world = world
- self._prim_path = None
- self._xform = None
-
- super().__init__(name=name, global_config=global_config)
-
- self._prim_path = self.config.get("prim_path", f"/World/{name}")
- source_str = self.config["source"]
- source_type = getattr(SourceType, source_str.upper())
-
- if source_type == SourceType.USD:
- usd_path = self.config.get("usd_path")
- if not usd_path:
- raise ValueError(
- f"USD source selected for '{name}' but no usd_path provided."
- )
-
- lazy.omni.isaac.core.utils.stage.add_reference_to_stage(
- str(usd_path), self._prim_path
- )
-
- elif source_type == SourceType.URDF:
- urdf_path = self.config.get("urdf_path")
- if not urdf_path:
- raise ValueError(
- f"URDF source selected for '{name}' but no urdf_path provided."
- )
-
- lazy.isaacsim.asset.importer.urdf.import_urdf(
- str(urdf_path), prim_path=self._prim_path
- )
-
- elif source_type == SourceType.PRIMITIVE:
- object = lazy.isaacsim.core.api.objects.DynamicCuboid(
- name=name,
- position=np.array(self.config["base_position"]),
- prim_path=self._prim_path,
- scale=np.array(self.config["visual"]["visual_shape"]["halfExtents"]),
- size=1.0,
- color=np.array(self.config["visual"]["visual_shape"]["rgbaColor"][:3]),
- )
- if self.config["multi_body"]["baseMass"] == 0:
- object.disable_rigid_body_physics()
-
- self.world.scene.add(object)
-
- else:
- raise RuntimeError(f"Unsupported object source type '{source_type}'.")
-
- # Wrap prim for pose access
- self._xform = lazy.isaacsim.core.prims.XFormPrim(
- prim_paths_expr=self._prim_path, name=name
- )
-
- # Setup publisher for ground truth
- self.publisher_name = f"{self.namespace}/" + self.name + "/ground_truth/sim"
- if self.publish_ground_truth:
- self.state_publisher = self.component_channels_init(
- {self.publisher_name: rigid_body_state_t}
- )
-
- def get_object_data(self) -> dict[str, Any]:
- """Retrieve the object's current pose in world coordinates.
-
- Returns:
- dict[str, Any]: Information describing the object's state:
- {
- "name": str,
- "position": ndarray,
- "orientation": ndarray,
- "lin_velocity": list[float],
- "ang_velocity": list[float],
- }
- """
- position, orientation = self._xform.get_world_poses()
- return {
- "name": self.name,
- "position": position.flatten(),
- "orientation": orientation.flatten(),
- "lin_velocity": [0.0, 0.0, 0.0],
- "ang_velocity": [0.0, 0.0, 0.0],
- }
-
- def pack_data(self, data_dict: dict[str, Any]) -> dict[str, rigid_body_state_t]:
- """Convert pose data into a rigid-body LCM message.
-
- Takes the raw pose data generated by :meth:`get_object_data`,
- constructs a `rigid_body_state_t` message, and returns it mapped
- to the component's publisher name.
-
- Args:
- data_dict (dict[str, Any]): Pose information from
- :meth:`get_object_data`.
-
- Returns:
- dict[str, rigid_body_state_t]: Mapping from publisher topic
- to LCM message ready for transmission.
- """
- msg = rigid_body_state_t()
- msg.name = data_dict["name"]
- msg.position = data_dict["position"]
- msg.orientation = data_dict["orientation"]
- msg.lin_velocity = data_dict["lin_velocity"]
- msg.ang_velocity = data_dict["ang_velocity"]
- return {self.publisher_name: msg}
-
- def reset_component(self, channel, msg) -> flag_t:
- """Reset the object's world pose via an incoming LCM message.
-
- This method updates the underlying prim's world position and
- orientation using data from the received state message.
-
- Args:
- channel (str): LCM channel the message arrived at.
- msg (rigid_body_state_t): Message containing new position and orientation values.
-
- Returns:
- flag_t: Status flag indicating completion.
-
- """
- if self._xform:
- self._xform.set_world_pose(msg.position, msg.orientation)
- log.ok(f"Reset object {self.name} to position {msg.position}")
- else:
- log.warn(f"No XFormPrim available to reset object {self.name}.")
- return flag_t()
diff --git a/ark/system/isaac/isaac_robot_driver.py b/ark/system/isaac/isaac_robot_driver.py
deleted file mode 100644
index a2429cf..0000000
--- a/ark/system/isaac/isaac_robot_driver.py
+++ /dev/null
@@ -1,270 +0,0 @@
-from __future__ import annotations
-
-from abc import abstractmethod
-from pathlib import Path
-from typing import Any
-
-import numpy as np
-from ark.system.driver.robot_driver import SimRobotDriver
-from ark.tools.log import log
-from ark.utils import lazy
-from pxr import Gf, PhysxSchema, Sdf, UsdPhysics
-
-
-class IsaacSimRobotDriver(SimRobotDriver):
- """Isaac Sim robot driver connecting ARK robot commands to an articulation.
-
- The driver acts as the low-level bridge between ARK control messages and
- Isaac Sim physics, ensuring that all robot joints map correctly to
- articulation DOFs.
-
- """
-
- def __init__(
- self,
- component_name: str,
- component_config: dict[str, Any],
- sim_app: Any,
- world: Any,
- ) -> None:
- """Initialize the Isaac Sim robot driver.
-
- Loads and imports the robot asset, creates the articulation, and
- performs an initial reset.
-
- Args:
- component_name (str): Robot name in ARK.
- component_config (dict[str, Any]): Configuration containing:.
- sim_app (SimulationApp): Isaac Sim application instance.
- world (World): World instance to which the articulation is added.
- """
-
- self.sim_app = sim_app
- self.world = world
- self._articulation = None
- self._joint_name_to_index: dict[str, int] = {}
- self.component_name = component_name
-
- super().__init__(
- component_name=component_name, component_config=component_config
- )
- self._load_robot()
- self.sim_reset()
-
- def _load_robot(self) -> None:
- """Import the robot asset and construct the robot articulation."""
-
- self.prim_path = self.config.get("prim_path", f"/World/{self.component_name}")
- urdf_path_cfg = self.config.get("urdf_path")
-
- base_dir = Path(self.config.get("class_dir", ".")).resolve()
-
- def _resolve(path: str | None) -> Path | None:
- if not path:
- return None
- candidate = Path(path)
- return (
- candidate
- if candidate.is_absolute()
- else (base_dir / candidate).resolve()
- )
-
- urdf_path = _resolve(urdf_path_cfg)
- usd_path = None # For future extensions, to load usd file
- self.urdf_path = urdf_path
-
- if urdf_path is None:
- raise ValueError(f"Robot '{self.component_name}' needs a URDF file.")
-
- if usd_path:
- # Load robot from USD file
- lazy.isaacsim.core.utils.stage.add_reference_to_stage(
- str(usd_path), self.prim_path
- )
- self._articulation = lazy.isaacsim.core.prims.Articulation(
- prim_paths_expr=self.prim_path, name=self.component_name
- )
- elif urdf_path:
- # Load robot from URDF file
-
- # Setting up import configuration:
- status, import_config = lazy.omni.kit.commands.execute(
- "URDFCreateImportConfig"
- )
- import_config.merge_fixed_joints = False
- import_config.convex_decomp = False
- import_config.import_inertia_tensor = True
- import_config.fix_base = True
-
- status, self.prim_path = lazy.omni.kit.commands.execute(
- "URDFParseAndImportFile",
- urdf_path=urdf_path,
- import_config=import_config,
- get_articulation_root=True,
- )
-
- # Get stage handle
- stage = lazy.omni.usd.get_context().get_stage()
-
- # Enable physics
- scene = UsdPhysics.Scene.Define(stage, Sdf.Path("/physicsScene"))
- # Set gravity
- scene.CreateGravityDirectionAttr().Set(Gf.Vec3f(0.0, 0.0, -1.0))
- scene.CreateGravityMagnitudeAttr().Set(9.81)
- # Set solver settings
- PhysxSchema.PhysxSceneAPI.Apply(stage.GetPrimAtPath("/physicsScene"))
- physxSceneAPI = PhysxSchema.PhysxSceneAPI.Get(stage, "/physicsScene")
- physxSceneAPI.CreateEnableCCDAttr(True)
- physxSceneAPI.CreateEnableStabilizationAttr(True)
- physxSceneAPI.CreateEnableGPUDynamicsAttr(False)
- physxSceneAPI.CreateBroadphaseTypeAttr("MBP")
- physxSceneAPI.CreateSolverTypeAttr("TGS")
-
- lazy.omni.timeline.get_timeline_interface().play()
- self.sim_app.update()
- self._articulation = lazy.isaacsim.core.prims.Articulation(self.prim_path)
- self.world.scene.add(self._articulation)
- self._articulation.initialize()
-
- # Set initial Position and Orientation
- self.base_position = self.config.get("base_position", [0.0, 0.0, 0.0])
- self.base_orientation = self.config.get(
- "base_orientation", [0.0, 0.0, 0.0, 1.0]
- )
- self.initial_configuration = self.config.get(
- "initial_configuration", [0.0] * len(self._articulation.joint_names)
- )
-
- self._joint_name_to_index = {
- name: self._articulation.get_dof_index(name)
- for name in self._articulation.dof_names
- }
-
- def check_torque_status(self) -> bool:
- """Check whether torque control is enabled.
-
- Returns:
- bool: Always True for this minimal driver implementation.
- """
- return True
-
- def pass_joint_positions(self, joints: list[str]) -> dict[str, float]:
- """Retrieve joint positions for the requested joints.
-
- Args:
- joints (list[str]): Joint names to query.
-
- Returns:
- dict[str, float]: Mapping joint_name → position_value.
- """
- positions = self._articulation.get_joint_positions().flatten()
- return {
- name: float(positions[self._joint_name_to_index[name]])
- for name in joints
- if name in self._joint_name_to_index
- }
-
- def pass_joint_velocities(self, joints: list[str]) -> dict[str, float]:
- """Retrieve joint velocities for the requested joints.
-
- Args:
- joints (list[str]): Joint names to query.
-
- Returns:
- dict[str, float]: Mapping joint_name → velocity_value.
- """
- velocities = self._articulation.get_joint_velocities().flatten()
- return {
- name: float(velocities[self._joint_name_to_index[name]])
- for name in joints
- if name in self._joint_name_to_index
- }
-
- @staticmethod
- def pass_joint_efforts(joints: list[str]) -> dict[str, float]:
- """Retrieve joint efforts.
-
- Notes:
- Efforts are currently not simulated and always return 0.0.
-
- Args:
- joints (list[str]): Joint names to query.
-
- Returns:
- dict[str, float]: Mapping joint_name → effort_value (0.0).
- """
- # TODO check for the implementation
- return {name: 0.0 for name in joints}
-
- def pass_joint_group_control_cmd(
- self, control_mode: str, cmd: dict[str, float], **kwargs
- ) -> None:
- """Send a group joint control command to the robot.
-
- Supported control modes:
- - "position": Sets joint target positions.
- - "velocity": Sets joint target velocities.
- - "torque": Applies joint efforts.
-
- Args:
- control_mode (str): Control mode type.
- cmd (dict[str, float]): Mapping joint_name → target_value.
- **kwargs: Additional unused parameters for compatibility.
- """
-
- positions = self._articulation.get_joint_positions()
- velocities = self._articulation.get_joint_velocities()
- efforts = np.zeros_like(positions)
-
- for joint_name, target in cmd.items():
- if joint_name not in self._joint_name_to_index:
- continue
- idx = self._joint_name_to_index[joint_name]
- if control_mode == "position":
- positions[idx] = target
- elif control_mode == "velocity":
- velocities[idx] = target
- elif control_mode == "torque":
- efforts[idx] = target
- else:
- log.warn(f"Unsupported control_mode '{control_mode}' for Isaac Sim.")
-
- action = lazy.omni.isaac.core.utils.types.ArticulationAction(
- joint_positions=positions,
- joint_velocities=velocities if control_mode == "velocity" else None,
- joint_efforts=efforts if control_mode == "torque" else None,
- )
- self._articulation.apply_action(action)
-
- def sim_reset(self, *kargs, **kwargs) -> None:
- """Reset the robot articulation.
-
- Args:
- *kargs: Ignored.
- **kwargs: Ignored.
- """
- self._articulation.set_world_poses(
- positions=np.array([self.base_position]),
- orientations=np.array([self.base_orientation]),
- )
- if len(self.initial_configuration) > 9:
- q_init = self.initial_configuration[:9]
- else:
- q_init = self.initial_configuration
-
- self._articulation.set_joint_positions([q_init])
- self._articulation.set_joint_velocities(
- np.zeros_like(self._articulation.get_joint_positions())
- )
-
- @abstractmethod
- def pass_cartesian_control_cmd(self, *kargs, **kwargs) -> None:
- """Send a Cartesian-space control command.
-
- Abstract method must be implemented by subclasses.
-
- Args:
- *kargs: Implementation-specific arguments.
- **kwargs: Implementation-specific arguments.
- """
- ...
diff --git a/ark/system/mujoco/mjcf_builder.py b/ark/system/mujoco/mjcf_builder.py
deleted file mode 100644
index f10f697..0000000
--- a/ark/system/mujoco/mjcf_builder.py
+++ /dev/null
@@ -1,577 +0,0 @@
-# mjcf_builder.py
-from __future__ import annotations
-
-from dataclasses import dataclass, field
-import copy
-import math
-import os
-from scipy.spatial.transform import Rotation as R
-import xml.etree.ElementTree as ET
-
-# ----------------------------- Utilities -----------------------------
-
-
-def _attrs(el: ET.Element, **kwargs):
- """Set attributes on an XML element, converting lists/tuples to space-separated strings."""
- for k, v in kwargs.items():
- if v is None:
- continue
- if isinstance(v, (list, tuple)):
- v = " ".join(str(x) for x in v)
- else:
- v = str(v)
- el.set(k, v)
- return el
-
-
-
-# ------------------------------ Data ---------------------------------
-
-
-@dataclass
-class BodySpec:
- name: str
- pos: Optional[List[float]] = None
- quat: Optional[List[float]] = None
- euler: Optional[List[float]] = None
- child_bodies: List["BodySpec"] = field(default_factory=list)
- geoms: List[Dict] = field(default_factory=list)
- joints: List[Dict] = field(default_factory=list)
- sites: List[Dict] = field(default_factory=list)
- cameras: List[Dict] = field(default_factory=list)
-
-
-# --------------------------- MJCF Builder ----------------------------
-
-
-class MJCFBuilder:
- """
- Minimal MJCF builder that supports:
- - assets, bodies, geoms, joints, sites, cameras, actuators, tendons, equalities, contacts
- - wrapping includes in a poseable body (optional free root joint)
- - tracking joint order and body poses
- - creating consolidated keyframes (e.g., 'spawn') with initial positions
- - specifying initial robot base pose and internal joint configuration at spawn
- """
-
- def __init__(self, model_name: str = "world"):
- self.model_name = model_name
-
- # Root & top-level sections
- self.root = ET.Element("mujoco")
- self.root.set("model", model_name)
-
- self.compiler = ET.SubElement(self.root, "compiler")
- _attrs(self.compiler, angle="degree", coordinate="local")
-
- self.option = ET.SubElement(self.root, "option")
- _attrs(self.option, timestep="0.002")
-
- self.size = ET.SubElement(self.root, "size")
- self.asset = ET.SubElement(self.root, "asset")
- self.default = ET.SubElement(self.root, "default")
- self.worldbody = ET.SubElement(self.root, "worldbody")
- self.actuator = ET.SubElement(self.root, "actuator")
- self.sensor = ET.SubElement(self.root, "sensor")
- self.tendon = ET.SubElement(self.root, "tendon")
- self.equality = ET.SubElement(self.root, "equality")
- self.contact = ET.SubElement(self.root, "contact")
- self.keyframe = ET.SubElement(self.root, "keyframe")
-
- # Asset bookkeeping
- self._materials: Dict[str, Dict] = {}
- self._textures: Dict[str, Dict] = {}
- self._meshes: Dict[str, Dict] = {}
- self._robots: Dict[str, Dict] = {}
-
- # For quick lookup to attach children under an existing body
- self._bodies: Dict[str, ET.Element] = {"__WORLD__": self.worldbody}
-
- # Track joint order and body pose for keyframe synthesis
- self._joint_order: List[Dict] = [] # each: {name, body, type, size}
- self._body_pose: Dict[str, Dict] = {} # body -> {pos:[3], quat:[4]}
-
- # Global defaults for make_spawn_keyframe (merged with per-call overrides)
- self._joint_defaults: Dict[str, List[float]] = {}
-
- # ---------- Global configuration ----------
- def set_compiler(self, **kwargs) -> "MJCFBuilder":
- _attrs(self.compiler, **kwargs)
- return self
-
- def set_option(self, **kwargs) -> "MJCFBuilder":
- _attrs(self.option, **kwargs)
- return self
-
- # ---------- Assets ----------
- def add_texture(self, name: str, **kwargs) -> "MJCFBuilder":
- if name not in self._textures:
- tex = ET.SubElement(self.asset, "texture")
- _attrs(tex, name=name, **kwargs)
- self._textures[name] = kwargs
- return self
-
- def add_material(self, name: str, **kwargs) -> "MJCFBuilder":
- if name not in self._materials:
- mat = ET.SubElement(self.asset, "material")
- _attrs(mat, name=name, **kwargs)
- self._materials[name] = kwargs
- return self
-
- def add_mesh(
- self, name: str, file: Optional[str] = None, **kwargs
- ) -> "MJCFBuilder":
- if name not in self._meshes:
- m = ET.SubElement(self.asset, "mesh")
- _attrs(m, name=name, file=file, **kwargs)
- self._meshes[name] = {"file": file, **kwargs}
- return self
-
- # ---------- Bodies / Robots / Objects ----------
- @staticmethod
- def _joint_qpos_size(jtype: Optional[str]) -> int:
- if jtype == "free":
- return 7
- if jtype == "ball":
- return 4
- return 1 # revolute/slide/hinge/etc.
-
- def add_body(
- self,
- name: str,
- parent: str = "__WORLD__",
- pos: Optional[List[float]] = None,
- quat: Optional[List[float]] = None,
- euler: Optional[List[float]] = None,
- ) -> "MJCFBuilder":
- parent_el = self._bodies[parent]
- b = ET.SubElement(parent_el, "body")
- _attrs(b, name=name, pos=pos, quat=quat, euler=euler)
- self._bodies[name] = b
-
- # Record pose for keyframe building (priority: quat > euler > identity)
- if quat is not None:
- q = quat
- elif euler is not None:
- deg = self.compiler.get("angle", "radian") == "degree"
- r = R.from_euler('xyz', euler, degrees=deg)
- q = r.as_quat(scalar_first=True).tolist()
- else:
- q = [1, 0, 0, 0]
- p = pos if pos is not None else [0, 0, 0]
- self._body_pose[name] = {"pos": p, "quat": q}
- return self
-
- def add_geom(self, body: str, **kwargs) -> "MJCFBuilder":
- b = self._bodies[body]
- g = ET.SubElement(b, "geom")
- _attrs(g, **kwargs)
- return self
-
- def add_joint(self, body: str, **kwargs) -> "MJCFBuilder":
- """
- Adds a joint under `body`. If no name is provided, auto-generate one.
- Records joint order and qpos size to help building keyframes later.
- """
- b = self._bodies[body]
- j = ET.SubElement(b, "joint")
-
- jtype = kwargs.get("type")
- jname = kwargs.get("name")
- if jname is None:
- count = sum(1 for jinfo in self._joint_order if jinfo.get("body") == body)
- jname = f"{body}_joint_{count}"
- kwargs["name"] = jname
-
- _attrs(j, **kwargs)
-
- if jtype is not None:
- self._joint_order.append(
- {
- "name": jname,
- "body": body,
- "type": jtype,
- "size": self._joint_qpos_size(jtype),
- }
- )
- return self
-
- def add_site(self, body: str, **kwargs) -> "MJCFBuilder":
- b = self._bodies[body]
- s = ET.SubElement(b, "site")
- _attrs(s, **kwargs)
- return self
-
- # ---------- Cameras ----------
- def add_camera(
- self, parent: str = "__WORLD__", name: Optional[str] = None, **kwargs
- ) -> "MJCFBuilder":
- p = self._bodies[parent]
- c = ET.SubElement(p, "camera")
- _attrs(c, name=name, **kwargs)
- return self
-
- # ---------- Sensors ----------
- def add_sensor(self, stype: str, **kwargs) -> "MJCFBuilder":
- s = ET.SubElement(self.sensor, stype)
- _attrs(s, **kwargs)
- return self
-
- # ---------- High-level loaders ----------
- def load_object(
- self,
- name: str,
- shape: str,
- size: Union[List[float], float],
- pos=(0, 0, 0),
- quat: Optional[List[float]] = None,
- density: Optional[float] = None,
- mass: Optional[float] = None,
- rgba: Optional[List[float]] = None,
- free: bool = True,
- **geom_kwargs,
- ) -> "MJCFBuilder":
- """Convenience: create a body + (optional) free joint + geom."""
- self.add_body(name=name, pos=pos, quat=quat)
- if free:
- self.add_joint(name, type="free", name=f"{name}_root")
- self.add_geom(
- name,
- type=shape,
- size=size,
- density=density,
- mass=mass,
- rgba=rgba,
- **geom_kwargs,
- )
- return self
-
- def load_robot_from_spec(
- self, root: BodySpec, parent: str = "__WORLD__"
- ) -> "MJCFBuilder":
- """Recursively create a robot from a BodySpec tree."""
-
- def _recurse(spec: BodySpec, parent_body: str):
- self.add_body(
- spec.name,
- parent=parent_body,
- pos=spec.pos,
- quat=spec.quat,
- euler=spec.euler,
- )
- for j in spec.joints:
- self.add_joint(spec.name, **j)
- for g in spec.geoms:
- self.add_geom(spec.name, **g)
- for s in spec.sites:
- self.add_site(spec.name, **s)
- for cam in spec.cameras:
- self.add_camera(parent=spec.name, **cam)
- for child in spec.child_bodies:
- _recurse(child, spec.name)
-
- _recurse(root, parent)
- return self
-
- def include(self, file: str, parent: str | None = None) -> "MJCFBuilder":
- """
- Insert either at the root (parent=None),
- inside worldbody (parent="__WORLD__"), or inside a specific body.
- """
- if parent is None:
- inc_parent = self.root
- elif parent == "__WORLD__":
- inc_parent = self.worldbody
- else:
- inc_parent = self._bodies[parent]
- inc = ET.SubElement(inc_parent, "include")
- inc.set("file", file)
- return self
-
- def include_robot(
- self,
- name: str,
- file: str,
- parent: str = "__WORLD__",
- pos: Optional[List[float]] = None,
- quat: Optional[List[float]] = None,
- euler: Optional[List[float]] = None,
- fixed_base: bool = False,
- root_joint_name: Optional[str] = None,
- *,
- qpos: Optional[List[float]] = None,
- ) -> "MJCFBuilder":
- """
- Wrap an in a named body so you can position the whole robot and
- (optionally) give it a free root joint to control base pose via keyframes.
-
- fixed_base:
- If True, the robot's base is fixed to the world (no root joint).
- If False, a free joint named ``root_joint_name`` (or ``f"{name}_root"``) is added.
-
- qpos:
- Flat list of the robot's *internal* generalized coordinates (excluding any free base).
- Stored as a single packed block for make_spawn_keyframe().
- """
-
- # Ensure internal bookkeeping exists
- if not hasattr(self, "_joint_order"):
- self._joint_order = []
- if not hasattr(self, "_joint_defaults"):
- self._joint_defaults = {}
-
- # 1) Create a wrapper body for the robot so we can position the entire model
- self.add_body(name, parent=parent, pos=pos, quat=quat, euler=euler)
-
- # 2) Optionally give the wrapper body a free joint to control base pose
- if not fixed_base:
- jname = root_joint_name if root_joint_name is not None else f"{name}_root"
-
- # Be defensive about builder signatures to avoid 'name' collisions.
- # Preferred: add_joint(body=..., ...attrs)
- try:
- self.add_joint(body=name, type="free", name=jname)
- except TypeError:
- # Fallback: some builders use (parent_or_body_name, **attrs)
- self.add_joint(name, **{"type": "free", "name": jname})
-
- # 3) Merge the referenced MJCF file into this model
- tree = ET.parse(file)
- root = tree.getroot()
-
- # Merge compiler attributes (e.g., meshdir) relative to the robot file
- comp = root.find("compiler")
- if comp is not None:
- attrs = dict(comp.attrib)
- if "meshdir" in attrs:
- # Normalize to the directory of the robot file to avoid double prefixes
- robot_dir = os.path.dirname(os.path.abspath(file))
- attrs["meshdir"] = robot_dir
- # Ensure self.compiler exists (your builder likely created it at init)
- _attrs(self.compiler, **attrs)
-
- # Merge assets, defaults, tendons, equalities, actuators, contacts
- for sec_name, target in [
- ("asset", self.asset),
- ("default", self.default),
- ("tendon", self.tendon),
- ("equality", self.equality),
- ("actuator", self.actuator),
- ("contact", self.contact),
- ]:
- src = root.find(sec_name)
- if src is not None:
- for child in src:
- target.append(copy.deepcopy(child))
-
- # Insert robot bodies under the wrapper body
- wb = root.find("worldbody")
- if wb is not None:
- wrapper = self._bodies[name]
- for body in wb:
- wrapper.append(copy.deepcopy(body))
-
- # 4) Record this robot's internal qpos as a single packed block
- if qpos is not None:
- block_name = f"{name}/*" # synthetic ID for packed internal qpos
- block_vals = [float(x) for x in qpos]
- self._joint_order.append(
- {
- "name": block_name,
- "body": name,
- "type": "packed",
- "size": len(block_vals),
- }
- )
- self._joint_defaults[block_name] = block_vals
-
- return self
-
- # ---------- Keyframes ----------
- def add_keyframe(
- self,
- name: str,
- qpos: Optional[List[float]] = None,
- qvel: Optional[List[float]] = None,
- ctrl: Optional[List[float]] = None,
- act: Optional[List[float]] = None,
- mocap_pos: Optional[List[float]] = None,
- mocap_quat: Optional[List[float]] = None,
- ) -> "MJCFBuilder":
- """
- Adds a single inside .
- Lists are serialized as space-separated strings, e.g. qpos="0 0 0 1 0 0 0 ..."
- """
- k = ET.SubElement(self.keyframe, "key")
- _attrs(
- k,
- name=name,
- qpos=qpos,
- qvel=qvel,
- ctrl=ctrl,
- act=act,
- mocap_pos=mocap_pos,
- mocap_quat=mocap_quat,
- )
- return self
-
- def add_keyframes(self, keys: List[Dict]) -> "MJCFBuilder":
- """Bulk add keyframes: each dict can include any fields accepted by add_keyframe()."""
- for k in keys:
- self.add_keyframe(**k)
- return self
-
- def joint_order(self) -> List[str]:
- """Returns the joint names in qpos order (useful for building qpos vectors)."""
- return [j["name"] for j in self._joint_order]
-
- def make_spawn_keyframe(
- self,
- name: str = "spawn",
- joint_defaults: Optional[Dict[str, Union[List[float], float]]] = None,
- ) -> "MJCFBuilder":
- """
- Assemble a single keyframe named `name` that sets:
- - free joints to the current body pose (pos + quat)
- - other joints to 0, unless specified in defaults
- 'joint_defaults' merges over any defaults previously provided via include_robot(..., joint_qpos=...).
-
- joint_defaults can map joint_name -> scalar (1-dof) or list (4/7-dof).
- """
- # Merge global defaults with per-call overrides
- merged_defaults: Dict[str, List[float]] = dict(self._joint_defaults)
- if joint_defaults:
- for k, v in joint_defaults.items():
- if isinstance(v, (int, float)):
- merged_defaults[k] = [float(v)]
- else:
- merged_defaults[k] = [float(x) for x in v]
-
- qpos: List[float] = []
-
- for j in self._joint_order:
- jn, jb, jt, sz = j["name"], j["body"], j["type"], j["size"]
-
- # If user provided explicit values for this joint
- if jn in merged_defaults:
- val = merged_defaults[jn]
- # allow scalar for 1-dof
- if len(val) == 1 and sz > 1:
- raise ValueError(
- f"Default for joint '{jn}' must have length {sz}, got 1"
- )
- if len(val) != sz:
- raise ValueError(
- f"joint_defaults[{jn}] length {len(val)} != expected {sz}"
- )
- qpos.extend(val)
- continue
-
- # Otherwise derive a sensible default
- if jt == "free":
- pose = self._body_pose.get(jb, {"pos": [0, 0, 0], "quat": [1, 0, 0, 0]})
- p = pose["pos"]
- q = pose["quat"] # [w, x, y, z]
- qpos.extend([p[0], p[1], p[2], q[0], q[1], q[2], q[3]])
- elif jt == "ball":
- qpos.extend([1.0, 0.0, 0.0, 0.0]) # identity quaternion
- else:
- qpos.append(0.0) # revolute/slide default
-
- self.add_keyframe(name=name, qpos=qpos)
- return self
-
- # ---------- Serialization ----------
- def to_string(self, pretty: bool = True) -> str:
- """Return the MJCF XML as a string."""
-
- def indent(elem, level=0):
- i = "\n" + level * " "
- if len(elem):
- if not elem.text or not elem.text.strip():
- elem.text = i + " "
- for e in elem:
- indent(e, level + 1)
- if not e.tail or not e.tail.strip():
- e.tail = i
- if level and (not elem.tail or not elem.tail.strip()):
- elem.tail = i
-
- if pretty:
- indent(self.root)
- return ET.tostring(self.root, encoding="unicode")
-
- def update_on_load(self, model):
- for robots in self.robots:
- root_joint_id = mujoco.mj_name2id(
- model, mujoco.mjtObj.mjOBJ_JOINT, "panda_root"
- )
- pass
-
-
-# --------------------------- Example usage ---------------------------
-
-if __name__ == "__main__":
- # Minimal example demonstrating include + spawn keyframe with base pose + joint config
- builder = (
- MJCFBuilder("demo_world")
- .set_compiler(angle="radian", meshdir="franka_emika_panda")
- .set_option(timestep="0.002")
- )
-
- # Ground
- builder.add_body("floor", pos=[0, 0, 0]).add_geom(
- "floor", type="plane", size=[10, 10, 0.1], rgba=[0.8, 0.8, 0.8, 1]
- )
-
- # Include a Franka Panda, wrapped in a poseable body with a free root joint
- # Provide its internal joint list in the correct qpos order and initial joint angles.
- builder.include_robot(
- name="panda",
- file="franka_emika_panda/panda.xml",
- pos=[0.3, -0.2, 0.0], # initial XYZ
- euler=[0, 0, 0], # initial orientation (XYZ Euler)
- fixed_base=False, # add a free joint controlling base pose
- root_joint_name="panda_root",
- internal_joints=[
- # order MUST match the MuJoCo qpos order inside the included file
- "joint1",
- "joint2",
- "joint3",
- "joint4",
- "joint5",
- "joint6",
- "joint7",
- "finger_joint1",
- "finger_joint2",
- ],
- joint_qpos={
- # radians for hinges; these will be baked into the 'spawn' keyframe
- "joint1": 0.0,
- "joint2": -0.6,
- "joint3": 0.0,
- "joint4": -2.2,
- "joint5": 0.0,
- "joint6": 1.6,
- "joint7": 0.8,
- "finger_joint1": 0.04,
- "finger_joint2": 0.04,
- },
- )
-
- # A free object with its own root free joint
- builder.load_object(
- name="cube",
- shape="box",
- size=[0.05, 0.05, 0.05],
- pos=[0.6, 0.0, 0.15],
- quat=[1, 0, 0, 0],
- rgba=[0.2, 0.6, 1.0, 1.0],
- free=True,
- )
-
- # Create a consolidated keyframe capturing initial positions & robot joint angles
- builder.make_spawn_keyframe(name="spawn")
-
- xml_text = builder.to_string(pretty=True)
- print(xml_text)
diff --git a/ark/system/mujoco/mujoco_backend.py b/ark/system/mujoco/mujoco_backend.py
deleted file mode 100644
index 7b1cf35..0000000
--- a/ark/system/mujoco/mujoco_backend.py
+++ /dev/null
@@ -1,282 +0,0 @@
-import importlib.util
-import sys, ast, os
-from pathlib import Path
-from typing import Any, Optional
-
-import mujoco
-import mujoco.viewer
-
-from ark.system.mujoco.mjcf_builder import MJCFBuilder
-
-from ark.tools.log import log
-from ark.system.simulation.simulator_backend import SimulatorBackend
-from ark.system.mujoco.mujoco_multibody import MujocoMultiBody
-from arktypes import *
-
-import textwrap
-
-
-def import_class_from_directory(path: Path) -> tuple[type, Optional[type]]:
- """!Load a class from ``path``.
-
- The helper searches for ``.py`` inside ``path`` and imports the
- class with the same name. If a ``Drivers`` class is present in the module
- its ``MUJOCO_DRIVER`` attribute is returned alongside the main class.
-
- @param path Path to the directory containing the module.
- @return Tuple ``(cls, driver_cls)`` where ``driver_cls`` is ``None`` when no
- driver is defined.
- @rtype Tuple[type, Optional[type]]
- """
- # Extract the class name from the last part of the directory path (last directory name)
- class_name = path.name
- file_path = path / f"{class_name}.py"
-
- # get the full absolute path
- file_path = file_path.resolve()
- if not file_path.exists():
- raise FileNotFoundError(f"The file {file_path} does not exist.")
-
- with open(file_path, "r", encoding="utf-8") as file:
- tree = ast.parse(file.read(), filename=file_path)
-
- # for imports
- module_dir = os.path.dirname(file_path)
- sys.path.insert(0, module_dir)
-
- # Extract class names from the AST
- class_names = [
- node.name for node in ast.walk(tree) if isinstance(node, ast.ClassDef)
- ]
-
- # check if Sensor_Drivers is in the class_names
- if "Drivers" in class_names:
- # Load the module dynamically
- spec = importlib.util.spec_from_file_location(class_names[0], file_path)
- module = importlib.util.module_from_spec(spec)
- sys.modules[class_names[0]] = module
- spec.loader.exec_module(module)
-
- class_ = getattr(module, class_names[0])
- sys.path.pop(0)
-
- drivers = class_.MUJOCO_DRIVER
- class_names.remove("Drivers")
-
- # Retrieve the class from the module (has to be list of one)
- class_ = getattr(module, class_names[0])
-
- if len(class_names) != 1:
- raise ValueError(
- f"Expected exactly two class definition in {file_path}, but found {len(class_names)}."
- )
-
- # Load the module dynamically
- spec = importlib.util.spec_from_file_location(class_name, file_path)
- module = importlib.util.module_from_spec(spec)
- sys.modules[class_name] = module
- spec.loader.exec_module(module)
-
- # Retrieve the class from the module (has to be list of one)
- class_ = getattr(module, class_names[0])
- sys.path.pop(0)
-
- # Return the class
- return class_, drivers
-
-
-class MujocoBackend(SimulatorBackend):
-
- def initialize(self) -> None:
- """!Initialize the MuJoCo simulation backend."""
- self.builder = MJCFBuilder("ARK Mujoco").set_compiler(
- angle="radian", meshdir="ark_mujoco_assets"
- )
-
- gravity = self.global_config["simulator"]["config"].get(
- "gravity", [0, 0, -9.81]
- )
- self.set_gravity(gravity)
-
- if self.global_config.get("objects", None):
- for object_name, object_config in self.global_config["objects"].items():
- self.add_sim_component(object_name, object_config)
-
- if self.global_config.get("robots", None):
- for robot_name, robot_config in self.global_config["robots"].items():
- self.add_robot(robot_name, robot_config)
-
- if self.global_config.get("sensors", None):
- for sensor_name, sensor_config in self.global_config["sensors"].items():
- self.add_sensor(sensor_name, sensor_config)
-
- self.builder.make_spawn_keyframe(name="spawn")
- xml_string = self.builder.to_string(pretty=True)
-
- self.model = mujoco.MjModel.from_xml_string(xml_string)
- self.data = mujoco.MjData(self.model)
-
- self.camera_id = mujoco.mj_name2id(
- self.model, mujoco.mjtObj.mjOBJ_CAMERA, "overview"
- )
- self.renderer = mujoco.Renderer(self.model, 100, 100)
-
- if (
- self.global_config["simulator"]["config"]["connection_mode"].upper()
- == "GUI"
- ):
- self.headless = False
- self.viewer = mujoco.viewer.launch_passive(
- self.model, self.data, show_left_ui=False, show_right_ui=False
- )
- else:
- self.headless = True
-
- for object_name in self.object_ref:
- self.object_ref[object_name].update_ids(self.model, self.data)
-
- for sensor_name in self.sensor_ref:
- self.sensor_ref[sensor_name]._driver.update_ids(self.model, self.data)
-
- for robot_name in self.robot_ref:
- self.robot_ref[robot_name]._driver.update_ids(self.model, self.data)
-
- self.timestep = 1 / self.global_config["simulator"]["config"].get(
- "sim_frequency", 500.0
- )
-
- def set_gravity(self, gravity: tuple[float, float, float]) -> None:
- """!Set the gravity vector for the simulation.
-
- @param gravity Gravity vector ``(x, y, z)`` in ``m/s^2``.
- @return ``None``
- """
- self.builder.set_option(gravity=gravity)
-
- def reset_simulator(self) -> None:
- """!Reset the simulator to the initial keyframe."""
- key_id = mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_KEY, "spawn")
- if key_id < 0:
- raise ValueError("Keyframe 'spawn' not found")
-
- mujoco.mj_resetDataKeyframe(self.model, self.data, key_id)
-
- self.data.qvel[:] = 0.0
- self.data.qacc[:] = 0.0
- if self.model.nu > 0:
- self.data.ctrl[:] = 0.0
- if self.model.na > 0:
- self.data.act[:] = 0.0
-
- mujoco.mj_forward(self.model, self.data)
-
- def add_robot(
- self,
- name: str,
- robot_config: dict[str, Any],
- ) -> None:
- """!Add a robot to the simulation.
-
- @param name Name of the robot.
- @param robot_config Configuration dictionary for the robot.
- @return ``None``
- """
- class_path = Path(robot_config["class_dir"])
- if class_path.is_file():
- class_path = class_path.parent
- RobotClass, DriverClass = import_class_from_directory(class_path)
- DriverClass = DriverClass.value
-
- driver = DriverClass(name, component_config=robot_config, builder=self.builder)
- robot = RobotClass(
- name=name,
- driver=driver,
- global_config=self.global_config,
- )
- self.robot_ref[name] = robot
-
- def add_sensor(
- self,
- name: str,
- sensor_config: dict[str, Any],
- ) -> None:
- """!Add a sensor to the simulation.
-
- @param name Name of the sensor.
- @param sensor_config Configuration dictionary for the sensor.
- @return ``None``
- """
- class_path = Path(sensor_config["class_dir"])
- if class_path.is_file():
- class_path = class_path.parent
-
- SensorClass, DriverClass = import_class_from_directory(class_path)
- DriverClass = DriverClass.value
-
- attached_body_id = None
- if sensor_config["sim_config"].get("attach", None):
- raise NotImplementedError(
- "Attaching sensors to bodies is not implemented for MuJoCo yet."
- )
-
- driver = DriverClass(
- name, sensor_config, attached_body_id, builder=self.builder
- )
- sensor = SensorClass(
- name=name,
- driver=driver,
- global_config=self.global_config,
- )
- self.sensor_ref[name] = sensor
-
- def add_sim_component(
- self,
- name: str,
- obj_config: dict[str, Any],
- ) -> None:
- """!Register a static simulation component.
-
- @param name Name of the component.
- @param obj_config Configuration dictionary for the component.
- @return ``None``
- """
- sim_component = MujocoMultiBody(
- name=name, builder=self.builder, global_config=self.global_config
- )
- self.object_ref[name] = sim_component
-
- def _all_available(self) -> bool:
- """!Check whether all registered components are active.
-
- @return ``True`` if no component is suspended.
- @rtype bool
- """
- for robot_name in self.robot_ref:
- if self.robot_ref[robot_name]._is_suspended:
- return False
- for object_name in self.object_ref:
- if self.object_ref[object_name]._is_suspended:
- return False
- return True
-
- def remove(self, name: str) -> None:
- """!Remove a component from the simulation."""
- raise NotImplementedError("Mujoco does not support removing objects once loaded into XML.")
-
- def step(self) -> None:
- """!Step the simulator forward by one time step."""
- if self._all_available():
- self._step_sim_components()
- mujoco.mj_step(self.model, self.data)
-
- if not self.headless:
- self.viewer.sync()
-
- self._simulation_time = self.data.time
- else:
- log.panda("Did not step")
-
- def shutdown_backend(self) -> None:
- """!Shut down the simulation backend."""
- if not self.headless:
- self.viewer.close()
diff --git a/ark/system/mujoco/mujoco_camera_driver.py b/ark/system/mujoco/mujoco_camera_driver.py
deleted file mode 100644
index e771465..0000000
--- a/ark/system/mujoco/mujoco_camera_driver.py
+++ /dev/null
@@ -1,101 +0,0 @@
-from enum import Enum
-from typing import Any, Dict, List, Optional
-
-import mujoco
-import numpy as np
-from scipy.spatial.transform import Rotation as R
-
-from ark.system.driver.sensor_driver import CameraDriver
-
-class CameraType(Enum):
- """Supported camera models."""
-
- FIXED = "fixed"
- ATTACHED = "attached"
-
-
-class MujocoCameraDriver(CameraDriver):
- """Camera driver implementation for MuJoCo."""
-
- def __init__(
- self,
- component_name: str,
- component_config: Dict[str, Any],
- attached_body_id: Optional[int] = None,
- builder: Any | None = None,
- ) -> None:
- """!Create a new camera driver.
-
- @param component_name Name of the camera component.
- @param component_config Configuration dictionary for the camera.
- @param attached_body_id ID of the body to attach the camera to.
- @param builder Optional MJCF builder instance.
- @return ``None``
- """
- super().__init__(component_name, component_config, True)
-
- self.name = component_name
- self.parent = "__WORLD__" # All cameras are fixed to the world
- self.width = component_config.get("width", 100)
- self.height = component_config.get("height", 100)
-
- sim_config = component_config.get("sim_config", {})
- self.fov = sim_config.get("fov", 45) # Default field of view
-
- if "quaternion" in sim_config:
- quaternion = sim_config.get("quaternion")
- self.quaternion = [
- quaternion[3],
- quaternion[0],
- quaternion[1],
- quaternion[2],
- ]
- else:
- self.quaternion = [1.0, 0.0, 0.0, 0.0]
-
- self.position = sim_config.get("position", [0.0, 0.0, 1.5])
-
- if builder is not None:
- builder.load_camera(
- name=self.name,
- parent=self.parent,
- pos=self.position,
- quat=self.quaternion,
- fov=self.fov,
- )
-
- def update_ids(self, model: mujoco.MjModel, data: mujoco.MjData) -> None:
- """!Update IDs and create a renderer.
-
- @param model MuJoCo model instance.
- @param data MuJoCo data instance.
- @return ``None``
- """
- self.model = model
- self.data = data
- self.camera_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_CAMERA, self.name)
- self.renderer = mujoco.Renderer(self.model, self.width, self.height)
-
- def get_xml_config(self) -> tuple[str, str, Optional[str]]:
- """!Return the XML configuration snippet for this camera."""
- return self.xml_config
-
- def get_images(self) -> Dict[str, np.ndarray]:
- """!Capture the current color and depth images.
-
- @return Dictionary containing ``color`` and ``depth`` arrays.
- @rtype Dict[str, np.ndarray]
- """
- self.renderer.update_scene(self.data, camera=self.camera_id)
- rgb_image = self.renderer.render()
-
- # Flip the RGB image (MuJoCo uses bottom-left as the origin)
- rgb_image = np.flipud(rgb_image.copy())
- return {
- "color": rgb_image,
- "depth": np.zeros(rgb_image.shape[:2], dtype=np.float32),
- }
-
- def shutdown_driver(self) -> None:
- """!Shut down the camera driver."""
- super().shutdown_driver()
diff --git a/ark/system/mujoco/mujoco_multibody.py b/ark/system/mujoco/mujoco_multibody.py
deleted file mode 100644
index 507f9c6..0000000
--- a/ark/system/mujoco/mujoco_multibody.py
+++ /dev/null
@@ -1,159 +0,0 @@
-from enum import Enum
-from typing import Any, Optional
-
-import mujoco
-
-from ark.tools.log import log
-from ark.system.component.sim_component import SimComponent
-from arktypes import rigid_body_state_t
-
-
-class SourceType(Enum):
- """Supported source types for object creation."""
-
- URDF = "urdf"
- PRIMITIVE = "primitive"
- SDF = "sdf"
- MJCF = "mjcf"
-
-
-SHAPE_MAP = {
- "GEOM_BOX": "box",
- "GEOM_SPHERE": "sphere",
- "GEOM_CAPSULE": "capsule",
- "GEOM_CYLINDER": "cylinder",
- # add more if you need
-}
-
-
-class MujocoMultiBody(SimComponent):
- """MuJoCo multi-body simulation component."""
-
- def __init__(
- self,
- name: str,
- builder: Any,
- global_config: dict[str, Any] | None = None,
- ) -> None:
- """!Initialize a multi-body object.
-
- @param name Name of the component.
- @param builder MJCF builder used to generate the object.
- @param global_config Global configuration dictionary.
- @return ``None``
- """
- super().__init__(name, global_config)
-
- source_str = self.config["source"]
- source_type = getattr(SourceType, source_str.upper())
-
- if source_type == SourceType.URDF:
- raise NotImplementedError(
- "Loading from URDF is not implemented for MujocoMultiBody."
- )
- elif source_type == SourceType.PRIMITIVE:
- visual_config = self.config.get("visual")
- if visual_config:
- visual_shape_type = SHAPE_MAP[visual_config["shape_type"].upper()]
-
- visual_shape = visual_config["visual_shape"]
-
- if visual_shape_type == "box":
- extents_size = [s * 1 for s in visual_shape["halfExtents"]]
-
- if visual_shape_type == "sphere":
- extents_size = [visual_shape["radius"]]
-
- rgba = visual_shape.get("rgbaColor", [1, 1, 1, 1])
- else:
- raise ValueError(
- "Visual configuration is required for primitive shapes."
- )
-
- collision_config = self.config.get("collision")
- if collision_config:
- log.warning(
- "Collision shapes are not supported in MujocoMultiBody yet, it is defaulted to visual sizes"
- )
-
- multibody_config = self.config["multi_body"]
- base_position = self.config["base_position"]
- base_orientation = self.config["base_orientation"]
- base_orientation = [
- base_orientation[1],
- base_orientation[2],
- base_orientation[3],
- base_orientation[0],
- ]
-
- if multibody_config["baseMass"] == 0:
- free = False
- mass = 0.001
- else:
- free = True
- mass = multibody_config["baseMass"]
-
- builder.load_object(
- name=name,
- shape=visual_shape_type,
- size=extents_size,
- pos=base_position,
- quat=base_orientation,
- rgba=rgba,
- free=free,
- mass=mass,
- )
-
- self.publisher_name = self.name + "/ground_truth/sim"
- if self.publish_ground_truth:
- self.state_publisher = self.component_channels_init(
- {self.publisher_name: rigid_body_state_t}
- )
-
- def update_ids(self, model: mujoco.MjModel, data: mujoco.MjData) -> None:
- """!Update internal identifiers from MuJoCo.
-
- @param model MuJoCo model instance.
- @param data MuJoCo data instance.
- @return ``None``
- """
- self.model = model
- self.data = data
- self.body_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_BODY, self.name)
-
- def get_xml_config(self) -> tuple[str, str, Optional[str]]:
- """!Return the XML configuration snippet for this object."""
- return self.xml_config
-
- def pack_data(self, data: dict[str, Any]) -> dict[str, Any]:
- """!Pack object data into the message format."""
- msg = rigid_body_state_t()
- msg.name = data["name"]
- msg.position = data["position"]
- msg.orientation = data["orientation"]
- msg.lin_velocity = data["lin_velocity"]
- msg.ang_velocity = data["ang_velocity"]
- return {self.publisher_name: msg}
-
- def get_object_data(self) -> Any:
- """!Retrieve the current state of the simulated object."""
- position = self.data.xpos[self.body_id]
- orientation = self.data.xquat[self.body_id]
- orientation = [orientation[1], orientation[2], orientation[3], orientation[0]]
-
- velocity = self.data.cvel[self.body_id]
- linear_velocity = velocity[:3]
- angular_velocity = velocity[3:]
- return {
- "name": self.name,
- "position": position.tolist(),
- "orientation": orientation,
- "lin_velocity": linear_velocity.tolist(),
- "ang_velocity": angular_velocity.tolist(),
- }
-
- def reset_component(self, channel: str, msg: Any) -> None:
- """!Reset the component (not implemented)."""
- raise NotImplementedError(
- "Resetting components is not implemented for MujocoMultiBody."
- )
diff --git a/ark/system/mujoco/mujoco_robot_driver.py b/ark/system/mujoco/mujoco_robot_driver.py
deleted file mode 100644
index 6ccb9cf..0000000
--- a/ark/system/mujoco/mujoco_robot_driver.py
+++ /dev/null
@@ -1,279 +0,0 @@
-from pathlib import Path
-from typing import Any
-
-import mujoco
-import numpy as np
-
-from ark.tools.log import log
-from ark.system.driver.robot_driver import SimRobotDriver
-
-
-class MujocoRobotDriver(SimRobotDriver):
- """Robot driver for MuJoCo simulations."""
-
- @staticmethod
- def body_subtree(model: mujoco.MjModel, root_body_id: int) -> list[int]:
- """Return IDs of bodies in the subtree rooted at ``root_body_id``."""
- descendants: list[int] = []
- stack = [root_body_id]
- visited = {root_body_id}
- while stack:
- current_id = stack.pop()
- descendants.append(current_id)
- for child in range(model.nbody):
- if model.body_parentid[child] == current_id and child not in visited:
- visited.add(child)
- stack.append(child)
- return descendants
-
- @classmethod
- def joints_for_body(
- cls, model: mujoco.MjModel, body_id: int
- ) -> list[tuple[int, str]]:
- """Return ``(joint_id, joint_name)`` for ``body_id`` and its descendants."""
- joint_ids: list[int] = []
- for current_id in cls.body_subtree(model, body_id):
- start = model.body_jntadr[current_id]
- count = model.body_jntnum[current_id]
- for offset in range(count):
- joint_ids.append(start + offset)
- return [
- (jid, mujoco.mj_id2name(model, mujoco.mjtObj.mjOBJ_JOINT, jid) or "")
- for jid in joint_ids
- ]
-
- @staticmethod
- def joint_qpos_slice(model: mujoco.MjModel, joint_id: int) -> slice:
- """Return slice in ``data.qpos`` corresponding to ``joint_id``."""
- address = model.jnt_qposadr[joint_id]
- joint_type = model.jnt_type[joint_id]
- if joint_type == mujoco.mjtJoint.mjJNT_FREE:
- width = 7
- elif joint_type == mujoco.mjtJoint.mjJNT_BALL:
- width = 4
- else:
- width = 1
- return slice(address, address + width)
-
- @staticmethod
- def actuators_for_joint(model: mujoco.MjModel, joint_id: int) -> list[int]:
- """Return actuator indices driving any degree of freedom of ``joint_id``."""
- actuator_indices: list[int] = []
- for actuator_index in range(model.nu):
- dof = model.actuator_trnid[actuator_index][0]
- if dof >= 0 and model.dof_jntid[dof] == joint_id:
- actuator_indices.append(actuator_index)
- return actuator_indices
-
- @staticmethod
- def _joint_widths(model: mujoco.MjModel, joint_index: int) -> tuple[int, int]:
- """Return ``( list[int]:
- """List joint IDs under the body subtree in ``qpos`` order."""
- joint_ids: list[int] = []
- for body_id in cls.body_subtree(model, root_body_id):
- start = model.body_jntadr[body_id]
- num = model.body_jntnum[body_id]
- for offset in range(num):
- joint_ids.append(start + offset)
- return sorted(joint_ids, key=lambda jid: model.jnt_qposadr[jid])
-
- @classmethod
- def get_robot_state(
- cls,
- model: mujoco.MjModel,
- data: mujoco.MjData,
- root_body_id: int,
- ):
- """Return joint positions, velocities and accelerations.
-
- Parameters
- ----------
- model
- MuJoCo model instance.
- data
- MuJoCo data instance.
- root_body_id
- Root body identifier for the robot.
- as_dict
- If ``True``, return a dictionary representation.
-
- Returns
- -------
- dict | tuple
- Either a dictionary with per-joint information or a tuple of
- concatenated ``(qpos, qvel, qacc)`` arrays.
- """
- per_joint: list[dict[str, Any]] = []
- qpos_chunks: list[np.ndarray] = []
- qvel_chunks: list[np.ndarray] = []
- qacc_chunks: list[np.ndarray] = []
-
- for joint_index in cls._joints_in_subtree(model, root_body_id):
- name = (
- mujoco.mj_id2name(model, mujoco.mjtObj.mjOBJ_JOINT, joint_index)
- or f"joint_{joint_index}"
- )
- qpos_w, qvel_w = cls._joint_widths(model, joint_index)
-
- qpos_addr = model.jnt_qposadr[joint_index]
- qpos = (
- data.qpos[qpos_addr : qpos_addr + qpos_w]
- if qpos_w > 1
- else np.array([data.qpos[qpos_addr]])
- )
-
- dof_addr = model.jnt_dofadr[joint_index]
- qvel = (
- data.qvel[dof_addr : dof_addr + qvel_w]
- if qvel_w > 1
- else np.array([data.qvel[dof_addr]])
- )
- qacc = (
- data.qacc[dof_addr : dof_addr + qvel_w]
- if qvel_w > 1
- else np.array([data.qacc[dof_addr]])
- )
-
- per_joint.append(
- {
- "id": joint_index,
- "name": name,
- "qpos": qpos.copy(),
- "qvel": qvel.copy(),
- "qacc": qacc.copy(),
- }
- )
- qpos_chunks.append(qpos)
- qvel_chunks.append(qvel)
- qacc_chunks.append(qacc)
-
- qpos_cat = np.concatenate(qpos_chunks) if qpos_chunks else np.array([])
- qvel_cat = np.concatenate(qvel_chunks) if qvel_chunks else np.array([])
- qacc_cat = np.concatenate(qacc_chunks) if qacc_chunks else np.array([])
-
- return {
- "per_joint": per_joint,
- "qpos": qpos_cat,
- "qvel": qvel_cat,
- "qacc": qacc_cat,
- }
-
- def __init__(
- self,
- component_name: str,
- component_config: dict[str, Any] | None = None,
- builder: Any | None = None,
- ) -> None:
- """Create a robot driver for MuJoCo."""
- super().__init__(component_name, component_config, True)
-
- self.name = component_name
-
- class_path = self.config.get("class_path")
- mjcf_path = self.config.get("mjcf_path")
- if mjcf_path:
- mjcf_path = (
- Path(class_path) / mjcf_path
- if class_path is not None
- else Path(mjcf_path)
- )
-
- if not mjcf_path.is_absolute():
- mjcf_path = Path(self.config["class_dir"]) / mjcf_path
-
- if not mjcf_path.exists():
- log.error(f"The URDF path '{mjcf_path}' does not exist.")
- log.error(f"Full path: {mjcf_path.resolve()}")
- return
-
- position = self.config.get("base_position", [0.0, 0.0, 0.0])
- orientation = self.config.get("base_orientation", [0.0, 0.0, 0.0, 1.0])
- orientation = [orientation[3], orientation[0], orientation[1], orientation[2]]
-
- fixed_base = self.config.get("use_fixed_base", False)
- root_joint_name = f"{self.name}_root"
- self.initial_positions = self.config.get("initial_configuration", None)
-
- if builder is not None:
- builder.include_robot(
- name=self.name,
- file=mjcf_path,
- pos=position,
- quat=orientation,
- fixed_base=fixed_base,
- root_joint_name=root_joint_name,
- qpos=self.initial_positions,
- )
-
- def update_ids(self, model: mujoco.MjModel, data: mujoco.MjData) -> None:
- """Update internal IDs from the MuJoCo model."""
- self.model = model
- self.data = data
-
- self.body_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_BODY, self.name)
- if self.body_id < 0:
- raise ValueError(f"Body '{self.name}' not found in model.")
-
- self.actuated_joints = {
- f"joint{i}": mujoco.mj_name2id(
- model, mujoco.mjtObj.mjOBJ_ACTUATOR, f"actuator{i}"
- )
- for i in range(1, 8)
- }
- self.gripper_id = mujoco.mj_name2id(
- model, mujoco.mjtObj.mjOBJ_ACTUATOR, "actuator8"
- )
-
- self.actuated_joints = {**self.actuated_joints, "gripper": self.gripper_id}
- for i, (_, actuator_id) in enumerate(self.actuated_joints.items()):
- data.ctrl[actuator_id] = self.initial_positions[i]
-
- data.ctrl[self.gripper_id] = 180
-
- def check_torque_status(self) -> bool:
- """Check the torque status of the robot."""
- raise NotImplementedError(
- "MujocoRobotDriver.check_torque_status is not implemented yet."
- )
-
- def pass_joint_efforts(self, joints: list[str]) -> dict[str, float]:
- """Retrieve joint efforts (not implemented)."""
- raise NotImplementedError(
- "MujocoRobotDriver.pass_joint_efforts is not implemented yet."
- )
-
- def pass_joint_group_control_cmd(
- self, control_mode: str, cmd: dict[str, float], **kwargs
- ) -> None:
- """Send a group control command to the robot."""
- for value, actuator_id in zip(cmd.values(), self.actuated_joints.values()):
- self.data.ctrl[actuator_id] = value
-
- def pass_joint_positions(self, positions: dict[str, float]) -> dict[str, float]:
- """Return current joint positions for all actuated joints."""
- state = self.get_robot_state(self.model, self.data, self.body_id)
- positions_dict: dict[str, float] = {}
- for i, joint in enumerate(self.actuated_joints):
- positions_dict[joint] = state["qpos"][i]
- return positions_dict
-
- def pass_joint_velocities(self, joints: list[str]) -> dict[str, float]:
- """Return joint velocities (not implemented)."""
- raise NotImplementedError(
- "MujocoRobotDriver.pass_joint_velocities is not implemented yet."
- )
-
- def sim_reset(
- self, base_pos: list[float], base_orn: list[float], init_pos: list[float]
- ) -> None:
- """Reset the robot simulation (not implemented)."""
- raise NotImplementedError("MujocoRobotDriver.sim_reset is not implemented yet.")
diff --git a/ark/system/pybullet/pybullet_backend.py b/ark/system/pybullet/pybullet_backend.py
deleted file mode 100644
index fa6e030..0000000
--- a/ark/system/pybullet/pybullet_backend.py
+++ /dev/null
@@ -1,461 +0,0 @@
-"""@file pybullet_backend.py
-@brief Backend implementation for running simulations in PyBullet.
-"""
-
-import importlib.util
-import sys, ast, os
-import math
-import cv2
-from pathlib import Path
-from typing import Any, Optional, Dict
-
-import pybullet as p
-import pybullet_data
-from pybullet_utils.bullet_client import BulletClient
-
-from ark.tools.log import log
-from ark.system.simulation.simulator_backend import SimulatorBackend
-from ark.system.pybullet.pybullet_robot_driver import BulletRobotDriver
-from ark.system.pybullet.pybullet_camera_driver import BulletCameraDriver
-from ark.system.pybullet.pybullet_multibody import PyBulletMultiBody
-from arktypes import *
-
-
-def import_class_from_directory(path: Path) -> tuple[type, Optional[type]]:
- """!Load a class from ``path``.
-
- The helper searches for ``.py`` inside ``path`` and imports the
- class with the same name. If a ``Drivers`` class is present in the module
- its ``PYBULLET_DRIVER`` attribute is returned alongside the main class.
-
- @param path Path to the directory containing the module.
- @return Tuple ``(cls, driver_cls)`` where ``driver_cls`` is ``None`` when no
- driver is defined.
- @rtype Tuple[type, Optional[type]]
- """
- # Extract the class name from the last part of the directory path (last directory name)
- class_name = path.name
- file_path = path / f"{class_name}.py"
- # get the full absolute path
- file_path = file_path.resolve()
- if not file_path.exists():
- raise FileNotFoundError(f"The file {file_path} does not exist.")
-
- with open(file_path, "r", encoding="utf-8") as file:
- tree = ast.parse(file.read(), filename=file_path)
- # for imports
- module_dir = os.path.dirname(file_path)
- sys.path.insert(0, module_dir)
- # Extract class names from the AST
- class_names = [
- node.name for node in ast.walk(tree) if isinstance(node, ast.ClassDef)
- ]
- # check if Sensor_Drivers is in the class_names
- if "Drivers" in class_names:
- # Load the module dynamically
- spec = importlib.util.spec_from_file_location(class_names[0], file_path)
- module = importlib.util.module_from_spec(spec)
- sys.modules[class_names[0]] = module
- spec.loader.exec_module(module)
-
- class_ = getattr(module, class_names[0])
- sys.path.pop(0)
-
- drivers = class_.PYBULLET_DRIVER
- class_names.remove("Drivers")
-
- # Retrieve the class from the module (has to be list of one)
- class_ = getattr(module, class_names[0])
-
- if len(class_names) != 1:
- raise ValueError(
- f"Expected exactly two class definition in {file_path}, but found {len(class_names)}."
- )
-
- # Load the module dynamically
- spec = importlib.util.spec_from_file_location(class_name, file_path)
- module = importlib.util.module_from_spec(spec)
- sys.modules[class_name] = module
- spec.loader.exec_module(module)
-
- # Retrieve the class from the module (has to be list of one)
- class_ = getattr(module, class_names[0])
- sys.path.pop(0)
-
- # Return the class
- return class_, drivers
-
-
-class PyBulletBackend(SimulatorBackend):
- """Backend wrapper around the PyBullet client.
-
- This class handles scene creation, stepping the simulation and managing
- simulated components such as robots, objects and sensors.
- """
-
- def initialize(self) -> None:
- """!Initialize the PyBullet world.
-
- The method creates the Bullet client, configures gravity and time step
- and loads all robots, objects and sensors defined in
- ``self.global_config``. Optional frame capture settings are applied as
- well.
- """
- self.ready = False
- self.client = self._connect_pybullet(self.global_config)
- self.client.setAdditionalSearchPath(pybullet_data.getDataPath())
-
- # Render images from Pybullet and save
- self.save_render_config = self.global_config["simulator"].get(
- "save_render", None
- )
- if self.save_render_config is not None:
- self._rendered_time = -1.0
- self.save_path = Path(
- self.save_render_config.get("save_path", "output/save_render")
- )
- self.save_path.mkdir(parents=True, exist_ok=True)
-
- # Remove existing files
- remove_existing = self.save_render_config.get("remove_existing", True)
- if remove_existing:
- for child in self.save_path.iterdir():
- if child.is_file():
- child.unlink()
-
- # Get config
- default_extrinsics = {
- "look_at": [0, 0, 1.0],
- "distance": 3,
- "azimuth": 0,
- "elevation": 0,
- }
- default_intrinsics = {
- "width": 640,
- "height": 480,
- "field_of_view": 60,
- "near_plane": 0.1,
- "far_plane": 100.0,
- }
- self.save_interval = self.save_render_config.get("save_interval", 1 / 30)
- self.overwrite_file = self.save_render_config.get("overwrite_file", False)
- self.extrinsics = self.save_render_config.get(
- "extrinsics", default_extrinsics
- )
- self.intrinsics = self.save_render_config.get(
- "intrinsics", default_intrinsics
- )
-
- for additional_urdf_dir in self.global_config["simulator"]["config"].get(
- "urdf_dirs", []
- ):
- self.client.setAdditionalSearchPath(additional_urdf_dir)
-
- gravity = self.global_config["simulator"]["config"].get(
- "gravity", [0, 0, -9.81]
- )
- self.set_gravity(gravity)
-
- timestep = 1 / self.global_config["simulator"]["config"].get(
- "sim_frequency", 240.0
- )
- self.set_time_step(timestep)
-
- # Setup robots
- if self.global_config.get("robots", None):
- for robot_name, robot_config in self.global_config["robots"].items():
- self.add_robot(robot_name, robot_config)
-
- # Setup objects
- if self.global_config.get("objects", None):
- for obj_name, obj_config in self.global_config["objects"].items():
- self.add_sim_component(obj_name, obj_config)
-
- # Sensors have to be set up last, as e.g. cameras might need
- # a parent to attach to
- if self.global_config.get("sensors", None):
- for sensor_name, sensor_config in self.global_config["sensors"].items():
- self.add_sensor(sensor_name, sensor_config)
- self.ready = True
-
- def is_ready(self) -> bool:
- """!Check whether the backend has finished initialization.
-
- @return ``True`` once all components were created and the simulator is
- ready for stepping.
- @rtype bool
- """
- return self.ready
-
- def _connect_pybullet(self, config: dict[str, Any]):
- """!Create and return the Bullet client.
-
- ``config`` must contain the ``connection_mode`` under the ``simulator``
- section. Optionally ``mp4`` can be provided to enable video
- recording.
-
- @param config Global configuration dictionary.
- @return Initialized :class:`BulletClient` instance.
- @rtype BulletClient
- """
- kwargs = dict(options="")
- mp4 = config.get("mp4")
- if mp4:
- kwargs["options"] = f"--mp4={mp4}"
- connection_mode_str = config["simulator"]["config"]["connection_mode"].upper()
- connection_mode = getattr(p, connection_mode_str)
- return BulletClient(connection_mode, **kwargs)
-
- def set_gravity(self, gravity: tuple[float]) -> None:
- """!Set the world gravity.
-
- @param gravity Tuple ``(gx, gy, gz)`` specifying gravity in m/s^2.
- """
- self.client.setGravity(gravity[0], gravity[1], gravity[2])
-
- def set_time_step(self, time_step: float) -> None:
- """!Set the simulation timestep.
-
- @param time_step Length of a single simulation step in seconds.
- """
- self.client.setTimeStep(time_step)
- self._time_step = time_step
-
- ##########################################################
- #### ROBOTS, SENSORS AND OBJECTS ####
- ##########################################################
-
- def add_robot(self, name: str, robot_config: Dict[str, Any]):
- """!Instantiate and register a robot in the simulation.
-
- @param name Identifier for the robot.
- @param robot_config Robot specific configuration dictionary.
- """
- class_path = Path(robot_config["class_dir"])
- if class_path.is_file():
- class_path = class_path.parent
- RobotClass, DriverClass = import_class_from_directory(class_path)
- DriverClass = DriverClass.value
- driver = DriverClass(name, robot_config, self.client)
- robot = RobotClass(name=name, global_config=self.global_config, driver=driver)
-
- self.robot_ref[name] = robot
-
- def add_sim_component(
- self,
- name: str,
- obj_config: Dict[str, Any],
- ) -> None:
- """!Add a generic simulated object.
-
- @param name Name of the object.
- @param obj_config Object specific configuration dictionary.
- """
- sim_component = PyBulletMultiBody(
- name=name, client=self.client, global_config=self.global_config
- )
- self.object_ref[name] = sim_component
-
- def add_sensor(self, name: str, sensor_config: Dict[str, Any]) -> None:
- """!Instantiate and register a sensor.
-
- @param name Name of the sensor component.
- @param sensor_config Sensor configuration dictionary.
- """
- sensor_type = sensor_config["type"]
- class_path = Path(sensor_config["class_dir"])
- if class_path.is_file():
- class_path = class_path.parent
-
- SensorClass, DriverClass = import_class_from_directory(class_path)
- DriverClass = DriverClass.value
-
- attached_body_id = None
- if sensor_config["sim_config"].get("attach", None):
-
- print(self.global_config["objects"].keys())
- # search through robots and objects to find attach link if needed
- if (
- sensor_config["sim_config"]["attach"]["parent_name"]
- in self.global_config["robots"].keys()
- ):
- attached_body_id = self.robot_ref[
- sensor_config["sim_config"]["attach"]["parent_name"]
- ]._driver.ref_body_id
- elif (
- sensor_config["sim_config"]["attach"]["parent_name"]
- in self.global_config["objects"].keys()
- ):
- attached_body_id = self.object_ref[
- sensor_config["sim_config"]["attach"]["parent_name"]
- ].ref_body_id
- else:
- log.error(f"Parent to attach sensor " + name + " to does not exist !")
- driver = DriverClass(name, sensor_config, attached_body_id, self.client)
- sensor = SensorClass(
- name=name,
- driver=driver,
- global_config=self.global_config,
- )
-
- self.sensor_ref[name] = sensor
-
- def remove(self, name: str) -> None:
- """!Remove a component from the simulator.
-
- @param name Name of the robot, object or sensor to remove.
- """
- if name in self.robot_ref:
- self.robot_ref[name].shutdown()
- del self.robot_ref[name]
- elif name in self.sensor_ref:
- self.sensor_ref[name].shutdown()
- del self.obsensor_refject_ref[name]
- elif name in self.object_ref:
- self.object_ref[name].shutdown()
- del self.object_ref[name]
- else:
- log.warning("Could not remove " + name + ", it does not exist.")
- return
- log.ok("Deleted " + name + " !")
-
- #######################################
- #### SIMULATION ####
- #######################################
-
- def _all_available(self):
- """!Check whether all registered components are active.
-
- @return ``True`` if no component is suspended.
- @rtype bool
- """
- for robot in self.robot_ref:
- if self.robot_ref[robot]._is_suspended:
- return False
- for obj in self.object_ref:
- if self.object_ref[obj]._is_suspended:
- return False
- return True
-
- def step(self) -> None:
- """!Advance the simulation by one timestep.
-
- The method updates all registered components, advances the physics
- engine and optionally saves renders when enabled.
- """
- if self._all_available():
- self._step_sim_components()
- self.client.stepSimulation()
- self._simulation_time += self._time_step
-
- if self.save_render_config is not None:
- if (self._simulation_time - self._rendered_time) > self.save_interval:
- self.save_render()
- self._rendered_time = self._simulation_time
-
- else:
- log.panda("Did not step")
- pass
-
- def save_render(self):
- """!Render the scene and write the image to disk.
-
- The image is saved either as ``render.png`` when overwriting or with the
- current simulation time as filename when not.
- """
- # Calculate camera extrinsic matrix
- look_at = self.extrinsics["look_at"]
- azimuth = math.radians(self.extrinsics["azimuth"])
- distance = self.extrinsics["distance"]
-
- x = look_at[0] + distance * math.cos(azimuth)
- y = look_at[1] + distance * math.sin(azimuth)
- z = look_at[2] + self.extrinsics["elevation"]
-
- view_matrix = p.computeViewMatrix(
- cameraEyePosition=[x, y, z],
- cameraTargetPosition=look_at,
- cameraUpVector=[0, 0, 1],
- )
-
- # Calculate intrinsic matrix
- width = self.intrinsics["width"]
- height = self.intrinsics["height"]
- aspect = width / height
- projection_matrix = p.computeProjectionMatrixFOV(
- fov=self.intrinsics["field_of_view"],
- aspect=aspect,
- nearVal=self.intrinsics["near_plane"],
- farVal=self.intrinsics["far_plane"],
- )
-
- # Render the image
- image = p.getCameraImage(
- width, height, viewMatrix=view_matrix, projectionMatrix=projection_matrix
- )
- print(
- f"width:{width}, height:{height}, viewMatrix:{view_matrix}, projectionMatrix:{projection_matrix}"
- )
- print("DEBUG")
- print(f"DEBUG1: {type(image)}")
-
- # image[2] contains the color image (RGBA) as a numpy array
- rgba = image[2]
-
- # Save image
- bgra = cv2.cvtColor(rgba, cv2.COLOR_RGB2BGR)
- time_us = int(1e6 * self._simulation_time)
-
- if self.overwrite_file:
- save_path = self.save_path / "render.png"
- else:
- save_path = self.save_path / f"{time_us}.png"
- cv2.imwrite(str(save_path), bgra)
-
- def reset_simulator(self) -> None:
- """!Reset the entire simulator state.
-
- All robots, objects and sensors are destroyed and the backend is
- re-initialized using ``self.global_config``.
- """
- log.error("Reset Simulator function is not ready yet !")
- for robot in self.robot_ref:
- self.robot_ref[robot].kill_node()
-
- for obj in self.object_ref:
- self.object_ref[obj].kill_node()
-
- for sensor in self.sensor_ref:
- self.sensor_ref[sensor].kill_node()
-
- self.client.disconnect()
- self._simulation_time = 0.0
- self.initialize()
-
- if self.save_render_config is not None:
- self._rendered_time = -1.0
-
- log.ok("Simulator reset complete.")
-
- def get_current_time(self) -> float:
- """!Return the current simulation time.
-
- @return Elapsed simulation time in seconds.
- @rtype float
- """
- # https://pybullet.org/Bullet/phpBB3/viewtopic.php?t=12438
- return self._simulation_time
-
- def shutdown_backend(self):
- """!Disconnect all components and shut down the backend.
-
- This should be called at program termination to cleanly close the
- simulator and free all resources.
- """
- self.client.disconnect()
- for robot in self.robot_ref:
- self.robot_ref[robot].kill_node()
- for obj in self.object_ref:
- self.object_ref[obj].kill_node()
- for sensor in self.sensor_ref:
- self.sensor_ref[sensor].kill_node()
diff --git a/ark/system/pybullet/pybullet_camera_driver.py b/ark/system/pybullet/pybullet_camera_driver.py
deleted file mode 100644
index ddd3690..0000000
--- a/ark/system/pybullet/pybullet_camera_driver.py
+++ /dev/null
@@ -1,306 +0,0 @@
-"""@file pybullet_camera_driver.py
-@brief Camera driver for the PyBullet simulator.
-"""
-
-from abc import ABC, abstractmethod
-from enum import Enum
-from typing import Any, Optional, Dict, List
-
-from ark.tools.log import log
-from ark.system.driver.sensor_driver import CameraDriver
-
-import numpy as np
-import pybullet as p
-from scipy.spatial.transform import Rotation as R
-
-from ark.utils.camera_utils import CameraType
-
-
-def rotation_matrix_to_euler(R_world):
- """!Convert a rotation matrix to Euler angles.
-
- @param R_world ``3x3`` rotation matrix in row-major order.
- @return Euler angles ``[roll, pitch, yaw]`` in degrees.
- @rtype List[float]
- """
- r = R.from_matrix(R_world)
- euler_angles = r.as_euler("xyz", degrees=True)
- return euler_angles
-
-
-class BulletCameraDriver(CameraDriver):
- """Camera driver implementation for PyBullet."""
-
- def __init__(
- self,
- component_name: str,
- component_config: Dict[str, Any],
- attached_body_id: int = None,
- client: Any = None,
- ) -> None:
- """!Create a new camera driver.
-
- @param component_name Name of the camera component.
- @param component_config Configuration dictionary for the camera.
- @param attached_body_id ID of the body to attach the camera to.
- @param client Optional PyBullet client.
- @return ``None``
- """
- super().__init__(
- component_name, component_config, True
- ) # simulation is always True
- self.client = client
- self.attached_body_id = attached_body_id
-
- try:
- self.camera_type = CameraType(self.config["camera_type"])
- except ValueError:
- raise ValueError(f"Invalid camera type for {self.component_name} !")
-
- self.visual_body_id = None
- self.attached_body_id = attached_body_id
-
- self.visualize = self.config["sim_config"].get("visualize", False)
- self.urdf_path = self.config["sim_config"].get("urdf_path", None)
- self.fov = self.config["sim_config"].get("fov", 60)
- self.near_val = self.config["sim_config"].get("near_val", 0.1)
- self.far_val = self.config["sim_config"].get("far_val", 100.0)
-
- if self.camera_type == CameraType.FIXED:
- self.camera_target_position = self.config["sim_config"]["fix"][
- "camera_target_position"
- ]
- self.distance = self.config["sim_config"]["fix"]["distance"]
- self.yaw = self.config["sim_config"]["fix"]["yaw"]
- self.pitch = self.config["sim_config"]["fix"]["pitch"]
- self.roll = self.config["sim_config"]["fix"]["roll"]
- self.up_axis_index = self.config["sim_config"]["fix"]["up_axis_index"]
-
- view_matrix = self.client.computeViewMatrixFromYawPitchRoll(
- cameraTargetPosition=self.camera_target_position,
- distance=self.distance,
- yaw=self.yaw,
- pitch=self.pitch,
- roll=self.roll,
- upAxisIndex=self.up_axis_index,
- )
- # for visualization of the camera
- view_matrix_np = np.array(view_matrix).reshape(4, 4).T
- self.current_position = -view_matrix_np[:3, :3].T @ view_matrix_np[:3, 3]
- self.current_orientation = self.client.getQuaternionFromEuler(
- rotation_matrix_to_euler(view_matrix_np[:3, :3].T)
- )
-
- elif self.camera_type == CameraType.ATTACHED:
- # assert attached body exists
- assert self.attached_body_id is not None
-
- self.parent_name = self.config["sim_config"]["attach"]["parent_name"]
- self.parent_link = self.config["sim_config"]["attach"].get(
- "parent_link", None
- )
- self.offset_translation = self.config["sim_config"]["attach"].get(
- "offset_translation", [0, 0, 0]
- )
- self.offset_rotation = self.config["sim_config"]["attach"].get(
- "offset_rotation", [0, 0, 0]
- )
- self.rel_camera_target = self.config["sim_config"]["attach"].get(
- "rel_camera_target", [1, 0, 0]
- )
-
- # Get all link names and indices
- num_joints = p.getNumJoints(self.attached_body_id)
- self.link_info = {}
- for i in range(num_joints):
- joint_info = p.getJointInfo(self.attached_body_id, i)
- link_name = joint_info[12].decode(
- "utf-8"
- ) # joint_info[12] is the link name
- self.link_info[link_name] = i
-
- # Get the parent link ID
- self.parent_link_id = self.link_info.get(self.parent_link, None)
-
- # extract position and orientation of link
- try:
- if self.parent_link is None or self.parent_link_id is None:
- position, orientation = p.getBasePositionAndOrientation(
- self.attached_body_id
- )
- else:
- link_state = p.getLinkState(
- bodyUniqueId=self.attached_body_id,
- linkIndex=self.parent_link_id,
- computeForwardKinematics=True,
- )
- position = link_state[0]
- orientation = link_state[1]
- except:
- log.error(
- "Could not find link to attach "
- + self.component_name
- + " to "
- + self.parent_name
- + " !"
- )
- if len(self.offset_rotation) == 3: # euler
- offset_rot = self.client.getQuaternionFromEuler(self.offset_rotation)
- else: # quaternion
- offset_rot = self.offset_rotation
- position, orientation = self.client.multiplyTransforms(
- position, orientation, self.offset_translation, offset_rot
- )
- # update position and orientation
- self.current_position = position
- self.current_orientation = orientation
-
- if self.visualize:
- visual_shape = self.client.createVisualShape(
- shapeType=p.GEOM_BOX,
- halfExtents=[0.005, 0.02, 0.01], # x,y,z
- rgbaColor=[1, 0, 0, 1], # Red color
- )
- self.visual_body_id = self.client.createMultiBody(
- baseVisualShapeIndex=visual_shape,
- basePosition=self.current_position,
- baseOrientation=self.current_orientation,
- )
-
- self.width = self.config.get("width", 640)
- self.height = self.config.get("height", 480)
- self.aspect = self.width / self.height
-
- # check if color stream is enabled
- if self.config["streams"].get("color"):
- if self.config["streams"]["color"]["enable"]:
- self.color_stream = True
- else:
- self.color_stream = False
-
- # check if depth stream is enabled
- if self.config["streams"].get("depth"):
- if self.config["streams"]["depth"]["enable"]:
- self.depth_stream = True
- else:
- self.depth_stream = False
-
- # check if infrared stream is enabled
- if self.config["streams"].get("infrared"):
- if self.config["streams"]["infrared"]["enable"]:
- log.warn("Infrared stream is not supported in pybullet !")
- self.infrared_stream = False
-
- # check if segmentation stream is enabled
- if self.config["streams"].get("segmentation"):
- if self.config["streams"]["segmentation"]["enable"]:
- self.segmentation_stream = True
- else:
- self.segmentation_stream = False
-
- def _update_position(self) -> Any:
- """!Update internal pose information.
-
- When the camera is attached to a body this queries PyBullet for the
- current link pose and updates ``self.current_position`` and
- ``self.current_orientation``.
- """
- if self.camera_type == CameraType.ATTACHED:
- if self.parent_link is None or self.parent_link_id is None:
- position, orientation = p.getBasePositionAndOrientation(
- self.attached_body_id
- )
- else:
- link_state = p.getLinkState(
- bodyUniqueId=self.attached_body_id,
- linkIndex=self.parent_link_id,
- computeForwardKinematics=True,
- )
- position = link_state[0]
- orientation = link_state[1]
- if len(self.offset_rotation) == 3: # euler
- offset_rot = self.client.getQuaternionFromEuler(self.offset_rotation)
- else: # quaternion
- offset_rot = self.offset_rotation
- self.current_position, self.current_orientation = (
- self.client.multiplyTransforms(
- position, orientation, self.offset_translation, offset_rot
- )
- )
- # update visualization
- if self.visualize:
- self.client.resetBasePositionAndOrientation(
- self.visual_body_id, self.current_position, self.current_orientation
- )
-
- def get_images(self):
- """!Capture camera images from the simulator.
-
- Depending on the enabled streams the returned dictionary can contain the
- keys ``color``, ``depth`` and ``segmentation``.
-
- @return Dictionary mapping stream names to ``numpy.ndarray`` images.
- @rtype Dict[str, np.ndarray]
- """
- if self.camera_type == CameraType.ATTACHED:
- self._update_position()
-
- cam_target = tuple(
- a + b
- for a, b in zip(
- tuple(self.current_position),
- p.rotateVector(self.current_orientation, self.rel_camera_target),
- )
- )
- cam_up_vector = p.rotateVector(self.current_orientation, [0, 0, 1])
- view_matrix = self.client.computeViewMatrix(
- cameraEyePosition=self.current_position,
- cameraTargetPosition=cam_target,
- cameraUpVector=cam_up_vector,
- )
- elif self.camera_type == CameraType.FIXED:
- view_matrix = self.client.computeViewMatrixFromYawPitchRoll(
- cameraTargetPosition=self.camera_target_position,
- distance=self.distance,
- yaw=self.yaw,
- pitch=self.pitch,
- roll=self.roll,
- upAxisIndex=self.up_axis_index,
- )
-
- projection_matrix = self.client.computeProjectionMatrixFOV(
- fov=self.fov, aspect=self.aspect, nearVal=self.near_val, farVal=self.far_val
- )
-
- _, _, rgb_img, depth_img, segmentation_img = self.client.getCameraImage(
- width=self.width,
- height=self.height,
- viewMatrix=view_matrix,
- projectionMatrix=projection_matrix,
- )
-
- # pack image into dictionary
- images = {}
- if self.color_stream:
- # convert to rgb to bgr
- bgr_image = rgb_img[..., :3][:, :, ::-1]
- images["color"] = bgr_image
- if self.depth_stream:
- # Convert to meters
- depth_img = (self.far_val * self.near_val) / (
- self.far_val - (self.far_val - self.near_val) * depth_img
- )
- images["depth"] = depth_img
- if self.segmentation_stream:
- images["segmentation"] = segmentation_img
-
- return images
-
- def shutdown_driver(self) -> None:
- """!Clean up any resources used by the driver.
-
- Called when the simulator is shutting down. The PyBullet camera driver
- currently does not allocate additional resources so the method is empty.
- """
- # nothing to worry about here
- pass
diff --git a/ark/system/pybullet/pybullet_lidar_driver.py b/ark/system/pybullet/pybullet_lidar_driver.py
deleted file mode 100644
index 41b57ae..0000000
--- a/ark/system/pybullet/pybullet_lidar_driver.py
+++ /dev/null
@@ -1,259 +0,0 @@
-"""@file pybullet_lidar_driver.py
-@brief LiDAR driver implementation for PyBullet.
-"""
-
-from abc import ABC, abstractmethod
-from enum import Enum
-from typing import Any, Optional, Dict, List
-
-from ark.tools.log import log
-from ark.system.driver.sensor_driver import LiDARDriver
-
-import numpy as np
-import pybullet as p
-from scipy.spatial.transform import Rotation as R
-
-"""
-Example Config:
-class_dir: "examples/sensors/lidar" # Directory where the class is located
-type: "LiDAR" # Type of sensor
- sim_config:
- lidar_type: "attached" # Fixed or attached to another body
- num_rays: 360 # Number of rays
- linear_range: 10.0 # Maximum range in meters
- angular_range: 360.0 # Field of view in degrees
- fix:
- position: [0.0, 0.0, 1.0] # Position in meters
- yaw: 0.0 # Yaw angle (rotation about Z-axis) in degrees
- attach:
- parent_name: ""SimpleTwoWheelCa" # Name of the parent body to attach to
- parent_link: lidar_link # Link name of the parent body to attach to. Remove this config param to attach it to the base
- offset_translation: [0.0, 0.0, 0.02] # Offset translation from the parent body in meters
- offset_yaw: 0.0 # Offset yaw angle (rotation about Z-axis) in degrees
-"""
-
-
-class LiDARType(Enum):
- """Types of LiDAR supported in the simulation."""
-
- FIXED = "fixed"
- ATTACHED = "attached"
-
-
-class BulletLiDARDriver(LiDARDriver):
- """LiDAR driver for the PyBullet simulator."""
-
- def __init__(
- self,
- component_name: str,
- component_config: Dict[str, Any],
- attached_body_id: int = None,
- client: Any = None,
- ) -> None:
- """Initialize the BulletLiDARDriver.
-
- @param component_name Name of the LiDAR component.
- @param component_config Dictionary containing LiDAR configuration (e.g., number of rays, range).
- @param attached_body_id Optional PyBullet body ID to which the LiDAR is attached.
- @param client PyBullet client ID for multi-client simulations.
- """
- super().__init__(
- component_name, component_config, True
- ) # sim is always True for pybullet
- self.client = client
- self.attached_body_id = attached_body_id
-
- sim_config = self.config.get("sim_config", {})
-
- self.num_rays = sim_config.get("num_rays", 360)
- self.linear_range = sim_config.get("linear_range", 10.0)
- self.angular_range = sim_config.get("angular_range", 360.0)
- self.lidar_type = sim_config.get("lidar_type", "fixed")
-
- # Check config
- assert (
- self.num_rays > 0
- ), f"num_rays should be greater than 0 for {self.component_name}"
- assert (
- self.linear_range > 0
- ), f"linear_range should be greater than 0 for {self.component_name}"
- assert (
- self.angular_range > 0 and self.angular_range <= 360
- ), f"angular_range should >0 and <= 360 for {self.component_name}"
- assert self.lidar_type in [
- "fixed",
- "attached",
- ], f"lidar_type should be either 'fixed' or 'attached' for {self.component_name}"
-
- try:
- self.lidar_type = LiDARType(self.lidar_type)
- except ValueError:
- raise ValueError(f"Invalid lidar type for {self.component_name} !")
-
- if self.lidar_type == LiDARType.FIXED:
- fix_config = sim_config.get("fix", {})
- self.current_position = fix_config.get("position", [0, 0, 0])
-
- yaw = fix_config.get("yaw", 0)
- yaw = np.deg2rad(yaw)
- self.current_orientation = self.client.getQuaternionFromEuler([0, 0, yaw])
-
- elif self.lidar_type == LiDARType.ATTACHED:
- # assert attached body exists
- assert self.attached_body_id is not None
-
- attach_config = sim_config.get("attach", {})
- self.parent_name = attach_config.get("parent_name", "SimpleTwoWheelCar")
- self.parent_link = attach_config.get("parent_link", None)
- self.offset_translation = attach_config.get("offset_translation", [0, 0, 0])
- self.offset_yaw = np.deg2rad(attach_config.get("offset_yaw", 0))
-
- # Get all link names and indices
- num_joints = p.getNumJoints(self.attached_body_id)
- self.link_info = {}
- for i in range(num_joints):
- joint_info = p.getJointInfo(self.attached_body_id, i)
- link_name = joint_info[12].decode(
- "utf-8"
- ) # joint_info[12] is the link name
- self.link_info[link_name] = i
-
- # Get the parent link ID
- self.parent_link_id = self.link_info.get(self.parent_link, None)
-
- # extract position and orientation of link
- try:
- if (
- p.getNumJoints(self.attached_body_id) == 0
- or self.parent_link_id is None
- ):
- position, orientation = p.getBasePositionAndOrientation(
- self.attached_body_id
- )
- else:
- link_state = p.getLinkState(
- bodyUniqueId=self.attached_body_id,
- linkIndex=self.parent_link_id,
- computeForwardKinematics=True,
- )
- position = link_state[0]
- orientation = link_state[1]
- except:
- log.error(
- "Could not find link to attach "
- + self.component_name
- + " to "
- + self.parent_name
- + " !"
- )
-
- self.offset_rot = self.client.getQuaternionFromEuler(
- [0, 0, self.offset_yaw]
- )
- position, orientation = self.client.multiplyTransforms(
- position, orientation, self.offset_translation, self.offset_rot
- )
- # update position and orientation
- self.current_position = position
- self.current_orientation = orientation
-
- def _update_position(self) -> Any:
- """!Update the LiDAR pose when attached to another body.
-
- This queries the pose of the attachment link and applies the configured
- offset. ``self.current_position`` and ``self.current_orientation`` are
- updated accordingly.
- """
- if self.lidar_type == LiDARType.ATTACHED:
- if (
- p.getNumJoints(self.attached_body_id) == 0
- or self.parent_link_id is None
- ):
- position, orientation = p.getBasePositionAndOrientation(
- self.attached_body_id
- )
- else:
- link_state = p.getLinkState(
- bodyUniqueId=self.attached_body_id,
- linkIndex=self.parent_link_id,
- computeForwardKinematics=True,
- )
-
- position = link_state[0]
- orientation = link_state[1]
- self.current_position, self.current_orientation = (
- self.client.multiplyTransforms(
- position, orientation, self.offset_translation, self.offset_rot
- )
- )
-
- def get_scan(self) -> Dict[str, np.ndarray]:
- """!Retrieve a simulated LiDAR scan from PyBullet.
-
- The returned dictionary contains the keys ``angles`` and ``ranges``. A
- range value of ``-1`` indicates that no hit was recorded for the
- corresponding angle.
-
- @return Dictionary with keys ``angles`` and ``ranges``.
- @rtype Dict[str, np.ndarray]
- """
- if self.lidar_type == LiDARType.ATTACHED:
- self._update_position()
-
- # Get current yaw
- euler = p.getEulerFromQuaternion(self.current_orientation)
- yaw = euler[2]
-
- # Set angular range
- angular_range = np.deg2rad(self.angular_range)
- min_angle = yaw - angular_range / 2
- max_angle = yaw + angular_range / 2
-
- # Don't repeat the same angle if the range is 360 degrees
- if self.angular_range == 360:
- endpoint = False
- else:
- endpoint = True
-
- # Generate angles
- angles = np.linspace(min_angle, max_angle, self.num_rays, endpoint=endpoint)
-
- # Ray directions (2D plane, xy only)
- dx = np.cos(angles)
- dy = np.sin(angles)
- directions = np.stack(
- [dx, dy, np.zeros_like(dx)], axis=1
- ) # shape (num_rays, 3)
-
- # Ray start and end positions
- ray_starts = np.array(self.current_position).reshape(1, 3) # shape (1, 3)
- ray_starts = ray_starts.repeat(self.num_rays, axis=0) # shape (num_rays, 3)
- ray_ends = ray_starts + directions * self.linear_range
-
- # Perform ray casting
- results = p.rayTestBatch(ray_starts.tolist(), ray_ends.tolist())
-
- # Extract distances
- ranges = []
- for i, result in enumerate(results):
- hit = result[0]
- hit_position = result[3]
- if hit != -1:
- dist = np.linalg.norm(np.array(hit_position) - np.array(ray_starts[i]))
- else:
- dist = -1
- ranges.append(dist)
- ranges = np.array(ranges)
-
- # Convert angles to the LiDAR's reference frame
- angles = angles - yaw
- scan = {"angles": angles, "ranges": ranges}
- return scan
-
- def shutdown_driver(self) -> None:
- """!Shutdown the LiDAR driver.
-
- Currently no additional resources are allocated, so this is a no-op.
- """
- # nothing to worry about here
- pass
diff --git a/ark/system/pybullet/pybullet_multibody.py b/ark/system/pybullet/pybullet_multibody.py
deleted file mode 100644
index ea1c4f7..0000000
--- a/ark/system/pybullet/pybullet_multibody.py
+++ /dev/null
@@ -1,200 +0,0 @@
-"""@file pybullet_multibody.py
-@brief Abstractions for multi-body objects in PyBullet.
-"""
-
-from abc import ABC, abstractmethod
-from typing import Any, Dict, Optional, Union
-from enum import Enum
-from pathlib import Path
-import yaml
-import os
-
-from ark.tools.log import log
-from ark.system.component.sim_component import SimComponent
-from arktypes import flag_t, rigid_body_state_t
-
-
-class SourceType(Enum):
- """Supported source types for object creation."""
-
- URDF = "urdf"
- PRIMITIVE = "primitive"
- SDF = "sdf"
- MJCF = "mjcf"
-
-
-class PyBulletMultiBody(SimComponent):
- """Utility class for creating PyBullet multi-body objects."""
-
- def __init__(
- self,
- name: str,
- client: Any,
- global_config: Dict[str, Any] = None,
- ) -> None:
- """Instantiate a PyBulletMultiBody object.
-
- @param name Name of the object.
- @param client Bullet client used for creation.
- @param global_config Global configuration dictionary.
- @return ``None``
- """
-
- super().__init__(name, global_config)
- self.client = client
- self.namespace = global_config["namespace"]
- source_str = self.config["source"]
- source_type = getattr(SourceType, source_str.upper())
-
- if source_type == SourceType.URDF:
- urdf_path = self.config["urdf_path"]
- if not urdf_path:
- log.error(
- "Selected loading object "
- + name
- + " from URDF, but no URDF was provided. Check your config again."
- )
-
- # If URDF path is provided, load the URDF
- base_position = self.config.get(
- "base_position", [0, 0, 0]
- ) # Default to (0, 0, 0) if not provided
- base_orientation = self.config.get(
- "base_orientation", [0, 0, 0, 1]
- ) # Default to identity quaternion if not provided
- if (
- len(base_orientation) == 3
- ): # Convert euler angles to quaternion if provided
- base_orientation = self.client.getQuaternionFromEuler(base_orientation)
-
- global_scaling = self.config.get("global_scaling", 1.0) # Default is 1.0
-
- # Load the URDF into the PyBullet simulation
- self.ref_body_id = client.loadURDF(
- fileName=urdf_path,
- basePosition=base_position,
- baseOrientation=base_orientation,
- globalScaling=global_scaling,
- useMaximalCoordinates=1,
- )
-
- # If there is any additional configuration for visual, collision, or dynamics, apply them
- vis = self.config.get("visual")
- if vis:
- vis_shape_type = getattr(client, vis["shape_type"].upper())
- vis_opts = vis["visual_shape"]
- vid = client.createVisualShape(vis_shape_type, **vis_opts)
- client.changeVisualShape(
- self.ref_body_id, -1, visualShapeIndex=vid
- ) # Change the visual shape
-
- col = self.config.get("collision")
- if col:
- col_shape_type = getattr(client, col["shape_type"].upper())
- col_opts = col["collision_shape"]
- cid = client.createCollisionShape(col_shape_type, **col_opts)
- client.changeCollisionShape(
- self.ref_body_id, -1, collisionShapeIndex=cid
- ) # Change the collision shape
-
- dynamics = self.config.get("dynamics")
- if dynamics:
- client.changeDynamics(
- self.ref_body_id, -1, **dynamics
- ) # Apply dynamics settings if present
- elif source_type == SourceType.PRIMITIVE:
- # Fall back to the original primitive creation if no URDF path is provided
- vis = self.config.get("visual")
- if vis:
- vis_shape_type = getattr(client, vis["shape_type"].upper())
- vis_opts = vis["visual_shape"]
- vid = client.createVisualShape(vis_shape_type, **vis_opts)
- else:
- vid = -1
- col = self.config.get("collision")
- if col:
- col_shape_type = getattr(client, col["shape_type"].upper())
- col_opts = col["collision_shape"]
- cid = client.createCollisionShape(col_shape_type, **col_opts)
- else:
- cid = -1
- kwargs = dict(
- baseCollisionShapeIndex=cid,
- baseVisualShapeIndex=vid,
- )
- # pybullet format
- multi_body = self.config["multi_body"]
- multi_body["basePosition"] = self.config["base_position"]
- multi_body["baseOrientation"] = self.config["base_orientation"]
- kwargs = {**kwargs, **multi_body}
- self.ref_body_id = client.createMultiBody(**kwargs)
-
- dynamics = self.config.get("dynamics")
- if dynamics:
- client.changeDynamics(self.ref_body_id, -1, **dynamics)
- elif source_type == SourceType.SDF:
- raise NotImplementedError
- elif source_type == SourceType.MJCF:
- raise NotImplementedError
- else:
- log.error("Unknown source specification. Check your config file.")
-
- # setup communication
- self.publisher_name = f"{self.namespace}/" + self.name + "/ground_truth/sim"
-
- if self.publish_ground_truth:
- self.state_publisher = self.component_channels_init(
- {self.publisher_name: rigid_body_state_t}
- )
-
- def get_object_data(self):
- """!Return the current state of the simulated object.
-
- @return Dictionary with position, orientation and velocities of the
- object.
- @rtype Dict[str, Any]
- """
- position, orientation = self.client.getBasePositionAndOrientation(
- self.ref_body_id
- )
- lin_vel, ang_vel = self.client.getBaseVelocity(self.ref_body_id)
- return {
- "name": self.name,
- "position": position,
- "orientation": orientation,
- "lin_velocity": lin_vel,
- "ang_velocity": ang_vel,
- }
-
- def pack_data(self, data_dict):
- """!Convert a state dictionary to a ``rigid_body_state_t`` message.
-
- @param data_dict Dictionary as returned by :func:`get_object_data`.
- @return Mapping suitable for :class:`MultiChannelPublisher`.
- @rtype Dict[str, rigid_body_state_t]
- """
- msg = rigid_body_state_t()
- msg.name = data_dict["name"]
- msg.position = data_dict["position"]
- msg.orientation = data_dict["orientation"]
- msg.lin_velocity = data_dict["lin_velocity"]
- msg.ang_velocity = data_dict["ang_velocity"]
- return {self.publisher_name: msg}
-
- def reset_component(self, channel, msg) -> None:
- """!Reset the object pose using a message.
-
- @param channel LCM channel on which the reset request was received.
- @param msg ``rigid_body_state_t`` containing the desired pose.
- @return ``flag_t`` acknowledging the reset.
- """
- new_pos = msg.position
- new_orn = msg.orientation
- log.info(f"Resetting object {self.name} to position: {new_pos}")
- log.info(
- "PyBullet does not support resetting with velocities, Only using positions."
- )
- self.client.resetBasePositionAndOrientation(self.ref_body_id, new_pos, new_orn)
- log.ok(f"Reset object {self.name} completed at: {new_pos}")
-
- return flag_t()
diff --git a/ark/system/pybullet/pybullet_robot_driver.py b/ark/system/pybullet/pybullet_robot_driver.py
deleted file mode 100644
index 74dbf9f..0000000
--- a/ark/system/pybullet/pybullet_robot_driver.py
+++ /dev/null
@@ -1,315 +0,0 @@
-"""@file pybullet_robot_driver.py
-@brief Robot driver handling PyBullet specific commands.
-"""
-
-from abc import ABC, abstractmethod
-from enum import Enum
-from typing import Any, Optional, Dict, List
-import os
-import pybullet as p
-from pathlib import Path
-
-from ark.tools.log import log
-from ark.system.driver.robot_driver import SimRobotDriver, ControlType
-
-# for pybullet setJointMotorControlArray optional arguments
-motor_control_kwarg = {
- "position": "targetPositions",
- "velocity": "targetVelocities",
- "torque": "forces",
-}
-
-
-class BulletRobotDriver(SimRobotDriver):
- """Robot driver that interfaces with the PyBullet simulation."""
-
- def __init__(
- self,
- component_name=str,
- component_config: Dict[str, Any] = None,
- client: Any = None,
- ) -> None:
- """!Create a robot driver for PyBullet.
-
- @param component_name Name of the robot component.
- @param component_config Configuration dictionary for the robot.
- @param client Bullet client instance.
- @return ``None``
- """
- super().__init__(component_name, component_config, True)
-
- self.client = client
-
- self.base_position = self.config.get("base_position", [0.0, 0.0, 0.0])
- self.base_orientation = self.config.get(
- "base_orientation", [0.0, 0.0, 0.0, 1.0]
- )
- if len(self.base_orientation) == 3:
- self.base_orientation = p.getQuaternionFromEuler(self.base_orientation)
-
- self.load_robot(self.base_position, self.base_orientation, None)
- self.initial_configuration = self.config.get(
- "initial_configuration", [0.0] * self.client.getNumJoints(self.ref_body_id)
- )
-
- self.num_joints = self.client.getNumJoints(self.ref_body_id)
- self.bullet_joint_infos = {}
- # {"name" : {"index" : ... ,
- # "type" : ... ,
- # "actuated" : ... ,
- # "parent_link" : ... ,
- # "child_link" : ... ,
- # "lower_limit" : ... ,
- # "upper_limit" : ... ,
- # "effort_limit" : ... ,
- # "velocity_limit" : ... ,
- #
- # "joint_axis" : ... , # PyBullet specific
- # "joint_parent_index": ... , # PyBullet specific
- # "joint_child_index" : ... , # PyBullet specific
- # }
- # }
-
- self.actuated_joints = {}
- self.joints = {}
- # {"name" : index}
-
- for joint_index in range(self.num_joints):
- # extract joint information
- joint_info = self.client.getJointInfo(self.ref_body_id, joint_index)
- # (jointIndex, jointName, jointType, jointAxis, jointLowerLimit, jointUpperLimit,
- # jointMaxForce, jointMaxVelocity, linkName, jointType, jointParentindex, jointChildIndex)
- joint_name = joint_info[1].decode("utf-8")
- self.joints[joint_name] = joint_index
- self.bullet_joint_infos[joint_name] = {}
- self.bullet_joint_infos[joint_name]["index"] = joint_index
- self.bullet_joint_infos[joint_name]["type"] = joint_info[2]
- if self.bullet_joint_infos[joint_name]["type"] == 4:
- self.bullet_joint_infos[joint_name]["actuated"] = False
- else:
- self.bullet_joint_infos[joint_name]["actuated"] = True
- self.actuated_joints[joint_name] = joint_index
- self.bullet_joint_infos[joint_name]["parent_link"] = None
- self.bullet_joint_infos[joint_name]["child_link"] = None
- self.bullet_joint_infos[joint_name]["lower_limit"] = joint_info[4]
- self.bullet_joint_infos[joint_name]["upper_limit"] = joint_info[5]
- self.bullet_joint_infos[joint_name]["effort_limit"] = joint_info[6]
- self.bullet_joint_infos[joint_name]["velocity_limit"] = joint_info[7]
-
- self.bullet_joint_infos[joint_name]["joint_axis"] = joint_info[3]
- self.bullet_joint_infos[joint_name]["joint_parent_index"] = joint_info[10]
- self.bullet_joint_infos[joint_name]["joint_child_index"] = joint_info[11]
-
- self.sim_reset(
- base_pos=self.base_position,
- base_orn=self.base_orientation,
- q_init=self.initial_configuration,
- )
-
- # PyBullet specific : extract and save joint group information to handle torque control
- torque_control_groups = {}
-
- for group_name, group_config in self.config.get("joint_groups", {}).items():
- # add control type from enum to internal config dict
-
- if group_config["control_mode"] == self.client.TORQUE_CONTROL:
- force_limit = group_config.get("force_limit", 0.0)
- torque_control_groups[group_name] = {}
- torque_control_groups[group_name]["force_limit"] = force_limit
- torque_control_groups[group_name]["indices"] = []
- for joint in group_config["joints"]:
- joint_idx = self.bullet_joint_infos[joint]["index"]
- torque_control_groups[group_name]["indices"].append(joint_idx)
-
- # Setup torque control
- # https://pybullet.org/Bullet/phpBB3/viewtopic.php?t=12644
- for group_name, group_data in torque_control_groups.items():
- joint_indices = group_data["indices"]
- force_limit = group_data["force_limit"]
- self.client.setJointMotorControlArray(
- self.ref_body_id,
- joint_indices,
- self.client.VELOCITY_CONTROL,
- forces=[force_limit] * len(joint_indices),
- )
-
- def load_robot(
- self, base_position=None, base_orientation=None, q_init=None
- ) -> None:
- """!Load the robot model into the simulator.
-
- @param base_position Optional base position ``[x, y, z]``.
- @param base_orientation Optional base orientation as quaternion.
- @param q_init Optional list of initial joint positions.
- """
- kwargs = {}
-
- kwargs["useFixedBase"] = self.config.get("use_fixed_base", 1)
-
- if self.config.get("merge_fixed_links", True):
- kwargs["flags"] = p.URDF_MERGE_FIXED_LINKS
-
- if base_position is not None:
- kwargs["basePosition"] = base_position
- else:
- kwargs["basePosition"] = self.config.get("base_position", [0.0, 0.0, 0.0])
-
- if base_orientation is not None:
- kwargs["baseOrientation"] = base_orientation
- else:
- kwargs["baseOrientation"] = self.config.get(
- "base_orientation", [0.0, 0.0, 0.0, 1.0]
- )
-
- urdf_path = self.config.get("urdf_path", None)
- mjcf_path = self.config.get("mjcf_path", None)
- class_path = self.config.get("class_path", None)
- if mjcf_path and urdf_path:
- log.warning("Both urdf and mjcf paths are provided. Defaulting to URDF.")
- if urdf_path:
- # Append the URDF path to the class path if provided
- if class_path is not None:
- urdf_path = Path(class_path) / urdf_path
- else:
- urdf_path = Path(urdf_path)
-
- # Make the URDF path absolute if it is not already
- if not urdf_path.is_absolute():
- urdf_path = Path(self.config["class_dir"]) / urdf_path
-
- # Check if the URDF path exists
- if not urdf_path.exists():
- log.error(f"The URDF path '{urdf_path}' does not exist.")
- log.error(f"Full path: {urdf_path.resolve()}")
- # print the full path for debugging
-
- return
-
- # Load the URDF into the simulator
- self.ref_body_id = self.client.loadURDF(str(urdf_path), **kwargs)
- log.ok(
- f"Initialized robot specified by URDF '{urdf_path}' in PyBullet simulator."
- )
-
- if q_init is not None:
- for joint in range(self.client.getNumJoints(self.ref_body_id)):
- self.client.resetJointState(self.ref_body_id, joint, q_init[joint], 0.0)
-
- #####################
- ## get infos ##
- #####################
-
- def check_torque_status(self) -> bool:
- """!Return ``True`` as simulated robots are always torqued.
-
- @return Always ``True`` in simulation.
- @rtype bool
- """
- return True # simulated robot is always torqued in bullet
-
- def pass_joint_positions(self, joints: List[str]) -> Dict[str, float]:
- """!Return the current joint positions.
-
- @param joints List of joint names.
- @return Dictionary from joint name to position.
- @rtype Dict[str, float]
- """
- pos = {}
- idx = [self.actuated_joints[joint] for joint in joints]
- # Iterate over each joint index and corresponding joint state to fill dictionaries
- for name, idx in zip(joints, idx):
- state = self.client.getJointState(self.ref_body_id, idx)
- pos[name] = state[0] # Joint position
- return pos
-
- def pass_joint_velocities(self, joints: List[str]) -> Dict[str, float]:
- """!Return the current joint velocities.
-
- @param joints List of joint names.
- @return Dictionary from joint name to velocity.
- @rtype Dict[str, float]
- """
- vel = {}
- idx = [self.actuated_joints[joint] for joint in joints]
- # Iterate over each joint index and corresponding joint state to fill dictionaries
- for name, idx in zip(joints, idx):
- state = self.client.getJointState(self.ref_body_id, idx)
- vel[name] = state[1] # Joint velocity
- return vel
-
- def pass_joint_efforts(self, joints: List[str]) -> Dict[str, float]:
- """!Return the current joint efforts.
-
- @param joints List of joint names.
- @return Dictionary from joint name to effort.
- @rtype Dict[str, float]
- """
- eff = {}
- idx = [self.actuated_joints[joint] for joint in joints]
- # Iterate over each joint index and corresponding joint state to fill dictionaries
- for name, idx in zip(joints, idx):
- state = self.client.getJointState(self.ref_body_id, idx)
- eff[name] = state[3] # Joint applied force (effort)
- return eff
-
- #####################
- ## control ##
- #####################
-
- def pass_joint_group_control_cmd(
- self, control_mode: str, cmd: Dict[str, float], **kwargs
- ) -> None:
- """!Send a control command to a group of joints.
-
- @param control_mode One of ``position``, ``velocity`` or ``torque``.
- @param cmd Mapping from joint names to command values.
- @param kwargs Additional keyword arguments forwarded to PyBullet.
- @return ``None``
- """
- idx = [self.actuated_joints[joint] for joint in cmd.keys()]
-
- kwargs = {motor_control_kwarg[control_mode]: list(cmd.values())}
- if control_mode == ControlType.POSITION.value:
- control_mode = p.POSITION_CONTROL
- elif control_mode == ControlType.VELOCITY.value:
- control_mode = p.VELOCITY_CONTROL
- elif control_mode == ControlType.TORQUE.value:
- control_mode = p.TORQUE_CONTROL
- else:
- log.error(
- "Invalid control mode. Please use 'position', 'velocity', or 'torque', but received: "
- + control_mode
- )
-
- self.client.setJointMotorControlArray(
- bodyUniqueId=self.ref_body_id,
- jointIndices=idx,
- controlMode=control_mode,
- **kwargs,
- )
-
- #####################
- ## misc. ##
- #####################
-
- def sim_reset(
- self, base_pos: List[float], base_orn: List[float], q_init: List[float]
- ) -> None:
- """!Reset the robot in the simulator.
-
- @param base_pos New base position.
- @param base_orn New base orientation quaternion.
- @param q_init List of joint positions after the reset.
- """
- # delete the robot
- self.client.removeBody(self.ref_body_id)
- self.load_robot(
- base_position=base_pos, base_orientation=base_orn, q_init=q_init
- )
-
- log.ok("Reset robot " + self.component_name + " completed.")
-
- # print the joint positons after reset
- joint_positions = self.pass_joint_positions(list(self.actuated_joints.keys()))
- log.info("Joint positions after reset: " + str(joint_positions))
- return
diff --git a/ark/system/simulation/__init__.py b/ark/system/simulation/__init__.py
deleted file mode 100644
index a097d4d..0000000
--- a/ark/system/simulation/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-"""ARK system simulation package.
-
-This package contains abstractions for simulator backends and nodes used to
-run robotic simulations within the ARK framework.
-"""
diff --git a/ark/system/simulation/simulator_backend.py b/ark/system/simulation/simulator_backend.py
deleted file mode 100644
index 5f18287..0000000
--- a/ark/system/simulation/simulator_backend.py
+++ /dev/null
@@ -1,146 +0,0 @@
-"""Abstract interface for simulation backends."""
-
-from abc import ABC, abstractmethod
-from enum import Enum
-from typing import Any, Dict, Optional
-
-from ark.tools.log import log
-from ark.system.driver.sensor_driver import SensorType
-from ark.system.component.robot import Robot
-from ark.system.component.sensor import Sensor
-from ark.system.component.sim_component import SimComponent
-
-
-class SimulatorBackend(ABC):
- """Base class for all simulator backends.
-
- The backend manages robots, sensors and other simulated components and
- exposes a minimal interface for stepping and resetting the simulation
- environment.
- """
-
- def __init__(self, global_config: Dict[str, Any]) -> None:
- """!Create and initialize the backend.
-
- @param global_config Dictionary describing the complete simulator
- configuration.
- """
- self.robot_ref: Dict[str, Robot] = {} # Key is robot name, value is config dict
- self.object_ref: Dict[str, SimComponent] = (
- {}
- ) # Key is object name, value is config dict
- self.sensor_ref: Dict[str, Sensor] = (
- {}
- ) # Key is sensor name, value is config dict
- self.ready: bool = False
- self._simulation_time: float = 0.0
- self.global_config = global_config
- self.initialize()
- self.ready = True
-
- def is_ready(self) -> bool:
- """!Check if the backend finished initialization."""
- return self.ready
-
- #########################
- ## Initialization ##
- #########################
-
- @abstractmethod
- def initialize(self) -> None:
- """!Initialize the simulator implementation."""
- ...
-
- @abstractmethod
- def set_gravity(self, gravity: tuple[float, float, float]) -> None:
- """!Set the gravity vector used by the simulator.
-
- @param gravity Tuple ``(x, y, z)`` representing the gravity vector.
- """
- ...
-
- @abstractmethod
- def reset_simulator(self) -> None:
- """!Reset the entire simulator state."""
- ...
-
- @abstractmethod
- def add_robot(
- self,
- name: str,
- global_config: dict[str, Any],
- ) -> None:
- """!Add a robot to the simulation.
-
- @param name Name of the robot.
- @param global_config Configuration dictionary for the robot.
- """
- ...
-
- @abstractmethod
- def add_sensor(
- self,
- name: str,
- sensor_type: SensorType,
- global_config: dict[str, Any],
- ) -> None:
- """!Add a sensor to the simulation.
-
- @param name Name of the sensor.
- @param sensor_type Type of the sensor.
- @param global_config Configuration dictionary for the sensor.
- """
- ...
-
- @abstractmethod
- def add_sim_component(
- self,
- name: str,
- type: str,
- global_config: dict[str, Any],
- ) -> None:
- """!Add a generic simulation object.
-
- @param name Name of the object.
- @param type Type identifier (e.g. ``"cube"``).
- @param global_config Configuration dictionary for the object.
- """
- ...
-
- @abstractmethod
- def remove(self, name: str) -> None:
- """!Remove a robot, sensor or object by name.
-
- @param name Name of the component to remove.
- """
- ...
-
- @abstractmethod
- def step(self) -> None:
- """!Advance the simulator by one timestep."""
- ...
-
- @abstractmethod
- def shutdown_backend(self) -> None:
- """!Shut down the simulator and free resources."""
- pass
-
- def _step_sim_components(self) -> None:
- """!Step all registered components."""
- for robot in self.robot_ref:
- if not self.robot_ref[robot]._is_suspended:
- self.robot_ref[robot].step_component()
- self.robot_ref[robot].control_robot()
- for obj in self.object_ref:
- self.object_ref[obj].step_component()
- for sensor in self.sensor_ref:
- self.sensor_ref[sensor].step_component()
-
- def _spin_sim_components(self) -> None:
- """!Spin components in manual mode."""
- for robot in self.robot_ref:
- self.robot_ref[robot].manual_spin()
- for obj in self.object_ref:
- self.object_ref[obj].manual_spin()
- for sensor in self.sensor_ref:
- self.sensor_ref[sensor].manual_spin()
diff --git a/ark/system/simulation/simulator_node.py b/ark/system/simulation/simulator_node.py
deleted file mode 100644
index cf13943..0000000
--- a/ark/system/simulation/simulator_node.py
+++ /dev/null
@@ -1,257 +0,0 @@
-"""Simulation node base implementation.
-
-This module provides :class:`SimulatorNode` which serves as the entry point
-for launching and controlling a simulator instance. It loads a global
-configuration, instantiates the desired backend and offers utilities for
-managing the simulation lifecycle. Concrete simulations should derive from
-this class and implement :func:`initialize_scene` and :func:`step`.
-"""
-
-import os
-from abc import ABC, abstractmethod
-from pathlib import Path
-from typing import Any
-
-from ark.client.comm_infrastructure.base_node import BaseNode
-from ark.tools.log import log
-from ark.utils.utils import ConfigPath
-from arktypes import flag_t
-
-
-class SimulatorNode(BaseNode, ABC):
- """Base class for simulator nodes.
-
- A :class:`SimulatorNode` wraps a simulation backend and exposes LCM
- services for stepping and resetting the simulation. Subclasses are
- expected to implement :func:`initialize_scene` to construct the initial
- environment and :func:`step` to execute custom logic on every simulation
- tick.
- """
-
- def __init__(
- self,
- global_config,
- observation_channels: dict[str, type] | None = None,
- action_channels: dict[str, type] | None = None,
- namespace: str = "ark",
- ):
- """!Construct the simulator node.
-
- The constructor loads the global configuration, instantiates the
- backend and sets up basic services for stepping and resetting the
- simulator.
-
- @param global_config Path to the configuration YAML file or a loaded
- configuration dictionary.
- """
- self._load_config(global_config)
- self.name = self.global_config["simulator"].get("name", "simulator")
-
- self.global_config["observation_channels"] = observation_channels
- self.global_config["action_channels"] = action_channels
- self.global_config["namespace"] = namespace
-
- super().__init__(self.name, global_config=global_config)
-
- log.info(
- "Initializing SimulatorNode called "
- + self.name
- + " with id "
- + self.node_id
- + " ..."
- )
-
- # Setup backend
- self.backend_type = self.global_config["simulator"]["backend_type"]
- if self.backend_type == "pybullet":
- from ark.system.pybullet.pybullet_backend import PyBulletBackend
- self.backend = PyBulletBackend(self.global_config)
- elif self.backend_type == "mujoco":
- from ark.system.mujoco.mujoco_backend import MujocoBackend
- self.backend = MujocoBackend(self.global_config)
- elif self.backend_type == "genesis":
- from ark.system.genesis.genesis_backend import GenesisBackend
- self.backend = GenesisBackend(self.global_config)
- elif self.backend_type in ["isaacsim", "isaac_sim", "isaac"]:
- from ark.system.isaac.isaac_backend import IsaacSimBackend
- self.backend = IsaacSimBackend(self.global_config)
- else:
- raise ValueError(f"Unsupported backend '{self.backend_type}'")
-
- # to initialize a scene with objects that dont need to publish, e.g. for visuals
- self.initialize_scene()
- self.step_physics()
-
- ## Reset Backend Service
- reset_service_name = f"{namespace}/" + self.name + "/backend/reset/sim"
- self.create_service(reset_service_name, flag_t, flag_t, self._reset_backend)
-
- custom_loop = getattr(self.backend, "custom_event_loop", None)
- self.custom_loop = True if callable(custom_loop) else False
- if not self.custom_loop:
- freq = self.global_config["simulator"]["config"].get(
- "node_frequency", 240.0
- )
- self.create_stepper(freq, self._step_simulation)
-
- def _load_config(self, global_config) -> None:
- """!Load and merge the global configuration.
-
- The configuration may either be provided as a path to a YAML file or
- already loaded into a dictionary. Included sub-configurations for
- robots, sensors and objects are resolved and merged.
-
- @param global_config Path to the configuration file or configuration
- dictionary.
- """
-
- if not global_config:
- raise ValueError("Please provide a global configuration file.")
-
- if isinstance(global_config, str):
- global_config = ConfigPath(global_config)
- elif isinstance(global_config, Path):
- global_config = ConfigPath(str(global_config))
- if not global_config.exists():
- raise ValueError(
- "Given configuration file path does not exist, currently: "
- + global_config.str
- )
-
- if not global_config.is_absolute():
- global_config = global_config.resolve()
-
- cfg = global_config.read_yaml()
-
- # assert that the config is a dict
- if not isinstance(cfg, dict):
- raise ValueError("The configuration file must be a valid dictionary.")
-
- # merge with subconfigs
- config = {}
- try:
- config["network"] = cfg["network"]
- except KeyError as e:
- config["network"] = None
- try:
- config["simulator"] = cfg["simulator"]
- except KeyError as e:
- raise ValueError(
- "Please provide at least name and backend_type under simulation in your config file."
- )
-
- try:
- config["robots"] = self._load_section(cfg, global_config, "robots")
- except KeyError as e:
- config["robots"] = {}
- try:
- config["sensors"] = self._load_section(cfg, global_config, "sensors")
- except KeyError as e:
- config["sensors"] = {}
- try:
- config["objects"] = self._load_section(cfg, global_config, "objects")
- except KeyError as e:
- config["objects"] = {}
-
- log.ok("Config file under " + global_config.str + " loaded successfully.")
- self.global_config = config
-
- def _load_section(
- self, cfg: dict[str, Any], config_path: str | ConfigPath, section_name: str
- ) -> dict[str, Any]:
- """!Load a sub‑configuration section.
-
- Sections may either be specified inline within the main configuration
- file or given as paths to external YAML files. The returned dictionary
- maps component names to their configuration dictionaries.
-
- @param cfg The top level configuration dictionary.
- @param config_path Absolute path to the loaded configuration file.
- @param section_name Name of the section to load (``"robots"``,
- ``"sensors"`` or ``"objects"``).
- @return Dictionary containing the merged configuration for the section.
- """
- # { "name" : { ... } },
- # "name" : { ... } }
- section_config = {}
- for item in cfg.get(section_name) or []:
- if isinstance(item, dict): # If it's an inline configuration
- subconfig = item
- elif isinstance(item, str) and item.endswith(
- ".yaml"
- ): # If it's a path to an external file
- if os.path.isabs(item): # Check if the path is absolute
- external_path = ConfigPath(item)
- else: # Relative path, use the directory of the main config file
- external_path = config_path.parent / item
- # Load the YAML file and return its content
- subconfig = external_path.read_yaml()
- else:
- log.error(
- f"Invalid entry in '{section_name}': {item}. Please provide either a config or a path to another config."
- )
- continue # Skip invalid entries
-
- section_config[subconfig["name"]] = subconfig["config"]
-
- return section_config
-
- def _reset_backend(self, channel, msg):
- """!Service callback resetting the backend."""
- self.backend.reset_simulator()
- return flag_t()
-
- def _step_simulation(self) -> None:
- """!Advance the simulation by one step and call :func:`step`."""
- self.step()
- self.backend.step()
-
- def step_physics(self, num_steps: int = 25, call_step_hook: bool = False) -> None:
- """
- Advance the physics backend
- Args:
- num_steps: Number of physics ticks to run.
- call_step_hook: If True, also invoke the subclass step() each tick.
-
- Returns:
- None
- """
- for _ in range(max(0, num_steps)):
- if call_step_hook:
- self.step()
- self.backend.step()
-
- def initialize_scene(self) -> None:
- """!Create the initial simulation scene."""
- pass
-
- def step(self) -> None:
- """!Hook executed every simulation step."""
- pass
-
- # OVERRIDE
- def spin(self) -> None:
- """!Run the node's main loop.
-
- The loop processes incoming LCM messages and forwards control to the
- backend for spinning all components. It terminates when an
- ``OSError`` occurs or :attr:`_done` is set to ``True``.
- """
- # Allow backends to provide their own event loop (e.g., IsaacSim main thread)
- if self.custom_loop:
- self.backend.custom_event_loop(sim_node=self)
- return
-
- while not self._done:
- try:
- self._lcm.handle_timeout(0)
- self.backend._spin_sim_components()
- except OSError as e:
- log.warning(f"LCM threw OSError {e}")
- self._done = True
-
- # OVERRIDE
- def kill_node(self) -> None:
- """!Shut down the node and the underlying backend."""
- self.backend.shutdown_backend()
- super().kill_node()
diff --git a/ark/tests/test_isaac_urdf_sim.py b/ark/tests/test_isaac_urdf_sim.py
deleted file mode 100644
index 61aa17c..0000000
--- a/ark/tests/test_isaac_urdf_sim.py
+++ /dev/null
@@ -1,177 +0,0 @@
-import os
-from pathlib import Path
-
-import numpy as np
-import pytest
-from isaacsim import SimulationApp
-
-app = SimulationApp({"headless": False})
-
-from isaacsim.core.api.objects import DynamicCuboid
-from isaacsim.robot.manipulators import SingleManipulator
-from isaacsim.robot.manipulators.grippers import ParallelGripper
-
-import omni.kit.commands
-from isaacsim.core.api import World
-from isaacsim.core.prims import Articulation
-from pxr import Gf, PhysxSchema, Sdf, UsdPhysics
-from isaacsim.robot.manipulators.examples.franka import KinematicsSolver
-
-
-# Allow overriding USD path; defaults to the requested asset.
-URDF_ASSET_PATH = Path(
- os.environ.get(
- "ARK_USD_PATH",
- "/home/refinath/ark/ark_franka/franka_panda/panda_with_gripper.urdf",
- )
-)
-
-omni_kit = pytest.importorskip(
- "omni.isaac.kit",
- reason="Isaac Sim Python packages are required to start the simulator.",
-)
-
-
-@pytest.mark.skipif(
- not URDF_ASSET_PATH.exists(),
- reason=(
- f"USD asset not found at {URDF_ASSET_PATH}. "
- "Set ARK_USD_PATH to point at a local copy."
- ),
-)
-def test_isaac_sim_headless_loads_usd():
- """Boot Isaac Sim headless, reference the USD, and verify a prim appears on the stage."""
-
- try:
- world = World(physics_dt=1 / 60.0, rendering_dt=1 / 60.0)
- world.scene.add_default_ground_plane()
-
- # Setting up import configuration:
- status, import_config = omni.kit.commands.execute("URDFCreateImportConfig")
- import_config.merge_fixed_joints = False
- import_config.fix_base = True
- import_config.import_inertia_tensor = True
- import_config.convex_decomp = False
-
- import_config.distance_scale = 1.0
- import_config.density = 0.0
- import_config.self_collision = False
- import_config.make_default_prim = True
- import_config.create_physics_scene = True
-
- # Import URDF, prim_path contains the path to the usd prim in the stage.
- status, prim_path = omni.kit.commands.execute(
- "URDFParseAndImportFile",
- urdf_path=str(URDF_ASSET_PATH),
- import_config=import_config,
- get_articulation_root=True,
- )
-
- # Get stage handle
- stage = omni.usd.get_context().get_stage()
-
- # Enable physics
- scene = UsdPhysics.Scene.Define(stage, Sdf.Path("/physicsScene"))
- # Set gravity
- scene.CreateGravityDirectionAttr().Set(Gf.Vec3f(0.0, 0.0, -1.0))
- scene.CreateGravityMagnitudeAttr().Set(9.81)
- # Set solver settings
- PhysxSchema.PhysxSceneAPI.Apply(stage.GetPrimAtPath("/physicsScene"))
- physxSceneAPI = PhysxSchema.PhysxSceneAPI.Get(stage, "/physicsScene")
- physxSceneAPI.CreateEnableCCDAttr(True)
- physxSceneAPI.CreateEnableStabilizationAttr(True)
- physxSceneAPI.CreateEnableGPUDynamicsAttr(False)
- physxSceneAPI.CreateBroadphaseTypeAttr("MBP")
- physxSceneAPI.CreateSolverTypeAttr("TGS")
-
- robot = Articulation(prim_path)
- world.scene.add(robot)
-
- rigid_bodies = []
- for prim in stage.Traverse():
- if prim.HasAPI(UsdPhysics.RigidBodyAPI):
- rigid_bodies.append(str(prim.GetPath()))
-
- print("Rigid bodies:", rigid_bodies)
-
- world.scene.add_default_ground_plane()
-
- omni.timeline.get_timeline_interface().play()
- app.update()
- robot.initialize()
- world.step(render=False)
-
- base_position = [0, 0, 0.2]
- base_orientation = [0, 0, 0, 1]
- q_int = [-0.3, 0.1, 0.3, -1.4, 0.1, 1.8, 0, 0, 0]
- robot.set_world_poses(
- positions=np.array([base_position]),
- orientations=np.array([base_orientation]),
- )
- robot.set_joint_positions([q_int])
- world.step(render=False)
-
- gripper = ParallelGripper(
- end_effector_prim_path=f"/panda/panda_hand",
- joint_prim_names=["panda_finger_joint1", "panda_finger_joint2"],
- joint_opened_positions=np.array([0.05, 0.05]),
- joint_closed_positions=np.array([0.02, 0.02]),
- action_deltas=np.array([0.01, 0.01]),
- )
-
- franka = world.scene.add(
- SingleManipulator(
- prim_path=prim_path,
- name="franka",
- end_effector_prim_path="/panda/panda_hand",
- gripper=gripper,
- )
- )
- franka.initialize()
-
- for _ in range(10):
- world.step(render=False)
-
- cube = world.scene.add(
- DynamicCuboid(
- name="cube",
- position=np.array([0.3, 0.3, 0.3]),
- prim_path="/World/Cube",
- scale=np.array([0.0515, 0.0515, 0.0515]),
- size=1.0,
- color=np.array([0, 0, 1]),
- )
- )
-
- franka.gripper.set_default_state(franka.gripper.joint_opened_positions)
-
- controller = KinematicsSolver(franka)
-
- for _ in range(100):
- world.step(render=True)
-
- position = [0.35000014901161194, 0.28286391496658325, 0.3802158236503601]
- quaternion = [
- 0.9999999403953552,
- 1.9853024113558604e-08,
- 1.1971462754445383e-07,
- 8.532768447366834e-08,
- ]
- actions, succ = controller.compute_inverse_kinematics(
- target_position=np.asarray(position),
- target_orientation=np.asarray(quaternion),
- )
- if succ:
- franka.apply_action(actions)
- else:
- print("IK did not converge to a solution. No action is being taken.")
-
- for _ in range(1000):
- world.step(render=True)
-
- finally:
- app.close()
-
-
-if __name__ == "__main__":
- test_isaac_sim_headless_loads_usd()
diff --git a/ark/tests/test_isaac_usd_sim.py b/ark/tests/test_isaac_usd_sim.py
deleted file mode 100644
index 3ad005a..0000000
--- a/ark/tests/test_isaac_usd_sim.py
+++ /dev/null
@@ -1,75 +0,0 @@
-import os
-from pathlib import Path
-
-import numpy as np
-import pytest
-
-# Allow overriding USD path; defaults to the requested asset.
-USD_ASSET_PATH = Path(
- os.environ.get(
- "ARK_USD_PATH",
- "/home/refinath/ark/ark_franka/franka_panda/panda_with_gripper.usd",
- )
-)
-
-omni_kit = pytest.importorskip(
- "omni.isaac.kit",
- reason="Isaac Sim Python packages are required to start the simulator.",
-)
-
-
-@pytest.mark.skipif(
- not USD_ASSET_PATH.exists(),
- reason=(
- f"USD asset not found at {USD_ASSET_PATH}. "
- "Set ARK_USD_PATH to point at a local copy."
- ),
-)
-def test_isaac_sim_headless_loads_usd():
- """Boot Isaac Sim headless, reference the USD, and verify a prim appears on the stage."""
- from isaacsim import SimulationApp
-
- app = SimulationApp({"headless": False})
- try:
- from isaacsim.core.api import World
- from isaacsim.core.utils.stage import add_reference_to_stage, get_stage_units
- from isaacsim.core.prims import Articulation
-
- # Prepare scene
- world = World(stage_units_in_meters=1.0)
- world.scene.add_default_ground_plane()
-
- # Add robot
- prim_path = "/World/Robot"
- component_name = "franka"
- add_reference_to_stage(str(USD_ASSET_PATH), prim_path)
- robot = Articulation(prim_paths_expr=prim_path, name=component_name)
-
- robot.set_world_poses(positions=np.array([[0.0, 1.0, 0.0]]) / get_stage_units())
-
- world.reset()
-
- for i in range(4):
- print("running cycle: ", i)
- if i == 1:
- print("moving")
- # move the arm
- robot.set_joint_positions(
- [[-1.5, 0.0, 0.0, -1.5, 0.0, 1.5, 0.5, 0.04, 0.04]]
- )
- if i == 2:
- print("stopping")
- # reset the arm
- robot.set_joint_positions(
- [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]
- )
- for j in range(1000):
- # step the simulation, both rendering and physics
- world.step(render=True)
-
- finally:
- app.close()
-
-
-if __name__ == "__main__":
- test_isaac_sim_headless_loads_usd()
diff --git a/ark/tests/test_usd_verify.py b/ark/tests/test_usd_verify.py
deleted file mode 100644
index f278e5c..0000000
--- a/ark/tests/test_usd_verify.py
+++ /dev/null
@@ -1,48 +0,0 @@
-import os
-from pathlib import Path
-
-import pytest
-
-# Allow overriding the USD path via env var for different environments; default to the requested asset.
-USD_ASSET_PATH = Path(
- os.environ.get(
- "ARK_USD_PATH",
- "/home/refinath/ark/ark_franka/franka_panda/panda_with_gripper.usd",
- )
-)
-
-pxr = pytest.importorskip(
- "pxr",
- reason="USD validation requires the Pixar USD Python bindings (pxr).",
-)
-from pxr import Usd
-
-
-@pytest.mark.skipif(
- not USD_ASSET_PATH.exists(),
- reason=(
- f"USD asset not found at {USD_ASSET_PATH}. "
- "Set ARK_USD_PATH to point at a local copy."
- ),
-)
-def test_usd_file_opens_and_has_valid_prims():
- """Basic sanity check that the USD file loads and contains usable prims."""
- stage = Usd.Stage.Open(str(USD_ASSET_PATH))
- assert stage is not None, f"Failed to open USD stage at {USD_ASSET_PATH}"
-
- prims = list(stage.Traverse())
- for prim in prims:
- print(prim)
- assert prims, f"USD stage at {USD_ASSET_PATH} has no prims when traversed"
-
- any_valid_non_root = any(
- prim.IsValid() and str(prim.GetPath()) != "/" for prim in prims
- )
- assert any_valid_non_root or prims[0].IsValid(), "USD stage contains no valid prims"
-
- default_prim = stage.GetDefaultPrim()
- if default_prim:
- assert default_prim.IsValid(), "Default prim in USD stage is not valid"
-
-if __name__ == "__main__":
- test_usd_file_opens_and_has_valid_prims()
diff --git a/ark/tests/test_vector_env.py b/ark/tests/test_vector_env.py
deleted file mode 100644
index 76674e0..0000000
--- a/ark/tests/test_vector_env.py
+++ /dev/null
@@ -1,160 +0,0 @@
-import argparse
-import os
-from pathlib import Path
-
-import numpy as np
-
-from ark.env.ark_env import ArkEnv
-
-from ark.env.franka_env import FrankaEnv
-from ark.env.vector_env import make_vector_env, make_sim
-from ark.utils.communication_utils import (
- build_action_space,
- build_observation_space,
- get_channel_types,
- _dynamic_observation_unpacker,
-)
-from ark.utils.utils import ConfigPath
-
-
-class DemoEnv(ArkEnv):
- def __init__(self, channel_schema, global_config, namespace: str, sim: bool = True):
- super().__init__(
- environment_name="demo_env",
- channel_schema=channel_schema,
- global_config=global_config,
- namespace=namespace,
- sim=sim,
- )
-
- def reset(self, *, seed=None, options=None):
- obs, info = super().reset()
-
- print("\n[DEBUG] reset() obs shapes for env:", self.namespace)
- # self._print_obs_shapes(obs)
-
- return obs, info
-
- def _print_obs_shapes(self, obs):
- for k, v in obs.items():
- try:
- arr = np.asarray(v)
- print(f" {k:35s} shape={arr.shape} dtype={arr.dtype}")
- except Exception as e:
- print(f" {k:35s} ERROR converting: {e}")
-
- def step(self, action):
- obs, reward, terminated, truncated, info = super().step(action)
- return obs, reward, terminated, truncated, info
-
- def reset_objects(self):
- self.reset_component("cube")
- self.reset_component("target")
- self.reset_component("franka")
-
- @staticmethod
- def _create_reward_functions():
- return {}
-
- @staticmethod
- def _create_termination_conditions():
- return {}
-
-
-def run_franka_vector_demo(
- channel_schema: str, config_path: str, num_envs: int = 2, num_steps: int = 5
-) -> None:
- """
- Simple driver to instantiate a vectorized FrankaEnv batch via
- make_vector_env and print observations, rewards and done flags
- for a few steps to verify behavior.
-
- This is intended as an integration smoke-test rather than a pure
- unit test and assumes the Ark backend is available.
- """
- print(f"Creating {num_envs} FrankaEnv instances...")
-
- vec_env = make_vector_env(
- DemoEnv,
- num_envs=num_envs,
- channel_schema=channel_schema,
- global_config=config_path,
- sim=True,
- asynchronous=False,
- )
-
- obs, info = vec_env.reset()
- print("Initial obs:")
- for i in range(vec_env.num_envs):
- per_env_obs = {k: v[i] for k, v in obs.items()}
- print(f" env[{i}] initial obs:", per_env_obs)
-
- for step in range(num_steps):
- # Build distinct actions per env so we can trace them
- actions = vec_env.action_space.sample()
- if isinstance(actions, np.ndarray):
- # For Box spaces: ensure each env has a different first element
- for i in range(vec_env.num_envs):
- actions[i, ...] = i
-
- obs, rewards, terminated, truncated, infos = vec_env.step(actions)
-
- print(f"\nStep {step}:")
- for i in range(vec_env.num_envs):
- per_env_obs = {k: v[i] for k, v in obs.items()}
- print(f" env[{i}]:")
- print(" action :", actions[i])
- print(" obs :", per_env_obs)
- print(" reward :", rewards[i])
- print(" terminated:", terminated[i], "truncated:", truncated[i])
-
- print("Final reset")
- obs, info = vec_env.reset()
-
-
-def main() -> None:
- parser = argparse.ArgumentParser(
- description=(
- "Test for make_vector_env using FrankaEnv instances. "
- "Assumes Ark comms and simulation are running."
- )
- )
- parser.add_argument(
- "--channel-schema",
- type=str,
- default="ark_framework/ark/configs/franka_panda.yaml",
- help="Path to RL channel schema YAML (with observation_space/action_space).",
- )
- parser.add_argument(
- "--config-path",
- type=str,
- default="ark_diffusion_policies_on_franka/diffusion_policy/config/global_config.yaml",
- help="Path to Ark global_config.yaml.",
- )
- parser.add_argument(
- "--num-envs", type=int, default=1, help="Number of parallel Franka envs."
- )
- parser.add_argument(
- "--num-steps", type=int, default=5, help="Number of rollout steps to print."
- )
-
- args = parser.parse_args()
-
- channel_schema = os.path.abspath(args.channel_schema)
- config_path = os.path.abspath(args.config_path)
-
- if not Path(channel_schema).exists():
- raise FileNotFoundError(f"Channel schema not found: {channel_schema}")
- if not Path(config_path).exists():
- raise FileNotFoundError(f"Config path not found: {config_path}")
-
- run_franka_vector_demo(
- channel_schema=channel_schema,
- config_path=config_path,
- num_envs=args.num_envs,
- num_steps=args.num_steps,
- )
-
-
-if __name__ == "__main__":
- main()
diff --git a/ark/tools/__init__.py b/ark/tools/__init__.py
deleted file mode 100644
index 8b13789..0000000
--- a/ark/tools/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-
diff --git a/ark/tools/ark_graph/__init__.py b/ark/tools/ark_graph/__init__.py
deleted file mode 100644
index 8b13789..0000000
--- a/ark/tools/ark_graph/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-
diff --git a/ark/tools/ark_graph/ark_graph.py b/ark/tools/ark_graph/ark_graph.py
deleted file mode 100644
index 3dec6d2..0000000
--- a/ark/tools/ark_graph/ark_graph.py
+++ /dev/null
@@ -1,586 +0,0 @@
-import argparse
-import json
-import sys
-
-# --- Third-party/project-specific imports ---
-from ark.client.comm_handler.service import send_service_request
-from ark.client.comm_infrastructure.endpoint import EndPoint
-from ark.tools.log import log
-from ark.global_constants import *
-from arktypes import flag_t, network_info_t
-
-import matplotlib.pyplot as plt
-from graphviz import Digraph
-import typer
-
-from dataclasses import dataclass
-from pathlib import Path
-
-# Render the image with matplotlib
-import io
-from PIL import Image
-
-app = typer.Typer()
-
-DEFAULT_SERVICE_DECORATOR = "__DEFAULT_SERVICE"
-
-
-# ----------------------------------------------------------------------
-# DATA CLASSES
-# ----------------------------------------------------------------------
-@dataclass
-class BaseGraph:
- """
- Graph base class.
-
- This class serves as a base for other classes representing different
- types of diagrams like `Flowchart`, `ERDiagram`, etc.
-
- Attributes:
- title (str): The title of the diagram.
- script (str): The main script to create the diagram.
- """
-
- title: str
- script: str
-
- def save(self, path=None) -> None:
- """
- Save the diagram to a file.
-
- Args:
- path (Optional[Union[Path,str]]): The path to save the diagram. If not
- provided, the diagram will be saved in the current directory
- with the title as the filename.
-
- Raises:
- ValueError: If the file extension is not '.gv' or '.dot'.
- """
- if path is None:
- path = Path(f"./{self.title}.gv")
- if isinstance(path, str):
- path = Path(path)
-
- if path.suffix not in [".gv", ".dot"]:
- raise ValueError("File extension must be '.gv' or '.dot'")
-
- with open(path, "w") as file:
- file.write(self.script)
-
- def _build_script(self) -> None:
- """
- Internal helper to finalize the script content for the diagram.
- """
- script: str = f"---\ntitle: {self.title}\n---"
- script += self.script
- self.script = script
-
-
-class ServiceInfo:
- """
- Encapsulates service-related information for a node.
-
- Attributes:
- comms_type (str): The communications type (e.g., TCP, UDP, etc.).
- service_name (str): The name of the service.
- service_host (str): The hostname/IP of the service.
- service_port (int): The port used by the service.
- registry_host (str): The registry host for service discovery.
- registry_port (int): The registry port for service discovery.
- request_type (str): The request LCM type.
- response_type (str): The response LCM type.
- """
-
- def __init__(
- self,
- comms_type: str,
- service_name: str,
- service_host: str,
- service_port: int,
- registry_host: str,
- registry_port: int,
- request_type: str,
- response_type: str,
- ):
- self.comms_type = comms_type
- self.service_name = service_name
- self.service_host = service_host
- self.service_port = service_port
- self.registry_host = registry_host
- self.registry_port = registry_port
- self.request_type = request_type
- self.response_type = response_type
-
-
-class ListenerInfo:
- """
- Encapsulates listener-related information for a node.
-
- Attributes:
- comms_type (str): The communications type (e.g., LCM).
- channel_name (str): The name of the channel.
- channel_type (str): The message type on that channel.
- channel_status (str): The status (e.g., active/inactive).
- """
-
- def __init__(
- self, comms_type: str, channel_name: str, channel_type: str, channel_status: str
- ):
- self.comms_type = comms_type
- self.channel_name = channel_name
- self.channel_type = channel_type
- self.channel_status = channel_status
-
-
-class SubscriberInfo:
- """
- Encapsulates subscriber-related information for a node.
-
- Attributes:
- comms_type (str): The communications type (e.g., LCM).
- channel_name (str): The name of the channel.
- channel_type (str): The message type on that channel.
- channel_status (str): The status (e.g., active/inactive).
- """
-
- def __init__(
- self, comms_type: str, channel_name: str, channel_type: str, channel_status: str
- ):
- self.comms_type = comms_type
- self.channel_name = channel_name
- self.channel_type = channel_type
- self.channel_status = channel_status
-
-
-class PublisherInfo:
- """
- Encapsulates publisher-related information for a node.
-
- Attributes:
- comms_type (str): The communications type (e.g., LCM).
- channel_name (str): The name of the channel.
- channel_type (str): The message type on that channel.
- channel_status (str): The status (e.g., active/inactive).
- """
-
- def __init__(
- self, comms_type: str, channel_name: str, channel_type: str, channel_status: str
- ):
- self.comms_type = comms_type
- self.channel_name = channel_name
- self.channel_type = channel_type
- self.channel_status = channel_status
-
-
-class CommsInfo:
- """
- Encapsulates all communications (listeners/subscribers/publishers/services) for a node.
-
- Attributes:
- n_listeners (int): Number of listeners on this node.
- listeners (List[ListenerInfo]): A list of listener info objects.
- n_subscribers (int): Number of subscribers on this node.
- subscribers (List[SubscriberInfo]): A list of subscriber info objects.
- n_publishers (int): Number of publishers on this node.
- publishers (List[PublisherInfo]): A list of publisher info objects.
- n_services (int): Number of services on this node.
- services (List[ServiceInfo]): A list of service info objects.
- """
-
- def __init__(
- self,
- n_listeners: int,
- listeners: list,
- n_subscribers: int,
- subscribers: list,
- n_publishers: int,
- publishers: list,
- n_services: int,
- services: list,
- ):
- self.n_listeners = n_listeners
- self.listeners = listeners
- self.n_subscribers = n_subscribers
- self.subscribers = subscribers
- self.n_publishers = n_publishers
- self.publishers = publishers
- self.n_services = n_services
- self.services = services
-
-
-class NodeInfo:
- """
- Encapsulates information about a single node in the network.
-
- Attributes:
- node_name (str): The name of the node (e.g., "Camera").
- node_id (str): A unique identifier for the node.
- comms (CommsInfo): Communication details for the node.
- """
-
- def __init__(self, node_name: str, node_id: str, comms: CommsInfo):
- self.name = node_name
- self.node_id = node_id
- self.comms = comms
-
-
-class NetworkInfo:
- """
- Encapsulates network-level information for multiple nodes.
-
- Attributes:
- num_nodes (int): The number of nodes in the network.
- nodes (List[NodeInfo]): A list of NodeInfo objects.
- """
-
- def __init__(self, n_nodes: int, nodes: list):
- self.num_nodes = n_nodes
- self.nodes = nodes
-
-
-# ----------------------------------------------------------------------
-# DECODING & HELPER FUNCTIONS
-# ----------------------------------------------------------------------
-def decode_network_info(lcm_message) -> NetworkInfo:
- """
- Converts an LCM network info message into a NetworkInfo object.
-
- Args:
- lcm_message (network_info_t): The LCM message containing network information.
-
- Returns:
- NetworkInfo: A NetworkInfo object with detailed node and comms information.
- """
- return NetworkInfo(
- n_nodes=lcm_message.n_nodes,
- nodes=[
- NodeInfo(
- node_name=node.node_name,
- node_id=node.node_id,
- comms=CommsInfo(
- n_listeners=node.comms.n_listeners,
- listeners=[
- ListenerInfo(
- comms_type=listener.comms_type,
- channel_name=listener.channel_name,
- channel_type=listener.channel_type,
- channel_status=listener.channel_status,
- )
- for listener in node.comms.listeners
- ],
- n_subscribers=node.comms.n_subscribers,
- subscribers=[
- SubscriberInfo(
- comms_type=subscriber.comms_type,
- channel_name=subscriber.channel_name,
- channel_type=subscriber.channel_type,
- channel_status=subscriber.channel_status,
- )
- for subscriber in node.comms.subscribers
- ],
- n_publishers=node.comms.n_publishers,
- publishers=[
- PublisherInfo(
- comms_type=publisher.comms_type,
- channel_name=publisher.channel_name,
- channel_type=publisher.channel_type,
- channel_status=publisher.channel_status,
- )
- for publisher in node.comms.publishers
- ],
- n_services=node.comms.n_services,
- services=[
- ServiceInfo(
- comms_type=service.comms_type,
- service_name=service.service_name,
- service_host=service.service_host,
- service_port=service.service_port,
- registry_host=service.registry_host,
- registry_port=service.registry_port,
- request_type=service.request_type,
- response_type=service.response_type,
- )
- for service in node.comms.services
- ],
- ),
- )
- for node in lcm_message.nodes
- ],
- )
-
-
-def network_info_lcm_to_dict(lcm_message) -> dict:
- """
- Converts an LCM network info message into a Python dictionary,
- allowing easy serialization or manipulation.
-
- Args:
- lcm_message (network_info_t): The LCM message containing network information.
-
- Returns:
- dict: A dictionary representation of the network info.
- """
- network_info_obj = decode_network_info(lcm_message)
- return json.loads(json.dumps(network_info_obj, default=lambda o: o.__dict__))
-
-
-# ----------------------------------------------------------------------
-# GRAPHVIZ VISUALIZATION
-# ----------------------------------------------------------------------
-def graph_viz_plot(data: dict):
- """
- Generate a GraphViz diagram from the given network data and display it using Matplotlib.
-
- Args:
- data (dict): A dictionary containing the network information.
- Typically the output of `network_info_lcm_to_dict(...)`.
-
- Returns:
- Image: The generated PIL Image containing the graph visualisation.
-
- Notes:
- Service nodes are drawn with single borders on the top and bottom and
- double borders on the left and right to resemble ``|| service ||``.
- """
- dot = Digraph(format="png")
- dot.attr("graph", fontname="Helvetica")
- dot.attr("node", fontname="Helvetica")
- dot.attr("edge", fontname="Helvetica")
-
- channel_id_map = {}
- id_counter = 1
-
- def get_channel_id(channel_name: str) -> str:
- nonlocal id_counter
- if channel_name not in channel_id_map:
- channel_id_map[channel_name] = f"ch_{id_counter}"
- id_counter += 1
- return channel_id_map[channel_name]
-
- # Build the graph
- for node in data["nodes"]:
- node_id = node["node_id"]
- node_name = node["name"]
-
- dot.node(node_id, node_name, shape="box", style="filled", fillcolor="lightblue")
-
- publishers = [pub["channel_name"] for pub in node["comms"]["publishers"]]
- subscribers = [sub["channel_name"] for sub in node["comms"]["subscribers"]]
- listeners = [lis["channel_name"] for lis in node["comms"]["listeners"]]
- services = [ser["service_name"] for ser in node["comms"]["services"]]
-
- for pub in publishers:
- pub_id = get_channel_id(pub)
- dot.node(
- pub_id, pub, shape="box", style="rounded,filled", fillcolor="white"
- )
- dot.edge(node_id, pub_id)
-
- for sub in subscribers:
- sub_id = get_channel_id(sub)
- dot.node(
- sub_id, sub, shape="box", style="rounded,filled", fillcolor="white"
- )
- dot.edge(sub_id, node_id)
-
- for lis in listeners:
- lis_id = get_channel_id(lis)
- dot.node(
- lis_id, lis, shape="box", style="rounded,filled", fillcolor="white"
- )
- dot.edge(lis_id, node_id)
-
- for ser in services:
- if ser.startswith(DEFAULT_SERVICE_DECORATOR):
- continue
- ser_id = get_channel_id(ser)
- service_label = (
- "<"
- ""
- ">"
- )
- dot.node(ser_id, label=service_label, shape="plaintext")
- dot.edge(node_id, ser_id)
-
- graph_image = dot.pipe()
- image_stream = io.BytesIO(graph_image)
- image = Image.open(image_stream)
-
- return image
-
-
-# ----------------------------------------------------------------------
-# MAIN CLASS
-# ----------------------------------------------------------------------
-class ArkGraph(EndPoint):
- """Endpoint that retrieves network info and renders a GraphViz diagram.
-
- The diagram can either be displayed immediately or saved for later use.
-
- Attributes:
- registry_host (str): The host of the registry server.
- registry_port (int): The port of the registry server.
- lcm_network_bounces (int): LCM network bounces for deeper network queries.
- """
-
- def __init__(
- self,
- registry_host: str = "127.0.0.1",
- registry_port: int = 1234,
- lcm_network_bounces: int = 1,
- *,
- display: bool = True,
- ):
- """
- Initializes the ArkGraph endpoint with registry configuration.
-
- Args:
- registry_host (str): The host address for the registry server.
- registry_port (int): The port for the registry server.
- lcm_network_bounces (int): LCM network bounces for deeper network queries.
- display (bool, optional): Whether to immediately display the diagram.
- If ``False``, the image can still be saved via :meth:`save_image`.
- """
- config = {
- "network": {
- "registry_host": registry_host,
- "registry_port": registry_port,
- "lcm_network_bounces": lcm_network_bounces,
- }
- }
- super().__init__(config)
-
- # Query the registry for network information
- req = flag_t()
- response_lcm = send_service_request(
- self.registry_host,
- self.registry_port,
- f"{DEFAULT_SERVICE_DECORATOR}/GetNetworkInfo",
- req,
- network_info_t,
- )
-
- # Convert LCM response to a dictionary
- data = network_info_lcm_to_dict(response_lcm)
-
- # Generate the GraphViz diagram
- self.plot_image = graph_viz_plot(data)
- if display:
- self.display_image(self.plot_image)
-
- def save_image(self, file_path: str | Path) -> None:
- """Save the generated diagram image to ``file_path``.
-
- Only ``.png`` files are supported.
- """
- if isinstance(file_path, str):
- file_path = Path(file_path)
-
- if file_path.suffix.lower() != ".png":
- raise ValueError("File extension must be '.png'")
-
- file_path.parent.mkdir(parents=True, exist_ok=True)
- self.plot_image.save(file_path)
-
- @staticmethod
- def get_cli_doc() -> str:
- """
- Return CLI help documentation.
- """
- return __doc__
-
- def display_image(self, plot_image):
- """
- Display the GraphViz diagram image using Matplotlib.
-
- Args:
- plot_image (Image.Image): The PIL Image containing the diagram.
- """
- plt.imshow(plot_image)
- plt.axis("off")
- plt.show()
-
-
-# ----------------------------------------------------------------------
-# COMMAND-LINE INTERFACE
-# ----------------------------------------------------------------------
-def parse_args():
- """
- Parse command-line arguments for running ArkGraph as a script.
-
- Returns:
- argparse.Namespace: The parsed arguments with `registry_host` and `registry_port`.
- """
- parser = argparse.ArgumentParser(
- description="ArkGraph"
- )
- parser.add_argument(
- "--registry_host",
- type=str,
- default="127.0.0.1",
- help="The host address for the registry server.",
- )
- parser.add_argument(
- "--registry_port",
- type=int,
- default=1234,
- help="The port for the registry server.",
- )
- return parser.parse_args()
-
-
-# ----------------------------------------------------------------------
-# MAIN
-# ----------------------------------------------------------------------
-@app.command()
-def start(
- registry_host: str = typer.Option(
- "127.0.0.1", "--host", help="The host address for the registry server."
- ),
- registry_port: int = typer.Option(
- 1234, "--port", help="The port for the registry server."
- ),
-):
- """Starts the graph with specified host and port."""
- server = ArkGraph(registry_host=registry_host, registry_port=registry_port)
-
-
-@app.command()
-def save(
- file_path: str,
- registry_host: str = typer.Option(
- "127.0.0.1", "--host", help="The host address for the registry server."
- ),
- registry_port: int = typer.Option(
- 1234, "--port", help="The port for the registry server."
- ),
-):
- """Save the graph image to ``FILE_PATH`` without displaying it.
-
- ``FILE_PATH`` must end with ``.png``.
- """
- server = ArkGraph(
- registry_host=registry_host, registry_port=registry_port, display=False
- )
- try:
- server.save_image(file_path)
- except ValueError as exc:
- log.error(str(exc))
- raise typer.Exit(code=1)
- else:
- log.ok(f"Graph saved to {file_path}")
-
-
-def main():
- """Entry point for the CLI."""
- app() # Initializes the Typer CLI
-
-
-if __name__ == "__main__":
- main()
diff --git a/ark/tools/data_logging/__init__.py b/ark/tools/data_logging/__init__.py
deleted file mode 100644
index 8b13789..0000000
--- a/ark/tools/data_logging/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-
diff --git a/ark/tools/data_logging/lcm_logger.py b/ark/tools/data_logging/lcm_logger.py
deleted file mode 100644
index e1baac8..0000000
--- a/ark/tools/data_logging/lcm_logger.py
+++ /dev/null
@@ -1,97 +0,0 @@
-from typing import Optional, Any
-
-from ark.client.comm_infrastructure.base_node import BaseNode, main
-from ark.tools.log import log
-from arktypes import string_t, flag_t, status_t
-import subprocess
-
-
-class LoggerNode(BaseNode):
-
- def __init__(self, name: str, config: Optional[dict[str, Any]] = None):
- """
- @brief Construct the logger node and register services.
-
- @param name Node name.
- @param config Optional configuration dictionary (unused).
- """
- super().__init__(name)
- self.proc: Optional[subprocess.Popen] = None
-
- self.create_service(
- "logger/start", string_t, status_t, self.start_logging
- )
- self.create_service(
- "logger/stop", flag_t, status_t, self.stop_logging
- )
-
- def start_logging(self, channel: str, msg: string_t) -> status_t:
- """
- @brief Start an LCM logging session if none is running.
-
- @param channel Service channel name (unused).
- @param msg Output file prefix/path (`string_t.data`) for `lcm-logger`.
-
- @return `status_t`
- """
- out = status_t()
-
- if self.proc is not None:
- log.warning(
- "lcm-logger already running; refusing to start a second session."
- )
- out.success = False
- out.message = "lcm-logger already running"
- return out
-
- try:
- log.info("Started logging")
- self.proc = subprocess.Popen(
- ["lcm-logger", msg.data],
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE,
- start_new_session=True,
- )
- out.success = True
- out.message = "lcm-logger started successfully"
- except Exception as e:
- log.error("Failed to start lcm-logger: %s", e)
- self.proc = None
- out.success = False
- out.message = str(e)
-
- return out
-
- def stop_logging(self, channel: str, msg: flag_t) -> status_t:
- """
- @brief Stop the current LCM logging session, if running.
-
- @param channel Service channel name.
- @param msg Input `flag_t` (unused).
- """
- out = status_t()
-
- if self.proc is None:
- log.warning("No lcm-logger session is running.")
- out.success = False
- out.message = "No lcm-logger session is running."
- return out
-
- try:
- self.proc.kill()
- self.proc.wait(timeout=5)
- log.info("Stopped logging")
- del self.proc
- self.proc = None
- out.success = True
- out.message = "lcm-logger stopped successfully"
- except Exception as e:
- log.error("Failed to stop lcm-logger: %s", e)
- out.success = False
- out.message = str(e)
-
- return out
-
-
-if __name__ == "__main__":
- main(LoggerNode, "Logger")
diff --git a/ark/tools/data_logging/lcm_to_csv.py b/ark/tools/data_logging/lcm_to_csv.py
deleted file mode 100644
index 45675fc..0000000
--- a/ark/tools/data_logging/lcm_to_csv.py
+++ /dev/null
@@ -1,133 +0,0 @@
-import struct
-import pandas as pd
-import argparse
-from arktypes import * # Assuming these are your generated LCM message types
-
-LCM_SYNC_WORD = 0xEDA1DA01 # The sync word for LCM log events
-
-
-class LCMLogParser:
- def __init__(self, input_filename, channel_config):
- """
- Initialize the LCM log parser.
-
- :param input_filename: Path to the LCM log file.
- :param channel_config: A list of tuples (channel_name, lcm_message_type) for decoding.
- """
- self.input_filename = input_filename
- self.channel_config = {
- channel: message_type for channel, message_type in channel_config
- }
- self.df = None
-
- def parse(self):
- """
- Parse the LCM log file into a Pandas DataFrame.
- :return: Pandas DataFrame containing the parsed and decoded data.
- """
- events = []
-
- with open(self.input_filename, "rb") as log_file:
- event_number = 0
-
- while True:
- header = log_file.read(28)
- if len(header) < 28:
- break # End of file
-
- (
- sync_word,
- event_number_upper,
- event_number_lower,
- timestamp_upper,
- timestamp_lower,
- channel_length,
- data_length,
- ) = struct.unpack(">I2I2I2I", header)
-
- if sync_word != LCM_SYNC_WORD:
- raise ValueError(
- f"Sync word mismatch. Expected {hex(LCM_SYNC_WORD)} but got {hex(sync_word)}."
- )
-
- event_number = (event_number_upper << 32) | event_number_lower
- timestamp = (timestamp_upper << 32) | timestamp_lower
- channel_name = log_file.read(channel_length).decode("utf-8")
- message_data = log_file.read(data_length)
-
- decoded_message_json = None
- if channel_name in self.channel_config:
- try:
- message_type = self.channel_config[channel_name]
- decoded_message = message_type.decode(message_data)
- decoded_message_json = self.decode_to_json(decoded_message)
- except Exception as e:
- print(f"Error decoding {channel_name} data: {e}")
- else:
- print(f"Unknown channel {channel_name} (data not decoded)")
-
- events.append(
- {
- "Event Number": event_number,
- "Timestamp": timestamp,
- "Channel": channel_name,
- "Data Length": data_length,
- "Message Data": message_data.hex(),
- "Decoded Message": decoded_message_json,
- }
- )
-
- return pd.DataFrame(events)
-
- @staticmethod
- def decode_to_json(decoded_message):
- """
- Convert an LCM decoded message to a JSON-like dictionary.
- :param decoded_message: The decoded LCM message.
- :return: A dictionary representing the message.
- """
- message_dict = {}
- for field in decoded_message.__slots__:
- value = getattr(decoded_message, field)
- if isinstance(value, (list, tuple)):
- message_dict[field] = [
- LCMLogParser.decode_to_json(v) if hasattr(v, "__slots__") else v
- for v in value
- ]
- elif hasattr(value, "__slots__"):
- message_dict[field] = LCMLogParser.decode_to_json(value)
- else:
- message_dict[field] = value
- return message_dict
-
- def save_to_csv(self, output_filepath):
- """
- Save the parsed LCM log data to a CSV file.
- :param output_filepath: Path to save the CSV file.
- """
- if self.df is None:
- self.df = self.parse()
- self.df.to_csv(output_filepath, index=False)
- print(f"Data has been saved to {output_filepath}")
-
- def get_dataframe(self):
- if self.df is None:
- self.df = self.parse()
- return self.df
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser(
- description="Parse LCM log files and output to CSV."
- )
- parser.add_argument("input_filepath", help="Path to the input LCM log file")
- parser.add_argument("output_filepath", help="Path to the output CSV file")
- args = parser.parse_args()
-
- channel_config = [
- ("viper/joint_states", joint_state_t),
- # ("transforms", ee_pos_t),
- ]
-
- parser = LCMLogParser(args.input_filepath, channel_config)
- parser.save_to_csv(args.output_filepath)
diff --git a/ark/tools/juypter_tools/__init__.py b/ark/tools/juypter_tools/__init__.py
deleted file mode 100644
index 8b13789..0000000
--- a/ark/tools/juypter_tools/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-
diff --git a/ark/tools/juypter_tools/juypter_tools.py b/ark/tools/juypter_tools/juypter_tools.py
deleted file mode 100644
index df23b5f..0000000
--- a/ark/tools/juypter_tools/juypter_tools.py
+++ /dev/null
@@ -1,66 +0,0 @@
-import numpy as np
-import cv2
-from IPython.display import clear_output, display
-import matplotlib.pyplot as plt
-from arktypes import image_t
-
-
-# Define num_channels for different pixel formats
-num_channels = {
- image_t.PIXEL_FORMAT_RGB: 3, # RGB has 3 channels
- image_t.PIXEL_FORMAT_RGBA: 4, # RGBA has 4 channels
- image_t.PIXEL_FORMAT_GRAY: 1, # Grayscale has 1 channel
-}
-
-
-def process_and_display_image(image_msg):
- # Decode the image data
- img_data = np.frombuffer(image_msg.data, dtype=np.uint8)
-
- # Handle compression
- if image_msg.compression_method in (
- image_t.COMPRESSION_METHOD_JPEG,
- image_t.COMPRESSION_METHOD_PNG,
- ):
- # Decompress image
- img = cv2.imdecode(img_data, cv2.IMREAD_COLOR)
- if img is None:
- print("Failed to decompress image")
- return
- elif image_msg.compression_method == image_t.COMPRESSION_METHOD_NOT_COMPRESSED:
- # Determine the number of channels based on pixel_format
- try:
- nchannels = num_channels[image_msg.pixel_format]
- except KeyError:
- print("Unsupported pixel format")
- return
-
- # Reshape the data to the original image dimensions
- try:
- img = img_data.reshape((image_msg.height, image_msg.width, nchannels))
- except ValueError as e:
- print(f"Error reshaping image data: {e}")
- return
-
- # Handle pixel format conversion if necessary
- if image_msg.pixel_format == image_t.PIXEL_FORMAT_RGB:
- # Convert RGB to BGR for OpenCV
- img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
- elif image_msg.pixel_format == image_t.PIXEL_FORMAT_RGBA:
- # Convert RGBA to BGRA for OpenCV
- img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGRA)
- elif image_msg.pixel_format == image_t.PIXEL_FORMAT_GRAY:
- # No conversion needed for grayscale
- pass
- # For BGR and BGRA, no conversion is needed as OpenCV uses BGR format
- else:
- print("Unsupported compression method")
- return
-
- # Now display the image dynamically in Jupyter
- clear_output(wait=True) # Clear previous image
- plt.imshow(
- cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
- ) # Convert BGR to RGB for proper display
- plt.axis("off") # Hide axes
- plt.show()
diff --git a/ark/tools/language_input/text_repeater.py b/ark/tools/language_input/text_repeater.py
deleted file mode 100644
index 648df31..0000000
--- a/ark/tools/language_input/text_repeater.py
+++ /dev/null
@@ -1,73 +0,0 @@
-from ark.client.comm_infrastructure.base_node import BaseNode, main
-from arktypes import string_t
-from pathlib import Path
-import argparse
-
-from ark.utils.utils import ConfigPath
-
-from ark.tools.log import log
-
-
-class TextRepeaterNode(BaseNode):
- def __init__(self, node_name: str, global_config: str):
- super().__init__(node_name, global_config)
- # Required keys in the global config
- global_config_path = ConfigPath(global_config)
- self.config = global_config_path.read_yaml()
- text_path = self.config.get("text_path", "")
- text = self.config.get("text", "")
- channel = self.config.get("channel", "user_input")
- freq = self.config.get("freq", 1)
-
- # Exactly one of 'text' or 'text_path' should be provided; both default to "".
- if text and text_path:
- raise ValueError("Pass only one of 'text' or 'text_path' (not both).")
-
- if text_path:
- p = Path(text_path)
- if not p.is_file():
- raise FileNotFoundError(
- f"'text_path' does not exist or is not a file: {p}"
- )
- self.text = p.read_text()
- else:
- self.text = text
-
- if len(self.text) == 0:
- raise ValueError("Text is empty.")
-
- self.text_msg = string_t()
- self.text_msg.data = self.text
-
- # Publisher on requested channel
- self.pub = self.create_publisher(channel, string_t)
-
- # Stepper that publishes at the requested frequency
- self.publish_text = lambda: self.pub.publish(self.text_msg)
- self.create_stepper(freq, self.publish_text)
-
-
-def get_args():
- parser = argparse.ArgumentParser(
- description="Publishes a text file as a string at a given frequency, using a global config.",
- )
- parser.add_argument(
- "--node-name",
- type=str,
- required=True,
- help="Name of this node.",
- )
- parser.add_argument(
- "--config",
- type=str,
- required=True,
- help="Path to global config file (.json or .yaml/.yml) containing: text_path, channel, freq.",
- )
- args = parser.parse_args()
- return args.node_name, args.config
-
-
-if __name__ == "__main__":
- node_name, global_config = get_args()
- # Pass exactly (node_name, global_config) to match TextRepeaterNode.__init__
- main(TextRepeaterNode, node_name, global_config)
diff --git a/ark/tools/launcher.py b/ark/tools/launcher.py
deleted file mode 100644
index 67d6820..0000000
--- a/ark/tools/launcher.py
+++ /dev/null
@@ -1,332 +0,0 @@
-import sys
-import yaml
-import os
-import subprocess
-import enum
-import time
-import importlib.util
-from pathlib import Path
-from dataclasses import dataclass
-from typing import Optional, Dict, Set
-from ark.tools.log import log
-import typer
-
-app = typer.Typer()
-
-
-class TargetType(enum.Enum):
- """
- An enumeration to specify the type of target that will be executed:
-
- 1. SCRIPT: A Python file (e.g., my_script.py).
- 2. MODULE: A Python module (e.g., my_package.my_module).
- """
-
- SCRIPT = 0
- MODULE = 1
-
-
-@dataclass
-class NodeProcessInfo:
- """
- A data class that holds information about a running node process:
-
- Attributes:
- node_name (str): The unique name of the node.
- process (subprocess.Popen): The Popen object representing the node's process.
- log_path (Optional[Path]): The optional path to the log file (None if using terminal).
- log_file (Optional[object]): The optional file handle for the log file.
- """
-
- node_name: str
- process: subprocess.Popen
- log_path: Optional[Path]
- log_file: Optional[object]
-
-
-class NodeExecutor:
- """
- A class responsible for configuring and launching a node as described by a YAML configuration.
-
- Usage:
- node_exec = NodeExecutor("my_node", config_dictionary)
- node_info = node_exec.run()
-
- This will spawn the specified target (script or module) in a subprocess,
- optionally logging its output to a file or displaying it in the terminal.
- """
-
- def __init__(self, node_name: str, config: Dict):
- """
- Initialize the NodeExecutor with a node name and its corresponding configuration.
-
- Args:
- node_name (str): The unique name of this node.
- config (dict): The YAML-derived configuration dictionary for this node.
- """
- self.name = node_name
- self.config = config
-
- def get_target(self) -> str:
- """
- Retrieve the 'target' value from the configuration.
-
- Returns:
- str: The 'target' string (file path or module name).
-
- Raises:
- ValueError: If the 'target' key is missing from the configuration.
- """
- try:
- return self.config["target"]
- except KeyError:
- raise ValueError(f"You must provide a target for the node '{self.name}'")
-
- def get_target_type(self) -> TargetType:
- """
- Determine whether the target is a Python script or a Python module.
-
- Returns:
- TargetType: An enum indicating SCRIPT or MODULE.
-
- Raises:
- ValueError: If the target is neither a recognized file nor an importable module.
- """
- target = self.get_target()
-
- # Check if it's a file
- if os.path.isfile(target):
- return TargetType.SCRIPT
-
- # Otherwise, check if it is an importable module
- spec = importlib.util.find_spec(target)
- if spec is not None:
- return TargetType.MODULE
-
- raise ValueError(
- f"Target '{target}' is neither a valid script file nor an importable module."
- )
-
- def get_command(self) -> list:
- """
- Build the full command (list of strings) for the subprocess.
- This includes the Python executable and the target (script or module).
-
- Returns:
- list: The full command list for starting the subprocess.
- """
- python_cmd = [
- sys.executable,
- "-u", # unbuffered output
- ]
-
- target_type = self.get_target_type()
-
- if target_type == TargetType.SCRIPT:
- # e.g. python -u my_script.py
- cmd = python_cmd + [self.get_target()]
- else:
- # e.g. python -u -m my_module
- cmd = python_cmd + ["-m", self.get_target()]
-
- return cmd
-
- def setup_display(self):
- """
- Set up how the node's output (stdout, stderr) will be handled:
-
- 'terminal' -> output goes to the parent's stdout/stderr.
- 'logfile' -> output is directed to a dedicated log file under .noahrlogs/.
- Any other string -> considered a file path, output is directed to that file.
-
- Returns:
- tuple: (stdout, stderr, log_path, log_file)
- """
- display = self.config.get("display", "logfile") # default is 'logfile'
- log_file, log_path = None, None
-
- if display == "terminal":
- # Inherit stdout/stderr from parent
- stdout, stderr = None, None
- elif display == "logfile":
- # Store logs in the current working directory under .noahrlogs/
- logs_dir = Path.cwd() / ".arklogs"
- logs_dir.mkdir(parents=True, exist_ok=True)
-
- node_logs_dir = logs_dir / self.name
- node_logs_dir.mkdir(parents=True, exist_ok=True)
-
- stamp = time.time_ns()
- log_path = node_logs_dir / f"{stamp}.log"
-
- log_file = open(log_path, "w", buffering=1)
- stdout, stderr = log_file, subprocess.STDOUT
- else:
- # Treat 'display' as a path to a file
- log_path = Path(display)
- log_file = open(log_path, "w", buffering=1)
- stdout, stderr = log_file, subprocess.STDOUT
-
- return stdout, stderr, log_path, log_file
-
- def run(self) -> NodeProcessInfo:
- """
- Launch the node process according to the configuration.
-
- Returns:
- NodeProcessInfo: An object containing process details (PID, log paths, etc.)
- """
- cmd = self.get_command()
- stdout, stderr, log_path, log_file = self.setup_display()
-
- process = subprocess.Popen(cmd, stdout=stdout, stderr=stderr)
-
- return NodeProcessInfo(
- node_name=self.name, process=process, log_path=log_path, log_file=log_file
- )
-
-
-def load_launch_file(launch_path: Path, included_files: Set[Path]) -> Dict:
- """
- Recursively load a YAML launch file and return a dictionary of node configurations.
- Supports the 'include' key, allowing composition of multiple YAML files.
-
- Args:
- launch_path (Path): The filesystem path to the launch file.
- included_files (Set[Path]): A set of files that have already been included (to prevent loops).
-
- Returns:
- Dict: A dictionary mapping node names to their configuration dictionaries.
-
- Raises:
- ValueError: If there's a circular include or duplicate node name.
- """
- launch_path = launch_path.resolve()
-
- if launch_path in included_files:
- raise ValueError(f"Circular include detected for file: {launch_path}")
- included_files.add(launch_path)
-
- with open(launch_path, "r") as f:
- launch_content = yaml.load(f, Loader=yaml.SafeLoader)
-
- nodes = {}
-
- # The YAML should be a dict of node_name -> node_config
- for key, config in launch_content.items():
- if "include" in config:
- include_path = Path(config["include"])
- if not include_path.is_absolute():
- include_path = launch_path.parent / include_path
-
- included_nodes = load_launch_file(include_path, included_files)
-
- for included_node_name in included_nodes:
- if included_node_name in nodes:
- raise ValueError(
- f"Duplicate node name '{included_node_name}' found "
- f"when including '{include_path}'"
- )
- nodes[included_node_name] = included_nodes[included_node_name]
- else:
- if key in nodes:
- raise ValueError(
- f"Duplicate node name '{key}' found in '{launch_path}'"
- )
- nodes[key] = config
-
- return nodes
-
-
-def ark_launch(launch_file: str):
- """
- Main entry point for the launch script.
-
- Usage:
- python launch_script.py
-
- Steps:
- 1. Parse the provided launch file path from sys.argv.
- 2. Recursively load and merge all node configurations (including nested includes).
- 3. Create and start each node process using NodeExecutor.
- 4. Monitor the running processes, logging any failures or normal terminations.
- 5. Shut down gracefully on user interrupt (Ctrl+C).
- """
- launch_path = Path(launch_file)
-
- included_files = set()
- nodes_config = load_launch_file(launch_path, included_files)
-
- processes = []
- for node_name, config in nodes_config.items():
- executor = NodeExecutor(node_name, config)
- node_info = executor.run()
- processes.append(node_info)
-
- log.ok(f"Started node '{node_name}' with PID {node_info.process.pid}")
- if node_info.log_path:
- log.ok(f"Logs for '{node_name}' are being written to {node_info.log_path}")
-
- try:
- while processes:
- for node_info in processes[:]:
- retcode = node_info.process.poll()
- if retcode is not None:
- if retcode == 0:
- log.ok(f"Node '{node_info.node_name}' has exited successfully.")
- else:
- log.error(
- f"Node '{node_info.node_name}' exited with return code {retcode}."
- )
- if node_info.log_path:
- log.error(f"Check logs at {node_info.log_path}")
-
- if node_info.log_file:
- node_info.log_file.close()
-
- processes.remove(node_info)
-
- time.sleep(1)
- except KeyboardInterrupt:
- log.warn("KeyboardInterrupt received. Terminating all nodes.")
- for node_info in processes:
- node_info.process.terminate()
- for node_info in processes:
- node_info.process.wait()
- finally:
- for node_info in processes:
- if node_info.log_file:
- node_info.log_file.close()
- log.ok("All nodes have been terminated.")
-
-
-@app.command()
-def start(launch_file: str):
- """
- Start the launcher with the specified launch file.
-
- Args:
- launch_file (str): The path to the launch file.
- """
- ark_launch(launch_file)
-
-
-def main():
- app()
-
-
-if __name__ == "__main__":
- main()
-
-# TRIVIA: Side Oiled Slideway Launching or Chrstening are ways of launching an ARK(ship)
-
-# ====================================================================================================
-# Example Usage of the Launcher YAML Configuration
-# ====================================================================================================
-
-# talker:
-# target: /nfs/rlteam/sarthakdas/arkframework/examples/basics/talker_listener/talker.py
-# display: terminal
-# listener:
-# target: /nfs/rlteam/sarthakdas/arkframework/examples/basics/talker_listener/listener.py
-# display: logfile
diff --git a/ark/tools/log.py b/ark/tools/log.py
deleted file mode 100644
index 5b1408c..0000000
--- a/ark/tools/log.py
+++ /dev/null
@@ -1,182 +0,0 @@
-import logging
-from datetime import datetime
-from typing import Optional
-
-
-# Define custom colors for log levels using the bcolors class
-class bcolors:
- """! This class contains color codes to be used in the log messages to provide
- visual cues for different log levels."""
-
- HEADER = "\033[95m" # Purple
- OKBLUE = "\033[94m" # Blue
- OKCYAN = "\033[96m" # Cyan
- OKGREEN = "\033[92m" # Green
- WARNING = "\033[93m" # Yellow
- FAIL = "\033[91m" # Red
- ENDC = "\033[0m" # Reset color
- BOLD = "\033[1m" # Bold
- UNDERLINE = "\033[4m" # Underline
- WHITE = "\033[97m" # White
- GREY = "\033[90m" # Grey
-
-
-# Define a custom log level: OK (between INFO and WARNING)
-OK_LEVEL_NUM = 25
-logging.addLevelName(OK_LEVEL_NUM, "OK")
-
-
-def ok(self: logging.Logger, message: str, *args: object, **kwargs: object) -> None:
- """! Custom log method for the OK log level.
-
- This method adds a custom logging level between INFO and WARNING. It is used
- to log messages that indicate normal operations, but with higher importance
- than INFO.
-
- @param message The log message.
- @param args Additional arguments for formatting the message.
- @param kwargs Additional keyword arguments.
- """
- if self.isEnabledFor(OK_LEVEL_NUM):
- self._log(OK_LEVEL_NUM, message, args, **kwargs)
-
-
-logging.Logger.ok = ok # Add the `ok` method to the Logger class
-
-
-def apply_panda_style(text: str) -> str:
- styled_text = ""
- colors = [bcolors.WHITE, bcolors.GREY]
- for i, char in enumerate(text):
- styled_text += colors[i % 2] + char
- return styled_text + bcolors.ENDC
-
-
-def log_panda(
- self: logging.Logger, message: str, *args: object, **kwargs: object
-) -> None:
- if self.isEnabledFor(logging.INFO):
- styled_message = apply_panda_style(message)
- self._log(logging.INFO, styled_message, args, **kwargs)
-
-
-logging.Logger.panda = log_panda # Add `log_panda` method to Logger class
-
-
-class CustomFormatter(logging.Formatter):
- """! CustomFormatter for applying color coding to log levels and including timestamp."""
-
- COLORS: dict[str, str] = {
- "DEBUG": bcolors.OKBLUE, # Blue for DEBUG level
- "INFO": bcolors.OKCYAN, # Cyan for INFO level
- "OK": bcolors.OKGREEN + bcolors.BOLD, # Bold Green for OK level
- "WARNING": bcolors.WARNING, # Yellow for WARNING level
- "ERROR": bcolors.FAIL, # Red for ERROR level
- "CRITICAL": bcolors.FAIL + bcolors.BOLD, # Bold Red for CRITICAL level
- "NOTSET": bcolors.ENDC, # Reset for NOTSET level
- }
-
- def __init__(
- self,
- fmt: str = "%(levelname)s [%(asctime)s] - %(message)s",
- datefmt: str = "%H:%M:%S.%f",
- ) -> None:
- """! Initializes the CustomFormatter with the specified format for the log messages.
-
- @param fmt The format string for log messages.
- @param datefmt The format string for the timestamp in the log message.
- """
- super().__init__(fmt, datefmt)
-
- def formatTime(
- self, record: logging.LogRecord, datefmt: Optional[str] = None
- ) -> str:
- """! Overrides the `formatTime` method to include milliseconds in the timestamp.
-
- @param record The log record.
- @param datefmt Optional format string for the timestamp.
- @return The formatted timestamp, including milliseconds.
- """
- if datefmt:
- return datetime.fromtimestamp(record.created).strftime(datefmt)
- else:
- # Format time as HH:MM:SS.mmm (milliseconds)
- return (
- datetime.fromtimestamp(record.created).strftime("%H:%M:%S.")
- + f"{int(record.msecs):02d}"
- )
-
- def format(self, record: logging.LogRecord) -> str:
- """! Formats the log message with the appropriate color based on the log level.
-
- @param record The log record.
- @return The formatted log message with color coding.
- """
- log_message = super().format(record)
-
- # Get the color for the log level, apply the color if it exists
- color = self.COLORS.get(record.levelname, bcolors.ENDC)
-
- # Apply color to the log message
- log_message = color + log_message + bcolors.ENDC
- return log_message
-
-
-def setup_logger() -> logging.Logger:
- """! Configures and returns a global logger with a custom formatter for colorized output.
-
- This function sets up a logger that includes the custom colorized formatter, and
- configures it to output logs to the console. It uses the DEBUG level as the minimum level
- to capture logs.
-
- @return The configured logger.
- """
- logger = logging.getLogger("my_logger")
- logger.setLevel(logging.DEBUG) # Adjust to the level you need
-
- # Create a stream handler and set the formatter
- stream_handler = logging.StreamHandler()
- formatter = CustomFormatter(
- fmt="[%(levelname)s] [%(asctime)s] - %(message)s",
- datefmt="%H:%M:%S.%f", # Use time format with milliseconds
- )
- stream_handler.setFormatter(formatter)
- logger.addHandler(stream_handler)
-
- return logger
-
-
-# Initialize the logger
-log = setup_logger()
-
-
-def query(msg: str) -> str:
- """! Prompts the user for input after printing a message to the console.
-
- This function logs the action of querying the user, prints the provided message
- to the console, and waits for the user's input.
-
- @param msg The message to display to the user.
- @return The user's input as a string.
- """
- log.info("querying user")
- print(msg)
- usrin = input(">> ")
- return usrin
-
-
-# Attach the `query` function as a method to the `log` object for easier use.
-log.query = query
-
-# Ensure only the log object is exported
-__all__ = ["log"]
-
-if __name__ == "__main__":
- usrin = log.query("ready?")
- log.debug(f"user said '{usrin}'")
- log.ok("all good")
- log.info("hello world")
- log.warning("warn message")
- log.error("oh no")
- log.critical("bad times")
- log.panda("this is a panda log")
diff --git a/ark/tools/moveit2/moveit2_bridge.py b/ark/tools/moveit2/moveit2_bridge.py
deleted file mode 100644
index 0fa019d..0000000
--- a/ark/tools/moveit2/moveit2_bridge.py
+++ /dev/null
@@ -1,96 +0,0 @@
-import os
-from typing import Optional, Any
-
-from ark.tools.ros_bridge.ark_ros2_bridge import ArkRos2Bridge
-from ark.utils.utils import ConfigPath
-from arktypes import joint_group_command_t
-from control_msgs.msg import JointTrajectoryControllerState
-
-
-class MoveIt2Bridge(ArkRos2Bridge):
- """Bridge ROS2 JointTrajectoryControllerState -> Ark joint_group_command_t."""
-
- def __init__(
- self,
- ros_controller: str,
- ark_robot_name: str,
- mapping_table: Optional[dict[str, Any]] = None,
- global_config: Optional[dict[str, Any]] = None,
- ):
- sim = self.is_sim_enabled(global_config=global_config)
-
- # Build topic/channel names
- if sim:
- ros_topic = f"/{ros_controller}_controller/state"
- ark_channel = f"{ark_robot_name}/joint_group_command/sim"
- else:
- ros_topic = f"/{ros_controller}_controller/state"
- ark_channel = f"{ark_robot_name}/joint_group_command"
-
- # Base mapping (MoveIt2 state -> Ark command)
- moveit2_mapping_table = {
- "ros2_to_ark": [
- {
- "ros2_channel": ros_topic,
- "ros2_type": JointTrajectoryControllerState,
- "ark_channel": ark_channel,
- "ark_type": joint_group_command_t,
- "translator_callback": self.moveit2_translator,
- }
- ],
- "ark_to_ros": [],
- }
-
- # Merge in extra mappings if provided
- if mapping_table:
- moveit2_mapping_table["ros2_to_ark"].extend(
- mapping_table.get("ros2_to_ark", [])
- )
- moveit2_mapping_table["ark_to_ros"].extend(
- mapping_table.get("ark_to_ros", [])
- )
-
- # Init parent with the final mapping
- super().__init__(
- mapping_table=moveit2_mapping_table, global_config=global_config
- )
-
- def moveit2_translator(
- self,
- ros_msg: JointTrajectoryControllerState,
- ros_channel: str,
- ros_type: type[JointTrajectoryControllerState],
- ark_channel: str,
- ark_type: type[joint_group_command_t],
- ):
- """Convert joint state positions into Ark command."""
- msg = joint_group_command_t()
- msg.name = "arm"
- msg.n = len(ros_msg.actual.positions)
- msg.cmd = list(ros_msg.actual.positions)
- return msg
-
- def is_sim_enabled(self, global_config: Any) -> Optional[bool]:
- """
- Check if the key 'sim' is True or False in a dict or YAML file.
-
- Args:
- global_config (Any): Global configuration dictionary or YAML file path.
-
- Returns:
- bool | None: True/False if 'sim' key exists, None if missing.
- """
- if isinstance(global_config, str):
- global_config = ConfigPath(global_config)
-
- if isinstance(global_config, dict):
- data = global_config
- elif isinstance(global_config, ConfigPath) and ConfigPath.is_file():
- data = global_config.read_yaml()
- else:
- raise ValueError("Source must be a dict or a valid YAML file path.")
-
- if not isinstance(data, dict):
- raise ValueError("YAML/Dict must represent a dictionary at top-level.")
-
- return data.get("sim", None)
diff --git a/ark/tools/network.py b/ark/tools/network.py
deleted file mode 100644
index ef0b704..0000000
--- a/ark/tools/network.py
+++ /dev/null
@@ -1,172 +0,0 @@
-import typing
-import typer
-
-from ark.client.comm_handler.service import send_service_request
-from ark.global_constants import DEFAULT_SERVICE_DECORATOR
-from ark.tools.ark_graph.ark_graph import network_info_lcm_to_dict
-from arktypes import flag_t, network_info_t
-
-node = typer.Typer(help="Interact with nodes", invoke_without_command=True)
-channel = typer.Typer(help="Inspect channels")
-service = typer.Typer(help="Inspect services")
-
-
-def _fetch_network_info(host: str, port: int) -> dict:
- req = flag_t()
- lcm_msg = send_service_request(
- host,
- port,
- f"{DEFAULT_SERVICE_DECORATOR}/GetNetworkInfo",
- req,
- network_info_t,
- )
- return network_info_lcm_to_dict(lcm_msg)
-
-
-def _extract_channel_type(ch: dict) -> str:
- """
- Try a few common keys to determine the data type used on a channel.
- Falls back to '?' if nothing is present.
- """
- return (
- ch.get("channel_type")
- or ch.get("type")
- or ch.get("message_type")
- or ch.get("msg_type")
- or "?"
- )
-
-
-@node.callback(invoke_without_command=True)
-def show_node(
- ctx: typer.Context,
- name: typing.Optional[str] = typer.Option(
- None,
- "--name",
- "-n",
- help="Name of the node to inspect",
- ),
- host: str = "127.0.0.1",
- port: int = 1234,
- verbose: bool = typer.Option(
- False, "--verbose", "-v", help="Show default services"
- ),
-):
- """Show information about NODE if provided via ``-n/--name``."""
- if ctx.invoked_subcommand is not None:
- return
- if name is None:
- typer.echo(ctx.get_help())
- raise typer.Exit()
- data = _fetch_network_info(host, port)
- for node_info in data.get("nodes", []):
- if node_info.get("name") == name:
- comms = node_info.get("comms", {})
-
- def _print_section(title: str, items: list, key: str):
- print(f"{title}:")
- if not items:
- print(" ")
- else:
- for it in items:
- dtype = (
- it.get("channel_type")
- if "channel_type" in it
- else it.get("request_type")
- )
- extra = (
- f" -> {it.get('response_type')}"
- if "response_type" in it
- else ""
- )
- print(f" {it.get(key)} ({dtype}{extra})")
-
- _print_section("Listeners", comms.get("listeners", []), "channel_name")
- _print_section("Publishers", comms.get("publishers", []), "channel_name")
- _print_section("Subscribers", comms.get("subscribers", []), "channel_name")
- print("Services:")
- services = comms.get("services", [])
- if not services:
- print(" ")
- else:
- for srv in services:
- sname = srv.get("service_name")
- if not verbose and sname.startswith(DEFAULT_SERVICE_DECORATOR):
- continue
- print(
- f" {sname} ({srv.get('request_type')} -> {srv.get('response_type')})"
- )
- return
- typer.echo(f"Node '{name}' not found.")
-
-
-@node.command("list")
-def list_nodes(host: str = "127.0.0.1", port: int = 1234):
- """List active nodes."""
- data = _fetch_network_info(host, port)
- for node_info in data.get("nodes", []):
- print(node_info.get("name"))
-
-
-@channel.command("list")
-def list_channels(host: str = "127.0.0.1", port: int = 1234):
- """
- List active channels with their message types.
-
- Output format:
- ([, ...])
- """
- data = _fetch_network_info(host, port)
-
- # Map channel_name -> set of discovered types
- channel_types: dict[str, set[str]] = {}
-
- for node_info in data.get("nodes", []):
- comms = node_info.get("comms", {}) or {}
- for comp in ("listeners", "subscribers", "publishers"):
- for ch in comms.get(comp, []) or []:
- if not ch.get("channel_status"):
- continue
- name = ch.get("channel_name")
- if not name:
- continue
- dtype = _extract_channel_type(ch)
- if name not in channel_types:
- channel_types[name] = set()
- if dtype:
- channel_types[name].add(dtype)
-
- for ch_name in sorted(channel_types.keys()):
- types_str = ", ".join(sorted(t for t in channel_types[ch_name] if t))
- if not types_str:
- types_str = "?"
- print(f"{ch_name} ({types_str})")
-
-
-@service.command("list")
-def list_services(
- host: str = "127.0.0.1",
- port: int = 1234,
- verbose: bool = typer.Option(
- False, "--verbose", "-v", help="Show default services"
- ),
-):
- """List available services."""
- data = _fetch_network_info(host, port)
- services = set()
- for node_info in data.get("nodes", []):
- for srv in node_info.get("comms", {}).get("services", []):
- name = srv.get("service_name")
- if not verbose and name.startswith(DEFAULT_SERVICE_DECORATOR):
- continue
- services.add(name)
- for srv in sorted(services):
- print(srv)
-
-
-if __name__ == "__main__":
- app = typer.Typer()
- app.add_typer(node, name="node")
- app.add_typer(channel, name="channel")
- app.add_typer(service, name="service")
- app()
diff --git a/ark/tools/ros_bridge/__init__.py b/ark/tools/ros_bridge/__init__.py
deleted file mode 100644
index 8b13789..0000000
--- a/ark/tools/ros_bridge/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-
diff --git a/ark/tools/ros_bridge/ark_ros2_bridge.py b/ark/tools/ros_bridge/ark_ros2_bridge.py
deleted file mode 100644
index 0425f9f..0000000
--- a/ark/tools/ros_bridge/ark_ros2_bridge.py
+++ /dev/null
@@ -1,270 +0,0 @@
-from functools import partial
-from typing import Any, Dict, Optional
-
-from ark.tools.log import log
-from ark.client.comm_infrastructure.base_node import BaseNode
-
-import rclpy
-from rclpy.node import Node as RclpyNode
-from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy, DurabilityPolicy
-
-
-__doc__ = """ARK ⟷ ROS 2 translator/bridge"""
-
-
-class ArkRos2Bridge(BaseNode):
- """
- Bridge Ark ⟷ ROS2 using `rclpy`.
-
- The bridge is bidirectional and driven by a user-supplied `mapping_table`
- that declares which topics/channels to connect and how to translate messages.
-
- Mapping table schema:
- mapping_table = {
- "ros2_to_ark": [
- {
- "ros2_channel": "/chatter", # str: ROS2 topic name
- "ros2_type": std_msgs.msg.String, # type: ROS2 msg class
- "ark_channel": "ark/chatter", # str: Ark channel name
- "ark_type": string_t, # type: Ark message type (class/struct)
- "translator_callback": callable, # (ros2_msg, ros2_channel, ros2_type, ark_channel, ark_type) -> ark_msg
- },
- ...
- ],
- "ark_to_ros2": [
- {
- "ark_channel": "ark/cmd", # str: Ark channel name
- "ark_type": string_t, # type: Ark message type (class/struct)
- "ros2_channel": "/cmd", # str: ROS2 topic name
- "ros2_type": std_msgs.msg.String, # type: ROS2 msg class
- "translator_callback": callable, # (t, ark_channel, ark_msg) -> ros2_msg
- },
- ...
- ],
- }
- """
-
- def __init__(
- self,
- mapping_table: Dict[str, Any],
- node_name: str = "ark_ros2_bridge",
- global_config: Optional[Dict[str, Any]] = None,
- qos_profile: Optional[QoSProfile] = None,
- ):
- """
- Initialize the bridge and wire up all declared mappings.
-
- mapping_table See class docs for full schema. Missing keys default to empty lists.
- node_name Name for both the Ark BaseNode and the underlying rclpy node.
- global_config Optional Ark node configuration passed to BaseNode.
- qos_profile Optional ROS2 QoS profile. If omitted, uses RELIABLE/KEEP_LAST(10)/VOLATILE.
- """
- super().__init__(node_name, global_config=global_config)
-
- # ---- ROS2 node setup ----
- if not rclpy.ok():
- rclpy.init(args=None)
-
- self._ros2_node: RclpyNode = rclpy.create_node(node_name)
-
- # Default QoS
- self._qos = qos_profile or QoSProfile(
- reliability=ReliabilityPolicy.RELIABLE,
- history=HistoryPolicy.KEEP_LAST,
- depth=10,
- durability=DurabilityPolicy.VOLATILE,
- )
-
- # Keep references so publishers/subscriptions don’t get GC’d
- self._ros2_publishers = []
- self._ros2_subscriptions = []
-
- # ---- Build mappings ----
- ros2_to_ark_table = mapping_table.get("ros2_to_ark", [])
- ark_to_ros2_table = mapping_table.get("ark_to_ros2", [])
-
- self.ros2_to_ark_mapping = []
- for mapping in ros2_to_ark_table:
- ros2_channel = mapping["ros2_channel"]
- ros2_type = mapping["ros2_type"]
- ark_channel = mapping["ark_channel"]
- ark_type = mapping["ark_type"]
- translator_callback = mapping["translator_callback"]
-
- # ARK publisher
- ark_pub = self.create_publisher(ark_channel, ark_type)
-
- # Subscriber callback (ROS2->ARK)
- sub_cb = partial(
- self._generic_ros2_to_ark_translator_callback,
- translator_callback=translator_callback,
- ros2_channel=ros2_channel,
- ros2_type=ros2_type,
- ark_channel=ark_channel,
- ark_type=ark_type,
- ark_publisher=ark_pub,
- )
-
- # ROS2 subscription
- sub = self._ros2_node.create_subscription(ros2_type, ros2_channel, sub_cb, self._qos)
- self._ros2_subscriptions.append(sub)
-
- self.ros2_to_ark_mapping.append(
- {
- "ros2_channel": ros2_channel,
- "ros2_type": ros2_type,
- "ark_channel": ark_channel,
- "ark_type": ark_type,
- "translator_callback": translator_callback,
- "publisher": ark_pub,
- }
- )
-
- self.ark_to_ros2_mapping = []
- for mapping in ark_to_ros2_table:
- ark_channel = mapping["ark_channel"]
- ark_type = mapping["ark_type"]
- ros2_channel = mapping["ros2_channel"]
- ros2_type = mapping["ros2_type"]
- translator_callback = mapping["translator_callback"]
-
- # ROS2 publisher
- ros2_pub = self._ros2_node.create_publisher(ros2_type, ros2_channel, self._qos)
- self._ros2_publishers.append(ros2_pub)
-
- # ARK subscriber (ARK->ROS2)
- ark_cb = partial(
- self._generic_ark_to_ros2_translator_callback,
- translator_callback=translator_callback,
- ark_channel=ark_channel,
- ark_type=ark_type,
- ros2_channel=ros2_channel,
- ros2_type=ros2_type,
- ros2_publisher=ros2_pub,
- )
- self.create_subscriber(ark_channel, ark_type, ark_cb)
-
- self.ark_to_ros2_mapping.append(
- {
- "ros2_channel": ros2_channel,
- "ros2_type": ros2_type,
- "ark_channel": ark_channel,
- "ark_type": ark_type,
- "translator_callback": translator_callback,
- "publisher": ros2_pub,
- }
- )
-
- # ---------- Callbacks ----------
-
- def _generic_ros2_to_ark_translator_callback(
- self,
- ros2_msg: Any,
- *,
- translator_callback,
- ros2_channel: str,
- ros2_type: Any,
- ark_channel: str,
- ark_type: Any,
- ark_publisher,
- ) -> None:
- """
- Translate ROS2 -> ARK.
- translator_callback: (ros2_msg, ros2_channel, ros2_type, ark_channel, ark_type) -> ark_msg
- """
- try:
- ark_msg = translator_callback(ros2_msg, ros2_channel, ros2_type, ark_channel, ark_type)
- ark_publisher.publish(ark_msg)
- except Exception as e:
- self._ros2_node.get_logger().warn(
- f"[ROS2→ARK] Failed translating {ros2_channel} -> {ark_channel}: {e}"
- )
-
- def _generic_ark_to_ros2_translator_callback(
- self,
- t: int,
- _channel: str,
- ark_msg: Any,
- *,
- translator_callback,
- ark_channel: str,
- ark_type: Any,
- ros2_channel: str,
- ros2_type: Any,
- ros2_publisher,
- ) -> None:
- """
- Translate ARK -> ROS2.
- translator_callback: (t, ark_channel, ark_msg) -> ros2_msg
- """
- try:
- ros2_msg = translator_callback(t, ark_channel, ark_msg)
- ros2_publisher.publish(ros2_msg)
- except Exception as e:
- self._ros2_node.get_logger().warn(
- f"[ARK→ROS2] Failed translating {ark_channel} -> {ros2_channel}: {e}"
- )
-
- # ---------- Lifecycle ----------
-
- def spin(self) -> None:
- """
- Drive both Ark (LCM) and ROS2 event loops without blocking either.
- """
- try:
- while not self._done and rclpy.ok():
- # Pump Ark
- try:
- self._lcm.handle_timeout(0)
- except OSError as e:
- log.warning(f"Ark threw OSError {e}")
- break
-
- # Pump ROS2 once (non-blocking)
- rclpy.spin_once(self._ros2_node, timeout_sec=0.0)
- finally:
- self.shutdown()
-
- @staticmethod
- def get_cli_doc():
- return __doc__
-
- def shutdown(self) -> None:
- """
- Cleanly stop Ark and ROS2 resources.
- """
- # Ark side
- for ch in self._comm_handlers:
- try:
- ch.shutdown()
- except Exception:
- pass
- for s in self._steppers:
- try:
- s.shutdown()
- except Exception:
- pass
-
- # ROS2 side
- try:
- # Destroy pubs/subs explicitly (optional but tidy)
- for sub in self._ros2_subscriptions:
- try:
- self._ros2_node.destroy_subscription(sub)
- except Exception:
- pass
- for pub in self._ros2_publishers:
- try:
- self._ros2_node.destroy_publisher(pub)
- except Exception:
- pass
- try:
- self._ros2_node.destroy_node()
- except Exception:
- pass
- finally:
- if rclpy.ok():
- try:
- rclpy.shutdown()
- except Exception:
- pass
diff --git a/ark/tools/ros_bridge/ark_ros_bridge.py b/ark/tools/ros_bridge/ark_ros_bridge.py
deleted file mode 100644
index 3f62f4d..0000000
--- a/ark/tools/ros_bridge/ark_ros_bridge.py
+++ /dev/null
@@ -1,155 +0,0 @@
-import yaml
-from ark.tools.log import log
-from functools import partial
-
-from typing import Dict, Any, Optional
-from ark.client.comm_infrastructure.base_node import BaseNode, main
-from arktypes import string_t
-
-import rospy
-
-__doc__ = """ARK to ROS translator"""
-
-
-class ArkRosBridge(BaseNode):
- def __init__(self, mapping_table, node_name="ARK_ROS_Bridge", global_config=None):
- super().__init__(node_name, global_config=global_config)
- self.ros_to_ark_mapping = []
-
- ros_to_ark_table = mapping_table["ros_to_ark"]
- ark_to_ros_table = mapping_table["ark_to_ros"]
-
- for mapping in ros_to_ark_table:
- ros_channel = mapping["ros_channel"]
- ros_type = mapping["ros_type"]
- ark_channel = mapping["ark_channel"]
- ark_type = mapping["ark_type"]
- translator_callback = mapping["translator_callback"]
-
- publisher = self.create_publisher(ark_channel, ark_type)
-
- modified_callback = partial(
- self._generic_ros_to_ark_translator_callback,
- translator_callback=translator_callback,
- ros_channel=ros_channel,
- ros_type=ros_type,
- ark_channel=ark_channel,
- ark_type=ark_type,
- publisher=publisher,
- )
-
- rospy.Subscriber(ros_channel, ros_type, modified_callback)
-
- ros_to_ark_map = {
- "ros_channel": ros_channel,
- "ros_type": ros_type,
- "ark_channel": ark_channel,
- "ark_type": ark_type,
- "translator_callback": translator_callback,
- "publisher": publisher,
- }
- self.ros_to_ark_mapping.append(ros_to_ark_map)
-
- self.ark_to_ros_mapping = []
-
- for mapping in ark_to_ros_table:
- ark_channel = mapping["ark_channel"]
- ark_type = mapping["ark_type"]
- ros_channel = mapping["ros_channel"]
- ros_type = mapping["ros_type"]
- translator_callback = mapping["translator_callback"]
-
- # Create a listener
- publisher = rospy.Publisher(ros_channel, ros_type, queue_size=10)
-
- modified_callback = partial(
- self._generic_ark_to_ros_translator_callback,
- translator_callback=translator_callback,
- ark_channel=ark_channel,
- ark_type=ark_type,
- ros_channel=ros_channel,
- ros_type=ros_type,
- publisher=publisher,
- )
-
- # Create a ROS publisher
- self.create_subscriber(ark_channel, ark_type, modified_callback)
-
- ros_to_ark_map = {
- "ros_channel": ros_channel,
- "ros_type": ros_type,
- "ark_channel": ark_channel,
- "ark_type": ark_type,
- "translator_callback": translator_callback,
- "publisher": publisher,
- }
- self.ark_to_ros_mapping.append(ros_to_ark_map)
-
- # Create a minimal ROS node
- rospy.init_node(node_name, anonymous=True)
-
- def _generic_ros_to_ark_translator_callback(
- self,
- ros_msg,
- translator_callback,
- ros_channel,
- ros_type,
- ark_channel,
- ark_type,
- publisher,
- ):
- """
- This is the modified callback that includes the ROS channel and ark publisher.
- """
- ark_msg = translator_callback(
- ros_msg, ros_channel, ros_type, ark_channel, ark_type
- )
- publisher.publish(ark_msg)
-
- def _generic_ark_to_ros_translator_callback(
- self,
- t,
- _,
- ark_msg,
- translator_callback,
- ark_channel,
- ark_type,
- ros_channel,
- ros_type,
- publisher,
- ):
- """
- This is the modified callback that includes the ark channel and ROS publisher.
- """
- ros_msg = translator_callback(t, ark_channel, ark_msg)
- publisher.publish(ros_msg)
-
- def spin(self) -> None:
- """!
- Runs the node’s main loop, handling ark messages continuously until the node is finished.
-
- The loop calls `self._ark.handle()` to process incoming messages. If an OSError is encountered,
- the loop will stop and the node will shut down.
- """
- while not self._done and not rospy.is_shutdown():
- try:
- self._lcm.handle_timeout(0)
- # rospy.spin()
- except OSError as e:
- log.warning(f"Ark or ROS threw OSError {e}")
- self._done = True
-
- @staticmethod
- def get_cli_doc():
- return __doc__
-
- def shutdown(self) -> None:
- """!
- Shuts down the node by stopping all communication handlers and steppers.
-
- Iterates through all registered communication handlers and steppers, shutting them down.
- """
- for ch in self._comm_handlers:
- ch.shutdown()
- for s in self._steppers:
- s.shutdown()
diff --git a/ark/tools/urdf_infos.py b/ark/tools/urdf_infos.py
deleted file mode 100644
index 2bf9e5e..0000000
--- a/ark/tools/urdf_infos.py
+++ /dev/null
@@ -1,78 +0,0 @@
-#!/usr/bin/env python3
-import argparse
-import xml.etree.ElementTree as ET
-
-
-def format_number(value):
- try:
- return f"{float(value):.4f}" # Limit to 4 decimal places
- except (ValueError, TypeError):
- return value # Return the value as-is if it's not a number
-
-
-def parse_urdf(path: str):
- # Load URDF file
- tree = ET.parse(path)
- root = tree.getroot()
-
- joint_details = []
-
- for i, joint in enumerate(root.findall('joint')):
- joint_info = {
- 'Index': i,
- 'Name': joint.get('name'),
- 'Type': joint.get('type'),
- 'Parent Link': joint.find('parent').get('link'),
- 'Child Link': joint.find('child').get('link')
- }
-
- # Limits (for revolute/prismatic joints)
- if joint_info['Type'] in ['revolute', 'prismatic']:
- limit = joint.find('limit')
- if limit is not None:
- joint_info['Lower Limit'] = limit.get('lower', 'N/A')
- joint_info['Upper Limit'] = limit.get('upper', 'N/A')
- joint_info['Effort Limit'] = limit.get('effort', 'N/A')
- joint_info['Velocity Limit'] = limit.get('velocity', 'N/A')
- else:
- joint_info['Limits'] = 'No limits defined'
-
- joint_details.append(joint_info)
-
- return joint_details
-
-
-def print_joint_table(joint_details):
- # Print header
- print(f"{'Index':<6}{'Joint Name':<20}{'Type':<18}{'Parent Link':<25}{'Child Link':<25}"
- f"{'Lower Limit':<15}{'Upper Limit':<15}{'Effort Limit':<15}{'Velocity Limit':<15}")
- print("-" * 145)
-
- # Print rows
- for joint in joint_details:
- print(f"{joint['Index']:<6}{joint['Name']:<20}{joint['Type']:<18}"
- f"{joint['Parent Link']:<30}{joint['Child Link']:<30}"
- f"{format_number(joint.get('Lower Limit', 'N/A')):<15}"
- f"{format_number(joint.get('Upper Limit', 'N/A')):<15}"
- f"{format_number(joint.get('Effort Limit', 'N/A')):<15}"
- f"{format_number(joint.get('Velocity Limit', 'N/A')):<15}")
-
-
-def main():
- parser = argparse.ArgumentParser(
- description="Parse and display joint info from a URDF file."
- )
- parser.add_argument(
- "urdf_path",
- type=str,
- help="Path to the URDF file"
- )
-
- args = parser.parse_args()
-
- joints = parse_urdf(args.urdf_path)
- print_joint_table(joints)
-
-
-if __name__ == "__main__":
- main()
diff --git a/ark/tools/visualization/image_viewer.py b/ark/tools/visualization/image_viewer.py
deleted file mode 100644
index 8405c6f..0000000
--- a/ark/tools/visualization/image_viewer.py
+++ /dev/null
@@ -1,105 +0,0 @@
-from ark.client.comm_infrastructure.base_node import BaseNode, main
-from arktypes import image_t
-from arktypes.utils import unpack
-import cv2
-import numpy as np
-import typer
-
-num_channels = {
- image_t.PIXEL_FORMAT_GRAY: 1,
- image_t.PIXEL_FORMAT_RGB: 3,
- image_t.PIXEL_FORMAT_BGR: 3,
- image_t.PIXEL_FORMAT_RGBA: 4,
- image_t.PIXEL_FORMAT_BGRA: 4,
-}
-
-app = typer.Typer()
-
-
-class ImageViewNode(BaseNode):
-
- def __init__(self, channel_name: str = "image/sim", image_type: str = "image"):
- super().__init__("image viewer")
- self.channel_name = channel_name
-
- # Select the message type based on the requested image_type
- if image_type == "rgbd":
- try:
- from arktypes import rgbd_t
-
- msg_type = rgbd_t
- self.image_type = "rgbd"
- self.create_subscriber(channel_name, rgbd_t, self._display_image)
- print("Subscribed to rgbd_t messages")
- except Exception:
- print("rgbd_t not available, falling back to image_t")
- elif image_type == "depth":
- try:
- msg_type = image_t
- self.image_type = "depth"
- self.create_subscriber(channel_name, image_t, self._display_image)
- except Exception:
- print("depth not available, falling back to image_t")
- elif image_type == "image":
- try:
- msg_type = image_t
- self.image_type = "image"
- self.create_subscriber(channel_name, image_t, self._display_image)
- except Exception:
- print("image_t not available")
- else:
- raise ValueError(f"Unsupported image type: {image_type}")
-
- def _display_image(self, channel_name: str, t, msg: image_t):
- print(f"Received message on channel {channel_name} at time {t}")
- if self.image_type == "rgbd":
- image, depth = unpack.rgbd(msg)
- elif self.image_type == "depth":
- image = unpack.image(msg)
- elif self.image_type == "image":
- image = unpack.image(msg)
-
- print(image.shape if image is not None else "No image data received")
- # display the image
- if image is not None:
- if isinstance(image, np.ndarray):
- if image.ndim == 3:
- # Color image
- if image.shape[2] == 3:
- cv2.imshow(self.channel_name, image)
- if image.ndim == 2:
- # Grayscale image
- cv2.imshow(self.channel_name, image)
- if depth is not None:
- if isinstance(depth, np.ndarray):
- # Display depth image
- depth_display = cv2.normalize(
- depth, None, 0, 255, cv2.NORM_MINMAX
- ).astype(np.uint8)
- cv2.imshow(f"{self.channel_name}_depth", depth_display)
-
- cv2.waitKey(1)
-
- def kill_node(self):
- cv2.destroyAllWindows()
- super().kill_node()
-
-
-@app.command()
-def start(
- channel: str = typer.Option("image/sim", help="Channel to listen to"),
- image_type: str = typer.Option(
- "image",
- help="Type of image message: image, depth, or rgbd",
- ),
-):
- """Start the image viewer node."""
- main(ImageViewNode, channel, image_type)
-
-
-def cli_main():
- app()
-
-
-if __name__ == "__main__":
- cli_main()
diff --git a/ark/utils/camera_utils.py b/ark/utils/camera_utils.py
deleted file mode 100644
index febf43e..0000000
--- a/ark/utils/camera_utils.py
+++ /dev/null
@@ -1,8 +0,0 @@
-from enum import Enum
-
-
-class CameraType(Enum):
- """Supported camera models."""
-
- FIXED = "fixed"
- ATTACHED = "attached"
diff --git a/ark/utils/communication_utils.py b/ark/utils/communication_utils.py
deleted file mode 100644
index 74411e0..0000000
--- a/ark/utils/communication_utils.py
+++ /dev/null
@@ -1,276 +0,0 @@
-import importlib
-from typing import Any, Callable
-
-import numpy as np
-import gymnasium as gym
-from gymnasium import spaces
-
-from arktypes.utils import unpack, pack
-from ark.decoders.registry import get_decoder
-from ark.utils.data_utils import generate_flat_dict
-
-from ark.decoders.builtin_decoders import OBS_SCHEMA
-
-
-def get_ark_fn_type(ark_module: unpack, name: str):
- """
- Retrieve both an unpacking function and its corresponding type from Ark module.
- Args:
- ark_module: The module (e.g., ``arktypes.utils.unpack``) containing the
- unpack functions and optional type definitions.
- name: The base name of the function/type pair to retrieve.
-
- Returns:
- A tuple (fn, dtype) where:
- - fn is the unpacking function corresponding to ``name``.
- - dtype is the associated type object if defined, otherwise None.
-
- """
- fn = getattr(ark_module, name)
- dtype = getattr(ark_module, f"{name}_t")
- return fn, dtype
-
-
-def _resolve_channel_types(mapping: dict[str, Any]) -> dict[str, type]:
- """
- Resolve a mapping of channel names to Ark types.
- Accepts either already-imported classes or string names present in the
- ``arktypes`` package. Returns a mapping of channel name to type.
- Args:
- A dictionary mapping channel names to resolved Ark type objects.
-
- Returns:
-
- """
- if not mapping:
- return {}
- resolved: dict[str, type] = {}
- arktypes_mod = importlib.import_module("arktypes")
- for ch_name, t in mapping.items():
- if isinstance(t, str):
- resolved[ch_name] = getattr(arktypes_mod, t)
- else:
- resolved[ch_name] = t
- return resolved
-
-
-def get_channel_types(schema: dict, channel_type: str | None) -> dict[str, type]:
- """
- Generate a mapping of observation channel names to Python/Ark types
- based on the observation schema.
-
- Args:
- schema: Observation schema dictionary (from YAML or Python dict).
- Each channel entry can optionally include a 'type' key, which can be
- a string corresponding to a class in `arktypes` or a Python type.
- channel_type: channel type
-
- Returns:
- Dict[str, type]: Dictionary mapping channel name to resolved type.
- """
- channels: dict[str, Any] = {}
-
- if channel_type is not None:
- obs_schema = schema.get(channel_type, {})
- else:
- obs_schema = schema
-
- for key, entries in obs_schema.items():
- for item in entries:
- ch_name = item["from"]
- using = item["using"]
- _, ch_type = get_ark_fn_type(ark_module=unpack, name=using)
- if ch_name not in channels:
- channels[ch_name] = ch_type
-
- # Resolve type strings to actual type objects using _resolve_channel_types
- resolved_channels = _resolve_channel_types(channels)
- return resolved_channels
-
-
-def _dynamic_observation_unpacker(schema: dict, namespace: str) -> Callable:
- """
- Create a dynamic observation unpacker based on a schema.
-
- The schema should be in the format:
- observation:
- state:
- - from: channel_name
- using: callable
- image_top:
- - from: channel_name
- using: callable
- wrap: True # optional
-
- Returns a function:
- _unpack(observation_dict) -> dict
- """
-
- obs_schema = schema["observation_space"]
- obs_schema = namespace_channels(channels=obs_schema, namespace=namespace)
-
- def _unpack(observation_dict: dict[str, Any]) -> dict[str, Any]:
- if not observation_dict:
- return {}
-
- result: dict[str, Any] = {}
-
- for key, entries in obs_schema.items():
- parts = {}
- for item in entries:
- ch_name = item["from"]
- msg = observation_dict.get(ch_name)
- decoder = get_decoder(item["using"])
- decoded = decoder(msg)
- if "name" in item:
- parts[item["name"]] = decoded
- else:
- parts[item["using"]] = decoded
-
- result[key] = parts
- return result
-
- return _unpack
-
-
-def _dynamic_action_packer(
- schema: dict, namespace: str
-) -> Callable[..., dict[str, Any]]:
- """
- Create a dynamic action packer from schema.
-
- Returns a function:
- _pack(observation_dict) -> dict
-
- """
-
- act_schema = schema["action_space"]
- act_schema = namespace_channels(channels=act_schema, namespace=namespace)
-
- def _pack(action: list[float] | np.ndarray) -> dict[str, Any]:
- a = np.asarray(action).tolist()
- result: dict[str, Any] = {}
-
- for key, entries in act_schema.items():
- for item in entries:
- channel = item["from"]
- using = item["using"]
- select = item.get("select", {})
-
- # resolve packer dynamically
- fn, dtype = get_ark_fn_type(ark_module=pack, name=using)
-
- # build args from config
- args = []
-
- for field_name, idx in select.items():
- if isinstance(idx, list):
- args.append(np.array([a[i] for i in idx]))
- elif isinstance(idx, str):
- args.append(idx)
- else:
- args.append(a[idx])
- msg = fn(*args)
- result[channel] = msg
- return result
-
- return _pack
-
-
-def build_action_space(schema):
- """
- Build a Gym-style action space from the action-space configuration.
-
- The expected config format is:
-
- action_space:
- action:
- - using: task_space_command
- dim: 8
-
- Which results in a single Box of shape (sum(dim),).
- """
- schema = schema["action_space"]
- proprio_dim = 0
- for items in schema.values():
- for item in items:
- proprio_dim += item["dim"]
-
- return spaces.Box(low=-1, high=1, shape=(proprio_dim,), dtype=np.float32)
-
-
-def build_observation_space(schema: dict, flatten_obs_space: bool) -> gym.Space:
- """
- Convert observation_space schema into a Gym Dict space.
- """
- gym_dict = {}
- num_joints = schema["robot"]["num_joints"]
- schema = schema["observation_space"]
-
- for key, entries in schema.items():
- inner_dict = {}
- for item in entries:
- decoder = item["using"]
- if (
- decoder == "rgbd"
- ): # check is there any other sensors/camera type is available
- component_dict = {}
- h = item.get("image_height")
- w = item.get("image_width")
- component_dict["rgb"] = gym.spaces.Box(
- low=0, high=255, shape=(h, w, 3), dtype=np.float32
- )
- # component_dict["depth"] = gym.spaces.Box(
- # low=0.0, high=5.0, shape=(h, w), dtype=np.float32
- # ) # TODO handle depth image in proper way
- else:
- components = OBS_SCHEMA[decoder]
- component_dict = {}
- for component in components:
- if decoder == "pose":
- dim = 3 if component == "position" else 4
- elif decoder == "rigid_body_state":
- dim = 4 if component == "orientation" else 3
- else:
- dim = num_joints
- component_dict[component] = gym.spaces.Box(
- low=-np.inf, high=np.inf, shape=(dim,), dtype=np.float32
- )
- if "name" in item:
- inner_dict[item["name"]] = gym.spaces.Dict(component_dict)
- else:
- inner_dict[decoder] = gym.spaces.Dict(component_dict)
- gym_dict[key] = gym.spaces.Dict(inner_dict)
-
- if flatten_obs_space:
- gym_dict = generate_flat_dict(gym_dict)
- return gym.spaces.Dict(gym_dict)
-
-
-def namespace_channels(channels: dict, namespace: str):
-
- prefix = f"{namespace}/"
-
- # Flat mapping {channel_name: type}
- if all(isinstance(v, type) for v in channels.values()):
- out = {}
- for ch_name, ch_type in channels.items():
- out[f"{prefix}{ch_name}"] = ch_type
- return out
-
- # Structured schema with lists of dicts containing "from"
- out = {}
-
- for key, items in channels.items():
- new_items = []
- for entry in items:
- entry = entry.copy() # avoid modifying original
-
- if "from" in entry:
- entry["from"] = f"{prefix}{entry['from']}"
-
- new_items.append(entry)
-
- out[key] = new_items
-
- return out
diff --git a/ark/utils/config_utils.py b/ark/utils/config_utils.py
deleted file mode 100644
index 9e45cb0..0000000
--- a/ark/utils/config_utils.py
+++ /dev/null
@@ -1,11 +0,0 @@
-def resolve_class(path: str) -> type:
- """Resolve a fully-qualified class path like 'module.submodule:Class'."""
-
- module_name, _, class_name = path.rpartition(".")
- if not module_name or not class_name:
- raise ValueError(f"Invalid class path '{path}'")
- module = __import__(module_name, fromlist=[class_name])
- cls = getattr(module, class_name, None)
- if cls is None:
- raise ImportError(f"Class '{class_name}' not found in module '{module_name}'")
- return cls
diff --git a/ark/utils/data_utils.py b/ark/utils/data_utils.py
deleted file mode 100644
index 77316b0..0000000
--- a/ark/utils/data_utils.py
+++ /dev/null
@@ -1,59 +0,0 @@
-import gymnasium as gym
-import torch
-from collections.abc import Iterable
-
-def generate_flat_dict(dic, prefix=None):
- """
- Helper function to recursively iterate through dictionary / gym.spaces.Dict @dic and flatten any nested elements,
- such that the result is a flat dictionary mapping keys to values
-
- Args:
- dic (dict or gym.spaces.Dict): (Potentially nested) dictionary to convert into a flattened dictionary
- prefix (None or str): Prefix to append to the beginning of all strings in the flattened dictionary. None results
- in no prefix being applied
-
- Returns:
- dict: Flattened version of @dic
- """
- out = dict()
- prefix = "" if prefix is None else f"{prefix}::"
- for k, v in dic.items():
- if isinstance(v, gym.spaces.Dict) or isinstance(v, dict):
- out.update(generate_flat_dict(dic=v, prefix=f"{prefix}{k}"))
- elif isinstance(v, gym.spaces.Tuple) or isinstance(v, tuple):
- for i, vv in enumerate(v):
- # Assume no dicts are nested within tuples
- out[f"{prefix}{k}::{i}"] = vv
- else:
- # Add to out dict
- out[f"{prefix}{k}"] = v
-
- return out
-
-
-def generate_compatible_dict(dic):
- """
- Helper function to recursively iterate through dictionary and cast values to necessary types to be compatible with
- Gym spaces -- in particular, the Sequence and Tuple types for th.tensor values in @dic
-
- Args:
- dic (dict or gym.spaces.Dict): (Potentially nested) dictionary to convert into a flattened dictionary
-
- Returns:
- dict: Gym-compatible version of @dic
- """
- out = dict()
- for k, v in dic.items():
- if isinstance(v, dict):
- out[k] = generate_compatible_dict(dic=v)
- elif isinstance(v, torch.Tensor) and v.dim() > 1:
- # Map to list of tuples
- out[k] = tuple(tuple(row.tolist()) for row in v)
- elif isinstance(v, Iterable):
- # bounding box modalities give a list of tuples
- out[k] = tuple(v)
- else:
- # Preserve the key-value pair
- out[k] = v
-
- return out
diff --git a/ark/utils/isaac_utils.py b/ark/utils/isaac_utils.py
deleted file mode 100644
index 576e65b..0000000
--- a/ark/utils/isaac_utils.py
+++ /dev/null
@@ -1,27 +0,0 @@
-import os
-import subprocess
-import sys
-from pathlib import Path
-
-
-def configure_isaac_setup():
- # Get ARK_ISSAC_PATH from environment
- ark_path = os.environ.get("ARK_ISSAC_PATH")
-
- # Terminate if not defined or empty
- if not ark_path:
- print(
- "ERROR: ARK_ISSAC_PATH is not defined. Please set it before running the isaac simulator."
- )
- sys.exit(1)
-
- # Resolve the expected setup script path
- setup_script = Path(ark_path) / "setup_conda_env.sh"
-
- if not setup_script.is_file():
- print(
- f"ERROR: Could not Configure isaac environment from the provided path: {setup_script}"
- )
- sys.exit(1)
-
- subprocess.run(f"source {setup_script}", shell=True, executable="/bin/bash", check=True)
diff --git a/ark/utils/lazy.py b/ark/utils/lazy.py
deleted file mode 100644
index 7e573d7..0000000
--- a/ark/utils/lazy.py
+++ /dev/null
@@ -1,5 +0,0 @@
-import sys
-
-from ark.utils.lazy_import_utils import LazyImporter
-
-sys.modules[__name__] = LazyImporter("", None)
diff --git a/ark/utils/lazy_import_utils.py b/ark/utils/lazy_import_utils.py
deleted file mode 100644
index fcdef8b..0000000
--- a/ark/utils/lazy_import_utils.py
+++ /dev/null
@@ -1,77 +0,0 @@
-import importlib
-from types import ModuleType
-
-
-class LazyImporter(ModuleType):
- """A lazily-loading proxy for a module and its submodules.
-
- This class replaces a module's global namespace to support lazy loading
- of submodules and member attributes. When an attribute is accessed, the
- importer first attempts to treat the name as a submodule. If that fails,
- it then attempts to retrieve the attribute from the wrapped module.
-
- """
-
- def __init__(self, module_name: str, module: ModuleType):
- """Initializes the lazy importer.
-
- Args:
- module_name (str): Name of the module being wrapped.
- module (ModuleType): The imported module instance.
- """
- super().__init__("lazy_" + module_name)
- self._module_path = module_name
- self._module = module
- self._not_module = set()
- self._submodules = {}
-
- def __getattr__(self, name: str):
- """Resolves an attribute access lazily.
-
- The attribute name is first checked as a possible submodule. If importing fails, the name is
- assumed to be a regular attribute of the wrapped module.
-
- Args:
- name : Name of attribute or submodule.
-
- Returns:
- Any: The resolved submodule or attribute.
-
- """
- if name not in self._not_module:
- submodule = self._try_load_submodule(name)
- if submodule:
- return submodule
-
- self._not_module.add(name)
-
- try:
- return getattr(self._module, name)
- except:
- raise AttributeError(
- f"module {self.__name__} has no attribute {name}"
- ) from None
-
- def _try_load_submodule(self, module_name: str):
- """Attempts to load a submodule lazily.
-
- Args:
- module_name: Submodule name relative to this module.
-
- Returns:
- LazyImporter | None: A LazyImporter for the submodule if found,
- otherwise None.
- """
-
- if self._module_path:
- module_name = f"{self._module_path}.{module_name}"
-
- if module_name in self._submodules:
- return self._submodules[module_name]
-
- try:
- wrapper = LazyImporter(module_name, importlib.import_module(module_name))
- self._submodules[module_name] = wrapper
- return wrapper
- except ModuleNotFoundError:
- return None
diff --git a/ark/utils/scene_status_utils.py b/ark/utils/scene_status_utils.py
deleted file mode 100644
index a2d9634..0000000
--- a/ark/utils/scene_status_utils.py
+++ /dev/null
@@ -1,151 +0,0 @@
-from __future__ import annotations
-
-from dataclasses import dataclass
-from typing import Any
-
-import numpy as np
-
-
-@dataclass
-class ObjectState:
- """
- Lightweight carrier for object pose used by reward/termination logic.
-
- Works with both nested and flattened observation dictionaries by looking
- for common field names.
- """
-
- name: str
- position: np.ndarray
- orientation: np.ndarray | None = None
-
- def distance_to(self, other: "ObjectState") -> float:
- return float(np.linalg.norm(self.position - other.position))
-
- @staticmethod
- def from_observation(obs: dict[str, Any], name: str):
- """
- Try to extract an object's position/orientation from the observation dict.
- Supports:
- - nested: obs[name] = {"position": ..., "orientation": ...}
- - flattened: keys like f"{name}::position", f"{name}::orientation"
- """
-
- pos = ori = None
-
- # Flattened observation dict (most common in this repo)
- if f"objects::{name}::position" in obs:
- pos = obs.get(f"objects::{name}::position")
- ori = obs.get(f"objects::{name}::orientation")
- # Nested observation dict
- elif "objects" in obs and name in obs["objects"]:
- obj = obs["objects"][name]
- pos = obj.get("position")
- ori = obj.get("orientation")
-
- if pos is None:
- return None
- pos_arr = np.asarray(pos, dtype=np.float32).reshape(-1)
- ori_arr = None
- if ori is not None:
- ori_arr = np.asarray(ori, dtype=np.float32).reshape(-1)
- return ObjectState(name=name, position=pos_arr, orientation=ori_arr)
-
-
-@dataclass
-class RobotState:
- """
- Lightweight carrier for object pose used by reward/termination logic.
-
- Works with both nested and flattened observation dictionaries by looking
- for common field names.
- """
-
- position: np.ndarray
- orientation: np.ndarray
- joint_positions: np.ndarray
-
- @staticmethod
- def from_observation(obs: dict[str, Any]):
- """
- Extract the robot pose from either flattened or nested observations.
- """
- pos = ori = joints = None
-
- # Flattened keys (after generate_flat_dict)
- if "proprio::pose::position" in obs:
- pos = obs.get("proprio::pose::position")
- ori = obs.get("proprio::pose::orientation")
-
- if "proprio::joint_state::position" in obs:
- joints = obs.get("proprio::joint_state::position")
-
- if "proprio" in obs:
- proprio = obs["proprio"]
-
- # Nested pose
- if "pose" in proprio:
- pose = proprio["pose"]
- pos = pos or pose.get("position")
- ori = ori or pose.get("orientation")
-
- # Nested joint state
- if "joint_state" in proprio:
- js = proprio["joint_state"]
- joints = joints or js.get("position")
-
- if pos is None or ori is None:
- return None
- pos_arr = np.asarray(pos, dtype=np.float32).reshape(-1)
- ori_arr = np.asarray(ori, dtype=np.float32).reshape(-1)
- joints_arr = np.asarray(joints, dtype=np.float32).reshape(-1)
-
- return RobotState(
- position=pos_arr, orientation=ori_arr, joint_positions=joints_arr
- )
-
- def get_position_orientation(self):
- return self.position, self.orientation
-
- def get_position(self):
- return self.position
-
- def get_current_joint_states(self):
- return self.joint_positions.copy()
-
-
-def task_space_action_from_obs(
- obs: dict[str, Any], action_dim: int, num_envs: int = 1
-) -> np.ndarray:
- """
- Build an initial task-space action (xyz + quat + gripper) from the reset observation.
- Falls back to zeros if position/orientation are missing.
- """
- init = np.zeros((num_envs, action_dim), dtype=np.float32)
- if not isinstance(obs, dict):
- return init
-
- pos = obs.get("proprio::pose::position")
- ori = obs.get("proprio::pose::orientation")
-
- if (pos is None or ori is None) and "proprio" in obs:
- proprio = obs["proprio"]
- pose = proprio.get("pose", {})
- pos = pos if pos is not None else pose.get("position")
- ori = ori if ori is not None else pose.get("orientation")
-
- if pos is None or ori is None:
- return init
-
- pos_arr = np.asarray(pos, dtype=np.float32).reshape(num_envs, -1)
- ori_arr = np.asarray(ori, dtype=np.float32).reshape(num_envs, -1)
-
- pos_len = min(3, pos_arr.shape[1], action_dim)
- init[:, :pos_len] = pos_arr[:, :pos_len]
-
- ori_start = pos_len
- ori_len = min(4, ori_arr.shape[1], action_dim - ori_start)
- if ori_len > 0:
- init[:, ori_start : ori_start + ori_len] = ori_arr[:, :ori_len]
-
- return init
diff --git a/ark/utils/source_type_utils.py b/ark/utils/source_type_utils.py
deleted file mode 100644
index 4dfc4f6..0000000
--- a/ark/utils/source_type_utils.py
+++ /dev/null
@@ -1,11 +0,0 @@
-from enum import Enum
-
-
-class SourceType(Enum):
- """Supported source types for object creation."""
-
- URDF = "urdf"
- PRIMITIVE = "primitive"
- SDF = "sdf"
- MJCF = "mjcf"
- USD = "usd"
diff --git a/ark/utils/utils.py b/ark/utils/utils.py
deleted file mode 100644
index 248ec63..0000000
--- a/ark/utils/utils.py
+++ /dev/null
@@ -1,72 +0,0 @@
-import json
-import os
-from pathlib import Path, PosixPath, WindowsPath
-from typing import Type
-
-import yaml
-from ark.tools.log import log
-
-# Explicit mapping for known platforms
-_OS_NAME_TO_PATH_CLS: dict[str, Type[Path]] = {
- "posix": PosixPath, # Linux, macOS
- "nt": WindowsPath, # Windows
-}
-
-# Pick base class
-BasePathClass: Type[Path] = _OS_NAME_TO_PATH_CLS.get(os.name, type(Path()))
-
-
-class ConfigPath(BasePathClass):
- """
- A Path subclass with convenience methods for reading configuration files.
- Works cross-platform (inherits PosixPath on Linux/macOS or WindowsPath on Windows).
- """
-
- @property
- def str(self) -> str:
- """
- Return the string representation of this path.
- Equivalent to calling str(self).
- """
- return str(self)
-
- def read_yaml(self, raise_fnf_error: bool = True) -> dict:
- """
- Load a YAML configuration schema from this path.
-
- Args:
- raise_fnf_error: If True, raise FileNotFoundError when the file is missing.
-
- Returns:
- Parsed YAML as a dictionary. Returns {} if empty.
- """
- if self.exists():
- with self.open("r", encoding="utf-8") as f:
- return yaml.safe_load(f) or {}
- else:
- if raise_fnf_error:
- raise FileNotFoundError(f"Config file not found: {self}")
- log.error(f"Config file {self} does not exist.")
- return {}
-
- def read_json(self, raise_fnf_error: bool = True) -> dict:
- """
- Load a JSON configuration schema from this path.
-
- Args:
- raise_fnf_error: If True, raise FileNotFoundError when the file is missing.
-
- Returns:
- Parsed JSON as a dictionary.
- """
- if self.exists():
- with self.open("r", encoding="utf-8") as f:
- return json.load(f)
- else:
- if raise_fnf_error:
- raise FileNotFoundError(f"Config file not found: {self}")
- log.error(f"Config file {self} does not exist.")
- return {}
-
- def __repr__(self) -> str:
- return f""
diff --git a/ark/utils/video_recorder.py b/ark/utils/video_recorder.py
deleted file mode 100644
index 6384ec6..0000000
--- a/ark/utils/video_recorder.py
+++ /dev/null
@@ -1,62 +0,0 @@
-from __future__ import annotations
-
-from pathlib import Path
-from typing import Any
-
-import imageio.v2 as imageio
-import numpy as np
-
-
-class VideoRecorder:
- """Video recorder for Ark environments."""
-
- def __init__(
- self,
- out_path: str | Path,
- fps: int = 20,
- obs_rgb_key: str = "rgb",
- ) -> None:
- self.out_path = Path(out_path)
- self.fps = fps
- self.obs_rgb_key = obs_rgb_key
- self._writer = None
-
- def start(self) -> None:
- if self._writer is None:
- self.out_path.parent.mkdir(parents=True, exist_ok=True)
- self._writer = imageio.get_writer(self.out_path, fps=self.fps)
-
- def add_frame(self, obs: dict[str, Any]) -> None:
- """
- Extract an RGB frame from observation dict and append to the video.
- """
- if self._writer is None:
- self.start()
-
- frame = obs.get(self.obs_rgb_key)
- if frame is None:
- # print("empty frame")
- return
-
- arr = np.asarray(frame)
- # Handle batched obs by taking first element
- if arr.ndim == 4:
- arr = arr[0]
-
- # Normalize if float images are in [0, 1]
- if arr.dtype != np.uint8:
- if arr.max() <= 1.0:
- arr = (arr * 255.0).clip(0, 255).astype(np.uint8)
- else:
- arr = arr.astype(np.uint8)
-
- # imageio expects HWC
- if arr.shape[-1] != 3 and arr.shape[0] == 3:
- arr = np.transpose(arr, (1, 2, 0))
-
- self._writer.append_data(arr)
-
- def close(self) -> None:
- if self._writer is not None:
- self._writer.close()
- self._writer = None
diff --git a/Logo.png b/docs/logo.png
similarity index 100%
rename from Logo.png
rename to docs/logo.png
diff --git a/pyproject.toml b/pyproject.toml
index f4589d4..bf56fba 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,30 +1,34 @@
[build-system]
-requires = ["setuptools>=64"]
+requires = ["setuptools>=68", "wheel"]
build-backend = "setuptools.build_meta"
[project]
name = "ark"
-version = "0.1"
+version = "1.0.0"
+requires-python = ">=3.12"
+description = "The Ark framework."
+
+authors = [
+ { name = "Christopher E. Mower" },
+ { name = "Refinath S N" }
+]
+
+maintainers = [
+ { name = "Christopher E. Mower" },
+ { name = "Refinath S N" }
+]
+
dependencies = [
- "lcm",
- "colorlog",
- "opencv-python",
- "gymnasium",
- "matplotlib",
- "pandas",
- "numpy==1.24.3",
- "PyYAML",
- "typer",
- "graphviz",
- "scipy",
- "pybullet; sys_platform == 'linux'"
+ "protobuf>=4.21",
+ "numpy",
+ "ark_msgs @ git+https://github.com/Robotics-Ark/ark_msgs.git"
]
-[project.optional-dependencies]
-test = ["pytest"]
+[tool.setuptools]
+package-dir = { "" = "src" }
-[tool.setuptools.packages]
-find = {include = ["ark", "arktypes", "arktypes.*"]}
+[tool.setuptools.packages.find]
+where = ["src"]
[project.scripts]
-ark = "ark.cli:main"
\ No newline at end of file
+ark = "ark.scripts.core:main"
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
deleted file mode 100644
index b294654..0000000
--- a/requirements.txt
+++ /dev/null
@@ -1,13 +0,0 @@
-lcm
-colorlog
-opencv-python
-gymnasium
-matplotlib
-pandas
-numpy==1.24.3
-pybullet
-PyYAML
-typer
-graphviz
-networkx
-scipy
\ No newline at end of file
diff --git a/ark/tests/__init__.py b/src/ark/__init__.py
similarity index 100%
rename from ark/tests/__init__.py
rename to src/ark/__init__.py
diff --git a/src/ark/comm/end_point.py b/src/ark/comm/end_point.py
new file mode 100644
index 0000000..71d7303
--- /dev/null
+++ b/src/ark/comm/end_point.py
@@ -0,0 +1,32 @@
+import zenoh
+from ark.time.clock import Clock
+from ark.core.registerable import Registerable
+from ark.data.data_collector import DataCollector
+
+
+class EndPoint(Registerable):
+
+ def __init__(
+ self,
+ node_name: str,
+ session: zenoh.Session,
+ clock: Clock,
+ channel: str,
+ data_collector: DataCollector | None,
+ ):
+ self._node_name = node_name
+ self._session = session
+ self._clock = clock
+ self._channel = channel
+ self._data_collector = data_collector
+ self._active = True
+ self._seq_index = 0
+
+ def is_active(self) -> bool:
+ return self._active
+
+ def reset(self):
+ self._active = True
+
+ def close(self):
+ self._active = False
diff --git a/src/ark/comm/publisher.py b/src/ark/comm/publisher.py
new file mode 100644
index 0000000..38b61de
--- /dev/null
+++ b/src/ark/comm/publisher.py
@@ -0,0 +1,58 @@
+import zenoh
+from ark.time.clock import Clock
+from .end_point import EndPoint
+from ark_msgs import Envelope
+from google.protobuf.message import Message
+from ark.data.data_collector import DataCollector
+
+
+class Publisher(EndPoint):
+
+ def __init__(
+ self,
+ node_name: str,
+ session: zenoh.Session,
+ clock: Clock,
+ channel: str,
+ data_collector: DataCollector | None,
+ ):
+ super().__init__(node_name, session, clock, channel, data_collector)
+ self._pub = self._session.declare_publisher(self._channel)
+
+ def core_registration(self):
+ print("..todo: register with ark core..")
+
+ def publish(self, msg: Message | bytes):
+ if self._active:
+
+ # Create Envelope
+ env = Envelope(
+ endpoint_type=Envelope.EndpointType.PUBLISH,
+ channel=self._channel,
+ src_node_name=self._node_name,
+ sent_seq_index=self._seq_index,
+ sent_timestamp=self._clock.now(),
+ )
+ if isinstance(msg, Message):
+ env.msg_type = msg.DESCRIPTOR.full_name
+ env.payload = msg.SerializeToString()
+ elif isinstance(msg, bytes):
+ env.msg_type = "__bytes__"
+ env.payload = bytes(msg)
+ else:
+ raise TypeError("msg must be a protobuf Message or bytes")
+ env_bytes = env.SerializeToString()
+
+ # Publish envelope
+ self._pub.put(env_bytes)
+
+ # Collect data if enabled
+ if self._data_collector is not None:
+ self._data_collector.append(env_bytes)
+
+ # Increment sequence index
+ self._seq_index += 1
+
+ def close(self):
+ super().close()
+ self._pub.undeclare()
diff --git a/src/ark/comm/queriable.py b/src/ark/comm/queriable.py
new file mode 100644
index 0000000..294346a
--- /dev/null
+++ b/src/ark/comm/queriable.py
@@ -0,0 +1,76 @@
+import zenoh
+from google.protobuf.message import Message
+from ark_msgs import Envelope
+from ark.time.clock import Clock
+from ark.comm.end_point import EndPoint
+from ark.data.data_collector import DataCollector
+from ark_msgs.registry import msgs
+from typing import Callable
+
+
+class Queryable(EndPoint):
+
+ def __init__(
+ self,
+ node_name: str,
+ session: zenoh.Session,
+ clock: Clock,
+ channel: str,
+ handler: Callable[[Message], Message],
+ data_collector: DataCollector | None = None,
+ ):
+ super().__init__(node_name, session, clock, channel, data_collector)
+ self._handler = handler
+ self._queryable = self._session.declare_queryable(self._channel, self._on_query)
+
+ def core_registration(self):
+ print("..todo: register with ark core..")
+
+ def _on_query(self, query: zenoh.Query) -> None:
+ # If we were closed, ignore queries
+ if not self._active:
+ return
+
+ try:
+ # Zenoh query may or may not include a payload.
+ # For your use-case, the request is always in query.value (bytes)
+ raw = bytes(query.value) if query.value is not None else b""
+ if not raw:
+ return # nothing to do
+
+ req_env = Envelope()
+ req_env.ParseFromString(raw)
+
+ # Decode request protobuf
+ req_type = msgs.get(req_env.payload_msg_type)
+ if req_type is None:
+ # Unknown message type: ignore (or reply error later)
+ return
+
+ req_msg = req_type()
+ req_msg.ParseFromString(req_env.payload)
+
+ # Call user handler
+ resp_msg: Message = self._handler(req_msg)
+
+ # Pack envelope for response
+ resp_env = Envelope()
+ resp_env.endpoint_type = Envelope.EndpointType.RESPONSE
+ resp_env.sent_timestamp = self._clock.now()
+ resp_env.sent_seq_index = self._seq_index
+ resp_env.src_node_name = self._node_name
+ resp_env.channel = self._channel
+
+ self._seq_index += 1
+
+ resp_env = Envelope.pack(self._node_name, self._clock, resp_msg)
+ query.reply(resp_env.SerializeToString())
+
+ if self._data_collector:
+ self._data_collector.append(req_env.SerializeToString())
+ self._data_collector.append(resp_env.SerializeToString())
+
+ except Exception:
+ # Keep it minimal: don't kill the zenoh callback thread
+ # You can add logging here if desired
+ return
diff --git a/src/ark/comm/querier.py b/src/ark/comm/querier.py
new file mode 100644
index 0000000..c6d4586
--- /dev/null
+++ b/src/ark/comm/querier.py
@@ -0,0 +1,79 @@
+import zenoh
+from ark_msgs import Envelope
+from google.protobuf.message import Message
+from ark.data.data_collector import DataCollector
+from ark.comm.end_point import EndPoint
+
+
+class Querier(EndPoint):
+
+ def __init__(
+ self,
+ node_name: str,
+ session: zenoh.Session,
+ clock,
+ channel: str,
+ data_collector: DataCollector | None,
+ ):
+ super().__init__(node_name, session, clock, channel, data_collector)
+ self._querier = self._session.declare_querier(self._channel)
+
+ def core_registration(self):
+ print("..todo: register with ark core..")
+
+ def query(
+ self,
+ req: Message | bytes,
+ timeout: float = 10.0,
+ ) -> Message:
+ """Send a query message and wait for the first OK response."""
+ if not self._active:
+ raise RuntimeError("Querier is not active")
+
+ # Create Envelope for the request
+ req_env = Envelope(
+ endpoint_type=Envelope.EndpointType.REQUEST,
+ channel=self._channel,
+ src_node_name=self._node_name,
+ sent_seq_index=self._seq_index,
+ sent_timestamp=self._clock.now(),
+ )
+
+ if isinstance(req, Message):
+ req_env.msg_type = req.DESCRIPTOR.full_name
+ req_env.payload = req.SerializeToString()
+ elif isinstance(req, bytes):
+ req_env.msg_type = "__bytes__"
+ req_env.payload = bytes(req)
+ else:
+ raise TypeError("req must be a protobuf Message or bytes")
+
+ replies = self._querier.get(value=req_env.SerializeToString(), timeout=timeout)
+
+ for reply in replies:
+ if reply.ok is None:
+ continue
+
+ resp_env = Envelope()
+ resp_env.ParseFromString(bytes(reply.ok))
+ resp_env.dst_node_name = self._node_name
+ resp_env.recv_timestamp = self._clock.now()
+
+ resp = resp_env.extract_message()
+
+ self._seq_index += 1
+
+ if self._data_collector:
+ self._data_collector.append(req_env.SerializeToString())
+ self._data_collector.append(resp_env.SerializeToString())
+
+ return resp
+
+ else:
+ raise TimeoutError(
+ f"No OK reply received for query on '{self._channel}' within {timeout}s"
+ )
+
+ def close(self):
+ super().close()
+ self._querier.undeclare()
diff --git a/src/ark/comm/subscriber.py b/src/ark/comm/subscriber.py
new file mode 100644
index 0000000..f579b2c
--- /dev/null
+++ b/src/ark/comm/subscriber.py
@@ -0,0 +1,51 @@
+import zenoh
+from ark.time.clock import Clock
+from ark_msgs import Envelope
+from collections.abc import Callable
+from ark.comm.end_point import EndPoint
+from google.protobuf.message import Message
+from ark.data.data_collector import DataCollector
+
+
+class Subscriber(EndPoint):
+
+ def __init__(
+ self,
+ node_name: str,
+ session: zenoh.Session,
+ clock: Clock,
+ channel: str,
+ data_collector: DataCollector | None,
+ callback: Callable[[Message | bytes], None],
+ ):
+ super().__init__(node_name, session, clock, channel, data_collector)
+ self._callback = callback
+ self._sub = self._session.declare_subscriber(self._channel, self._on_sample)
+
+ def core_registration(self):
+ print("..todo: register with ark core..")
+
+ def _on_sample(self, sample: zenoh.Sample):
+ if self._active:
+
+ # Retreive Envelope from sample and mark as RECEIVE
+ env = Envelope()
+ env.ParseFromString(bytes(sample.payload))
+ env.endpoint_type = Envelope.EndpointType.RECEIVE
+ env.dst_node_name = self._node_name
+ env.recv_timestamp = self._clock.now()
+ env.recv_seq_index = self._seq_index
+
+ # Collect data if enabled
+ if self._data_collector:
+ self._data_collector.append(env.SerializeToString())
+
+ # Invoke user callback
+ self._callback(env.extract_message())
+
+ # Increment sequence index
+ self._seq_index += 1
+
+ def close(self):
+ super().close()
+ self._sub.undeclare()
diff --git a/src/ark/core/registerable.py b/src/ark/core/registerable.py
new file mode 100644
index 0000000..8e00e75
--- /dev/null
+++ b/src/ark/core/registerable.py
@@ -0,0 +1,12 @@
+from abc import ABC, abstractmethod
+
+
+class Registerable(ABC):
+
+ @abstractmethod
+ def core_registration(self):
+ """Register the object with ark core."""
+
+ @abstractmethod
+ def close(self):
+ """Close the object and release any resources."""
diff --git a/src/ark/data/data_collector.py b/src/ark/data/data_collector.py
new file mode 100644
index 0000000..b43595a
--- /dev/null
+++ b/src/ark/data/data_collector.py
@@ -0,0 +1,48 @@
+import struct
+from pathlib import Path
+from datetime import datetime
+from threading import Thread
+from queue import Queue
+from ark.core.registerable import Registerable
+
+
+class DataCollector(Registerable):
+ __slots__ = ("_path", "_file_path", "_queue", "_thread")
+
+ _SENTINEL = object()
+
+ def __init__(self, node_name: str, queue_maxsize: int = 1000):
+ self._path = Path.home() / ".ark" / "data" / node_name
+ self._path.mkdir(parents=True, exist_ok=True)
+
+ stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ self._file_path = self._path / f"data_{stamp}.bin"
+
+ self._queue: Queue[object] = Queue(maxsize=queue_maxsize)
+
+ self.core_registration()
+
+ self._thread = Thread(target=self._save_data, daemon=True)
+ self._thread.start()
+
+ def append(self, data: bytes):
+ self._queue.put(data)
+
+ def core_registration(self):
+ print(".. todo: register data collector with ark core..")
+
+ def _save_data(self):
+ with open(self._file_path, "ab", buffering=1024 * 1024) as f:
+ while True:
+ b = self._queue.get() # serialized message
+ if b is self._SENTINEL:
+ self._queue.task_done()
+ break
+
+ f.write(struct.pack(" list[Envelope]:
+ records: list[Envelope] = []
+ path = Path(path)
+
+ with open(path, "rb") as f:
+ while True:
+ len_bytes = f.read(4)
+ if not len_bytes:
+ break # EOF
+ if len(len_bytes) != 4:
+ raise IOError("Corrupted data file (incomplete length prefix)")
+
+ (msg_len,) = struct.unpack(" Publisher:
+ pub = Publisher(
+ self._node_name,
+ self._session,
+ self._clock,
+ channel,
+ self._data_collector,
+ )
+ pub.core_registration()
+ self._pubs[channel] = pub
+ return pub
+
+ def create_subscriber(self, channel, callback) -> Subscriber:
+ sub = Subscriber(
+ self._node_name,
+ self._session,
+ self._clock,
+ channel,
+ self._data_collector,
+ callback,
+ )
+ sub.core_registration()
+ self._subs[channel] = sub
+ return sub
+
+ def create_querier(self, channel, timeout=10.0) -> Querier:
+ querier = Querier(
+ self._node_name,
+ self._session,
+ self._clock,
+ channel,
+ self._data_collector,
+ timeout,
+ )
+ querier.core_registration()
+ self._queriers[channel] = querier
+ return querier
+
+ def create_queryable(self, channel, handler) -> Queryable:
+ queryable = Queryable(
+ self._node_name,
+ self._session,
+ self._clock,
+ channel,
+ handler,
+ self._data_collector,
+ )
+ queryable.core_registration()
+ self._queriables[channel] = queryable
+ return queryable
+
+ def create_rate(self, hz: float):
+ rate = Rate(self._clock, hz)
+ self._rates.append(rate)
+ return rate
+
+ def create_stepper(self, hz: float, callback) -> Stepper:
+ stepper = Stepper(self._clock, hz, callback)
+ self._steppers.append(stepper)
+ return stepper
+
+ def spin(self):
+ while True:
+ time.sleep(1.0)
+
+ def close(self):
+ closable_objs = (
+ self._steppers
+ + list(self._pubs.values())
+ + list(self._subs.values())
+ + list(self._queriers.values())
+ + list(self._queriables.values())
+ )
+ for obj in closable_objs:
+ obj.close()
+
+ self._session.close()
+
+ if self._data_collector:
+ self._data_collector.close()
diff --git a/src/ark/registry.py b/src/ark/registry.py
new file mode 100644
index 0000000..68b749a
--- /dev/null
+++ b/src/ark/registry.py
@@ -0,0 +1,38 @@
+from copy import deepcopy
+from typing import Any
+
+
+class Registry:
+
+ def __init__(self):
+ self._registry: dict[str, Any] = {}
+ self._counter = 0
+
+ def register_item(self, name: str, item: Any) -> None:
+ index = deepcopy(self._counter)
+ self._registry[name] = (index, item)
+ self._counter += 1
+
+ def register(self, name: str):
+ """Register a new item in the registry using decorator syntax."""
+
+ def register_item(item):
+ self.register_item(name, item)
+ return item
+
+ return register_item
+
+ def get(self, name: str) -> Any:
+ _, item = self._registry[name]
+ return item
+
+ def get_index(self, name: str) -> int:
+ index, _ = self._registry[name]
+ return index
+
+ def get_name(self, index: int) -> str:
+ for name, (idx, _) in self._registry.items():
+ if idx == index:
+ return name
+ else:
+ raise KeyError(f"Index {index} not found in registry.")
diff --git a/src/ark/scripts/core.py b/src/ark/scripts/core.py
new file mode 100644
index 0000000..8760b94
--- /dev/null
+++ b/src/ark/scripts/core.py
@@ -0,0 +1,6 @@
+import sys
+
+
+def main():
+ print(">>Ark core<<")
+ print(sys.executable)
diff --git a/src/ark/time/clock.py b/src/ark/time/clock.py
new file mode 100644
index 0000000..da9567b
--- /dev/null
+++ b/src/ark/time/clock.py
@@ -0,0 +1,61 @@
+import struct
+import threading
+import zenoh
+from time import time_ns
+
+
+class Clock:
+ __slots__ = ("_sim", "_t", "_started", "_sim_time_cv", "_sub", "now")
+
+ _FMT = " int:
+ with self._sim_time_cv:
+ while not self._started:
+ self._sim_time_cv.wait()
+ return self._t
+
+ def wait_until(self, target: int) -> None:
+ with self._sim_time_cv:
+ while not self._started:
+ self._sim_time_cv.wait()
+ while self._t < target:
+ self._sim_time_cv.wait()
+
+ def notify(self) -> None:
+ if not self._sim:
+ return
+ with self._sim_time_cv:
+ self._sim_time_cv.notify_all()
diff --git a/src/ark/time/rate.py b/src/ark/time/rate.py
new file mode 100644
index 0000000..53eb15c
--- /dev/null
+++ b/src/ark/time/rate.py
@@ -0,0 +1,26 @@
+from .clock import Clock
+from .sleeper import Sleeper
+
+
+class Rate:
+
+ __slots__ = ("_clock", "_sleep", "_time_step", "_next")
+
+ def __init__(self, clock: Clock, hz: float):
+ self._clock = clock
+ self._sleep = Sleeper(clock)
+ self._time_step = int(1e9 / hz) # Convert hz to nanoseconds period
+ self.reset()
+
+ def reset(self) -> None:
+ self._next = self._clock.now() + self._time_step
+
+ def sleep(self):
+ now = self._clock.now()
+ remaining = self._next - now
+ if remaining > 0:
+ self._sleep(remaining)
+ self._next += self._time_step
+ else:
+ # We are late, skip to next period
+ self._next = now + self._time_step
diff --git a/src/ark/time/simtime.py b/src/ark/time/simtime.py
new file mode 100644
index 0000000..1879ff8
--- /dev/null
+++ b/src/ark/time/simtime.py
@@ -0,0 +1,21 @@
+import zenoh
+import struct
+
+
+class SimTime:
+ def __init__(
+ self, session: zenoh.Session, clock_channel_name: str, time_step_ns: int
+ ):
+ self._pub = session.declare_publisher(clock_channel_name)
+ self._sim_time_ns = None
+ self._time_step_ns = int(time_step_ns)
+
+ def reset(self):
+ self._sim_time_ns = 0
+ self._pub.put(struct.pack(" None:
+ if dur <= 0:
+ return
+ sleep(dur / 1e9) # ns -> s
+
+
+class Sleeper:
+ __slots__ = ("_clock", "_sleep")
+
+ def __init__(self, clock: Clock):
+ self._clock = clock
+ self._sleep = self._sim_sleep if clock._sim else _wall_sleep
+
+ def _sim_sleep(self, dur: int) -> None:
+ if dur <= 0:
+ return
+ now = self._clock.now()
+ self._clock.wait_until(now + dur)
+
+ def __call__(self, dur: int) -> None:
+ self._sleep(dur)
diff --git a/src/ark/time/stepper.py b/src/ark/time/stepper.py
new file mode 100644
index 0000000..2d4509e
--- /dev/null
+++ b/src/ark/time/stepper.py
@@ -0,0 +1,32 @@
+import threading
+from typing import Callable
+from ark.time.rate import Rate
+from ark.time.clock import Clock
+
+
+class Stepper(threading.Thread):
+
+ def __init__(
+ self,
+ clock: Clock,
+ hz: float,
+ callback: Callable[[int], None],
+ ):
+ super().__init__(daemon=True)
+ self._clock = clock
+ self._rate = Rate(clock, hz)
+ self.reset = self._rate.reset
+ self._callback = callback
+ self._closed = False
+ self.start()
+
+ def run(self) -> None:
+ while not self._closed:
+ self._rate.sleep()
+ if self._closed:
+ break
+ self._callback(self._clock.now())
+
+ def close(self) -> None:
+ self._closed = True
+ self._clock.notify()
diff --git a/test/common.py b/test/common.py
new file mode 100644
index 0000000..b26b213
--- /dev/null
+++ b/test/common.py
@@ -0,0 +1 @@
+z_cfg = {"mode": "peer", "connect": {"endpoints": ["udp/127.0.0.1:7447"]}}
diff --git a/test/diff_publisher.py b/test/diff_publisher.py
new file mode 100644
index 0000000..5bb1a04
--- /dev/null
+++ b/test/diff_publisher.py
@@ -0,0 +1,35 @@
+import math
+import time
+from ark.node import BaseNode
+from ark_msgs import Translation, dTranslation
+from common import z_cfg
+# Lissajous parameters
+A, B = 1.0, 1.0
+a, b = 3.0, 2.0
+delta = math.pi / 2
+HZ = 50
+DT = 1.0 / HZ
+class DiffPublisherNode(BaseNode):
+ def __init__(self):
+ super().__init__("env", "diff_pub", z_cfg, sim=True)
+ self.pos_pub = self.create_publisher("position")
+ self.vel_pub = self.create_publisher("velocity")
+ self.rate = self.create_rate(HZ)
+ def spin(self):
+ t = 0.0
+ while True:
+ x = A * math.sin(a * t + delta)
+ y = B * math.sin(b * t)
+ dx = A * a * math.cos(a * t + delta)
+ dy = B * b * math.cos(b * t)
+ self.pos_pub.publish(Translation(x=x, y=y, z=0.0))
+ self.vel_pub.publish(dTranslation(x=dx, y=dy, z=0.0))
+ t += DT
+ self.rate.sleep()
+if __name__ == "__main__":
+ try:
+ node = DiffPublisherNode()
+ node.spin()
+ except KeyboardInterrupt:
+ print("Shutting down diff publisher.")
+ node.close()
diff --git a/test/plotter_subsriber.py b/test/plotter_subsriber.py
new file mode 100644
index 0000000..2477962
--- /dev/null
+++ b/test/plotter_subsriber.py
@@ -0,0 +1,47 @@
+import threading
+import matplotlib.pyplot as plt
+import matplotlib.animation as animation
+from ark.node import BaseNode
+from ark_msgs import Translation, dTranslation
+from common import z_cfg
+class SubscriberPlotterNode(BaseNode):
+ def __init__(self):
+ super().__init__("env", "plotter", z_cfg, sim=True)
+ self.pos_x, self.pos_y = [], []
+ self.vel_x, self.vel_y = [], []
+ self.create_subscriber("position", self.on_position)
+ self.create_subscriber("velocity", self.on_velocity)
+ def on_position(self, msg: Translation):
+ self.pos_x.append(msg.x)
+ self.pos_y.append(msg.y)
+ def on_velocity(self, msg: dTranslation):
+ self.vel_x.append(msg.x)
+ self.vel_y.append(msg.y)
+def main():
+ node = SubscriberPlotterNode()
+ threading.Thread(target=node.spin, daemon=True).start()
+ fig, (ax_pos, ax_vel) = plt.subplots(1, 2, figsize=(10, 5))
+ ax_pos.set_title("Position (Translation)")
+ ax_pos.set_xlabel("x")
+ ax_pos.set_ylabel("y")
+ ax_pos.set_xlim(-1.5, 1.5)
+ ax_pos.set_ylim(-1.5, 1.5)
+ ax_pos.set_aspect("equal")
+ (line_pos,) = ax_pos.plot([], [], "b-")
+ ax_vel.set_title("Velocity (dTranslation)")
+ ax_vel.set_xlabel("dx")
+ ax_vel.set_ylabel("dy")
+ ax_vel.set_xlim(-5, 5)
+ ax_vel.set_ylim(-5, 5)
+ ax_vel.set_aspect("equal")
+ (line_vel,) = ax_vel.plot([], [], "r-")
+ def update(frame):
+ line_pos.set_data(node.pos_x, node.pos_y)
+ line_vel.set_data(node.vel_x, node.vel_y)
+ return line_pos, line_vel
+ ani = animation.FuncAnimation(fig, update, interval=50, blit=True)
+ plt.tight_layout()
+ plt.show()
+ node.close()
+if __name__ == "__main__":
+ main()
diff --git a/test/pub.py b/test/pub.py
new file mode 100644
index 0000000..d595a3b
--- /dev/null
+++ b/test/pub.py
@@ -0,0 +1,27 @@
+from ark.node import BaseNode
+from itertools import count
+from common import z_cfg
+
+
+class PublisherNode(BaseNode):
+
+ def __init__(self):
+ super().__init__("env", "pub", z_cfg, sim=True)
+ self.pub = self.create_publisher("chatter")
+ self.rate = self.create_rate(1) # 1 Hz
+
+ def spin(self):
+ for c in count():
+ msg = f"Hello World {c}"
+ self.pub.publish(msg.encode("utf-8"))
+ print(f"Published: {msg}")
+ self.rate.sleep()
+
+
+if __name__ == "__main__":
+ try:
+ node = PublisherNode()
+ node.spin()
+ except KeyboardInterrupt:
+ print("Shutting down publisher node.")
+ node.close()
diff --git a/test/simstep.py b/test/simstep.py
new file mode 100644
index 0000000..06da237
--- /dev/null
+++ b/test/simstep.py
@@ -0,0 +1,18 @@
+from ark.time.simtime import SimTime
+from common import z_cfg
+import json
+import zenoh
+import time
+
+def main():
+ z_config = zenoh.Config.from_json5(json.dumps(z_cfg))
+ with zenoh.open(z_config) as z:
+ sim_time = SimTime(z, "clock", 1000)
+ sim_time.reset()
+ while True:
+ current_time = time.time()
+ print(f"Simulated Time: {current_time:.2f} seconds")
+ sim_time.tick()
+
+if __name__ == "__main__":
+ main()
diff --git a/test/sub.py b/test/sub.py
new file mode 100644
index 0000000..3101551
--- /dev/null
+++ b/test/sub.py
@@ -0,0 +1,20 @@
+from common import z_cfg
+from ark.node import BaseNode
+
+
+class SubscriberNode(BaseNode):
+ def __init__(self):
+ super().__init__("env", "sub", z_cfg, sim=True, collect_data=True)
+ self.sub = self.create_subscriber("chatter", self.callback)
+
+ def callback(self, msg: bytes):
+ print(f"Received: {msg.decode('utf-8')}")
+
+
+if __name__ == "__main__":
+ try:
+ node = SubscriberNode()
+ node.spin()
+ except KeyboardInterrupt:
+ print("Shutting down subscriber node.")
+ node.close()