Skip to content

Services

ASR Service module that handle all AI interactions.

ASRAsyncService

Bases: ASRService

ASR Service module for async endpoints.

Source code in src/wordcab_transcribe/services/asr_service.py
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
class ASRAsyncService(ASRService):
    """ASR Service module for async endpoints."""

    def __init__(
        self,
        whisper_model: str,
        compute_type: str,
        window_lengths: List[float],
        shift_lengths: List[float],
        multiscale_weights: List[float],
        extra_languages: Union[List[str], None],
        extra_languages_model_paths: Union[List[str], None],
        transcribe_server_urls: Union[List[str], None],
        diarize_server_urls: Union[List[str], None],
        debug_mode: bool,
    ) -> None:
        """
        Initialize the ASRAsyncService class.

        Args:
            whisper_model (str):
                The path to the whisper model.
            compute_type (str):
                The compute type to use for inference.
            window_lengths (List[float]):
                The window lengths to use for diarization.
            shift_lengths (List[float]):
                The shift lengths to use for diarization.
            multiscale_weights (List[float]):
                The multiscale weights to use for diarization.
            extra_languages (Union[List[str], None]):
                The list of extra languages to support.
            extra_languages_model_paths (Union[List[str], None]):
                The list of paths to the extra language models.
            use_remote_servers (bool):
                Whether to use remote servers for transcription and diarization.
            transcribe_server_urls (Union[List[str], None]):
                The list of URLs to the remote transcription servers.
            diarize_server_urls (Union[List[str], None]):
                The list of URLs to the remote diarization servers.
            debug_mode (bool):
                Whether to run in debug mode.
        """
        super().__init__()

        self.whisper_model: str = whisper_model
        self.compute_type: str = compute_type
        self.window_lengths: List[float] = window_lengths
        self.shift_lengths: List[float] = shift_lengths
        self.multiscale_weights: List[float] = multiscale_weights
        self.extra_languages: Union[List[str], None] = extra_languages
        self.extra_languages_model_paths: Union[List[str], None] = (
            extra_languages_model_paths
        )

        self.local_services: LocalServiceRegistry = LocalServiceRegistry()
        self.remote_services: RemoteServiceRegistry = RemoteServiceRegistry()
        self.dual_channel_transcribe_options: dict = {
            "beam_size": 5,
            "patience": 1,
            "length_penalty": 1,
            "suppress_blank": False,
            "word_timestamps": True,
            "temperature": 0.0,
        }

        if transcribe_server_urls is not None:
            logger.info(
                "You provided URLs for remote transcription server, no local model will"
                " be used."
            )
            self.remote_services.transcription = RemoteServiceConfig(
                use_remote=True,
                url_handler=URLService(remote_urls=transcribe_server_urls),
            )
        else:
            logger.info(
                "You did not provide URLs for remote transcription server, local model"
                " will be used."
            )
            self.create_transcription_local_service()

        if diarize_server_urls is not None:
            logger.info(
                "You provided URLs for remote diarization server, no local model will"
                " be used."
            )
            self.remote_services.diarization = RemoteServiceConfig(
                use_remote=True,
                url_handler=URLService(remote_urls=diarize_server_urls),
            )
        else:
            logger.info(
                "You did not provide URLs for remote diarization server, local model"
                " will be used."
            )
            self.create_diarization_local_service()

        self.debug_mode = debug_mode

    def create_transcription_local_service(self) -> None:
        """Create a local transcription service."""
        self.local_services.transcription = TranscribeService(
            model_path=self.whisper_model,
            compute_type=self.compute_type,
            device=self.device,
            device_index=self.device_index,
            extra_languages=self.extra_languages,
            extra_languages_model_paths=self.extra_languages_model_paths,
        )

    def create_diarization_local_service(self) -> None:
        """Create a local diarization service."""
        self.local_services.diarization = DiarizeService(
            device=self.device,
            device_index=self.device_index,
            window_lengths=self.window_lengths,
            shift_lengths=self.shift_lengths,
            multiscale_weights=self.multiscale_weights,
        )

    def create_local_service(
        self, task: Literal["transcription", "diarization"]
    ) -> None:
        """Create a local service."""
        if task == "transcription":
            self.create_transcription_local_service()
        elif task == "diarization":
            self.create_diarization_local_service()
        else:
            raise NotImplementedError("No task specified.")

    async def inference_warmup(self) -> None:
        """Warmup the GPU by loading the models."""
        sample_path = Path(__file__).parent.parent / "assets/warmup_sample.wav"

        for gpu_index in self.gpu_handler.device_index:
            logger.info(f"Warmup GPU {gpu_index}.")
            await self.process_input(
                filepath=str(sample_path),
                offset_start=None,
                offset_end=None,
                num_speakers=1,
                diarization=True,
                multi_channel=False,
                source_lang="en",
                timestamps_format="s",
                vocab=None,
                word_timestamps=False,
                internal_vad=False,
                repetition_penalty=1.0,
                compression_ratio_threshold=2.4,
                log_prob_threshold=-1.0,
                no_speech_threshold=0.6,
                condition_on_previous_text=True,
            )

    async def process_input(  # noqa: C901
        self,
        filepath: Union[str, List[str]],
        offset_start: Union[float, None],
        offset_end: Union[float, None],
        num_speakers: int,
        diarization: bool,
        multi_channel: bool,
        source_lang: str,
        timestamps_format: str,
        vocab: Union[List[str], None],
        word_timestamps: bool,
        internal_vad: bool,
        repetition_penalty: float,
        compression_ratio_threshold: float,
        log_prob_threshold: float,
        no_speech_threshold: float,
        condition_on_previous_text: bool,
    ) -> Union[Tuple[List[dict], ProcessTimes, float], Exception]:
        """Process the input request and return the results.

        This method will create a task and add it to the appropriate queues.
        All tasks are added to the transcription queue, but will be added to the
        diarization queues only if the user requested it.
        Each step will be processed asynchronously and the results will be returned
        and stored in separated keys in the task dictionary.

        Args:
            filepath (Union[str, List[str]]):
                Path to the audio file or list of paths to the audio files to process.
            offset_start (Union[float, None]):
                The start time of the audio file to process.
            offset_end (Union[float, None]):
                The end time of the audio file to process.
            num_speakers (int):
                The number of oracle speakers.
            diarization (bool):
                Whether to do diarization or not.
            multi_channel (bool):
                Whether to do multi-channel diarization or not.
            source_lang (str):
                Source language of the audio file.
            timestamps_format (str):
                Timestamps format to use.
            vocab (Union[List[str], None]):
                List of words to use for the vocabulary.
            word_timestamps (bool):
                Whether to return word timestamps or not.
            internal_vad (bool):
                Whether to use faster-whisper's VAD or not.
            repetition_penalty (float):
                The repetition penalty to use for the beam search.
            compression_ratio_threshold (float):
                If the gzip compression ratio is above this value, treat as failed.
            log_prob_threshold (float):
                If the average log probability over sampled tokens is below this value, treat as failed.
            no_speech_threshold (float):
                If the no_speech probability is higher than this value AND the average log probability
                over sampled tokens is below `log_prob_threshold`, consider the segment as silent.
            condition_on_previous_text (bool):
                If True, the previous output of the model is provided as a prompt for the next window;
                disabling may make the text inconsistent across windows, but the model becomes less prone
                to getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.

        Returns:
            Union[Tuple[List[dict], ProcessTimes, float], Exception]:
                The results of the ASR pipeline or an exception if something went wrong.
                Results are returned as a tuple of the following:
                    * List[dict]: The final results of the ASR pipeline.
                    * ProcessTimes: The process times of each step of the ASR pipeline.
                    * float: The audio duration
        """
        if isinstance(filepath, list):
            audio, durations = [], []
            for path in filepath:
                _audio, _duration = read_audio(
                    path, offset_start=offset_start, offset_end=offset_end
                )

                audio.append(_audio)
                durations.append(_duration)

            duration = sum(durations) / len(durations)

        else:
            audio, duration = read_audio(
                filepath, offset_start=offset_start, offset_end=offset_end
            )

        gpu_index = None
        if self.remote_services.transcription.use_remote is True:
            _url = await self.remote_services.transcription.next_url()
            transcription_execution = RemoteExecution(url=_url)
        else:
            gpu_index = await self.gpu_handler.get_device()
            transcription_execution = LocalExecution(index=gpu_index)

        if diarization and multi_channel is False:
            if self.remote_services.diarization.use_remote is True:
                _url = await self.remote_services.diarization.next_url()
                diarization_execution = RemoteExecution(url=_url)
            else:
                if gpu_index is None:
                    gpu_index = await self.gpu_handler.get_device()

                diarization_execution = LocalExecution(index=gpu_index)
        else:
            diarization_execution = None

        task = ASRTask(
            audio=audio,
            diarization=DiarizationTask(
                execution=diarization_execution, num_speakers=num_speakers
            ),
            duration=duration,
            multi_channel=multi_channel,
            offset_start=offset_start,
            post_processing=PostProcessingTask(),
            process_times=ProcessTimes(),
            timestamps_format=timestamps_format,
            transcription=TranscriptionTask(
                execution=transcription_execution,
                options=TranscriptionOptions(
                    compression_ratio_threshold=compression_ratio_threshold,
                    condition_on_previous_text=condition_on_previous_text,
                    internal_vad=internal_vad,
                    log_prob_threshold=log_prob_threshold,
                    no_speech_threshold=no_speech_threshold,
                    repetition_penalty=repetition_penalty,
                    source_lang=source_lang,
                    vocab=vocab,
                ),
            ),
            word_timestamps=word_timestamps,
        )

        try:
            start_process_time = time.time()

            transcription_task = self.process_transcription(task, self.debug_mode)
            diarization_task = self.process_diarization(task, self.debug_mode)

            await asyncio.gather(transcription_task, diarization_task)

            if isinstance(task.diarization.result, ProcessException):
                return task.diarization.result

            if (
                diarization
                and task.diarization.result is None
                and multi_channel is False
            ):
                # Empty audio early return
                return early_return(duration=duration)

            if isinstance(task.transcription.result, ProcessException):
                return task.transcription.result

            await asyncio.get_event_loop().run_in_executor(
                None,
                self.process_post_processing,
                task,
            )

            if isinstance(task.post_processing.result, ProcessException):
                return task.post_processing.result

            task.process_times.total = time.time() - start_process_time

            return task.post_processing.result, task.process_times, duration

        except Exception as e:
            return e

        finally:
            del task

            if gpu_index is not None:
                self.gpu_handler.release_device(gpu_index)

    async def process_transcription(self, task: ASRTask, debug_mode: bool) -> None:
        """
        Process a task of transcription and update the task with the result.

        Args:
            task (ASRTask): The task and its parameters.
            debug_mode (bool): Whether to run in debug mode or not.

        Returns:
            None: The task is updated with the result.
        """
        try:
            if isinstance(task.transcription.execution, LocalExecution):
                out = await time_and_tell_async(
                    lambda: self.local_services.transcription(
                        task.audio,
                        model_index=task.transcription.execution.index,
                        suppress_blank=False,
                        word_timestamps=True,
                        **task.transcription.options.model_dump(),
                    ),
                    func_name="transcription",
                    debug_mode=debug_mode,
                )
                result, process_time = out

            elif isinstance(task.transcription.execution, RemoteExecution):
                if isinstance(task.audio, list):
                    ts = [
                        TensorShare.from_dict({"audio": a}, backend=Backend.TORCH)
                        for a in task.audio
                    ]
                else:
                    ts = TensorShare.from_dict(
                        {"audio": task.audio}, backend=Backend.TORCH
                    )

                data = TranscribeRequest(
                    audio=ts,
                    **task.transcription.options.model_dump(),
                )
                out = await time_and_tell_async(
                    self.remote_transcription(
                        url=task.transcription.execution.url,
                        data=data,
                    ),
                    func_name="transcription",
                    debug_mode=debug_mode,
                )
                result, process_time = out

            else:
                raise NotImplementedError("No execution method specified.")

        except Exception as e:
            result = ProcessException(
                source=ExceptionSource.transcription,
                message=f"Error in transcription: {e}\n{traceback.format_exc()}",
            )
            process_time = None

        finally:
            task.process_times.transcription = process_time
            task.transcription.result = result

        return None

    async def process_diarization(self, task: ASRTask, debug_mode: bool) -> None:
        """
        Process a task of diarization.

        Args:
            task (ASRTask): The task and its parameters.
            debug_mode (bool): Whether to run in debug mode or not.

        Returns:
            None: The task is updated with the result.
        """
        try:
            if isinstance(task.diarization.execution, LocalExecution):
                out = await time_and_tell_async(
                    lambda: self.local_services.diarization(
                        waveform=task.audio,
                        audio_duration=task.duration,
                        oracle_num_speakers=task.diarization.num_speakers,
                        model_index=task.diarization.execution.index,
                        vad_service=self.local_services.vad,
                    ),
                    func_name="diarization",
                    debug_mode=debug_mode,
                )
                result, process_time = out

            elif isinstance(task.diarization.execution, RemoteExecution):
                ts = TensorShare.from_dict({"audio": task.audio}, backend=Backend.TORCH)

                data = DiarizationRequest(
                    audio=ts,
                    duration=task.duration,
                    num_speakers=task.diarization.num_speakers,
                )
                out = await time_and_tell_async(
                    self.remote_diarization(
                        url=task.diarization.execution.url,
                        data=data,
                    ),
                    func_name="diarization",
                    debug_mode=debug_mode,
                )
                result, process_time = out

            elif task.diarization.execution is None:
                result = None
                process_time = None

            else:
                raise NotImplementedError("No execution method specified.")

        except Exception as e:
            result = ProcessException(
                source=ExceptionSource.diarization,
                message=f"Error in diarization: {e}\n{traceback.format_exc()}",
            )
            process_time = None

        finally:
            task.process_times.diarization = process_time
            task.diarization.result = result

        return None

    def process_post_processing(self, task: ASRTask) -> None:
        """
        Process a task of post-processing.

        Args:
            task (ASRTask): The task and its parameters.

        Returns:
            None: The task is updated with the result.
        """
        try:
            total_post_process_time = 0

            if task.multi_channel:
                utterances, process_time = time_and_tell(
                    self.local_services.post_processing.multi_channel_speaker_mapping(
                        task.transcription.result
                    ),
                    func_name="multi_channel_speaker_mapping",
                    debug_mode=self.debug_mode,
                )
                total_post_process_time += process_time

            else:
                formatted_segments, process_time = time_and_tell(
                    format_segments(
                        transcription_output=task.transcription.result,
                    ),
                    func_name="format_segments",
                    debug_mode=self.debug_mode,
                )
                total_post_process_time += process_time

                if task.diarization.execution is not None:
                    utterances, process_time = time_and_tell(
                        self.local_services.post_processing.single_channel_speaker_mapping(
                            transcript_segments=formatted_segments,
                            speaker_timestamps=task.diarization.result,
                            word_timestamps=task.word_timestamps,
                        ),
                        func_name="single_channel_speaker_mapping",
                        debug_mode=self.debug_mode,
                    )
                    total_post_process_time += process_time
                else:
                    utterances = formatted_segments

            final_utterances, process_time = time_and_tell(
                self.local_services.post_processing.final_processing_before_returning(
                    utterances=utterances,
                    offset_start=task.offset_start,
                    timestamps_format=task.timestamps_format,
                    word_timestamps=task.word_timestamps,
                ),
                func_name="final_processing_before_returning",
                debug_mode=self.debug_mode,
            )
            total_post_process_time += process_time

        except Exception as e:
            final_utterances = ProcessException(
                source=ExceptionSource.post_processing,
                message=f"Error in post-processing: {e}\n{traceback.format_exc()}",
            )
            total_post_process_time = None

        finally:
            task.process_times.post_processing = total_post_process_time
            task.post_processing.result = final_utterances

        return None

    async def remote_transcription(
        self,
        url: str,
        data: TranscribeRequest,
    ) -> TranscriptionOutput:
        """Remote transcription method."""
        async with aiohttp.ClientSession() as session:
            async with session.post(
                url=f"{url}/api/v1/transcribe",
                data=data.model_dump_json(),
                headers={"Content-Type": "application/json"},
            ) as response:
                if response.status != 200:
                    raise Exception(response.status)
                else:
                    return TranscriptionOutput(**await response.json())

    async def remote_diarization(
        self,
        url: str,
        data: DiarizationRequest,
    ) -> DiarizationOutput:
        """Remote diarization method."""
        async with aiohttp.ClientSession() as session:
            async with session.post(
                url=f"{url}/api/v1/diarize",
                data=data.model_dump_json(),
                headers={"Content-Type": "application/json"},
            ) as response:
                if response.status != 200:
                    r = await response.json()
                    raise Exception(r["detail"])
                else:
                    return DiarizationOutput(**await response.json())

    async def get_url(
        self, task: Literal["transcription", "diarization"]
    ) -> Union[List[str], ProcessException]:
        """Get the list of remote URLs."""
        logger.info(self.remote_services.transcription)
        logger.info(self.remote_services.diarization)
        try:
            selected_task = getattr(self.remote_services, task)
            logger.info(selected_task)
            # Case 1: We are not using remote task
            if selected_task.use_remote is False:
                return ProcessException(
                    source=ExceptionSource.get_url,
                    message=f"You are not using remote {task}.",
                )
            # Case 2: We are using remote task
            else:
                return selected_task.get_urls()

        except Exception as e:
            return ProcessException(
                source=ExceptionSource.get_url,
                message=f"Error in getting URL: {e}\n{traceback.format_exc()}",
            )

    async def add_url(self, data: UrlSchema) -> Union[UrlSchema, ProcessException]:
        """Add a remote URL to the list of URLs."""
        try:
            selected_task = getattr(self.remote_services, data.task)
            # Case 1: We are not using remote task yet
            if selected_task.use_remote is False:
                setattr(
                    self.remote_services,
                    data.task,
                    RemoteServiceConfig(
                        use_remote=True,
                        url_handler=URLService(remote_urls=[str(data.url)]),
                    ),
                )
                setattr(self.local_services, data.task, None)
            # Case 2: We are already using remote task
            else:
                await selected_task.add_url(str(data.url))

        except Exception as e:
            return ProcessException(
                source=ExceptionSource.add_url,
                message=f"Error in adding URL: {e}\n{traceback.format_exc()}",
            )

        return data

    async def remove_url(self, data: UrlSchema) -> Union[UrlSchema, ProcessException]:
        """Remove a remote URL from the list of URLs."""
        try:
            selected_task = getattr(self.remote_services, data.task)
            # Case 1: We are not using remote task
            if selected_task.use_remote is False:
                raise ValueError(f"You are not using remote {data.task}.")
            # Case 2: We are using remote task
            else:
                await selected_task.remove_url(str(data.url))
                if selected_task.get_queue_size() == 0:
                    # No more remote URLs, switch to local service
                    self.create_local_service(task=data.task)
                    setattr(self.remote_services, data.task, RemoteServiceConfig())

            return data

        except Exception as e:
            return ProcessException(
                source=ExceptionSource.remove_url,
                message=f"Error in removing URL: {e}\n{traceback.format_exc()}",
            )

