# -*- coding: utf-8 -*- # @Time : 2023/7/20 16:10 # @Author : XuJiakai # @File : async_pool # @Software: PyCharm import asyncio import functools import signal import sys from typing import Coroutine from loguru import logger as log class AsyncPool(object): def __init__(self, max_concurrency: int): self._semaphore: asyncio.Semaphore = asyncio.Semaphore(max_concurrency) self._running_task_num = 0 async def create_task(self, coro: Coroutine) -> asyncio.Task: await self._semaphore.acquire() self._running_task_num += 1 task: asyncio.Task = asyncio.create_task(coro) task.add_done_callback(lambda t: self._release(t)) return task def is_done(self): return self._running_task_num == 0 async def wait_all_done(self): while self._running_task_num > 0: log.info("running task num : {} , wait done...", self._running_task_num) await asyncio.sleep(1) def _release(self, t): self._running_task_num -= 1 self._semaphore.release() class GracefulExit(SystemExit): code = 1 class AsyncPoolListenShut: def __init__(self, max_concurrency: int): self._semaphore: asyncio.Semaphore = asyncio.Semaphore(max_concurrency) self._data: dict = {} self._look = False self._is_windows = sys.platform == 'win32' self._register_shutdown_by_signal() self._current_data = None def _register_shutdown_by_signal(self): print("注册shutdown") signal.signal(signal.SIGINT, functools.partial(self.listen_shutdown)) signal.signal(signal.SIGTERM, functools.partial(self.listen_shutdown)) pass async def create_task(self, coro: Coroutine, data) -> asyncio.Task: self._current_data = data if self._look: print("停止消费...") await asyncio.sleep(10000) pass await self._semaphore.acquire() task: asyncio.Task = asyncio.create_task(coro) print('创建task,id: ', id(task)) self._data[task] = data task.add_done_callback(lambda t: self._release(t)) return task pass def _release(self, t): print("释放task,id:", id(t)) del self._data[t] self._semaphore.release() async def _shutdown(self): print("检测到停止信号...") # await self._look.acquire() count = len(self._data) print('当前数量%s' % count) num = 0 while count > 0: num += 1 print("\n第%s" % num) for t in self._data.keys(): print(id(t), "是否结束", t.done()) count -= 1 if t.done() else 0 await asyncio.sleep(3) await asyncio.sleep(3) print("所有任务已经结束!") def listen_shutdown(self, *args, **kwargs): all_data = list(self._data.values()) + [self._current_data] print("all_data: ", all_data) self._look = True print("所有任务已经结束!") loop = asyncio.get_running_loop() tasks = asyncio.tasks.all_tasks(loop) for t in tasks: t.cancel() loop.stop() # raise GracefulExit() pass pass async def run1(tt=None): print("sleeping") await asyncio.sleep(3) print("slept") if tt: print(tt) pass async def callback(msg): print("callback") await asyncio.sleep(3) pass async def main(): pool = AsyncPool(5) for i in range(10): await pool.create_task(run1()) await pool.wait_all_done() pass if __name__ == '__main__': asyncio.run(main()) pass