|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197 |
- from importlib.metadata import metadata
- import logging
- import warnings
- import uuid
- from sanic import Blueprint, response
- from sanic.request import Request
- from sanic.response import HTTPResponse
- from socketio import AsyncServer
- from typing import Optional, Text, Any, List, Dict, Iterable, Callable, Awaitable
-
- from rasa.core.channels.channel import InputChannel
- from rasa.core.channels.channel import UserMessage, OutputChannel
-
- logger = logging.getLogger(__name__)
-
-
- class SocketBlueprint(Blueprint):
- def __init__(self, sio: AsyncServer, socketio_path, *args, **kwargs):
- self.sio = sio
- self.socketio_path = socketio_path
- super().__init__(*args, **kwargs)
-
- def register(self, app, options):
- self.sio.attach(app, self.socketio_path)
- super().register(app, options)
-
-
- class SocketIOOutput(OutputChannel):
- @classmethod
- def name(cls):
- return "socketio"
-
- def __init__(self, sio, sid, bot_message_evt):
- self.sio = sio
- self.sid = sid
- self.bot_message_evt = bot_message_evt
-
- async def _send_message(self, socket_id: Text, response: Any) -> None:
- """Sends a message to the recipient using the bot event."""
-
- await self.sio.emit(self.bot_message_evt, response, room=socket_id)
-
- async def send_text_message(
- self, recipient_id: Text, text: Text, **kwargs: Any
- ) -> None:
- """Send a message through this channel."""
-
- await self._send_message(self.sid, {"text": text})
-
- async def send_image_url(
- self, recipient_id: Text, image: Text, **kwargs: Any
- ) -> None:
- """Sends an image to the output"""
-
- message = {"attachment": {"type": "image", "payload": {"src": image}}}
- await self._send_message(self.sid, message)
-
- async def send_text_with_buttons(
- self,
- recipient_id: Text,
- text: Text,
- buttons: List[Dict[Text, Any]],
- **kwargs: Any,
- ) -> None:
- """Sends buttons to the output."""
-
- message = {"text": text, "quick_replies": []}
-
- for button in buttons:
- message["quick_replies"].append(
- {
- "content_type": "text",
- "title": button["title"],
- "payload": button["payload"],
- }
- )
-
- await self._send_message(self.sid, message)
-
- async def send_elements(
- self, recipient_id: Text, elements: Iterable[Dict[Text, Any]], **kwargs: Any
- ) -> None:
- """Sends elements to the output."""
-
- for element in elements:
- message = {
- "attachment": {
- "type": "template",
- "payload": {"template_type": "generic", "elements": element},
- }
- }
-
- await self._send_message(self.sid, message)
-
- async def send_custom_json(
- self, recipient_id: Text, json_message: Dict[Text, Any], **kwargs: Any
- ) -> None:
- """Sends custom json to the output"""
-
- json_message.setdefault("room", self.sid)
-
- await self.sio.emit(self.bot_message_evt, **json_message)
-
- async def send_attachment(
- self, recipient_id: Text, attachment: Dict[Text, Any], **kwargs: Any
- ) -> None:
- """Sends an attachment to the user."""
- await self._send_message(self.sid, {"attachment": attachment})
-
-
- class SocketIOInput(InputChannel):
- """A socket.io input channel."""
-
- @classmethod
- def name(cls) -> Text:
- return "socketio"
-
- @classmethod
- def from_credentials(cls, credentials: Optional[Dict[Text, Any]]) -> InputChannel:
- credentials = credentials or {}
- return cls(
- credentials.get("user_message_evt", "user_uttered"),
- credentials.get("bot_message_evt", "bot_uttered"),
- credentials.get("namespace"),
- credentials.get("session_persistence", False),
- credentials.get("socketio_path", "/socket.io"),
- )
-
- def __init__(
- self,
- user_message_evt: Text = "user_uttered",
- bot_message_evt: Text = "bot_uttered",
- namespace: Optional[Text] = None,
- session_persistence: bool = False,
- socketio_path: Optional[Text] = "/socket.io",
- ):
- self.bot_message_evt = bot_message_evt
- self.session_persistence = session_persistence
- self.user_message_evt = user_message_evt
- self.namespace = namespace
- self.socketio_path = socketio_path
-
- def blueprint(
- self, on_new_message: Callable[[UserMessage], Awaitable[Any]]
- ) -> Blueprint:
- # Workaround so that socketio works with requests from other origins.
- sio = AsyncServer(async_mode="sanic", cors_allowed_origins=[])
- socketio_webhook = SocketBlueprint(
- sio, self.socketio_path, "socketio_webhook", __name__
- )
-
- @socketio_webhook.route("/", methods=["GET"])
- async def health(_: Request) -> HTTPResponse:
- return response.json({"status": "ok"})
-
- @sio.on("connect", namespace=self.namespace)
- async def connect(sid: Text, _) -> None:
- logger.debug(f"User {sid} connected to socketIO endpoint.")
-
- @sio.on("disconnect", namespace=self.namespace)
- async def disconnect(sid: Text) -> None:
- logger.debug(f"User {sid} disconnected from socketIO endpoint.")
-
- @sio.on("session_request", namespace=self.namespace)
- async def session_request(sid: Text, data: Optional[Dict]):
- if data is None:
- data = {}
- if "session_id" not in data or data["session_id"] is None:
- data["session_id"] = uuid.uuid4().hex
- await sio.emit("session_confirm", data["session_id"], room=sid)
- logger.debug(f"User {sid} connected to socketIO endpoint.")
-
- @sio.on(self.user_message_evt, namespace=self.namespace)
- async def handle_message(sid: Text, data: Dict) -> Any:
- output_channel = SocketIOOutput(sio, sid, self.bot_message_evt)
-
- if self.session_persistence:
- if not data.get("session_id"):
- warnings.warn(
- "A message without a valid sender_id "
- "was received. This message will be "
- "ignored. Make sure to set a proper "
- "session id using the "
- "`session_request` socketIO event."
- )
- return
- sender_id = data["session_id"]
- else:
- sender_id = sid
-
- message = UserMessage(
- data["message"], output_channel, sender_id, input_channel=self.name(),
- metadata = data['customData']
- )
- await on_new_message(message)
-
- return socketio_webhook
|