/ pkg / openai / openai.go
openai.go
  1  package openai
  2  
  3  import (
  4  	"context"
  5  	"encoding/json"
  6  	"fmt"
  7  	openai "github.com/sashabaranov/go-openai"
  8  	"go.uber.org/zap"
  9  	"io"
 10  	"krillin-ai/config"
 11  	"krillin-ai/log"
 12  	"net/http"
 13  	"os"
 14  	"strings"
 15  )
 16  
 17  func (c *Client) ChatCompletion(query string) (string, error) {
 18  	var responseFormat *openai.ChatCompletionResponseFormat
 19  
 20  	req := openai.ChatCompletionRequest{
 21  		Model: config.Conf.Llm.Model,
 22  		Messages: []openai.ChatCompletionMessage{
 23  			{
 24  				Role:    openai.ChatMessageRoleSystem,
 25  				Content: "You are an assistant that helps with subtitle translation.",
 26  			},
 27  			{
 28  				Role:    openai.ChatMessageRoleUser,
 29  				Content: query,
 30  			},
 31  		},
 32  		Temperature:    0.9,
 33  		Stream:         true,
 34  		MaxTokens:      8192,
 35  		ResponseFormat: responseFormat,
 36  	}
 37  
 38  	stream, err := c.client.CreateChatCompletionStream(context.Background(), req)
 39  	if err != nil {
 40  		log.GetLogger().Error("openai create chat completion stream failed", zap.Error(err))
 41  		return "", err
 42  	}
 43  	defer stream.Close()
 44  
 45  	var resContent string
 46  	for {
 47  		response, err := stream.Recv()
 48  		if err == io.EOF {
 49  			break
 50  		}
 51  		if err != nil {
 52  			log.GetLogger().Error("openai stream receive failed", zap.Error(err))
 53  			return "", err
 54  		}
 55  		if len(response.Choices) == 0 {
 56  			log.GetLogger().Info("openai stream receive no choices", zap.Any("response", response))
 57  			continue
 58  		}
 59  
 60  		resContent += response.Choices[0].Delta.Content
 61  	}
 62  
 63  	return resContent, nil
 64  }
 65  
 66  func (c *Client) Text2Speech(text, voice string, outputFile string) error {
 67  	baseUrl := config.Conf.Tts.Openai.BaseUrl
 68  	if baseUrl == "" {
 69  		baseUrl = "https://api.openai.com/v1"
 70  	}
 71  	url := baseUrl + "/audio/speech"
 72  
 73  	// 创建HTTP请求
 74  	reqBody := fmt.Sprintf(`{
 75  		"model": "tts-1",
 76  		"input": "%s",
 77  		"voice":"%s",
 78  		"response_format": "wav"
 79  	}`, text, voice)
 80  	req, err := http.NewRequest("POST", url, strings.NewReader(reqBody))
 81  	if err != nil {
 82  		return err
 83  	}
 84  
 85  	req.Header.Set("Content-Type", "application/json")
 86  	req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", config.Conf.Tts.Openai.ApiKey))
 87  
 88  	// 发送HTTP请求
 89  	client := &http.Client{}
 90  	resp, err := client.Do(req)
 91  	if err != nil {
 92  		return err
 93  	}
 94  	defer resp.Body.Close()
 95  
 96  	if resp.StatusCode != http.StatusOK {
 97  		body, _ := io.ReadAll(resp.Body)
 98  		log.GetLogger().Error("openai tts failed", zap.Int("status_code", resp.StatusCode), zap.String("body", string(body)))
 99  		return fmt.Errorf("openai tts none-200 status code: %d", resp.StatusCode)
100  	}
101  
102  	file, err := os.Create(outputFile)
103  	if err != nil {
104  		return err
105  	}
106  	defer file.Close()
107  
108  	_, err = io.Copy(file, resp.Body)
109  	if err != nil {
110  		return err
111  	}
112  
113  	return nil
114  }
115  
116  func parseJSONResponse(jsonStr string) (string, error) {
117  	var response struct {
118  		Translations []struct {
119  			Original   string `json:"original_sentence"`
120  			Translated string `json:"translated_sentence"`
121  		} `json:"translations"`
122  	}
123  
124  	err := json.Unmarshal([]byte(jsonStr), &response)
125  	if err != nil {
126  		return "", fmt.Errorf("failed to parse JSON: %v", err)
127  	}
128  
129  	var result strings.Builder
130  	for i, item := range response.Translations {
131  		result.WriteString(fmt.Sprintf("%d\n%s\n%s\n\n",
132  			i+1,
133  			item.Translated,
134  			item.Original))
135  	}
136  
137  	return result.String(), nil
138  }