__init__(whisper_model, compute_type, window_lengths, shift_lengths, multiscale_weights, extra_languages, extra_languages_model_paths, transcribe_server_urls, diarize_server_urls, debug_mode)

Initialize the ASRAsyncService class.

Parameters:

Name Type Description Default
whisper_model str

The path to the whisper model.

required
compute_type str

The compute type to use for inference.

required
window_lengths List[float]

The window lengths to use for diarization.

required
shift_lengths List[float]

The shift lengths to use for diarization.

required
multiscale_weights List[float]

The multiscale weights to use for diarization.

required
extra_languages Union[List[str], None]

The list of extra languages to support.

required
extra_languages_model_paths Union[List[str], None]

The list of paths to the extra language models.

required
use_remote_servers bool

Whether to use remote servers for transcription and diarization.

required
transcribe_server_urls Union[List[str], None]

The list of URLs to the remote transcription servers.

required
diarize_server_urls Union[List[str], None]

The list of URLs to the remote diarization servers.

required
debug_mode bool

Whether to run in debug mode.

required
Source code in src/wordcab_transcribe/services/asr_service.py
def __init__(
    self,
    whisper_model: str,
    compute_type: str,
    window_lengths: List[float],
    shift_lengths: List[float],
    multiscale_weights: List[float],
    extra_languages: Union[List[str], None],
    extra_languages_model_paths: Union[List[str], None],
    transcribe_server_urls: Union[List[str], None],
    diarize_server_urls: Union[List[str], None],
    debug_mode: bool,
) -> None:
    """
    Initialize the ASRAsyncService class.

    Args:
        whisper_model (str):
            The path to the whisper model.
        compute_type (str):
            The compute type to use for inference.
        window_lengths (List[float]):
            The window lengths to use for diarization.
        shift_lengths (List[float]):
            The shift lengths to use for diarization.
        multiscale_weights (List[float]):
            The multiscale weights to use for diarization.
        extra_languages (Union[List[str], None]):
            The list of extra languages to support.
        extra_languages_model_paths (Union[List[str], None]):
            The list of paths to the extra language models.
        use_remote_servers (bool):
            Whether to use remote servers for transcription and diarization.
        transcribe_server_urls (Union[List[str], None]):
            The list of URLs to the remote transcription servers.
        diarize_server_urls (Union[List[str], None]):
            The list of URLs to the remote diarization servers.
        debug_mode (bool):
            Whether to run in debug mode.
    """
    super().__init__()

    self.whisper_model: str = whisper_model
    self.compute_type: str = compute_type
    self.window_lengths: List[float] = window_lengths
    self.shift_lengths: List[float] = shift_lengths
    self.multiscale_weights: List[float] = multiscale_weights
    self.extra_languages: Union[List[str], None] = extra_languages
    self.extra_languages_model_paths: Union[List[str], None] = (
        extra_languages_model_paths
    )

    self.local_services: LocalServiceRegistry = LocalServiceRegistry()
    self.remote_services: RemoteServiceRegistry = RemoteServiceRegistry()
    self.dual_channel_transcribe_options: dict = {
        "beam_size": 5,
        "patience": 1,
        "length_penalty": 1,
        "suppress_blank": False,
        "word_timestamps": True,
        "temperature": 0.0,
    }

    if transcribe_server_urls is not None:
        logger.info(
            "You provided URLs for remote transcription server, no local model will"
            " be used."
        )
        self.remote_services.transcription = RemoteServiceConfig(
            use_remote=True,
            url_handler=URLService(remote_urls=transcribe_server_urls),
        )
    else:
        logger.info(
            "You did not provide URLs for remote transcription server, local model"
            " will be used."
        )
        self.create_transcription_local_service()

    if diarize_server_urls is not None:
        logger.info(
            "You provided URLs for remote diarization server, no local model will"
            " be used."
        )
        self.remote_services.diarization = RemoteServiceConfig(
            use_remote=True,
            url_handler=URLService(remote_urls=diarize_server_urls),
        )
    else:
        logger.info(
            "You did not provide URLs for remote diarization server, local model"
            " will be used."
        )
        self.create_diarization_local_service()

    self.debug_mode = debug_mode

add_url(data) async

Add a remote URL to the list of URLs.

Source code in src/wordcab_transcribe/services/asr_service.py
async def add_url(self, data: UrlSchema) -> Union[UrlSchema, ProcessException]:
    """Add a remote URL to the list of URLs."""
    try:
        selected_task = getattr(self.remote_services, data.task)
        # Case 1: We are not using remote task yet
        if selected_task.use_remote is False:
            setattr(
                self.remote_services,
                data.task,
                RemoteServiceConfig(
                    use_remote=True,
                    url_handler=URLService(remote_urls=[str(data.url)]),
                ),
            )
            setattr(self.local_services, data.task, None)
        # Case 2: We are already using remote task
        else:
            await selected_task.add_url(str(data.url))

    except Exception as e:
        return ProcessException(
            source=ExceptionSource.add_url,
            message=f"Error in adding URL: {e}\n{traceback.format_exc()}",
        )

    return data

create_diarization_local_service()

Create a local diarization service.

Source code in src/wordcab_transcribe/services/asr_service.py
def create_diarization_local_service(self) -> None:
    """Create a local diarization service."""
    self.local_services.diarization = DiarizeService(
        device=self.device,
        device_index=self.device_index,
        window_lengths=self.window_lengths,
        shift_lengths=self.shift_lengths,
        multiscale_weights=self.multiscale_weights,
    )

create_local_service(task)

Create a local service.

Source code in src/wordcab_transcribe/services/asr_service.py
def create_local_service(
    self, task: Literal["transcription", "diarization"]
) -> None:
    """Create a local service."""
    if task == "transcription":
        self.create_transcription_local_service()
    elif task == "diarization":
        self.create_diarization_local_service()
    else:
        raise NotImplementedError("No task specified.")

create_transcription_local_service()

Create a local transcription service.

Source code in src/wordcab_transcribe/services/asr_service.py
def create_transcription_local_service(self) -> None:
    """Create a local transcription service."""
    self.local_services.transcription = TranscribeService(
        model_path=self.whisper_model,
        compute_type=self.compute_type,
        device=self.device,
        device_index=self.device_index,
        extra_languages=self.extra_languages,
        extra_languages_model_paths=self.extra_languages_model_paths,
    )

get_url(task) async

Get the list of remote URLs.

Source code in src/wordcab_transcribe/services/asr_service.py
async def get_url(
    self, task: Literal["transcription", "diarization"]
) -> Union[List[str], ProcessException]:
    """Get the list of remote URLs."""
    logger.info(self.remote_services.transcription)
    logger.info(self.remote_services.diarization)
    try:
        selected_task = getattr(self.remote_services, task)
        logger.info(selected_task)
        # Case 1: We are not using remote task
        if selected_task.use_remote is False:
            return ProcessException(
                source=ExceptionSource.get_url,
                message=f"You are not using remote {task}.",
            )
        # Case 2: We are using remote task
        else:
            return selected_task.get_urls()

    except Exception as e:
        return ProcessException(
            source=ExceptionSource.get_url,
            message=f"Error in getting URL: {e}\n{traceback.format_exc()}",
        )

inference_warmup() async

Warmup the GPU by loading the models.

Source code in src/wordcab_transcribe/services/asr_service.py
async def inference_warmup(self) -> None:
    """Warmup the GPU by loading the models."""
    sample_path = Path(__file__).parent.parent / "assets/warmup_sample.wav"

    for gpu_index in self.gpu_handler.device_index:
        logger.info(f"Warmup GPU {gpu_index}.")
        await self.process_input(
            filepath=str(sample_path),
            offset_start=None,
            offset_end=None,
            num_speakers=1,
            diarization=True,
            multi_channel=False,
            source_lang="en",
            timestamps_format="s",
            vocab=None,
            word_timestamps=False,
            internal_vad=False,
            repetition_penalty=1.0,
            compression_ratio_threshold=2.4,
            log_prob_threshold=-1.0,
            no_speech_threshold=0.6,
            condition_on_previous_text=True,
        )

process_diarization(task, debug_mode) async

Process a task of diarization.

Parameters:

Name Type Description Default
task ASRTask

The task and its parameters.

required
debug_mode bool

Whether to run in debug mode or not.

required

Returns:

Name Type Description
None None

The task is updated with the result.

Source code in src/wordcab_transcribe/services/asr_service.py
async def process_diarization(self, task: ASRTask, debug_mode: bool) -> None:
    """
    Process a task of diarization.

    Args:
        task (ASRTask): The task and its parameters.
        debug_mode (bool): Whether to run in debug mode or not.

    Returns:
        None: The task is updated with the result.
    """
    try:
        if isinstance(task.diarization.execution, LocalExecution):
            out = await time_and_tell_async(
                lambda: self.local_services.diarization(
                    waveform=task.audio,
                    audio_duration=task.duration,
                    oracle_num_speakers=task.diarization.num_speakers,
                    model_index=task.diarization.execution.index,
                    vad_service=self.local_services.vad,
                ),
                func_name="diarization",
                debug_mode=debug_mode,
            )
            result, process_time = out

        elif isinstance(task.diarization.execution, RemoteExecution):
            ts = TensorShare.from_dict({"audio": task.audio}, backend=Backend.TORCH)

            data = DiarizationRequest(
                audio=ts,
                duration=task.duration,
                num_speakers=task.diarization.num_speakers,
            )
            out = await time_and_tell_async(
                self.remote_diarization(
                    url=task.diarization.execution.url,
                    data=data,
                ),
                func_name="diarization",
                debug_mode=debug_mode,
            )
            result, process_time = out

        elif task.diarization.execution is None:
            result = None
            process_time = None

        else:
            raise NotImplementedError("No execution method specified.")

    except Exception as e:
        result = ProcessException(
            source=ExceptionSource.diarization,
            message=f"Error in diarization: {e}\n{traceback.format_exc()}",
        )
        process_time = None

    finally:
        task.process_times.diarization = process_time
        task.diarization.result = result

    return None

process_input(filepath, offset_start, offset_end, num_speakers, diarization, multi_channel, source_lang, timestamps_format, vocab, word_timestamps, internal_vad, repetition_penalty, compression_ratio_threshold, log_prob_threshold, no_speech_threshold, condition_on_previous_text) async

Process the input request and return the results.

This method will create a task and add it to the appropriate queues. All tasks are added to the transcription queue, but will be added to the diarization queues only if the user requested it. Each step will be processed asynchronously and the results will be returned and stored in separated keys in the task dictionary.

Parameters:

Name Type Description Default
filepath Union[str, List[str]]

Path to the audio file or list of paths to the audio files to process.

required
offset_start Union[float, None]

The start time of the audio file to process.

required
offset_end Union[float, None]

The end time of the audio file to process.

required
num_speakers int

The number of oracle speakers.

required
diarization bool

Whether to do diarization or not.

required
multi_channel bool

Whether to do multi-channel diarization or not.

required
source_lang str

Source language of the audio file.

required
timestamps_format str

Timestamps format to use.

required
vocab Union[List[str], None]

List of words to use for the vocabulary.

required
word_timestamps bool

Whether to return word timestamps or not.

required
internal_vad bool

Whether to use faster-whisper's VAD or not.

required
repetition_penalty float

The repetition penalty to use for the beam search.

required
compression_ratio_threshold float

If the gzip compression ratio is above this value, treat as failed.

required
log_prob_threshold float

If the average log probability over sampled tokens is below this value, treat as failed.

required
no_speech_threshold float

If the no_speech probability is higher than this value AND the average log probability over sampled tokens is below log_prob_threshold, consider the segment as silent.

required
condition_on_previous_text bool

If True, the previous output of the model is provided as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.

required

Returns:

Type Description
Union[Tuple[List[dict], ProcessTimes, float], Exception]

Union[Tuple[List[dict], ProcessTimes, float], Exception]: The results of the ASR pipeline or an exception if something went wrong. Results are returned as a tuple of the following: * List[dict]: The final results of the ASR pipeline. * ProcessTimes: The process times of each step of the ASR pipeline. * float: The audio duration

Source code in src/wordcab_transcribe/services/asr_service.py
async def process_input(  # noqa: C901
    self,
    filepath: Union[str, List[str]],
    offset_start: Union[float, None],
    offset_end: Union[float, None],
    num_speakers: int,
    diarization: bool,
    multi_channel: bool,
    source_lang: str,
    timestamps_format: str,
    vocab: Union[List[str], None],
    word_timestamps: bool,
    internal_vad: bool,
    repetition_penalty: float,
    compression_ratio_threshold: float,
    log_prob_threshold: float,
    no_speech_threshold: float,
    condition_on_previous_text: bool,
) -> Union[Tuple[List[dict], ProcessTimes, float], Exception]:
    """Process the input request and return the results.

    This method will create a task and add it to the appropriate queues.
    All tasks are added to the transcription queue, but will be added to the
    diarization queues only if the user requested it.
    Each step will be processed asynchronously and the results will be returned
    and stored in separated keys in the task dictionary.

    Args:
        filepath (Union[str, List[str]]):
            Path to the audio file or list of paths to the audio files to process.
        offset_start (Union[float, None]):
            The start time of the audio file to process.
        offset_end (Union[float, None]):
            The end time of the audio file to process.
        num_speakers (int):
            The number of oracle speakers.
        diarization (bool):
            Whether to do diarization or not.
        multi_channel (bool):
            Whether to do multi-channel diarization or not.
        source_lang (str):
            Source language of the audio file.
        timestamps_format (str):
            Timestamps format to use.
        vocab (Union[List[str], None]):
            List of words to use for the vocabulary.
        word_timestamps (bool):
            Whether to return word timestamps or not.
        internal_vad (bool):
            Whether to use faster-whisper's VAD or not.
        repetition_penalty (float):
            The repetition penalty to use for the beam search.
        compression_ratio_threshold (float):
            If the gzip compression ratio is above this value, treat as failed.
        log_prob_threshold (float):
            If the average log probability over sampled tokens is below this value, treat as failed.
        no_speech_threshold (float):
            If the no_speech probability is higher than this value AND the average log probability
            over sampled tokens is below `log_prob_threshold`, consider the segment as silent.
        condition_on_previous_text (bool):
            If True, the previous output of the model is provided as a prompt for the next window;
            disabling may make the text inconsistent across windows, but the model becomes less prone
            to getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.

    Returns:
        Union[Tuple[List[dict], ProcessTimes, float], Exception]:
            The results of the ASR pipeline or an exception if something went wrong.
            Results are returned as a tuple of the following:
                * List[dict]: The final results of the ASR pipeline.
                * ProcessTimes: The process times of each step of the ASR pipeline.
                * float: The audio duration
    """
    if isinstance(filepath, list):
        audio, durations = [], []
        for path in filepath:
            _audio, _duration = read_audio(
                path, offset_start=offset_start, offset_end=offset_end
            )

            audio.append(_audio)
            durations.append(_duration)

        duration = sum(durations) / len(durations)

    else:
        audio, duration = read_audio(
            filepath, offset_start=offset_start, offset_end=offset_end
        )

    gpu_index = None
    if self.remote_services.transcription.use_remote is True:
        _url = await self.remote_services.transcription.next_url()
        transcription_execution = RemoteExecution(url=_url)
    else:
        gpu_index = await self.gpu_handler.get_device()
        transcription_execution = LocalExecution(index=gpu_index)

    if diarization and multi_channel is False:
        if self.remote_services.diarization.use_remote is True:
            _url = await self.remote_services.diarization.next_url()
            diarization_execution = RemoteExecution(url=_url)
        else:
            if gpu_index is None:
                gpu_index = await self.gpu_handler.get_device()

            diarization_execution = LocalExecution(index=gpu_index)
    else:
        diarization_execution = None

    task = ASRTask(
        audio=audio,
        diarization=DiarizationTask(
            execution=diarization_execution, num_speakers=num_speakers
        ),
        duration=duration,
        multi_channel=multi_channel,
        offset_start=offset_start,
        post_processing=PostProcessingTask(),
        process_times=ProcessTimes(),
        timestamps_format=timestamps_format,
        transcription=TranscriptionTask(
            execution=transcription_execution,
            options=TranscriptionOptions(
                compression_ratio_threshold=compression_ratio_threshold,
                condition_on_previous_text=condition_on_previous_text,
                internal_vad=internal_vad,
                log_prob_threshold=log_prob_threshold,
                no_speech_threshold=no_speech_threshold,
                repetition_penalty=repetition_penalty,
                source_lang=source_lang,
                vocab=vocab,
            ),
        ),
        word_timestamps=word_timestamps,
    )

    try:
        start_process_time = time.time()

        transcription_task = self.process_transcription(task, self.debug_mode)
        diarization_task = self.process_diarization(task, self.debug_mode)

        await asyncio.gather(transcription_task, diarization_task)

        if isinstance(task.diarization.result, ProcessException):
            return task.diarization.result

        if (
            diarization
            and task.diarization.result is None
            and multi_channel is False
        ):
            # Empty audio early return
            return early_return(duration=duration)

        if isinstance(task.transcription.result, ProcessException):
            return task.transcription.result

        await asyncio.get_event_loop().run_in_executor(
            None,
            self.process_post_processing,
            task,
        )

        if isinstance(task.post_processing.result, ProcessException):
            return task.post_processing.result

        task.process_times.total = time.time() - start_process_time

        return task.post_processing.result, task.process_times, duration

    except Exception as e:
        return e

    finally:
        del task

        if gpu_index is not None:
            self.gpu_handler.release_device(gpu_index)

process_post_processing(task)

Process a task of post-processing.

Parameters:

Name Type Description Default
task ASRTask

The task and its parameters.

required

Returns:

Name Type Description
None None

The task is updated with the result.

Source code in src/wordcab_transcribe/services/asr_service.py
def process_post_processing(self, task: ASRTask) -> None:
    """
    Process a task of post-processing.

    Args:
        task (ASRTask): The task and its parameters.

    Returns:
        None: The task is updated with the result.
    """
    try:
        total_post_process_time = 0

        if task.multi_channel:
            utterances, process_time = time_and_tell(
                self.local_services.post_processing.multi_channel_speaker_mapping(
                    task.transcription.result
                ),
                func_name="multi_channel_speaker_mapping",
                debug_mode=self.debug_mode,
            )
            total_post_process_time += process_time

        else:
            formatted_segments, process_time = time_and_tell(
                format_segments(
                    transcription_output=task.transcription.result,
                ),
                func_name="format_segments",
                debug_mode=self.debug_mode,
            )
            total_post_process_time += process_time

            if task.diarization.execution is not None:
                utterances, process_time = time_and_tell(
                    self.local_services.post_processing.single_channel_speaker_mapping(
                        transcript_segments=formatted_segments,
                        speaker_timestamps=task.diarization.result,
                        word_timestamps=task.word_timestamps,
                    ),
                    func_name="single_channel_speaker_mapping",
                    debug_mode=self.debug_mode,
                )
                total_post_process_time += process_time
            else:
                utterances = formatted_segments

        final_utterances, process_time = time_and_tell(
            self.local_services.post_processing.final_processing_before_returning(
                utterances=utterances,
                offset_start=task.offset_start,
                timestamps_format=task.timestamps_format,
                word_timestamps=task.word_timestamps,
            ),
            func_name="final_processing_before_returning",
            debug_mode=self.debug_mode,
        )
        total_post_process_time += process_time

    except Exception as e:
        final_utterances = ProcessException(
            source=ExceptionSource.post_processing,
            message=f"Error in post-processing: {e}\n{traceback.format_exc()}",
        )
        total_post_process_time = None

    finally:
        task.process_times.post_processing = total_post_process_time
        task.post_processing.result = final_utterances

    return None

process_transcription(task, debug_mode) async

Process a task of transcription and update the task with the result.

Parameters:

Name Type Description Default
task ASRTask

The task and its parameters.

required
debug_mode bool

Whether to run in debug mode or not.

required

Returns:

Name Type Description
None None

The task is updated with the result.

Source code in src/wordcab_transcribe/services/asr_service.py
async def process_transcription(self, task: ASRTask, debug_mode: bool) -> None:
    """
    Process a task of transcription and update the task with the result.

    Args:
        task (ASRTask): The task and its parameters.
        debug_mode (bool): Whether to run in debug mode or not.

    Returns:
        None: The task is updated with the result.
    """
    try:
        if isinstance(task.transcription.execution, LocalExecution):
            out = await time_and_tell_async(
                lambda: self.local_services.transcription(
                    task.audio,
                    model_index=task.transcription.execution.index,
                    suppress_blank=False,
                    word_timestamps=True,
                    **task.transcription.options.model_dump(),
                ),
                func_name="transcription",
                debug_mode=debug_mode,
            )
            result, process_time = out

        elif isinstance(task.transcription.execution, RemoteExecution):
            if isinstance(task.audio, list):
                ts = [
                    TensorShare.from_dict({"audio": a}, backend=Backend.TORCH)
                    for a in task.audio
                ]
            else:
                ts = TensorShare.from_dict(
                    {"audio": task.audio}, backend=Backend.TORCH
                )

            data = TranscribeRequest(
                audio=ts,
                **task.transcription.options.model_dump(),
            )
            out = await time_and_tell_async(
                self.remote_transcription(
                    url=task.transcription.execution.url,
                    data=data,
                ),
                func_name="transcription",
                debug_mode=debug_mode,
            )
            result, process_time = out

        else:
            raise NotImplementedError("No execution method specified.")

    except Exception as e:
        result = ProcessException(
            source=ExceptionSource.transcription,
            message=f"Error in transcription: {e}\n{traceback.format_exc()}",
        )
        process_time = None

    finally:
        task.process_times.transcription = process_time
        task.transcription.result = result

    return None

remote_diarization(url, data) async

Remote diarization method.

Source code in src/wordcab_transcribe/services/asr_service.py
async def remote_diarization(
    self,
    url: str,
    data: DiarizationRequest,
) -> DiarizationOutput:
    """Remote diarization method."""
    async with aiohttp.ClientSession() as session:
        async with session.post(
            url=f"{url}/api/v1/diarize",
            data=data.model_dump_json(),
            headers={"Content-Type": "application/json"},
        ) as response:
            if response.status != 200:
                r = await response.json()
                raise Exception(r["detail"])
            else:
                return DiarizationOutput(**await response.json())

remote_transcription(url, data) async

Remote transcription method.

Source code in src/wordcab_transcribe/services/asr_service.py
async def remote_transcription(
    self,
    url: str,
    data: TranscribeRequest,
) -> TranscriptionOutput:
    """Remote transcription method."""
    async with aiohttp.ClientSession() as session:
        async with session.post(
            url=f"{url}/api/v1/transcribe",
            data=data.model_dump_json(),
            headers={"Content-Type": "application/json"},
        ) as response:
            if response.status != 200:
                raise Exception(response.status)
            else:
                return TranscriptionOutput(**await response.json())

remove_url(data) async

Remove a remote URL from the list of URLs.

Source code in src/wordcab_transcribe/services/asr_service.py
async def remove_url(self, data: UrlSchema) -> Union[UrlSchema, ProcessException]:
    """Remove a remote URL from the list of URLs."""
    try:
        selected_task = getattr(self.remote_services, data.task)
        # Case 1: We are not using remote task
        if selected_task.use_remote is False:
            raise ValueError(f"You are not using remote {data.task}.")
        # Case 2: We are using remote task
        else:
            await selected_task.remove_url(str(data.url))
            if selected_task.get_queue_size() == 0:
                # No more remote URLs, switch to local service
                self.create_local_service(task=data.task)
                setattr(self.remote_services, data.task, RemoteServiceConfig())

        return data

    except Exception as e:
        return ProcessException(
            source=ExceptionSource.remove_url,
            message=f"Error in removing URL: {e}\n{traceback.format_exc()}",
        )

ASRDiarizationOnly

Bases: ASRService

ASR Service module for diarization-only endpoint.

Source code in src/wordcab_transcribe/services/asr_service.py
class ASRDiarizationOnly(ASRService):
    """ASR Service module for diarization-only endpoint."""

    def __init__(
        self,
        window_lengths: List[int],
        shift_lengths: List[int],
        multiscale_weights: List[float],
        debug_mode: bool,
    ) -> None:
        """Initialize the ASRDiarizationOnly class."""
        super().__init__()

        self.diarization_service = DiarizeService(
            device=self.device,
            device_index=self.device_index,
            window_lengths=window_lengths,
            shift_lengths=shift_lengths,
            multiscale_weights=multiscale_weights,
        )
        self.vad_service = VadService()
        self.debug_mode = debug_mode

    async def inference_warmup(self) -> None:
        """Warmup the GPU by doing one inference."""
        sample_audio = Path(__file__).parent.parent / "assets/warmup_sample.wav"

        audio, duration = read_audio(str(sample_audio))
        ts = TensorShare.from_dict({"audio": audio}, backend=Backend.TORCH)

        data = DiarizationRequest(
            audio=ts,
            duration=duration,
            num_speakers=1,
        )

        for gpu_index in self.gpu_handler.device_index:
            logger.info(f"Warmup GPU {gpu_index}.")
            await self.process_input(data=data)

    async def process_input(self, data: DiarizationRequest) -> DiarizationOutput:
        """
        Process the input data and return the results as a list of segments.

        Args:
            data (DiarizationRequest):
                The input data to process.

        Returns:
            DiarizationOutput:
                The results of the ASR pipeline.
        """
        gpu_index = await self.gpu_handler.get_device()

        try:
            result = self.diarization_service(
                waveform=data.audio,
                audio_duration=data.duration,
                oracle_num_speakers=data.num_speakers,
                model_index=gpu_index,
                vad_service=self.vad_service,
            )

        except Exception as e:
            result = ProcessException(
                source=ExceptionSource.diarization,
                message=f"Error in diarization: {e}\n{traceback.format_exc()}",
            )

        finally:
            self.gpu_handler.release_device(gpu_index)

        return result

__init__(window_lengths, shift_lengths, multiscale_weights, debug_mode)

Initialize the ASRDiarizationOnly class.

Source code in src/wordcab_transcribe/services/asr_service.py
def __init__(
    self,
    window_lengths: List[int],
    shift_lengths: List[int],
    multiscale_weights: List[float],
    debug_mode: bool,
) -> None:
    """Initialize the ASRDiarizationOnly class."""
    super().__init__()

    self.diarization_service = DiarizeService(
        device=self.device,
        device_index=self.device_index,
        window_lengths=window_lengths,
        shift_lengths=shift_lengths,
        multiscale_weights=multiscale_weights,
    )
    self.vad_service = VadService()
    self.debug_mode = debug_mode

inference_warmup() async

Warmup the GPU by doing one inference.

Source code in src/wordcab_transcribe/services/asr_service.py
async def inference_warmup(self) -> None:
    """Warmup the GPU by doing one inference."""
    sample_audio = Path(__file__).parent.parent / "assets/warmup_sample.wav"

    audio, duration = read_audio(str(sample_audio))
    ts = TensorShare.from_dict({"audio": audio}, backend=Backend.TORCH)

    data = DiarizationRequest(
        audio=ts,
        duration=duration,
        num_speakers=1,
    )

    for gpu_index in self.gpu_handler.device_index:
        logger.info(f"Warmup GPU {gpu_index}.")
        await self.process_input(data=data)

process_input(data) async

Process the input data and return the results as a list of segments.

Parameters:

Name Type Description Default
data DiarizationRequest

The input data to process.

required

Returns:

Name Type Description
DiarizationOutput DiarizationOutput

The results of the ASR pipeline.

Source code in src/wordcab_transcribe/services/asr_service.py
async def process_input(self, data: DiarizationRequest) -> DiarizationOutput:
    """
    Process the input data and return the results as a list of segments.

    Args:
        data (DiarizationRequest):
            The input data to process.

    Returns:
        DiarizationOutput:
            The results of the ASR pipeline.
    """
    gpu_index = await self.gpu_handler.get_device()

    try:
        result = self.diarization_service(
            waveform=data.audio,
            audio_duration=data.duration,
            oracle_num_speakers=data.num_speakers,
            model_index=gpu_index,
            vad_service=self.vad_service,
        )

    except Exception as e:
        result = ProcessException(
            source=ExceptionSource.diarization,
            message=f"Error in diarization: {e}\n{traceback.format_exc()}",
        )

    finally:
        self.gpu_handler.release_device(gpu_index)

    return result

ASRLiveService

Bases: ASRService

ASR Service module for live endpoints.

Source code in src/wordcab_transcribe/services/asr_service.py
class ASRLiveService(ASRService):
    """ASR Service module for live endpoints."""

    def __init__(self, whisper_model: str, compute_type: str, debug_mode: bool) -> None:
        """Initialize the ASRLiveService class."""
        super().__init__()

        self.transcription_service = TranscribeService(
            model_path=whisper_model,
            compute_type=compute_type,
            device=self.device,
            device_index=self.device_index,
        )
        self.debug_mode = debug_mode

    async def inference_warmup(self) -> None:
        """Warmup the GPU by loading the models."""
        sample_audio = Path(__file__).parent.parent / "assets/warmup_sample.wav"
        with open(sample_audio, "rb") as audio_file:
            async for _ in self.process_input(
                data=audio_file.read(),
                source_lang="en",
            ):
                pass

    async def process_input(self, data: bytes, source_lang: str) -> Iterable[dict]:
        """
        Process the input data and return the results as a tuple of text and duration.

        Args:
            data (bytes):
                The raw audio bytes to process.
            source_lang (str):
                The source language of the audio data.

        Yields:
            Iterable[dict]: The results of the ASR pipeline.
        """
        gpu_index = await self.gpu_handler.get_device()

        try:
            waveform, _ = read_audio(data)

            async for result in self.transcription_service.async_live_transcribe(
                audio=waveform, source_lang=source_lang, model_index=gpu_index
            ):
                yield result

        except Exception as e:
            logger.error(
                f"Error in transcription gpu {gpu_index}: {e}\n{traceback.format_exc()}"
            )

        finally:
            self.gpu_handler.release_device(gpu_index)

__init__(whisper_model, compute_type, debug_mode)

Initialize the ASRLiveService class.

Source code in src/wordcab_transcribe/services/asr_service.py
def __init__(self, whisper_model: str, compute_type: str, debug_mode: bool) -> None:
    """Initialize the ASRLiveService class."""
    super().__init__()

    self.transcription_service = TranscribeService(
        model_path=whisper_model,
        compute_type=compute_type,
        device=self.device,
        device_index=self.device_index,
    )
    self.debug_mode = debug_mode

inference_warmup() async

Warmup the GPU by loading the models.

Source code in src/wordcab_transcribe/services/asr_service.py
async def inference_warmup(self) -> None:
    """Warmup the GPU by loading the models."""
    sample_audio = Path(__file__).parent.parent / "assets/warmup_sample.wav"
    with open(sample_audio, "rb") as audio_file:
        async for _ in self.process_input(
            data=audio_file.read(),
            source_lang="en",
        ):
            pass

process_input(data, source_lang) async

Process the input data and return the results as a tuple of text and duration.

Parameters:

Name Type Description Default
data bytes

The raw audio bytes to process.

required
source_lang str

The source language of the audio data.

required

Yields:

Type Description
Iterable[dict]

Iterable[dict]: The results of the ASR pipeline.

Source code in src/wordcab_transcribe/services/asr_service.py
async def process_input(self, data: bytes, source_lang: str) -> Iterable[dict]:
    """
    Process the input data and return the results as a tuple of text and duration.

    Args:
        data (bytes):
            The raw audio bytes to process.
        source_lang (str):
            The source language of the audio data.

    Yields:
        Iterable[dict]: The results of the ASR pipeline.
    """
    gpu_index = await self.gpu_handler.get_device()

    try:
        waveform, _ = read_audio(data)

        async for result in self.transcription_service.async_live_transcribe(
            audio=waveform, source_lang=source_lang, model_index=gpu_index
        ):
            yield result

    except Exception as e:
        logger.error(
            f"Error in transcription gpu {gpu_index}: {e}\n{traceback.format_exc()}"
        )

    finally:
        self.gpu_handler.release_device(gpu_index)

ASRService

Bases: ABC

Base ASR Service module that handle all AI interactions and batch processing.

Source code in src/wordcab_transcribe/services/asr_service.py
class ASRService(ABC):
    """Base ASR Service module that handle all AI interactions and batch processing."""

    def __init__(self) -> None:
        """Initialize the ASR Service.

        This class is not meant to be instantiated. Use the subclasses instead.
        """
        self.device = (
            "cuda" if torch.cuda.is_available() else "cpu"
        )  # Do we have a GPU? If so, use it!
        self.num_gpus = torch.cuda.device_count() if self.device == "cuda" else 0
        logger.info(f"NVIDIA GPUs available: {self.num_gpus}")

        if self.num_gpus > 1 and self.device == "cuda":
            self.device_index = list(range(self.num_gpus))
        else:
            self.device_index = [0]

        self.gpu_handler = GPUService(
            device=self.device, device_index=self.device_index
        )

    @abstractmethod
    async def process_input(self) -> None:
        """Process the input request by creating a task and adding it to the appropriate queues."""
        raise NotImplementedError("This method should be implemented in subclasses.")

__init__()

Initialize the ASR Service.

This class is not meant to be instantiated. Use the subclasses instead.

Source code in src/wordcab_transcribe/services/asr_service.py
def __init__(self) -> None:
    """Initialize the ASR Service.

    This class is not meant to be instantiated. Use the subclasses instead.
    """
    self.device = (
        "cuda" if torch.cuda.is_available() else "cpu"
    )  # Do we have a GPU? If so, use it!
    self.num_gpus = torch.cuda.device_count() if self.device == "cuda" else 0
    logger.info(f"NVIDIA GPUs available: {self.num_gpus}")

    if self.num_gpus > 1 and self.device == "cuda":
        self.device_index = list(range(self.num_gpus))
    else:
        self.device_index = [0]

    self.gpu_handler = GPUService(
        device=self.device, device_index=self.device_index
    )

process_input() abstractmethod async

Process the input request by creating a task and adding it to the appropriate queues.

Source code in src/wordcab_transcribe/services/asr_service.py
@abstractmethod
async def process_input(self) -> None:
    """Process the input request by creating a task and adding it to the appropriate queues."""
    raise NotImplementedError("This method should be implemented in subclasses.")

ASRTask

Bases: BaseModel

ASR Task model.

