Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speaker identification WIP #46

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions alembic/versions/b6aff0a993d7_add_person_and_voicesamples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""Add person and voicesamples

Revision ID: b6aff0a993d7
Revises: 33bddba74d25
Create Date: 2024-03-01 08:56:55.205553

"""
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa
import sqlmodel


# revision identifiers, used by Alembic.
revision: str = 'b6aff0a993d7'
down_revision: Union[str, None] = '33bddba74d25'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# Use batch operations to support SQLite ALTER TABLE for adding constraints
with op.batch_alter_table('utterance', schema=None) as batch_op:
batch_op.add_column(sa.Column('person_id', sa.Integer(), nullable=True))
batch_op.create_foreign_key('fk_utterance_person', 'person', ['person_id'], ['id'])
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to store the vector embedding of the voice here? This way, we would be able to

  • Show distinct speakers in the UI even without having the Persons in the DB
  • On creating a new person, easily find all the instances in the past when that person spoke by fetching all the utterances with similar enough voice embeddings


op.create_table('person',
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('first_name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('last_name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.PrimaryKeyConstraint('id')
)
op.create_table('voicesample',
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('filepath', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('speaker_embeddings', sa.JSON(), nullable=True),
sa.Column('person_id', sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['person_id'], ['person.id'], name='fk_voicesample_person'),
sa.PrimaryKeyConstraint('id')
)

def downgrade() -> None:
# Use batch operations for dropping column with SQLite
with op.batch_alter_table('utterance', schema=None) as batch_op:
batch_op.drop_constraint('fk_utterance_person', type_='foreignkey')
batch_op.drop_column('person_id')

# Commands for dropping tables remain unchanged
op.drop_table('voicesample')
op.drop_table('person')
37 changes: 37 additions & 0 deletions owl/core/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
import subprocess
from alembic import command
from alembic.config import Config
from ..database.database import Database
from ..database.crud import create_person, create_voice_sample
from ..models.schemas import Person, VoiceSample

import click
from rich.console import Console
Expand Down Expand Up @@ -202,6 +205,40 @@ def create_migration(config: Configuration, message: str):

console.log(f"[bold green]Migration script generated with message: '{message}'")

####################################################################################################
# Persons
####################################################################################################

@cli.command()
@add_options(_config_options)
@click.option('--first-name', required=True, help='First name of the person')
@click.option('--last-name', required=True, help='Last name of the person')
@click.option('--voice-sample-path', required=True, help='Path to the voice sample file')
def enroll_speaker(config: Configuration, first_name: str, last_name: str, voice_sample_path: str):
"""Enroll a new person with a voice sample."""
console = Console()
console.log("[bold green]Enrolling speaker...")

database = Database(config.database)
with next(database.get_db()) as db:
person = create_person(db, Person(first_name=first_name, last_name=last_name))
sample_directory = config.speaker_identification.voice_sample_directory
sample_directory = os.path.join(sample_directory, str(person.id))
os.makedirs(sample_directory, exist_ok=True)

filename = os.path.basename(voice_sample_path)
extension = os.path.splitext(filename)[1]

sample_file_path = os.path.join(sample_directory, f"{uuid.uuid1().hex}.{extension[1:]}")

with next(database.get_db()) as db:
voice_sample = create_voice_sample(db, VoiceSample(person_id=person.id, filepath=sample_file_path))
with open(voice_sample_path, "rb") as f:
with open(sample_file_path, "wb") as f2:
f2.write(f.read())

console.log(f"[bold green]Enrolled new person: '{person.id} ({voice_sample.id})'")

####################################################################################################
# Server
####################################################################################################
Expand Down
7 changes: 6 additions & 1 deletion owl/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ class StreamingTranscriptionConfiguration(BaseModel):
class AsyncTranscriptionConfiguration(BaseModel):
provider: str

class SpeakerIdentificationConfiguration(BaseModel):
provider: str
voice_sample_directory: Optional[str] = None

class DatabaseConfiguration(BaseModel):
url: str

Expand Down Expand Up @@ -104,4 +108,5 @@ def load_config_yaml(cls, config_filepath: str) -> 'Configuration':
conversation_endpointing: ConversationEndpointingConfiguration
notification: NotificationConfiguration
udp: UDPConfiguration
bing: BingConfiguration | None = None
bing: BingConfiguration | None = None
speaker_identification: SpeakerIdentificationConfiguration | None = None
14 changes: 13 additions & 1 deletion owl/database/crud.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from sqlmodel import SQLModel, Session, select
from ..models.schemas import Transcription, Conversation, Utterance, Location, CaptureSegment, Capture, ConversationState
from ..models.schemas import Transcription, Conversation, Utterance, Location, CaptureSegment, Capture, ConversationState, Person, VoiceSample
from typing import List, Optional
from sqlalchemy.orm import joinedload, selectinload
from sqlalchemy import desc, func, or_
Expand All @@ -8,6 +8,18 @@

logger = logging.getLogger(__name__)

def create_person(db: Session, person: Person) -> Person:
db.add(person)
db.commit()
db.refresh(person)
return person

def create_voice_sample(db: Session, voice_sample: VoiceSample) -> VoiceSample:
db.add(voice_sample)
db.commit()
db.refresh(voice_sample)
return voice_sample

def create_utterance(db: Session, utterance: Utterance) -> Utterance:
db.add(utterance)
db.commit()
Expand Down
17 changes: 16 additions & 1 deletion owl/models/schemas.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import List, Optional
from sqlmodel import SQLModel, Field, Relationship
from sqlmodel import SQLModel, Field, Relationship, Column, JSON
from datetime import datetime, timezone
from pydantic import BaseModel
from enum import Enum
Expand Down Expand Up @@ -36,6 +36,8 @@ class Utterance(CreatedAtMixin, table=True):
transcription: "Transcription" = Relationship(back_populates="utterances")

words: List[Word] = Relationship(back_populates="utterance", sa_relationship_kwargs={"cascade": "all, delete-orphan"})
person_id: Optional[int] = Field(default=None, foreign_key="person.id")
person: Optional["Person"] = Relationship(back_populates="utterances")

class Transcription(CreatedAtMixin, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
Expand Down Expand Up @@ -106,6 +108,19 @@ class CaptureSegment(CreatedAtMixin, table=True):

conversation: Optional[Conversation] = Relationship(back_populates="capture_segment_file")

class Person(CreatedAtMixin, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
first_name: str
last_name: str
voice_samples: List["VoiceSample"] = Relationship(back_populates="person")
utterances: List[Utterance] = Relationship(back_populates="person")

class VoiceSample(CreatedAtMixin, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
filepath: str = Field(...)
speaker_embeddings: dict = Field(default={}, sa_column=Column(JSON))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the key is embedding model can we name this explicitly speaker_embeddings_by_model?

person_id: Optional[int] = Field(default=None, foreign_key="person.id")
person: Optional["Person"] = Relationship(back_populates="voice_samples")

# API Response Models
# https://sqlmodel.tiangolo.com/tutorial/fastapi/relationships/#dont-include-all-the-data
Expand Down
4 changes: 4 additions & 0 deletions owl/sample_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ udp:
host: '0.0.0.0'
port: 8001

speaker_identification:
provider: speech_brain
voice_sample_directory: voice_samples

# To enable web search
# bing:
# subscription_key: your_bing_subscription_service_key
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from abc import ABC, abstractmethod
from ....models.schemas import Transcript

class AbstractSpeakerIdentificationService(ABC):

@abstractmethod
async def identifiy_speakers(self, transcript: Transcript, persons) -> Transcript:
pass