from __future__ import annotations import json import os import sys import time import uuid from dataclasses import dataclass, field from datetime import datetime, timezone from pathlib import Path from typing import Any, Mapping import requests import typer SCHEMA_VERSION = "tjwater-cli/v1" CLI_NAME = "tjwater-cli" DEFAULT_TIMEOUT = 60 DEFAULT_SERVER = "http://192.168.1.114:8000" class CLIError(Exception): def __init__( self, summary: str, *, code: str, message: str, exit_code: int, retryable: bool = False, next_commands: list[str] | None = None, data: Any = None, ) -> None: super().__init__(message) self.summary = summary self.code = code self.message = message self.exit_code = exit_code self.retryable = retryable self.next_commands = next_commands or [] self.data = data @dataclass(frozen=True) class AuthContext: server: str | None = None access_token: str | None = None project_id: str | None = None user_id: str | None = None username: str | None = None network: str | None = None headers: dict[str, str] = field(default_factory=dict) @dataclass(frozen=True) class RuntimeContext: server: str | None auth: AuthContext scheme: str | None timeout: int request_id: str @dataclass(frozen=True) class CommandOptionDoc: name: str description: str required: bool = False repeated: bool = False default: Any = None @dataclass(frozen=True) class CommandDoc: path: tuple[str, ...] summary: str description: str options: tuple[CommandOptionDoc, ...] = () examples: tuple[str, ...] = () next_commands: tuple[str, ...] = () output: str = "标准 JSON 输出" def _pick(mapping: Mapping[str, Any], *keys: str) -> Any: for key in keys: value = mapping.get(key) if value not in (None, ""): return value return None def load_auth_context(auth_stdin: bool = False) -> AuthContext: if auth_stdin: raw = json.loads(sys.stdin.read()) else: extra_headers = os.getenv("TJWATER_EXTRA_HEADERS") raw = { "server": os.getenv("TJWATER_SERVER"), "access_token": os.getenv("TJWATER_ACCESS_TOKEN"), "project_id": os.getenv("TJWATER_PROJECT_ID"), "user_id": os.getenv("TJWATER_USER_ID"), "username": os.getenv("TJWATER_USERNAME"), "network": os.getenv("TJWATER_NETWORK"), "headers": json.loads(extra_headers) if extra_headers else {}, } headers = raw.get("headers") or {} if not isinstance(headers, dict): raise CLIError( "认证失败", code="AUTH_CONTEXT_INVALID", message="auth context headers must be a JSON object", exit_code=3, ) return AuthContext( server=_pick(raw, "server", "base_url"), access_token=_pick(raw, "access_token", "token", "accessToken"), project_id=_pick(raw, "project_id", "projectId", "x_project_id"), user_id=_pick(raw, "user_id", "userId", "x_user_id"), username=_pick(raw, "username", "preferred_username"), network="tjwater", headers={str(key): str(value) for key, value in headers.items()}, ) def build_runtime_context( *, server: str | None, auth_stdin: bool = False, scheme: str | None, timeout: int, request_id: str | None, ) -> RuntimeContext: auth = load_auth_context(auth_stdin=auth_stdin) resolved_request_id = request_id or str(uuid.uuid4()) return RuntimeContext( server=server or auth.server or DEFAULT_SERVER, auth=auth, scheme=scheme, timeout=timeout, request_id=resolved_request_id, ) def require_server(ctx: RuntimeContext) -> str: if ctx.server: return ctx.server.rstrip("/") raise CLIError( "认证失败", code="SERVER_REQUIRED", message="missing server URL; use --server or include server in auth context", exit_code=3, ) def require_access_token(ctx: RuntimeContext) -> str: if ctx.auth.access_token: return ctx.auth.access_token raise CLIError( "认证失败", code="UNAUTHENTICATED", message="missing access token for agent context", exit_code=3, next_commands=["provide access_token via --auth-stdin or TJWATER_ACCESS_TOKEN env var"], ) def require_project_id(ctx: RuntimeContext) -> str: if ctx.auth.project_id: return ctx.auth.project_id raise CLIError( "认证失败", code="PROJECT_CONTEXT_REQUIRED", message="missing project_id for agent context", exit_code=3, next_commands=["add project_id to auth context"], ) def require_network(ctx: RuntimeContext) -> str: if ctx.auth.network: return ctx.auth.network raise CLIError( "认证失败", code="NETWORK_CONTEXT_REQUIRED", message="missing network in auth context for legacy network-based endpoints", exit_code=3, next_commands=["add network to auth context"], ) def require_username(ctx: RuntimeContext) -> str: if ctx.auth.username: return ctx.auth.username raise CLIError( "认证失败", code="USERNAME_CONTEXT_REQUIRED", message="missing username in auth context", exit_code=3, next_commands=["add username to auth context"], ) def resolve_scheme(ctx: RuntimeContext, explicit_scheme: str | None, *, required: bool = False) -> str | None: scheme = explicit_scheme or ctx.scheme if required and not scheme: raise CLIError( "CLI 参数错误", code="SCHEME_REQUIRED", message="missing scheme; use --scheme", exit_code=2, ) return scheme def parse_time_with_timezone(value: str, *, option_name: str) -> datetime: try: parsed = datetime.fromisoformat(value) except ValueError as exc: raise CLIError( "CLI 参数错误", code="INVALID_TIME", message=f"{option_name} must be a valid ISO 8601 / RFC 3339 timestamp", exit_code=2, ) from exc if parsed.tzinfo is None: raise CLIError( "CLI 参数错误", code="TIMEZONE_REQUIRED", message=f"{option_name} must include an explicit timezone offset", exit_code=2, ) return parsed def read_json_input(path: Path, *, label: str) -> Any: try: return json.loads(path.read_text(encoding="utf-8")) except FileNotFoundError as exc: raise CLIError( "CLI 参数错误", code="INPUT_NOT_FOUND", message=f"{label} file not found: {path}", exit_code=2, ) from exc except json.JSONDecodeError as exc: raise CLIError( "CLI 参数错误", code="INPUT_INVALID_JSON", message=f"{label} file must be valid JSON: {path}", exit_code=2, ) from exc def parse_burst_file(path: Path) -> tuple[list[str], list[float]]: raw = read_json_input(path, label="burst") if isinstance(raw, dict) and "bursts" in raw: raw = raw["bursts"] if isinstance(raw, dict) and "burst_ID" in raw and "burst_size" in raw: ids = [str(item) for item in raw["burst_ID"]] sizes = [float(item) for item in raw["burst_size"]] if len(ids) != len(sizes): raise CLIError( "CLI 参数错误", code="BURST_FILE_INVALID", message="burst file burst_ID and burst_size must have the same length", exit_code=2, ) return ids, sizes if isinstance(raw, list): ids: list[str] = [] sizes: list[float] = [] for item in raw: if not isinstance(item, dict) or "id" not in item or "size" not in item: raise CLIError( "CLI 参数错误", code="BURST_FILE_INVALID", message="burst file items must contain id and size", exit_code=2, ) ids.append(str(item["id"])) sizes.append(float(item["size"])) return ids, sizes raise CLIError( "CLI 参数错误", code="BURST_FILE_INVALID", message="burst file must be a JSON array or object with burst_ID/burst_size", exit_code=2, ) def parse_valve_setting_file(path: Path) -> tuple[list[str], list[float]]: raw = read_json_input(path, label="valve-setting") if isinstance(raw, dict) and "valves" in raw and "valves_k" in raw: valves = [str(item) for item in raw["valves"]] openings = [float(item) for item in raw["valves_k"]] if len(valves) != len(openings): raise CLIError( "CLI 参数错误", code="VALVE_SETTING_INVALID", message="valves and valves_k must have the same length", exit_code=2, ) return valves, openings if isinstance(raw, list): valves: list[str] = [] openings: list[float] = [] for item in raw: if not isinstance(item, dict) or "valve" not in item or "opening" not in item: raise CLIError( "CLI 参数错误", code="VALVE_SETTING_INVALID", message="valve-setting items must contain valve and opening", exit_code=2, ) valves.append(str(item["valve"])) openings.append(float(item["opening"])) return valves, openings raise CLIError( "CLI 参数错误", code="VALVE_SETTING_INVALID", message="valve-setting file must be a JSON array or object with valves/valves_k", exit_code=2, ) def parse_optional_dataset_file(path: Path | None, *, label: str) -> Any: if path is None: return None return read_json_input(path, label=label) def build_headers( ctx: RuntimeContext, *, require_auth: bool, require_project: bool, ) -> dict[str, str]: headers = { "Accept": "application/json, text/plain, */*", "X-Request-Id": ctx.request_id, } headers.update(ctx.auth.headers) if require_auth: headers["Authorization"] = f"Bearer {require_access_token(ctx)}" elif ctx.auth.access_token: headers["Authorization"] = f"Bearer {ctx.auth.access_token}" if require_project: headers["X-Project-Id"] = require_project_id(ctx) elif ctx.auth.project_id: headers["X-Project-Id"] = ctx.auth.project_id if ctx.auth.user_id: headers["X-User-Id"] = ctx.auth.user_id return headers def _extract_error_message(response: requests.Response) -> str: try: payload = response.json() except ValueError: text = response.text.strip() return text or f"http {response.status_code}" if isinstance(payload, dict): detail = payload.get("detail") if isinstance(detail, str): return detail if isinstance(detail, list): return "; ".join(json.dumps(item, ensure_ascii=False) for item in detail) message = payload.get("message") if isinstance(message, str): return message return json.dumps(payload, ensure_ascii=False) def map_http_status_to_exit_code(status_code: int) -> int: if status_code in (400, 422): return 2 if status_code == 401: return 3 if status_code == 403: return 4 if status_code == 404: return 5 if status_code in (409, 412): return 6 return 7 def _parse_response_body(response: requests.Response) -> Any: if response.status_code == 204 or not response.content: return {} content_type = response.headers.get("content-type", "").lower() if "application/json" in content_type: payload = response.json() if isinstance(payload, dict) and payload.get("status") == "error": raise CLIError( "服务端错误", code="SERVER_ERROR", message=str(payload.get("message") or "server returned error status"), exit_code=7, data=payload, ) return payload text = response.text if text: return {"report": text} return {} def request_json( ctx: RuntimeContext, *, method: str, path: str, params: dict[str, Any] | None = None, json_body: Any = None, require_auth: bool = True, require_project: bool = False, require_network_ctx: bool = False, require_username_ctx: bool = False, ) -> tuple[Any, int]: require_server(ctx) if require_network_ctx: require_network(ctx) if require_username_ctx: require_username(ctx) url = f"{require_server(ctx)}/api/v1{path}" headers = build_headers(ctx, require_auth=require_auth, require_project=require_project) started = time.monotonic() try: response = requests.request( method=method.upper(), url=url, params=params, json=json_body, headers=headers, timeout=ctx.timeout, ) except requests.Timeout as exc: raise CLIError( "请求超时", code="REQUEST_TIMEOUT", message=f"request timed out after {ctx.timeout} seconds", exit_code=7, retryable=True, ) from exc except requests.RequestException as exc: raise CLIError( "连接失败", code="REQUEST_FAILED", message=str(exc), exit_code=7, retryable=True, ) from exc duration_ms = int((time.monotonic() - started) * 1000) if not response.ok: raise CLIError( "请求失败", code=f"HTTP_{response.status_code}", message=_extract_error_message(response), exit_code=map_http_status_to_exit_code(response.status_code), retryable=response.status_code >= 500, ) return _parse_response_body(response), duration_ms def request_bytes( ctx: RuntimeContext, *, method: str, path: str, params: dict[str, Any] | None = None, require_auth: bool = True, require_project: bool = False, require_network_ctx: bool = False, ) -> tuple[bytes, int]: require_server(ctx) if require_network_ctx: require_network(ctx) url = f"{require_server(ctx)}/api/v1{path}" headers = build_headers(ctx, require_auth=require_auth, require_project=require_project) started = time.monotonic() try: response = requests.request( method=method.upper(), url=url, params=params, headers=headers, timeout=ctx.timeout, ) except requests.Timeout as exc: raise CLIError( "请求超时", code="REQUEST_TIMEOUT", message=f"request timed out after {ctx.timeout} seconds", exit_code=7, retryable=True, ) from exc except requests.RequestException as exc: raise CLIError( "连接失败", code="REQUEST_FAILED", message=str(exc), exit_code=7, retryable=True, ) from exc duration_ms = int((time.monotonic() - started) * 1000) if not response.ok: raise CLIError( "请求失败", code=f"HTTP_{response.status_code}", message=_extract_error_message(response), exit_code=map_http_status_to_exit_code(response.status_code), retryable=response.status_code >= 500, ) return response.content, duration_ms def build_success_payload( *, summary: str, data: Any, server: str | None, request_id: str, duration_ms: int, next_commands: list[str] | None = None, ) -> dict[str, Any]: return { "ok": True, "schema_version": SCHEMA_VERSION, "summary": summary, "data": data, "metadata": { "request_id": request_id, "server": server, "duration_ms": duration_ms, "generated_at": datetime.now(timezone.utc).isoformat(timespec="seconds").replace("+00:00", "Z"), }, "next_commands": next_commands or [], } def build_failure_payload( *, summary: str, code: str, message: str, retryable: bool, server: str | None, request_id: str | None, next_commands: list[str] | None = None, data: Any = None, ) -> dict[str, Any]: return { "ok": False, "schema_version": SCHEMA_VERSION, "summary": summary, "error": { "code": code, "message": message, "retryable": retryable, }, "data": data, "metadata": { "request_id": request_id, "server": server, "generated_at": datetime.now(timezone.utc).isoformat(timespec="seconds").replace("+00:00", "Z"), }, "next_commands": next_commands or [], } def emit_success( *, summary: str, data: Any, ctx: RuntimeContext, duration_ms: int, next_commands: list[str] | None = None, ) -> None: typer.echo( json.dumps( build_success_payload( summary=summary, data=data, server=ctx.server, request_id=ctx.request_id, duration_ms=duration_ms, next_commands=next_commands, ), ensure_ascii=False, ) ) def emit_failure( *, summary: str, code: str, message: str, exit_code: int, retryable: bool, server: str | None, request_id: str | None, next_commands: list[str] | None = None, data: Any = None, ) -> int: typer.echo( json.dumps( build_failure_payload( summary=summary, code=code, message=message, retryable=retryable, server=server, request_id=request_id, next_commands=next_commands, data=data, ), ensure_ascii=False, ) ) return exit_code