preprocess.py 3.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. from database import MySqlDao, CustConfig, OrderConfig
  2. import os
  3. import pandas as pd
  4. class DataProcess():
  5. def __init__(self, city_uuid, save_dir):
  6. self._mysql_dao = MySqlDao()
  7. self.save_dir = save_dir
  8. print("正在加载cust_info...")
  9. self._cust_data = self._mysql_dao.cust_table_dao.load_data(CustConfig.FEATURES_COLUMNS, city_uuid)
  10. print("正在加载analysis_info...")
  11. self._order_data = self._mysql_dao.order_table_dao.load_data(OrderConfig.FEATURE_COLUMNS, city_uuid)
  12. def data_process(self):
  13. """数据预处理"""
  14. # train_data_save_path = os.path.join(save_dir, "train.csv")
  15. # if os.path.exists(train_data_save_path):
  16. # os.remove(train_data_save_path)
  17. self._clean_cust_data()
  18. self._clean_order_data()
  19. train_data = self._generate_train_data()
  20. # train_data.to_csv(train_data_save_path, index=False, encoding="utf-8")
  21. return train_data
  22. def _clean_cust_data(self):
  23. """用户数据清洗"""
  24. self._cust_data["cust_code"] = self._cust_data["cust_code"].astype(str)
  25. # 根据配置规则清洗数据
  26. for feature, rules, in CustConfig.CLEANING_RULES.items():
  27. if rules["type"] == "num":
  28. # 先将数值型字符串转换为数值
  29. self._cust_data[feature] = pd.to_numeric(self._cust_data[feature], errors="coerce")
  30. if rules["method"] == "fillna":
  31. if rules["opt"] == "fill":
  32. self._cust_data[feature] = self._cust_data[feature].fillna(rules["value"]).infer_objects(copy=False)
  33. elif rules["opt"] == "replace":
  34. self._cust_data[feature] = self._cust_data[feature].fillna(self._cust_data[rules["value"]]).infer_objects(copy=False)
  35. elif rules["opt"] == "mean":
  36. self._cust_data[feature] = self._cust_data[feature].fillna(self._cust_data[feature].mean()).infer_objects(copy=False)
  37. self._cust_data[feature] = self._cust_data[feature].infer_objects(copy=False)
  38. def _clean_order_data(self):
  39. self._order_data["cust_code"] = self._order_data["cust_code"].astype(str)
  40. self._order_data["product_code"] = self._order_data["product_code"].astype(str)
  41. # self._order_data[order_cols.drop(col_all_missing)] = self._order_data[order_cols.drop(col_all_missing)].fillna(0)
  42. self._order_data["order_number_stability"] = self._order_data["order_number_stability"].fillna(0)
  43. self._order_data["order_quantity_stability"] = self._order_data["order_quantity_stability"].fillna(0)
  44. self._order_data["order_ratio_stability"] = self._order_data["order_ratio_stability"].fillna(0)
  45. self._order_data["real_demand_stability"] = self._order_data["real_demand_stability"].fillna(0)
  46. self._order_data = self._order_data.infer_objects(copy=False)
  47. def _generate_train_data(self):
  48. """生成训练数据"""
  49. union_data = self._order_data.merge(self._cust_data, on="cust_code", how="inner")
  50. return union_data
  51. if __name__ == '__main__':
  52. city_uuid = "00000000000000000000000011440601"
  53. save_dir = os.path.join("./data", city_uuid)
  54. dataprocess = DataProcess(city_uuid, save_dir)
  55. train_data = dataprocess.data_process()
  56. grouped = train_data.groupby('price_tier')
  57. os.makedirs(save_dir, exist_ok=True)
  58. for price_tier, group_df in grouped:
  59. tier_str = str(price_tier)
  60. file_name = f"价位段_{tier_str}.csv"
  61. save_data = group_df.drop('price_tier', axis=1)
  62. save_data.to_csv(os.path.join(save_dir, file_name), index=False)