Source code in src/wordcab_transcribe/services/asr_service.py
class ASRTask(BaseModel):
    """ASR Task model."""

    model_config = ConfigDict(arbitrary_types_allowed=True)

    audio: Union[torch.Tensor, List[torch.Tensor]]
    diarization: "DiarizationTask"
    duration: float
    multi_channel: bool
    offset_start: Union[float, None]
    post_processing: "PostProcessingTask"
    process_times: ProcessTimes
    timestamps_format: Timestamps
    transcription: "TranscriptionTask"
    word_timestamps: bool

ASRTranscriptionOnly

Bases: ASRService

ASR Service module for transcription-only endpoint.

Source code in src/wordcab_transcribe/services/asr_service.py
class ASRTranscriptionOnly(ASRService):
    """ASR Service module for transcription-only endpoint."""

    def __init__(
        self,
        whisper_model: str,
        compute_type: str,
        extra_languages: Union[List[str], None],
        extra_languages_model_paths: Union[List[str], None],
        debug_mode: bool,
    ) -> None:
        """Initialize the ASRTranscriptionOnly class."""
        super().__init__()

        self.transcription_service = TranscribeService(
            model_path=whisper_model,
            compute_type=compute_type,
            device=self.device,
            device_index=self.device_index,
            extra_languages=extra_languages,
            extra_languages_model_paths=extra_languages_model_paths,
        )
        self.debug_mode = debug_mode

    async def inference_warmup(self) -> None:
        """Warmup the GPU by doing one inference."""
        sample_audio = Path(__file__).parent.parent / "assets/warmup_sample.wav"

        audio, _ = read_audio(str(sample_audio))
        ts = TensorShare.from_dict({"audio": audio}, backend=Backend.TORCH)

        data = TranscribeRequest(
            audio=ts,
            source_lang="en",
            compression_ratio_threshold=2.4,
            condition_on_previous_text=True,
            internal_vad=False,
            log_prob_threshold=-1.0,
            no_speech_threshold=0.6,
            repetition_penalty=1.0,
            vocab=None,
        )

        for gpu_index in self.gpu_handler.device_index:
            logger.info(f"Warmup GPU {gpu_index}.")
            await self.process_input(data=data)

    async def process_input(
        self, data: TranscribeRequest
    ) -> Union[TranscriptionOutput, List[TranscriptionOutput]]:
        """
        Process the input data and return the results as a list of segments.

        Args:
            data (TranscribeRequest):
                The input data to process.

        Returns:
            Union[TranscriptionOutput, List[TranscriptionOutput]]:
                The results of the ASR pipeline.
        """
        gpu_index = await self.gpu_handler.get_device()

        try:
            result = self.transcription_service(
                audio=data.audio,
                source_lang=data.source_lang,
                model_index=gpu_index,
                suppress_blank=False,
                word_timestamps=True,
                compression_ratio_threshold=data.compression_ratio_threshold,
                condition_on_previous_text=data.condition_on_previous_text,
                internal_vad=data.internal_vad,
                log_prob_threshold=data.log_prob_threshold,
                repetition_penalty=data.repetition_penalty,
                no_speech_threshold=data.no_speech_threshold,
                vocab=data.vocab,
            )

        except Exception as e:
            result = ProcessException(
                source=ExceptionSource.transcription,
                message=f"Error in transcription: {e}\n{traceback.format_exc()}",
            )

        finally:
            self.gpu_handler.release_device(gpu_index)

        return result

__init__(whisper_model, compute_type, extra_languages, extra_languages_model_paths, debug_mode)

Initialize the ASRTranscriptionOnly class.

Source code in src/wordcab_transcribe/services/asr_service.py
def __init__(
    self,
    whisper_model: str,
    compute_type: str,
    extra_languages: Union[List[str], None],
    extra_languages_model_paths: Union[List[str], None],
    debug_mode: bool,
) -> None:
    """Initialize the ASRTranscriptionOnly class."""
    super().__init__()

    self.transcription_service = TranscribeService(
        model_path=whisper_model,
        compute_type=compute_type,
        device=self.device,
        device_index=self.device_index,
        extra_languages=extra_languages,
        extra_languages_model_paths=extra_languages_model_paths,
    )
    self.debug_mode = debug_mode

inference_warmup() async

Warmup the GPU by doing one inference.

Source code in src/wordcab_transcribe/services/asr_service.py
async def inference_warmup(self) -> None:
    """Warmup the GPU by doing one inference."""
    sample_audio = Path(__file__).parent.parent / "assets/warmup_sample.wav"

    audio, _ = read_audio(str(sample_audio))
    ts = TensorShare.from_dict({"audio": audio}, backend=Backend.TORCH)

    data = TranscribeRequest(
        audio=ts,
        source_lang="en",
        compression_ratio_threshold=2.4,
        condition_on_previous_text=True,
        internal_vad=False,
        log_prob_threshold=-1.0,
        no_speech_threshold=0.6,
        repetition_penalty=1.0,
        vocab=None,
    )

    for gpu_index in self.gpu_handler.device_index:
        logger.info(f"Warmup GPU {gpu_index}.")
        await self.process_input(data=data)

process_input(data) async

Process the input data and return the results as a list of segments.

Parameters:

Name Type Description Default
data TranscribeRequest

The input data to process.

required

Returns:

Type Description
Union[TranscriptionOutput, List[TranscriptionOutput]]

Union[TranscriptionOutput, List[TranscriptionOutput]]: The results of the ASR pipeline.

Source code in src/wordcab_transcribe/services/asr_service.py
async def process_input(
    self, data: TranscribeRequest
) -> Union[TranscriptionOutput, List[TranscriptionOutput]]:
    """
    Process the input data and return the results as a list of segments.

    Args:
        data (TranscribeRequest):
            The input data to process.

    Returns:
        Union[TranscriptionOutput, List[TranscriptionOutput]]:
            The results of the ASR pipeline.
    """
    gpu_index = await self.gpu_handler.get_device()

    try:
        result = self.transcription_service(
            audio=data.audio,
            source_lang=data.source_lang,
            model_index=gpu_index,
            suppress_blank=False,
            word_timestamps=True,
            compression_ratio_threshold=data.compression_ratio_threshold,
            condition_on_previous_text=data.condition_on_previous_text,
            internal_vad=data.internal_vad,
            log_prob_threshold=data.log_prob_threshold,
            repetition_penalty=data.repetition_penalty,
            no_speech_threshold=data.no_speech_threshold,
            vocab=data.vocab,
        )

    except Exception as e:
        result = ProcessException(
            source=ExceptionSource.transcription,
            message=f"Error in transcription: {e}\n{traceback.format_exc()}",
        )

    finally:
        self.gpu_handler.release_device(gpu_index)

    return result

DiarizationTask

Bases: BaseModel

Diarization Task model.

Source code in src/wordcab_transcribe/services/asr_service.py
class DiarizationTask(BaseModel):
    """Diarization Task model."""

    execution: Union[LocalExecution, RemoteExecution, None]
    num_speakers: int
    result: Union[ProcessException, DiarizationOutput, None] = None

ExceptionSource

Bases: str, Enum

Exception source enum.

Source code in src/wordcab_transcribe/services/asr_service.py
class ExceptionSource(str, Enum):
    """Exception source enum."""

    add_url = "add_url"
    diarization = "diarization"
    get_url = "get_url"
    post_processing = "post_processing"
    remove_url = "remove_url"
    transcription = "transcription"

LocalExecution

Bases: BaseModel

Local execution model.

Source code in src/wordcab_transcribe/services/asr_service.py
class LocalExecution(BaseModel):
    """Local execution model."""

    index: Union[int, None]

LocalServiceRegistry dataclass

Registry for local services.

Source code in src/wordcab_transcribe/services/asr_service.py
@dataclass
class LocalServiceRegistry:
    """Registry for local services."""

    diarization: Union[DiarizeService, None] = None
    post_processing: PostProcessingService = PostProcessingService()
    transcription: Union[TranscribeService, None] = None
    vad: VadService = VadService()

PostProcessingTask

Bases: BaseModel

Post Processing Task model.

Source code in src/wordcab_transcribe/services/asr_service.py
class PostProcessingTask(BaseModel):
    """Post Processing Task model."""

    result: Union[ProcessException, List[Utterance], None] = None

ProcessException

Bases: BaseModel

Process exception model.

Source code in src/wordcab_transcribe/services/asr_service.py
class ProcessException(BaseModel):
    """Process exception model."""

    source: ExceptionSource
    message: str

RemoteExecution

Bases: BaseModel

Remote execution model.

Source code in src/wordcab_transcribe/services/asr_service.py
class RemoteExecution(BaseModel):
    """Remote execution model."""

    url: str

RemoteServiceConfig dataclass

Remote service config.

Source code in src/wordcab_transcribe/services/asr_service.py
@dataclass
class RemoteServiceConfig:
    """Remote service config."""

    url_handler: Union[URLService, None] = None
    use_remote: bool = False

    def get_urls(self) -> List[str]:
        """Get the list of URLs."""
        return self.url_handler.get_urls()

    def get_queue_size(self) -> int:
        """Get the queue size."""
        return self.url_handler.get_queue_size()

    async def add_url(self, url: str) -> None:
        """Add a URL to the list of URLs."""
        await self.url_handler.add_url(url)

    async def next_url(self) -> str:
        """Get the next URL."""
        return await self.url_handler.next_url()

    async def remove_url(self, url: str) -> None:
        """Remove a URL from the list of URLs."""
        await self.url_handler.remove_url(url)

add_url(url) async

Add a URL to the list of URLs.

Source code in src/wordcab_transcribe/services/asr_service.py
async def add_url(self, url: str) -> None:
    """Add a URL to the list of URLs."""
    await self.url_handler.add_url(url)

get_queue_size()

Get the queue size.

Source code in src/wordcab_transcribe/services/asr_service.py
def get_queue_size(self) -> int:
    """Get the queue size."""
    return self.url_handler.get_queue_size()

get_urls()

Get the list of URLs.

Source code in src/wordcab_transcribe/services/asr_service.py
def get_urls(self) -> List[str]:
    """Get the list of URLs."""
    return self.url_handler.get_urls()

next_url() async

Get the next URL.

Source code in src/wordcab_transcribe/services/asr_service.py
async def next_url(self) -> str:
    """Get the next URL."""
    return await self.url_handler.next_url()

remove_url(url) async

Remove a URL from the list of URLs.

Source code in src/wordcab_transcribe/services/asr_service.py
async def remove_url(self, url: str) -> None:
    """Remove a URL from the list of URLs."""
    await self.url_handler.remove_url(url)

RemoteServiceRegistry dataclass

Registry for remote services.

Source code in src/wordcab_transcribe/services/asr_service.py
@dataclass
class RemoteServiceRegistry:
    """Registry for remote services."""

    diarization: RemoteServiceConfig = RemoteServiceConfig()
    transcription: RemoteServiceConfig = RemoteServiceConfig()

TranscriptionOptions

Bases: BaseModel

Transcription options model.

Source code in src/wordcab_transcribe/services/asr_service.py
class TranscriptionOptions(BaseModel):
    """Transcription options model."""

    compression_ratio_threshold: float
    condition_on_previous_text: bool
    internal_vad: bool
    log_prob_threshold: float
    no_speech_threshold: float
    repetition_penalty: float
    source_lang: str
    vocab: Union[List[str], None]

TranscriptionTask

Bases: BaseModel

Transcription Task model.

Source code in src/wordcab_transcribe/services/asr_service.py
class TranscriptionTask(BaseModel):
    """Transcription Task model."""

    execution: Union[LocalExecution, RemoteExecution]
    options: TranscriptionOptions
    result: Union[
        ProcessException, TranscriptionOutput, List[TranscriptionOutput], None
    ] = None

GPU service class to handle gpu availability for models.

GPUService

GPU service class to handle gpu availability for models.

Source code in src/wordcab_transcribe/services/concurrency_services.py
class GPUService:
    """GPU service class to handle gpu availability for models."""

    def __init__(self, device: str, device_index: List[int]) -> None:
        """
        Initialize the GPU service.

        Args:
            device (str): Device to use for inference. Can be "cpu" or "cuda".
            device_index (List[int]): Index of the device to use for inference.
        """
        self.device: str = device
        self.device_index: List[int] = device_index

        self.queue = asyncio.Queue(maxsize=len(self.device_index))
        for idx in self.device_index:
            self.queue.put_nowait(idx)

    async def get_device(self) -> int:
        """
        Get the next available device.

        Returns:
            int: Index of the next available device.
        """
        while True:
            try:
                device_index = self.queue.get_nowait()
                return device_index
            except asyncio.QueueEmpty:
                await asyncio.sleep(1.0)

    def release_device(self, device_index: int) -> None:
        """
        Return a device to the available devices list.

        Args:
            device_index (int): Index of the device to add to the available devices list.
        """
        if not any(item == device_index for item in self.queue._queue):
            self.queue.put_nowait(device_index)

__init__(device, device_index)

Initialize the GPU service.

Parameters:

Name Type Description Default
device str

Device to use for inference. Can be "cpu" or "cuda".

required
device_index List[int]

Index of the device to use for inference.

required
Source code in src/wordcab_transcribe/services/concurrency_services.py
def __init__(self, device: str, device_index: List[int]) -> None:
    """
    Initialize the GPU service.

    Args:
        device (str): Device to use for inference. Can be "cpu" or "cuda".
        device_index (List[int]): Index of the device to use for inference.
    """
    self.device: str = device
    self.device_index: List[int] = device_index

    self.queue = asyncio.Queue(maxsize=len(self.device_index))
    for idx in self.device_index:
        self.queue.put_nowait(idx)

get_device() async

Get the next available device.

Returns:

Name Type Description
int int

Index of the next available device.

Source code in src/wordcab_transcribe/services/concurrency_services.py
async def get_device(self) -> int:
    """
    Get the next available device.

    Returns:
        int: Index of the next available device.
    """
    while True:
        try:
            device_index = self.queue.get_nowait()
            return device_index
        except asyncio.QueueEmpty:
            await asyncio.sleep(1.0)

release_device(device_index)

Return a device to the available devices list.

Parameters:

Name Type Description Default
device_index int

Index of the device to add to the available devices list.

required
Source code in src/wordcab_transcribe/services/concurrency_services.py
def release_device(self, device_index: int) -> None:
    """
    Return a device to the available devices list.

    Args:
        device_index (int): Index of the device to add to the available devices list.
    """
    if not any(item == device_index for item in self.queue._queue):
        self.queue.put_nowait(device_index)

URLService

URL service class to handle multiple remote URLs.

Source code in src/wordcab_transcribe/services/concurrency_services.py
class URLService:
    """URL service class to handle multiple remote URLs."""

    def __init__(self, remote_urls: List[str]) -> None:
        """
        Initialize the URL service.

        Args:
            remote_urls (List[str]): List of remote URLs to use.
        """
        self.remote_urls: List[str] = remote_urls
        self._init_queue()

    def _init_queue(self) -> None:
        """Initialize the queue with the available URLs."""
        self.queue = asyncio.Queue(maxsize=len(self.remote_urls))
        for url in self.remote_urls:
            self.queue.put_nowait(url)

    def get_queue_size(self) -> int:
        """
        Get the current queue size.

        Returns:
            int: Current queue size.
        """
        return self.queue.qsize()

    def get_urls(self) -> List[str]:
        """
        Get the list of available URLs.

        Returns:
            List[str]: List of available URLs.
        """
        return self.remote_urls

    async def next_url(self) -> str:
        """
        We use this to iterate equally over the available URLs.

        Returns:
            str: Next available URL.
        """
        url = self.queue.get_nowait()
        # Unlike GPU we don't want to block remote ASR requests.
        # So we re-insert the URL back into the queue after getting it.
        self.queue.put_nowait(url)

        return url

    async def add_url(self, url: str) -> None:
        """
        Add a URL to the pool of available URLs.

        Args:
            url (str): URL to add to the queue.
        """
        if url not in self.remote_urls:
            self.remote_urls.append(url)

            # Re-initialize the queue with the new URL.
            self._init_queue()

    async def remove_url(self, url: str) -> None:
        """
        Remove a URL from the pool of available URLs.

        Args:
            url (str): URL to remove from the queue.
        """
        if url in self.remote_urls:
            self.remote_urls.remove(url)

            # Re-initialize the queue without the removed URL.
            self._init_queue()

__init__(remote_urls)

Initialize the URL service.

Parameters:

Name Type Description Default
remote_urls List[str]

List of remote URLs to use.

required
Source code in src/wordcab_transcribe/services/concurrency_services.py
def __init__(self, remote_urls: List[str]) -> None:
    """
    Initialize the URL service.

    Args:
        remote_urls (List[str]): List of remote URLs to use.
    """
    self.remote_urls: List[str] = remote_urls
    self._init_queue()

add_url(url) async

Add a URL to the pool of available URLs.

Parameters:

Name Type Description Default
url str

URL to add to the queue.

required
Source code in src/wordcab_transcribe/services/concurrency_services.py
async def add_url(self, url: str) -> None:
    """
    Add a URL to the pool of available URLs.

    Args:
        url (str): URL to add to the queue.
    """
    if url not in self.remote_urls:
        self.remote_urls.append(url)

        # Re-initialize the queue with the new URL.
        self._init_queue()

get_queue_size()

Get the current queue size.

Returns:

Name Type Description
int int

Current queue size.

Source code in src/wordcab_transcribe/services/concurrency_services.py
def get_queue_size(self) -> int:
    """
    Get the current queue size.

    Returns:
        int: Current queue size.
    """
    return self.queue.qsize()

get_urls()

Get the list of available URLs.

Returns:

Type Description
List[str]

List[str]: List of available URLs.

Source code in src/wordcab_transcribe/services/concurrency_services.py
def get_urls(self) -> List[str]:
    """
    Get the list of available URLs.

    Returns:
        List[str]: List of available URLs.
    """
    return self.remote_urls

next_url() async

We use this to iterate equally over the available URLs.

Returns:

Name Type Description
str str

Next available URL.

