Files
TJWaterServerBinary/cli/tjwater_agent_cli/core.py
T

648 lines
19 KiB
Python

from __future__ import annotations
import json
import os
import time
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Mapping
import requests
import typer
SCHEMA_VERSION = "tjwater-cli/v1"
DEFAULT_TIMEOUT = 60
DEFAULT_SERVER = "http://192.168.1.114:8000"
class CLIError(Exception):
def __init__(
self,
summary: str,
*,
code: str,
message: str,
exit_code: int,
retryable: bool = False,
next_commands: list[str] | None = None,
data: Any = None,
) -> None:
super().__init__(message)
self.summary = summary
self.code = code
self.message = message
self.exit_code = exit_code
self.retryable = retryable
self.next_commands = next_commands or []
self.data = data
@dataclass(frozen=True)
class AuthContext:
server: str | None = None
access_token: str | None = None
project_id: str | None = None
user_id: str | None = None
username: str | None = None
network: str | None = None
headers: dict[str, str] = field(default_factory=dict)
@dataclass(frozen=True)
class RuntimeContext:
server: str | None
auth: AuthContext
scheme: str | None
timeout: int
request_id: str
@dataclass(frozen=True)
class CommandOptionDoc:
name: str
description: str
required: bool = False
repeated: bool = False
default: Any = None
@dataclass(frozen=True)
class CommandDoc:
path: tuple[str, ...]
summary: str
description: str
options: tuple[CommandOptionDoc, ...] = ()
examples: tuple[str, ...] = ()
next_commands: tuple[str, ...] = ()
output: str = "标准 JSON 输出"
def _read_json_file(path: Path) -> dict[str, Any]:
try:
return json.loads(path.read_text(encoding="utf-8"))
except FileNotFoundError as exc:
raise CLIError(
"认证失败",
code="AUTH_CONTEXT_NOT_FOUND",
message=f"auth context file not found: {path}",
exit_code=3,
) from exc
except json.JSONDecodeError as exc:
raise CLIError(
"认证失败",
code="AUTH_CONTEXT_INVALID",
message=f"auth context file is not valid JSON: {path}",
exit_code=3,
) from exc
def _pick(mapping: Mapping[str, Any], *keys: str) -> Any:
for key in keys:
value = mapping.get(key)
if value not in (None, ""):
return value
return None
def load_auth_context(auth_context_path: Path | None) -> AuthContext:
raw: dict[str, Any] = {}
if auth_context_path is not None:
raw = _read_json_file(auth_context_path)
else:
extra_headers = os.getenv("TJWATER_EXTRA_HEADERS")
raw = {
"server": os.getenv("TJWATER_SERVER"),
"access_token": os.getenv("TJWATER_ACCESS_TOKEN"),
"project_id": os.getenv("TJWATER_PROJECT_ID"),
"user_id": os.getenv("TJWATER_USER_ID"),
"username": os.getenv("TJWATER_USERNAME"),
"network": os.getenv("TJWATER_NETWORK"),
"headers": json.loads(extra_headers) if extra_headers else {},
}
headers = raw.get("headers") or {}
if not isinstance(headers, dict):
raise CLIError(
"认证失败",
code="AUTH_CONTEXT_INVALID",
message="auth context headers must be a JSON object",
exit_code=3,
)
return AuthContext(
server=_pick(raw, "server", "base_url"),
access_token=_pick(raw, "access_token", "token", "accessToken"),
project_id=_pick(raw, "project_id", "projectId", "x_project_id"),
user_id=_pick(raw, "user_id", "userId", "x_user_id"),
username=_pick(raw, "username", "preferred_username"),
network=_pick(raw, "network", "project_code", "projectCode", "project"),
headers={str(key): str(value) for key, value in headers.items()},
)
def build_runtime_context(
*,
server: str | None,
auth_context_path: Path | None,
scheme: str | None,
timeout: int,
request_id: str | None,
) -> RuntimeContext:
auth = load_auth_context(auth_context_path)
resolved_request_id = request_id or str(uuid.uuid4())
return RuntimeContext(
server=server or auth.server or DEFAULT_SERVER,
auth=auth,
scheme=scheme,
timeout=timeout,
request_id=resolved_request_id,
)
def require_server(ctx: RuntimeContext) -> str:
if ctx.server:
return ctx.server.rstrip("/")
raise CLIError(
"认证失败",
code="SERVER_REQUIRED",
message="missing server URL; use --server or include server in auth context",
exit_code=3,
)
def require_access_token(ctx: RuntimeContext) -> str:
if ctx.auth.access_token:
return ctx.auth.access_token
raise CLIError(
"认证失败",
code="UNAUTHENTICATED",
message="missing access token for agent context",
exit_code=3,
next_commands=["tjwater <command> --auth-context /path/to/auth-context.json"],
)
def require_project_id(ctx: RuntimeContext) -> str:
if ctx.auth.project_id:
return ctx.auth.project_id
raise CLIError(
"认证失败",
code="PROJECT_CONTEXT_REQUIRED",
message="missing project_id for agent context",
exit_code=3,
next_commands=["add project_id to the auth context file"],
)
def require_network(ctx: RuntimeContext) -> str:
if ctx.auth.network:
return ctx.auth.network
raise CLIError(
"认证失败",
code="NETWORK_CONTEXT_REQUIRED",
message="missing network in auth context for legacy network-based endpoints",
exit_code=3,
next_commands=["add network to the auth context file"],
)
def require_username(ctx: RuntimeContext) -> str:
if ctx.auth.username:
return ctx.auth.username
raise CLIError(
"认证失败",
code="USERNAME_CONTEXT_REQUIRED",
message="missing username in auth context",
exit_code=3,
next_commands=["add username to the auth context file"],
)
def resolve_scheme(ctx: RuntimeContext, explicit_scheme: str | None, *, required: bool = False) -> str | None:
scheme = explicit_scheme or ctx.scheme
if required and not scheme:
raise CLIError(
"CLI 参数错误",
code="SCHEME_REQUIRED",
message="missing scheme; use --scheme",
exit_code=2,
)
return scheme
def parse_time_with_timezone(value: str, *, option_name: str) -> datetime:
try:
parsed = datetime.fromisoformat(value)
except ValueError as exc:
raise CLIError(
"CLI 参数错误",
code="INVALID_TIME",
message=f"{option_name} must be a valid ISO 8601 / RFC 3339 timestamp",
exit_code=2,
) from exc
if parsed.tzinfo is None:
raise CLIError(
"CLI 参数错误",
code="TIMEZONE_REQUIRED",
message=f"{option_name} must include an explicit timezone offset",
exit_code=2,
)
return parsed
def read_json_input(path: Path, *, label: str) -> Any:
try:
return json.loads(path.read_text(encoding="utf-8"))
except FileNotFoundError as exc:
raise CLIError(
"CLI 参数错误",
code="INPUT_NOT_FOUND",
message=f"{label} file not found: {path}",
exit_code=2,
) from exc
except json.JSONDecodeError as exc:
raise CLIError(
"CLI 参数错误",
code="INPUT_INVALID_JSON",
message=f"{label} file must be valid JSON: {path}",
exit_code=2,
) from exc
def parse_burst_file(path: Path) -> tuple[list[str], list[float]]:
raw = read_json_input(path, label="burst")
if isinstance(raw, dict) and "bursts" in raw:
raw = raw["bursts"]
if isinstance(raw, dict) and "burst_ID" in raw and "burst_size" in raw:
ids = [str(item) for item in raw["burst_ID"]]
sizes = [float(item) for item in raw["burst_size"]]
if len(ids) != len(sizes):
raise CLIError(
"CLI 参数错误",
code="BURST_FILE_INVALID",
message="burst file burst_ID and burst_size must have the same length",
exit_code=2,
)
return ids, sizes
if isinstance(raw, list):
ids: list[str] = []
sizes: list[float] = []
for item in raw:
if not isinstance(item, dict) or "id" not in item or "size" not in item:
raise CLIError(
"CLI 参数错误",
code="BURST_FILE_INVALID",
message="burst file items must contain id and size",
exit_code=2,
)
ids.append(str(item["id"]))
sizes.append(float(item["size"]))
return ids, sizes
raise CLIError(
"CLI 参数错误",
code="BURST_FILE_INVALID",
message="burst file must be a JSON array or object with burst_ID/burst_size",
exit_code=2,
)
def parse_valve_setting_file(path: Path) -> tuple[list[str], list[float]]:
raw = read_json_input(path, label="valve-setting")
if isinstance(raw, dict) and "valves" in raw and "valves_k" in raw:
valves = [str(item) for item in raw["valves"]]
openings = [float(item) for item in raw["valves_k"]]
if len(valves) != len(openings):
raise CLIError(
"CLI 参数错误",
code="VALVE_SETTING_INVALID",
message="valves and valves_k must have the same length",
exit_code=2,
)
return valves, openings
if isinstance(raw, list):
valves: list[str] = []
openings: list[float] = []
for item in raw:
if not isinstance(item, dict) or "valve" not in item or "opening" not in item:
raise CLIError(
"CLI 参数错误",
code="VALVE_SETTING_INVALID",
message="valve-setting items must contain valve and opening",
exit_code=2,
)
valves.append(str(item["valve"]))
openings.append(float(item["opening"]))
return valves, openings
raise CLIError(
"CLI 参数错误",
code="VALVE_SETTING_INVALID",
message="valve-setting file must be a JSON array or object with valves/valves_k",
exit_code=2,
)
def parse_optional_dataset_file(path: Path | None, *, label: str) -> Any:
if path is None:
return None
return read_json_input(path, label=label)
def build_headers(
ctx: RuntimeContext,
*,
require_auth: bool,
require_project: bool,
) -> dict[str, str]:
headers = {
"Accept": "application/json, text/plain, */*",
"X-Request-Id": ctx.request_id,
}
headers.update(ctx.auth.headers)
if require_auth:
headers["Authorization"] = f"Bearer {require_access_token(ctx)}"
elif ctx.auth.access_token:
headers["Authorization"] = f"Bearer {ctx.auth.access_token}"
if require_project:
headers["X-Project-Id"] = require_project_id(ctx)
elif ctx.auth.project_id:
headers["X-Project-Id"] = ctx.auth.project_id
if ctx.auth.user_id:
headers["X-User-Id"] = ctx.auth.user_id
return headers
def _extract_error_message(response: requests.Response) -> str:
try:
payload = response.json()
except ValueError:
text = response.text.strip()
return text or f"http {response.status_code}"
if isinstance(payload, dict):
detail = payload.get("detail")
if isinstance(detail, str):
return detail
if isinstance(detail, list):
return "; ".join(json.dumps(item, ensure_ascii=False) for item in detail)
message = payload.get("message")
if isinstance(message, str):
return message
return json.dumps(payload, ensure_ascii=False)
def map_http_status_to_exit_code(status_code: int) -> int:
if status_code in (400, 422):
return 2
if status_code == 401:
return 3
if status_code == 403:
return 4
if status_code == 404:
return 5
if status_code in (409, 412):
return 6
return 7
def _parse_response_body(response: requests.Response) -> Any:
if response.status_code == 204 or not response.content:
return {}
content_type = response.headers.get("content-type", "").lower()
if "application/json" in content_type:
payload = response.json()
if isinstance(payload, dict) and payload.get("status") == "error":
raise CLIError(
"服务端错误",
code="SERVER_ERROR",
message=str(payload.get("message") or "server returned error status"),
exit_code=7,
data=payload,
)
return payload
text = response.text
if text:
return {"report": text}
return {}
def request_json(
ctx: RuntimeContext,
*,
method: str,
path: str,
params: dict[str, Any] | None = None,
json_body: Any = None,
require_auth: bool = True,
require_project: bool = False,
require_network_ctx: bool = False,
require_username_ctx: bool = False,
) -> tuple[Any, int]:
require_server(ctx)
if require_network_ctx:
require_network(ctx)
if require_username_ctx:
require_username(ctx)
url = f"{require_server(ctx)}/api/v1{path}"
headers = build_headers(ctx, require_auth=require_auth, require_project=require_project)
started = time.monotonic()
try:
response = requests.request(
method=method.upper(),
url=url,
params=params,
json=json_body,
headers=headers,
timeout=ctx.timeout,
)
except requests.Timeout as exc:
raise CLIError(
"请求超时",
code="REQUEST_TIMEOUT",
message=f"request timed out after {ctx.timeout} seconds",
exit_code=7,
retryable=True,
) from exc
except requests.RequestException as exc:
raise CLIError(
"连接失败",
code="REQUEST_FAILED",
message=str(exc),
exit_code=7,
retryable=True,
) from exc
duration_ms = int((time.monotonic() - started) * 1000)
if not response.ok:
raise CLIError(
"请求失败",
code=f"HTTP_{response.status_code}",
message=_extract_error_message(response),
exit_code=map_http_status_to_exit_code(response.status_code),
retryable=response.status_code >= 500,
)
return _parse_response_body(response), duration_ms
def request_bytes(
ctx: RuntimeContext,
*,
method: str,
path: str,
params: dict[str, Any] | None = None,
require_auth: bool = True,
require_project: bool = False,
require_network_ctx: bool = False,
) -> tuple[bytes, int]:
require_server(ctx)
if require_network_ctx:
require_network(ctx)
url = f"{require_server(ctx)}/api/v1{path}"
headers = build_headers(ctx, require_auth=require_auth, require_project=require_project)
started = time.monotonic()
try:
response = requests.request(
method=method.upper(),
url=url,
params=params,
headers=headers,
timeout=ctx.timeout,
)
except requests.Timeout as exc:
raise CLIError(
"请求超时",
code="REQUEST_TIMEOUT",
message=f"request timed out after {ctx.timeout} seconds",
exit_code=7,
retryable=True,
) from exc
except requests.RequestException as exc:
raise CLIError(
"连接失败",
code="REQUEST_FAILED",
message=str(exc),
exit_code=7,
retryable=True,
) from exc
duration_ms = int((time.monotonic() - started) * 1000)
if not response.ok:
raise CLIError(
"请求失败",
code=f"HTTP_{response.status_code}",
message=_extract_error_message(response),
exit_code=map_http_status_to_exit_code(response.status_code),
retryable=response.status_code >= 500,
)
return response.content, duration_ms
def build_success_payload(
*,
summary: str,
data: Any,
server: str | None,
request_id: str,
duration_ms: int,
next_commands: list[str] | None = None,
) -> dict[str, Any]:
return {
"ok": True,
"schema_version": SCHEMA_VERSION,
"summary": summary,
"data": data,
"metadata": {
"request_id": request_id,
"server": server,
"duration_ms": duration_ms,
"generated_at": datetime.now(timezone.utc).isoformat(timespec="seconds").replace("+00:00", "Z"),
},
"next_commands": next_commands or [],
}
def build_failure_payload(
*,
summary: str,
code: str,
message: str,
retryable: bool,
server: str | None,
request_id: str | None,
next_commands: list[str] | None = None,
data: Any = None,
) -> dict[str, Any]:
return {
"ok": False,
"schema_version": SCHEMA_VERSION,
"summary": summary,
"error": {
"code": code,
"message": message,
"retryable": retryable,
},
"data": data,
"metadata": {
"request_id": request_id,
"server": server,
"generated_at": datetime.now(timezone.utc).isoformat(timespec="seconds").replace("+00:00", "Z"),
},
"next_commands": next_commands or [],
}
def emit_success(
*,
summary: str,
data: Any,
ctx: RuntimeContext,
duration_ms: int,
next_commands: list[str] | None = None,
) -> None:
typer.echo(
json.dumps(
build_success_payload(
summary=summary,
data=data,
server=ctx.server,
request_id=ctx.request_id,
duration_ms=duration_ms,
next_commands=next_commands,
),
ensure_ascii=False,
)
)
def emit_failure(
*,
summary: str,
code: str,
message: str,
exit_code: int,
retryable: bool,
server: str | None,
request_id: str | None,
next_commands: list[str] | None = None,
data: Any = None,
) -> int:
typer.echo(
json.dumps(
build_failure_payload(
summary=summary,
code=code,
message=message,
retryable=retryable,
server=server,
request_id=request_id,
next_commands=next_commands,
data=data,
),
ensure_ascii=False,
)
)
return exit_code