diff --git a/kazoo/client.py b/kazoo/client.py index 3029d1c5..9b8384c8 100644 --- a/kazoo/client.py +++ b/kazoo/client.py @@ -25,11 +25,14 @@ from kazoo.protocol.connection import ConnectionHandler from kazoo.protocol.paths import _prefix_root, normpath from kazoo.protocol.serialization import ( + AddWatch, Auth, CheckVersion, CloseInstance, Create, Create2, + CreateContainer, + CreateTTL, Delete, Exists, GetChildren, @@ -38,6 +41,7 @@ SetACL, GetData, Reconfig, + RemoveWatches, SetData, Sync, Transaction, @@ -48,6 +52,8 @@ KazooState, KeeperState, WatchedEvent, + AddWatchMode, + WatcherType ) from kazoo.retry import KazooRetry from kazoo.security import ACL, OPEN_ACL_UNSAFE @@ -252,6 +258,8 @@ def __init__( self.state_listeners = set() self._child_watchers = defaultdict(set) self._data_watchers = defaultdict(set) + self._persistent_watchers = defaultdict(set) + self._persistent_recursive_watchers = defaultdict(set) self._reset() self.read_only = read_only @@ -420,8 +428,15 @@ def _reset_watchers(self): for data_watchers in self._data_watchers.values(): watchers.extend(data_watchers) + for persistent_watchers in self._persistent_watchers.values(): + watchers.extend(persistent_watchers) + + for pr_watchers in self._persistent_recursive_watchers.values(): + watchers.extend(pr_watchers) self._child_watchers = defaultdict(set) self._data_watchers = defaultdict(set) + self._persistent_watchers = defaultdict(set) + self._persistent_recursive_watchers = defaultdict(set) ev = WatchedEvent(EventType.NONE, self._state, None) for watch in watchers: @@ -924,6 +939,8 @@ def create( sequence=False, makepath=False, include_data=False, + container=False, + ttl=0, ): """Create a node with the given value as its data. Optionally set an ACL on the node. @@ -1001,6 +1018,8 @@ def create( The `makepath` option. .. versionadded:: 2.7 The `include_data` option. + .. versionadded:: 2.9 + The `container` and `ttl` options. """ acl = acl or self.default_acl return self.create_async( @@ -1011,6 +1030,8 @@ def create( sequence=sequence, makepath=makepath, include_data=include_data, + container=container, + ttl=ttl, ).get() def create_async( @@ -1022,6 +1043,8 @@ def create_async( sequence=False, makepath=False, include_data=False, + container=False, + ttl=0, ): """Asynchronously create a ZNode. Takes the same arguments as :meth:`create`. @@ -1032,6 +1055,8 @@ def create_async( The makepath option. .. versionadded:: 2.7 The `include_data` option. + .. versionadded:: 2.9 + The `container` and `ttl` options. """ if acl is None and self.default_acl: acl = self.default_acl @@ -1054,27 +1079,37 @@ def create_async( raise TypeError("Invalid type for 'makepath' (bool expected)") if not isinstance(include_data, bool): raise TypeError("Invalid type for 'include_data' (bool expected)") + if not isinstance(container, bool): + raise TypeError("Invalid type for 'container' (bool expected)") + if not isinstance(ttl, int) or ttl < 0: + raise TypeError("Invalid 'ttl' (integer >= 0 expected)") + if ttl and ephemeral: + raise TypeError("Invalid node creation: ephemeral & ttl") + if container and (ephemeral or sequence or ttl): + raise TypeError( + "Invalid node creation: container & ephemeral/sequence/ttl" + ) - flags = 0 - if ephemeral: - flags |= 1 - if sequence: - flags |= 2 if acl is None: acl = OPEN_ACL_UNSAFE + opcode = _create_opcode( + path, + value, + acl, + self.chroot, + ephemeral, + sequence, + include_data, + container, + ttl, + ) + async_result = self.handler.async_result() @capture_exceptions(async_result) def do_create(): - result = self._create_async_inner( - path, - value, - acl, - flags, - trailing=sequence, - include_data=include_data, - ) + result = self._create_async_inner(opcode) result.rawlink(create_completion) @capture_exceptions(async_result) @@ -1085,7 +1120,7 @@ def retry_completion(result): @wrap(async_result) def create_completion(result): try: - if include_data: + if opcode.type == Create2.type or opcode.type == CreateContainer.type or opcode.type == CreateTTL.type: new_path, stat = result.get() return self.unchroot(new_path), stat else: @@ -1102,24 +1137,10 @@ def create_completion(result): do_create() return async_result - def _create_async_inner( - self, path, value, acl, flags, trailing=False, include_data=False - ): + def _create_async_inner(self, opcode): async_result = self.handler.async_result() - if include_data: - opcode = Create2 - else: - opcode = Create - - call_result = self._call( - opcode( - _prefix_root(self.chroot, path, trailing=trailing), - value, - acl, - flags, - ), - async_result, - ) + + call_result = self._call(opcode, async_result) if call_result is False: # We hit a short-circuit exit on the _call. Because we are # not using the original async_result here, we bubble the @@ -1651,6 +1672,97 @@ def reconfig_async(self, joining, leaving, new_members, from_config): return async_result + def add_watch(self, path, watch, mode): + """Add a watch. + This method adds persistent watches. Unlike the data and + child watches which may be set by calls to + :meth:`KazooClient.exists`, :meth:`KazooClient.get`, and + :meth:`KazooClient.get_children`, persistent watches are not + removed after being triggered. + To remove a persistent watch, use + :meth:`KazooClient.remove_all_watches` with an argument of + :attr:`~kazoo.protocol.states.WatcherType.ANY`. + The `mode` argument determines whether or not the watch is + recursive. To set a persistent watch, use + :class:`~kazoo.protocol.states.AddWatchMode.PERSISTENT`. To set a + persistent recursive watch, use + :class:`~kazoo.protocol.states.AddWatchMode.PERSISTENT_RECURSIVE`. + :param path: Path of node to watch. + :param watch: Watch callback to set for future changes + to this path. + :param mode: The mode to use. + :type mode: int + :raises: + :exc:`~kazoo.exceptions.MarshallingError` if mode is + unknown. + :exc:`~kazoo.exceptions.ZookeeperError` if the server + returns a non-zero error code. + """ + return self.add_watch_async(path, watch, mode).get() + + def add_watch_async(self, path, watch, mode): + """Asynchronously add a watch. Takes the same arguments as + :meth:`add_watch`. + """ + if not isinstance(path, str): + raise TypeError("Invalid type for 'path' (string expected)") + if not callable(watch): + raise TypeError("Invalid type for 'watch' (must be a callable)") + if not isinstance(mode, int): + raise TypeError("Invalid type for 'mode' (int expected)") + if mode not in ( + AddWatchMode.PERSISTENT, + AddWatchMode.PERSISTENT_RECURSIVE, + ): + raise ValueError("Invalid value for 'mode'") + + async_result = self.handler.async_result() + self._call( + AddWatch(_prefix_root(self.chroot, path), watch, mode), + async_result, + ) + return async_result + + def remove_all_watches(self, path, watcher_type): + """Remove watches from a path. + This removes all watches of a specified type (data, child, + any) from a given path. + The `watcher_type` argument specifies which type to use. It + may be one of: + * :attr:`~kazoo.protocol.states.WatcherType.DATA` + * :attr:`~kazoo.protocol.states.WatcherType.CHILDREN` + * :attr:`~kazoo.protocol.states.WatcherType.ANY` + To remove persistent watches, specify a watcher type of + :attr:`~kazoo.protocol.states.WatcherType.ANY`. + :param path: Path of watch to remove. + :param watcher_type: The type of watch to remove. + :type watcher_type: int + """ + + return self.remove_all_watches_async(path, watcher_type).get() + + def remove_all_watches_async(self, path, watcher_type): + """Asynchronously remove watches. Takes the same arguments as + :meth:`remove_all_watches`. + """ + if not isinstance(path, str): + raise TypeError("Invalid type for 'path' (string expected)") + if not isinstance(watcher_type, int): + raise TypeError("Invalid type for 'watcher_type' (int expected)") + if watcher_type not in ( + WatcherType.ANY, + WatcherType.CHILDREN, + WatcherType.DATA, + ): + raise ValueError("Invalid value for 'watcher_type'") + + async_result = self.handler.async_result() + self._call( + RemoveWatches(_prefix_root(self.chroot, path), watcher_type), + async_result, + ) + return async_result + class TransactionRequest(object): """A Zookeeper Transaction Request @@ -1680,7 +1792,15 @@ def __init__(self, client): self.committed = False def create( - self, path, value=b"", acl=None, ephemeral=False, sequence=False + self, + path, + value=b"", + acl=None, + ephemeral=False, + sequence=False, + include_data=False, + container=False, + ttl=0, ): """Add a create ZNode to the transaction. Takes the same arguments as :meth:`KazooClient.create`, with the exception @@ -1688,6 +1808,8 @@ def create( :returns: None + .. versionadded:: 2.9 + The `include_data`, `container` and `ttl` options. """ if acl is None and self.client.default_acl: acl = self.client.default_acl @@ -1704,19 +1826,34 @@ def create( raise TypeError("Invalid type for 'ephemeral' (bool expected)") if not isinstance(sequence, bool): raise TypeError("Invalid type for 'sequence' (bool expected)") + if not isinstance(include_data, bool): + raise TypeError("Invalid type for 'include_data' (bool expected)") + if not isinstance(container, bool): + raise TypeError("Invalid type for 'container' (bool expected)") + if not isinstance(ttl, int) or ttl < 0: + raise TypeError("Invalid 'ttl' (integer >= 0 expected)") + if ttl and ephemeral: + raise TypeError("Invalid node creation: ephemeral & ttl") + if container and (ephemeral or sequence or ttl): + raise TypeError( + "Invalid node creation: container & ephemeral/sequence/ttl" + ) - flags = 0 - if ephemeral: - flags |= 1 - if sequence: - flags |= 2 if acl is None: acl = OPEN_ACL_UNSAFE - self._add( - Create(_prefix_root(self.client.chroot, path), value, acl, flags), - None, + opcode = _create_opcode( + path, + value, + acl, + self.client.chroot, + ephemeral, + sequence, + include_data, + container, + ttl, ) + self._add(opcode, None) def delete(self, path, version=-1): """Add a delete ZNode to the transaction. Takes the same @@ -1797,3 +1934,85 @@ def _add(self, request, post_processor=None): self._check_tx_state() self.client.logger.log(BLATHER, "Added %r to %r", request, self) self.operations.append(request) + + +def _create_opcode( + path, + value, + acl, + chroot, + ephemeral, + sequence, + include_data, + container, + ttl, +): + """Helper function. + Creates the create OpCode for regular `client.create()` operations as + well as in a `client.transaction()` context. + """ + if not isinstance(path, str): + raise TypeError("Invalid type for 'path' (string expected)") + if acl and (isinstance(acl, ACL) or not isinstance(acl, (tuple, list))): + raise TypeError( + "Invalid type for 'acl' (acl must be a tuple/list" " of ACL's" + ) + if value is not None and not isinstance(value, bytes): + raise TypeError("Invalid type for 'value' (must be a byte string)") + if not isinstance(ephemeral, bool): + raise TypeError("Invalid type for 'ephemeral' (bool expected)") + if not isinstance(sequence, bool): + raise TypeError("Invalid type for 'sequence' (bool expected)") + if not isinstance(include_data, bool): + raise TypeError("Invalid type for 'include_data' (bool expected)") + if not isinstance(container, bool): + raise TypeError("Invalid type for 'container' (bool expected)") + if not isinstance(ttl, int) or ttl < 0: + raise TypeError("Invalid 'ttl' (integer >= 0 expected)") + if ttl and ephemeral: + raise TypeError("Invalid node creation: ephemeral & ttl") + if container and (ephemeral or sequence or ttl): + raise TypeError( + "Invalid node creation: container & ephemeral/sequence/ttl" + ) + + # Should match Zookeeper's CreateMode fromFlag + # https://github.com/apache/zookeeper/blob/master/zookeeper-server/ + # src/main/java/org/apache/zookeeper/CreateMode.java#L112 + flags = 0 + if ephemeral: + flags |= 1 + if sequence: + flags |= 2 + if container: + flags = 4 + if ttl: + if sequence: + flags = 6 + else: + flags = 5 + + if acl is None: + acl = OPEN_ACL_UNSAFE + + # Figure out the OpCode we are going to send + if include_data: + return Create2( + _prefix_root(chroot, path, trailing=sequence), value, acl, flags + ) + elif container: + return CreateContainer( + _prefix_root(chroot, path, trailing=False), value, acl, flags + ) + elif ttl: + return CreateTTL( + _prefix_root(chroot, path, trailing=sequence), + value, + acl, + flags, + ttl, + ) + else: + return Create( + _prefix_root(chroot, path, trailing=sequence), value, acl, flags + ) diff --git a/kazoo/protocol/connection.py b/kazoo/protocol/connection.py index 3df7b162..343eb9bd 100644 --- a/kazoo/protocol/connection.py +++ b/kazoo/protocol/connection.py @@ -19,6 +19,7 @@ ) from kazoo.loggingsupport import BLATHER from kazoo.protocol.serialization import ( + AddWatch, Auth, Close, Connect, @@ -27,6 +28,7 @@ GetChildren2, Ping, PingInstance, + RemoveWatches, ReplyHeader, SASL, Transaction, @@ -34,10 +36,12 @@ int_struct, ) from kazoo.protocol.states import ( + AddWatchMode, Callback, KeeperState, WatchedEvent, EVENT_TYPE_MAP, + WatcherType, ) from kazoo.retry import ( ForceRetryError, @@ -356,6 +360,18 @@ def _write(self, msg, timeout): raise ConnectionDropped("socket connection broken") sent += bytes_sent + def _find_persistent_recursive_watchers(self, path): + parts = path.split("/") + watchers = [] + for count in range(len(parts)): + candidate = "/".join(parts[: count + 1]) + if not candidate: + continue + watchers.extend( + self.client._persistent_recursive_watchers.get(candidate, []) + ) + return watchers + def _read_watch_event(self, buffer, offset): client = self.client watch, offset = Watch.deserialize(buffer, offset) @@ -367,9 +383,13 @@ def _read_watch_event(self, buffer, offset): if watch.type in (CREATED_EVENT, CHANGED_EVENT): watchers.extend(client._data_watchers.pop(path, [])) + watchers.extend(client._persistent_watchers.get(path, [])) + watchers.extend(self._find_persistent_recursive_watchers(path)) elif watch.type == DELETED_EVENT: watchers.extend(client._data_watchers.pop(path, [])) watchers.extend(client._child_watchers.pop(path, [])) + watchers.extend(client._persistent_watchers.get(path, [])) + watchers.extend(self._find_persistent_recursive_watchers(path)) elif watch.type == CHILD_EVENT: watchers.extend(client._child_watchers.pop(path, [])) else: @@ -441,13 +461,35 @@ def _read_response(self, header, buffer, offset): async_object.set(response) - # Determine if watchers should be registered - watcher = getattr(request, "watcher", None) - if not client._stopped.is_set() and watcher: - if isinstance(request, (GetChildren, GetChildren2)): - client._child_watchers[request.path].add(watcher) - else: - client._data_watchers[request.path].add(watcher) + # Determine if watchers should be registered or unregistered + if not client._stopped.is_set(): + watcher = getattr(request, "watcher", None) + if watcher: + if isinstance(request, AddWatch): + if request.mode == AddWatchMode.PERSISTENT: + client._persistent_watchers[request.path].add( + watcher + ) + elif request.mode == AddWatchMode.PERSISTENT_RECURSIVE: + client._persistent_recursive_watchers[ + request.path + ].add(watcher) + elif isinstance(request, (GetChildren, GetChildren2)): + client._child_watchers[request.path].add(watcher) + else: + client._data_watchers[request.path].add(watcher) + if isinstance(request, RemoveWatches): + if request.watcher_type == WatcherType.CHILDREN: + client._child_watchers.pop(request.path, None) + elif request.watcher_type == WatcherType.DATA: + client._data_watchers.pop(request.path, None) + elif request.watcher_type == WatcherType.ANY: + client._child_watchers.pop(request.path, None) + client._data_watchers.pop(request.path, None) + client._persistent_watchers.pop(request.path, None) + client._persistent_recursive_watchers.pop( + request.path, None + ) if isinstance(request, Close): self.logger.log(BLATHER, "Read close response") diff --git a/kazoo/protocol/serialization.py b/kazoo/protocol/serialization.py index 40e6360c..2cee2d01 100644 --- a/kazoo/protocol/serialization.py +++ b/kazoo/protocol/serialization.py @@ -343,6 +343,11 @@ def deserialize(cls, bytes, offset): while not header.done: if header.type == Create.type: response, offset = read_string(bytes, offset) + elif header.type in (Create2.type, CreateContainer.type, CreateTTL.type): + path, offset = read_string(bytes, offset) + stat = ZnodeStat._make(stat_struct.unpack_from(bytes, offset)) + offset += stat_struct.size + response = (path, stat) elif header.type == Delete.type: response = True elif header.type == SetData.type: @@ -367,6 +372,10 @@ def unchroot(client, response): for result in response: if isinstance(result, str): resp.append(client.unchroot(result)) + elif isinstance(result, ZnodeStat): + resp.append(result) + elif isinstance(result, tuple): + resp.append((client.unchroot(result[0]), result[1])) else: resp.append(result) return resp @@ -416,6 +425,69 @@ def deserialize(cls, bytes, offset): return data, stat +class CreateContainer(namedtuple("CreateContainer", "path data acl flags")): + type = 19 + + def serialize(self): + b = bytearray() + b.extend(write_string(self.path)) + b.extend(write_buffer(self.data)) + b.extend(int_struct.pack(len(self.acl))) + for acl in self.acl: + b.extend( + int_struct.pack(acl.perms) + + write_string(acl.id.scheme) + + write_string(acl.id.id) + ) + b.extend(int_struct.pack(self.flags)) + return b + + @classmethod + def deserialize(cls, bytes, offset): + path, offset = read_string(bytes, offset) + stat = ZnodeStat._make(stat_struct.unpack_from(bytes, offset)) + return path, stat + + +class CreateTTL(namedtuple("CreateTTL", "path data acl flags ttl")): + type = 21 + + def serialize(self): + b = bytearray() + b.extend(write_string(self.path)) + b.extend(write_buffer(self.data)) + b.extend(int_struct.pack(len(self.acl))) + for acl in self.acl: + b.extend( + int_struct.pack(acl.perms) + + write_string(acl.id.scheme) + + write_string(acl.id.id) + ) + b.extend(int_struct.pack(self.flags)) + b.extend(long_struct.pack(self.ttl)) + return b + + @classmethod + def deserialize(cls, bytes, offset): + path, offset = read_string(bytes, offset) + stat = ZnodeStat._make(stat_struct.unpack_from(bytes, offset)) + return path, stat + + +class RemoveWatches(namedtuple("RemoveWatches", "path watcher_type")): + type = 18 + + def serialize(self): + b = bytearray() + b.extend(write_string(self.path)) + b.extend(int_struct.pack(self.watcher_type)) + return b + + @classmethod + def deserialize(cls, bytes, offset): + return None + + class Auth(namedtuple("Auth", "auth_type scheme auth")): type = 100 @@ -441,6 +513,20 @@ def deserialize(cls, bytes, offset): return challenge, offset +class AddWatch(namedtuple("AddWatch", "path watcher mode")): + type = 106 + + def serialize(self): + b = bytearray() + b.extend(write_string(self.path)) + b.extend(int_struct.pack(self.mode)) + return b + + @classmethod + def deserialize(cls, bytes, offset): + return None + + class Watch(namedtuple("Watch", "type state path")): @classmethod def deserialize(cls, bytes, offset): diff --git a/kazoo/protocol/states.py b/kazoo/protocol/states.py index 480a586e..b9d07bee 100644 --- a/kazoo/protocol/states.py +++ b/kazoo/protocol/states.py @@ -251,3 +251,43 @@ def data_length(self): @property def children_count(self): return self.numChildren + +class AddWatchMode(object): + """Modes for use with :meth:`~kazoo.client.KazooClient.add_watch` + + .. attribute:: PERSISTENT + + The watch is not removed when trigged. + + .. attribute:: PERSISTENT_RECURSIVE + + The watch is not removed when trigged, and applies to all + paths underneath the supplied path as well. + """ + + PERSISTENT = 0 + PERSISTENT_RECURSIVE = 1 + + +class WatcherType(object): + """Watcher types for use with + :meth:`~kazoo.client.KazooClient.remove_all_watches` + + .. attribute:: CHILDREN + + Child watches. + + .. attribute:: DATA + + Data watches. + + .. attribute:: ANY + + Any type of watch (child, data, persistent, or persistent + recursive). + + """ + + CHILDREN = 1 + DATA = 2 + ANY = 3 diff --git a/kazoo/testing/harness.py b/kazoo/testing/harness.py index bb77f071..fa248bed 100644 --- a/kazoo/testing/harness.py +++ b/kazoo/testing/harness.py @@ -75,6 +75,8 @@ def get_global_cluster(): # in read only test "localSessionsEnabled=" + ZOOKEEPER_LOCAL_SESSION_RO, "localSessionsUpgradingEnabled=" + ZOOKEEPER_LOCAL_SESSION_RO, + # enable container and TTL node types (ZK 3.5+) + "extendedTypesEnabled=true", ] # If defined, this sets the superuser password to "test" additional_java_system_properties = [ diff --git a/kazoo/tests/test_client.py b/kazoo/tests/test_client.py index 3f1748c4..57813a71 100644 --- a/kazoo/tests/test_client.py +++ b/kazoo/tests/test_client.py @@ -27,6 +27,13 @@ from kazoo.protocol.connection import _CONNECTION_DROP from kazoo.protocol.states import KeeperState, KazooState from kazoo.tests.util import CI_ZK_VERSION +from kazoo.protocol.states import ( + AddWatchMode, + KazooState, + KeeperState, + WatcherType, + EventType, +) class TestClientTransitions(KazooTestCase): @@ -60,6 +67,51 @@ def listener(state): assert states == req_states +class TestCreateOpcode(unittest.TestCase): + """Unit tests for _create_opcode (no ZK required).""" + + def test_create_opcode_returns_create(self): + from kazoo.client import _create_opcode + from kazoo.protocol.serialization import Create + from kazoo.security import OPEN_ACL_UNSAFE + + op = _create_opcode( + "/test", b"data", OPEN_ACL_UNSAFE, "", False, False, False, False, 0 + ) + assert op.__class__ is Create + + def test_create_opcode_returns_create2_for_include_data(self): + from kazoo.client import _create_opcode + from kazoo.protocol.serialization import Create2 + from kazoo.security import OPEN_ACL_UNSAFE + + op = _create_opcode( + "/test", b"data", OPEN_ACL_UNSAFE, "", False, False, True, False, 0 + ) + assert op.__class__ is Create2 + + def test_create_opcode_returns_create_container(self): + from kazoo.client import _create_opcode + from kazoo.protocol.serialization import CreateContainer + from kazoo.security import OPEN_ACL_UNSAFE + + op = _create_opcode( + "/test", b"data", OPEN_ACL_UNSAFE, "", False, False, False, True, 0 + ) + assert op.__class__ is CreateContainer + + def test_create_opcode_returns_create_ttl(self): + from kazoo.client import _create_opcode + from kazoo.protocol.serialization import CreateTTL + from kazoo.security import OPEN_ACL_UNSAFE + + op = _create_opcode( + "/test", b"data", OPEN_ACL_UNSAFE, "", False, False, False, False, 1000 + ) + assert op.__class__ is CreateTTL + assert op.ttl == 1000 + + class TestClientConstructor(unittest.TestCase): def _makeOne(self, *args, **kw): from kazoo.client import KazooClient @@ -571,6 +623,16 @@ def test_create_invalid_arguments(self): client.create("a", sequence="yes") with pytest.raises(TypeError): client.create("a", makepath="yes") + with pytest.raises(TypeError): + client.create("a", container="yes") + with pytest.raises(TypeError): + client.create("a", ttl="1000") + with pytest.raises(TypeError): + client.create("a", ttl=-1) + with pytest.raises(TypeError): + client.create("a", ephemeral=True, ttl=1000) + with pytest.raises(TypeError): + client.create("a", container=True, ephemeral=True) def test_create_value(self): client = self.client @@ -713,6 +775,42 @@ def test_create_exists(self): with pytest.raises(NodeExistsError): client.create(path) + def test_create_container(self): + """Create container node (ZK 3.5+ with extendedTypesEnabled).""" + if CI_ZK_VERSION and CI_ZK_VERSION < (3, 5): + pytest.skip("Container nodes require Zookeeper 3.5+") + elif CI_ZK_VERSION and CI_ZK_VERSION >= (3, 5): + pass + else: + ver = self.client.server_version() + if ver < (3, 5): + pytest.skip("Container nodes require Zookeeper 3.5+") + + client = self.client + path = client.create("/container_test", b"data", container=True) + assert path == "/container_test" + assert client.exists("/container_test") + data, stat = client.get("/container_test") + assert data == b"data" + + def test_create_ttl(self): + """Create TTL node (ZK 3.5.5+ with extendedTypesEnabled).""" + if CI_ZK_VERSION and CI_ZK_VERSION < (3, 5): + pytest.skip("TTL nodes require Zookeeper 3.5+") + elif CI_ZK_VERSION and CI_ZK_VERSION >= (3, 5): + pass + else: + ver = self.client.server_version() + if ver < (3, 5): + pytest.skip("TTL nodes require Zookeeper 3.5+") + + client = self.client + path = client.create("/ttl_test", b"data", ttl=60000) + assert path == "/ttl_test" + assert client.exists("/ttl_test") + data, stat = client.get("/ttl_test") + assert data == b"data" + def test_create_stat(self): if CI_ZK_VERSION: version = CI_ZK_VERSION @@ -1146,6 +1244,147 @@ def test_request_queuing_session_expired(self): finally: client.stop() + def _require_zk_version(self, major, minor): + skip = False + if CI_ZK_VERSION and CI_ZK_VERSION < (major, minor): + skip = True + elif CI_ZK_VERSION and CI_ZK_VERSION >= (major, minor): + skip = False + else: + ver = self.client.server_version() + if ver[1] < minor: + skip = True + if skip: + pytest.skip("Must use Zookeeper %s.%s or above" % (major, minor)) + + def test_persistent_recursive_watch(self): + # This tests adding and removing a persistent recursive watch. + self._require_zk_version(3, 6) + events = [] + + def callback(event): + events.append(dict(type=event.type, path=event.path)) + + client = self.client + client.add_watch("/a", callback, AddWatchMode.PERSISTENT_RECURSIVE) + full_path = client.chroot + "/a" + assert len(client._persistent_recursive_watchers[full_path]) == 1 + client.create("/a") + client.create("/a/b") + client.create("/a/b/c", value=b"1") + client.create("/a/b/d", value=b"1") + client.set("/a/b/c", value=b"2") + client.set("/a/b/d", value=b"2") + client.delete("/a", recursive=True) + # Remove the watch + client.remove_all_watches("/a", WatcherType.ANY) + # Perform one more call that we don't expect to see + client.create("/a") + # Wait in case the callback does arrive + time.sleep(3) + assert client._persistent_recursive_watchers[full_path] == set() + assert events == [ + dict(type=EventType.CREATED, path="/a"), + dict(type=EventType.CREATED, path="/a/b"), + dict(type=EventType.CREATED, path="/a/b/c"), + dict(type=EventType.CREATED, path="/a/b/d"), + dict(type=EventType.CHANGED, path="/a/b/c"), + dict(type=EventType.CHANGED, path="/a/b/d"), + dict(type=EventType.DELETED, path="/a/b/c"), + dict(type=EventType.DELETED, path="/a/b/d"), + dict(type=EventType.DELETED, path="/a/b"), + dict(type=EventType.DELETED, path="/a"), + ] + + def test_persistent_watch(self): + # This tests adding and removing a persistent watch. + self._require_zk_version(3, 6) + events = [] + + def callback(event): + events.append(dict(type=event.type, path=event.path)) + + client = self.client + client.add_watch("/a", callback, AddWatchMode.PERSISTENT) + full_path = client.chroot + "/a" + assert len(client._persistent_watchers[full_path]) == 1 + client.create("/a") + # This shouldn't appear since the watch is not recursive + client.create("/a/b") + client.delete("/a", recursive=True) + # Remove the watch + client.remove_all_watches("/a", WatcherType.ANY) + # Perform one more call that we don't expect to see + client.create("/a") + # Wait in case the callback does arrive + time.sleep(3) + assert client._persistent_watchers[full_path] == set() + assert events == [ + dict(type=EventType.CREATED, path="/a"), + dict(type=EventType.DELETED, path="/a"), + ] + + def test_remove_data_watch(self): + # Test that removing a data watch leaves a child watch in place. + self._require_zk_version(3, 6) + callback_event = threading.Event() + + def child_callback(event): + callback_event.set() + + def data_callback(event): + pass + + client = self.client + client.create("/a") + client.get("/a", watch=data_callback) + client.get_children("/a", watch=child_callback) + client.remove_all_watches("/a", WatcherType.DATA) + client.create("/a/b") + callback_event.wait(30) + + def test_remove_children_watch(self): + # Test that removing a children watch leaves a data watch in place. + self._require_zk_version(3, 6) + callback_event = threading.Event() + + def data_callback(event): + callback_event.set() + + def child_callback(event): + pass + + client = self.client + client.create("/a") + client.get("/a", watch=data_callback) + client.get_children("/a", watch=child_callback) + client.remove_all_watches("/a", WatcherType.CHILDREN) + client.set("/a", b"1") + callback_event.wait(30) + + def test_invalid_add_watch_values(self): + def callback(event): + return + + client = self.client + with pytest.raises(TypeError): + client.add_watch(b"/a", callback, AddWatchMode.PERSISTENT) + with pytest.raises(TypeError): + client.add_watch("/a", "test", AddWatchMode.PERSISTENT) + with pytest.raises(TypeError): + client.add_watch("/a", callback, "1") + with pytest.raises(ValueError): + client.add_watch("/a", callback, 42) + + def test_invalid_remove_all_watch_values(self): + client = self.client + with pytest.raises(TypeError): + client.remove_all_watches(b"/a", WatcherType.ANY) + with pytest.raises(TypeError): + client.remove_all_watches("/a", "test") + with pytest.raises(ValueError): + client.remove_all_watches("/a", 42) + class TestSSLClient(KazooTestCase): def setUp(self): @@ -1225,6 +1464,8 @@ def test_bad_creates(self): ("/smith", b"", "bleh"), ("/smith", b"", None, "fred"), ("/smith", b"", None, True, "fred"), + ("/smith", b"", None, False, False, False, "yes"), + ("/smith", b"", None, False, False, False, False, False, "ttl"), ] for args in args_list: @@ -1232,6 +1473,42 @@ def test_bad_creates(self): t = self.client.transaction() t.create(*args) + def test_transaction_create_container(self): + """Transaction with container node (ZK 3.5+).""" + if CI_ZK_VERSION and CI_ZK_VERSION < (3, 5): + pytest.skip("Container nodes require Zookeeper 3.5+") + elif CI_ZK_VERSION and CI_ZK_VERSION >= (3, 5): + pass + else: + ver = self.client.server_version() + if ver < (3, 5): + pytest.skip("Container nodes require Zookeeper 3.5+") + + t = self.client.transaction() + t.create("/tx_container", b"data", container=True) + results = t.commit() + assert len(results) == 1 + assert results[0][0] == "/tx_container" + assert self.client.exists("/tx_container") + + def test_transaction_create_ttl(self): + """Transaction with TTL node (ZK 3.5+).""" + if CI_ZK_VERSION and CI_ZK_VERSION < (3, 5): + pytest.skip("TTL nodes require Zookeeper 3.5+") + elif CI_ZK_VERSION and CI_ZK_VERSION >= (3, 5): + pass + else: + ver = self.client.server_version() + if ver < (3, 5): + pytest.skip("TTL nodes require Zookeeper 3.5+") + + t = self.client.transaction() + t.create("/tx_ttl", b"data", ttl=60000) + results = t.commit() + assert len(results) == 1 + assert results[0][0] == "/tx_ttl" + assert self.client.exists("/tx_ttl") + def test_default_acl(self): from kazoo.security import make_digest_acl