/ pkg / whispercpp / transcription.go
transcription.go
  1  package whispercpp
  2  
  3  import (
  4  	"encoding/json"
  5  	"fmt"
  6  	"krillin-ai/internal/storage"
  7  	"krillin-ai/internal/types"
  8  	"krillin-ai/log"
  9  	"krillin-ai/pkg/util"
 10  	"os"
 11  	"os/exec"
 12  	"regexp"
 13  	"strconv"
 14  	"strings"
 15  
 16  	"go.uber.org/zap"
 17  )
 18  
 19  func (c *WhispercppProcessor) Transcription(audioFile, language, workDir string) (*types.TranscriptionData, error) {
 20  	name := util.ChangeFileExtension(audioFile, "")
 21  	cmdArgs := []string{
 22  		"-m", fmt.Sprintf("./models/whispercpp/ggml-%s.bin", c.Model),
 23  		"--output-json-full",
 24  		"--flash-attn",
 25  		"--split-on-word",
 26  		"--language", language,
 27  		"--output-file", name,
 28  		"--file", audioFile,
 29  	}
 30  	cmd := exec.Command(storage.WhispercppPath, cmdArgs...)
 31  	log.GetLogger().Info("WhispercppProcessor转录开始", zap.String("cmd", cmd.String()))
 32  	output, err := cmd.CombinedOutput()
 33  	if err != nil && !strings.Contains(string(output), "output_json: saving output to") {
 34  		log.GetLogger().Error("WhispercppProcessor  cmd 执行失败", zap.String("output", string(output)), zap.Error(err))
 35  		return nil, err
 36  	}
 37  	log.GetLogger().Info("WhispercppProcessor转录json生成完毕", zap.String("audio file", audioFile))
 38  
 39  	var result types.WhispercppOutput
 40  	fileData, err := os.Open(util.ChangeFileExtension(audioFile, ".json"))
 41  	if err != nil {
 42  		log.GetLogger().Error("WhispercppProcessor 打开json文件失败", zap.Error(err))
 43  		return nil, err
 44  	}
 45  	defer fileData.Close()
 46  	decoder := json.NewDecoder(fileData)
 47  	if err = decoder.Decode(&result); err != nil {
 48  		log.GetLogger().Error("WhispercppProcessor 解析json文件失败", zap.Error(err))
 49  		return nil, err
 50  	}
 51  
 52  	var (
 53  		transcriptionData types.TranscriptionData
 54  		num               int
 55  	)
 56  	for _, segment := range result.Transcription {
 57  		transcriptionData.Text += strings.ReplaceAll(segment.Text, "—", " ") // 连字符处理,因为模型存在很多错误添加到连字符
 58  		for _, word := range segment.Tokens {
 59  			fromSec, err := parseTimestampToSeconds(word.Timestamps.From)
 60  			if err != nil {
 61  				log.GetLogger().Error("解析开始时间失败", zap.Error(err))
 62  				return nil, err
 63  			}
 64  
 65  			toSec, err := parseTimestampToSeconds(word.Timestamps.To)
 66  			if err != nil {
 67  				log.GetLogger().Error("解析结束时间失败", zap.Error(err))
 68  				return nil, err
 69  			}
 70  			regex := regexp.MustCompile(`^\[.*\]$`)
 71  			if regex.MatchString(word.Text) {
 72  				continue
 73  			} else if strings.Contains(word.Text, "—") {
 74  				// 对称切分
 75  				mid := (fromSec + toSec) / 2
 76  				seperatedWords := strings.Split(word.Text, "—")
 77  				transcriptionData.Words = append(transcriptionData.Words, []types.Word{
 78  					{
 79  						Num:   num,
 80  						Text:  util.CleanPunction(strings.TrimSpace(seperatedWords[0])),
 81  						Start: fromSec,
 82  						End:   mid,
 83  					},
 84  					{
 85  						Num:   num + 1,
 86  						Text:  util.CleanPunction(strings.TrimSpace(seperatedWords[1])),
 87  						Start: mid,
 88  						End:   toSec,
 89  					},
 90  				}...)
 91  				num += 2
 92  			} else {
 93  				transcriptionData.Words = append(transcriptionData.Words, types.Word{
 94  					Num:   num,
 95  					Text:  util.CleanPunction(strings.TrimSpace(word.Text)),
 96  					Start: fromSec,
 97  					End:   toSec,
 98  				})
 99  				num++
100  			}
101  		}
102  	}
103  	log.GetLogger().Info("WhispercppProcessor转录成功")
104  	return &transcriptionData, nil
105  }
106  
107  // 新增时间戳转换函数
108  func parseTimestampToSeconds(timeStr string) (float64, error) {
109  	parts := strings.Split(timeStr, ",")
110  	if len(parts) != 2 {
111  		return 0, fmt.Errorf("invalid timestamp format: %s", timeStr)
112  	}
113  
114  	timePart := strings.Split(parts[0], ":")
115  	if len(timePart) != 3 {
116  		return 0, fmt.Errorf("invalid time format: %s", parts[0])
117  	}
118  
119  	hours, _ := strconv.Atoi(timePart[0])
120  	minutes, _ := strconv.Atoi(timePart[1])
121  	seconds, _ := strconv.Atoi(timePart[2])
122  	milliseconds, _ := strconv.Atoi(parts[1])
123  
124  	return float64(hours*3600+minutes*60+seconds) + float64(milliseconds)/1000, nil
125  }