Browse Source

feat: 使用registry方式添加处理类

许家凯 1 năm trước cách đây
mục cha
commit
d893c19ffe

+ 1 - 1
data_clean/api/mongo_api.py

@@ -17,7 +17,7 @@ async def insert_one(collection, data):
 
 
 async def insert_many(collection, data: list):
-    result = await db[collection].insert_many(data)
+    result = await db[collection].insert_many(data, ordered=False)
     return result.inserted_id
     pass
 

+ 107 - 0
data_clean/dim_handle_registry.py

@@ -0,0 +1,107 @@
+# -*- coding: utf-8 -*-
+# @Time : 2023/7/24 11:03
+# @Author : XuJiakai
+# @File : dim_handle_registry
+# @Software: PyCharm
+
+class DimHandleRegistry:
+    def __init__(self, name=None):
+        if name is None:
+            self._name = "DefaultRegistry"
+        if name.endswith(".py"):
+            name = name[:-3]
+        self._name = name
+        # 创建注册表,以字典的形式。
+        self._obj_list = {}
+        self._row_func = []
+
+    def __registry(self, obj, name=None):
+        """
+        内部注册函数
+        :param obj:函数或者类的地址。
+        :return:
+        """
+        # 判断是否目标函数或者类已经注册,如果已经注册过则标错,如果没有则进行注册。
+        if name is None:
+            name = obj.__name__
+
+        assert (obj.__name__ not in self._obj_list.keys()), "{} already exists in {}".format(obj.__name__, self._name)
+        self._obj_list[name] = obj
+
+    def registry(self, obj=None):
+        """
+        # 外部注册函数。注册方法分为两种。
+        # 1.通过装饰器调用
+        # 2.通过函数的方式进行调用
+
+        :param obj: 函数或者类的本身
+        :return:
+        """
+        # 1.通过装饰器调用
+        if obj == None:
+            def _no_obj_registry(func__or__class, *args, **kwargs):
+                self.__registry(func__or__class)
+                # 此时被装饰的函数会被修改为该函数的返回值。
+                return func__or__class
+
+            return _no_obj_registry
+        # 2.通过函数的方式进行调用
+        self.__registry(obj)
+
+    def registry_prefix_func(self, obj=None):
+        if obj == None:
+            def _no_obj_registry(func__or__class, *args, **kwargs):
+                self.__registry(func__or__class, name="prefix_func")
+                # 此时被装饰的函数会被修改为该函数的返回值。
+                return func__or__class
+
+            return _no_obj_registry
+        # 2.通过函数的方式进行调用
+        self.__registry(obj, name="prefix_func")
+        pass
+
+    def registry_postfix_func(self, obj=None):
+        if obj == None:
+            def _no_obj_registry(func__or__class, *args, **kwargs):
+                self.__registry(func__or__class, name="postfix_func")
+                # 此时被装饰的函数会被修改为该函数的返回值。
+                return func__or__class
+
+            return _no_obj_registry
+        # 2.通过函数的方式进行调用
+        self.__registry(obj, name="postfix_func")
+        pass
+
+    def registry_row_func(self, obj=None):
+        if obj == None:
+            def _no_obj_registry(func__or__class, *args, **kwargs):
+                self._row_func.append(func__or__class)
+                # 此时被装饰的函数会被修改为该函数的返回值。
+                return func__or__class
+
+            return _no_obj_registry
+        # 2.通过函数的方式进行调用
+        self._row_func.append(obj)
+        pass
+
+    async def execute_dim(self, dim_data: list):
+        if "prefix_func" in self._obj_list:
+            await self._obj_list["prefix_func"](dim_data)
+
+        result_list = []
+        for row in dim_data:
+            row_data = await self._exec_row(row)
+            if row_data is not None:
+                result_list.append(row_data)
+
+        if "postfix_func" in self._obj_list and len(result_list) > 0:
+            await self._obj_list["postfix_func"](result_list)
+
+        return result_list
+
+    async def _exec_row(self, row_data):
+        for func in self._row_func:
+            row_data = await func(row_data)
+            pass
+
+        return row_data

+ 0 - 39
data_clean/dim_template_class.py