Source code in src/wordcab_transcribe/services/concurrency_services.py
async def next_url(self) -> str:
    """
    We use this to iterate equally over the available URLs.

    Returns:
        str: Next available URL.
    """
    url = self.queue.get_nowait()
    # Unlike GPU we don't want to block remote ASR requests.
    # So we re-insert the URL back into the queue after getting it.
    self.queue.put_nowait(url)

    return url

remove_url(url) async

Remove a URL from the pool of available URLs.

Parameters:

Name Type Description Default
url str

URL to remove from the queue.

required
Source code in src/wordcab_transcribe/services/concurrency_services.py
async def remove_url(self, url: str) -> None:
    """
    Remove a URL from the pool of available URLs.

    Args:
        url (str): URL to remove from the queue.
    """
    if url in self.remote_urls:
        self.remote_urls.remove(url)

        # Re-initialize the queue without the removed URL.
        self._init_queue()

Diarization Service for audio files.

DiarizationModels

Bases: NamedTuple

Diarization Models.

Source code in src/wordcab_transcribe/services/diarization/diarize_service.py
class DiarizationModels(NamedTuple):
    """Diarization Models."""

    segmentation: SegmentationModule
    clustering: ClusteringModule
    device: str

DiarizeService

Diarize Service for audio files.

Source code in src/wordcab_transcribe/services/diarization/diarize_service.py
class DiarizeService:
    """Diarize Service for audio files."""

    def __init__(
        self,
        device: str,
        device_index: List[int],
        window_lengths: List[float],
        shift_lengths: List[float],
        multiscale_weights: List[int],
        max_num_speakers: int = 8,
    ) -> None:
        """Initialize the Diarize Service.

        This service uses the NVIDIA NeMo diarization models.

        Args:
            device (str): Device to use for inference. Can be "cpu" or "cuda".
            device_index (Union[int, List[int]]): Index of the device to use for inference.
            window_lengths (List[float]): List of window lengths.
            shift_lengths (List[float]): List of shift lengths.
            multiscale_weights (List[int]): List of weights for each scale.
            max_num_speakers (int): Maximum number of speakers. Defaults to 8.
        """
        self.device = device
        self.models = {}

        self.max_num_speakers = max_num_speakers
        self.default_window_lengths = window_lengths
        self.default_shift_lengths = shift_lengths
        self.default_multiscale_weights = multiscale_weights

        if len(self.default_multiscale_weights) > 3:
            self.default_segmentation_batch_size = 64
        elif len(self.default_multiscale_weights) > 1:
            self.default_segmentation_batch_size = 128
        else:
            self.default_segmentation_batch_size = 256

        self.default_scale_dict = dict(enumerate(zip(window_lengths, shift_lengths)))

        for idx in device_index:
            _device = f"cuda:{idx}" if self.device == "cuda" else "cpu"

            segmentation_module = SegmentationModule(_device)
            clustering_module = ClusteringModule(_device, self.max_num_speakers)

            self.models[idx] = DiarizationModels(
                segmentation=segmentation_module,
                clustering=clustering_module,
                device=_device,
            )

    def __call__(
        self,
        waveform: Union[torch.Tensor, TensorShare],
        audio_duration: float,
        oracle_num_speakers: int,
        model_index: int,
        vad_service: VadService,
    ) -> DiarizationOutput:
        """
        Run inference with the diarization model.

        Args:
            waveform (Union[torch.Tensor, TensorShare]):
                Waveform to run inference on.
            audio_duration (float):
                Duration of the audio file in seconds.
            oracle_num_speakers (int):
                Number of speakers in the audio file.
            model_index (int):
                Index of the model to use for inference.
            vad_service (VadService):
                VAD service instance to use for Voice Activity Detection.

        Returns:
            DiarizationOutput:
                List of segments with the following keys: "start", "end", "speaker".
        """
        if isinstance(waveform, TensorShare):
            ts = waveform.to_tensors(backend=Backend.TORCH)
            waveform = ts["audio"]

        vad_outputs, _ = vad_service(waveform, group_timestamps=False)

        if len(vad_outputs) == 0:  # Empty audio
            return None

        if audio_duration < 3600:
            scale_dict = self.default_scale_dict
            segmentation_batch_size = self.default_segmentation_batch_size
            multiscale_weights = self.default_multiscale_weights
        elif audio_duration < 10800:
            scale_dict = dict(
                enumerate(
                    zip(
                        [3.0, 2.5, 2.0, 1.5, 1.0],
                        self.default_shift_lengths,
                    )
                )
            )
            segmentation_batch_size = 64
            multiscale_weights = self.default_multiscale_weights
        else:
            scale_dict = dict(enumerate(zip([3.0, 2.0, 1.0], [0.75, 0.5, 0.25])))
            segmentation_batch_size = 32
            multiscale_weights = [1.0, 1.0, 1.0]

        ms_emb_ts: MultiscaleEmbeddingsAndTimestamps = self.models[
            model_index
        ].segmentation(
            waveform=waveform,
            batch_size=segmentation_batch_size,
            vad_outputs=vad_outputs,
            scale_dict=scale_dict,
            multiscale_weights=multiscale_weights,
        )

        clustering_outputs = self.models[model_index].clustering(
            ms_emb_ts, oracle_num_speakers
        )

        _outputs = self.get_contiguous_stamps(clustering_outputs)
        outputs = self.merge_stamps(_outputs)

        return DiarizationOutput(segments=outputs)

    @staticmethod
    def get_contiguous_stamps(
        stamps: List[Tuple[float, float, int]]
    ) -> List[Tuple[float, float, int]]:
        """
        Return contiguous timestamps.

        Args:
            stamps (List[Tuple[float, float, int]]): List of segments containing the start time, end time and speaker.

        Returns:
            List[Tuple[float, float, int]]: List of segments containing the start time, end time and speaker.
        """
        contiguous_stamps = []
        for i in range(len(stamps) - 1):
            start, end, speaker = stamps[i]
            next_start, next_end, next_speaker = stamps[i + 1]

            if end > next_start:
                avg = (next_start + end) / 2.0
                stamps[i + 1] = (avg, next_end, next_speaker)
                contiguous_stamps.append((start, avg, speaker))
            else:
                contiguous_stamps.append((start, end, speaker))

        start, end, speaker = stamps[-1]
        contiguous_stamps.append((start, end, speaker))

        return contiguous_stamps

    @staticmethod
    def merge_stamps(
        stamps: List[Tuple[float, float, int]]
    ) -> List[Tuple[float, float, int]]:
        """
        Merge timestamps of the same speaker.

        Args:
            stamps (List[Tuple[float, float, int]]): List of segments containing the start time, end time and speaker.

        Returns:
            List[Tuple[float, float, int]]: List of segments containing the start time, end time and speaker.
        """
        overlap_stamps = []
        for i in range(len(stamps) - 1):
            start, end, speaker = stamps[i]
            next_start, next_end, next_speaker = stamps[i + 1]

            if end == next_start and speaker == next_speaker:
                stamps[i + 1] = (start, next_end, next_speaker)
            else:
                overlap_stamps.append((start, end, speaker))

        start, end, speaker = stamps[-1]
        overlap_stamps.append((start, end, speaker))

        return overlap_stamps

__call__(waveform, audio_duration, oracle_num_speakers, model_index, vad_service)

Run inference with the diarization model.

Parameters:

Name Type Description Default
waveform Union[Tensor, TensorShare]

Waveform to run inference on.

required
audio_duration float

Duration of the audio file in seconds.

required
oracle_num_speakers int

Number of speakers in the audio file.

required
model_index int

Index of the model to use for inference.

required
vad_service VadService

VAD service instance to use for Voice Activity Detection.

required

Returns:

Name Type Description
DiarizationOutput DiarizationOutput

List of segments with the following keys: "start", "end", "speaker".

Source code in src/wordcab_transcribe/services/diarization/diarize_service.py
def __call__(
    self,
    waveform: Union[torch.Tensor, TensorShare],
    audio_duration: float,
    oracle_num_speakers: int,
    model_index: int,
    vad_service: VadService,
) -> DiarizationOutput:
    """
    Run inference with the diarization model.

    Args:
        waveform (Union[torch.Tensor, TensorShare]):
            Waveform to run inference on.
        audio_duration (float):
            Duration of the audio file in seconds.
        oracle_num_speakers (int):
            Number of speakers in the audio file.
        model_index (int):
            Index of the model to use for inference.
        vad_service (VadService):
            VAD service instance to use for Voice Activity Detection.

    Returns:
        DiarizationOutput:
            List of segments with the following keys: "start", "end", "speaker".
    """
    if isinstance(waveform, TensorShare):
        ts = waveform.to_tensors(backend=Backend.TORCH)
        waveform = ts["audio"]

    vad_outputs, _ = vad_service(waveform, group_timestamps=False)

    if len(vad_outputs) == 0:  # Empty audio
        return None

    if audio_duration < 3600:
        scale_dict = self.default_scale_dict
        segmentation_batch_size = self.default_segmentation_batch_size
        multiscale_weights = self.default_multiscale_weights
    elif audio_duration < 10800:
        scale_dict = dict(
            enumerate(
                zip(
                    [3.0, 2.5, 2.0, 1.5, 1.0],
                    self.default_shift_lengths,
                )
            )
        )
        segmentation_batch_size = 64
        multiscale_weights = self.default_multiscale_weights
    else:
        scale_dict = dict(enumerate(zip([3.0, 2.0, 1.0], [0.75, 0.5, 0.25])))
        segmentation_batch_size = 32
        multiscale_weights = [1.0, 1.0, 1.0]

    ms_emb_ts: MultiscaleEmbeddingsAndTimestamps = self.models[
        model_index
    ].segmentation(
        waveform=waveform,
        batch_size=segmentation_batch_size,
        vad_outputs=vad_outputs,
        scale_dict=scale_dict,
        multiscale_weights=multiscale_weights,
    )

    clustering_outputs = self.models[model_index].clustering(
        ms_emb_ts, oracle_num_speakers
    )

    _outputs = self.get_contiguous_stamps(clustering_outputs)
    outputs = self.merge_stamps(_outputs)

    return DiarizationOutput(segments=outputs)

__init__(device, device_index, window_lengths, shift_lengths, multiscale_weights, max_num_speakers=8)

Initialize the Diarize Service.

This service uses the NVIDIA NeMo diarization models.

Parameters:

Name Type Description Default
device str

Device to use for inference. Can be "cpu" or "cuda".

required
device_index Union[int, List[int]]

Index of the device to use for inference.

required
window_lengths List[float]

List of window lengths.

required
shift_lengths List[float]

List of shift lengths.

required
multiscale_weights List[int]

List of weights for each scale.

required
max_num_speakers int

Maximum number of speakers. Defaults to 8.

8
Source code in src/wordcab_transcribe/services/diarization/diarize_service.py
def __init__(
    self,
    device: str,
    device_index: List[int],
    window_lengths: List[float],
    shift_lengths: List[float],
    multiscale_weights: List[int],
    max_num_speakers: int = 8,
) -> None:
    """Initialize the Diarize Service.

    This service uses the NVIDIA NeMo diarization models.

    Args:
        device (str): Device to use for inference. Can be "cpu" or "cuda".
        device_index (Union[int, List[int]]): Index of the device to use for inference.
        window_lengths (List[float]): List of window lengths.
        shift_lengths (List[float]): List of shift lengths.
        multiscale_weights (List[int]): List of weights for each scale.
        max_num_speakers (int): Maximum number of speakers. Defaults to 8.
    """
    self.device = device
    self.models = {}

    self.max_num_speakers = max_num_speakers
    self.default_window_lengths = window_lengths
    self.default_shift_lengths = shift_lengths
    self.default_multiscale_weights = multiscale_weights

    if len(self.default_multiscale_weights) > 3:
        self.default_segmentation_batch_size = 64
    elif len(self.default_multiscale_weights) > 1:
        self.default_segmentation_batch_size = 128
    else:
        self.default_segmentation_batch_size = 256

    self.default_scale_dict = dict(enumerate(zip(window_lengths, shift_lengths)))

    for idx in device_index:
        _device = f"cuda:{idx}" if self.device == "cuda" else "cpu"

        segmentation_module = SegmentationModule(_device)
        clustering_module = ClusteringModule(_device, self.max_num_speakers)

        self.models[idx] = DiarizationModels(
            segmentation=segmentation_module,
            clustering=clustering_module,
            device=_device,
        )

get_contiguous_stamps(stamps) staticmethod

Return contiguous timestamps.

Parameters:

Name Type Description Default
stamps List[Tuple[float, float, int]]

List of segments containing the start time, end time and speaker.

required

Returns:

Type Description
List[Tuple[float, float, int]]

List[Tuple[float, float, int]]: List of segments containing the start time, end time and speaker.

Source code in src/wordcab_transcribe/services/diarization/diarize_service.py
@staticmethod
def get_contiguous_stamps(
    stamps: List[Tuple[float, float, int]]
) -> List[Tuple[float, float, int]]:
    """
    Return contiguous timestamps.

    Args:
        stamps (List[Tuple[float, float, int]]): List of segments containing the start time, end time and speaker.

    Returns:
        List[Tuple[float, float, int]]: List of segments containing the start time, end time and speaker.
    """
    contiguous_stamps = []
    for i in range(len(stamps) - 1):
        start, end, speaker = stamps[i]
        next_start, next_end, next_speaker = stamps[i + 1]

        if end > next_start:
            avg = (next_start + end) / 2.0
            stamps[i + 1] = (avg, next_end, next_speaker)
            contiguous_stamps.append((start, avg, speaker))
        else:
            contiguous_stamps.append((start, end, speaker))

    start, end, speaker = stamps[-1]
    contiguous_stamps.append((start, end, speaker))

    return contiguous_stamps

merge_stamps(stamps) staticmethod

Merge timestamps of the same speaker.

Parameters:

Name Type Description Default
stamps List[Tuple[float, float, int]]

List of segments containing the start time, end time and speaker.

required

Returns:

Type Description
List[Tuple[float, float, int]]

List[Tuple[float, float, int]]: List of segments containing the start time, end time and speaker.

Source code in src/wordcab_transcribe/services/diarization/diarize_service.py
@staticmethod
def merge_stamps(
    stamps: List[Tuple[float, float, int]]
) -> List[Tuple[float, float, int]]:
    """
    Merge timestamps of the same speaker.

    Args:
        stamps (List[Tuple[float, float, int]]): List of segments containing the start time, end time and speaker.

    Returns:
        List[Tuple[float, float, int]]: List of segments containing the start time, end time and speaker.
    """
    overlap_stamps = []
    for i in range(len(stamps) - 1):
        start, end, speaker = stamps[i]
        next_start, next_end, next_speaker = stamps[i + 1]

        if end == next_start and speaker == next_speaker:
            stamps[i + 1] = (start, next_end, next_speaker)
        else:
            overlap_stamps.append((start, end, speaker))

    start, end, speaker = stamps[-1]
    overlap_stamps.append((start, end, speaker))

    return overlap_stamps

Post-Processing Service for audio files.

PostProcessingService

Post-Processing Service for audio files.

Source code in src/wordcab_transcribe/services/post_processing_service.py
class PostProcessingService:
    """Post-Processing Service for audio files."""

    def __init__(self) -> None:
        """Initialize the PostProcessingService."""
        self.sample_rate = 16000

    def single_channel_speaker_mapping(
        self,
        transcript_segments: List[Utterance],
        speaker_timestamps: DiarizationOutput,
        word_timestamps: bool,
    ) -> List[Utterance]:
        """Run the post-processing functions on the inputs.

        The postprocessing pipeline is as follows:
        1. Map each transcript segment to its corresponding speaker.
        2. Group utterances of the same speaker together.

        Args:
            transcript_segments (List[Utterance]):
                List of transcript utterances.
            speaker_timestamps (DiarizationOutput):
                List of speaker timestamps.
            word_timestamps (bool):
                Whether to include word timestamps.

        Returns:
            List[Utterance]:
                List of utterances with speaker mapping.
        """
        segments_with_speaker_mapping = self.segments_speaker_mapping(
            transcript_segments,
            speaker_timestamps.segments,
        )

        utterances = self.reconstruct_utterances(
            segments_with_speaker_mapping, word_timestamps
        )

        return utterances

    def multi_channel_speaker_mapping(
        self, multi_channel_segments: List[MultiChannelTranscriptionOutput]
    ) -> TranscriptionOutput:
        """
        Run the multi-channel post-processing functions on the inputs by merging the segments based on the timestamps.

        Args:
            multi_channel_segments (List[MultiChannelTranscriptionOutput]):
                List of segments from multi speakers.

        Returns:
            TranscriptionOutput: List of sentences with speaker mapping.
        """
        words_with_speaker_mapping = [
            (segment.speaker, word)
            for output in multi_channel_segments
            for segment in output.segments
            for word in segment.words
        ]
        words_with_speaker_mapping.sort(key=lambda x: x[1].start)

        utterances: List[Utterance] = self.reconstruct_multi_channel_utterances(
            words_with_speaker_mapping
        )

        return utterances

    def segments_speaker_mapping(
        self,
        transcript_segments: List[Utterance],
        speaker_timestamps: List[DiarizationSegment],
    ) -> List[dict]:
        """Function to map transcription and diarization results.

        Map each segment to its corresponding speaker based on the speaker timestamps and reconstruct the utterances
        when the speaker changes in the middle of a segment.

        Args:
            transcript_segments (List[dict]): List of transcript segments.
            speaker_timestamps (List[dict]): List of speaker timestamps.

        Returns:
            List[dict]: List of sentences with speaker mapping.
        """

        def _assign_speaker(
            mapping: list,
            seg_index: int,
            split: bool,
            current_speaker: str,
            current_split_len: int,
        ):
            """Assign speaker to the segment."""
            if split and len(mapping) > 1:
                last_split_len = len(mapping[seg_index - 1].text)
                if last_split_len > current_split_len:
                    current_speaker = mapping[seg_index - 1].speaker
                elif last_split_len < current_split_len:
                    mapping[seg_index - 1].speaker = current_speaker
            return current_speaker

        threshold = 0.3
        turn_idx = 0
        was_split = False
        _, end, speaker = speaker_timestamps[turn_idx]

        segment_index = 0
        segment_speaker_mapping = []
        while segment_index < len(transcript_segments):
            segment: Utterance = transcript_segments[segment_index]
            segment_start, segment_end, segment_text = (
                segment.start,
                segment.end,
                segment.text,
            )
            while (
                segment_start > float(end)
                or abs(segment_start - float(end)) < threshold
            ):
                turn_idx += 1
                turn_idx = min(turn_idx, len(speaker_timestamps) - 1)
                _, end, speaker = speaker_timestamps[turn_idx]
                if turn_idx == len(speaker_timestamps) - 1:
                    end = segment_end
                    break

            if segment_end > float(end) and abs(segment_end - float(end)) > threshold:
                words = segment.words
                word_index = next(
                    (
                        i
                        for i, word in enumerate(words)
                        if word.start > float(end)
                        or abs(word.start - float(end)) < threshold
                    ),
                    None,
                )

                if word_index is not None:
                    _split_segment = segment_text.split()

                    if word_index > 0:
                        text = " ".join(_split_segment[:word_index])
                        speaker = _assign_speaker(
                            segment_speaker_mapping,
                            segment_index,
                            was_split,
                            speaker,
                            len(text),
                        )

                        _segment_to_add = Utterance(
                            start=words[0].start,
                            end=words[word_index - 1].end,
                            text=text,
                            speaker=speaker,
                            words=words[:word_index],
                        )
                    else:
                        text = _split_segment[0]
                        speaker = _assign_speaker(
                            segment_speaker_mapping,
                            segment_index,
                            was_split,
                            speaker,
                            len(text),
                        )

                        _segment_to_add = Utterance(
                            start=words[0].start,
                            end=words[0].end,
                            text=_split_segment[0],
                            speaker=speaker,
                            words=words[:1],
                        )
                    segment_speaker_mapping.append(_segment_to_add)
                    transcript_segments.insert(
                        segment_index + 1,
                        Utterance(
                            start=words[word_index].start,
                            end=segment_end,
                            text=" ".join(_split_segment[word_index:]),
                            words=words[word_index:],
                        ),
                    )
                    was_split = True
                else:
                    speaker = _assign_speaker(
                        segment_speaker_mapping,
                        segment_index,
                        was_split,
                        speaker,
                        len(segment_text),
                    )
                    was_split = False

                    segment_speaker_mapping.append(
                        Utterance(
                            start=segment_start,
                            end=segment_end,
                            text=segment_text,
                            speaker=speaker,
                            words=words,
                        )
                    )
            else:
                speaker = _assign_speaker(
                    segment_speaker_mapping,
                    segment_index,
                    was_split,
                    speaker,
                    len(segment_text),
                )
                was_split = False

                segment_speaker_mapping.append(
                    Utterance(
                        start=segment_start,
                        end=segment_end,
                        text=segment_text,
                        speaker=speaker,
                        words=segment.words,
                    )
                )
            segment_index += 1

        return segment_speaker_mapping

    def reconstruct_utterances(
        self,
        transcript_segments: List[Utterance],
        word_timestamps: bool,
    ) -> List[Utterance]:
        """
        Reconstruct the utterances based on the speaker mapping.

        Args:
            transcript_words (List[Utterance]):
                List of transcript segments.
            word_timestamps (bool):
                Whether to include word timestamps.

        Returns:
            List[Utterance]:
                List of sentences with speaker mapping.
        """
        start_t0, end_t0, speaker_t0 = (
            transcript_segments[0].start,
            transcript_segments[0].end,
            transcript_segments[0].speaker,
        )

        previous_speaker = speaker_t0
        current_sentence = {
            "speaker": speaker_t0,
            "start": start_t0,
            "end": end_t0,
            "text": "",
        }
        if word_timestamps:
            current_sentence["words"] = []

        sentences = []
        for segment in transcript_segments:
            text, speaker = segment.text, segment.speaker
            start_t, end_t = segment.start, segment.end

            if speaker != previous_speaker:
                sentences.append(Utterance(**current_sentence))
                current_sentence = {
                    "speaker": speaker,
                    "start": start_t,
                    "end": end_t,
                    "text": "",
                }
                if word_timestamps:
                    current_sentence["words"] = []
            else:
                current_sentence["end"] = end_t

            current_sentence["text"] += text + " "
            previous_speaker = speaker
            if word_timestamps:
                current_sentence["words"].extend(segment.words)

        # Catch the last sentence
        sentences.append(Utterance(**current_sentence))

        return sentences

    def reconstruct_multi_channel_utterances(
        self,
        transcript_words: List[Tuple[int, Word]],
    ) -> List[Utterance]:
        """
        Reconstruct multi-channel utterances based on the speaker mapping.

        Args:
            transcript_words (List[Tuple[int, Word]]):
                List of tuples containing the speaker and the word.

        Returns:
            List[Utterance]: List of sentences with speaker mapping.
        """
        speaker_t0, word = transcript_words[0]
        start_t0, end_t0 = word.start, word.end

        previous_speaker = speaker_t0
        current_sentence = {
            "speaker": speaker_t0,
            "start": start_t0,
            "end": end_t0,
            "text": "",
            "words": [],
        }

        sentences = []
        for speaker, word in transcript_words:
            start_t, end_t, text = word.start, word.end, word.word

            if speaker != previous_speaker:
                sentences.append(current_sentence)
                current_sentence = {
                    "speaker": speaker,
                    "start": start_t,
                    "end": end_t,
                    "text": "",
                }
                current_sentence["words"] = []
            else:
                current_sentence["end"] = end_t

            current_sentence["text"] += text
            previous_speaker = speaker
            current_sentence["words"].append(word)

        # Catch the last sentence
        sentences.append(current_sentence)

        for sentence in sentences:
            sentence["text"] = sentence["text"].strip()

        return [Utterance(**sentence) for sentence in sentences]

    def final_processing_before_returning(
        self,
        utterances: List[Utterance],
        offset_start: Union[float, None],
        timestamps_format: Timestamps,
        word_timestamps: bool,
    ) -> List[Utterance]:
        """
        Do final processing before returning the utterances to the API.

        Args:
            utterances (List[Utterance]):
                List of utterances.
            offset_start (Union[float, None]):
                Offset start.
            timestamps_format (Timestamps):
                Timestamps format. Can be `s`, `ms`, or `hms`.
            word_timestamps (bool):
                Whether to include word timestamps.

        Returns:
            List[Utterance]:
                List of utterances after final processing.
        """
        if offset_start is not None:
            offset_start = float(offset_start)
        else:
            offset_start = 0.0

        final_utterances = []
        for utterance in utterances:
            # Check if the utterance is not empty
            if utterance.text.strip():
                utterance.text = format_punct(utterance.text)
                utterance.start = convert_timestamp(
                    (utterance.start + offset_start), timestamps_format
                )
                utterance.end = convert_timestamp(
                    (utterance.end + offset_start), timestamps_format
                )
                utterance.words = utterance.words if word_timestamps else None

                final_utterances.append(utterance)

        return final_utterances

__init__()

Initialize the PostProcessingService.

Source code in src/wordcab_transcribe/services/post_processing_service.py
def __init__(self) -> None:
    """Initialize the PostProcessingService."""
    self.sample_rate = 16000

final_processing_before_returning(utterances, offset_start, timestamps_format, word_timestamps)

Do final processing before returning the utterances to the API.

Parameters:

Name Type Description Default
utterances List[Utterance]

List of utterances.

required
offset_start Union[float, None]

Offset start.

required
timestamps_format Timestamps

Timestamps format. Can be s, ms, or hms.

required
word_timestamps bool

Whether to include word timestamps.

required

Returns:

Type Description
List[Utterance]

List[Utterance]: List of utterances after final processing.

Source code in src/wordcab_transcribe/services/post_processing_service.py
def final_processing_before_returning(
    self,
    utterances: List[Utterance],
    offset_start: Union[float, None],
    timestamps_format: Timestamps,
    word_timestamps: bool,
) -> List[Utterance]:
    """
    Do final processing before returning the utterances to the API.

    Args:
        utterances (List[Utterance]):
            List of utterances.
        offset_start (Union[float, None]):
            Offset start.
        timestamps_format (Timestamps):
            Timestamps format. Can be `s`, `ms`, or `hms`.
        word_timestamps (bool):
            Whether to include word timestamps.

    Returns:
        List[Utterance]:
            List of utterances after final processing.
    """
    if offset_start is not None:
        offset_start = float(offset_start)
    else:
        offset_start = 0.0

    final_utterances = []
    for utterance in utterances:
        # Check if the utterance is not empty
        if utterance.text.strip():
            utterance.text = format_punct(utterance.text)
            utterance.start = convert_timestamp(
                (utterance.start + offset_start), timestamps_format
            )
            utterance.end = convert_timestamp(
                (utterance.end + offset_start), timestamps_format
            )
            utterance.words = utterance.words if word_timestamps else None

            final_utterances.append(utterance)

    return final_utterances

multi_channel_speaker_mapping(multi_channel_segments)

Run the multi-channel post-processing functions on the inputs by merging the segments based on the timestamps.

Parameters:

Name Type Description Default
multi_channel_segments List[MultiChannelTranscriptionOutput]

List of segments from multi speakers.

required

Returns:

Name Type Description
TranscriptionOutput TranscriptionOutput

List of sentences with speaker mapping.

Source code in src/wordcab_transcribe/services/post_processing_service.py
def multi_channel_speaker_mapping(
    self, multi_channel_segments: List[MultiChannelTranscriptionOutput]
) -> TranscriptionOutput:
    """
    Run the multi-channel post-processing functions on the inputs by merging the segments based on the timestamps.

    Args:
        multi_channel_segments (List[MultiChannelTranscriptionOutput]):
            List of segments from multi speakers.

    Returns:
        TranscriptionOutput: List of sentences with speaker mapping.
    """
    words_with_speaker_mapping = [
        (segment.speaker, word)
        for output in multi_channel_segments
        for segment in output.segments
        for word in segment.words
    ]
    words_with_speaker_mapping.sort(key=lambda x: x[1].start)

    utterances: List[Utterance] = self.reconstruct_multi_channel_utterances(
        words_with_speaker_mapping
    )

    return utterances

reconstruct_multi_channel_utterances(transcript_words)

Reconstruct multi-channel utterances based on the speaker mapping.

Parameters:

Name Type Description Default
transcript_words List[Tuple[int, Word]]

List of tuples containing the speaker and the word.

required

Returns:

Type Description
List[Utterance]

List[Utterance]: List of sentences with speaker mapping.

Source code in src/wordcab_transcribe/services/post_processing_service.py
def reconstruct_multi_channel_utterances(
    self,
    transcript_words: List[Tuple[int, Word]],
) -> List[Utterance]:
    """
    Reconstruct multi-channel utterances based on the speaker mapping.

    Args:
        transcript_words (List[Tuple[int, Word]]):
            List of tuples containing the speaker and the word.

    Returns:
        List[Utterance]: List of sentences with speaker mapping.
    """
    speaker_t0, word = transcript_words[0]
    start_t0, end_t0 = word.start, word.end

    previous_speaker = speaker_t0
    current_sentence = {
        "speaker": speaker_t0,
        "start": start_t0,
        "end": end_t0,
        "text": "",
        "words": [],
    }

    sentences = []
    for speaker, word in transcript_words:
        start_t, end_t, text = word.start, word.end, word.word

        if speaker != previous_speaker:
            sentences.append(current_sentence)
            current_sentence = {
                "speaker": speaker,
                "start": start_t,
                "end": end_t,
                "text": "",
            }
            current_sentence["words"] = []
        else:
            current_sentence["end"] = end_t

        current_sentence["text"] += text
        previous_speaker = speaker
        current_sentence["words"].append(word)

    # Catch the last sentence
    sentences.append(current_sentence)

    for sentence in sentences:
        sentence["text"] = sentence["text"].strip()

    return [Utterance(**sentence) for sentence in sentences]

reconstruct_utterances(transcript_segments, word_timestamps)

Reconstruct the utterances based on the speaker mapping.

Parameters:

Name Type Description Default
transcript_words List[Utterance]

List of transcript segments.

required
word_timestamps bool

Whether to include word timestamps.

required

Returns:

Type Description
List[Utterance]

List[Utterance]: List of sentences with speaker mapping.

Source code in src/wordcab_transcribe/services/post_processing_service.py
def reconstruct_utterances(
    self,
    transcript_segments: List[Utterance],
    word_timestamps: bool,
) -> List[Utterance]:
    """
    Reconstruct the utterances based on the speaker mapping.

    Args:
        transcript_words (List[Utterance]):
            List of transcript segments.
        word_timestamps (bool):
            Whether to include word timestamps.

    Returns:
        List[Utterance]:
            List of sentences with speaker mapping.
    """
    start_t0, end_t0, speaker_t0 = (
        transcript_segments[0].start,
        transcript_segments[0].end,
        transcript_segments[0].speaker,
    )

    previous_speaker = speaker_t0
    current_sentence = {
        "speaker": speaker_t0,
        "start": start_t0,
        "end": end_t0,
        "text": "",
    }
    if word_timestamps:
        current_sentence["words"] = []

    sentences = []
    for segment in transcript_segments:
        text, speaker = segment.text, segment.speaker
        start_t, end_t = segment.start, segment.end

        if speaker != previous_speaker:
            sentences.append(Utterance(**current_sentence))
            current_sentence = {
                "speaker": speaker,
                "start": start_t,
                "end": end_t,
                "text": "",
            }
            if word_timestamps:
                current_sentence["words"] = []
        else:
            current_sentence["end"] = end_t

        current_sentence["text"] += text + " "
        previous_speaker = speaker
        if word_timestamps:
            current_sentence["words"].extend(segment.words)

    # Catch the last sentence
    sentences.append(Utterance(**current_sentence))

    return sentences

segments_speaker_mapping(transcript_segments, speaker_timestamps)

Function to map transcription and diarization results.

Map each segment to its corresponding speaker based on the speaker timestamps and reconstruct the utterances when the speaker changes in the middle of a segment.

Parameters:

Name Type Description Default
transcript_segments List[dict]

List of transcript segments.

required
speaker_timestamps List[dict]

List of speaker timestamps.

required

Returns:

Type Description
List[dict]

List[dict]: List of sentences with speaker mapping.

Source code in src/wordcab_transcribe/services/post_processing_service.py
def segments_speaker_mapping(
    self,
    transcript_segments: List[Utterance],
    speaker_timestamps: List[DiarizationSegment],
) -> List[dict]:
    """Function to map transcription and diarization results.

    Map each segment to its corresponding speaker based on the speaker timestamps and reconstruct the utterances
    when the speaker changes in the middle of a segment.

    Args:
        transcript_segments (List[dict]): List of transcript segments.
        speaker_timestamps (List[dict]): List of speaker timestamps.

    Returns:
        List[dict]: List of sentences with speaker mapping.
    """

    def _assign_speaker(
        mapping: list,
        seg_index: int,
        split: bool,
        current_speaker: str,
        current_split_len: int,
    ):
        """Assign speaker to the segment."""
        if split and len(mapping) > 1:
            last_split_len = len(mapping[seg_index - 1].text)
            if last_split_len > current_split_len:
                current_speaker = mapping[seg_index - 1].speaker
            elif last_split_len < current_split_len:
                mapping[seg_index - 1].speaker = current_speaker
        return current_speaker

    threshold = 0.3
    turn_idx = 0
    was_split = False
    _, end, speaker = speaker_timestamps[turn_idx]

    segment_index = 0
    segment_speaker_mapping = []
    while segment_index < len(transcript_segments):
        segment: Utterance = transcript_segments[segment_index]
        segment_start, segment_end, segment_text = (
            segment.start,
            segment.end,
            segment.text,
        )
        while (
            segment_start > float(end)
            or abs(segment_start - float(end)) < threshold
        ):
            turn_idx += 1
            turn_idx = min(turn_idx, len(speaker_timestamps) - 1)
            _, end, speaker = speaker_timestamps[turn_idx]
            if turn_idx == len(speaker_timestamps) - 1:
                end = segment_end
                break

        if segment_end > float(end) and abs(segment_end - float(end)) > threshold:
            words = segment.words
            word_index = next(
                (
                    i
                    for i, word in enumerate(words)
                    if word.start > float(end)
                    or abs(word.start - float(end)) < threshold
                ),
                None,
            )

            if word_index is not None:
                _split_segment = segment_text.split()

                if word_index > 0:
                    text = " ".join(_split_segment[:word_index])
                    speaker = _assign_speaker(
                        segment_speaker_mapping,
                        segment_index,
                        was_split,
                        speaker,
                        len(text),
                    )

                    _segment_to_add = Utterance(
                        start=words[0].start,
                        end=words[word_index - 1].end,
                        text=text,
                        speaker=speaker,
                        words=words[:word_index],
                    )
                else:
                    text = _split_segment[0]
                    speaker = _assign_speaker(
                        segment_speaker_mapping,
                        segment_index,
                        was_split,
                        speaker,
                        len(text),
                    )

                    _segment_to_add = Utterance(
                        start=words[0].start,
                        end=words[0].end,
                        text=_split_segment[0],
                        speaker=speaker,
                        words=words[:1],
                    )
                segment_speaker_mapping.append(_segment_to_add)
                transcript_segments.insert(
                    segment_index + 1,
                    Utterance(
                        start=words[word_index].start,
                        end=segment_end,
                        text=" ".join(_split_segment[word_index:]),
                        words=words[word_index:],
                    ),
                )
                was_split = True
            else:
                speaker = _assign_speaker(
                    segment_speaker_mapping,
                    segment_index,
                    was_split,
                    speaker,
                    len(segment_text),
                )
                was_split = False

                segment_speaker_mapping.append(
                    Utterance(
                        start=segment_start,
                        end=segment_end,
                        text=segment_text,
                        speaker=speaker,
                        words=words,
                    )
                )
        else:
            speaker = _assign_speaker(
                segment_speaker_mapping,
                segment_index,
                was_split,
                speaker,
                len(segment_text),
            )
            was_split = False

            segment_speaker_mapping.append(
                Utterance(
                    start=segment_start,
                    end=segment_end,
                    text=segment_text,
                    speaker=speaker,
                    words=segment.words,
                )
            )
        segment_index += 1

    return segment_speaker_mapping

