import pandas as pd import numpy as np import matplotlib.pyplot as plt import os ID_LIKE_COLUMNS = { "id", "device_id", "node_id", "sensor_id", "monitor_id", "junction_id", } def _normalize_time_frame(data: pd.DataFrame) -> pd.DataFrame: """返回按时间排序的副本,并尽量将 time 列解析为时间类型。""" data = data.copy() if "time" in data.columns: data["time"] = pd.to_datetime(data["time"], errors="coerce") data = data.sort_values(["time"]).reset_index(drop=True) return data def _select_pressure_columns(data: pd.DataFrame) -> tuple[list[str], list[str]]: """区分需要清洗的数值列与需要原样保留的列。""" value_cols: list[str] = [] keep_cols: list[str] = [] for col in data.columns: if col == "time": continue col_key = col.lower() if col_key in ID_LIKE_COLUMNS or col_key.endswith("_id"): keep_cols.append(col) continue numeric = pd.to_numeric(data[col], errors="coerce") if numeric.notna().sum() == 0 or numeric.nunique(dropna=True) <= 1: keep_cols.append(col) else: value_cols.append(col) return value_cols, keep_cols def _robust_scale(values: pd.Series) -> float: """基于 MAD 计算稳健尺度。""" series = pd.to_numeric(values, errors="coerce").dropna() if series.empty: return 1.0 median = series.median() mad = (series - median).abs().median() if pd.notna(mad) and mad > 0: return float(1.4826 * mad) iqr = series.quantile(0.75) - series.quantile(0.25) if pd.notna(iqr) and iqr > 0: return float(iqr / 1.349) std = series.std() if pd.notna(std) and std > 0: return float(std) return 1.0 def _shrink_toward_baseline(observed: float, baseline: float, scale: float) -> float: """把观测值向基线值收缩,scale 越小,修复越强。""" if pd.isna(observed): return baseline if pd.isna(baseline): return observed diff = observed - baseline weight = scale / (abs(diff) + scale) return float(baseline + diff * weight) def _infer_time_frequency(time_values: pd.Series | pd.Index) -> pd.Timedelta: """从时间序列中推断采样频率,失败时默认 15 分钟。""" parsed = pd.to_datetime(pd.Series(time_values), errors="coerce").dropna().sort_values() if len(parsed) < 2: return pd.Timedelta(minutes=15) diffs = parsed.diff().dropna() diffs = diffs[diffs > pd.Timedelta(0)] if diffs.empty: return pd.Timedelta(minutes=15) mode = diffs.mode() return mode.iloc[0] if not mode.empty else diffs.median() def _build_local_pressure_baseline(series: pd.Series) -> pd.Series: """基于局部插值与中值滤波构造平滑基线。""" baseline = _safe_time_interpolate(series) baseline = baseline.rolling(window=5, center=True, min_periods=1).median() baseline = _safe_time_interpolate(baseline) return baseline.ffill().bfill() def _build_seasonal_pressure_baseline(series: pd.Series) -> pd.Series: """按一天内的同一时刻构造季节性基线,适合日周期压力数据。""" if not isinstance(series.index, pd.DatetimeIndex): return pd.Series(np.nan, index=series.index, dtype=float) slot_labels = pd.Series(series.index.strftime("%H:%M:%S"), index=series.index) return series.groupby(slot_labels).transform("median") def _detect_pressure_spikes(series: pd.Series, local_baseline: pd.Series) -> pd.Series: """识别单点异常上升/下降尖峰,避免过度修正正常波动。""" residual = series - local_baseline neighbor_center = (series.shift(1) + series.shift(-1)) / 2 curvature = series - neighbor_center residual_scale = max(_robust_scale(residual), 1e-6) curvature_scale = max(_robust_scale(curvature), 1e-6) direction_flip = ((series - series.shift(1)) * (series.shift(-1) - series) < 0).fillna(False) return ( residual.abs() > 3.5 * residual_scale ) & ( curvature.abs() > 3.0 * curvature_scale ) & direction_flip def _fill_pressure_gaps( original: pd.Series, repaired: pd.Series, local_baseline: pd.Series, seasonal_baseline: pd.Series, ) -> pd.Series: """短缺口用局部插值,长缺口优先使用同一时刻的季节性轨迹。""" missing_mask = original.isna() if not missing_mask.any(): return repaired gap_groups = (missing_mask != missing_mask.shift(fill_value=False)).cumsum() gap_lengths = missing_mask.groupby(gap_groups).transform("sum").where(missing_mask, 0) filled = repaired.copy() short_gap_mask = missing_mask & (gap_lengths < 4) long_gap_mask = missing_mask & ~short_gap_mask filled[short_gap_mask] = local_baseline[short_gap_mask] long_gap_fill = seasonal_baseline.where(seasonal_baseline.notna(), local_baseline) filled[long_gap_mask] = long_gap_fill[long_gap_mask] return filled def _clean_pressure_series(series: pd.Series) -> pd.Series: """清洗单个压力时间序列。""" series = pd.to_numeric(series, errors="coerce").astype(float) local_baseline = _build_local_pressure_baseline(series) spike_mask = _detect_pressure_spikes(series, local_baseline) repaired = series.copy() repaired[spike_mask] = local_baseline[spike_mask] seasonal_baseline = _build_seasonal_pressure_baseline(repaired) repaired = _fill_pressure_gaps(series, repaired, local_baseline, seasonal_baseline) if repaired.isna().any(): repaired = repaired.where(repaired.notna(), local_baseline) return repaired.ffill().bfill() def _format_time_column(data: pd.DataFrame) -> pd.DataFrame: """统一输出时间格式,方便下游直接按 ISO 字符串解析。""" if "time" not in data.columns: return data formatted = data.copy() time_values = pd.to_datetime(formatted["time"], errors="coerce") if time_values.isna().all(): return formatted if time_values.dt.tz is not None: time_strings = time_values.dt.strftime("%Y-%m-%dT%H:%M:%S%z") time_strings = time_strings.str.replace( r"([+-]\d{2})(\d{2})$", r"\1:\2", regex=True, ) else: time_strings = time_values.dt.strftime("%Y-%m-%dT%H:%M:%S") formatted["time"] = time_strings.where(time_values.notna(), formatted["time"]) return formatted def _expand_snapshot_time_grid(data: pd.DataFrame, freq: pd.Timedelta) -> pd.DataFrame: """仅补齐时间轴,不提前填充值,避免长缺口丢失原始形状特征。""" expanded = data.copy() expanded["time"] = pd.to_datetime(expanded["time"], errors="coerce") expanded = expanded.dropna(subset=["time"]).sort_values("time") if expanded.empty: return data indexed = expanded.set_index("time") full_index = pd.date_range(indexed.index.min(), indexed.index.max(), freq=freq) indexed = indexed.reindex(full_index) indexed.index.name = "time" return indexed.reset_index() def _safe_datetime_index(values: pd.Series | pd.Index | list[object]) -> pd.DatetimeIndex | None: """尽量把时间值标准化为 DatetimeIndex;失败则返回 None。""" parsed = pd.to_datetime(values, errors="coerce") try: datetime_index = pd.DatetimeIndex(parsed) except (TypeError, ValueError): return None if datetime_index.isna().all(): return None return datetime_index def _safe_time_interpolate(series: pd.Series) -> pd.Series: """仅在索引确实是 DatetimeIndex 时使用 time interpolation。""" if isinstance(series.index, pd.DatetimeIndex): return series.interpolate(method="time", limit_direction="both") return series.interpolate(limit_direction="both") def _detect_long_form_identifier(data: pd.DataFrame, value_cols: list[str], keep_cols: list[str]) -> str | None: """识别 time/id/value 长表结构。""" if "time" not in data.columns or len(value_cols) != 1: return None identifier_candidates = [ col for col in keep_cols if col.lower() in ID_LIKE_COLUMNS or col.lower().endswith("_id") ] if len(identifier_candidates) != 1: return None if not data["time"].duplicated().any(): return None return identifier_candidates[0] def _clean_long_form_pressure( data: pd.DataFrame, value_col: str, identifier_col: str, keep_cols: list[str], fill_gaps: bool, ) -> pd.DataFrame: """按测点拆分 long-form 压力数据,再逐列清洗后恢复原结构。""" data = _normalize_time_frame(data) wide_df = ( data[[identifier_col, "time", value_col]] .pivot(index="time", columns=identifier_col, values=value_col) .reset_index() ) sensor_cols = [col for col in wide_df.columns if col != "time"] cleaned_wide = _clean_snapshot_pressure(wide_df, sensor_cols, keep_cols=[], fill_gaps=fill_gaps) cleaned_long = cleaned_wide.melt( id_vars="time", var_name=identifier_col, value_name=value_col, ) passthrough_cols = [col for col in keep_cols if col != identifier_col] if passthrough_cols: metadata = data[[identifier_col] + passthrough_cols].drop_duplicates(subset=[identifier_col]) cleaned_long = cleaned_long.merge(metadata, on=identifier_col, how="left") try: cleaned_long[identifier_col] = cleaned_long[identifier_col].astype(data[identifier_col].dtype) except (TypeError, ValueError): pass cleaned_long = cleaned_long.sort_values(["time", identifier_col]).reset_index(drop=True) ordered_cols = ["time", identifier_col] + passthrough_cols + [value_col] cleaned_long = cleaned_long[[col for col in ordered_cols if col in cleaned_long.columns]] return cleaned_long def _build_time_slot_frame( data: pd.DataFrame, value_col: str, expected_slots: int ) -> pd.DataFrame: """把重复时间点整理成 time x slot 的矩阵。""" grouped = data.groupby("time", sort=True) times = list(grouped.groups.keys()) slot_frame = pd.DataFrame(index=pd.Index(times, name="time"), columns=range(expected_slots), dtype=float) for time_value, group in grouped: values = pd.to_numeric(group[value_col], errors="coerce").tolist() for slot_idx, value in enumerate(values[:expected_slots]): slot_frame.loc[time_value, slot_idx] = value return slot_frame def _slot_baseline(slot_frame: pd.DataFrame) -> pd.DataFrame: """对每个槽位做时间插值和平滑,得到基线轨迹。""" baseline = pd.DataFrame(index=slot_frame.index, columns=slot_frame.columns, dtype=float) for col in slot_frame.columns: series = slot_frame[col].astype(float) series = _safe_time_interpolate(series) series = series.rolling(window=5, center=True, min_periods=1).median() series = _safe_time_interpolate(series).ffill().bfill() baseline[col] = series return baseline def _choose_insertion_position( observed: list[float], baseline_row: pd.Series, expected_slots: int ) -> int: """为少一个观测值的时间组选择最合理的插入位置。""" missing_count = expected_slots - len(observed) if missing_count <= 0: return 0 best_pos = 0 best_cost = float("inf") for insert_pos in range(expected_slots): cost = 0.0 obs_idx = 0 for slot_idx in range(expected_slots): if slot_idx == insert_pos: continue obs_value = observed[obs_idx] base_value = float(baseline_row.iloc[slot_idx]) if pd.notna(obs_value) and pd.notna(base_value): cost += abs(obs_value - base_value) obs_idx += 1 if cost < best_cost: best_cost = cost best_pos = insert_pos return best_pos def _clean_repeated_timestamp_pressure( data: pd.DataFrame, value_col: str, keep_cols: list[str] ) -> pd.DataFrame: """针对同一时间点重复采样的压力数据进行修复。""" data = _normalize_time_frame(data) grouped_sizes = data.groupby("time").size() if grouped_sizes.empty: return data expected_slots = int(grouped_sizes.mode().iloc[0]) if not grouped_sizes.mode().empty else int(grouped_sizes.max()) expected_slots = max(expected_slots, int(grouped_sizes.max())) slot_frame = _build_time_slot_frame(data, value_col, expected_slots) baseline_frame = _slot_baseline(slot_frame) residuals = slot_frame - baseline_frame slot_scales = { col: max(_robust_scale(residuals[col]), 1e-6) for col in residuals.columns } cleaned_rows: list[dict[str, object]] = [] grouped = data.groupby("time", sort=True) for time_value, group in grouped: observed_values = pd.to_numeric(group[value_col], errors="coerce").tolist() baseline_row = baseline_frame.loc[time_value] insert_pos = _choose_insertion_position(observed_values, baseline_row, expected_slots) cleaned_values: list[float] = [] obs_idx = 0 for slot_idx in range(expected_slots): if slot_idx == insert_pos and len(observed_values) < expected_slots: cleaned_values.append(float(baseline_row.iloc[slot_idx])) continue if obs_idx >= len(observed_values): cleaned_values.append(float(baseline_row.iloc[slot_idx])) continue observed = observed_values[obs_idx] baseline = float(baseline_row.iloc[slot_idx]) cleaned_values.append( _shrink_toward_baseline(observed, baseline, slot_scales.get(slot_idx, 1.0)) ) obs_idx += 1 # 其余字段原样保留;常量列(如 id)直接复制第一条记录即可 template_row = group.iloc[0].to_dict() for slot_idx, cleaned_value in enumerate(cleaned_values): row = dict(template_row) row["time"] = time_value row[value_col] = cleaned_value cleaned_rows.append(row) cleaned_df = pd.DataFrame(cleaned_rows) cleaned_df = cleaned_df.sort_values(["time"]).reset_index(drop=True) ordered_cols = ["time"] + keep_cols + [value_col] ordered_cols = [col for col in ordered_cols if col in cleaned_df.columns] remaining_cols = [col for col in cleaned_df.columns if col not in ordered_cols] cleaned_df = cleaned_df[ordered_cols + remaining_cols] return _format_time_column(cleaned_df) def _clean_snapshot_pressure( data: pd.DataFrame, value_cols: list[str], keep_cols: list[str], fill_gaps: bool ) -> pd.DataFrame: """针对单条时间序列或多列快照数据进行稳健修复。""" data = _normalize_time_frame(data) if fill_gaps and "time" in data.columns: freq = _infer_time_frequency(data["time"]) data = _expand_snapshot_time_grid(data, freq) data["time"] = pd.to_datetime(data["time"], errors="coerce") data = data.sort_values(["time"]).reset_index(drop=True) cleaned_df = data.copy() time_index = ( _safe_datetime_index(cleaned_df["time"]) if "time" in cleaned_df.columns else None ) if time_index is None: time_index = pd.RangeIndex(start=0, stop=len(cleaned_df)) for col in value_cols: series = pd.Series( pd.to_numeric(cleaned_df[col], errors="coerce").to_numpy(), index=time_index, dtype=float, ) cleaned_df[col] = _clean_pressure_series(series).to_numpy() ordered_cols = ["time"] + keep_cols + value_cols ordered_cols = [col for col in ordered_cols if col in cleaned_df.columns] remaining_cols = [col for col in cleaned_df.columns if col not in ordered_cols] cleaned_df = cleaned_df[ordered_cols + remaining_cols] return _format_time_column(cleaned_df) def clean_pressure_data_km( input_csv_path: str, show_plot: bool = False, fill_gaps: bool = True ) -> str: """ 读取输入 CSV,基于时间结构进行稳健修复。输出为 _cleaned.xlsx(同目录)。 原始数据在 sheet 'raw_pressure_data',处理后数据在 sheet 'cleaned_pressusre_data'。 返回输出文件的绝对路径。 Args: input_csv_path: CSV 文件路径 show_plot: 是否显示可视化 fill_gaps: 是否先补齐时间缺口(默认 True) """ # 读取 CSV input_csv_path = os.path.abspath(input_csv_path) data = pd.read_csv(input_csv_path, header=0, index_col=None, encoding="utf-8") data = _normalize_time_frame(data) value_cols, keep_cols = _select_pressure_columns(data) has_repeated_time = "time" in data.columns and data["time"].duplicated().any() identifier_col = _detect_long_form_identifier(data, value_cols, keep_cols) if identifier_col is not None: data_repaired = _clean_long_form_pressure( data, value_cols[0], identifier_col, keep_cols, fill_gaps, ) elif has_repeated_time and len(value_cols) == 1: data_repaired = _clean_repeated_timestamp_pressure(data, value_cols[0], keep_cols) else: data_repaired = _clean_snapshot_pressure(data, value_cols, keep_cols, fill_gaps) # 可选可视化(只展示首个数值列) plt.rcParams["font.sans-serif"] = ["SimHei"] plt.rcParams["axes.unicode_minus"] = False if show_plot and value_cols: plot_col = value_cols[0] if "time" in data_repaired.columns: x = pd.to_datetime(data_repaired["time"], errors="coerce") else: x = np.arange(len(data_repaired)) plt.figure(figsize=(12, 6)) plt.plot(x, pd.to_numeric(data_repaired[plot_col], errors="coerce"), label="cleaned") plt.xlabel("时间" if "time" in data_repaired.columns else "序号") plt.ylabel("压力监测值") plt.title(f"{plot_col} 清洗结果") plt.legend() plt.show() # 保存到 Excel:两个 sheet input_dir = os.path.dirname(os.path.abspath(input_csv_path)) input_base = os.path.splitext(os.path.basename(input_csv_path))[0] output_filename = f"{input_base}_cleaned.xlsx" output_path = os.path.join(input_dir, output_filename) # 如果原始数据包含时间列,将其添加回结果 data_for_save = data.copy() data_repaired_for_save = data_repaired.copy() if os.path.exists(output_path): os.remove(output_path) # 覆盖同名文件 with pd.ExcelWriter(output_path, engine="openpyxl") as writer: data_for_save.to_excel(writer, sheet_name="raw_pressure_data", index=False) data_repaired_for_save.to_excel( writer, sheet_name="cleaned_pressusre_data", index=False ) # 返回输出文件的绝对路径 return os.path.abspath(output_path) def clean_pressure_data_df_km(data: pd.DataFrame, show_plot: bool = False) -> pd.DataFrame: """ 接收一个 DataFrame 数据结构,使用时间感知的稳健修复方法清洗压力数据。 返回清洗后的 DataFrame。 Args: data: 输入 DataFrame(可包含 time 列) show_plot: 是否显示可视化 """ # 使用传入的 DataFrame data = data.copy() data = _normalize_time_frame(data) value_cols, keep_cols = _select_pressure_columns(data) has_repeated_time = "time" in data.columns and data["time"].duplicated().any() identifier_col = _detect_long_form_identifier(data, value_cols, keep_cols) if identifier_col is not None: data_repaired = _clean_long_form_pressure( data, value_cols[0], identifier_col, keep_cols, fill_gaps=True, ) elif has_repeated_time and len(value_cols) == 1: data_repaired = _clean_repeated_timestamp_pressure(data, value_cols[0], keep_cols) else: data_repaired = _clean_snapshot_pressure(data, value_cols, keep_cols, fill_gaps=True) if show_plot and value_cols: plt.rcParams["font.sans-serif"] = ["SimHei"] plt.rcParams["axes.unicode_minus"] = False plot_col = value_cols[0] x = pd.to_datetime(data_repaired["time"], errors="coerce") if "time" in data_repaired.columns else np.arange(len(data_repaired)) plt.figure(figsize=(12, 6)) plt.plot(x, pd.to_numeric(data_repaired[plot_col], errors="coerce"), label="cleaned") plt.xlabel("时间" if "time" in data_repaired.columns else "序号") plt.ylabel("压力监测值") plt.title(f"{plot_col} 清洗结果") plt.legend() plt.show() return data_repaired # 测试 # if __name__ == "__main__": # # 默认使用脚本目录下的 pressure_raw_data.csv # script_dir = os.path.dirname(os.path.abspath(__file__)) # default_csv = os.path.join(script_dir, "pressure_raw_data.csv") # out_path = clean_pressure_data_km(default_csv, show_plot=False) # print("保存路径:", out_path) # 测试 clean_pressure_data_dict_km 函数 if __name__ == "__main__": import random # 读取 szh_pressure_scada.csv 文件 script_dir = os.path.dirname(os.path.abspath(__file__)) csv_path = os.path.join(script_dir, "szh_pressure_scada.csv") data = pd.read_csv(csv_path, header=0, index_col=None, encoding="utf-8") # 排除 Time 列,随机选择 5 列 columns_to_exclude = ["Time"] available_columns = [col for col in data.columns if col not in columns_to_exclude] selected_columns = random.sample(available_columns, 5) # 将选中的列转换为字典 data_dict = {col: data[col].tolist() for col in selected_columns} print("选中的列:", selected_columns) print("原始数据长度:", len(data_dict[selected_columns[0]])) # 调用函数进行清洗 cleaned_dict = clean_pressure_data_df_km(data_dict, show_plot=True) print("清洗后的字典键:", list(cleaned_dict.keys())) print("清洗后的数据长度:", len(cleaned_dict[selected_columns[0]])) print("测试完成:函数运行正常")