Last active
February 26, 2026 20:17
-
-
Save MahmoudAshraf97/504eb60dd19ea352728665ae74a51d05 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| import argparse | |
| import asyncio | |
| import json | |
| import time | |
| import uuid | |
| import wave | |
| from websockets import connect, WebSocketException | |
| import os | |
| # Configuration constants | |
| SAMPLE_RATE = 16000 | |
| CHANNELS = 1 | |
| PORT = 4001 # WebSocket server port | |
| EXIT_TIMEOUT = 2 # seconds of inactivity before ending stream | |
| # Event and mode enums | |
| class ClientEvents: | |
| START_STREAM = "START_STREAM" | |
| END_STREAM = "END_STREAM" | |
| class RecitationMode: | |
| DETECTION = "DETECTION" | |
| FOLLOW_ALONG = "FOLLOW_ALONG" | |
| async def send_audio(ws, audio_file, start_stream_data, shared): | |
| # Validate WAV properties | |
| with wave.open(audio_file, "rb") as wf: | |
| if wf.getframerate() != SAMPLE_RATE: | |
| raise ValueError(f"Invalid sample rate: {wf.getframerate()}, expected {SAMPLE_RATE}") | |
| if wf.getnchannels() != CHANNELS: | |
| raise ValueError(f"Invalid channel count: {wf.getnchannels()}, expected {CHANNELS}") | |
| # Send START_STREAM | |
| await ws.send(json.dumps(start_stream_data)) | |
| chunk_size = 0.08 | |
| # Stream binary audio in chunks | |
| shared["stream_start_time"] = time.perf_counter() | |
| while chunk := wf.readframes(int(chunk_size * SAMPLE_RATE)): | |
| await ws.send(chunk) | |
| await asyncio.sleep(chunk_size) | |
| async def receive_messages(ws, shared): | |
| mistakes = {} | |
| first_identified_position = None | |
| seen_positions = set() | |
| absolute_latencies = {} | |
| try: | |
| async for message in ws: | |
| try: | |
| result = json.loads(message) | |
| except json.JSONDecodeError: | |
| print(f"[UNKNOWN] {message}") | |
| continue | |
| evt = result.get("event") | |
| data = result.get("data", {}) | |
| if evt == "STATES_UPDATE": | |
| if data.get("mistakeUpdates"): | |
| for id, mistake in data.get("mistakeUpdates").items(): | |
| mistakes[id] = mistake | |
| if not first_identified_position: | |
| for state in data.get("newStates", []): | |
| pos = state.get("position") | |
| if pos: | |
| first_identified_position = ( | |
| f"{pos['surahNumber']}:{pos['ayahNumber']}:{pos['wordNumber']}" | |
| ) | |
| break | |
| for m in data.get("mistakeUpdates", {}).values(): | |
| if m is not None: | |
| print( | |
| f"[{m['mistakeType']}] E: {m['expectedTranscript']} / R: {m['receivedTranscript']}" | |
| f" / {m['startTimeMs']}-{m['endTimeMs']}" | |
| ) | |
| for state in data.get("newStates", []): | |
| pos = state.get("position") | |
| position = ( | |
| f"{pos['surahNumber']}:{pos['ayahNumber']}:{pos['wordNumber']}" | |
| if pos | |
| else "0:0:0" | |
| ) | |
| # print( | |
| # f"[{state['type']}] {position} {state['word']} ({state['startTime']}-{state['endTime']})" | |
| # ) | |
| if pos and position not in seen_positions: | |
| seen_positions.add(position) | |
| absolute_latencies[position] = ( | |
| time.perf_counter() - shared["stream_start_time"] | |
| ) * 1000 | |
| # elif evt == "PARTIAL_TRANSCRIPT": | |
| # print(f"[{evt}] {data.get('queryText')}") | |
| # elif evt in ("GOT_LOST", "ERROR"): | |
| # print(f"[{evt}] {data if evt=='ERROR' else ''}") | |
| # else: | |
| # print(f"[UNKNOWN_EVENT] {result}") | |
| # print("=" * 20) | |
| except WebSocketException as e: | |
| print(f"WebSocket error: {e}") | |
| return ( | |
| {k: v for k, v in mistakes.items() if v is not None}, | |
| first_identified_position, | |
| absolute_latencies, | |
| ) | |
| async def inactivity_watch(ws, end_stream_data): | |
| # Wait then send END_STREAM | |
| await asyncio.sleep(EXIT_TIMEOUT) | |
| # print("Sending END_STREAM event due to inactivity") | |
| await ws.send(json.dumps(end_stream_data)) | |
| await ws.close() | |
| async def evaluate(audio_file: str, verbose: bool = True, new_stt_server: bool = False): | |
| client_config = { | |
| "appVersion": "dev", | |
| "audioConfig": { | |
| "fileFormat": "WAV", | |
| "channels": CHANNELS, | |
| "sampleRate": SAMPLE_RATE, | |
| "modelName": None, | |
| }, | |
| "authToken": "4b70b75dc4d77118cd63adb3acbbc5d7eeca65bb", | |
| "deviceId": "123", | |
| "devicePlatform": "WEB", | |
| "recitationMode": RecitationMode.DETECTION, | |
| "sessionId": str(uuid.uuid4()), | |
| "isDiacritized": True, | |
| "isMemorization": False, | |
| "shouldCollectAudio": False, | |
| "shouldLabelAudio": False, | |
| "mistakeReportingTimeLag": 0 if new_stt_server else 800, | |
| "isNewSttServer": new_stt_server, | |
| # "isDualModel": False, | |
| # "mistakeReportingTimeLag": 800, | |
| # "isNewSttServer": False, | |
| "isDualModel": True, | |
| } | |
| start_stream_data = {"event": ClientEvents.START_STREAM, "data": client_config} | |
| end_stream_data = {"event": ClientEvents.END_STREAM, "data": {}} | |
| uri = f"ws://localhost:{PORT}" | |
| # uri = "wss://voice-v2-dev.tarteel.io" | |
| shared = {} | |
| async with connect(uri, close_timeout=200, open_timeout=200) as ws: | |
| # Launch tasks concurrently | |
| send_task = asyncio.create_task(send_audio(ws, audio_file, start_stream_data, shared)) | |
| recv_task = asyncio.create_task(receive_messages(ws, shared)) | |
| # Start inactivity timer after sending completes | |
| await send_task | |
| await inactivity_watch(ws, end_stream_data) | |
| # Ensure we process remaining messages | |
| mistakes, first_identified_position, absolute_latencies = await recv_task | |
| if verbose: | |
| # print(f"First identified position: {first_identified_position}") | |
| for mistake in mistakes.values(): | |
| # print(mistake) | |
| # print( | |
| # f"[{mistake['mistakeType']}] E: {mistake['expectedTranscript']} / R: {mistake['receivedTranscript']}" | |
| # f" / {mistake['startTimeMs']}-{mistake['endTimeMs']}" | |
| # ) | |
| ... | |
| if absolute_latencies: | |
| avg = sum(absolute_latencies.values()) / len(absolute_latencies) | |
| # print(f"Avg absolute latency: {avg:.1f}ms ({len(absolute_latencies)} positions)") | |
| return absolute_latencies | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Async WAV-to-WebSocket streamer") | |
| parser.add_argument("file", help="Path to WAV audio file (16kHz, mono)") | |
| args = parser.parse_args() | |
| async def main(): | |
| all_avgs = [] | |
| num_runs = 10 | |
| for i in range(num_runs): | |
| old_latencies, new_latencies = await asyncio.gather( | |
| evaluate(args.file, new_stt_server=False), | |
| evaluate(args.file, new_stt_server=True), | |
| ) | |
| with open(f"old_latencies_{os.getpid()}_{i}.json", "w") as f: | |
| json.dump(old_latencies, f, indent=2) | |
| with open(f"new_latencies_{os.getpid()}_{i}.json", "w") as f: | |
| json.dump(new_latencies, f, indent=2) | |
| delta = [] | |
| for k in new_latencies: | |
| if k not in old_latencies: | |
| print(f"Key {k} missing in old latencies") | |
| continue | |
| delta.append(new_latencies[k] - old_latencies[k]) | |
| if delta: | |
| print(f"\nAvg latency delta (new - old): {sum(delta) / len(delta):.1f}ms ({len(delta)} positions)") | |
| all_avgs.append(sum(delta) / len(delta)) | |
| if all_avgs: | |
| print(f"\nOverall avg latency delta across {num_runs} runs: {sum(all_avgs) / len(all_avgs):.1f}ms") | |
| try: | |
| asyncio.run(main()) | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| exit(1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment