No puede seleccionar más de 25 temas Los temas deben comenzar con una letra o número, pueden incluir guiones ('-') y pueden tener hasta 35 caracteres de largo.
 
 
 

198 líneas
6.7 KiB

  1. from importlib.metadata import metadata
  2. import logging
  3. import warnings
  4. import uuid
  5. from sanic import Blueprint, response
  6. from sanic.request import Request
  7. from sanic.response import HTTPResponse
  8. from socketio import AsyncServer
  9. from typing import Optional, Text, Any, List, Dict, Iterable, Callable, Awaitable
  10. from rasa.core.channels.channel import InputChannel
  11. from rasa.core.channels.channel import UserMessage, OutputChannel
  12. logger = logging.getLogger(__name__)
  13. class SocketBlueprint(Blueprint):
  14. def __init__(self, sio: AsyncServer, socketio_path, *args, **kwargs):
  15. self.sio = sio
  16. self.socketio_path = socketio_path
  17. super().__init__(*args, **kwargs)
  18. def register(self, app, options):
  19. self.sio.attach(app, self.socketio_path)
  20. super().register(app, options)
  21. class SocketIOOutput(OutputChannel):
  22. @classmethod
  23. def name(cls):
  24. return "socketio"
  25. def __init__(self, sio, sid, bot_message_evt):
  26. self.sio = sio
  27. self.sid = sid
  28. self.bot_message_evt = bot_message_evt
  29. async def _send_message(self, socket_id: Text, response: Any) -> None:
  30. """Sends a message to the recipient using the bot event."""
  31. await self.sio.emit(self.bot_message_evt, response, room=socket_id)
  32. async def send_text_message(
  33. self, recipient_id: Text, text: Text, **kwargs: Any
  34. ) -> None:
  35. """Send a message through this channel."""
  36. await self._send_message(self.sid, {"text": text})
  37. async def send_image_url(
  38. self, recipient_id: Text, image: Text, **kwargs: Any
  39. ) -> None:
  40. """Sends an image to the output"""
  41. message = {"attachment": {"type": "image", "payload": {"src": image}}}
  42. await self._send_message(self.sid, message)
  43. async def send_text_with_buttons(
  44. self,
  45. recipient_id: Text,
  46. text: Text,
  47. buttons: List[Dict[Text, Any]],
  48. **kwargs: Any,
  49. ) -> None:
  50. """Sends buttons to the output."""
  51. message = {"text": text, "quick_replies": []}
  52. for button in buttons:
  53. message["quick_replies"].append(
  54. {
  55. "content_type": "text",
  56. "title": button["title"],
  57. "payload": button["payload"],
  58. }
  59. )
  60. await self._send_message(self.sid, message)
  61. async def send_elements(
  62. self, recipient_id: Text, elements: Iterable[Dict[Text, Any]], **kwargs: Any
  63. ) -> None:
  64. """Sends elements to the output."""
  65. for element in elements:
  66. message = {
  67. "attachment": {
  68. "type": "template",
  69. "payload": {"template_type": "generic", "elements": element},
  70. }
  71. }
  72. await self._send_message(self.sid, message)
  73. async def send_custom_json(
  74. self, recipient_id: Text, json_message: Dict[Text, Any], **kwargs: Any
  75. ) -> None:
  76. """Sends custom json to the output"""
  77. json_message.setdefault("room", self.sid)
  78. await self.sio.emit(self.bot_message_evt, **json_message)
  79. async def send_attachment(
  80. self, recipient_id: Text, attachment: Dict[Text, Any], **kwargs: Any
  81. ) -> None:
  82. """Sends an attachment to the user."""
  83. await self._send_message(self.sid, {"attachment": attachment})
  84. class SocketIOInput(InputChannel):
  85. """A socket.io input channel."""
  86. @classmethod
  87. def name(cls) -> Text:
  88. return "socketio"
  89. @classmethod
  90. def from_credentials(cls, credentials: Optional[Dict[Text, Any]]) -> InputChannel:
  91. credentials = credentials or {}
  92. return cls(
  93. credentials.get("user_message_evt", "user_uttered"),
  94. credentials.get("bot_message_evt", "bot_uttered"),
  95. credentials.get("namespace"),
  96. credentials.get("session_persistence", False),
  97. credentials.get("socketio_path", "/socket.io"),
  98. )
  99. def __init__(
  100. self,
  101. user_message_evt: Text = "user_uttered",
  102. bot_message_evt: Text = "bot_uttered",
  103. namespace: Optional[Text] = None,
  104. session_persistence: bool = False,
  105. socketio_path: Optional[Text] = "/socket.io",
  106. ):
  107. self.bot_message_evt = bot_message_evt
  108. self.session_persistence = session_persistence
  109. self.user_message_evt = user_message_evt
  110. self.namespace = namespace
  111. self.socketio_path = socketio_path
  112. def blueprint(
  113. self, on_new_message: Callable[[UserMessage], Awaitable[Any]]
  114. ) -> Blueprint:
  115. # Workaround so that socketio works with requests from other origins.
  116. sio = AsyncServer(async_mode="sanic", cors_allowed_origins=[])
  117. socketio_webhook = SocketBlueprint(
  118. sio, self.socketio_path, "socketio_webhook", __name__
  119. )
  120. @socketio_webhook.route("/", methods=["GET"])
  121. async def health(_: Request) -> HTTPResponse:
  122. return response.json({"status": "ok"})
  123. @sio.on("connect", namespace=self.namespace)
  124. async def connect(sid: Text, _) -> None:
  125. logger.debug(f"User {sid} connected to socketIO endpoint.")
  126. @sio.on("disconnect", namespace=self.namespace)
  127. async def disconnect(sid: Text) -> None:
  128. logger.debug(f"User {sid} disconnected from socketIO endpoint.")
  129. @sio.on("session_request", namespace=self.namespace)
  130. async def session_request(sid: Text, data: Optional[Dict]):
  131. if data is None:
  132. data = {}
  133. if "session_id" not in data or data["session_id"] is None:
  134. data["session_id"] = uuid.uuid4().hex
  135. await sio.emit("session_confirm", data["session_id"], room=sid)
  136. logger.debug(f"User {sid} connected to socketIO endpoint.")
  137. @sio.on(self.user_message_evt, namespace=self.namespace)
  138. async def handle_message(sid: Text, data: Dict) -> Any:
  139. output_channel = SocketIOOutput(sio, sid, self.bot_message_evt)
  140. if self.session_persistence:
  141. if not data.get("session_id"):
  142. warnings.warn(
  143. "A message without a valid sender_id "
  144. "was received. This message will be "
  145. "ignored. Make sure to set a proper "
  146. "session id using the "
  147. "`session_request` socketIO event."
  148. )
  149. return
  150. sender_id = data["session_id"]
  151. else:
  152. sender_id = sid
  153. message = UserMessage(
  154. data["message"], output_channel, sender_id, input_channel=self.name(),
  155. metadata = data['customData']
  156. )
  157. await on_new_message(message)
  158. return socketio_webhook