Skip to content

Commit

Permalink
Merge pull request #199 from shiguredo/feature/add-ogg-file-output
Browse files Browse the repository at this point in the history
音声データを Ogg ファイルで出力できる機能の追加
  • Loading branch information
Hexa authored Jan 8, 2025
2 parents a346123 + dffcb32 commit 1800950
Show file tree
Hide file tree
Showing 10 changed files with 237 additions and 20 deletions.
9 changes: 9 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,15 @@

## develop

- [ADD] 受信した音声データを Ogg ファイルで保存するかを指定する enable_ogg_file_output を追加する
- 保存するファイル名は、sora-session-id ヘッダーと sora-connection-id ヘッダーの値を使用して作成する
- ${sora-session-id}-${sora-connection-id}.ogg
- デフォルト値: false
- @Hexa
- [ADD] 受信した音声データを Ogg ファイルで保存する場合の保存先ディレクトリを指定する ogg_dir を追加する
- デフォルト値: .
- @Hexa

### misc

- [CHANGE] GitHub Actions の ubuntu-latest を ubuntu-24.04 に変更する
Expand Down
7 changes: 5 additions & 2 deletions amazon_transcribe_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,13 @@ func (h *AmazonTranscribeHandler) ResetRetryCount() int {
return h.RetryCount
}

