/ chat_workflow / storage_client.py
storage_client.py
  1  import boto3
  2  import hashlib
  3  import uuid
  4  from typing import Any, Dict, Union
  5  from chainlit.logger import logger
  6  from chainlit.data.storage_clients.base import BaseStorageClient
  7  from sqlalchemy import Column, String, Boolean, Integer, ForeignKey, Text, JSON
  8  from sqlalchemy.dialects.postgresql import UUID as PG_UUID, JSONB, ARRAY
  9  from sqlalchemy.ext.declarative import declarative_base
 10  from sqlalchemy.orm import relationship
 11  
 12  
 13  class MinIOStorageClient(BaseStorageClient):
 14      """
 15      Class to enable MinIO storage provider using the S3 compatible API
 16      """
 17  
 18      def __init__(self, bucket: str, endpoint_url: str, access_key: str, secret_key: str):
 19          try:
 20              self.bucket = bucket
 21              # Initialize boto3 client with MinIO-specific configurations
 22              self.client = boto3.client(
 23                  "s3",
 24                  endpoint_url=endpoint_url,
 25                  aws_access_key_id=access_key,
 26                  aws_secret_access_key=secret_key,
 27              )
 28              response = self.client.list_buckets()
 29              existing_buckets = [bucket['Name']
 30                                  for bucket in response['Buckets']]
 31              logger.info(
 32                  f"Successfully connected. Available buckets: {existing_buckets}")
 33  
 34              if self.bucket not in existing_buckets:
 35                  logger.info(
 36                      f"Bucket '{self.bucket}' does not exist. Creating it now.")
 37                  self.client.create_bucket(Bucket=self.bucket)
 38                  logger.info(f"Bucket '{self.bucket}' created successfully.")
 39              else:
 40                  logger.info(f"Bucket '{self.bucket}' already exists.")
 41  
 42              logger.info("MinIOStorageClient initialized")
 43          except Exception as e:
 44              logger.warning(f"MinIOStorageClient initialization error: {e}")
 45  
 46      async def upload_file(
 47          self,
 48          object_key: str,
 49          data: Union[bytes, str],
 50          mime: str = "application/octet-stream",
 51          overwrite: bool = True,
 52          content_md5: bool = False  # Optionally send content-md5
 53      ) -> Dict[str, Any]:
 54          try:
 55              # Optionally calculate and send content-md5
 56              extra_args = {"ContentType": mime}
 57              if content_md5:
 58                  md5_hash = hashlib.md5(data if isinstance(
 59                      data, bytes) else data.encode('utf-8')).digest()
 60                  extra_args["ContentMD5"] = md5_hash
 61              self.client.put_object(
 62                  Bucket=self.bucket, Key=object_key, Body=data, **extra_args
 63              )
 64              url = f"{self.client.meta.endpoint_url}/{self.bucket}/{object_key}"
 65              return {"object_key": object_key, "url": url}
 66          except Exception as e:
 67              logger.warning(f"MinIOStorageClient, upload_file error: {e}")
 68              return {}
 69  
 70      async def get_read_url(self, object_key: str) -> str:
 71          """
 72          Generate a presigned URL for reading/downloading a file from MinIO
 73          """
 74          try:
 75              url = self.client.generate_presigned_url(
 76                  'get_object',
 77                  Params={'Bucket': self.bucket, 'Key': object_key},
 78                  ExpiresIn=3600  # URL valid for 1 hour
 79              )
 80              return url
 81          except Exception as e:
 82              logger.warning(f"MinIOStorageClient, get_read_url error: {e}")
 83              return f"{self.client.meta.endpoint_url}/{self.bucket}/{object_key}"
 84  
 85      async def delete_file(self, object_key: str) -> bool:
 86          """
 87          Delete a file from MinIO storage
 88          """
 89          try:
 90              self.client.delete_object(Bucket=self.bucket, Key=object_key)
 91              logger.info(f"Deleted file: {object_key}")
 92              return True
 93          except Exception as e:
 94              logger.warning(f"MinIOStorageClient, delete_file error: {e}")
 95              return False
 96  
 97      async def close(self) -> None:
 98          """
 99          Close the storage client connection (cleanup if needed)
100          """
101          # boto3 client doesn't require explicit closing
102          logger.info("MinIOStorageClient closed")
103          pass
104  
105  
106  Base = declarative_base()
107  
108  
109  class User(Base):
110      __tablename__ = 'users'
111      id = Column(PG_UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
112      identifier = Column(String, nullable=False, unique=True)
113      metadata_ = Column("metadata", JSONB, nullable=False)
114      createdAt = Column(String)
115  
116  
117  class Thread(Base):
118      __tablename__ = 'threads'
119      id = Column(PG_UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
120      createdAt = Column(String)
121      name = Column(String)
122      userId = Column(PG_UUID(as_uuid=True), ForeignKey(
123          'users.id', ondelete='CASCADE'))
124      userIdentifier = Column(String)
125      tags = Column(ARRAY(String))
126      metadata_ = Column("metadata", JSONB)
127  
128      user = relationship("User", backref="threads")
129  
130  
131  class Step(Base):
132      __tablename__ = 'steps'
133      id = Column(PG_UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
134      name = Column(String, nullable=False)
135      type = Column(String, nullable=False)
136      threadId = Column(PG_UUID(as_uuid=True), ForeignKey(
137          'threads.id'), nullable=False)
138      parentId = Column(PG_UUID(as_uuid=True))
139      disableFeedback = Column(Boolean, nullable=True)
140      streaming = Column(Boolean, nullable=False)
141      waitForAnswer = Column(Boolean)
142      isError = Column(Boolean)
143      metadata_ = Column("metadata", JSONB)
144      tags = Column(ARRAY(String))
145      input = Column(Text)
146      output = Column(Text)
147      createdAt = Column(String)
148      start = Column(String)
149      end = Column(String)
150      generation = Column(JSONB)
151      showInput = Column(Text)
152      language = Column(String)
153      indent = Column(Integer)
154  
155  
156  class Element(Base):
157      __tablename__ = 'elements'
158      id = Column(PG_UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
159      threadId = Column(PG_UUID(as_uuid=True), ForeignKey('threads.id'))
160      type = Column(String)
161      url = Column(String)
162      chainlitKey = Column(String)
163      name = Column(String, nullable=False)
164      display = Column(String)
165      objectKey = Column(String)
166      size = Column(String)
167      page = Column(Integer)
168      language = Column(String)
169      forId = Column(PG_UUID(as_uuid=True))
170      mime = Column(String)
171  
172  
173  class Feedback(Base):
174      __tablename__ = 'feedbacks'
175      id = Column(PG_UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
176      forId = Column(PG_UUID(as_uuid=True), nullable=False)
177      threadId = Column(PG_UUID(as_uuid=True), ForeignKey(
178          'threads.id'), nullable=False)
179      value = Column(Integer, nullable=False)
180      comment = Column(Text)
181  
182      thread = relationship("Thread", backref="feedbacks")
183  
184  
185  class LangGraph(Base):
186      __tablename__ = 'langgraphs'
187      thread_id = Column(String, primary_key=True)
188      state = Column(JSON, nullable=False)
189      workflow = Column(String, nullable=False)