"""
This module contains the Spokestack KeywordRecognizer which identifies multiple keywords
from an audio stream.
"""
import os
from typing import Any, List
import numpy as np
from spokestack.context import SpeechContext
from spokestack.models.tensorflow import TFLiteModel
from spokestack.ring_buffer import RingBuffer
[docs]class KeywordRecognizer:
"""Recognizes keywords in an audio stream.
Args:
classes (List[str]): Keyword labels
pre_emphasis (float): The value of the pre-emphasis filter
sample_rate (int): The number of audio samples per second of audio (kHz)
fft_window_type (str): The type of fft window. (only support for hann)
fft_hop_length (int): Audio sliding window for STFT calculation (ms)
model_dir (str): Path to the directory containing .tflite models
posterior_threshold (float): Probability threshold for detection
"""
def __init__(
self,
classes: List[str],
pre_emphasis: float = 0.97,
sample_rate: int = 16000,
fft_window_type: str = "hann",
fft_hop_length: int = 10,
model_dir: str = "",
posterior_threshold: float = 0.5,
**kwargs: Any
) -> None:
self.classes = classes
self.pre_emphasis: float = pre_emphasis
self.hop_length: int = int(fft_hop_length * sample_rate / 1000)
if fft_window_type != "hann":
raise ValueError("Invalid fft_window_type")
self.filter_model: TFLiteModel = TFLiteModel(
model_path=os.path.join(model_dir, "filter.tflite")
)
self.encode_model: TFLiteModel = TFLiteModel(
model_path=os.path.join(model_dir, "encode.tflite")
)
self.detect_model: TFLiteModel = TFLiteModel(
model_path=os.path.join(model_dir, "detect.tflite")
)
if len(classes) != self.detect_model.output_details[0]["shape"][-1]:
raise ValueError("Invalid number of classes")
# window size calculated based on fft
# the filter inputs are (fft_size - 1) / 2
# which makes the window size (post_fft_size - 1) * 2
self._window_size = (self.filter_model.input_details[0]["shape"][-1] - 1) * 2
self._fft_window = np.hanning(self._window_size)
# retrieve the mel_length and mel_width based on the encoder model metadata
# these allocate the buffer to the correct size
self.mel_length: int = self.encode_model.input_details[0]["shape"][1]
self.mel_width: int = self.encode_model.input_details[0]["shape"][-1]
# initialize the first state input for autoregressive encoder model
# retrieve the encode_length and encode_width from the model detect_model
# metadata. We get the dimensions from the detect_model inputs because the
# encode_model runs autoregressive and outputs a single encoded sample.
# the detect_model input is a collection of these samples.
self.state = np.zeros(self.encode_model.input_details[1]["shape"], np.float32)
self.encode_length: int = self.detect_model.input_details[0]["shape"][1]
self.encode_width: int = self.detect_model.input_details[0]["shape"][-1]
self.sample_window: RingBuffer = RingBuffer(
shape=[self._window_size],
)
self.frame_window: RingBuffer = RingBuffer(
shape=[self.mel_length, self.mel_width]
)
self.encode_window: RingBuffer = RingBuffer(
shape=[self.encode_length, self.encode_width]
)
# initialize the frame and encode windows with zeros
# this minimizes the delay caused by filling the buffer
self.frame_window.fill(0.0)
self.encode_window.fill(-1.0)
self._posterior_threshold: float = posterior_threshold
self._prev_sample: float = 0.0
self._is_active = False
def __call__(self, context: SpeechContext, frame: np.ndarray) -> None:
self._sample(context, frame)
if not context.is_active and self._is_active:
self._detect(context)
self._is_active = context.is_active
def _sample(self, context: SpeechContext, frame: np.ndarray) -> None:
# convert the PCM-16 audio to float32 in (-1.0, 1.0)
frame = frame.astype(np.float32) / (2 ** 15 - 1)
frame = np.clip(frame, -1.0, 1.0)
# pull out a single value from the frame and apply pre-emphasis
# with the previous sample then cache the previous sample
# to be use in the next iteration
prev_sample = frame[-1]
frame -= self.pre_emphasis * np.append(self._prev_sample, frame[:-1])
self._prev_sample = prev_sample
# fill the sample window to analyze speech containing samples
# after each window fill the buffer advances by the hop length
# to produce an overlapping window
for sample in frame:
self.sample_window.write(sample)
if self.sample_window.is_full:
if context.is_active:
self._analyze(context)
self.sample_window.rewind().seek(self.hop_length)
def _analyze(self, context: SpeechContext) -> None:
# read the full contents of the sample window to calculate a single frame
# of the STFT by applying the DFT to a real-valued input and
# taking the magnitude of the complex DFT
frame = self.sample_window.read_all()
frame = np.fft.rfft(frame * self._fft_window, n=self._window_size)
frame = np.abs(frame).astype(np.float32)
# compute mel spectrogram
self._filter(context, frame)
def _filter(self, context: SpeechContext, frame: np.ndarray) -> None:
# add the batch dimension and compute the mel spectrogram with filter model
frame = np.expand_dims(frame, 0)
frame = self.filter_model(frame)[0]
# advance the window by 1 and write mel frame to the frame buffer
self.frame_window.rewind().seek(1)
self.frame_window.write(frame)
# encode the mel spectrogram
self._encode(context)
def _encode(self, context: SpeechContext) -> None:
# read the full contents of the frame window and add the batch dimension
# run the encoder and save the output state for autoregression
frame = self.frame_window.read_all()
frame = np.expand_dims(frame, 0)
frame, self.state = self.encode_model(frame, self.state)
# accumulate encoded samples until size of detection window
self.encode_window.rewind().seek(1)
self.encode_window.write(frame)
def _detect(self, context: SpeechContext) -> None:
# read the full contents of the encode window and add the batch dimension
# calculate a scalar likelihood that the frame contains a keyword
# with the detect model
frame = self.encode_window.read_all()
frame = np.expand_dims(frame, 0)
posterior = self.detect_model(frame)[0][0]
class_index = np.argmax(posterior)
confidence = posterior[class_index]
if confidence >= self._posterior_threshold:
context.transcript = self.classes[class_index]
context.confidence = confidence
context.event("recognize")
else:
context.event("timeout")
self.reset()
[docs] def reset(self) -> None:
""" Resets the current KeywordDetector state """
self.sample_window.reset()
self.frame_window.reset().fill(0.0)
self.encode_window.reset().fill(-1.0)
self.state[:] = 0.0
[docs] def close(self) -> None:
""" Close interface for use in the SpeechPipeline """
self.reset()