Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 76 additions & 11 deletions gridappsd-python-lib/gridappsd/goss.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import logging
import os
import random
import re
import threading
from collections import defaultdict
from datetime import datetime
Expand All @@ -76,6 +77,47 @@
_STOMP_V8 = _stomp_major >= 8


def _is_wildcard_topic(topic):
"""Return True if the topic contains ActiveMQ wildcard characters (* or >)."""
return '*' in topic or '>' in topic


def _wildcard_to_regex(topic):
"""Compile an ActiveMQ wildcard topic pattern into a regex Pattern.

ActiveMQ STOMP wildcard rules:
- '.' is the segment separator
- '*' matches exactly one segment (between dots)
- '>' matches one or more trailing segments (must be last token)

The topic may start with /topic/, /queue/, or /temp-queue/ which are
matched literally.
"""
prefix = ''
body = topic
for stomp_prefix in ('/topic/', '/queue/', '/temp-queue/'):
if topic.startswith(stomp_prefix):
prefix = stomp_prefix
body = topic[len(stomp_prefix):]
break

segments = body.split('.')
regex_parts = []
for seg in segments:
if seg == '>':
# '>' matches one or more trailing segments
regex_parts.append(r'[^/]+(?:\.[^/]+)*')
break
elif seg == '*':
# '*' matches exactly one segment
regex_parts.append(r'[^./]+')
else:
regex_parts.append(re.escape(seg))

pattern = re.escape(prefix) + r'\.'.join(regex_parts)
return re.compile('^' + pattern + '$')


class GRIDAPPSD_ENV_ENUM(Enum):
GRIDAPPSD_USER = "GRIDAPPSD_USER"
GRIDAPPSD_PASSWORD = "GRIDAPPSD_PASSWORD"
Expand Down Expand Up @@ -388,6 +430,8 @@ class CallbackRouter(object):
def __init__(self):
self.callbacks = {}
self._topics_callback_map = defaultdict(list)
self._wildcard_patterns = [] # list of (compiled_regex, topic_key)
self._lock = threading.Lock()
self._queue_callerback = Queue()
self._thread = threading.Thread(target=self.run_callbacks)
self._thread.daemon = True
Expand All @@ -410,18 +454,27 @@ def run_callbacks(self):
def add_callback(self, topic, callback):
if not topic.startswith("/topic/") and not topic.startswith("/temp-queue/"):
topic = "/queue/{topic}".format(topic=topic)
if callback in self._topics_callback_map[topic]:
raise ValueError("Callbacks can only be used one time per topic")
_log.debug("Added callbac using topic {topic}".format(topic=topic))
self._topics_callback_map[topic].append(callback)
with self._lock:
if callback in self._topics_callback_map[topic]:
raise ValueError("Callbacks can only be used one time per topic")
_log.debug("Added callback using topic {topic}".format(topic=topic))
self._topics_callback_map[topic].append(callback)
if _is_wildcard_topic(topic) and not any(key == topic for _, key in self._wildcard_patterns):
self._wildcard_patterns.append((_wildcard_to_regex(topic), topic))

def remove_callback(self, topic, callback):
if topic in self._topics_callback_map:
callbacks = self._topics_callback_map[topic]
try:
callbacks.remove(callback)
except ValueError:
pass
with self._lock:
if topic in self._topics_callback_map:
callbacks = self._topics_callback_map[topic]
try:
callbacks.remove(callback)
except ValueError:
pass
if not callbacks:
del self._topics_callback_map[topic]
self._wildcard_patterns = [
(pat, key) for pat, key in self._wildcard_patterns if key != topic
]

def on_message(self, *args):
if _STOMP_V8:
Expand All @@ -430,9 +483,21 @@ def on_message(self, *args):
else:
headers, message = args[0], args[1]
destination = headers["destination"]
# _log.debug("Topic map keys are: {keys}".format(keys=self._topics_callback_map.keys()))

# Fast path: exact match
if destination in self._topics_callback_map:
self._queue_callerback.put((self._topics_callback_map[destination], headers, message))
return

# Slow path: check wildcard patterns
matched_callbacks = []
with self._lock:
for pattern, topic_key in self._wildcard_patterns:
if pattern.fullmatch(destination):
matched_callbacks.extend(self._topics_callback_map[topic_key])

if matched_callbacks:
self._queue_callerback.put((matched_callbacks, headers, message))
else:
_log.error("INVALID DESTINATION {destination}".format(destination=destination))

Expand Down
Loading
Loading