diff --git a/pyhilo/api.py b/pyhilo/api.py index 553743e..cdbc452 100755 --- a/pyhilo/api.py +++ b/pyhilo/api.py @@ -291,7 +291,7 @@ async def _async_request( try: data = await resp.json(content_type=None) except json.decoder.JSONDecodeError: - LOG.warning(f"JSON Decode error: {resp.__dict__}") + LOG.warning("JSON Decode error: %s", resp.__dict__) message = await resp.text() data = {"error": message} else: @@ -353,7 +353,7 @@ async def _async_handle_on_backoff(self, _: dict[str, Any]) -> None: err: ClientResponseError = err_info[1].with_traceback(err_info[2]) # type: ignore if err.status in (401, 403): - LOG.warning(f"Refreshing websocket token {err.request_info.url}") + LOG.warning("Refreshing websocket token %s", err.request_info.url) if ( "client/negotiate" in str(err.request_info.url) and err.request_info.method == "POST" @@ -361,7 +361,7 @@ async def _async_handle_on_backoff(self, _: dict[str, Any]) -> None: LOG.info( "401 detected on websocket, refreshing websocket token. Old url: {self.ws_url} Old Token: {self.ws_token}" ) - LOG.info(f"401 detected on {err.request_info.url}") + LOG.info("401 detected on %s", err.request_info.url) async with self._backoff_refresh_lock_ws: await self.refresh_ws_token() await self.get_websocket_params() @@ -480,7 +480,7 @@ async def fb_install(self, fb_id: str) -> None: json=body, ) except ClientResponseError as err: - LOG.error(f"ClientResponseError: {err}") + LOG.error("ClientResponseError: %s", err) if err.status in (401, 403): raise InvalidCredentialsError("Invalid credentials") from err raise RequestError(err) from err @@ -518,14 +518,14 @@ async def android_register(self) -> None: data=parsed_body, ) except ClientResponseError as err: - LOG.error(f"ClientResponseError: {err}") + LOG.error("ClientResponseError: %s", err) if err.status in (401, 403): raise InvalidCredentialsError("Invalid credentials") from err raise RequestError(err) from err LOG.debug("Android client register: %s", resp) msg: str = resp.get("message", "") if msg.startswith("Error="): - LOG.error(f"Android registration error: {msg}") + LOG.error("Android registration error: %s", msg) raise RequestError token = msg.split("=")[-1] LOG.debug("Calling set_state android_register") diff --git a/pyhilo/device/__init__.py b/pyhilo/device/__init__.py index c92095f..026b5b2 100644 --- a/pyhilo/device/__init__.py +++ b/pyhilo/device/__init__.py @@ -51,7 +51,7 @@ def __init__( def update(self, **kwargs: Dict[str, Union[str, int, Dict]]) -> None: # TODO(dvd): This has to be re-written, this is not dynamic at all. if self._api.log_traces: - LOG.debug(f"[TRACE] Adding device {kwargs}") + LOG.debug("[TRACE] Adding device %s", kwargs) for orig_att, val in kwargs.items(): att = camel_to_snake(orig_att) if reading_att := HILO_READING_TYPES.get(orig_att): @@ -70,7 +70,7 @@ def update(self, **kwargs: Dict[str, Union[str, int, Dict]]) -> None: self.update_readings(DeviceReading(**reading)) # type: ignore if att not in HILO_DEVICE_ATTRIBUTES: - LOG.warning(f"Unknown device attribute {att}: {val}") + LOG.warning("Unknown device attribute %s: %s", att, val) continue elif att in HILO_LIST_ATTRIBUTES: # This is where we generated the supported_attributes and settable_attributes @@ -108,7 +108,7 @@ def update(self, **kwargs: Dict[str, Union[str, int, Dict]]) -> None: async def set_attribute(self, attribute: str, value: Union[str, int, None]) -> None: if dev_attribute := cast(DeviceAttribute, self._api.dev_atts(attribute)): - LOG.debug(f"{self._tag} Setting {dev_attribute} to {value}") + LOG.debug("%s Setting %s to %s", self._tag, dev_attribute, value) await self._set_attribute(dev_attribute, value) return LOG.warning( @@ -134,7 +134,7 @@ async def _set_attribute( ) ) else: - LOG.warning(f"{self._tag} Invalid attribute {attribute} for device") + LOG.warning("%s Invalid attribute %s for device", self._tag, attribute) def get_attribute(self, attribute: str) -> Union[DeviceReading, None]: if dev_attribute := cast(DeviceAttribute, self._api.dev_atts(attribute)): @@ -245,7 +245,7 @@ def __init__(self, **kwargs: Dict[str, Any]): else "" ) if not self.device_attribute: - LOG.warning(f"Received invalid reading for {self.device_id}: {kwargs}") + LOG.warning("Received invalid reading for %s: %s", self.device_id, kwargs) def __repr__(self) -> str: return f"" diff --git a/pyhilo/devices.py b/pyhilo/devices.py index c9de57c..e927bb4 100644 --- a/pyhilo/devices.py +++ b/pyhilo/devices.py @@ -87,7 +87,7 @@ def generate_device(self, device: dict) -> HiloDevice: try: device_type = HILO_DEVICE_TYPES[dev.type] except KeyError: - LOG.warning(f"Unknown device type {dev.type}, adding as Sensor") + LOG.warning("Unknown device type %s, adding as Sensor", dev.type) device_type = "Sensor" dev.__class__ = globals()[device_type] return dev diff --git a/pyhilo/util/state.py b/pyhilo/util/state.py index 87a00bf..0f180dd 100644 --- a/pyhilo/util/state.py +++ b/pyhilo/util/state.py @@ -4,7 +4,9 @@ import asyncio from datetime import datetime +import os from os.path import isfile +import tempfile from typing import Any, ForwardRef, TypedDict, TypeVar, get_type_hints import aiofiles @@ -77,7 +79,7 @@ class StateDict(TypedDict, total=False): T = TypeVar("T", bound="StateDict") -def _get_defaults(cls: type[T]) -> dict[str, Any]: +def _get_defaults(cls: type[T]) -> T: """Generate a default dict based on typed dict This function recursively creates a nested dictionary structure that mirrors @@ -117,22 +119,71 @@ def _get_defaults(cls: type[T]) -> dict[str, Any]: return new_dict # type: ignore[return-value] -async def get_state(state_yaml: str) -> StateDict: +def _write_state(state_yaml: str, state: dict[str, Any] | StateDict) -> None: + "Write state atomically to a temp file, this prevents reading a file being written to" + + dir_name = os.path.dirname(os.path.abspath(state_yaml)) + content = yaml.dump(state) + with tempfile.NamedTemporaryFile( + mode="w", dir=dir_name, delete=False, suffix=".tmp" + ) as tmp: + tmp.write(content) + tmp_path = tmp.name + os.chmod(tmp_path, 0o644) + os.replace(tmp_path, state_yaml) + + +async def get_state(state_yaml: str, _already_locked: bool = False) -> StateDict: """Read in state yaml. :param state_yaml: filename where to read the state :type state_yaml: ``str`` + :param _already_locked: Whether the lock is already held by the caller (e.g. set_state). + Prevents deadlock when corruption recovery needs to write defaults. + :type _already_locked: ``bool`` :rtype: ``StateDict`` """ if not isfile( state_yaml ): # noqa: PTH113 - isfile is fine and simpler in this case. - return _get_defaults(StateDict) # type: ignore - async with aiofiles.open(state_yaml, mode="r") as yaml_file: - LOG.debug("Loading state from yaml") - content = await yaml_file.read() - state_yaml_payload: StateDict = yaml.safe_load(content) - return state_yaml_payload + return _get_defaults(StateDict) + + try: + async with aiofiles.open(state_yaml, mode="r") as yaml_file: + LOG.debug("Loading state from yaml") + content = await yaml_file.read() + + state_yaml_payload: StateDict | None = yaml.safe_load(content) + + # Handle corrupted/empty YAML files + if state_yaml_payload is None or not isinstance(state_yaml_payload, dict): + LOG.warning( + "State file %s is corrupted or empty, reinitializing with defaults", + state_yaml, + ) + defaults = _get_defaults(StateDict) + if _already_locked: + _write_state(state_yaml, defaults) + else: + async with lock: + _write_state(state_yaml, defaults) + return defaults + + return state_yaml_payload + + except yaml.YAMLError as e: + LOG.error( + "Failed to parse state file %s: %s. Reinitializing with defaults.", + state_yaml, + e, + ) + defaults = _get_defaults(StateDict) + if _already_locked: + _write_state(state_yaml, defaults) + else: + async with lock: + _write_state(state_yaml, defaults) + return defaults async def set_state( @@ -143,6 +194,7 @@ async def set_state( ), ) -> None: """Save state yaml. + :param state_yaml: filename where to read the state :type state_yaml: ``str`` :param key: Key name @@ -152,14 +204,11 @@ async def set_state( :rtype: ``StateDict`` """ async with lock: # note ic-dev21: on lock le fichier pour être sûr de finir la job - current_state = await get_state(state_yaml) or {} + current_state = await get_state(state_yaml, _already_locked=True) or {} merged_state: dict[str, Any] = {key: {**current_state.get(key, {}), **state}} # type: ignore[dict-item] new_state: dict[str, Any] = {**current_state, **merged_state} - async with aiofiles.open(state_yaml, mode="w") as yaml_file: - LOG.debug("Saving state to yaml file") - # TODO: Use asyncio.get_running_loop() and run_in_executor to write - # to the file in a non blocking manner. Currently, the file writes - # are properly async but the yaml dump is done synchronously on the - # main event loop. - content = yaml.dump(new_state) - await yaml_file.write(content) + LOG.debug("Saving state to yaml file") + # TODO: Use asyncio.get_running_loop() and run_in_executor to write + # to the file in a non blocking manner. Currently, yaml.dump is + # synchronous on the main event loop. + _write_state(state_yaml, new_state) diff --git a/pyhilo/websocket.py b/pyhilo/websocket.py index 4522f73..b6d69c9 100755 --- a/pyhilo/websocket.py +++ b/pyhilo/websocket.py @@ -173,7 +173,9 @@ async def _async_receive_json(self) -> list[Dict[str, Any]]: response = await self._client.receive(300) if response.type in (WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.CLOSING): - LOG.error(f"Websocket: Received event to close connection: {response.type}") + LOG.error( + "Websocket: Received event to close connection: %s", response.type + ) raise ConnectionClosedError("Connection was closed.") if response.type == WSMsgType.ERROR: @@ -183,7 +185,7 @@ async def _async_receive_json(self) -> list[Dict[str, Any]]: raise ConnectionFailedError if response.type != WSMsgType.TEXT: - LOG.error(f"Websocket: Received invalid message: {response}") + LOG.error("Websocket: Received invalid message: %s", response) raise InvalidMessageError(f"Received non-text message: {response.type}") messages: list[Dict[str, Any]] = [] @@ -196,7 +198,7 @@ async def _async_receive_json(self) -> list[Dict[str, Any]]: except ValueError as v_exc: raise InvalidMessageError("Received invalid JSON") from v_exc except json.decoder.JSONDecodeError as j_exc: - LOG.error(f"Received invalid JSON: {msg}") + LOG.error("Received invalid JSON: %s", msg) LOG.exception(j_exc) data = {} @@ -307,14 +309,14 @@ async def async_connect(self) -> None: **proxy_env, ) except (ClientError, ServerDisconnectedError, WSServerHandshakeError) as err: - LOG.error(f"Unable to connect to WS server {err}") + LOG.error("Unable to connect to WS server %s", err) if hasattr(err, "status") and err.status in (401, 403, 404, 409): raise InvalidCredentialsError("Invalid credentials") from err except Exception as err: - LOG.error(f"Unable to connect to WS server {err}") + LOG.error("Unable to connect to WS server %s", err) raise CannotConnectError(err) from err - LOG.info(f"Connected to websocket server {self._api.endpoint}") + LOG.info("Connected to websocket server %s", self._api.endpoint) # Quick pause to prevent race condition await asyncio.sleep(0.05) @@ -353,11 +355,11 @@ async def async_listen(self) -> None: LOG.info("Websocket: Listen cancelled.") raise except ConnectionClosedError as err: - LOG.error(f"Websocket: Closed while listening: {err}") + LOG.error("Websocket: Closed while listening: %s", err) LOG.exception(err) pass except InvalidMessageError as err: - LOG.warning(f"Websocket: Received invalid json : {err}") + LOG.warning("Websocket: Received invalid json : %s", err) pass finally: LOG.info("Websocket: Listen completed; cleaning up")