@@ -1,39 +0,0 @@
-# -*- coding: utf-8 -*-
-# @Time : 2023/7/20 17:41
-# @Author : XuJiakai
-# @File : dim_template_class
-# @Software: PyCharm
-from abc import abstractmethod
-
-
-class DimTemplateClass:
-    def __init__(self):
-        pass
-
-    @abstractmethod
-    async def _exec_row(self, row_data: dict):
-        raise NotImplementedError
-
-    async def _prefix_func(self, dim_data: list):
-        pass
-
-    async def _postfix_func(self, dim_data: list):
-        pass
-
-    async def execute_dim(self, dim_data: list):
-        await self._prefix_func(dim_data)
-
-        result_list = []
-        for row in dim_data:
-            row_data = await self._exec_row(row)
-            if row_data is not None:
-                result_list.append(row_data)
-
-        if len(result_list) > 0:
-            await self._prefix_func(result_list)
-
-        return result_list
-
-
-if __name__ == '__main__':
-    pass

+ 2 - 2
data_clean/exception/exception_handle.py

@@ -12,11 +12,11 @@ log = get_log("exception_handler")
 
 
 def exception_handle(func):
-    async def wrapper(self, *args):
+    async def wrapper(self, *args, **kwargs):
         tn = pascal_case_to_snake_case(self.__class__.__name__)
         result = None
         try:
-            result = await func(self, *args)
+            result = await func(self, *args, **kwargs)
         except (FetchException, RulerValidationException) as ex:
             log.warn("%s", ex)
             pass

+ 21 - 22
data_clean/handle/company_court_open_announcement.py

@@ -3,13 +3,33 @@
 # @Author : XuJiakai
 # @File : company_court_open_announcement
 # @Software: PyCharm
+import os
 
-from data_clean.dim_template_class import DimTemplateClass
+from data_clean.dim_handle_registry import DimHandleRegistry
 from data_clean.exception.exception_handle import exception_handle
 from data_clean.exception.ruler_validation_exception import RulerValidationException
 from data_clean.utils.str_utils import json_str_2_list
 
+# 必须命名为dim_handle
+dim_handle = DimHandleRegistry(os.path.basename(__file__))
 
+
+@dim_handle.registry_prefix_func
+@exception_handle
+async def prefix_func(dim_data: list):
+    print("前置程序:", dim_data)
+    pass
+
+
+@dim_handle.registry_postfix_func()
+@exception_handle
+async def post_func(dim_data: list):
+    print("后置程序")
+    pass
+
+
+@dim_handle.registry_row_func
+@exception_handle
 async def party_intersect(row_data: dict) -> dict:
     plaintiff_info = json_str_2_list(row_data['plaintiff_info'], "name")
     defendant_info = json_str_2_list(row_data['defendant_info'], "name")
@@ -22,24 +42,3 @@ async def party_intersect(row_data: dict) -> dict:
         raise RulerValidationException("ccoa_001", "当事人有交叉:%s" % inter)
 
     pass
-
-
-validate_func = [
-    party_intersect
-]
-
-
-class CompanyCourtOpenAnnouncement(DimTemplateClass):
-
-    @exception_handle
-    async def _exec_row(self, row_data: dict):
-        for func in validate_func:
-            row_data = await func(row_data)
-            pass
-
-        return row_data
-        pass
-
-
-if __name__ == '__main__':
-    pass

+ 11 - 6
data_clean/task_distributor.py

@@ -3,12 +3,16 @@
 # @Author : XuJiakai
 # @File : task_distributor
 # @Software: PyCharm
+import os
 
-import data_clean.handle
+scan_path = os.path.join(os.path.dirname(__file__), 'handle')
 
-func_dict = {
-    "company_court_open_announcement": data_clean.handle.company_court_open_announcement.CompanyCourtOpenAnnouncement()
-}
+file_name_list = [file_name[:-3] for file_name in os.listdir(scan_path) if not file_name.startswith("__")]
+
+class_dict = {}
+for tn in file_name_list:
+    tmp = __import__(f"data_clean.handle.{tn}", fromlist=(tn))
+    class_dict[tn] = tmp.dim_handle
 
 
 async def task_distribute(data: dict):
@@ -20,8 +24,8 @@ async def task_distribute(data: dict):
     tmp_data = data['data']
 
     for key in set(tmp_data.keys()):
-        if key in func_dict:
-            result_data = await func_dict[key].execute_dim(tmp_data[key])
+        if key in class_dict:
+            result_data = await class_dict[key].execute_dim(tmp_data[key])
             if len(result_data) == 0:
                 del tmp_data[key]
             else:
@@ -37,4 +41,5 @@ async def task_distribute(data: dict):
 
 
 if __name__ == '__main__':
+    print(class_dict)
     pass

+ 2 - 2
tests/TestMain.py

@@ -49,6 +49,6 @@ async def test_for_url():
 
 
 if __name__ == '__main__':
-    asyncio.run(test_send_kafka())
-    # asyncio.run(test_for_url())
+    # asyncio.run(test_send_kafka())
+    asyncio.run(test_for_url())
     pass