single_channel_speaker_mapping(transcript_segments, speaker_timestamps, word_timestamps)

Run the post-processing functions on the inputs.

The postprocessing pipeline is as follows: 1. Map each transcript segment to its corresponding speaker. 2. Group utterances of the same speaker together.

Parameters:

Name Type Description Default
transcript_segments List[Utterance]

List of transcript utterances.

required
speaker_timestamps DiarizationOutput

List of speaker timestamps.

required
word_timestamps bool

Whether to include word timestamps.

required

Returns:

Type Description
List[Utterance]

List[Utterance]: List of utterances with speaker mapping.

Source code in src/wordcab_transcribe/services/post_processing_service.py
def single_channel_speaker_mapping(
    self,
    transcript_segments: List[Utterance],
    speaker_timestamps: DiarizationOutput,
    word_timestamps: bool,
) -> List[Utterance]:
    """Run the post-processing functions on the inputs.

    The postprocessing pipeline is as follows:
    1. Map each transcript segment to its corresponding speaker.
    2. Group utterances of the same speaker together.

    Args:
        transcript_segments (List[Utterance]):
            List of transcript utterances.
        speaker_timestamps (DiarizationOutput):
            List of speaker timestamps.
        word_timestamps (bool):
            Whether to include word timestamps.

    Returns:
        List[Utterance]:
            List of utterances with speaker mapping.
    """
    segments_with_speaker_mapping = self.segments_speaker_mapping(
        transcript_segments,
        speaker_timestamps.segments,
    )

    utterances = self.reconstruct_utterances(
        segments_with_speaker_mapping, word_timestamps
    )

    return utterances

Transcribe Service for audio files.

FasterWhisperModel

Bases: NamedTuple

Faster Whisper Model.

Source code in src/wordcab_transcribe/services/transcribe_service.py
class FasterWhisperModel(NamedTuple):
    """Faster Whisper Model."""

    model: WhisperModel
    lang: str

TranscribeService

Transcribe Service for audio files.

Source code in src/wordcab_transcribe/services/transcribe_service.py
class TranscribeService:
    """Transcribe Service for audio files."""

    def __init__(
        self,
        model_path: str,
        compute_type: str,
        device: str,
        device_index: Union[int, List[int]],
        extra_languages: Union[List[str], None] = None,
        extra_languages_model_paths: Union[List[str], None] = None,
    ) -> None:
        """Initialize the Transcribe Service.

        This service uses the WhisperModel from faster-whisper to transcribe audio files.

        Args:
            model_path (str):
                Path to the model checkpoint. This can be a local path or a URL.
            compute_type (str):
                Compute type to use for inference. Can be "int8", "int8_float16", "int16" or "float_16".
            device (str):
                Device to use for inference. Can be "cpu" or "cuda".
            device_index (Union[int, List[int]]):
                Index of the device to use for inference.
            extra_languages (Union[List[str], None]):
                List of extra languages to transcribe. Defaults to None.
            extra_languages_model_paths (Union[List[str], None]):
                List of paths to the extra language models. Defaults to None.
        """
        self.device = device
        self.compute_type = compute_type
        self.model_path = model_path

        self.model = WhisperModel(
            self.model_path,
            device=self.device,
            device_index=device_index,
            compute_type=self.compute_type,
        )

        self.extra_lang = extra_languages
        self.extra_lang_models = extra_languages_model_paths

    def __call__(
        self,
        audio: Union[
            str,
            torch.Tensor,
            TensorShare,
            List[str],
            List[torch.Tensor],
            List[TensorShare],
        ],
        source_lang: str,
        model_index: int,
        suppress_blank: bool = False,
        vocab: Union[List[str], None] = None,
        word_timestamps: bool = True,
        internal_vad: bool = False,
        repetition_penalty: float = 1.0,
        compression_ratio_threshold: float = 2.4,
        log_prob_threshold: float = -1.0,
        no_speech_threshold: float = 0.6,
        condition_on_previous_text: bool = True,
    ) -> Union[TranscriptionOutput, List[TranscriptionOutput]]:
        """
        Run inference with the transcribe model.

        Args:
            audio (Union[str, torch.Tensor, TensorShare, List[str], List[torch.Tensor], List[TensorShare]]):
                Audio file path or audio tensor. If a tuple is passed, the task is assumed
                to be a multi_channel task and the list of audio files or tensors is passed.
            source_lang (str):
                Language of the audio file.
            model_index (int):
                Index of the model to use.
            suppress_blank (bool):
                Whether to suppress blank at the beginning of the sampling.
            vocab (Union[List[str], None]):
                Vocabulary to use during generation if not None. Defaults to None.
            word_timestamps (bool):
                Whether to return word timestamps.
            internal_vad (bool):
                Whether to use faster-whisper's VAD or not.
            repetition_penalty (float):
                Repetition penalty to use during generation beamed search.
            compression_ratio_threshold (float):
                If the gzip compression ratio is above this value, treat as failed.
            log_prob_threshold (float):
                If the average log probability over sampled tokens is below this value, treat as failed.
            no_speech_threshold (float):
                If the no_speech probability is higher than this value AND the average log probability
                over sampled tokens is below `log_prob_threshold`, consider the segment as silent.
            condition_on_previous_text (bool):
                If True, the previous output of the model is provided as a prompt for the next window;
                disabling may make the text inconsistent across windows, but the model becomes less prone
                to getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.

        Returns:
            Union[TranscriptionOutput, List[TranscriptionOutput]]:
                Transcription output. If the task is a multi_channel task, a list of TranscriptionOutput is returned.
        """
        # Extra language models are disabled until we can handle an index mapping
        # if (
        #     source_lang in self.extra_lang
        #     and self.models[model_index].lang != source_lang
        # ):
        #     logger.debug(f"Loading model for language {source_lang} on GPU {model_index}.")
        #     self.models[model_index] = FasterWhisperModel(
        #         model=WhisperModel(
        #             self.extra_lang_models[source_lang],
        #             device=self.device,
        #             device_index=model_index,
        #             compute_type=self.compute_type,
        #         ),
        #         lang=source_lang,
        #     )
        #     self.loaded_model_lang = source_lang

        # elif source_lang not in self.extra_lang and self.models[model_index].lang != "multi":
        #     logger.debug(f"Re-loading multi-language model on GPU {model_index}.")
        #     self.models[model_index] = FasterWhisperModel(
        #         model=WhisperModel(
        #             self.model_path,
        #             device=self.device,
        #             device_index=model_index,
        #             compute_type=self.compute_type,
        #         ),
        #         lang=source_lang,
        #     )

        if (
            vocab is not None
            and isinstance(vocab, list)
            and len(vocab) > 0
            and vocab[0].strip()
        ):
            words = ", ".join(vocab)
            prompt = f"Vocab: {words.strip()}"
        else:
            prompt = None

        if not isinstance(audio, list):
            if isinstance(audio, torch.Tensor):
                audio = audio.numpy()
            elif isinstance(audio, TensorShare):
                ts = audio.to_tensors(backend=Backend.NUMPY)
                audio = ts["audio"]

            segments, _ = self.model.transcribe(
                audio,
                language=source_lang,
                initial_prompt=prompt,
                repetition_penalty=repetition_penalty,
                compression_ratio_threshold=compression_ratio_threshold,
                log_prob_threshold=log_prob_threshold,
                no_speech_threshold=no_speech_threshold,
                condition_on_previous_text=condition_on_previous_text,
                suppress_blank=suppress_blank,
                word_timestamps=word_timestamps,
                vad_filter=internal_vad,
                vad_parameters={
                    "threshold": 0.5,
                    "min_speech_duration_ms": 250,
                    "min_silence_duration_ms": 100,
                    "speech_pad_ms": 30,
                    "window_size_samples": 512,
                },
            )

            segments = list(segments)
            if not segments:
                logger.warning(
                    "Empty transcription result. Trying with vad_filter=True."
                )
                segments, _ = self.model.transcribe(
                    audio,
                    language=source_lang,
                    initial_prompt=prompt,
                    repetition_penalty=repetition_penalty,
                    compression_ratio_threshold=compression_ratio_threshold,
                    log_prob_threshold=log_prob_threshold,
                    no_speech_threshold=no_speech_threshold,
                    condition_on_previous_text=condition_on_previous_text,
                    suppress_blank=False,
                    word_timestamps=True,
                    vad_filter=False if internal_vad else True,
                )

            _outputs = [segment._asdict() for segment in segments]
            outputs = TranscriptionOutput(segments=_outputs)

        else:
            outputs = []
            for audio_index, audio_file in enumerate(audio):
                outputs.append(
                    self.multi_channel(
                        audio_file,
                        source_lang=source_lang,
                        speaker_id=audio_index,
                        suppress_blank=suppress_blank,
                        word_timestamps=word_timestamps,
                        internal_vad=internal_vad,
                        repetition_penalty=repetition_penalty,
                        compression_ratio_threshold=compression_ratio_threshold,
                        log_prob_threshold=log_prob_threshold,
                        no_speech_threshold=no_speech_threshold,
                        prompt=prompt,
                    )
                )

        return outputs

    async def async_live_transcribe(
        self,
        audio: torch.Tensor,
        source_lang: str,
        model_index: int,
    ) -> Iterable[dict]:
        """Async generator for live transcriptions.

        This method wraps the live_transcribe method to make it async.

        Args:
            audio (torch.Tensor): Audio tensor.
            source_lang (str): Language of the audio file.
            model_index (int): Index of the model to use.

        Yields:
            Iterable[dict]: Iterable of transcribed segments.
        """
        for result in self.live_transcribe(audio, source_lang, model_index):
            yield result

    def live_transcribe(
        self,
        audio: torch.Tensor,
        source_lang: str,
        model_index: int,
    ) -> Iterable[dict]:
        """
        Transcribe audio from a WebSocket connection.

        Args:
            audio (torch.Tensor): Audio tensor.
            source_lang (str): Language of the audio file.
            model_index (int): Index of the model to use.

        Yields:
            Iterable[dict]: Iterable of transcribed segments.
        """
        segments, _ = self.model.transcribe(
            audio.numpy(),
            language=source_lang,
            suppress_blank=True,
            word_timestamps=False,
        )

        for segment in segments:
            yield segment._asdict()

    def multi_channel(
        self,
        audio: Union[str, torch.Tensor, TensorShare],
        source_lang: str,
        speaker_id: int,
        suppress_blank: bool = False,
        word_timestamps: bool = True,
        internal_vad: bool = True,
        repetition_penalty: float = 1.0,
        compression_ratio_threshold: float = 2.4,
        log_prob_threshold: float = -1.0,
        no_speech_threshold: float = 0.6,
        condition_on_previous_text: bool = False,
        prompt: Optional[str] = None,
    ) -> MultiChannelTranscriptionOutput:
        """
        Transcribe an audio file using the faster-whisper original pipeline.

        Args:
            audio (Union[str, torch.Tensor, TensorShare]): Audio file path or loaded audio.
            source_lang (str): Language of the audio file.
            speaker_id (int): Speaker ID used in the diarization.
            suppress_blank (bool):
                Whether to suppress blank at the beginning of the sampling.
            word_timestamps (bool):
                Whether to return word timestamps.
            internal_vad (bool):
                Whether to use faster-whisper's VAD or not.
            repetition_penalty (float):
                Repetition penalty to use during generation beamed search.
            compression_ratio_threshold (float):
                If the gzip compression ratio is above this value, treat as failed.
            log_prob_threshold (float):
                If the average log probability over sampled tokens is below this value, treat as failed.
            no_speech_threshold (float):
                If the no_speech probability is higher than this value AND the average log probability
                over sampled tokens is below `log_prob_threshold`, consider the segment as silent.
            condition_on_previous_text (bool):
                If True, the previous output of the model is provided as a prompt for the next window;
                disabling may make the text inconsistent across windows, but the model becomes less prone
                to getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
            prompt (Optional[str]): Initial prompt to use for the generation.

        Returns:
            MultiChannelTranscriptionOutput: Multi-channel transcription segments in a list.
        """
        if isinstance(audio, torch.Tensor):
            _audio = audio.numpy()
        elif isinstance(audio, TensorShare):
            ts = audio.to_tensors(backend=Backend.NUMPY)
            _audio = ts["audio"]

        final_segments = []

        segments, _ = self.model.transcribe(
            _audio,
            language=source_lang,
            initial_prompt=prompt,
            repetition_penalty=repetition_penalty,
            compression_ratio_threshold=compression_ratio_threshold,
            log_prob_threshold=log_prob_threshold,
            no_speech_threshold=no_speech_threshold,
            condition_on_previous_text=condition_on_previous_text,
            suppress_blank=suppress_blank,
            word_timestamps=word_timestamps,
            vad_filter=internal_vad,
            vad_parameters={
                "threshold": 0.5,
                "min_speech_duration_ms": 250,
                "min_silence_duration_ms": 100,
                "speech_pad_ms": 30,
                "window_size_samples": 512,
            },
        )

        for segment in segments:
            _segment = MultiChannelSegment(
                start=segment.start,
                end=segment.end,
                text=segment.text,
                words=[Word(**word._asdict()) for word in segment.words],
                speaker=speaker_id,
            )
            final_segments.append(_segment)

        return MultiChannelTranscriptionOutput(segments=final_segments)

__call__(audio, source_lang, model_index, suppress_blank=False, vocab=None, word_timestamps=True, internal_vad=False, repetition_penalty=1.0, compression_ratio_threshold=2.4, log_prob_threshold=-1.0, no_speech_threshold=0.6, condition_on_previous_text=True)

Run inference with the transcribe model.

Parameters:

Name Type Description Default
audio Union[str, Tensor, TensorShare, List[str], List[Tensor], List[TensorShare]]

Audio file path or audio tensor. If a tuple is passed, the task is assumed to be a multi_channel task and the list of audio files or tensors is passed.

required
source_lang str

Language of the audio file.

required
model_index int

Index of the model to use.

required
suppress_blank bool

Whether to suppress blank at the beginning of the sampling.

False
vocab Union[List[str], None]

Vocabulary to use during generation if not None. Defaults to None.

None
word_timestamps bool

Whether to return word timestamps.

True
internal_vad bool

Whether to use faster-whisper's VAD or not.

False
repetition_penalty float

Repetition penalty to use during generation beamed search.

1.0
compression_ratio_threshold float

If the gzip compression ratio is above this value, treat as failed.

2.4
log_prob_threshold float

If the average log probability over sampled tokens is below this value, treat as failed.

-1.0
no_speech_threshold float

If the no_speech probability is higher than this value AND the average log probability over sampled tokens is below log_prob_threshold, consider the segment as silent.

0.6
condition_on_previous_text bool

If True, the previous output of the model is provided as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.

True

Returns:

Type Description
Union[TranscriptionOutput, List[TranscriptionOutput]]

Union[TranscriptionOutput, List[TranscriptionOutput]]: Transcription output. If the task is a multi_channel task, a list of TranscriptionOutput is returned.

