Skip to content

Instantly share code, notes, and snippets.

@MahmoudAshraf97
Last active February 26, 2026 20:17
Show Gist options
  • Select an option

  • Save MahmoudAshraf97/504eb60dd19ea352728665ae74a51d05 to your computer and use it in GitHub Desktop.

Select an option

Save MahmoudAshraf97/504eb60dd19ea352728665ae74a51d05 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
import argparse
import asyncio
import json
import time
import uuid
import wave
from websockets import connect, WebSocketException
import os
# Configuration constants
SAMPLE_RATE = 16000
CHANNELS = 1
PORT = 4001 # WebSocket server port
EXIT_TIMEOUT = 2 # seconds of inactivity before ending stream
# Event and mode enums
class ClientEvents:
START_STREAM = "START_STREAM"
END_STREAM = "END_STREAM"
class RecitationMode:
DETECTION = "DETECTION"
FOLLOW_ALONG = "FOLLOW_ALONG"
async def send_audio(ws, audio_file, start_stream_data, shared):
# Validate WAV properties
with wave.open(audio_file, "rb") as wf:
if wf.getframerate() != SAMPLE_RATE:
raise ValueError(f"Invalid sample rate: {wf.getframerate()}, expected {SAMPLE_RATE}")
if wf.getnchannels() != CHANNELS:
raise ValueError(f"Invalid channel count: {wf.getnchannels()}, expected {CHANNELS}")
# Send START_STREAM
await ws.send(json.dumps(start_stream_data))
chunk_size = 0.08
# Stream binary audio in chunks
shared["stream_start_time"] = time.perf_counter()
while chunk := wf.readframes(int(chunk_size * SAMPLE_RATE)):
await ws.send(chunk)
await asyncio.sleep(chunk_size)
async def receive_messages(ws, shared):
mistakes = {}
first_identified_position = None
seen_positions = set()
absolute_latencies = {}
try:
async for message in ws:
try:
result = json.loads(message)
except json.JSONDecodeError:
print(f"[UNKNOWN] {message}")
continue
evt = result.get("event")
data = result.get("data", {})
if evt == "STATES_UPDATE":
if data.get("mistakeUpdates"):
for id, mistake in data.get("mistakeUpdates").items():
mistakes[id] = mistake
if not first_identified_position:
for state in data.get("newStates", []):
pos = state.get("position")
if pos:
first_identified_position = (
f"{pos['surahNumber']}:{pos['ayahNumber']}:{pos['wordNumber']}"
)
break
for m in data.get("mistakeUpdates", {}).values():
if m is not None:
print(
f"[{m['mistakeType']}] E: {m['expectedTranscript']} / R: {m['receivedTranscript']}"
f" / {m['startTimeMs']}-{m['endTimeMs']}"
)
for state in data.get("newStates", []):
pos = state.get("position")
position = (
f"{pos['surahNumber']}:{pos['ayahNumber']}:{pos['wordNumber']}"
if pos
else "0:0:0"
)
# print(
# f"[{state['type']}] {position} {state['word']} ({state['startTime']}-{state['endTime']})"
# )
if pos and position not in seen_positions:
seen_positions.add(position)
absolute_latencies[position] = (
time.perf_counter() - shared["stream_start_time"]
) * 1000
# elif evt == "PARTIAL_TRANSCRIPT":
# print(f"[{evt}] {data.get('queryText')}")
# elif evt in ("GOT_LOST", "ERROR"):
# print(f"[{evt}] {data if evt=='ERROR' else ''}")
# else:
# print(f"[UNKNOWN_EVENT] {result}")
# print("=" * 20)
except WebSocketException as e:
print(f"WebSocket error: {e}")
return (
{k: v for k, v in mistakes.items() if v is not None},
first_identified_position,
absolute_latencies,
)
async def inactivity_watch(ws, end_stream_data):
# Wait then send END_STREAM
await asyncio.sleep(EXIT_TIMEOUT)
# print("Sending END_STREAM event due to inactivity")
await ws.send(json.dumps(end_stream_data))
await ws.close()
async def evaluate(audio_file: str, verbose: bool = True, new_stt_server: bool = False):
client_config = {
"appVersion": "dev",
"audioConfig": {
"fileFormat": "WAV",
"channels": CHANNELS,
"sampleRate": SAMPLE_RATE,
"modelName": None,
},
"authToken": "4b70b75dc4d77118cd63adb3acbbc5d7eeca65bb",
"deviceId": "123",
"devicePlatform": "WEB",
"recitationMode": RecitationMode.DETECTION,
"sessionId": str(uuid.uuid4()),
"isDiacritized": True,
"isMemorization": False,
"shouldCollectAudio": False,
"shouldLabelAudio": False,
"mistakeReportingTimeLag": 0 if new_stt_server else 800,
"isNewSttServer": new_stt_server,
# "isDualModel": False,
# "mistakeReportingTimeLag": 800,
# "isNewSttServer": False,
"isDualModel": True,
}
start_stream_data = {"event": ClientEvents.START_STREAM, "data": client_config}
end_stream_data = {"event": ClientEvents.END_STREAM, "data": {}}
uri = f"ws://localhost:{PORT}"
# uri = "wss://voice-v2-dev.tarteel.io"
shared = {}
async with connect(uri, close_timeout=200, open_timeout=200) as ws:
# Launch tasks concurrently
send_task = asyncio.create_task(send_audio(ws, audio_file, start_stream_data, shared))
recv_task = asyncio.create_task(receive_messages(ws, shared))
# Start inactivity timer after sending completes
await send_task
await inactivity_watch(ws, end_stream_data)
# Ensure we process remaining messages
mistakes, first_identified_position, absolute_latencies = await recv_task
if verbose:
# print(f"First identified position: {first_identified_position}")
for mistake in mistakes.values():
# print(mistake)
# print(
# f"[{mistake['mistakeType']}] E: {mistake['expectedTranscript']} / R: {mistake['receivedTranscript']}"
# f" / {mistake['startTimeMs']}-{mistake['endTimeMs']}"
# )
...
if absolute_latencies:
avg = sum(absolute_latencies.values()) / len(absolute_latencies)
# print(f"Avg absolute latency: {avg:.1f}ms ({len(absolute_latencies)} positions)")
return absolute_latencies
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Async WAV-to-WebSocket streamer")
parser.add_argument("file", help="Path to WAV audio file (16kHz, mono)")
args = parser.parse_args()
async def main():
all_avgs = []
num_runs = 10
for i in range(num_runs):
old_latencies, new_latencies = await asyncio.gather(
evaluate(args.file, new_stt_server=False),
evaluate(args.file, new_stt_server=True),
)
with open(f"old_latencies_{os.getpid()}_{i}.json", "w") as f:
json.dump(old_latencies, f, indent=2)
with open(f"new_latencies_{os.getpid()}_{i}.json", "w") as f:
json.dump(new_latencies, f, indent=2)
delta = []
for k in new_latencies:
if k not in old_latencies:
print(f"Key {k} missing in old latencies")
continue
delta.append(new_latencies[k] - old_latencies[k])
if delta:
print(f"\nAvg latency delta (new - old): {sum(delta) / len(delta):.1f}ms ({len(delta)} positions)")
all_avgs.append(sum(delta) / len(delta))
if all_avgs:
print(f"\nOverall avg latency delta across {num_runs} runs: {sum(all_avgs) / len(all_avgs):.1f}ms")
try:
asyncio.run(main())
except Exception as e:
print(f"Error: {e}")
exit(1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment