Skip to content

Instantly share code, notes, and snippets.

@tin2tin
Last active December 12, 2025 13:22
Show Gist options
  • Select an option

  • Save tin2tin/d5dbe4a7edaf6827f8acb7781b195f72 to your computer and use it in GitHub Desktop.

Select an option

Save tin2tin/d5dbe4a7edaf6827f8acb7781b195f72 to your computer and use it in GitHub Desktop.
Wan 8bit Compile crash
import os
import torch
import gc
from diffusers import WanImageToVideoPipeline, WanTransformer3DModel
from diffusers.utils import export_to_video, load_image
from transformers import BitsAndBytesConfig
from PIL import Image
# --- FIX FOR COMPILE CRASHES ON WINDOWS/BNB ---
# This prevents the "UserDefinedObjectVariable" error by falling back
# to eager execution for layers the compiler doesn't understand.
import torch._dynamo
torch._dynamo.config.suppress_errors = True
# ----------------------------------------------
# ==========================================
# 1. USER SETTINGS
# ==========================================
PROMPT = "A cinematic close-up of a cat wearing a black fedora hat, looking around, subtle movement, high quality, 4k, indoors"
NEGATIVE_PROMPT = "low quality, bad hands, distorted, blur, motion artifacts"
# INPUT_IMAGE_PATH = "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/wan_i2v_input.JPG"
INPUT_IMAGE_PATH = r"C:\Users\peter\AppData\Roaming\Blender Foundation\Blender\4.5\datafiles\Pallaidium_Media\2025-12-09\1772654657_a_cat_with_a_hat.png"
OUTPUT_PATH = "C:/Tmp/wan_result_8bit.mp4"
# ==========================================
# 2. SETUP & CLEANUP
# ==========================================
print("--- Initializing ---")
gc.collect()
torch.cuda.empty_cache()
MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
# 8-Bit Configuration
# Note: 8-bit uses more VRAM (~15GB for weights) than 4-bit.
int8_config = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_enable_fp32_cpu_offload=True # Helps prevent conflicts during offload
)
# ==========================================
# 3. LOAD MODELS (8-BIT QUANTIZED)
# ==========================================
print("--- Loading Models (8-Bit) ---")
# Load Transformer 1 (High Noise)
transformer_high = WanTransformer3DModel.from_pretrained(
MODEL_ID,
subfolder="transformer",
quantization_config=int8_config,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True
)
# Load Transformer 2 (Low Noise)
transformer_low = WanTransformer3DModel.from_pretrained(
MODEL_ID,
subfolder="transformer_2",
quantization_config=int8_config,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True
)
# Create Pipeline
pipe = WanImageToVideoPipeline.from_pretrained(
MODEL_ID,
transformer=transformer_high,
transformer_2=transformer_low,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True
)
# Enable CPU Offload (Crucial for 8-bit 14B models on 24GB cards)
pipe.enable_model_cpu_offload()
# --- COMPILE ENABLED ---
print("--- Compiling Model (This may take a few minutes) ---")
# fullgraph=True is aggressive; if it still errors, try removing fullgraph=True
try:
pipe.transformer.compile_repeated_blocks(fullgraph=True)
except Exception as e:
print(f"Warning: Compilation failed ({e}). Proceeding with uncompiled model.")
# ==========================================
# 4. LOAD LORA (4-Step Turbo)
# ==========================================
print("--- Loading LoRAs ---")
try:
pipe.load_lora_weights(
"Kijai/WanVideo_comfy",
weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
adapter_name="lightx2v"
)
pipe.load_lora_weights(
"Kijai/WanVideo_comfy",
weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
adapter_name="lightx2v_2",
load_into_transformer_2=True
)
pipe.set_adapters(["lightx2v", "lightx2v_2"], adapter_weights=[1.0, 1.0])
print("LoRAs loaded successfully.")
except Exception as e:
print(f"LoRA Load Failed: {e}")
# ==========================================
# 5. PREPARE IMAGE
# ==========================================
print("--- Preparing Image ---")
if INPUT_IMAGE_PATH.startswith("http"):
input_image = load_image(INPUT_IMAGE_PATH)
else:
input_image = Image.open(INPUT_IMAGE_PATH)
def resize_for_wan(image, max_dim=832):
w, h = image.size
scale = max_dim / max(w, h)
new_w = int(w * scale)
new_h = int(h * scale)
new_w = (new_w // 16) * 16
new_h = (new_h // 16) * 16
return image.resize((new_w, new_h), Image.LANCZOS)
input_image = resize_for_wan(input_image)
# ==========================================
# 6. GENERATE
# ==========================================
print("--- Generating Video ---")
gc.collect()
torch.cuda.empty_cache()
seed = 42
generator = torch.Generator(device="cuda").manual_seed(seed)
output_frames = pipe(
image=input_image,
prompt=PROMPT,
negative_prompt=NEGATIVE_PROMPT,
height=input_image.height,
width=input_image.width,
num_frames=81,
num_inference_steps=8,
guidance_scale=1.0,
guidance_scale_2=1.0,
generator=generator
).frames[0]
# ==========================================
# 7. SAVE RESULT
# ==========================================
export_to_video(output_frames, OUTPUT_PATH, fps=16)
print(f"DONE! Video saved to: {OUTPUT_PATH}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment