#!/usr/bin/env python import os import time from tqdm import tqdm import numpy as np from argparse import ArgumentParser from pprint import pformat import pandas as pd import logging import torch import librosa from torch.utils.data import DataLoader from dataset import SoundDataset_test import soundfile as sf from config import ParameterSetting from models import VGGish from metrics import cfm, classification_report, roc_auc import flask import json logger = logging.getLogger(__file__) app = flask.Flask("predict-server") def run_predict_server(): app.run(host='0.0.0.0', port=8080) def wav_read(wav_file): wav_data, sr = sf.read(wav_file, dtype='int16') return wav_data, sr def inference_single_audio(path): wav_data, sr = wav_read(path) return wav_data, sr def preprocessing(params, wav_data, sr): """Convert wav_data to log mel spectrogram. 1. normalize the wav_data 2. convert the wav_data into mono-channel 3. resample the wav_data to the sampling rate we want 4. compute the log mel spetrogram with librosa function Args: wav_data: An np.array indicating wav data in np.int16 datatype sr: An integer specifying the sampling rate of this wav data Return: inpt: An np.array indicating the log mel spectrogram of data """ # normalize wav_data if params.normalize == 'peak': samples = wav_data/np.max(wav_data) elif params.normalize == 'rms': rms_level = 0 r = 10**(rms_level / 10.0) a = np.sqrt((len(wav_data) * r**2) / np.sum(wav_data**2)) samples = wav_data * a else: samples = wav_data / 32768.0 # convert samples to mono-channel file if len(samples.shape) > 1: samples = np.mean(samples, axis=1) # resample samples to 8k if sr != params.sr: samples = resampy.resample(samples, sr, params.sr) # transform samples to mel spectrogram inpt_x = 500 spec = librosa.feature.melspectrogram(samples, sr=params.sr, n_fft=params.nfft, hop_length=params.hop, n_mels=params.mel) spec_db = librosa.power_to_db(spec).T spec_db = np.concatenate((spec_db, np.zeros((inpt_x - spec_db.shape[0], params.mel))), axis=0) if spec_db.shape[0] < inpt_x else spec_db[:inpt_x] inpt = np.reshape(spec_db, (1, spec_db.shape[0], spec_db.shape[1])) return inpt.astype('float32') def load_model(): prefix = '/opt/ml/' model_name = "final_model.pkl" model_path = os.path.join(prefix, 'model', model_name) model = None model = VGGish(params) model = torch.load(model_path) model.eval() model = model.to(device) return model @app.route('/ping', methods=['GET']) def ping(): """Determine if the container is working and healthy. In this sample container, we declare it healthy if we can load the model successfully.""" status = 200 return flask.Response(response='\n', status=status, mimetype='application/json') import uuid @app.route('/invocations', methods=['POST']) def predict(): data = flask.request.data path = '/tmp/{}.wav'.format(uuid.uuid4()) tmpFile = open(path, 'wb') tmpFile.write(data) tmpFile.close() wav,sr = inference_single_audio(path) wav = preprocessing(params, wav, sr) wav = torch.tensor([wav]) spec = wav.to(device) outputs = model(spec) outputs = torch.nn.functional.softmax(outputs, dim=1) _, preds = torch.max(outputs, 1) pred_label = preds.cpu().detach().numpy() outputs = outputs.cpu().detach().numpy() print(pred_label, outputs) results = {} results['label']=int(pred_label[0]) results['probability']=outputs[0].tolist() return flask.Response(response=json.dumps(results), status=200, mimetype='text/json') params = ParameterSetting(sr=8000,nfft=200, hop=80, mel=64, normalize=None, preload=False) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model=load_model() run_predict_server()