649 lines
19 KiB
Python
649 lines
19 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import os
|
|
import time
|
|
import uuid
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from typing import Any, Mapping
|
|
|
|
import requests
|
|
import typer
|
|
|
|
SCHEMA_VERSION = "tjwater-cli/v1"
|
|
CLI_NAME = "tjwater-cli"
|
|
DEFAULT_TIMEOUT = 60
|
|
DEFAULT_SERVER = "http://192.168.1.114:8000"
|
|
|
|
|
|
class CLIError(Exception):
|
|
def __init__(
|
|
self,
|
|
summary: str,
|
|
*,
|
|
code: str,
|
|
message: str,
|
|
exit_code: int,
|
|
retryable: bool = False,
|
|
next_commands: list[str] | None = None,
|
|
data: Any = None,
|
|
) -> None:
|
|
super().__init__(message)
|
|
self.summary = summary
|
|
self.code = code
|
|
self.message = message
|
|
self.exit_code = exit_code
|
|
self.retryable = retryable
|
|
self.next_commands = next_commands or []
|
|
self.data = data
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class AuthContext:
|
|
server: str | None = None
|
|
access_token: str | None = None
|
|
project_id: str | None = None
|
|
user_id: str | None = None
|
|
username: str | None = None
|
|
network: str | None = None
|
|
headers: dict[str, str] = field(default_factory=dict)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class RuntimeContext:
|
|
server: str | None
|
|
auth: AuthContext
|
|
scheme: str | None
|
|
timeout: int
|
|
request_id: str
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class CommandOptionDoc:
|
|
name: str
|
|
description: str
|
|
required: bool = False
|
|
repeated: bool = False
|
|
default: Any = None
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class CommandDoc:
|
|
path: tuple[str, ...]
|
|
summary: str
|
|
description: str
|
|
options: tuple[CommandOptionDoc, ...] = ()
|
|
examples: tuple[str, ...] = ()
|
|
next_commands: tuple[str, ...] = ()
|
|
output: str = "标准 JSON 输出"
|
|
|
|
|
|
def _read_json_file(path: Path) -> dict[str, Any]:
|
|
try:
|
|
return json.loads(path.read_text(encoding="utf-8"))
|
|
except FileNotFoundError as exc:
|
|
raise CLIError(
|
|
"认证失败",
|
|
code="AUTH_CONTEXT_NOT_FOUND",
|
|
message=f"auth context file not found: {path}",
|
|
exit_code=3,
|
|
) from exc
|
|
except json.JSONDecodeError as exc:
|
|
raise CLIError(
|
|
"认证失败",
|
|
code="AUTH_CONTEXT_INVALID",
|
|
message=f"auth context file is not valid JSON: {path}",
|
|
exit_code=3,
|
|
) from exc
|
|
|
|
|
|
def _pick(mapping: Mapping[str, Any], *keys: str) -> Any:
|
|
for key in keys:
|
|
value = mapping.get(key)
|
|
if value not in (None, ""):
|
|
return value
|
|
return None
|
|
|
|
|
|
def load_auth_context(auth_context_path: Path | None) -> AuthContext:
|
|
raw: dict[str, Any] = {}
|
|
if auth_context_path is not None:
|
|
raw = _read_json_file(auth_context_path)
|
|
else:
|
|
extra_headers = os.getenv("TJWATER_EXTRA_HEADERS")
|
|
raw = {
|
|
"server": os.getenv("TJWATER_SERVER"),
|
|
"access_token": os.getenv("TJWATER_ACCESS_TOKEN"),
|
|
"project_id": os.getenv("TJWATER_PROJECT_ID"),
|
|
"user_id": os.getenv("TJWATER_USER_ID"),
|
|
"username": os.getenv("TJWATER_USERNAME"),
|
|
"network": os.getenv("TJWATER_NETWORK"),
|
|
"headers": json.loads(extra_headers) if extra_headers else {},
|
|
}
|
|
|
|
headers = raw.get("headers") or {}
|
|
if not isinstance(headers, dict):
|
|
raise CLIError(
|
|
"认证失败",
|
|
code="AUTH_CONTEXT_INVALID",
|
|
message="auth context headers must be a JSON object",
|
|
exit_code=3,
|
|
)
|
|
|
|
return AuthContext(
|
|
server=_pick(raw, "server", "base_url"),
|
|
access_token=_pick(raw, "access_token", "token", "accessToken"),
|
|
project_id=_pick(raw, "project_id", "projectId", "x_project_id"),
|
|
user_id=_pick(raw, "user_id", "userId", "x_user_id"),
|
|
username=_pick(raw, "username", "preferred_username"),
|
|
network=_pick(raw, "network", "project_code", "projectCode", "project"),
|
|
headers={str(key): str(value) for key, value in headers.items()},
|
|
)
|
|
|
|
|
|
def build_runtime_context(
|
|
*,
|
|
server: str | None,
|
|
auth_context_path: Path | None,
|
|
scheme: str | None,
|
|
timeout: int,
|
|
request_id: str | None,
|
|
) -> RuntimeContext:
|
|
auth = load_auth_context(auth_context_path)
|
|
resolved_request_id = request_id or str(uuid.uuid4())
|
|
return RuntimeContext(
|
|
server=server or auth.server or DEFAULT_SERVER,
|
|
auth=auth,
|
|
scheme=scheme,
|
|
timeout=timeout,
|
|
request_id=resolved_request_id,
|
|
)
|
|
|
|
|
|
def require_server(ctx: RuntimeContext) -> str:
|
|
if ctx.server:
|
|
return ctx.server.rstrip("/")
|
|
raise CLIError(
|
|
"认证失败",
|
|
code="SERVER_REQUIRED",
|
|
message="missing server URL; use --server or include server in auth context",
|
|
exit_code=3,
|
|
)
|
|
|
|
|
|
def require_access_token(ctx: RuntimeContext) -> str:
|
|
if ctx.auth.access_token:
|
|
return ctx.auth.access_token
|
|
raise CLIError(
|
|
"认证失败",
|
|
code="UNAUTHENTICATED",
|
|
message="missing access token for agent context",
|
|
exit_code=3,
|
|
next_commands=["tjwater-cli <command> --auth-context /path/to/auth-context.json"],
|
|
)
|
|
|
|
|
|
def require_project_id(ctx: RuntimeContext) -> str:
|
|
if ctx.auth.project_id:
|
|
return ctx.auth.project_id
|
|
raise CLIError(
|
|
"认证失败",
|
|
code="PROJECT_CONTEXT_REQUIRED",
|
|
message="missing project_id for agent context",
|
|
exit_code=3,
|
|
next_commands=["add project_id to the auth context file"],
|
|
)
|
|
|
|
|
|
def require_network(ctx: RuntimeContext) -> str:
|
|
if ctx.auth.network:
|
|
return ctx.auth.network
|
|
raise CLIError(
|
|
"认证失败",
|
|
code="NETWORK_CONTEXT_REQUIRED",
|
|
message="missing network in auth context for legacy network-based endpoints",
|
|
exit_code=3,
|
|
next_commands=["add network to the auth context file"],
|
|
)
|
|
|
|
|
|
def require_username(ctx: RuntimeContext) -> str:
|
|
if ctx.auth.username:
|
|
return ctx.auth.username
|
|
raise CLIError(
|
|
"认证失败",
|
|
code="USERNAME_CONTEXT_REQUIRED",
|
|
message="missing username in auth context",
|
|
exit_code=3,
|
|
next_commands=["add username to the auth context file"],
|
|
)
|
|
|
|
|
|
def resolve_scheme(ctx: RuntimeContext, explicit_scheme: str | None, *, required: bool = False) -> str | None:
|
|
scheme = explicit_scheme or ctx.scheme
|
|
if required and not scheme:
|
|
raise CLIError(
|
|
"CLI 参数错误",
|
|
code="SCHEME_REQUIRED",
|
|
message="missing scheme; use --scheme",
|
|
exit_code=2,
|
|
)
|
|
return scheme
|
|
|
|
|
|
def parse_time_with_timezone(value: str, *, option_name: str) -> datetime:
|
|
try:
|
|
parsed = datetime.fromisoformat(value)
|
|
except ValueError as exc:
|
|
raise CLIError(
|
|
"CLI 参数错误",
|
|
code="INVALID_TIME",
|
|
message=f"{option_name} must be a valid ISO 8601 / RFC 3339 timestamp",
|
|
exit_code=2,
|
|
) from exc
|
|
if parsed.tzinfo is None:
|
|
raise CLIError(
|
|
"CLI 参数错误",
|
|
code="TIMEZONE_REQUIRED",
|
|
message=f"{option_name} must include an explicit timezone offset",
|
|
exit_code=2,
|
|
)
|
|
return parsed
|
|
|
|
|
|
def read_json_input(path: Path, *, label: str) -> Any:
|
|
try:
|
|
return json.loads(path.read_text(encoding="utf-8"))
|
|
except FileNotFoundError as exc:
|
|
raise CLIError(
|
|
"CLI 参数错误",
|
|
code="INPUT_NOT_FOUND",
|
|
message=f"{label} file not found: {path}",
|
|
exit_code=2,
|
|
) from exc
|
|
except json.JSONDecodeError as exc:
|
|
raise CLIError(
|
|
"CLI 参数错误",
|
|
code="INPUT_INVALID_JSON",
|
|
message=f"{label} file must be valid JSON: {path}",
|
|
exit_code=2,
|
|
) from exc
|
|
|
|
|
|
def parse_burst_file(path: Path) -> tuple[list[str], list[float]]:
|
|
raw = read_json_input(path, label="burst")
|
|
if isinstance(raw, dict) and "bursts" in raw:
|
|
raw = raw["bursts"]
|
|
if isinstance(raw, dict) and "burst_ID" in raw and "burst_size" in raw:
|
|
ids = [str(item) for item in raw["burst_ID"]]
|
|
sizes = [float(item) for item in raw["burst_size"]]
|
|
if len(ids) != len(sizes):
|
|
raise CLIError(
|
|
"CLI 参数错误",
|
|
code="BURST_FILE_INVALID",
|
|
message="burst file burst_ID and burst_size must have the same length",
|
|
exit_code=2,
|
|
)
|
|
return ids, sizes
|
|
if isinstance(raw, list):
|
|
ids: list[str] = []
|
|
sizes: list[float] = []
|
|
for item in raw:
|
|
if not isinstance(item, dict) or "id" not in item or "size" not in item:
|
|
raise CLIError(
|
|
"CLI 参数错误",
|
|
code="BURST_FILE_INVALID",
|
|
message="burst file items must contain id and size",
|
|
exit_code=2,
|
|
)
|
|
ids.append(str(item["id"]))
|
|
sizes.append(float(item["size"]))
|
|
return ids, sizes
|
|
raise CLIError(
|
|
"CLI 参数错误",
|
|
code="BURST_FILE_INVALID",
|
|
message="burst file must be a JSON array or object with burst_ID/burst_size",
|
|
exit_code=2,
|
|
)
|
|
|
|
|
|
def parse_valve_setting_file(path: Path) -> tuple[list[str], list[float]]:
|
|
raw = read_json_input(path, label="valve-setting")
|
|
if isinstance(raw, dict) and "valves" in raw and "valves_k" in raw:
|
|
valves = [str(item) for item in raw["valves"]]
|
|
openings = [float(item) for item in raw["valves_k"]]
|
|
if len(valves) != len(openings):
|
|
raise CLIError(
|
|
"CLI 参数错误",
|
|
code="VALVE_SETTING_INVALID",
|
|
message="valves and valves_k must have the same length",
|
|
exit_code=2,
|
|
)
|
|
return valves, openings
|
|
if isinstance(raw, list):
|
|
valves: list[str] = []
|
|
openings: list[float] = []
|
|
for item in raw:
|
|
if not isinstance(item, dict) or "valve" not in item or "opening" not in item:
|
|
raise CLIError(
|
|
"CLI 参数错误",
|
|
code="VALVE_SETTING_INVALID",
|
|
message="valve-setting items must contain valve and opening",
|
|
exit_code=2,
|
|
)
|
|
valves.append(str(item["valve"]))
|
|
openings.append(float(item["opening"]))
|
|
return valves, openings
|
|
raise CLIError(
|
|
"CLI 参数错误",
|
|
code="VALVE_SETTING_INVALID",
|
|
message="valve-setting file must be a JSON array or object with valves/valves_k",
|
|
exit_code=2,
|
|
)
|
|
|
|
|
|
def parse_optional_dataset_file(path: Path | None, *, label: str) -> Any:
|
|
if path is None:
|
|
return None
|
|
return read_json_input(path, label=label)
|
|
|
|
|
|
def build_headers(
|
|
ctx: RuntimeContext,
|
|
*,
|
|
require_auth: bool,
|
|
require_project: bool,
|
|
) -> dict[str, str]:
|
|
headers = {
|
|
"Accept": "application/json, text/plain, */*",
|
|
"X-Request-Id": ctx.request_id,
|
|
}
|
|
headers.update(ctx.auth.headers)
|
|
if require_auth:
|
|
headers["Authorization"] = f"Bearer {require_access_token(ctx)}"
|
|
elif ctx.auth.access_token:
|
|
headers["Authorization"] = f"Bearer {ctx.auth.access_token}"
|
|
if require_project:
|
|
headers["X-Project-Id"] = require_project_id(ctx)
|
|
elif ctx.auth.project_id:
|
|
headers["X-Project-Id"] = ctx.auth.project_id
|
|
if ctx.auth.user_id:
|
|
headers["X-User-Id"] = ctx.auth.user_id
|
|
return headers
|
|
|
|
|
|
def _extract_error_message(response: requests.Response) -> str:
|
|
try:
|
|
payload = response.json()
|
|
except ValueError:
|
|
text = response.text.strip()
|
|
return text or f"http {response.status_code}"
|
|
|
|
if isinstance(payload, dict):
|
|
detail = payload.get("detail")
|
|
if isinstance(detail, str):
|
|
return detail
|
|
if isinstance(detail, list):
|
|
return "; ".join(json.dumps(item, ensure_ascii=False) for item in detail)
|
|
message = payload.get("message")
|
|
if isinstance(message, str):
|
|
return message
|
|
return json.dumps(payload, ensure_ascii=False)
|
|
|
|
|
|
def map_http_status_to_exit_code(status_code: int) -> int:
|
|
if status_code in (400, 422):
|
|
return 2
|
|
if status_code == 401:
|
|
return 3
|
|
if status_code == 403:
|
|
return 4
|
|
if status_code == 404:
|
|
return 5
|
|
if status_code in (409, 412):
|
|
return 6
|
|
return 7
|
|
|
|
|
|
def _parse_response_body(response: requests.Response) -> Any:
|
|
if response.status_code == 204 or not response.content:
|
|
return {}
|
|
content_type = response.headers.get("content-type", "").lower()
|
|
if "application/json" in content_type:
|
|
payload = response.json()
|
|
if isinstance(payload, dict) and payload.get("status") == "error":
|
|
raise CLIError(
|
|
"服务端错误",
|
|
code="SERVER_ERROR",
|
|
message=str(payload.get("message") or "server returned error status"),
|
|
exit_code=7,
|
|
data=payload,
|
|
)
|
|
return payload
|
|
text = response.text
|
|
if text:
|
|
return {"report": text}
|
|
return {}
|
|
|
|
|
|
def request_json(
|
|
ctx: RuntimeContext,
|
|
*,
|
|
method: str,
|
|
path: str,
|
|
params: dict[str, Any] | None = None,
|
|
json_body: Any = None,
|
|
require_auth: bool = True,
|
|
require_project: bool = False,
|
|
require_network_ctx: bool = False,
|
|
require_username_ctx: bool = False,
|
|
) -> tuple[Any, int]:
|
|
require_server(ctx)
|
|
if require_network_ctx:
|
|
require_network(ctx)
|
|
if require_username_ctx:
|
|
require_username(ctx)
|
|
|
|
url = f"{require_server(ctx)}/api/v1{path}"
|
|
headers = build_headers(ctx, require_auth=require_auth, require_project=require_project)
|
|
started = time.monotonic()
|
|
try:
|
|
response = requests.request(
|
|
method=method.upper(),
|
|
url=url,
|
|
params=params,
|
|
json=json_body,
|
|
headers=headers,
|
|
timeout=ctx.timeout,
|
|
)
|
|
except requests.Timeout as exc:
|
|
raise CLIError(
|
|
"请求超时",
|
|
code="REQUEST_TIMEOUT",
|
|
message=f"request timed out after {ctx.timeout} seconds",
|
|
exit_code=7,
|
|
retryable=True,
|
|
) from exc
|
|
except requests.RequestException as exc:
|
|
raise CLIError(
|
|
"连接失败",
|
|
code="REQUEST_FAILED",
|
|
message=str(exc),
|
|
exit_code=7,
|
|
retryable=True,
|
|
) from exc
|
|
duration_ms = int((time.monotonic() - started) * 1000)
|
|
|
|
if not response.ok:
|
|
raise CLIError(
|
|
"请求失败",
|
|
code=f"HTTP_{response.status_code}",
|
|
message=_extract_error_message(response),
|
|
exit_code=map_http_status_to_exit_code(response.status_code),
|
|
retryable=response.status_code >= 500,
|
|
)
|
|
return _parse_response_body(response), duration_ms
|
|
|
|
|
|
def request_bytes(
|
|
ctx: RuntimeContext,
|
|
*,
|
|
method: str,
|
|
path: str,
|
|
params: dict[str, Any] | None = None,
|
|
require_auth: bool = True,
|
|
require_project: bool = False,
|
|
require_network_ctx: bool = False,
|
|
) -> tuple[bytes, int]:
|
|
require_server(ctx)
|
|
if require_network_ctx:
|
|
require_network(ctx)
|
|
|
|
url = f"{require_server(ctx)}/api/v1{path}"
|
|
headers = build_headers(ctx, require_auth=require_auth, require_project=require_project)
|
|
started = time.monotonic()
|
|
try:
|
|
response = requests.request(
|
|
method=method.upper(),
|
|
url=url,
|
|
params=params,
|
|
headers=headers,
|
|
timeout=ctx.timeout,
|
|
)
|
|
except requests.Timeout as exc:
|
|
raise CLIError(
|
|
"请求超时",
|
|
code="REQUEST_TIMEOUT",
|
|
message=f"request timed out after {ctx.timeout} seconds",
|
|
exit_code=7,
|
|
retryable=True,
|
|
) from exc
|
|
except requests.RequestException as exc:
|
|
raise CLIError(
|
|
"连接失败",
|
|
code="REQUEST_FAILED",
|
|
message=str(exc),
|
|
exit_code=7,
|
|
retryable=True,
|
|
) from exc
|
|
duration_ms = int((time.monotonic() - started) * 1000)
|
|
|
|
if not response.ok:
|
|
raise CLIError(
|
|
"请求失败",
|
|
code=f"HTTP_{response.status_code}",
|
|
message=_extract_error_message(response),
|
|
exit_code=map_http_status_to_exit_code(response.status_code),
|
|
retryable=response.status_code >= 500,
|
|
)
|
|
return response.content, duration_ms
|
|
|
|
|
|
def build_success_payload(
|
|
*,
|
|
summary: str,
|
|
data: Any,
|
|
server: str | None,
|
|
request_id: str,
|
|
duration_ms: int,
|
|
next_commands: list[str] | None = None,
|
|
) -> dict[str, Any]:
|
|
return {
|
|
"ok": True,
|
|
"schema_version": SCHEMA_VERSION,
|
|
"summary": summary,
|
|
"data": data,
|
|
"metadata": {
|
|
"request_id": request_id,
|
|
"server": server,
|
|
"duration_ms": duration_ms,
|
|
"generated_at": datetime.now(timezone.utc).isoformat(timespec="seconds").replace("+00:00", "Z"),
|
|
},
|
|
"next_commands": next_commands or [],
|
|
}
|
|
|
|
|
|
def build_failure_payload(
|
|
*,
|
|
summary: str,
|
|
code: str,
|
|
message: str,
|
|
retryable: bool,
|
|
server: str | None,
|
|
request_id: str | None,
|
|
next_commands: list[str] | None = None,
|
|
data: Any = None,
|
|
) -> dict[str, Any]:
|
|
return {
|
|
"ok": False,
|
|
"schema_version": SCHEMA_VERSION,
|
|
"summary": summary,
|
|
"error": {
|
|
"code": code,
|
|
"message": message,
|
|
"retryable": retryable,
|
|
},
|
|
"data": data,
|
|
"metadata": {
|
|
"request_id": request_id,
|
|
"server": server,
|
|
"generated_at": datetime.now(timezone.utc).isoformat(timespec="seconds").replace("+00:00", "Z"),
|
|
},
|
|
"next_commands": next_commands or [],
|
|
}
|
|
|
|
|
|
def emit_success(
|
|
*,
|
|
summary: str,
|
|
data: Any,
|
|
ctx: RuntimeContext,
|
|
duration_ms: int,
|
|
next_commands: list[str] | None = None,
|
|
) -> None:
|
|
typer.echo(
|
|
json.dumps(
|
|
build_success_payload(
|
|
summary=summary,
|
|
data=data,
|
|
server=ctx.server,
|
|
request_id=ctx.request_id,
|
|
duration_ms=duration_ms,
|
|
next_commands=next_commands,
|
|
),
|
|
ensure_ascii=False,
|
|
)
|
|
)
|
|
|
|
|
|
def emit_failure(
|
|
*,
|
|
summary: str,
|
|
code: str,
|
|
message: str,
|
|
exit_code: int,
|
|
retryable: bool,
|
|
server: str | None,
|
|
request_id: str | None,
|
|
next_commands: list[str] | None = None,
|
|
data: Any = None,
|
|
) -> int:
|
|
typer.echo(
|
|
json.dumps(
|
|
build_failure_payload(
|
|
summary=summary,
|
|
code=code,
|
|
message=message,
|
|
retryable=retryable,
|
|
server=server,
|
|
request_id=request_id,
|
|
next_commands=next_commands,
|
|
data=data,
|
|
),
|
|
ensure_ascii=False,
|
|
)
|
|
)
|
|
return exit_code
|