/ chat_workflow / storage_client.py
storage_client.py
1 import boto3 2 import hashlib 3 import uuid 4 from typing import Any, Dict, Union 5 from chainlit.logger import logger 6 from chainlit.data.storage_clients.base import BaseStorageClient 7 from sqlalchemy import Column, String, Boolean, Integer, ForeignKey, Text, JSON 8 from sqlalchemy.dialects.postgresql import UUID as PG_UUID, JSONB, ARRAY 9 from sqlalchemy.ext.declarative import declarative_base 10 from sqlalchemy.orm import relationship 11 12 13 class MinIOStorageClient(BaseStorageClient): 14 """ 15 Class to enable MinIO storage provider using the S3 compatible API 16 """ 17 18 def __init__(self, bucket: str, endpoint_url: str, access_key: str, secret_key: str): 19 try: 20 self.bucket = bucket 21 # Initialize boto3 client with MinIO-specific configurations 22 self.client = boto3.client( 23 "s3", 24 endpoint_url=endpoint_url, 25 aws_access_key_id=access_key, 26 aws_secret_access_key=secret_key, 27 ) 28 response = self.client.list_buckets() 29 existing_buckets = [bucket['Name'] 30 for bucket in response['Buckets']] 31 logger.info( 32 f"Successfully connected. Available buckets: {existing_buckets}") 33 34 if self.bucket not in existing_buckets: 35 logger.info( 36 f"Bucket '{self.bucket}' does not exist. Creating it now.") 37 self.client.create_bucket(Bucket=self.bucket) 38 logger.info(f"Bucket '{self.bucket}' created successfully.") 39 else: 40 logger.info(f"Bucket '{self.bucket}' already exists.") 41 42 logger.info("MinIOStorageClient initialized") 43 except Exception as e: 44 logger.warning(f"MinIOStorageClient initialization error: {e}") 45 46 async def upload_file( 47 self, 48 object_key: str, 49 data: Union[bytes, str], 50 mime: str = "application/octet-stream", 51 overwrite: bool = True, 52 content_md5: bool = False # Optionally send content-md5 53 ) -> Dict[str, Any]: 54 try: 55 # Optionally calculate and send content-md5 56 extra_args = {"ContentType": mime} 57 if content_md5: 58 md5_hash = hashlib.md5(data if isinstance( 59 data, bytes) else data.encode('utf-8')).digest() 60 extra_args["ContentMD5"] = md5_hash 61 self.client.put_object( 62 Bucket=self.bucket, Key=object_key, Body=data, **extra_args 63 ) 64 url = f"{self.client.meta.endpoint_url}/{self.bucket}/{object_key}" 65 return {"object_key": object_key, "url": url} 66 except Exception as e: 67 logger.warning(f"MinIOStorageClient, upload_file error: {e}") 68 return {} 69 70 async def get_read_url(self, object_key: str) -> str: 71 """ 72 Generate a presigned URL for reading/downloading a file from MinIO 73 """ 74 try: 75 url = self.client.generate_presigned_url( 76 'get_object', 77 Params={'Bucket': self.bucket, 'Key': object_key}, 78 ExpiresIn=3600 # URL valid for 1 hour 79 ) 80 return url 81 except Exception as e: 82 logger.warning(f"MinIOStorageClient, get_read_url error: {e}") 83 return f"{self.client.meta.endpoint_url}/{self.bucket}/{object_key}" 84 85 async def delete_file(self, object_key: str) -> bool: 86 """ 87 Delete a file from MinIO storage 88 """ 89 try: 90 self.client.delete_object(Bucket=self.bucket, Key=object_key) 91 logger.info(f"Deleted file: {object_key}") 92 return True 93 except Exception as e: 94 logger.warning(f"MinIOStorageClient, delete_file error: {e}") 95 return False 96 97 async def close(self) -> None: 98 """ 99 Close the storage client connection (cleanup if needed) 100 """ 101 # boto3 client doesn't require explicit closing 102 logger.info("MinIOStorageClient closed") 103 pass 104 105 106 Base = declarative_base() 107 108 109 class User(Base): 110 __tablename__ = 'users' 111 id = Column(PG_UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) 112 identifier = Column(String, nullable=False, unique=True) 113 metadata_ = Column("metadata", JSONB, nullable=False) 114 createdAt = Column(String) 115 116 117 class Thread(Base): 118 __tablename__ = 'threads' 119 id = Column(PG_UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) 120 createdAt = Column(String) 121 name = Column(String) 122 userId = Column(PG_UUID(as_uuid=True), ForeignKey( 123 'users.id', ondelete='CASCADE')) 124 userIdentifier = Column(String) 125 tags = Column(ARRAY(String)) 126 metadata_ = Column("metadata", JSONB) 127 128 user = relationship("User", backref="threads") 129 130 131 class Step(Base): 132 __tablename__ = 'steps' 133 id = Column(PG_UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) 134 name = Column(String, nullable=False) 135 type = Column(String, nullable=False) 136 threadId = Column(PG_UUID(as_uuid=True), ForeignKey( 137 'threads.id'), nullable=False) 138 parentId = Column(PG_UUID(as_uuid=True)) 139 disableFeedback = Column(Boolean, nullable=True) 140 streaming = Column(Boolean, nullable=False) 141 waitForAnswer = Column(Boolean) 142 isError = Column(Boolean) 143 metadata_ = Column("metadata", JSONB) 144 tags = Column(ARRAY(String)) 145 input = Column(Text) 146 output = Column(Text) 147 createdAt = Column(String) 148 start = Column(String) 149 end = Column(String) 150 generation = Column(JSONB) 151 showInput = Column(Text) 152 language = Column(String) 153 indent = Column(Integer) 154 155 156 class Element(Base): 157 __tablename__ = 'elements' 158 id = Column(PG_UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) 159 threadId = Column(PG_UUID(as_uuid=True), ForeignKey('threads.id')) 160 type = Column(String) 161 url = Column(String) 162 chainlitKey = Column(String) 163 name = Column(String, nullable=False) 164 display = Column(String) 165 objectKey = Column(String) 166 size = Column(String) 167 page = Column(Integer) 168 language = Column(String) 169 forId = Column(PG_UUID(as_uuid=True)) 170 mime = Column(String) 171 172 173 class Feedback(Base): 174 __tablename__ = 'feedbacks' 175 id = Column(PG_UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) 176 forId = Column(PG_UUID(as_uuid=True), nullable=False) 177 threadId = Column(PG_UUID(as_uuid=True), ForeignKey( 178 'threads.id'), nullable=False) 179 value = Column(Integer, nullable=False) 180 comment = Column(Text) 181 182 thread = relationship("Thread", backref="feedbacks") 183 184 185 class LangGraph(Base): 186 __tablename__ = 'langgraphs' 187 thread_id = Column(String, primary_key=True) 188 state = Column(JSON, nullable=False) 189 workflow = Column(String, nullable=False)