Skip to content

Commit 9cf623d

Browse files
committed
Commit relevant files
2 parents b0cdaa1 + d3a8c4b commit 9cf623d

File tree

6 files changed

+521
-159
lines changed

6 files changed

+521
-159
lines changed

.github/workflows/security-scan.yml

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
name: Security Scan
2+
3+
on:
4+
pull_request:
5+
push:
6+
branches:
7+
- main
8+
9+
jobs:
10+
security:
11+
name: OSS Security SAST
12+
uses: Roblox/security-workflows/.github/workflows/oss-security-sast.yaml@main
13+
with:
14+
skip-ossf: true
15+
secrets:
16+
GITLEAKS_LICENSE: ${{ secrets.GITLEAKS_KEY }}
17+
ROBLOX_SEMGREP_GHC_POC_APP_TOKEN: ${{ secrets.ROBLOX_SEMGREP_GHC_POC_APP_TOKEN }}

LICENSE.md

+352-159
Large diffs are not rendered by default.

README.md

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
## Model Description
2+
The model is finetuned on the [WavLM base plus](https://arxiv.org/abs/2110.13900) with 2,374 hours of audio clips from
3+
voice chat for multilabel classification.
4+
The audio clips are automatically labeled using a synthetic data pipeline described in [our blog post](link to blog post here).
5+
A single output can have multiple labels.
6+
The model outputs a n by 6 output tensor where the inferred labels are `Profanity`, `DatingAndSexting`, `Racist`,
7+
`Bullying`, `Other`, `NoViolation`. `Other` consists of policy violation categories with low prevalence such as drugs
8+
and alcohol or self-harm that are combined into a single category.
9+
10+
We evaluated this model on a dataset with human annotated labels that contained a total of 9795 samples with the class
11+
distribution shown below. Note that we did not include the "other" category in this evaluation dataset.
12+
13+
|Class|Number of examples| Duration (hours)|% of dataset|
14+
|---|---|---|---|
15+
|Profanity | 4893| 15.38 | 49.95%|
16+
|DatingAndSexting | 688 | 2.52 | 7.02% |
17+
|Racist | 889 | 3.10 | 9.08% |
18+
|Bullying | 1256 | 4.25 | 12.82% |
19+
|NoViolation | 4185 | 9.93 | 42.73% |
20+
21+
22+
If we set the same threshold across all classes and treat the model as a binary classifier across all 4 toxicity classes (`Profanity`, `DatingAndSexting`, `Racist`, `Bullying`), we get a binarized average precision of 94.48%. The precision recall curve is as shown below.
23+
<p align="center">
24+
<img src="images/human_eval_pr_curve.png" alt="PR Curve" width="500"/>
25+
</p>
26+
27+
## Usage
28+
The dependencies for the inference file can be installed as follows:
29+
```
30+
pip install -r requirements.txt
31+
```
32+
The inference file contains useful helper functions to preprocess the audio file for proper inference.
33+
To run the inference file, please run the following command:
34+
```
35+
python inference.py --audio_file <your audio file path> --model_path <path to Huggingface model>
36+
```
37+
You can get the model weights either by downloading from the model releases page [here](https://github.com/Roblox/voice-safety-classifier/releases/tag/vs-classifier-v1), or from HuggingFace under `roblox/voice-safety-classifier`.
38+
If `model_path` isn’t specified, the model will be loaded directly from HuggingFace.

images/human_eval_pr_curve.png

28.1 KB
Loading

inference.py

+110
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Copyright © 2024 Roblox Corporation
2+
3+
"""
4+
This file gives a sample demonstration of how to use the given functions in Python, for the Voice Safety Classifier model.
5+
"""
6+
7+
import torch
8+
import librosa
9+
import numpy as np
10+
import argparse
11+
from transformers import WavLMForSequenceClassification
12+
13+
14+
def feature_extract_simple(
15+
wav,
16+
sr=16_000,
17+
win_len=15.0,
18+
win_stride=15.0,
19+
do_normalize=False,
20+
):
21+
"""simple feature extraction for wavLM
22+
Parameters
23+
----------
24+
wav : str or array-like
25+
path to the wav file, or array-like
26+
sr : int, optional
27+
sample rate, by default 16_000
28+
win_len : float, optional
29+
window length, by default 15.0
30+
win_stride : float, optional
31+
window stride, by default 15.0
32+
do_normalize: bool, optional
33+
whether to normalize the input, by default False.
34+
Returns
35+
-------
36+
np.ndarray
37+
batched input to wavLM
38+
"""
39+
if type(wav) == str:
40+
signal, _ = librosa.core.load(wav, sr=sr)
41+
else:
42+
try:
43+
signal = np.array(wav).squeeze()
44+
except Exception as e:
45+
print(e)
46+
raise RuntimeError
47+
batched_input = []
48+
stride = int(win_stride * sr)
49+
l = int(win_len * sr)
50+
if len(signal) / sr > win_len:
51+
for i in range(0, len(signal), stride):
52+
if i + int(win_len * sr) > len(signal):
53+
# padding the last chunk to make it the same length as others
54+
chunked = np.pad(signal[i:], (0, l - len(signal[i:])))
55+
else:
56+
chunked = signal[i : i + l]
57+
if do_normalize:
58+
chunked = (chunked - np.mean(chunked)) / (np.std(chunked) + 1e-7)
59+
batched_input.append(chunked)
60+
if i + int(win_len * sr) > len(signal):
61+
break
62+
else:
63+
if do_normalize:
64+
signal = (signal - np.mean(signal)) / (np.std(signal) + 1e-7)
65+
batched_input.append(signal)
66+
return np.stack(batched_input) # [N, T]
67+
68+
69+
def infer(model, inputs):
70+
output = model(inputs)
71+
probs = torch.sigmoid(torch.Tensor(output.logits))
72+
return probs
73+
74+
75+
if __name__ == "__main__":
76+
parser = argparse.ArgumentParser()
77+
parser.add_argument(
78+
"--audio_file",
79+
type=str,
80+
help="File to run inference",
81+
)
82+
parser.add_argument(
83+
"--model_path",
84+
type=str,
85+
default="roblox/voice-safety-classifier",
86+
help="checkpoint file of model",
87+
)
88+
args = parser.parse_args()
89+
labels_name_list = [
90+
"Profanity",
91+
"DatingAndSexting",
92+
"Racist",
93+
"Bullying",
94+
"Other",
95+
"NoViolation",
96+
]
97+
# Model is trained on only 16kHz audio
98+
audio, _ = librosa.core.load(args.audio_file, sr=16000)
99+
input_np = feature_extract_simple(audio, sr=16000)
100+
input_pt = torch.Tensor(input_np)
101+
model = WavLMForSequenceClassification.from_pretrained(
102+
args.model_path, num_labels=len(labels_name_list)
103+
)
104+
probs = infer(model, input_pt)
105+
probs = probs.reshape(-1, 6).detach().tolist()
106+
print(f"Probabilities for {args.audio_file} is:")
107+
for chunk_idx in range(len(probs)):
108+
print(f"\nSegment {chunk_idx}:")
109+
for label_idx, label in enumerate(labels_name_list):
110+
print(f"{label} : {probs[chunk_idx][label_idx]}")

requirements.txt

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
torch
2+
transformers
3+
librosa
4+
numpy

0 commit comments

Comments
 (0)