diff --git a/.gitignore b/.gitignore index 10754c5..46780ee 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ build/ *.dump .vscode/ app/algorithms/health/model/my_survival_forest_model_quxi.joblib +inp/ diff --git a/app/core/security.py b/app/core/security.py index 802e837..a99e69f 100644 --- a/app/core/security.py +++ b/app/core/security.py @@ -1,4 +1,4 @@ -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from typing import Optional, Union, Any from jose import jwt @@ -8,6 +8,10 @@ from app.core.config import settings pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") +def _utc_now() -> datetime: + return datetime.now(timezone.utc) + + def create_access_token( subject: Union[str, Any], expires_delta: Optional[timedelta] = None ) -> str: @@ -22,9 +26,9 @@ def create_access_token( JWT token 字符串 """ if expires_delta: - expire = datetime.now() + expires_delta + expire = _utc_now() + expires_delta else: - expire = datetime.now() + timedelta( + expire = _utc_now() + timedelta( minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES ) @@ -32,7 +36,7 @@ def create_access_token( "exp": expire, "sub": str(subject), "type": "access", - "iat": datetime.now(), + "iat": _utc_now(), } encoded_jwt = jwt.encode( to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM @@ -50,13 +54,13 @@ def create_refresh_token(subject: Union[str, Any]) -> str: Returns: JWT refresh token 字符串 """ - expire = datetime.now() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS) + expire = _utc_now() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS) to_encode = { "exp": expire, "sub": str(subject), "type": "refresh", - "iat": datetime.now(), + "iat": _utc_now(), } encoded_jwt = jwt.encode( to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM diff --git a/app/infra/db/timescaledb/internal_queries.py b/app/infra/db/timescaledb/internal_queries.py index 3de4db7..fda46f3 100644 --- a/app/infra/db/timescaledb/internal_queries.py +++ b/app/infra/db/timescaledb/internal_queries.py @@ -10,6 +10,7 @@ from app.core.config import get_timescaledb_pgconn_string from app.infra.db.timescaledb.repositories.scheme import SchemeRepository from app.infra.db.timescaledb.repositories.realtime import RealtimeRepository from app.infra.db.timescaledb.repositories.scada import ScadaRepository +from app.services.time_api import parse_utc_time class InternalStorage: @@ -89,10 +90,9 @@ class InternalQueries: ) -> dict: """查询指定时间点的 SCADA 数据""" - # 解析时间,假设是北京时间 - beijing_time = datetime.fromisoformat(query_time) - start_time = beijing_time - timedelta(seconds=1) - end_time = beijing_time + timedelta(seconds=1) + target_time = parse_utc_time(query_time, field_name="query_time") + start_time = target_time - timedelta(seconds=1) + end_time = target_time + timedelta(seconds=1) for attempt in range(max_retries): try: @@ -132,14 +132,8 @@ class InternalQueries: max_retries: int = 3, ) -> dict[str, list[dict]]: """查询指定时间窗的 SCADA 数据,返回 {device_id: [{time, value}, ...]}。""" - start_dt = ( - datetime.fromisoformat(start_time) - if isinstance(start_time, str) - else start_time - ) - end_dt = ( - datetime.fromisoformat(end_time) if isinstance(end_time, str) else end_time - ) + start_dt = parse_utc_time(start_time, field_name="start_time") + end_dt = parse_utc_time(end_time, field_name="end_time") for attempt in range(max_retries): try: @@ -238,14 +232,8 @@ class InternalQueries: if not element_ids: return {} - start_dt = ( - datetime.fromisoformat(start_time) - if isinstance(start_time, str) - else start_time - ) - end_dt = ( - datetime.fromisoformat(end_time) if isinstance(end_time, str) else end_time - ) + start_dt = parse_utc_time(start_time, field_name="start_time") + end_dt = parse_utc_time(end_time, field_name="end_time") table_name, valid_fields = InternalQueries._resolve_simulation_table(element_type) if field not in valid_fields: raise ValueError(f"Invalid field for {element_type}: {field}") diff --git a/app/infra/db/timescaledb/repositories/realtime.py b/app/infra/db/timescaledb/repositories/realtime.py index 06a32de..7c0facb 100644 --- a/app/infra/db/timescaledb/repositories/realtime.py +++ b/app/infra/db/timescaledb/repositories/realtime.py @@ -1,10 +1,8 @@ from typing import List, Any, Dict -from datetime import datetime, timedelta, timezone +from datetime import datetime, timedelta from collections import defaultdict from psycopg import AsyncConnection, Connection, sql - -# 定义UTC+8时区 -UTC_8 = timezone(timedelta(hours=8)) +from app.services.time_api import parse_utc_time class RealtimeRepository: @@ -397,24 +395,9 @@ class RealtimeRepository: link_result_list: List of link simulation results result_start_time: Start time for the results (ISO format string) """ - # Convert result_start_time string to datetime if needed - if isinstance(result_start_time, str): - # 如果是ISO格式字符串,解析并转换为UTC+8 - if result_start_time.endswith("Z"): - # UTC时间,转换为UTC+8 - utc_time = datetime.fromisoformat( - result_start_time.replace("Z", "+00:00") - ) - simulation_time = utc_time.astimezone(UTC_8) - else: - # 假设已经是UTC+8时间 - simulation_time = datetime.fromisoformat(result_start_time) - if simulation_time.tzinfo is None: - simulation_time = simulation_time.replace(tzinfo=UTC_8) - else: - simulation_time = result_start_time - if simulation_time.tzinfo is None: - simulation_time = simulation_time.replace(tzinfo=UTC_8) + simulation_time = parse_utc_time( + result_start_time, field_name="result_start_time" + ) # Prepare node data for batch insert node_data = [] @@ -475,24 +458,9 @@ class RealtimeRepository: link_result_list: List of link simulation results result_start_time: Start time for the results (ISO format string) """ - # Convert result_start_time string to datetime if needed - if isinstance(result_start_time, str): - # 如果是ISO格式字符串,解析并转换为UTC+8 - if result_start_time.endswith("Z"): - # UTC时间,转换为UTC+8 - utc_time = datetime.fromisoformat( - result_start_time.replace("Z", "+00:00") - ) - simulation_time = utc_time.astimezone(UTC_8) - else: - # 假设已经是UTC+8时间 - simulation_time = datetime.fromisoformat(result_start_time) - if simulation_time.tzinfo is None: - simulation_time = simulation_time.replace(tzinfo=UTC_8) - else: - simulation_time = result_start_time - if simulation_time.tzinfo is None: - simulation_time = simulation_time.replace(tzinfo=UTC_8) + simulation_time = parse_utc_time( + result_start_time, field_name="result_start_time" + ) # Prepare node data for batch insert node_data = [] @@ -556,21 +524,7 @@ class RealtimeRepository: Returns: List of records matching the criteria """ - # Convert query_time string to datetime - if isinstance(query_time, str): - if query_time.endswith("Z"): - # UTC时间,转换为UTC+8 - utc_time = datetime.fromisoformat(query_time.replace("Z", "+00:00")) - target_time = utc_time.astimezone(UTC_8) - else: - # 假设已经是UTC+8时间 - target_time = datetime.fromisoformat(query_time) - if target_time.tzinfo is None: - target_time = target_time.replace(tzinfo=UTC_8) - else: - target_time = query_time - if target_time.tzinfo is None: - target_time = target_time.replace(tzinfo=UTC_8) + target_time = parse_utc_time(query_time, field_name="query_time") # Create time range: query_time ± 1 second start_time = target_time - timedelta(seconds=1) @@ -614,21 +568,7 @@ class RealtimeRepository: Returns: List of records matching the criteria """ - # Convert query_time string to datetime - if isinstance(query_time, str): - if query_time.endswith("Z"): - # UTC时间,转换为UTC+8 - utc_time = datetime.fromisoformat(query_time.replace("Z", "+00:00")) - target_time = utc_time.astimezone(UTC_8) - else: - # 假设已经是UTC+8时间 - target_time = datetime.fromisoformat(query_time) - if target_time.tzinfo is None: - target_time = target_time.replace(tzinfo=UTC_8) - else: - target_time = query_time - if target_time.tzinfo is None: - target_time = target_time.replace(tzinfo=UTC_8) + target_time = parse_utc_time(query_time, field_name="query_time") # Create time range: query_time ± 1 second start_time = target_time - timedelta(seconds=1) diff --git a/app/infra/db/timescaledb/repositories/scheme.py b/app/infra/db/timescaledb/repositories/scheme.py index bfa09ca..f0960b3 100644 --- a/app/infra/db/timescaledb/repositories/scheme.py +++ b/app/infra/db/timescaledb/repositories/scheme.py @@ -1,11 +1,9 @@ from typing import List, Any, Dict -from datetime import datetime, timedelta, timezone +from datetime import datetime, timedelta from collections import defaultdict from psycopg import AsyncConnection, Connection, sql import app.services.globals as globals - -# 定义UTC+8时区 -UTC_8 = timezone(timedelta(hours=8)) +from app.services.time_api import parse_utc_time class SchemeRepository: @@ -466,24 +464,9 @@ class SchemeRepository: link_result_list: List of link simulation results result_start_time: Start time for the results (ISO format string) """ - # Convert result_start_time string to datetime if needed - if isinstance(result_start_time, str): - # 如果是ISO格式字符串,解析并转换为UTC+8 - if result_start_time.endswith("Z"): - # UTC时间,转换为UTC+8 - utc_time = datetime.fromisoformat( - result_start_time.replace("Z", "+00:00") - ) - simulation_time = utc_time.astimezone(UTC_8) - else: - # 假设已经是UTC+8时间 - simulation_time = datetime.fromisoformat(result_start_time) - if simulation_time.tzinfo is None: - simulation_time = simulation_time.replace(tzinfo=UTC_8) - else: - simulation_time = result_start_time - if simulation_time.tzinfo is None: - simulation_time = simulation_time.replace(tzinfo=UTC_8) + simulation_time = parse_utc_time( + result_start_time, field_name="result_start_time" + ) timestep_parts = globals.hydraulic_timestep.split(":") timestep = timedelta( @@ -564,24 +547,9 @@ class SchemeRepository: link_result_list: List of link simulation results result_start_time: Start time for the results (ISO format string) """ - # Convert result_start_time string to datetime if needed - if isinstance(result_start_time, str): - # 如果是ISO格式字符串,解析并转换为UTC+8 - if result_start_time.endswith("Z"): - # UTC时间,转换为UTC+8 - utc_time = datetime.fromisoformat( - result_start_time.replace("Z", "+00:00") - ) - simulation_time = utc_time.astimezone(UTC_8) - else: - # 假设已经是UTC+8时间 - simulation_time = datetime.fromisoformat(result_start_time) - if simulation_time.tzinfo is None: - simulation_time = simulation_time.replace(tzinfo=UTC_8) - else: - simulation_time = result_start_time - if simulation_time.tzinfo is None: - simulation_time = simulation_time.replace(tzinfo=UTC_8) + simulation_time = parse_utc_time( + result_start_time, field_name="result_start_time" + ) timestep_parts = globals.hydraulic_timestep.split(":") timestep = timedelta( @@ -664,21 +632,7 @@ class SchemeRepository: Returns: List of records matching the criteria """ - # Convert query_time string to datetime - if isinstance(query_time, str): - if query_time.endswith("Z"): - # UTC时间,转换为UTC+8 - utc_time = datetime.fromisoformat(query_time.replace("Z", "+00:00")) - target_time = utc_time.astimezone(UTC_8) - else: - # 假设已经是UTC+8时间 - target_time = datetime.fromisoformat(query_time) - if target_time.tzinfo is None: - target_time = target_time.replace(tzinfo=UTC_8) - else: - target_time = query_time - if target_time.tzinfo is None: - target_time = target_time.replace(tzinfo=UTC_8) + target_time = parse_utc_time(query_time, field_name="query_time") # Create time range: query_time ± 1 second start_time = target_time - timedelta(seconds=1) @@ -727,21 +681,7 @@ class SchemeRepository: Returns: List of records matching the criteria """ - # Convert query_time string to datetime - if isinstance(query_time, str): - if query_time.endswith("Z"): - # UTC时间,转换为UTC+8 - utc_time = datetime.fromisoformat(query_time.replace("Z", "+00:00")) - target_time = utc_time.astimezone(UTC_8) - else: - # 假设已经是UTC+8时间 - target_time = datetime.fromisoformat(query_time) - if target_time.tzinfo is None: - target_time = target_time.replace(tzinfo=UTC_8) - else: - target_time = query_time - if target_time.tzinfo is None: - target_time = target_time.replace(tzinfo=UTC_8) + target_time = parse_utc_time(query_time, field_name="query_time") # Create time range: query_time ± 1 second start_time = target_time - timedelta(seconds=1) diff --git a/app/services/burst_detection.py b/app/services/burst_detection.py index 59baf32..9665934 100644 --- a/app/services/burst_detection.py +++ b/app/services/burst_detection.py @@ -14,6 +14,7 @@ from app.services.scheme_management import ( store_scheme_info, ) from app.services.tjnetwork import get_all_scada_info +from app.services.time_api import extract_date, parse_utc_time, utc_now def run_burst_detection( @@ -241,7 +242,7 @@ def list_burst_detection_schemes( network: str, query_date: datetime | str | None = None, ) -> list[dict[str, Any]]: - parsed_date = _to_datetime(query_date).date() if query_date is not None else None + parsed_date = extract_date(query_date, field_name="query_date") if query_date is not None else None return query_burst_detection_schemes( name=network, network=network, @@ -269,7 +270,7 @@ def _store_burst_detection_scheme( if scheme_name_exists(network, scheme_name): raise ValueError(f"方案名称已存在: {scheme_name}") - now_iso = datetime.now().isoformat() + now_iso = utc_now().isoformat() scheme_detail = { "network": network, "sensor_nodes": payload.get("sensor_nodes", []), @@ -426,6 +427,4 @@ def _build_observed_pressure_from_scada( def _to_datetime(value: datetime | str) -> datetime: - if isinstance(value, datetime): - return value - return datetime.fromisoformat(value) + return parse_utc_time(value) diff --git a/app/services/burst_location.py b/app/services/burst_location.py index 3892ca2..5a6b52b 100644 --- a/app/services/burst_location.py +++ b/app/services/burst_location.py @@ -15,6 +15,7 @@ from app.services.scheme_management import ( store_scheme_info, ) from app.services.tjnetwork import dump_inp, get_all_scada_info +from app.services.time_api import extract_date, parse_utc_time, utc_now SeriesInput = pd.Series | dict[str, Any] | list[dict[str, Any]] FLOW_SCADA_TYPES = {"pipe_flow", "flow", "demand"} @@ -301,7 +302,7 @@ def run_burst_location_by_network( def list_burst_location_schemes( network: str, query_date: datetime | str | None = None ) -> list[dict[str, Any]]: - parsed_date = _to_datetime(query_date).date() if query_date is not None else None + parsed_date = extract_date(query_date, field_name="query_date") if query_date is not None else None return query_burst_location_schemes( name=network, network=network, query_date=parsed_date ) @@ -327,7 +328,7 @@ def _store_burst_scheme( if scheme_name_exists(network, scheme_name): raise ValueError(f"方案名称已存在: {scheme_name}") - now_iso = datetime.now().isoformat() + now_iso = utc_now().isoformat() scheme_detail = { "network": network, "pressure_scada_ids": payload.get("pressure_scada_ids", []), @@ -641,9 +642,7 @@ def _dedupe_ids(ids: list[str] | None) -> list[str]: def _to_datetime(value: datetime | str) -> datetime: - if isinstance(value, datetime): - return value - return datetime.fromisoformat(value) + return parse_utc_time(value) def _prepare_burst_inp(network: str) -> str: diff --git a/app/services/leakage_identifier.py b/app/services/leakage_identifier.py index e90cb24..a85d653 100644 --- a/app/services/leakage_identifier.py +++ b/app/services/leakage_identifier.py @@ -23,6 +23,7 @@ from app.services.tjnetwork import ( get_network_link_nodes, get_network_node_coords, ) +from app.services.time_api import extract_date, parse_utc_time, utc_now DEFAULT_N_WORKERS = max(1, min((os.cpu_count() or 1) - 1, 4)) @@ -119,7 +120,7 @@ def run_leakage_identification( scheme_start_time = ( _to_datetime(scada_start).isoformat() if scada_start is not None - else datetime.now().isoformat() + else utc_now().isoformat() ) scheme_detail = { "network": network, @@ -177,7 +178,7 @@ def run_leakage_identification( def list_leakage_identify_schemes( network: str, query_date: datetime | str | None = None ) -> list[dict[str, Any]]: - parsed_date = _to_datetime(query_date).date() if query_date is not None else None + parsed_date = extract_date(query_date, field_name="query_date") if query_date is not None else None return query_leakage_identify_schemes( name=network, network=network, query_date=parsed_date ) @@ -509,9 +510,7 @@ def _build_observed_pressure_from_scada( def _to_datetime(value: datetime | str) -> datetime: - if isinstance(value, datetime): - return value - return datetime.fromisoformat(value) + return parse_utc_time(value) def _prepare_leakage_inp(network: str) -> str: diff --git a/app/services/scheme_management.py b/app/services/scheme_management.py index a86a9bd..0bb1f11 100644 --- a/app/services/scheme_management.py +++ b/app/services/scheme_management.py @@ -1,6 +1,6 @@ import ast import json -from datetime import date +from datetime import date, datetime import geopandas as gpd import pandas as pd @@ -8,6 +8,7 @@ import psycopg from sqlalchemy import create_engine from app.core.config import get_pgconn_string +from app.services.time_api import parse_utc_time # 2025/03/23 @@ -89,7 +90,7 @@ def store_scheme_info( scheme_name: str, scheme_type: str, username: str, - scheme_start_time: str, + scheme_start_time: datetime | str, scheme_detail: dict, ): """ @@ -112,13 +113,16 @@ def store_scheme_info( """ # 将字典转换为 JSON 字符串 scheme_detail_json = json.dumps(scheme_detail) + normalized_scheme_start_time = parse_utc_time( + scheme_start_time, field_name="scheme_start_time" + ) cur.execute( sql, ( scheme_name, scheme_type, username, - scheme_start_time, + normalized_scheme_start_time, scheme_detail_json, ), ) diff --git a/app/services/time_api.py b/app/services/time_api.py index 85cc2f4..00904e2 100644 --- a/app/services/time_api.py +++ b/app/services/time_api.py @@ -1,5 +1,6 @@ -from datetime import datetime, timezone, timedelta -from dateutil import parser, tz +from datetime import date, datetime, time, timedelta, timezone + +from dateutil import parser, tz ''' 2025-02-09T15:45:00+00:00 采用的是 ISO 8601 国际标准日期时间格式,具体特点如下: @@ -13,57 +14,67 @@ from dateutil import parser, tz 2025-02-09T15:45:00+08:00 ''' -BG_TZ = tz.gettz('Asia/Shanghai') -UTC_TZ = tz.gettz('UTC') +BG_TZ = tz.gettz("Asia/Shanghai") +UTC_TZ = timezone.utc -def parse_utc_time(query_time: str) -> datetime: - ''' - 接受 任意格式的字符串,如果解析出来不带时区,则用 replace 添加 +00:00 时区 - 如果解析出来已经有时区,则用 astimezone 转换成UTC时间 - ''' +TIMEZONE_REQUIRED_MESSAGE = ( + "Datetime values must include an explicit timezone offset, for example " + "'2025-02-09T15:45:00Z' or '2025-02-09T23:45:00+08:00'." +) - # 解析时间字符串 - dt: datetime = parser.parse(query_time) + +def parse_aware_time(query_time: datetime | str, field_name: str = "datetime") -> datetime: + """ + 解析时间并确保结果带有时区信息。 + """ + dt = parser.parse(query_time) if isinstance(query_time, str) else query_time if dt.tzinfo is None: - dt = dt.replace(tzinfo=UTC_TZ) - else: - dt = dt.astimezone(UTC_TZ) - + raise ValueError(f"{field_name} is missing timezone information. {TIMEZONE_REQUIRED_MESSAGE}") return dt + + +def extract_date(value: date | datetime | str, field_name: str = "date") -> date: + """ + 提取日期部分,但保留调用方原始时区语义,不强制转换到 UTC。 + """ + if isinstance(value, date) and not isinstance(value, datetime): + return value + return parse_aware_time(value, field_name=field_name).date() + + +def utc_now() -> datetime: + """ + 返回带 UTC 时区的当前时间。 + """ + return datetime.now(UTC_TZ) + + +def parse_utc_time(query_time: datetime | str, field_name: str = "datetime") -> datetime: + ''' + 接受带时区的时间字符串/对象,并统一转换成 UTC 时间。 + ''' + return parse_aware_time(query_time, field_name=field_name).astimezone(UTC_TZ) -def parse_beijing_time(query_time: str) -> datetime: + +def parse_beijing_time(query_time: datetime | str, field_name: str = "datetime") -> datetime: ''' - 接受 任意格式的字符串,如果解析出来不带时区,则用 replace 添加 +08:00 时区 - 如果解析出来已经有时区,则用 astimezone 转换成北京时间 - - 也就是任意合法的时间字符串,最后都解析成 北京 时间 - + 接受带时区的时间字符串/对象,并统一转换成北京时间。 ''' - - # 解析时间字符串 - dt: datetime = parser.parse(query_time) - if dt.tzinfo is None: - dt = dt.replace(tzinfo=BG_TZ) - else: - dt = dt.astimezone(tz=BG_TZ) - - return dt + return parse_aware_time(query_time, field_name=field_name).astimezone(tz=BG_TZ) -def to_utc_time(dt: datetime) -> datetime: +def to_utc_time(dt: datetime | str, field_name: str = "datetime") -> datetime: ''' - 将一个北京时间的时间点,转换成utc + 将一个带时区的时间点,转换成 UTC。 ''' - utc_time = dt.astimezone(UTC_TZ) - return utc_time + return parse_aware_time(dt, field_name=field_name).astimezone(UTC_TZ) -def to_beijing_time(dt: datetime) -> datetime: +def to_beijing_time(dt: datetime | str, field_name: str = "datetime") -> datetime: ''' - 将一个 utc 的时间点,转换成北京时间 + 将一个带时区的时间点,转换成北京时间。 ''' - beijing_time = dt.astimezone(tz=BG_TZ) - return beijing_time + return parse_aware_time(dt, field_name=field_name).astimezone(tz=BG_TZ) def to_time_range(dt: datetime, delta: float) -> tuple[datetime, datetime]: @@ -83,7 +94,8 @@ def parse_beijing_date_range(query_date: str) -> tuple[datetime, datetime]: 将一个日期字符串,转换成 start/end 时间段,传进来的日期被认为是北京时间 日期字符串格式:YYYY-MM-DD ''' - start_time = parse_beijing_time(query_date) + target_date = date.fromisoformat(query_date) + start_time = datetime.combine(target_date, time.min, BG_TZ) end_time = start_time + timedelta(days=1) return (start_time, end_time) @@ -108,7 +120,7 @@ def get_date_from_time(time: str) -> str: ''' 将一个时间点,转换成日期 ''' - dt = parse_beijing_time(time) + dt = parse_beijing_time(time, field_name="time") return str(dt.date()) @@ -116,28 +128,27 @@ def is_today(query_date: str) -> bool: ''' 判断一个日期是否是今天 ''' - dt = parse_beijing_time(query_date) - return dt.date() == datetime.now().date() + dt = parse_beijing_time(query_date, field_name="query_date") + return dt.date() == datetime.now(BG_TZ).date() def is_yesterday(query_date: str) -> bool: ''' 判断一个日期是否是昨天 ''' - dt = parse_beijing_time(query_date) - return dt.date() == (datetime.now().date() - timedelta(days=1)) + dt = parse_beijing_time(query_date, field_name="query_date") + return dt.date() == (datetime.now(BG_TZ).date() - timedelta(days=1)) def is_tomorrow(query_date: str) -> bool: ''' 判断一个日期是否是明天 ''' - dt = parse_beijing_time(query_date) - return dt.date() == (datetime.now().date() + timedelta(days=1)) + dt = parse_beijing_time(query_date, field_name="query_date") + return dt.date() == (datetime.now(BG_TZ).date() + timedelta(days=1)) def is_today_or_future(query_date: str) -> bool: ''' 判断一个日期是否是今天或未来 ''' - dt = parse_beijing_time(query_date) - return dt.date() >= datetime.now().date() - + dt = parse_beijing_time(query_date, field_name="query_date") + return dt.date() >= datetime.now(BG_TZ).date() diff --git a/resources/sql/001_create_users_table.sql b/resources/sql/001_create_users_table.sql index d0eb301..5caed32 100644 --- a/resources/sql/001_create_users_table.sql +++ b/resources/sql/001_create_users_table.sql @@ -11,8 +11,8 @@ CREATE TABLE IF NOT EXISTS users ( role VARCHAR(20) DEFAULT 'USER' NOT NULL, is_active BOOLEAN DEFAULT TRUE NOT NULL, is_superuser BOOLEAN DEFAULT FALSE NOT NULL, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL, - updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL, + created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP NOT NULL, + updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP NOT NULL, CONSTRAINT users_role_check CHECK (role IN ('ADMIN', 'OPERATOR', 'USER', 'VIEWER')) ); diff --git a/resources/sql/002_create_audit_logs_table.sql b/resources/sql/002_create_audit_logs_table.sql index 5fdc1c1..6f0d9fe 100644 --- a/resources/sql/002_create_audit_logs_table.sql +++ b/resources/sql/002_create_audit_logs_table.sql @@ -17,7 +17,7 @@ CREATE TABLE IF NOT EXISTS audit_logs ( request_data JSONB, response_status INTEGER, error_message TEXT, - timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL + timestamp TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP NOT NULL ); -- 创建索引以提高查询性能 diff --git a/resources/sql/003_normalize_timestamp_columns.sql b/resources/sql/003_normalize_timestamp_columns.sql new file mode 100644 index 0000000..bd99af4 --- /dev/null +++ b/resources/sql/003_normalize_timestamp_columns.sql @@ -0,0 +1,63 @@ +-- ============================================ +-- TJWater Server 时区统一迁移脚本 +-- 将历史无时区时间列升级为 TIMESTAMP WITH TIME ZONE +-- 约定:历史无时区值按 UTC 解释 +-- ============================================ + +DO $$ +BEGIN + IF EXISTS ( + SELECT 1 + FROM information_schema.columns + WHERE table_schema = 'public' + AND table_name = 'users' + AND column_name = 'created_at' + AND data_type = 'timestamp without time zone' + ) THEN + EXECUTE 'ALTER TABLE public.users + ALTER COLUMN created_at TYPE TIMESTAMP WITH TIME ZONE + USING created_at AT TIME ZONE ''UTC'''; + END IF; + + IF EXISTS ( + SELECT 1 + FROM information_schema.columns + WHERE table_schema = 'public' + AND table_name = 'users' + AND column_name = 'updated_at' + AND data_type = 'timestamp without time zone' + ) THEN + EXECUTE 'ALTER TABLE public.users + ALTER COLUMN updated_at TYPE TIMESTAMP WITH TIME ZONE + USING updated_at AT TIME ZONE ''UTC'''; + END IF; + + IF EXISTS ( + SELECT 1 + FROM information_schema.columns + WHERE table_schema = 'public' + AND table_name = 'audit_logs' + AND column_name = 'timestamp' + AND data_type = 'timestamp without time zone' + ) THEN + EXECUTE 'ALTER TABLE public.audit_logs + ALTER COLUMN timestamp TYPE TIMESTAMP WITH TIME ZONE + USING "timestamp" AT TIME ZONE ''UTC'''; + END IF; + + IF EXISTS ( + SELECT 1 + FROM information_schema.columns + WHERE table_schema = 'public' + AND table_name = 'scheme_list' + AND column_name = 'scheme_start_time' + AND data_type IN ('character varying', 'text') + ) THEN + EXECUTE 'ALTER TABLE public.scheme_list + ALTER COLUMN scheme_start_time TYPE TIMESTAMP WITH TIME ZONE + USING CASE + WHEN scheme_start_time ~ ''(Z|[+-][0-9]{2}:[0-9]{2})$'' THEN scheme_start_time::timestamptz + ELSE scheme_start_time::timestamp AT TIME ZONE ''UTC'' + END'; + END IF; +END $$; diff --git a/resources/sql/create/40.scheme_list.sql b/resources/sql/create/40.scheme_list.sql index 4dbb7f9..0b8d007 100644 --- a/resources/sql/create/40.scheme_list.sql +++ b/resources/sql/create/40.scheme_list.sql @@ -9,6 +9,6 @@ create table scheme_list ( scheme_type varchar(32) not null, username varchar(32) not null REFERENCES "users"(username) ON UPDATE CASCADE ON DELETE RESTRICT, create_time TIMESTAMP WITH TIME ZONE not null DEFAULT date_trunc('minute', CURRENT_TIMESTAMP), - scheme_start_time varchar(50) not null, + scheme_start_time TIMESTAMP WITH TIME ZONE not null, scheme_detail JSON -) \ No newline at end of file +) diff --git a/tests/unit/test_burst_location_service.py b/tests/unit/test_burst_location_service.py index abe38a1..e5132fb 100644 --- a/tests/unit/test_burst_location_service.py +++ b/tests/unit/test_burst_location_service.py @@ -1,7 +1,7 @@ import importlib.util import sys import types -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from pathlib import Path import pytest @@ -30,6 +30,24 @@ def _load_burst_location_module(): ]: ensure_package(package_name) + time_api_module = types.ModuleType("app.services.time_api") + time_api_module.parse_utc_time = ( + lambda value, field_name="datetime": ( + value.astimezone(timezone.utc) + if isinstance(value, datetime) and value.tzinfo is not None + else datetime.fromisoformat(value).astimezone(timezone.utc) + ) + ) + time_api_module.extract_date = ( + lambda value, field_name="date": ( + value.date() + if isinstance(value, datetime) + else datetime.fromisoformat(value).date() + ) + ) + time_api_module.utc_now = lambda: datetime.now(timezone.utc) + sys.modules["app.services.time_api"] = time_api_module + algorithms_module = types.ModuleType("app.algorithms.burst_location") algorithms_module.run_burst_location = lambda **kwargs: {} sys.modules["app.algorithms.burst_location"] = algorithms_module @@ -125,16 +143,16 @@ def test_run_burst_location_uses_single_timerange_with_burst_source_split(monkey def fake_scheme_query(**kwargs): scheme_calls.append(kwargs) + start_hour = datetime.fromisoformat(kwargs["start_time"]).astimezone( + timezone(timedelta(hours=8)) + ).hour if kwargs["element_type"] == "node" and kwargs["field"] == "pressure": - start_hour = datetime.fromisoformat(kwargs["start_time"]).hour values = [12.0, 14.0, 16.0, 18.0] if start_hour == 8 else [8.0, 10.0, 12.0, 14.0] return {"J1": _build_series(kwargs["start_time"], values)} if kwargs["element_type"] == "link" and kwargs["field"] == "flow": - start_hour = datetime.fromisoformat(kwargs["start_time"]).hour values = [5.0, 7.0, 9.0, 11.0] if start_hour == 8 else [2.0, 4.0, 6.0, 8.0] return {"P1": _build_series(kwargs["start_time"], values)} if kwargs["element_type"] == "node" and kwargs["field"] == "actual_demand": - start_hour = datetime.fromisoformat(kwargs["start_time"]).hour values = [3.0, 5.0, 7.0, 9.0] if start_hour == 8 else [1.0, 3.0, 5.0, 7.0] return {"J2": _build_series(kwargs["start_time"], values)} raise AssertionError(f"Unexpected scheme query: {kwargs}") @@ -167,8 +185,8 @@ def test_run_burst_location_uses_single_timerange_with_burst_source_split(monkey simulation_scheme_name="BurstSchemeA", simulation_scheme_type="burst_analysis", burst_leakage=10.0, - scada_burst_start=datetime(2025, 1, 1, 8, 0, 0), - scada_burst_end=datetime(2025, 1, 1, 9, 0, 0), + scada_burst_start=datetime(2025, 1, 1, 8, 0, 0, tzinfo=timezone(timedelta(hours=8))), + scada_burst_end=datetime(2025, 1, 1, 9, 0, 0, tzinfo=timezone(timedelta(hours=8))), use_scada_flow=True, ) @@ -192,14 +210,14 @@ def test_run_burst_location_uses_single_timerange_with_burst_source_split(monkey assert any(call["element_type"] == "link" and call["field"] == "flow" for call in scheme_calls) assert any(call["element_type"] == "node" and call["field"] == "actual_demand" for call in scheme_calls) assert len(realtime_calls) == 3 - assert all(datetime.fromisoformat(call["start_time"]).hour == 8 for call in realtime_calls) - assert all(datetime.fromisoformat(call["end_time"]).hour == 9 for call in realtime_calls) + assert all(datetime.fromisoformat(call["start_time"]).hour == 0 for call in realtime_calls) + assert all(datetime.fromisoformat(call["end_time"]).hour == 1 for call in realtime_calls) assert any(call["element_type"] == "node" and call["field"] == "pressure" for call in realtime_calls) assert any(call["element_type"] == "link" and call["field"] == "flow" for call in realtime_calls) assert any(call["element_type"] == "node" and call["field"] == "actual_demand" for call in realtime_calls) assert result["scada_window"] == { - "burst_start": "2025-01-01T08:00:00", - "burst_end": "2025-01-01T09:00:00", + "burst_start": "2025-01-01T00:00:00+00:00", + "burst_end": "2025-01-01T01:00:00+00:00", } @@ -225,8 +243,8 @@ def test_run_burst_location_requires_simulation_scheme_name(monkeypatch, tmp_pat username="testuser", data_source="simulation", burst_leakage=1.0, - scada_burst_start=datetime(2025, 1, 1, 8, 0, 0), - scada_burst_end=datetime(2025, 1, 1, 9, 0, 0), + scada_burst_start=datetime(2025, 1, 1, 8, 0, 0, tzinfo=timezone(timedelta(hours=8))), + scada_burst_end=datetime(2025, 1, 1, 9, 0, 0, tzinfo=timezone(timedelta(hours=8))), ) @@ -290,8 +308,8 @@ def test_run_burst_location_monitoring_uses_scada_for_burst_and_realtime_for_nor username="testuser", data_source="monitoring", burst_leakage=1.0, - scada_burst_start=datetime(2025, 1, 1, 8, 0, 0), - scada_burst_end=datetime(2025, 1, 1, 9, 0, 0), + scada_burst_start=datetime(2025, 1, 1, 8, 0, 0, tzinfo=timezone(timedelta(hours=8))), + scada_burst_end=datetime(2025, 1, 1, 9, 0, 0, tzinfo=timezone(timedelta(hours=8))), ) assert result["observed_source"] == "scada_burst_realtime_normal_timerange" diff --git a/tests/unit/test_time_api.py b/tests/unit/test_time_api.py new file mode 100644 index 0000000..ef1bbfb --- /dev/null +++ b/tests/unit/test_time_api.py @@ -0,0 +1,45 @@ +import importlib.util +from datetime import date, datetime, timedelta, timezone +from pathlib import Path + +import pytest + + +def _load_time_api_module(): + module_path = ( + Path(__file__).resolve().parents[2] / "app" / "services" / "time_api.py" + ) + spec = importlib.util.spec_from_file_location("tests_time_api_under_test", module_path) + module = importlib.util.module_from_spec(spec) + assert spec and spec.loader + spec.loader.exec_module(module) + return module + + +def test_parse_utc_time_rejects_naive_datetimes(): + module = _load_time_api_module() + + with pytest.raises(ValueError, match="timezone information"): + module.parse_utc_time("2025-01-01T08:00:00") + + +def test_parse_utc_time_normalizes_offset_datetime_to_utc(): + module = _load_time_api_module() + result = module.parse_utc_time("2025-01-01T08:00:00+08:00") + + assert result == datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc) + + +def test_extract_date_keeps_original_offset_calendar_day(): + module = _load_time_api_module() + result = module.extract_date("2025-01-01T00:30:00+08:00") + + assert result == date(2025, 1, 1) + + +def test_utc_now_returns_timezone_aware_utc_datetime(): + module = _load_time_api_module() + result = module.utc_now() + + assert result.tzinfo == timezone.utc + assert result.utcoffset() == timedelta(0)