Source code for spokestack.wakeword.tflite

"""
This module contains the class for detecting
the presence of keywords in an audio stream
"""
import logging
import os
from typing import Any

import numpy as np

from spokestack.context import SpeechContext
from spokestack.models.tensorflow import TFLiteModel
from spokestack.ring_buffer import RingBuffer

_LOG = logging.getLogger(__name__)


[docs]class WakewordTrigger: """Detects the presence of a wakeword in the audio input Args: pre_emphasis (float): The value of the pre-emmphasis filter sample_rate (int): The number of audio samples per second of audio (kHz) fft_window_type (str): The type of fft window. (only support for hann) fft_hop_length (int): Audio sliding window for STFT calculation (ms) model_dir (str): Path to the directory containing .tflite models posterior_threshold (float): Probability threshold for if a wakeword was detected """ def __init__( self, pre_emphasis: float = 0.0, sample_rate: int = 16000, fft_window_type: str = "hann", fft_hop_length: int = 10, model_dir: str = "", posterior_threshold: float = 0.5, **kwargs: Any, ) -> None: self.pre_emphasis: float = pre_emphasis self.hop_length: int = int(fft_hop_length * sample_rate / 1000) if fft_window_type != "hann": raise ValueError("Invalid fft_window_type") self.filter_model: TFLiteModel = TFLiteModel( model_path=os.path.join(model_dir, "filter.tflite") ) self.encode_model: TFLiteModel = TFLiteModel( model_path=os.path.join(model_dir, "encode.tflite") ) self.detect_model: TFLiteModel = TFLiteModel( model_path=os.path.join(model_dir, "detect.tflite") ) # window size calculated based on fft # the filter inputs are (fft_size - 1) / 2 # which makes the window size (post_fft_size - 1) * 2 self._window_size = (self.filter_model.input_details[0]["shape"][-1] - 1) * 2 self._fft_window = np.hanning(self._window_size) # retrieve the mel_length and mel_width based on the encoder model metadata # these allocate the buffer to the correct size self.mel_length: int = self.encode_model.input_details[0]["shape"][1] self.mel_width: int = self.encode_model.input_details[0]["shape"][-1] # initialize the first state input for autoregressive encoder model # retrieve the encode_length and encode_width from the model detect_model # metadata. We get the dimensions from the detect_model inputs because the # encode_model runs autoregressively and outputs a single encoded sample. # the detect_model input is a collection of these samples. self.state = np.zeros(self.encode_model.input_details[1]["shape"], np.float32) self.encode_length: int = self.detect_model.input_details[0]["shape"][1] self.encode_width: int = self.detect_model.input_details[0]["shape"][-1] self.sample_window: RingBuffer = RingBuffer(shape=[self._window_size]) self.frame_window: RingBuffer = RingBuffer( shape=[self.mel_length, self.mel_width] ) self.encode_window: RingBuffer = RingBuffer( shape=[self.encode_length, self.encode_width] ) # initialize the frame and encode windows with zeros # this minimizes the delay caused by filling the buffer self.frame_window.fill(0.0) self.encode_window.fill(-1.0) self._posterior_threshold: float = posterior_threshold self._posterior_max: float = 0.0 self._prev_sample: float = 0.0 self._is_speech: bool = False def __call__(self, context: SpeechContext, frame: np.ndarray) -> None: """Entry point of the trigger Args: context (SpeechContext): current state of the speech pipeline frame (np.ndarray): a single frame of an audio signal Returns: None """ # detect vad edges for wakeword deactivation vad_fall = self._is_speech and not context.is_speech self._is_speech = context.is_speech # sample frame to detect the presence of wakeword if not context.is_active: self._sample(context, frame) # reset on vad fall deactivation if vad_fall: if not context.is_active: _LOG.info(f"wake: {self._posterior_max}") self.reset() def _sample(self, context: SpeechContext, frame: np.ndarray) -> None: # convert the PCM-16 audio to float32 in (-1.0, 1.0) frame = frame.astype(np.float32) / (2 ** 15 - 1) frame = np.clip(frame, -1.0, 1.0) # pull out a single value from the frame and apply pre-emphasis # with the previous sample then cache the previous sample # to be use in the next iteration prev_sample = frame[-1] frame -= self.pre_emphasis * np.append(self._prev_sample, frame[:-1]) self._prev_sample = prev_sample # fill the sample window to analyze speech containing samples # after each window fill the buffer advances by the hop length # to produce an overlapping window for sample in frame: self.sample_window.write(sample) if self.sample_window.is_full: if context.is_speech: self._analyze(context) self.sample_window.rewind().seek(self.hop_length) def _analyze(self, context: SpeechContext) -> None: # read the full contents of the sample window to calculate a single frame # of the STFT by applying the DFT to a real-valued input and # taking the magnitude of the complex DFT frame = self.sample_window.read_all() frame = np.fft.rfft(frame * self._fft_window, n=self._window_size) frame = np.abs(frame).astype(np.float32) # compute mel spectrogram self._filter(context, frame) def _filter(self, context: SpeechContext, frame: np.ndarray) -> None: # add the batch dimension and compute the mel spectrogram with filter model frame = np.expand_dims(frame, 0) frame = self.filter_model(frame)[0] # advance the window by 1 and write mel frame to the frame buffer self.frame_window.rewind().seek(1) self.frame_window.write(frame) # encode the mel spectrogram self._encode(context) def _encode(self, context: SpeechContext) -> None: # read the full contents of the frame window and add the batch dimension # run the encoder and save the output state for autoregression frame = self.frame_window.read_all() frame = np.expand_dims(frame, 0) frame, self.state = self.encode_model(frame, self.state) # accumulate encoded samples until size of detection window self.encode_window.rewind().seek(1) self.encode_window.write(frame) self._detect(context) def _detect(self, context: SpeechContext) -> None: # read the full contents of the encode window and add the batch dimension # calculate a scalar probability of if the frame contains the wakeword # with the detect model frame = self.encode_window.read_all() frame = np.expand_dims(frame, 0) posterior = self.detect_model(frame)[0][0][0] if posterior > self._posterior_max: self._posterior_max = posterior if posterior > self._posterior_threshold: context.is_active = True _LOG.info(f"wake: {self._posterior_max}")
[docs] def reset(self) -> None: """ Resets the currect WakewordDetector state """ self.sample_window.reset() self.frame_window.reset().fill(0.0) self.encode_window.reset().fill(-1.0) self.state[:] = 0.0 self._posterior_max = 0.0
[docs] def close(self) -> None: """ Close interface for use in the pipeline """ self.reset()