Source code in src/wordcab_transcribe/services/transcribe_service.py
def __call__(
    self,
    audio: Union[
        str,
        torch.Tensor,
        TensorShare,
        List[str],
        List[torch.Tensor],
        List[TensorShare],
    ],
    source_lang: str,
    model_index: int,
    suppress_blank: bool = False,
    vocab: Union[List[str], None] = None,
    word_timestamps: bool = True,
    internal_vad: bool = False,
    repetition_penalty: float = 1.0,
    compression_ratio_threshold: float = 2.4,
    log_prob_threshold: float = -1.0,
    no_speech_threshold: float = 0.6,
    condition_on_previous_text: bool = True,
) -> Union[TranscriptionOutput, List[TranscriptionOutput]]:
    """
    Run inference with the transcribe model.

    Args:
        audio (Union[str, torch.Tensor, TensorShare, List[str], List[torch.Tensor], List[TensorShare]]):
            Audio file path or audio tensor. If a tuple is passed, the task is assumed
            to be a multi_channel task and the list of audio files or tensors is passed.
        source_lang (str):
            Language of the audio file.
        model_index (int):
            Index of the model to use.
        suppress_blank (bool):
            Whether to suppress blank at the beginning of the sampling.
        vocab (Union[List[str], None]):
            Vocabulary to use during generation if not None. Defaults to None.
        word_timestamps (bool):
            Whether to return word timestamps.
        internal_vad (bool):
            Whether to use faster-whisper's VAD or not.
        repetition_penalty (float):
            Repetition penalty to use during generation beamed search.
        compression_ratio_threshold (float):
            If the gzip compression ratio is above this value, treat as failed.
        log_prob_threshold (float):
            If the average log probability over sampled tokens is below this value, treat as failed.
        no_speech_threshold (float):
            If the no_speech probability is higher than this value AND the average log probability
            over sampled tokens is below `log_prob_threshold`, consider the segment as silent.
        condition_on_previous_text (bool):
            If True, the previous output of the model is provided as a prompt for the next window;
            disabling may make the text inconsistent across windows, but the model becomes less prone
            to getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.

    Returns:
        Union[TranscriptionOutput, List[TranscriptionOutput]]:
            Transcription output. If the task is a multi_channel task, a list of TranscriptionOutput is returned.
    """
    # Extra language models are disabled until we can handle an index mapping
    # if (
    #     source_lang in self.extra_lang
    #     and self.models[model_index].lang != source_lang
    # ):
    #     logger.debug(f"Loading model for language {source_lang} on GPU {model_index}.")
    #     self.models[model_index] = FasterWhisperModel(
    #         model=WhisperModel(
    #             self.extra_lang_models[source_lang],
    #             device=self.device,
    #             device_index=model_index,
    #             compute_type=self.compute_type,
    #         ),
    #         lang=source_lang,
    #     )
    #     self.loaded_model_lang = source_lang

    # elif source_lang not in self.extra_lang and self.models[model_index].lang != "multi":
    #     logger.debug(f"Re-loading multi-language model on GPU {model_index}.")
    #     self.models[model_index] = FasterWhisperModel(
    #         model=WhisperModel(
    #             self.model_path,
    #             device=self.device,
    #             device_index=model_index,
    #             compute_type=self.compute_type,
    #         ),
    #         lang=source_lang,
    #     )

    if (
        vocab is not None
        and isinstance(vocab, list)
        and len(vocab) > 0
        and vocab[0].strip()
    ):
        words = ", ".join(vocab)
        prompt = f"Vocab: {words.strip()}"
    else:
        prompt = None

    if not isinstance(audio, list):
        if isinstance(audio, torch.Tensor):
            audio = audio.numpy()
        elif isinstance(audio, TensorShare):
            ts = audio.to_tensors(backend=Backend.NUMPY)
            audio = ts["audio"]

        segments, _ = self.model.transcribe(
            audio,
            language=source_lang,
            initial_prompt=prompt,
            repetition_penalty=repetition_penalty,
            compression_ratio_threshold=compression_ratio_threshold,
            log_prob_threshold=log_prob_threshold,
            no_speech_threshold=no_speech_threshold,
            condition_on_previous_text=condition_on_previous_text,
            suppress_blank=suppress_blank,
            word_timestamps=word_timestamps,
            vad_filter=internal_vad,
            vad_parameters={
                "threshold": 0.5,
                "min_speech_duration_ms": 250,
                "min_silence_duration_ms": 100,
                "speech_pad_ms": 30,
                "window_size_samples": 512,
            },
        )

        segments = list(segments)
        if not segments:
            logger.warning(
                "Empty transcription result. Trying with vad_filter=True."
            )
            segments, _ = self.model.transcribe(
                audio,
                language=source_lang,
                initial_prompt=prompt,
                repetition_penalty=repetition_penalty,
                compression_ratio_threshold=compression_ratio_threshold,
                log_prob_threshold=log_prob_threshold,
                no_speech_threshold=no_speech_threshold,
                condition_on_previous_text=condition_on_previous_text,
                suppress_blank=False,
                word_timestamps=True,
                vad_filter=False if internal_vad else True,
            )

        _outputs = [segment._asdict() for segment in segments]
        outputs = TranscriptionOutput(segments=_outputs)

    else:
        outputs = []
        for audio_index, audio_file in enumerate(audio):
            outputs.append(
                self.multi_channel(
                    audio_file,
                    source_lang=source_lang,
                    speaker_id=audio_index,
                    suppress_blank=suppress_blank,
                    word_timestamps=word_timestamps,
                    internal_vad=internal_vad,
                    repetition_penalty=repetition_penalty,
                    compression_ratio_threshold=compression_ratio_threshold,
                    log_prob_threshold=log_prob_threshold,
                    no_speech_threshold=no_speech_threshold,
                    prompt=prompt,
                )
            )

    return outputs

__init__(model_path, compute_type, device, device_index, extra_languages=None, extra_languages_model_paths=None)

Initialize the Transcribe Service.

This service uses the WhisperModel from faster-whisper to transcribe audio files.

Parameters:

Name Type Description Default
model_path str

Path to the model checkpoint. This can be a local path or a URL.

required
compute_type str

Compute type to use for inference. Can be "int8", "int8_float16", "int16" or "float_16".

required
device str

Device to use for inference. Can be "cpu" or "cuda".

required
device_index Union[int, List[int]]

Index of the device to use for inference.

required
extra_languages Union[List[str], None]

List of extra languages to transcribe. Defaults to None.

None
extra_languages_model_paths Union[List[str], None]

List of paths to the extra language models. Defaults to None.

None
Source code in src/wordcab_transcribe/services/transcribe_service.py
def __init__(
    self,
    model_path: str,
    compute_type: str,
    device: str,
    device_index: Union[int, List[int]],
    extra_languages: Union[List[str], None] = None,
    extra_languages_model_paths: Union[List[str], None] = None,
) -> None:
    """Initialize the Transcribe Service.

    This service uses the WhisperModel from faster-whisper to transcribe audio files.

    Args:
        model_path (str):
            Path to the model checkpoint. This can be a local path or a URL.
        compute_type (str):
            Compute type to use for inference. Can be "int8", "int8_float16", "int16" or "float_16".
        device (str):
            Device to use for inference. Can be "cpu" or "cuda".
        device_index (Union[int, List[int]]):
            Index of the device to use for inference.
        extra_languages (Union[List[str], None]):
            List of extra languages to transcribe. Defaults to None.
        extra_languages_model_paths (Union[List[str], None]):
            List of paths to the extra language models. Defaults to None.
    """
    self.device = device
    self.compute_type = compute_type
    self.model_path = model_path

    self.model = WhisperModel(
        self.model_path,
        device=self.device,
        device_index=device_index,
        compute_type=self.compute_type,
    )

    self.extra_lang = extra_languages
    self.extra_lang_models = extra_languages_model_paths

async_live_transcribe(audio, source_lang, model_index) async

Async generator for live transcriptions.

This method wraps the live_transcribe method to make it async.

Parameters:

Name Type Description Default
audio Tensor

Audio tensor.

required
source_lang str

Language of the audio file.

required
model_index int

Index of the model to use.

required

Yields:

Type Description
Iterable[dict]

Iterable[dict]: Iterable of transcribed segments.

Source code in src/wordcab_transcribe/services/transcribe_service.py
async def async_live_transcribe(
    self,
    audio: torch.Tensor,
    source_lang: str,
    model_index: int,
) -> Iterable[dict]:
    """Async generator for live transcriptions.

    This method wraps the live_transcribe method to make it async.

    Args:
        audio (torch.Tensor): Audio tensor.
        source_lang (str): Language of the audio file.
        model_index (int): Index of the model to use.

    Yields:
        Iterable[dict]: Iterable of transcribed segments.
    """
    for result in self.live_transcribe(audio, source_lang, model_index):
        yield result

live_transcribe(audio, source_lang, model_index)

Transcribe audio from a WebSocket connection.

Parameters:

Name Type Description Default
audio Tensor

Audio tensor.

required
source_lang str

Language of the audio file.

required
model_index int

Index of the model to use.

required

Yields:

Type Description
Iterable[dict]

Iterable[dict]: Iterable of transcribed segments.

Source code in src/wordcab_transcribe/services/transcribe_service.py
def live_transcribe(
    self,
    audio: torch.Tensor,
    source_lang: str,
    model_index: int,
) -> Iterable[dict]:
    """
    Transcribe audio from a WebSocket connection.

    Args:
        audio (torch.Tensor): Audio tensor.
        source_lang (str): Language of the audio file.
        model_index (int): Index of the model to use.

    Yields:
        Iterable[dict]: Iterable of transcribed segments.
    """
    segments, _ = self.model.transcribe(
        audio.numpy(),
        language=source_lang,
        suppress_blank=True,
        word_timestamps=False,
    )

    for segment in segments:
        yield segment._asdict()

multi_channel(audio, source_lang, speaker_id, suppress_blank=False, word_timestamps=True, internal_vad=True, repetition_penalty=1.0, compression_ratio_threshold=2.4, log_prob_threshold=-1.0, no_speech_threshold=0.6, condition_on_previous_text=False, prompt=None)

Transcribe an audio file using the faster-whisper original pipeline.

Parameters:

Name Type Description Default
audio Union[str, Tensor, TensorShare]

Audio file path or loaded audio.

required
source_lang str

Language of the audio file.

required
speaker_id int

Speaker ID used in the diarization.

required
suppress_blank bool

Whether to suppress blank at the beginning of the sampling.

False
word_timestamps bool

Whether to return word timestamps.

True
internal_vad bool

Whether to use faster-whisper's VAD or not.

True
repetition_penalty float

Repetition penalty to use during generation beamed search.

1.0
compression_ratio_threshold float

If the gzip compression ratio is above this value, treat as failed.

2.4
log_prob_threshold float

If the average log probability over sampled tokens is below this value, treat as failed.

-1.0
no_speech_threshold float

If the no_speech probability is higher than this value AND the average log probability over sampled tokens is below log_prob_threshold, consider the segment as silent.

0.6
condition_on_previous_text bool

If True, the previous output of the model is provided as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.

False
prompt Optional[str]

Initial prompt to use for the generation.

None

Returns:

Name Type Description
MultiChannelTranscriptionOutput MultiChannelTranscriptionOutput

Multi-channel transcription segments in a list.

Source code in src/wordcab_transcribe/services/transcribe_service.py
def multi_channel(
    self,
    audio: Union[str, torch.Tensor, TensorShare],
    source_lang: str,
    speaker_id: int,
    suppress_blank: bool = False,
    word_timestamps: bool = True,
    internal_vad: bool = True,
    repetition_penalty: float = 1.0,
    compression_ratio_threshold: float = 2.4,
    log_prob_threshold: float = -1.0,
    no_speech_threshold: float = 0.6,
    condition_on_previous_text: bool = False,
    prompt: Optional[str] = None,
) -> MultiChannelTranscriptionOutput:
    """
    Transcribe an audio file using the faster-whisper original pipeline.

    Args:
        audio (Union[str, torch.Tensor, TensorShare]): Audio file path or loaded audio.
        source_lang (str): Language of the audio file.
        speaker_id (int): Speaker ID used in the diarization.
        suppress_blank (bool):
            Whether to suppress blank at the beginning of the sampling.
        word_timestamps (bool):
            Whether to return word timestamps.
        internal_vad (bool):
            Whether to use faster-whisper's VAD or not.
        repetition_penalty (float):
            Repetition penalty to use during generation beamed search.
        compression_ratio_threshold (float):
            If the gzip compression ratio is above this value, treat as failed.
        log_prob_threshold (float):
            If the average log probability over sampled tokens is below this value, treat as failed.
        no_speech_threshold (float):
            If the no_speech probability is higher than this value AND the average log probability
            over sampled tokens is below `log_prob_threshold`, consider the segment as silent.
        condition_on_previous_text (bool):
            If True, the previous output of the model is provided as a prompt for the next window;
            disabling may make the text inconsistent across windows, but the model becomes less prone
            to getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
        prompt (Optional[str]): Initial prompt to use for the generation.

    Returns:
        MultiChannelTranscriptionOutput: Multi-channel transcription segments in a list.
    """
    if isinstance(audio, torch.Tensor):
        _audio = audio.numpy()
    elif isinstance(audio, TensorShare):
        ts = audio.to_tensors(backend=Backend.NUMPY)
        _audio = ts["audio"]

    final_segments = []

    segments, _ = self.model.transcribe(
        _audio,
        language=source_lang,
        initial_prompt=prompt,
        repetition_penalty=repetition_penalty,
        compression_ratio_threshold=compression_ratio_threshold,
        log_prob_threshold=log_prob_threshold,
        no_speech_threshold=no_speech_threshold,
        condition_on_previous_text=condition_on_previous_text,
        suppress_blank=suppress_blank,
        word_timestamps=word_timestamps,
        vad_filter=internal_vad,
        vad_parameters={
            "threshold": 0.5,
            "min_speech_duration_ms": 250,
            "min_silence_duration_ms": 100,
            "speech_pad_ms": 30,
            "window_size_samples": 512,
        },
    )

    for segment in segments:
        _segment = MultiChannelSegment(
            start=segment.start,
            end=segment.end,
            text=segment.text,
            words=[Word(**word._asdict()) for word in segment.words],
            speaker=speaker_id,
        )
        final_segments.append(_segment)

    return MultiChannelTranscriptionOutput(segments=final_segments)

Voice Activation Detection (VAD) Service for audio files.

VadService

VAD Service for audio files.

Source code in src/wordcab_transcribe/services/vad_service.py
class VadService:
    """VAD Service for audio files."""

    def __init__(self) -> None:
        """Initialize the VAD Service."""
        self.sample_rate = 16000
        self.options = VadOptions(
            threshold=0.5,
            min_speech_duration_ms=250,
            max_speech_duration_s=30,
            min_silence_duration_ms=100,
            window_size_samples=512,
            speech_pad_ms=400,
        )

    def __call__(
        self, waveform: torch.Tensor, group_timestamps: Optional[bool] = True
    ) -> Tuple[Union[List[dict], List[List[dict]]], torch.Tensor]:
        """
        Use the VAD model to get the speech timestamps. Multi-channel pipeline.

        Args:
            waveform (torch.Tensor): Audio tensor.
            group_timestamps (Optional[bool], optional): Group timestamps. Defaults to True.

        Returns:
            Tuple[Union[List[dict], List[List[dict]]], torch.Tensor]: Speech timestamps and audio tensor.
        """
        if waveform.size(0) == 1:
            waveform = waveform.squeeze(0)

        speech_timestamps = get_speech_timestamps(
            audio=waveform, vad_options=self.options
        )

        _speech_timestamps_list = [
            {"start": ts["start"], "end": ts["end"]} for ts in speech_timestamps
        ]

        if group_timestamps:
            speech_timestamps_list = self.group_timestamps(_speech_timestamps_list)
        else:
            speech_timestamps_list = _speech_timestamps_list

        return speech_timestamps_list, waveform

    def group_timestamps(
        self, timestamps: List[dict], threshold: Optional[float] = 3.0
    ) -> List[List[dict]]:
        """
        Group timestamps based on a threshold.

        Args:
            timestamps (List[dict]): List of timestamps.
            threshold (float, optional): Threshold to use for grouping. Defaults to 3.0.

        Returns:
            List[List[dict]]: List of grouped timestamps.
        """
        grouped_segments = [[]]

        for i in range(len(timestamps)):
            if (
                i > 0
                and (timestamps[i]["start"] - timestamps[i - 1]["end"]) > threshold
            ):
                grouped_segments.append([])

            grouped_segments[-1].append(timestamps[i])

        return grouped_segments

    def save_audio(self, filepath: str, audio: torch.Tensor) -> None:
        """
        Save audio tensor to file.

        Args:
            filepath (str): Path to save the audio file.
            audio (torch.Tensor): Audio tensor.
        """
        torchaudio.save(
            filepath, audio.unsqueeze(0), self.sample_rate, bits_per_sample=16
        )

__call__(waveform, group_timestamps=True)

Use the VAD model to get the speech timestamps. Multi-channel pipeline.

Parameters:

Name Type Description Default
waveform Tensor

Audio tensor.

required
group_timestamps Optional[bool]

Group timestamps. Defaults to True.

True

Returns:

Type Description
Tuple[Union[List[dict], List[List[dict]]], Tensor]

Tuple[Union[List[dict], List[List[dict]]], torch.Tensor]: Speech timestamps and audio tensor.

Source code in src/wordcab_transcribe/services/vad_service.py
def __call__(
    self, waveform: torch.Tensor, group_timestamps: Optional[bool] = True
) -> Tuple[Union[List[dict], List[List[dict]]], torch.Tensor]:
    """
    Use the VAD model to get the speech timestamps. Multi-channel pipeline.

    Args:
        waveform (torch.Tensor): Audio tensor.
        group_timestamps (Optional[bool], optional): Group timestamps. Defaults to True.

    Returns:
        Tuple[Union[List[dict], List[List[dict]]], torch.Tensor]: Speech timestamps and audio tensor.
    """
    if waveform.size(0) == 1:
        waveform = waveform.squeeze(0)

    speech_timestamps = get_speech_timestamps(
        audio=waveform, vad_options=self.options
    )

    _speech_timestamps_list = [
        {"start": ts["start"], "end": ts["end"]} for ts in speech_timestamps
    ]

    if group_timestamps:
        speech_timestamps_list = self.group_timestamps(_speech_timestamps_list)
    else:
        speech_timestamps_list = _speech_timestamps_list

    return speech_timestamps_list, waveform

__init__()

Initialize the VAD Service.

Source code in src/wordcab_transcribe/services/vad_service.py
def __init__(self) -> None:
    """Initialize the VAD Service."""
    self.sample_rate = 16000
    self.options = VadOptions(
        threshold=0.5,
        min_speech_duration_ms=250,
        max_speech_duration_s=30,
        min_silence_duration_ms=100,
        window_size_samples=512,
        speech_pad_ms=400,
    )

group_timestamps(timestamps, threshold=3.0)

Group timestamps based on a threshold.

Parameters:

Name Type Description Default
timestamps List[dict]

List of timestamps.

required
threshold float

Threshold to use for grouping. Defaults to 3.0.

3.0

Returns:

Type Description
List[List[dict]]

List[List[dict]]: List of grouped timestamps.

Source code in src/wordcab_transcribe/services/vad_service.py
def group_timestamps(
    self, timestamps: List[dict], threshold: Optional[float] = 3.0
) -> List[List[dict]]:
    """
    Group timestamps based on a threshold.

    Args:
        timestamps (List[dict]): List of timestamps.
        threshold (float, optional): Threshold to use for grouping. Defaults to 3.0.

    Returns:
        List[List[dict]]: List of grouped timestamps.
    """
    grouped_segments = [[]]

    for i in range(len(timestamps)):
        if (
            i > 0
            and (timestamps[i]["start"] - timestamps[i - 1]["end"]) > threshold
        ):
            grouped_segments.append([])

        grouped_segments[-1].append(timestamps[i])

    return grouped_segments

save_audio(filepath, audio)

Save audio tensor to file.

Parameters:

Name Type Description Default
filepath str

Path to save the audio file.

required
audio Tensor

Audio tensor.

required
Source code in src/wordcab_transcribe/services/vad_service.py
def save_audio(self, filepath: str, audio: torch.Tensor) -> None:
    """
    Save audio tensor to file.

    Args:
        filepath (str): Path to save the audio file.
        audio (torch.Tensor): Audio tensor.
    """
    torchaudio.save(
        filepath, audio.unsqueeze(0), self.sample_rate, bits_per_sample=16
    )

Last update: 2023-10-12
Created: 2023-10-12