func (h *AmazonTranscribeHandler) Handle(ctx context.Context, opusCh chan opusChannel) (*io.PipeReader, error) {
func (h *AmazonTranscribeHandler) Handle(ctx context.Context, opusCh chan opusChannel, header soraHeader) (*io.PipeReader, error) {
at := NewAmazonTranscribe(h.Config, h.LanguageCode, int64(h.SampleRate), int64(h.ChannelCount))

packetReader := opus2ogg(ctx, opusCh, h.SampleRate, h.ChannelCount, h.Config)
packetReader, err := opus2ogg(ctx, opusCh, h.SampleRate, h.ChannelCount, h.Config, header)
if err != nil {
return nil, err
}

stream, err := at.Start(ctx, packetReader)
if err != nil {
Expand Down
7 changes: 7 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ type Config struct {
SampleRate int `ini:"audio_sample_rate"`
ChannelCount int `ini:"audio_channel_count"`

EnableOggFileOutput bool `ini:"enable_ogg_file_output"`
OggDir string `ini:"ogg_dir"`

DumpFile string `ini:"dump_file"`

LogDir string `ini:"log_dir"`
Expand Down Expand Up @@ -173,6 +176,10 @@ func setDefaultsConfig(config *Config) {
if config.RetryIntervalMs == 0 {
config.RetryIntervalMs = DefaultRetryIntervalMs
}

if config.OggDir == "" {
config.OggDir = "."
}
}

func validateConfig(config *Config) error {
Expand Down
4 changes: 4 additions & 0 deletions config_example.ini
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ retry_interval_ms = 100
# aws の場合は IsPartial が false, gcp の場合は IsFinal が true の場合の最終的な結果のみを返す指定
final_result_only = true

# 受信した音声データを Ogg ファイルで保存するかどうかです
enable_ogg_file_output = false
# Ogg ファイルの保存先ディレクトリです
ogg_dir = "."

# 採用する結果の信頼スコアの最小値です(aws 指定時のみ有効)
# minimum_confidence_score が 0.0 の場合は信頼スコアによるフィルタリングは無効です
Expand Down
52 changes: 39 additions & 13 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"fmt"
"io"
"net/http"
"os"
"path"
"strings"
"time"

Expand Down Expand Up @@ -41,6 +43,16 @@ func NewSuzuErrorResponse(err error) TranscriptionResult {
}
}

type soraHeader struct {
SoraChannelID string `header:"sora-channel-id"`
SoraSessionID string `header:"sora-session-id"`
// SoraClientID string `header:"sora-client-id"`
SoraConnectionID string `header:"sora-connection-id"`
// SoraAudioCodecType string `header:"sora-audio-codec-type"`
// SoraAudioSampleRate int64 `header:"sora-audio-sample-rate"`
SoraAudioStreamingLanguageCode string `header:"sora-audio-streaming-language-code"`
}

func getServiceHandler(serviceType string, config Config, channelID, connectionID string, sampleRate uint32, channelCount uint16, languageCode string, onResultFunc any) (serviceHandlerInterface, error) {
newHandlerFunc, err := NewServiceHandlerFuncs.get(serviceType)
if err != nil {
Expand All @@ -65,15 +77,7 @@ func (s *Server) createSpeechHandler(serviceType string, onResultFunc func(conte
return echo.NewHTTPError(http.StatusBadRequest)
}

h := struct {
SoraChannelID string `header:"Sora-Channel-Id"`
// SoraSessionID string `header:"sora-session-id"`
// SoraClientID string `header:"sora-client-id"`
SoraConnectionID string `header:"sora-connection-id"`
// SoraAudioCodecType string `header:"sora-audio-codec-type"`
// SoraAudioSampleRate int64 `header:"sora-audio-sample-rate"`
SoraAudioStreamingLanguageCode string `header:"sora-audio-streaming-language-code"`
}{}
h := soraHeader{}
if err := (&echo.DefaultBinder{}).BindHeaders(c, &h); err != nil {
zlog.Error().
Err(err).
Expand Down Expand Up @@ -153,7 +157,7 @@ func (s *Server) createSpeechHandler(serviceType string, onResultFunc func(conte
serviceHandlerCtx, cancelServiceHandler := context.WithCancel(ctx)
defer cancelServiceHandler()

reader, err := serviceHandler.Handle(serviceHandlerCtx, opusCh)
reader, err := serviceHandler.Handle(serviceHandlerCtx, opusCh, h)
if err != nil {
zlog.Error().
Err(err).
Expand Down Expand Up @@ -459,17 +463,39 @@ func readOpus(ctx context.Context, reader io.Reader) chan opusChannel {
return opusCh
}

func opus2ogg(ctx context.Context, opusCh chan opusChannel, sampleRate uint32, channelCount uint16, c Config) io.ReadCloser {
func opus2ogg(ctx context.Context, opusCh chan opusChannel, sampleRate uint32, channelCount uint16, c Config, header soraHeader) (io.ReadCloser, error) {
oggReader, oggWriter := io.Pipe()

writers := []io.Writer{}

var f *os.File
if c.EnableOggFileOutput {
fileName := fmt.Sprintf("%s-%s.ogg", header.SoraSessionID, header.SoraConnectionID)
filePath := path.Join(c.OggDir, fileName)

var err error
f, err = os.Create(filePath)
if err != nil {
return nil, err
}
writers = append(writers, f)
}
writers = append(writers, oggWriter)

multiWriter := io.MultiWriter(writers...)

go func() {
o, err := NewWith(oggWriter, sampleRate, channelCount)
o, err := NewWith(multiWriter, sampleRate, channelCount)
if err != nil {
oggWriter.CloseWithError(err)
return
}
defer o.Close()

if c.EnableOggFileOutput {
o.fd = f
}

for {
select {
case <-ctx.Done():
Expand Down Expand Up @@ -501,7 +527,7 @@ func opus2ogg(ctx context.Context, opusCh chan opusChannel, sampleRate uint32, c
}
}()

return oggReader
return oggReader, nil
}

type opusRequest struct {
Expand Down
165 changes: 165 additions & 0 deletions handler_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
package suzu

import (
"context"
"errors"
"fmt"
"io"
"os"
"path/filepath"
"testing"
"time"

Expand Down Expand Up @@ -310,3 +314,164 @@ func TestReadPacketWithHeader(t *testing.T) {
})
}
}

func TestOggFileWriting(t *testing.T) {
t.Run("success", func(t *testing.T) {
oggDir, err := os.MkdirTemp("", "ogg-")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(oggDir)

c := Config{
EnableOggFileOutput: true,
OggDir: oggDir,
}

header := soraHeader{
SoraChannelID: "ogg-test",
SoraSessionID: "C2TFB1QBDS4WD5SX317SWMJ6FM",
SoraConnectionID: "1X0Z8JXZAD5A93X68M2S9NTC4G",
}

opusCh := make(chan opusChannel)
defer close(opusCh)

sampleRate := uint32(48000)
channelCount := uint16(1)

ctx := context.Background()
reader, err := opus2ogg(ctx, opusCh, sampleRate, channelCount, c, header)
if assert.NoError(t, err) {
assert.NotNil(t, reader)
}
defer reader.Close()

// ファイルへの書き込み待ち
time.Sleep(100 * time.Millisecond)

filename := fmt.Sprintf("%s-%s.ogg", header.SoraSessionID, header.SoraConnectionID)
filePath := filepath.Join(oggDir, filename)
_, err = os.Stat(filePath)
assert.NoError(t, err)

// Ogg ファイルのヘッダーを確認
f, err := os.Open(filePath)
if err != nil {
t.Fatal(err)
}
defer f.Close()

buf := make([]byte, 4)
n, err := f.Read(buf)
assert.NoError(t, err)
assert.Equal(t, []byte(`OggS`), buf[:n])
})

t.Run("disable_ogg_file_output", func(t *testing.T) {
oggDir, err := os.MkdirTemp("", "ogg-")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(oggDir)

c := Config{
EnableOggFileOutput: false,
OggDir: oggDir,
}

header := soraHeader{
SoraChannelID: "ogg-test",
SoraSessionID: "C2TFB1QBDS4WD5SX317SWMJ6FM",
SoraConnectionID: "1X0Z8JXZAD5A93X68M2S9NTC4G",
}

opusCh := make(chan opusChannel)
defer close(opusCh)

sampleRate := uint32(48000)
channelCount := uint16(1)

ctx := context.Background()
reader, err := opus2ogg(ctx, opusCh, sampleRate, channelCount, c, header)
assert.NoError(t, err)
assert.NotNil(t, reader)
defer reader.Close()

filename := fmt.Sprintf("%s-%s.ogg", header.SoraSessionID, header.SoraConnectionID)
filePath := filepath.Join(oggDir, filename)
_, err = os.Stat(filePath)
assert.ErrorIs(t, err, os.ErrNotExist)
})

t.Run("no permission", func(t *testing.T) {
oggDir, err := os.MkdirTemp("", "ogg-")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(oggDir)

// 書き込み権限を剥奪
if err := os.Chmod(oggDir, 0000); err != nil {
t.Fatal(err)
}
defer func() {
if err := os.Chmod(oggDir, 0700); err != nil {
t.Fatal(err)
}
}()

c := Config{
EnableOggFileOutput: true,
OggDir: oggDir,
}

header := soraHeader{
SoraChannelID: "ogg-test",
SoraSessionID: "C2TFB1QBDS4WD5SX317SWMJ6FM",
SoraConnectionID: "1X0Z8JXZAD5A93X68M2S9NTC4G",
}

opusCh := make(chan opusChannel)
defer close(opusCh)

sampleRate := uint32(48000)
channelCount := uint16(1)

ctx := context.Background()
reader, err := opus2ogg(ctx, opusCh, sampleRate, channelCount, c, header)
assert.ErrorIs(t, err, os.ErrPermission)
assert.Nil(t, reader)
})

t.Run("directory does not exist", func(t *testing.T) {
oggDir, err := os.MkdirTemp("", "ogg-")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(oggDir)

c := Config{
EnableOggFileOutput: true,
// 既存のディレクトリ名に 0 を付与して存在しないディレクトリを指定する
OggDir: oggDir + "0",
}

header := soraHeader{
SoraChannelID: "ogg-test",
SoraSessionID: "C2TFB1QBDS4WD5SX317SWMJ6FM",
SoraConnectionID: "1X0Z8JXZAD5A93X68M2S9NTC4G",
}

opusCh := make(chan opusChannel)
defer close(opusCh)

sampleRate := uint32(48000)
channelCount := uint16(1)

ctx := context.Background()
reader, err := opus2ogg(ctx, opusCh, sampleRate, channelCount, c, header)
assert.ErrorIs(t, err, os.ErrNotExist)
assert.Nil(t, reader)
})
}
2 changes: 1 addition & 1 deletion packet_dump_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func (h *PacketDumpHandler) ResetRetryCount() int {
return h.RetryCount
}

func (h *PacketDumpHandler) Handle(ctx context.Context, opusCh chan opusChannel) (*io.PipeReader, error) {
func (h *PacketDumpHandler) Handle(ctx context.Context, opusCh chan opusChannel, header soraHeader) (*io.PipeReader, error) {
c := h.Config
filename := c.DumpFile
channelID := h.ChannelID
Expand Down
2 changes: 1 addition & 1 deletion service_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ var (
)

type serviceHandlerInterface interface {
Handle(context.Context, chan opusChannel) (*io.PipeReader, error)
Handle(context.Context, chan opusChannel, soraHeader) (*io.PipeReader, error)
UpdateRetryCount() int
GetRetryCount() int
ResetRetryCount() int
Expand Down
7 changes: 5 additions & 2 deletions speech_to_text_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,13 @@ func (h *SpeechToTextHandler) ResetRetryCount() int {
return h.RetryCount
}

func (h *SpeechToTextHandler) Handle(ctx context.Context, opusCh chan opusChannel) (*io.PipeReader, error) {
func (h *SpeechToTextHandler) Handle(ctx context.Context, opusCh chan opusChannel, header soraHeader) (*io.PipeReader, error) {
stt := NewSpeechToText(h.Config, h.LanguageCode, int32(h.SampleRate), int32(h.ChannelCount))

packetReader := opus2ogg(ctx, opusCh, h.SampleRate, h.ChannelCount, h.Config)
packetReader, err := opus2ogg(ctx, opusCh, h.SampleRate, h.ChannelCount, h.Config, header)
if err != nil {
return nil, err
}

stream, err := stt.Start(ctx, packetReader)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion test_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func (h *TestHandler) ResetRetryCount() int {
return h.RetryCount
}

func (h *TestHandler) Handle(ctx context.Context, opusCh chan opusChannel) (*io.PipeReader, error) {
func (h *TestHandler) Handle(ctx context.Context, opusCh chan opusChannel, header soraHeader) (*io.PipeReader, error) {
r, w := io.Pipe()

reader := opusChannelToIOReadCloser(ctx, opusCh)
Expand Down

0 comments on commit 1800950

Please sign in to comment.