Skip to content

Commit

Permalink
Merge pull request #411 from tphakala/refactor-file.go
Browse files Browse the repository at this point in the history
refactor: reduce complexity of processAudioFile() function
  • Loading branch information
tphakala authored Jan 25, 2025
2 parents b18548f + 6306001 commit 61dd198
Showing 1 changed file with 152 additions and 126 deletions.
278 changes: 152 additions & 126 deletions internal/analysis/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"path/filepath"
"strings"
"sync"
"sync/atomic"
"time"

"golang.org/x/term"
Expand Down Expand Up @@ -141,61 +142,116 @@ func formatProgressLine(filename string, duration time.Duration, chunkCount, tot
baseFormat)
}

// processAudioFile processes the audio file and returns the notes.
func processAudioFile(settings *conf.Settings, audioInfo *myaudio.AudioInfo, ctx context.Context) ([]datastore.Note, error) {
// Calculate total chunks
totalChunks := myaudio.GetTotalChunks(
audioInfo.SampleRate,
audioInfo.TotalSamples,
settings.BirdNET.Overlap,
)

// Define a type for audio chunks with file position
type audioChunk struct {
Data []float32
FilePosition time.Time
}

// Calculate audio duration
duration := time.Duration(float64(audioInfo.TotalSamples) / float64(audioInfo.SampleRate) * float64(time.Second))
// monitorProgress starts a goroutine to monitor and display analysis progress
func monitorProgress(ctx context.Context, doneChan chan struct{}, filename string, duration time.Duration,
totalChunks int, chunkCount *int64, startTime time.Time) {

// Get filename and truncate if necessary
filename := filepath.Base(settings.Input.Path)

startTime := time.Now()
chunkCount := 1
lastChunkCount := 0
lastChunkCount := int64(0)
lastProgressUpdate := startTime

// Moving average window for chunks/sec calculation
const windowSize = 10 // Number of samples to average
chunkRates := make([]float64, 0, windowSize)

// Set number of workers to 1
numWorkers := 1
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()

if settings.Debug {
fmt.Printf("DEBUG: Starting analysis with %d total chunks and %d workers\n", totalChunks, numWorkers)
for {
select {
case <-ctx.Done():
return
case <-doneChan:
return
case <-ticker.C:
currentTime := time.Now()
timeSinceLastUpdate := currentTime.Sub(lastProgressUpdate)

// Get current chunk count atomically
currentCount := atomic.LoadInt64(chunkCount)

// Calculate current chunk rate
chunksProcessed := currentCount - lastChunkCount
currentRate := float64(chunksProcessed) / timeSinceLastUpdate.Seconds()

// Update moving average
if len(chunkRates) >= windowSize {
// Remove oldest value
chunkRates = chunkRates[1:]
}
chunkRates = append(chunkRates, currentRate)

// Calculate average rate
var avgRate float64
if len(chunkRates) > 0 {
sum := 0.0
for _, rate := range chunkRates {
sum += rate
}
avgRate = sum / float64(len(chunkRates))
}

// Update counters for next iteration
lastChunkCount = currentCount
lastProgressUpdate = currentTime

// Get terminal width
width, _, err := term.GetSize(int(os.Stdout.Fd()))
if err != nil {
width = 80 // Default to 80 columns if we can't get terminal width
}

// Format and print the progress line
fmt.Print(formatProgressLine(
filename,
duration,
int(currentCount),
totalChunks,
avgRate,
birdnet.EstimateTimeRemaining(startTime, int(currentCount), totalChunks),
width,
))
}
}
}

// Create buffered channels for processing
chunkChan := make(chan audioChunk, 4)
resultChan := make(chan []datastore.Note, 4)
errorChan := make(chan error, 1)
doneChan := make(chan struct{})
// processChunk handles the processing of a single audio chunk
func processChunk(ctx context.Context, chunk audioChunk, settings *conf.Settings,
resultChan chan<- []datastore.Note, errorChan chan<- error) error {

var allNotes []datastore.Note
notes, err := bn.ProcessChunk(chunk.Data, chunk.FilePosition)
if err != nil {
// Block until we can send the error or context is cancelled
select {
case errorChan <- err:
// Error successfully sent
case <-ctx.Done():
// If context is done while trying to send error, prioritize context error
return ctx.Err()
}
return err
}

// Create a single cancel function to coordinate shutdown
var doneChanClosed sync.Once
shutdown := func() {
doneChanClosed.Do(func() {
close(doneChan)
})
// Filter notes based on included species list
var filteredNotes []datastore.Note
for i := range notes {
if settings.IsSpeciesIncluded(notes[i].ScientificName) {
filteredNotes = append(filteredNotes, notes[i])
}
}
defer shutdown()

// Start worker goroutines for BirdNET analysis
// Block until we can send results or context is cancelled
select {
case <-ctx.Done():
return ctx.Err()
case resultChan <- filteredNotes:
return nil
}
}

// startWorkers initializes and starts the worker goroutines for audio analysis
func startWorkers(ctx context.Context, numWorkers int, chunkChan chan audioChunk,
resultChan chan []datastore.Note, errorChan chan error, settings *conf.Settings) {

for i := 0; i < numWorkers; i++ {
go func(workerID int) {
if settings.Debug {
Expand All @@ -213,105 +269,74 @@ func processAudioFile(settings *conf.Settings, audioInfo *myaudio.AudioInfo, ctx
select {
case errorChan <- ctx.Err():
default:
// Another goroutine already sent the error
}
return
default:
}

notes, err := bn.ProcessChunk(chunk.Data, chunk.FilePosition)
if err != nil {
if err := processChunk(ctx, chunk, settings, resultChan, errorChan); err != nil {
if settings.Debug {
fmt.Printf("DEBUG: Worker %d encountered error: %v\n", workerID, err)
}
select {
case errorChan <- err:
default:
// Another goroutine already sent an error
}
return
}

// Filter notes based on included species list
var filteredNotes []datastore.Note
for _, note := range notes {
if settings.IsSpeciesIncluded(note.ScientificName) {
filteredNotes = append(filteredNotes, note)
}
}

select {
case <-ctx.Done():
select {
case errorChan <- ctx.Err():
default:
}
return
case resultChan <- filteredNotes:
}
}
}(i)
}
}

// Start progress monitoring goroutine
go func() {
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
// Define audioChunk type at package level since it's used by multiple functions
type audioChunk struct {
Data []float32
FilePosition time.Time
}

for {
select {
case <-ctx.Done():
return
case <-doneChan:
return
case <-ticker.C:
currentTime := time.Now()
timeSinceLastUpdate := currentTime.Sub(lastProgressUpdate)

// Calculate current chunk rate
chunksProcessed := chunkCount - lastChunkCount
currentRate := float64(chunksProcessed) / timeSinceLastUpdate.Seconds()

// Update moving average
if len(chunkRates) >= windowSize {
// Remove oldest value
chunkRates = chunkRates[1:]
}
chunkRates = append(chunkRates, currentRate)

// Calculate average rate
var avgRate float64
if len(chunkRates) > 0 {
sum := 0.0
for _, rate := range chunkRates {
sum += rate
}
avgRate = sum / float64(len(chunkRates))
}
func processAudioFile(settings *conf.Settings, audioInfo *myaudio.AudioInfo, ctx context.Context) ([]datastore.Note, error) {
// Calculate total chunks
totalChunks := myaudio.GetTotalChunks(
audioInfo.SampleRate,
audioInfo.TotalSamples,
settings.BirdNET.Overlap,
)

// Update counters for next iteration
lastChunkCount = chunkCount
lastProgressUpdate = currentTime
// Calculate audio duration
duration := time.Duration(float64(audioInfo.TotalSamples) / float64(audioInfo.SampleRate) * float64(time.Second))

// Get terminal width
width, _, err := term.GetSize(int(os.Stdout.Fd()))
if err != nil {
width = 80 // Default to 80 columns if we can't get terminal width
}
// Get filename and truncate if necessary
filename := filepath.Base(settings.Input.Path)

// Format and print the progress line
fmt.Print(formatProgressLine(
filename,
duration,
chunkCount,
totalChunks,
avgRate,
birdnet.EstimateTimeRemaining(startTime, chunkCount, totalChunks),
width,
))
}
}
}()
startTime := time.Now()
var chunkCount int64 = 1

// Set number of workers to 1
numWorkers := 1

if settings.Debug {
fmt.Printf("DEBUG: Starting analysis with %d total chunks and %d workers\n", totalChunks, numWorkers)
}

// Create buffered channels for processing
chunkChan := make(chan audioChunk, 4)
resultChan := make(chan []datastore.Note, 4)
errorChan := make(chan error, 1)
doneChan := make(chan struct{})

var allNotes []datastore.Note

// Create a single cancel function to coordinate shutdown
var doneChanClosed sync.Once
shutdown := func() {
doneChanClosed.Do(func() {
close(doneChan)
})
}
defer shutdown()

// Start worker goroutines
startWorkers(ctx, numWorkers, chunkChan, resultChan, errorChan, settings)

// Start progress monitoring goroutine
go monitorProgress(ctx, doneChan, filename, duration, totalChunks, &chunkCount, startTime)

// Start result collector goroutine
var processingError error
Expand All @@ -332,11 +357,11 @@ func processAudioFile(settings *conf.Settings, audioInfo *myaudio.AudioInfo, ctx
return
case notes := <-resultChan:
if settings.Debug {
fmt.Printf("DEBUG: Received results for chunk #%d\n", chunkCount)
fmt.Printf("DEBUG: Received results for chunk #%d\n", atomic.LoadInt64(&chunkCount))
}
allNotes = append(allNotes, notes...)
chunkCount++
if chunkCount > totalChunks {
atomic.AddInt64(&chunkCount, 1)
if atomic.LoadInt64(&chunkCount) > int64(totalChunks) {
return
}
case err := <-errorChan:
Expand All @@ -352,7 +377,8 @@ func processAudioFile(settings *conf.Settings, audioInfo *myaudio.AudioInfo, ctx
fmt.Printf("DEBUG: Timeout waiting for chunk %d results\n", i)
}
processingErrorMutex.Lock()
processingError = fmt.Errorf("timeout waiting for analysis results (processed %d/%d chunks)", chunkCount, totalChunks)
currentCount := atomic.LoadInt64(&chunkCount)
processingError = fmt.Errorf("timeout waiting for analysis results (processed %d/%d chunks)", currentCount, totalChunks)
processingErrorMutex.Unlock()
return
}
Expand Down

0 comments on commit 61dd198

Please sign in to comment.