asyncio_pool.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. # -*- coding: utf-8 -*-
  2. # @Time : 2023/7/20 16:10
  3. # @Author : XuJiakai
  4. # @File : async_pool
  5. # @Software: PyCharm
  6. import asyncio
  7. import functools
  8. import signal
  9. import sys
  10. from typing import Coroutine
  11. class AsyncPool(object):
  12. def __init__(self, max_concurrency: int):
  13. self._semaphore: asyncio.Semaphore = asyncio.Semaphore(max_concurrency)
  14. async def create_task(self, coro: Coroutine) -> asyncio.Task:
  15. await self._semaphore.acquire()
  16. task: asyncio.Task = asyncio.create_task(coro)
  17. task.add_done_callback(lambda t: self._semaphore.release())
  18. return task
  19. class GracefulExit(SystemExit):
  20. code = 1
  21. class AsyncPoolListenShut:
  22. def __init__(self, max_concurrency: int):
  23. self._semaphore: asyncio.Semaphore = asyncio.Semaphore(max_concurrency)
  24. self._data: dict = {}
  25. self._look = False
  26. self._is_windows = sys.platform == 'win32'
  27. self._register_shutdown_by_signal()
  28. self._current_data = None
  29. def _register_shutdown_by_signal(self):
  30. print("注册shutdown")
  31. signal.signal(signal.SIGINT, functools.partial(self.listen_shutdown))
  32. signal.signal(signal.SIGTERM, functools.partial(self.listen_shutdown))
  33. pass
  34. async def create_task(self, coro: Coroutine, data) -> asyncio.Task:
  35. self._current_data = data
  36. if self._look:
  37. print("停止消费...")
  38. await asyncio.sleep(10000)
  39. pass
  40. await self._semaphore.acquire()
  41. task: asyncio.Task = asyncio.create_task(coro)
  42. print('创建task,id: ', id(task))
  43. self._data[task] = data
  44. task.add_done_callback(lambda t: self._release(t))
  45. return task
  46. pass
  47. def _release(self, t):
  48. print("释放task,id:", id(t))
  49. del self._data[t]
  50. self._semaphore.release()
  51. async def _shutdown(self):
  52. print("检测到停止信号...")
  53. # await self._look.acquire()
  54. count = len(self._data)
  55. print('当前数量%s' % count)
  56. num = 0
  57. while count > 0:
  58. num += 1
  59. print("\n第%s" % num)
  60. for t in self._data.keys():
  61. print(id(t), "是否结束", t.done())
  62. count -= 1 if t.done() else 0
  63. await asyncio.sleep(3)
  64. await asyncio.sleep(3)
  65. print("所有任务已经结束!")
  66. def listen_shutdown(self, *args, **kwargs):
  67. all_data = list(self._data.values()) + [self._current_data]
  68. print("all_data: ", all_data)
  69. self._look = True
  70. print("所有任务已经结束!")
  71. loop = asyncio.get_running_loop()
  72. tasks = asyncio.tasks.all_tasks(loop)
  73. for t in tasks:
  74. t.cancel()
  75. loop.stop()
  76. # raise GracefulExit()
  77. pass
  78. pass
  79. async def run1(tt=None):
  80. print("sleeping")
  81. await asyncio.sleep(3)
  82. print("slept")
  83. if tt:
  84. print(tt)
  85. pass
  86. async def callback(msg):
  87. print("callback")
  88. await asyncio.sleep(3)
  89. pass
  90. async def main():
  91. pool = AsyncPoolTest(5, callback)
  92. for i in range(10):
  93. await pool.create_task(run1(), str(i))
  94. pass
  95. if __name__ == '__main__':
  96. asyncio.run(main())
  97. pass