diff --git a/dp3/api/internal/models.py b/dp3/api/internal/models.py index b60133b3..65918730 100644 --- a/dp3/api/internal/models.py +++ b/dp3/api/internal/models.py @@ -1,11 +1,10 @@ -from datetime import datetime from typing import Annotated, Any, Literal, Optional, Union from pydantic import BaseModel, Field, TypeAdapter, create_model, model_validator from dp3.api.internal.config import MODEL_SPEC from dp3.api.internal.helpers import api_to_dp3_datapoint -from dp3.common.types import T2Datetime +from dp3.common.types import AwareDatetime, T2Datetime class DataPoint(BaseModel): @@ -27,7 +26,7 @@ class DataPoint(BaseModel): id: Any attr: str v: Any - t1: Optional[datetime] = None + t1: Optional[AwareDatetime] = None t2: Optional[T2Datetime] = Field(None, validate_default=True) c: Annotated[float, Field(ge=0.0, le=1.0)] = 1.0 src: Optional[str] = None diff --git a/dp3/api/routers/entity.py b/dp3/api/routers/entity.py index 9b5988ff..ba81cc4f 100644 --- a/dp3/api/routers/entity.py +++ b/dp3/api/routers/entity.py @@ -1,4 +1,4 @@ -from datetime import datetime, timezone +from datetime import datetime from typing import Annotated, Any, Optional from fastapi import APIRouter, Depends, HTTPException, Request @@ -20,6 +20,7 @@ from dp3.api.internal.response_models import ErrorResponse, RequestValidationError, SuccessResponse from dp3.common.attrspec import AttrType from dp3.common.task import DataPointTask, task_context +from dp3.common.types import UTC, AwareDatetime from dp3.database.database import DatabaseError @@ -42,7 +43,7 @@ async def parse_eid(etype: str, eid: str): def get_eid_master_record_handler( - e: EntityId, date_from: Optional[datetime] = None, date_to: Optional[datetime] = None + e: EntityId, date_from: Optional[AwareDatetime] = None, date_to: Optional[AwareDatetime] = None ): """Handler for getting master record of EID""" # TODO: This is probably not the most efficient way. Maybe gather only @@ -66,7 +67,7 @@ def get_eid_master_record_handler( def get_eid_snapshots_handler( - e: EntityId, date_from: Optional[datetime] = None, date_to: Optional[datetime] = None + e: EntityId, date_from: Optional[AwareDatetime] = None, date_to: Optional[AwareDatetime] = None ) -> list[dict[str, Any]]: """Handler for getting snapshots of EID""" snapshots = list(DB.snapshots.get_by_eid(e.type, e.id, t1=date_from, t2=date_to)) @@ -274,7 +275,7 @@ async def count_entity_type_eids( @router.get("/{etype}/{eid}") async def get_eid_data( - e: ParsedEid, date_from: Optional[datetime] = None, date_to: Optional[datetime] = None + e: ParsedEid, date_from: Optional[AwareDatetime] = None, date_to: Optional[AwareDatetime] = None ) -> EntityEidData: """Get data of `etype`'s `eid`. @@ -294,7 +295,7 @@ async def get_eid_data( @router.get("/{etype}/{eid}/master") async def get_eid_master_record( - e: ParsedEid, date_from: Optional[datetime] = None, date_to: Optional[datetime] = None + e: ParsedEid, date_from: Optional[AwareDatetime] = None, date_to: Optional[AwareDatetime] = None ) -> EntityEidMasterRecord: """Get master record of `etype`'s `eid`.""" return get_eid_master_record_handler(e, date_from, date_to) @@ -302,7 +303,7 @@ async def get_eid_master_record( @router.get("/{etype}/{eid}/snapshots") async def get_eid_snapshots( - e: ParsedEid, date_from: Optional[datetime] = None, date_to: Optional[datetime] = None + e: ParsedEid, date_from: Optional[AwareDatetime] = None, date_to: Optional[AwareDatetime] = None ) -> EntityEidSnapshots: """Get snapshots of `etype`'s `eid`.""" return get_eid_snapshots_handler(e, date_from, date_to) @@ -312,8 +313,8 @@ async def get_eid_snapshots( async def get_eid_attr_value( e: ParsedEid, attr: str, - date_from: Optional[datetime] = None, - date_to: Optional[datetime] = None, + date_from: Optional[AwareDatetime] = None, + date_to: Optional[AwareDatetime] = None, ) -> EntityEidAttrValueOrHistory: """Get attribute value @@ -354,7 +355,7 @@ async def set_eid_attr_value( id=eid, attr=attr, v=body.value, - t1=datetime.now(timezone.utc), + t1=datetime.now(UTC), src=f"{request.client.host} via API", ) dp3_dp = api_to_dp3_datapoint(dp.dict()) @@ -393,7 +394,7 @@ async def get_distinct_attribute_values(etype: str, attr: str) -> dict[JsonVal, @router.post("/{etype}/{eid}/ttl") -async def extend_eid_ttls(e: ParsedEid, body: dict[str, datetime]) -> SuccessResponse: +async def extend_eid_ttls(e: ParsedEid, body: dict[str, AwareDatetime]) -> SuccessResponse: """Extend TTLs of the specified entity""" # Construct task with task_context(MODEL_SPEC): diff --git a/dp3/common/datapoint.py b/dp3/common/datapoint.py index 1753e427..ee8850c0 100644 --- a/dp3/common/datapoint.py +++ b/dp3/common/datapoint.py @@ -1,11 +1,10 @@ -from datetime import datetime from ipaddress import IPv4Address, IPv6Address from typing import Annotated, Any, Optional, Union from pydantic import BaseModel, BeforeValidator, Field, PlainSerializer from dp3.common.mac_address import MACAddress -from dp3.common.types import T2Datetime +from dp3.common.types import AwareDatetime, T2Datetime def to_json_friendly(v): @@ -60,7 +59,7 @@ class DataPointObservationsBase(DataPointBase): Contains single raw data value received on API for observations attribute. """ - t1: datetime + t1: AwareDatetime t2: T2Datetime = Field(None, validate_default=True) c: Annotated[float, Field(ge=0.0, le=1.0)] = 1.0 @@ -71,7 +70,7 @@ class DataPointTimeseriesBase(DataPointBase): Contains single raw data value received on API for observations attribute. """ - t1: datetime + t1: AwareDatetime t2: T2Datetime = Field(None, validate_default=True) diff --git a/dp3/common/types.py b/dp3/common/types.py index 296d0a7b..922963da 100644 --- a/dp3/common/types.py +++ b/dp3/common/types.py @@ -1,7 +1,7 @@ -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from ipaddress import IPv4Address, IPv6Address from json import JSONEncoder -from typing import Annotated, Any, Union +from typing import Annotated, Any, Optional, Union from event_count_logger import DummyEventGroup, EventGroup from pydantic import AfterValidator, BeforeValidator @@ -9,6 +9,8 @@ from dp3.common.utils import parse_time_duration, time_duration_pattern +UTC = timezone.utc + def parse_timedelta_or_passthrough(v): """ @@ -22,6 +24,18 @@ def parse_timedelta_or_passthrough(v): ParsedTimedelta = Annotated[timedelta, BeforeValidator(parse_timedelta_or_passthrough)] +def ensure_timezone_aware(v: Optional[datetime]): + """Ensure datetime is timezone-aware by defaulting to UTC.""" + if v is None: + return v + if v.tzinfo is None: + return v.replace(tzinfo=UTC) + return v + + +AwareDatetime = Annotated[datetime, AfterValidator(ensure_timezone_aware)] + + def t2_implicity_t1(v, info: FieldValidationInfo): """If t2 is not specified, it is set to t1.""" v = v or info.data.get("t1") @@ -35,7 +49,11 @@ def t2_after_t1(v, info: FieldValidationInfo): return v -T2Datetime = Annotated[datetime, BeforeValidator(t2_implicity_t1), AfterValidator(t2_after_t1)] +T2Datetime = Annotated[ + AwareDatetime, + BeforeValidator(t2_implicity_t1), + AfterValidator(t2_after_t1), +] EventGroupType = Union[EventGroup, DummyEventGroup] diff --git a/dp3/common/utils.py b/dp3/common/utils.py index 574c9afc..742bb258 100644 --- a/dp3/common/utils.py +++ b/dp3/common/utils.py @@ -2,9 +2,9 @@ auxiliary/utility functions and classes """ -import datetime import re from collections.abc import Iterable, Iterator +from datetime import datetime, timedelta from functools import partial from itertools import islice from typing import Union @@ -65,8 +65,8 @@ def parse_rfc_time(time_str): us = int(us_str) zonestr = res.group(8) zoneoffset = 0 if zonestr in (None, "z", "Z") else int(zonestr[:3]) * 60 + int(zonestr[4:6]) - zonediff = datetime.timedelta(minutes=zoneoffset) - return datetime.datetime(year, month, day, hour, minute, second, us) - zonediff + zonediff = timedelta(minutes=zoneoffset) + return datetime(year, month, day, hour, minute, second, us) - zonediff else: raise ValueError("Wrong timestamp format") @@ -74,18 +74,18 @@ def parse_rfc_time(time_str): time_duration_pattern = re.compile(r"^\s*(\d+)([smhd])?$") -def parse_time_duration(duration_string: Union[str, int, datetime.timedelta]) -> datetime.timedelta: +def parse_time_duration(duration_string: Union[str, int, timedelta]) -> timedelta: """ Parse duration in format (or just "0"). Return datetime.timedelta """ # if it's already timedelta, just return it unchanged - if isinstance(duration_string, datetime.timedelta): + if isinstance(duration_string, timedelta): return duration_string # if number is passed, consider it number of seconds if isinstance(duration_string, (int, float)): - return datetime.timedelta(seconds=duration_string) + return timedelta(seconds=duration_string) d = 0 h = 0 @@ -105,7 +105,7 @@ def parse_time_duration(duration_string: Union[str, int, datetime.timedelta]) -> else: raise ValueError("Invalid time duration string") - return datetime.timedelta(days=d, hours=h, minutes=m, seconds=s) + return timedelta(days=d, hours=h, minutes=m, seconds=s) # *** object (de)serialization *** @@ -118,14 +118,14 @@ def conv_to_json(obj): - datetime - timedelta """ - if isinstance(obj, datetime.datetime): + if isinstance(obj, datetime): if obj.tzinfo: raise NotImplementedError( "Can't serialize timezone-aware datetime object " "(DP3 policy is to use naive datetimes in UTC everywhere)" ) return {"$datetime": obj.strftime("%Y-%m-%dT%H:%M:%S.%f")} - if isinstance(obj, datetime.timedelta): + if isinstance(obj, timedelta): return {"$timedelta": f"{obj.days},{obj.seconds},{obj.microseconds}"} raise TypeError(f"{repr(obj)}%r is not JSON serializable") @@ -140,10 +140,10 @@ def conv_from_json(dct): """ if "$datetime" in dct: val = dct["$datetime"] - return datetime.datetime.strptime(val, "%Y-%m-%dT%H:%M:%S.%f") + return datetime.strptime(val, "%Y-%m-%dT%H:%M:%S.%f") if "$timedelta" in dct: days, seconds, microseconds = dct["$timedelta"].split(",") - return datetime.timedelta(int(days), int(seconds), int(microseconds)) + return timedelta(int(days), int(seconds), int(microseconds)) return dct diff --git a/dp3/core/collector.py b/dp3/core/collector.py index bb5a2df9..f6ea9785 100644 --- a/dp3/core/collector.py +++ b/dp3/core/collector.py @@ -15,6 +15,7 @@ from dp3.common.datapoint import DataPointBase, DataPointObservationsBase, DataPointTimeseriesBase from dp3.common.datatype import AnyEidT from dp3.common.task import DataPointTask, parse_eids_from_cache +from dp3.common.types import UTC from dp3.database.database import EntityDatabase DB_SEND_CHUNK = 1000 @@ -100,7 +101,7 @@ def extend_ttl_on_create( task = DataPointTask( etype=task.etype, eid=eid, - ttl_tokens={"base": datetime.utcnow() + base_ttl}, + ttl_tokens={"base": datetime.now(UTC) + base_ttl}, ) return [task] @@ -158,7 +159,7 @@ def _register_ttl_extensions( def collect_weak(self, etype: str): """Deletes weak entities when their last reference has expired.""" self.log.debug("Starting removal of '%s' weak entities", etype) - start = datetime.now() + start = datetime.now(UTC) entities = 0 deleted = 0 @@ -203,7 +204,7 @@ def collect_weak(self, etype: str): self.db.update_metadata( start, - metadata={"weak_collect_end": datetime.now()}, + metadata={"weak_collect_end": datetime.now(UTC)}, increase={"entities": entities, "deleted": deleted}, ) self.log.info( @@ -216,8 +217,7 @@ def collect_weak(self, etype: str): def collect_ttl(self, etype: str): """Deletes entities after their TTL lifetime has expired.""" self.log.debug("Starting removal of '%s' entities by TTL", etype) - start = datetime.now() - utc_now = datetime.utcnow() + now = datetime.now(UTC) entities = 0 deleted = 0 @@ -225,7 +225,7 @@ def collect_ttl(self, etype: str): expired_ttls = {} self.db.save_metadata( - start, {"entities": 0, "deleted": 0, "ttl_collect_start": start, "entity": etype} + now, {"entities": 0, "deleted": 0, "ttl_collect_start": now, "entity": etype} ) records_cursor = self.db.get_worker_master_records( @@ -237,12 +237,12 @@ def collect_ttl(self, etype: str): if "#ttl" not in master_document: continue # TTL not set, ignore for now - if all(ttl < utc_now for ttl in master_document["#ttl"].values()): + if all(ttl < now for ttl in master_document["#ttl"].values()): deleted += 1 to_delete.append(master_document["_id"]) else: eid_expired_ttls = [ - name for name, ttl in master_document["#ttl"].items() if ttl < start + name for name, ttl in master_document["#ttl"].items() if ttl < now ] if eid_expired_ttls: expired_ttls[master_document["_id"]] = eid_expired_ttls @@ -265,8 +265,8 @@ def collect_ttl(self, etype: str): records_cursor.close() self.db.update_metadata( - start, - metadata={"ttl_collect_end": datetime.now()}, + now, + metadata={"ttl_collect_end": datetime.now(UTC)}, increase={"entities": entities, "deleted": deleted}, ) self.log.info( @@ -280,7 +280,7 @@ def extend_plain_ttl( self, eid: AnyEidT, dp: DataPointBase, extend_by: timedelta ) -> list[DataPointTask]: """Extends the TTL of the entity by the specified timedelta.""" - now = datetime.utcnow() + now = datetime.now(UTC) task = DataPointTask( etype=dp.etype, eid=eid, diff --git a/dp3/core/link_manager.py b/dp3/core/link_manager.py index 4750b467..94aedc09 100644 --- a/dp3/core/link_manager.py +++ b/dp3/core/link_manager.py @@ -14,6 +14,7 @@ from dp3.common.datapoint import DataPointBase, DataPointObservationsBase from dp3.common.datatype import AnyEidT from dp3.common.task import parse_eids_from_cache +from dp3.common.types import UTC from dp3.database.database import EntityDatabase @@ -35,7 +36,7 @@ def __init__( self.db.register_on_entity_delete( self.remove_link_cache_of_deleted, self.remove_link_cache_of_many_deleted ) - self.max_date = datetime.max.replace(tzinfo=None) + self.max_date = datetime.max.replace(tzinfo=UTC) for (entity, attr), spec in self.model_spec.relations.items(): if spec.t == AttrType.PLAIN: if spec.is_iterable: diff --git a/dp3/core/updater.py b/dp3/core/updater.py index 5097470d..1bd507b8 100644 --- a/dp3/core/updater.py +++ b/dp3/core/updater.py @@ -14,7 +14,7 @@ from dp3.common.config import CronExpression, PlatformConfig from dp3.common.scheduler import Scheduler from dp3.common.task import DataPointTask, task_context -from dp3.common.types import EventGroupType, ParsedTimedelta +from dp3.common.types import UTC, EventGroupType, ParsedTimedelta from dp3.database.database import EntityDatabase from dp3.task_processing.task_queue import TaskQueueWriter @@ -79,7 +79,7 @@ class UpdateThreadState(BaseModel, validate_assignment=True): @classmethod def new(cls, hooks: dict, period: float, entity_type: str, eid_only: bool = False): """Create a new instance initialized with hooks and thread_id components.""" - now = datetime.now() + now = datetime.now(UTC) return cls( t_created=now, t_last_update=now, @@ -102,7 +102,7 @@ def thread_id(self) -> tuple[float, str, bool]: def reset(self): """Resets counters and timestamps.""" - now = datetime.now() + now = datetime.now(UTC) self.t_created = now self.t_last_update = now self.t_end = now + timedelta(seconds=self.period) @@ -399,7 +399,7 @@ def _process_batch( state.period, state.eid_only, ) - start = datetime.now() + start = datetime.now(UTC) iteration_cnt = state.total_iterations iteration = state.iteration @@ -418,7 +418,7 @@ def _process_batch( hook_runner(hooks, entity_type, record) state.processed += 1 - state.t_last_update = datetime.now() + state.t_last_update = datetime.now(UTC) duration = state.t_last_update - start state.runtime_secs += duration.total_seconds() diff --git a/dp3/database/database.py b/dp3/database/database.py index d6df921c..adb89bec 100644 --- a/dp3/database/database.py +++ b/dp3/database/database.py @@ -24,7 +24,7 @@ from dp3.common.datatype import AnyEidT from dp3.common.scheduler import Scheduler from dp3.common.task import HASH -from dp3.common.types import EventGroupType +from dp3.common.types import UTC, EventGroupType from dp3.database.config import MongoConfig, MongoReplicaConfig, MongoStandaloneConfig from dp3.database.encodings import get_codec_options from dp3.database.exceptions import DatabaseError @@ -202,7 +202,7 @@ def _raw_col(self, entity: str, **kwargs) -> Collection: @staticmethod def _get_new_archive_col_name(entity: str) -> str: """Returns name of new archive collection for `entity`.""" - return f"{entity}#archive_{datetime.utcnow().strftime('%Y_%m_%d_%H%M%S')}" + return f"{entity}#archive_{datetime.now(UTC).strftime('%Y_%m_%d_%H%M%S')}" def _archive_col_names(self, entity: str) -> list[str]: """Returns names of archive collections for `entity`.""" @@ -319,7 +319,7 @@ def _init_database_schema(self) -> None: continue # Create wildcard index for attribute histories - index_name = f"wildcard_{datetime.now().strftime('%Y%m%d%H%M%S')}" + index_name = f"wildcard_{datetime.now(UTC).strftime('%Y%m%d%H%M%S')}" try: master_col.create_index( [("$**", 1)], @@ -395,6 +395,7 @@ def insert_datapoints( # Update master document master_changes = {"pushes": defaultdict(list), "$set": {}} + dt_now = datetime.now(UTC) for dp in dps: attr_spec = self._db_schema_config.attr(etype, dp.attr) @@ -405,7 +406,7 @@ def insert_datapoints( # Rewrite value of plain attribute if attr_spec.t == AttrType.PLAIN: - master_changes["$set"][dp.attr] = {"v": v, "ts_last_update": datetime.now()} + master_changes["$set"][dp.attr] = {"v": v, "ts_last_update": dt_now} # Push new data of observation if attr_spec.t == AttrType.OBSERVATIONS: @@ -419,7 +420,7 @@ def insert_datapoints( if new_entity: master_changes["$set"]["#hash"] = HASH(f"{etype}:{eid}") - master_changes["$set"]["#time_created"] = datetime.now() + master_changes["$set"]["#time_created"] = dt_now with self._master_buffer_locks[etype]: if eid in self._master_buffers[etype]: @@ -541,7 +542,7 @@ def extend_ttl(self, etype: str, eid: AnyEidT, ttl_tokens: dict[str, datetime]): if eid in buf: if "$max" in buf[eid]: for ttl_name, this_val in extensions.items(): - curr_val = buf[eid]["$max"].get(ttl_name, datetime.min) + curr_val = buf[eid]["$max"].get(ttl_name, datetime.min.replace(tzinfo=UTC)) buf[eid]["$max"][ttl_name] = max(curr_val, this_val) else: self._master_buffers[etype][eid]["$max"] = extensions @@ -851,7 +852,7 @@ def save_metadata(self, time: datetime, metadata: dict, worker_id: Optional[int] metadata["_id"] = self._get_metadata_id(module, time, worker_id) metadata["#module"] = module metadata["#time_created"] = time - metadata["#last_update"] = datetime.now() + metadata["#last_update"] = datetime.now(UTC) try: self._db["#metadata"].insert_one(metadata) self.log.debug("Inserted metadata %s: %s", metadata["_id"], metadata) @@ -864,7 +865,7 @@ def update_metadata( """Updates existing metadata of caller module and passed timestamp.""" module = get_caller_id() metadata_id = self._get_metadata_id(module, time, worker_id) - metadata["#last_update"] = datetime.now() + metadata["#last_update"] = datetime.now(UTC) changes = {"$set": metadata} if increase is None else {"$set": metadata, "$inc": increase} @@ -901,8 +902,8 @@ def get_observation_history( Returns: list of dicts (reduced datapoints) """ - t1 = datetime.fromtimestamp(0) if t1 is None else t1.replace(tzinfo=None) - t2 = datetime.now() if t2 is None else t2.replace(tzinfo=None) + t1 = datetime.fromtimestamp(0, UTC) if t1 is None else t1.astimezone(UTC) + t2 = datetime.now(UTC) if t2 is None else t2.astimezone(UTC) # Get attribute history mr = self.get_master_record(etype, eid) @@ -956,8 +957,8 @@ def get_timeseries_history( Returns: list of dicts (reduced datapoints) - each represents just one point at time """ - t1 = datetime.fromtimestamp(0) if t1 is None else t1.replace(tzinfo=None) - t2 = datetime.now() if t2 is None else t2.replace(tzinfo=None) + t1 = datetime.fromtimestamp(0, UTC) if t1 is None else t1.astimezone(UTC) + t2 = datetime.now(UTC) if t2 is None else t2.astimezone(UTC) attr_history = self.get_observation_history(etype, attr_name, eid, t1, t2, sort) if not attr_history: diff --git a/dp3/database/encodings.py b/dp3/database/encodings.py index 36e79bbd..228c3143 100644 --- a/dp3/database/encodings.py +++ b/dp3/database/encodings.py @@ -1,4 +1,5 @@ from ipaddress import IPv4Address, IPv6Address +from zoneinfo import ZoneInfo from bson import Binary from bson.binary import USER_DEFINED_SUBTYPE @@ -38,4 +39,4 @@ def transform_bson(self, value): def get_codec_options(): tr = TypeRegistry([DP3BinaryDecoder()], fallback_encoder=fallback_encoder) - return CodecOptions(type_registry=tr) + return CodecOptions(type_registry=tr, tz_aware=True, tzinfo=ZoneInfo("UTC")) diff --git a/dp3/database/schema_cleaner.py b/dp3/database/schema_cleaner.py index 8f76b562..09ada81b 100644 --- a/dp3/database/schema_cleaner.py +++ b/dp3/database/schema_cleaner.py @@ -12,6 +12,7 @@ from dp3.common.attrspec import ID_REGEX, AttrSpecType, AttrType from dp3.common.config import HierarchicalDict, ModelSpec +from dp3.common.types import UTC from dp3.common.utils import batched # number of seconds to wait for the i-th attempt to reconnect after error @@ -603,7 +604,7 @@ def migrate_schema_2_to_3(self, schema: dict) -> dict: {"_id": {"$not": {"$regex": r"_#\d+$"}}} ): if doc.get("oversized", False): - ctime = doc["last"].get("_time_created", datetime.now()) + ctime = doc["last"].get("_time_created", datetime.now(UTC)) snapshot_col.bulk_write( [ InsertOne( diff --git a/dp3/history_management/history_manager.py b/dp3/history_management/history_manager.py index b6291403..15d78926 100644 --- a/dp3/history_management/history_manager.py +++ b/dp3/history_management/history_manager.py @@ -16,7 +16,7 @@ ) from dp3.common.callback_registrar import CallbackRegistrar from dp3.common.config import CronExpression, PlatformConfig -from dp3.common.types import DP3Encoder, ParsedTimedelta +from dp3.common.types import UTC, DP3Encoder, ParsedTimedelta from dp3.common.utils import entity_expired from dp3.database.database import DatabaseError, EntityDatabase @@ -134,7 +134,7 @@ def delete_old_dps(self): if not max_age: continue - t_old = datetime.utcnow() - max_age + t_old = datetime.now(UTC) - max_age try: self.db.delete_old_dps(etype, attr_name, t_old) @@ -161,7 +161,7 @@ def mark_datapoints_in_master_docs(self): def delete_old_snapshots(self): """Deletes old snapshots.""" - t_old = datetime.now() - self.keep_snapshot_delta + t_old = datetime.now(UTC) - self.keep_snapshot_delta self.log.debug("Deleting all snapshots before %s", t_old) deleted_total = 0 @@ -178,7 +178,7 @@ def archive_old_dps(self): Updates already saved archive files, if present. """ - t_old = datetime.utcnow() - self.keep_raw_delta + t_old = datetime.now(UTC) - self.keep_raw_delta self.log.debug("Archiving all records before %s ...", t_old) for etype in self.model_spec.entities: @@ -258,10 +258,9 @@ def _ensure_log_dir(log_dir_path: str): def aggregate_master_docs(self): self.log.debug("Starting master documents aggregation.") - start = datetime.now() - utcnow = datetime.utcnow() + ts = datetime.now(UTC) entities = 0 - self.db.save_metadata(start, {"entities": 0, "aggregation_start": start}) + self.db.save_metadata(ts, {"entities": 0, "aggregation_start": ts}) for entity in self.model_spec.entities: entity_attr_specs = self.model_spec.entity_attributes[entity] @@ -272,7 +271,7 @@ def aggregate_master_docs(self): ) try: for master_document in records_cursor: - if entity_expired(utcnow, master_document): + if entity_expired(ts, master_document): continue # Avoid expired entities to avoid conflict with garbage collector entities += 1 @@ -293,7 +292,7 @@ def aggregate_master_docs(self): records_cursor.close() self.db.update_metadata( - start, metadata={"aggregation_end": datetime.now()}, increase={"entities": entities} + ts, metadata={"aggregation_end": datetime.now(UTC)}, increase={"entities": entities} ) self.log.debug("Master documents aggregation end.") diff --git a/dp3/scripts/dummy_sender.py b/dp3/scripts/dummy_sender.py index 4495e526..b6d8cb2f 100755 --- a/dp3/scripts/dummy_sender.py +++ b/dp3/scripts/dummy_sender.py @@ -6,7 +6,7 @@ import os import time from argparse import ArgumentParser -from datetime import datetime +from datetime import datetime, timezone from itertools import islice from queue import Queue from threading import Event, Thread @@ -14,6 +14,8 @@ import pandas as pd import requests +UTC = timezone.utc + def get_valid_path(parser, arg): if not os.path.exists(arg): @@ -43,7 +45,7 @@ def get_shifted_datapoint_from_row(row): del dp["t1"] del dp["t2"] return dp - now = datetime.utcnow() + now = datetime.now(UTC) shift = now - dp["t1"] dp["t1"] = now.strftime("%Y-%m-%dT%H:%M:%S.%fZ")[:-4] diff --git a/dp3/snapshots/snapshooter.py b/dp3/snapshots/snapshooter.py index 46d70770..717524ee 100644 --- a/dp3/snapshots/snapshooter.py +++ b/dp3/snapshots/snapshooter.py @@ -42,7 +42,7 @@ parse_eids_from_cache, task_context, ) -from dp3.common.types import EventGroupType +from dp3.common.types import UTC, EventGroupType from dp3.common.utils import get_func_name from dp3.database.database import EntityDatabase from dp3.snapshots.snapshot_hooks import ( @@ -238,7 +238,7 @@ def register_run_finalize_hook(self, hook: Callable[[], list[DataPointTask]]): def make_snapshots(self): """Creates snapshots for all entities currently active in database.""" - time = datetime.utcnow() + time = datetime.now(UTC) self.db.save_metadata( time, { @@ -270,7 +270,7 @@ def make_snapshots(self): counts = {"entities": 0, "components": 0} try: linked_entities = self.get_linked_entities(time, cached) - times["components_loaded"] = datetime.utcnow() + times["components_loaded"] = datetime.now(UTC) for i, linked_entities_component in enumerate(linked_entities): counts["entities"] += len(linked_entities_component) @@ -291,7 +291,7 @@ def make_snapshots(self): except pymongo.errors.CursorNotFound as err: self.log.exception(err) finally: - times["task_creation_end"] = datetime.utcnow() + times["task_creation_end"] = datetime.now(UTC) self.db.update_metadata( time, metadata=times, diff --git a/dp3/task_processing/task_distributor.py b/dp3/task_processing/task_distributor.py index e7f86f48..42f765af 100644 --- a/dp3/task_processing/task_distributor.py +++ b/dp3/task_processing/task_distributor.py @@ -6,7 +6,6 @@ import sys import threading import time -from datetime import datetime from functools import partial from dp3.common.config import PlatformConfig @@ -204,14 +203,14 @@ def _worker_func(self, thread_index): continue # Process the task - start_time = datetime.now() + start_time = time.time() try: created, new_tasks = self.task_executor.process_task(task) except Exception: self.log.error(f"Error has occurred during processing task: {task}") raise - duration = (datetime.now() - start_time).total_seconds() + duration = time.time() - start_time # self.log.debug("Task {} finished in {:.3f} seconds.".format(msg_id, duration)) if duration > 1.0: self.log.debug( diff --git a/tests/test_api/common.py b/tests/test_api/common.py index 848383db..9f766fb0 100644 --- a/tests/test_api/common.py +++ b/tests/test_api/common.py @@ -25,7 +25,7 @@ class ConfigEnv(BaseModel): CONF_DIR: str -conf_env = ConfigEnv.parse_obj(os.environ) +conf_env = ConfigEnv.model_validate(os.environ) CONFIG = read_config_dir(conf_env.CONF_DIR, recursive=True) MODEL_SPEC = ModelSpec(CONFIG.get("db_entities")) @@ -39,7 +39,7 @@ class ConfigEnv(BaseModel): "ipv4": ["127.0.0.1"], "ipv6": ["2001:0db8:85a3:0000:0000:8a2e:0370:7334", "::1"], "mac": ["de:ad:be:ef:ba:be", "11:22:33:44:55:66"], - "time": ["2020-01-01T00:00:00"], + "time": ["2020-01-01T00:00:00Z"], "json": [{"test": "test"}], "category": ["cat1"], "array": [[1, 2, 3]], diff --git a/tests/test_api/test_01_datapoints.py b/tests/test_api/test_01_datapoints.py index f79a59da..c7e73c36 100644 --- a/tests/test_api/test_01_datapoints.py +++ b/tests/test_api/test_01_datapoints.py @@ -6,6 +6,8 @@ import common from common import ACCEPTED_ERROR_CODES +from dp3.common.types import UTC + class PushDatapoints(common.APITest): def test_invalid_payload(self): @@ -68,7 +70,7 @@ def make_datapoint(data_type: str, value: Any) -> dict[str, Any]: def make_observation_datapoint(self, data_type: str, value: Any) -> dict[str, Any]: dp = self.make_datapoint(data_type, value) - dp["t1"] = datetime.utcnow().isoformat() + dp["t1"] = datetime.now(UTC).isoformat() return dp def helper_test_datatype_value(self, datapoint: dict, expected_codes: set[int]): diff --git a/tests/test_api/test_get_entity_eid_data.py b/tests/test_api/test_get_entity_eid_data.py index 13a7c34c..b7e1ac9e 100644 --- a/tests/test_api/test_get_entity_eid_data.py +++ b/tests/test_api/test_get_entity_eid_data.py @@ -5,6 +5,7 @@ from pydantic import RootModel from dp3.api.internal.entity_response_models import EntityEidData, EntityEidMasterRecord +from dp3.common.types import UTC DATAPOINT_COUNT = 6 @@ -14,9 +15,9 @@ class GetEntityEidData(common.APITest): def setUp(self) -> None: super().setUpClass() - t1 = datetime.now() - timedelta(minutes=30) + t1 = datetime.now(UTC) - timedelta(minutes=30) t2 = t1 + timedelta(minutes=10) - self.eid = f"test_get_data__{datetime.now()}" + self.eid = f"test_get_data__{datetime.now(UTC)}" dp_base = { "src": "setup@test", "attr": "test_attr_history", diff --git a/tests/test_api/test_snapshots.py b/tests/test_api/test_snapshots.py index ffae83ff..ee6d8828 100644 --- a/tests/test_api/test_snapshots.py +++ b/tests/test_api/test_snapshots.py @@ -6,12 +6,13 @@ import common from dp3.api.internal.entity_response_models import EntityEidData +from dp3.common.types import UTC class SnapshotIntegration(common.APITest): @classmethod def setUpClass(cls) -> None: - now = datetime.datetime.now() + now = datetime.datetime.now(UTC) t1 = (now - datetime.timedelta(minutes=30)).strftime("%Y-%m-%dT%H:%M:%S.%fZ")[:-4] t2 = (now + datetime.timedelta(minutes=30)).strftime("%Y-%m-%dT%H:%M:%S.%fZ")[:-4] diff --git a/tests/test_common/test_snapshots.py b/tests/test_common/test_snapshots.py index 9adff033..ae39eed8 100644 --- a/tests/test_common/test_snapshots.py +++ b/tests/test_common/test_snapshots.py @@ -10,6 +10,7 @@ from dp3.common.config import ModelSpec, PlatformConfig, read_config_dir from dp3.common.task import Task +from dp3.common.types import UTC from dp3.snapshots.snapshooter import SnapShooter from dp3.snapshots.snapshot_hooks import SnapshotCorrelationHookContainer @@ -153,7 +154,7 @@ def setUp(self) -> None: config = read_config_dir(config_base_path, recursive=True) self.model_spec = ModelSpec(config.get("db_entities")) - self.now = datetime.datetime.now() + self.now = datetime.datetime.now(UTC) self.t1 = self.now - datetime.timedelta(minutes=30) self.t2 = self.now + datetime.timedelta(minutes=30) diff --git a/tests/test_common/test_types.py b/tests/test_common/test_types.py new file mode 100644 index 00000000..1c0e0fb0 --- /dev/null +++ b/tests/test_common/test_types.py @@ -0,0 +1,34 @@ +import unittest +from datetime import datetime, timedelta, timezone + +from pydantic import BaseModel, Field + +from dp3.common.types import AwareDatetime, T2Datetime + + +class _AwareModel(BaseModel): + dt: AwareDatetime + + +class _T2Model(BaseModel): + t1: AwareDatetime + t2: T2Datetime = Field(None, validate_default=True) + + +class TestAwareDatetime(unittest.TestCase): + def test_naive_datetime_defaults_to_utc(self): + model = _AwareModel(dt="2024-01-01T10:00:00") + self.assertEqual(model.dt.tzinfo, timezone.utc) + + def test_existing_timezone_is_preserved(self): + cest_timezone = timezone(timedelta(hours=2), "CEST") + aware = datetime(2024, 1, 1, 10, 0, tzinfo=cest_timezone) + model = _AwareModel(dt=aware) + self.assertEqual(model.dt.tzinfo, cest_timezone) + + def test_t2_datetime_inherits_timezone_when_missing(self): + model = _T2Model(t1="2024-01-01T00:00:00") + self.assertIsNotNone(model.t2) + self.assertEqual(model.t1.tzinfo, timezone.utc) + self.assertEqual(model.t2.tzinfo, timezone.utc) + self.assertEqual(model.t2, model.t1) diff --git a/tests/test_example/dps_gen.py b/tests/test_example/dps_gen.py index 2ee00735..9001a162 100644 --- a/tests/test_example/dps_gen.py +++ b/tests/test_example/dps_gen.py @@ -1,20 +1,20 @@ # Very simple datapoint generator for bus example config -import datetime import json import random +from datetime import datetime, timedelta, timezone class TimeContainer: def __init__(self): - self.time = datetime.datetime.utcnow() - datetime.timedelta(days=4) + self.time = datetime.now(timezone.utc) - timedelta(days=4) def add_minutes(self, minutes: int): - self.time += datetime.timedelta(minutes=minutes) + self.time += timedelta(minutes=minutes) return self.time def add_minutes_no_modify(self, minutes: int): - return self.time + datetime.timedelta(minutes=minutes) + return self.time + timedelta(minutes=minutes) time = TimeContainer() @@ -129,7 +129,7 @@ def random_passenger_counts_3(): "back_out": random_passenger_counts_3(), }, "t1": random_t1_local.isoformat(), - "t2": (random_t1_local + datetime.timedelta(minutes=30)).isoformat(), + "t2": (random_t1_local + timedelta(minutes=30)).isoformat(), "src": "Bus counter", } ) diff --git a/tests/test_example/dps_gen_realtime.py b/tests/test_example/dps_gen_realtime.py index a0266600..9f701e87 100644 --- a/tests/test_example/dps_gen_realtime.py +++ b/tests/test_example/dps_gen_realtime.py @@ -2,12 +2,14 @@ import random from argparse import ArgumentParser -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from sys import stderr from time import sleep import requests +UTC = timezone.utc + def random_initial_location(): latitude = random.uniform(39.0, 41.0) @@ -22,11 +24,11 @@ def do_random_location_increment(current_location): def t1(): - return datetime.utcnow() + return datetime.now(UTC) def random_t2(): - return datetime.utcnow() + timedelta(minutes=random.randint(5, 15)) + return datetime.now(UTC) + timedelta(minutes=random.randint(5, 15)) def random_passenger_counts_3():