-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Open
Labels
Description
What happened?
I am trying to use yamnet model but library issue is coming
Relevant code
import tensorflow as tf
import numpy as np
import tensorflow_hub as hub
import tensorflow_io as tfio
import random
# For monitoring RAM usage
from sklearn.metrics import classification_report
from collections import Counter
import sys
sys.setrecursionlimit(3000)
# Set random seeds for reproducibility
SEED = 42
tf.random.set_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
# Paths to the dataset
audio_files = {
'Rov': "/content/drive/MyDrive/rovdata/guassian16kfolder/rov_16k.wav"
}
# YAMNet model handle
yamnet_model_handle = 'https://tfhub.dev/google/yamnet/1'
yamnet_model = hub.load(yamnet_model_handle)
# Function to load and resample audio to 16 kHz mono (taking only 10% of duration)
def load_wav_16k_mono(filename):
"""Load a WAV file, convert it to 16 kHz mono, and return a tensor."""
file_contents = tf.io.read_file(filename)
wav, sample_rate = tf.audio.decode_wav(file_contents, desired_channels=1)
wav = tf.squeeze(wav, axis=-1)
sample_rate = tf.cast(sample_rate, dtype=tf.int64)
wav = tfio.audio.resample(wav, rate_in=sample_rate, rate_out=16000)
return wav
# Define classes of interest
my_classes = ['Rov', 'ship', 'Mammal', 'Ambient']
class_to_idx = {cls: idx for idx, cls in enumerate(my_classes)}
# Lazy loading with RAM monitoring
data = []
for label, file_path in audio_files.items():
print(f"🔹 Loading {label} audio...")
wav_data = load_wav_16k_mono(file_path)
if wav_data is not None: # Ensure valid data
label_idx = class_to_idx[label]
data.append((wav_data, label_idx))
#print(f"✅ Loaded {len(data)} audio files before hitting RAM limit.")
# Generator function
def generator():
"""Yields (audio waveform, label) pairs."""
for wav, label in data:
yield wav, label
# Create dataset using generator
main_ds = tf.data.Dataset.from_generator(
generator,
output_signature=(
tf.TensorSpec(shape=(None,), dtype=tf.float32), # Audio waveform
tf.TensorSpec(shape=(), dtype=tf.int32) # Label
)
)
def extract_embedding(wav_data, label):
wav_data = tf.reshape(wav_data, [-1])
scores, embeddings, spectrogram = yamnet_model(wav_data)
num_embeddings = tf.shape(embeddings)[0]
return embeddings, tf.repeat(label, num_embeddings)
# Extract embeddings using YAMNet
#def extract_embedding(wav_data, label):
"""Run YAMNet to extract embeddings from the audio waveform."""
# Ensure YAMNet gets a 1D tensor (No extra batch dimension)
#wav_data = tf.reshape(wav_data, [-1])
#a Run YAMNet model
#scores, embeddings, spectrogram = yamnet_model(wav_data)
# Get number of embeddings
#num_embeddings = tf.shape(embeddings)[0]
#return embeddings, tf.repeat(label, num_embeddings)
# Apply mapping and unbatch
main_ds = main_ds.map(extract_embedding).unbatch()
print("✅ Dataset prepared successfully.")Relevant log output
mai_ds_file has foundtensorflow_hub Version
0.13.0.dev (unstable development build)
TensorFlow Version
2.8 (latest stable release)
Other libraries
!pip install tensorflow-io
!pip install scikit-learn
!pip install tensorflow==2.17.1
Python Version
3.x
OS
Linux