Created
June 14, 2022 02:57
-
-
Save drdaxxy/deaeddf9672aa76b72752c3719d5c370 to your computer and use it in GitHub Desktop.
Speeding up DALL-E Mega inference with parallel null prompt evaluation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Speeding up DALL-E Mega inference with parallel null prompt evaluation\n", | |
| "\n", | |
| "* `DalleBart.generate` currently *(June 13, 2022, `borisdayma/dalle-mini` revision `a72705f`)* runs `model.decode` twice per position, sequentially, when applying \"super conditioning\".\n", | |
| "* **Parallelizing this with `jax.vmap` dramatically speeds up inference in my tests on an RTX 3090.**\n", | |
| "* At small batch sizes there's a negligible increase in VRAM use, for larger batches it actually saves some memory.\n", | |
| "* `mini-1:v0` on 8 TPUv2 cores in Colab Free also gets significant speedups, as does `mega-1-fp16` with batch size 1 (larger ones run out of memory). **If you have easy access to better TPUs, I'd be grateful if you tested this.**\n", | |
| "\n", | |
| "## ⏱️ Performance stats\n", | |
| "\n", | |
| "My results of running this notebook with a few different settings:\n", | |
| "\n", | |
| "* In each case, `n_predictions` was 8.\n", | |
| "* I restarted the notebook between experiments, commented out the definition and warmup of the unused `p_generate_fn` version, and ran the cells in order.\n", | |
| "* Local test runs performed on a *Windows 11* system (native, not WSL) with an RTX 3090, Ryzen 7 3700X (on PCIe 3.0), 32 GB host RAM, driver 512.95, CUDA 11.3.1, cuDNN 8.4.0, Flax 0.5.0, JAX/jaxlib installed from source (commit `d43cb36dae7e2f4cf734de29431cc371a5efeac5`).\n", | |
| "* Process GPU memory measured with the highly sophisticated method of watching the *Dedicated GPU memory* process details column in Windows Task Manager during the loops 🙂.\n", | |
| "\n", | |
| "### half precision, mega-1-fp16:v14\n", | |
| "\n", | |
| "| Experiment | cell run, n_predictions=8 | Time per image | | Process VRAM usage | |\n", | |
| "|---------------------------:|-------------------------------:|---------------:|------------:|-------------------:|-------:|\n", | |
| "| 2 prompts/batch, serial | `8/8 [02:28<00:00, 18.58s/it]` | 9.25 s | | 7,118,364k | |\n", | |
| "| 2 prompts/batch, parallel | `8/8 [01:25<00:00, 10.64s/it]` | 5.31 s | **-42.60%** | 7,136,796k | +0.26% |\n", | |
| "| 16 prompts/batch, serial | `8/8 [03:33<00:00, 26.68s/it]` | 1.66 s | | 10,222,268k | |\n", | |
| "| 16 prompts/batch, parallel | `8/8 [02:29<00:00, 18.70s/it]` | 1.16 s | **-30.12%** | 9,935,800k | -2.80% |\n", | |
| "\n", | |
| "### single precision, mega-1:v16\n", | |
| "\n", | |
| "| Experiment | cell run, n_predictions=8 | Time per image | | Process VRAM usage | |\n", | |
| "|---------------------------:|-------------------------------:|---------------:|------------:|-------------------:|-------:|\n", | |
| "| 2 prompts/batch, serial | `8/8 [03:17<00:00, 24.60s/it]` | 12.31 s | | 12,628,312k | |\n", | |
| "| 2 prompts/batch, parallel | `8/8 [01:52<00:00, 13.95s/it]` | 7.00 s | **-43.14%** | 12,661,012k | +0.26% |\n", | |
| "| 16 prompts/batch, serial | `8/8 [04:59<00:00, 37.51s/it]` | 2.34 s | | 18,835,972k | |\n", | |
| "| 16 prompts/batch, parallel | `8/8 [03:48<00:00, 28.64s/it]` | 1.78 s | **-23.93%** | 18,263,812k | -3.04% |\n", | |
| "\n", | |
| "### mini-1:v0 on Colab TPUv2-8\n", | |
| "\n", | |
| "*Cell runs* here are the output of `%%timeit run_pred_no_decode(...)`, I preemptively skipped loading the VQGAN since I ran out of memory trying to load `mega-1-fp16` before. \n", | |
| "\n", | |
| "(I loaded `mini-1:v0` with `float16` dtype, not `bfloat16`, I haven't checked if that's correct, or whether weights should be converted etc.)\n", | |
| "\n", | |
| "| Experiment | cell run, n_predictions=8 | Time per image | |\n", | |
| "|-------------------------------:|--------------------------:|---------------:|------------:|\n", | |
| "| 2 prompts * 8 cores, serial | 25.6 s | 1.6 s | |\n", | |
| "| 2 prompts * 8 cores, parallel | 9.27 s | 0.579 s | **-63.81%** |\n", | |
| "| 4 prompts * 8 cores, serial | 26.4 s | 0.825 s | |\n", | |
| "| 4 prompts * 8 cores, parallel | 13.9 s | 0.434 s | **-47.39%** |\n", | |
| "| 8 prompts * 8 cores, serial | 28.1 s | 0.439 s | |\n", | |
| "| 8 prompts * 8 cores, parallel | 23 s | 0.359 s | **-18.22%** |\n", | |
| "| 16 prompts * 8 cores, serial | 48.5s | 0.379 s | |\n", | |
| "| 16 prompts * 8 cores, parallel | 41.7s | 0.326 s | **-13.98%** |\n", | |
| "| 32 prompts * 8 cores, serial | 89 s | 0.348 s | |\n", | |
| "| 32 prompts * 8 cores, parallel | 79 s | 0.309 s | **-11.20%** |\n", | |
| "\n", | |
| "### Colab TPUv2 mega-1-fp16\n", | |
| "\n", | |
| "With 1 prompt * 8 cores, serial took 73 seconds, parallel only 35.9 seconds (-50.82%). Unfortunately, that was all I was able to test.\n", | |
| "\n", | |
| "## 📝 Notes\n", | |
| "\n", | |
| "* Using the same PRNG seeds, the methods produce slightly different output. I think it's just numerical instability. (In other experiments, I've tried to reproduce predictions across different batch sizes by giving each item its own PRNG sequence, but some outputs diverged midway through, to varying degrees.)\n", | |
| "* Besides the technique I'm demonstrating, the code here is not optimized. I *think* I've improved efficiency further in my working copy with more reuse and parallelization. It's very disorganized right now, I'll have to clean it up and test things properly, but I think I got turnaround time for 24-result batches in fp32 under 30 seconds on my GPU, with some VRAM to spare.\n", | |
| "* Even without tweaks an RTX 3090 can handle over 24 images per batch. Available memory is the throughput bottleneck. I only picked batch size 16 for demonstration, it's not a sweet spot - if latency doesn't matter, go as high as you can.\n", | |
| "\n", | |
| "The full code below is mostly copy-pasted from the original `dalle-mini/src/model/modeling.py`, the tl;dr is this change in `sample_search_body_fn` inside `DalleBart._sample`, adding an outermost axis to all `model.decode` arguments and stacking input and null sequence data in it:\n", | |
| "\n", | |
| "```diff\n", | |
| " def sample_search_body_fn(state):\n", | |
| " \"\"\"state update fn.\"\"\"\n", | |
| " prng_key, prng_key_next = jax.random.split(state.prng_key)\n", | |
| "\n", | |
| "- model_outputs = model(\n", | |
| "- state.running_token, params=params, **state.model_kwargs\n", | |
| "- )\n", | |
| "\n", | |
| "+ model_outputs_parallel_cond = model_decode_explicit_parallel_cond(\n", | |
| "+ model,\n", | |
| "+ state.running_token,\n", | |
| "+ params,\n", | |
| "+ state.model_kwargs_parallel_cond['decoder_attention_mask'],\n", | |
| "+ state.model_kwargs_parallel_cond['decoder_position_ids'],\n", | |
| "+ state.model_kwargs_parallel_cond['encoder_attention_mask'],\n", | |
| "+ state.model_kwargs_parallel_cond['encoder_outputs'],\n", | |
| "+ state.model_kwargs_parallel_cond['past_key_values'],\n", | |
| "+ )\n", | |
| "\n", | |
| "- logits = model_outputs.logits[:, -1]\n", | |
| "+ logits_both = model_outputs_parallel_cond.logits[:, :, -1]\n", | |
| "\n", | |
| " # perform super conditioning\n", | |
| " # Source: @RiversHaveWings - https://twitter.com/RiversHaveWings/status/1478093658716966912?s=20&t=xdm-wZ61Wf7OLnE_NJHZ1w\n", | |
| "- if condition_scale != 1.0:\n", | |
| "- model_outputs_uncond = model(\n", | |
| "- state.running_token, params=params, **state.model_kwargs_uncond\n", | |
| "- )\n", | |
| "- logits_uncond = model_outputs_uncond.logits[:, -1]\n", | |
| "- logits = logits_uncond + condition_scale * (logits - logits_uncond)\n", | |
| "- else:\n", | |
| "- model_outputs_uncond = None\n", | |
| "+ logits = logits_both[1] + condition_scale * (logits_both[0] - logits_both[1])\n", | |
| "```" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "dS8LbaonYm3a", | |
| "tags": [] | |
| }, | |
| "source": [ | |
| "## 🛠️ Installation and set-up" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "uzjAM2GBYpZX" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# Install required libraries\n", | |
| "!pip install -q dalle-mini\n", | |
| "!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "If you don't have a WandB session, you may want to run this:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import wandb\n", | |
| "wandb.init(anonymous='must')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "If you want to monitor GPU memory use with an external tool, run this next cell (at least for GPUs, I'm unfamiliar with TPU operation) **before loading JAX**." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import os\n", | |
| "# for externally monitoring memory usage\n", | |
| "os.environ['XLA_PYTHON_CLIENT_ALLOCATOR']='platform'" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": { | |
| "id": "Yv-aR3t4Oe5v" | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "1" | |
| ] | |
| }, | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "import jax\n", | |
| "import jax.numpy as jnp\n", | |
| "\n", | |
| "# check how many devices are available\n", | |
| "jax.local_device_count()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Note that I've set a fixed model version below for reproducibility, just in case." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": { | |
| "id": "K6CxW2o42f-w" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# Model references\n", | |
| "\n", | |
| "# dalle-mega\n", | |
| "DALLE_MODEL = \"dalle-mini/dalle-mini/mega-1-fp16:v14\" # can be wandb artifact or 🤗 Hub or local folder or google bucket\n", | |
| "dtype = jnp.float16\n", | |
| "DALLE_COMMIT_ID = None\n", | |
| "\n", | |
| "# if the notebook crashes too often you can use dalle-mini instead by uncommenting below line\n", | |
| "# DALLE_MODEL = \"dalle-mini/dalle-mini/mini-1:v0\"\n", | |
| "\n", | |
| "# VQGAN model\n", | |
| "VQGAN_REPO = \"dalle-mini/vqgan_imagenet_f16_16384\"\n", | |
| "VQGAN_COMMIT_ID = \"e93a26e7707683d349bf5d5c41c5b0ef69b677a9\"" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "tags": [] | |
| }, | |
| "source": [ | |
| "### Patched methods for parallel null prompt evaluation\n", | |
| "\n", | |
| "**⚠ Do not merge as-is:** This code makes more assumptions about the model configuration than `dalle-mini@main`.\n", | |
| "\n", | |
| "To apply the patch after already obtaining a `model` instance, uncomment the last two lines." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": { | |
| "tags": [] | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "from functools import partial\n", | |
| "import jax\n", | |
| "import flax\n", | |
| "from jax import lax\n", | |
| "from jax import numpy as jnp\n", | |
| "from typing import Dict, Optional\n", | |
| "from transformers.generation_flax_utils import FlaxSampleOutput\n", | |
| "from dalle_mini import DalleBart\n", | |
| "\n", | |
| "@flax.struct.dataclass\n", | |
| "class ParallelCondSampleState:\n", | |
| " cur_len: jnp.ndarray\n", | |
| " sequences: jnp.ndarray\n", | |
| " running_token: jnp.ndarray\n", | |
| " is_sent_finished: jnp.ndarray\n", | |
| " prng_key: jnp.ndarray\n", | |
| " model_kwargs_parallel_cond: Dict[str, jnp.ndarray]\n", | |
| "\n", | |
| "@partial(jax.vmap, in_axes=(None, 0, 0, None), out_axes=0)\n", | |
| "def prepare_encoder_decoder_kwargs_parallel_cond(self, input_ids, attention_mask, params):\n", | |
| " return self._prepare_encoder_decoder_kwargs_for_generation(\n", | |
| " input_ids, params, {\"attention_mask\": attention_mask}\n", | |
| " )\n", | |
| "\n", | |
| "@partial(jax.vmap, in_axes=(None, None, None, 0, 0), out_axes=0)\n", | |
| "def prepare_inputs_for_generation_parallel_cond(self, input_ids, max_length, attention_mask, encoder_outputs):\n", | |
| " return self.prepare_inputs_for_generation(\n", | |
| " input_ids, max_length, attention_mask=attention_mask, encoder_outputs=encoder_outputs\n", | |
| " )\n", | |
| "\n", | |
| "@partial(jax.vmap, in_axes=(None, 0, 0), out_axes=0)\n", | |
| "def update_inputs_for_generation_parallel_cond(self, model_outputs, model_kwargs):\n", | |
| " return self.update_inputs_for_generation(model_outputs, model_kwargs)\n", | |
| "\n", | |
| "@partial(jax.vmap, in_axes=(None, None, None, 0, 0, 0, 0, 0), out_axes=0)\n", | |
| "def model_decode_explicit_parallel_cond(\n", | |
| " decode_fn,\n", | |
| " running_token,\n", | |
| " params,\n", | |
| " decoder_attention_mask,\n", | |
| " decoder_position_ids,\n", | |
| " encoder_attention_mask,\n", | |
| " encoder_outputs,\n", | |
| " past_key_values\n", | |
| "):\n", | |
| " return decode_fn(\n", | |
| " running_token,\n", | |
| " params=params,\n", | |
| " decoder_attention_mask=decoder_attention_mask,\n", | |
| " decoder_position_ids=decoder_position_ids,\n", | |
| " encoder_attention_mask=encoder_attention_mask,\n", | |
| " encoder_outputs=encoder_outputs,\n", | |
| " past_key_values=past_key_values\n", | |
| " )\n", | |
| "\n", | |
| "def parallel_cond_generate(\n", | |
| " self,\n", | |
| " input_ids: jnp.ndarray,\n", | |
| " attention_mask: Optional[jnp.ndarray] = None,\n", | |
| " max_length: Optional[int] = None,\n", | |
| " pad_token_id: Optional[int] = None,\n", | |
| " bos_token_id: Optional[int] = None,\n", | |
| " eos_token_id: Optional[int] = None,\n", | |
| " decoder_start_token_id: Optional[int] = None,\n", | |
| " do_sample: Optional[bool] = None,\n", | |
| " prng_key: Optional[jnp.ndarray] = None,\n", | |
| " top_k: Optional[int] = None,\n", | |
| " top_p: Optional[float] = None,\n", | |
| " temperature: Optional[float] = None,\n", | |
| " num_beams: Optional[int] = None,\n", | |
| " no_repeat_ngram_size: Optional[int] = None,\n", | |
| " min_length: Optional[int] = None,\n", | |
| " forced_bos_token_id: Optional[int] = None,\n", | |
| " forced_eos_token_id: Optional[int] = None,\n", | |
| " length_penalty: Optional[float] = None,\n", | |
| " early_stopping: Optional[bool] = None,\n", | |
| " trace: bool = True,\n", | |
| " params: Optional[Dict[str, jnp.ndarray]] = None,\n", | |
| " condition_scale: Optional[float] = 1.0,\n", | |
| " input_ids_uncond: Optional[jnp.ndarray] = None,\n", | |
| " attention_mask_uncond: Optional[jnp.ndarray] = None,\n", | |
| " **model_kwargs,\n", | |
| "):\n", | |
| " \"\"\"Edit: Allow super conditioning.\"\"\"\n", | |
| "\n", | |
| " # set init values\n", | |
| " max_length = max_length if max_length is not None else self.config.max_length\n", | |
| " bos_token_id = (\n", | |
| " bos_token_id if bos_token_id is not None else self.config.bos_token_id\n", | |
| " )\n", | |
| " pad_token_id = (\n", | |
| " pad_token_id if pad_token_id is not None else self.config.pad_token_id\n", | |
| " )\n", | |
| " eos_token_id = (\n", | |
| " eos_token_id if eos_token_id is not None else self.config.eos_token_id\n", | |
| " )\n", | |
| " decoder_start_token_id = (\n", | |
| " decoder_start_token_id\n", | |
| " if decoder_start_token_id\n", | |
| " else self.config.decoder_start_token_id\n", | |
| " )\n", | |
| " prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)\n", | |
| "\n", | |
| " if decoder_start_token_id is None and self.config.is_encoder_decoder:\n", | |
| " raise ValueError(\n", | |
| " \"`decoder_start_token_id` has to be defined for encoder-decoder generation.\"\n", | |
| " )\n", | |
| "\n", | |
| " do_sample = do_sample if do_sample is not None else self.config.do_sample\n", | |
| " num_beams = num_beams if num_beams is not None else self.config.num_beams\n", | |
| "\n", | |
| " assert (\n", | |
| " condition_scale != 1.0 and self.config.is_encoder_decoder\n", | |
| " ), \"this patched version of generate() assumes a mega-1-like config, with super conditioning\"\n", | |
| " assert (\n", | |
| " input_ids_uncond is not None\n", | |
| " ), \"`input_ids_uncond` has to be defined for super conditioning.\"\n", | |
| " assert (\n", | |
| " do_sample is True\n", | |
| " ), \"`do_sample` has to be True for super conditioning.\"\n", | |
| " assert (\n", | |
| " num_beams == 1\n", | |
| " ), \"`num_beams` has to be 1 for super conditioning.\"\n", | |
| "\n", | |
| " input_ids_parallel_cond = jnp.stack([input_ids, input_ids_uncond])\n", | |
| " attention_mask_parallel_cond = jnp.stack([attention_mask, attention_mask_uncond])\n", | |
| "\n", | |
| " model_kwargs_parallel_cond = prepare_encoder_decoder_kwargs_parallel_cond(\n", | |
| " self,\n", | |
| " input_ids_parallel_cond,\n", | |
| " attention_mask_parallel_cond,\n", | |
| " params\n", | |
| " )\n", | |
| "\n", | |
| " # prepare decoder_input_ids for generation\n", | |
| " input_ids = (\n", | |
| " jnp.ones((input_ids.shape[0], 1), dtype=\"i4\") * decoder_start_token_id\n", | |
| " )\n", | |
| "\n", | |
| " logits_warper = self._get_logits_warper(\n", | |
| " top_k=top_k, top_p=top_p, temperature=temperature\n", | |
| " )\n", | |
| " logits_processor = self._get_logits_processor(\n", | |
| " no_repeat_ngram_size,\n", | |
| " min_length,\n", | |
| " max_length,\n", | |
| " eos_token_id,\n", | |
| " forced_bos_token_id,\n", | |
| " forced_eos_token_id,\n", | |
| " )\n", | |
| " return self._parallel_cond_sample(\n", | |
| " input_ids,\n", | |
| " max_length,\n", | |
| " pad_token_id,\n", | |
| " eos_token_id,\n", | |
| " prng_key,\n", | |
| " logits_warper=logits_warper,\n", | |
| " logits_processor=logits_processor,\n", | |
| " trace=trace,\n", | |
| " params=params,\n", | |
| " condition_scale=condition_scale,\n", | |
| " model_kwargs_parallel_cond=model_kwargs_parallel_cond,\n", | |
| " )\n", | |
| "\n", | |
| "def _parallel_cond_sample(\n", | |
| " self,\n", | |
| " input_ids: None,\n", | |
| " max_length: Optional[int] = None,\n", | |
| " pad_token_id: Optional[int] = None,\n", | |
| " eos_token_id: Optional[int] = None,\n", | |
| " prng_key: Optional[jnp.ndarray] = None,\n", | |
| " logits_processor=None,\n", | |
| " logits_warper=None,\n", | |
| " trace: bool = True,\n", | |
| " params: Optional[Dict[str, jnp.ndarray]] = None,\n", | |
| " model_kwargs_parallel_cond: Optional[Dict[str, jnp.ndarray]] = None,\n", | |
| " condition_scale: float = 1.0,\n", | |
| "):\n", | |
| " # init values\n", | |
| " max_length = max_length if max_length is not None else self.config.max_length\n", | |
| " pad_token_id = (\n", | |
| " pad_token_id if pad_token_id is not None else self.config.pad_token_id\n", | |
| " )\n", | |
| " eos_token_id = (\n", | |
| " eos_token_id if eos_token_id is not None else self.config.eos_token_id\n", | |
| " )\n", | |
| " prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)\n", | |
| "\n", | |
| " batch_size, cur_len = input_ids.shape\n", | |
| "\n", | |
| " eos_token_id = jnp.array(eos_token_id)\n", | |
| " pad_token_id = jnp.array(pad_token_id)\n", | |
| " cur_len = jnp.array(cur_len)\n", | |
| "\n", | |
| " # per batch-item holding current token in loop.\n", | |
| " sequences = jnp.full((batch_size, max_length), pad_token_id, dtype=jnp.int32)\n", | |
| " sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0))\n", | |
| "\n", | |
| " # per batch-item state bit indicating if sentence has finished.\n", | |
| " is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_)\n", | |
| "\n", | |
| " # For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop\n", | |
| " # and pass it the `encoder_outputs`, which are part of the `model_kwargs`.\n", | |
| " model = self.decode if self.config.is_encoder_decoder else self\n", | |
| "\n", | |
| " # initialize model specific kwargs\n", | |
| " model_kwargs_parallel_cond = prepare_inputs_for_generation_parallel_cond(\n", | |
| " self,\n", | |
| " input_ids,\n", | |
| " max_length,\n", | |
| " model_kwargs_parallel_cond['attention_mask'],\n", | |
| " model_kwargs_parallel_cond['encoder_outputs'],\n", | |
| " )\n", | |
| "\n", | |
| " # initialize state\n", | |
| " state = ParallelCondSampleState(\n", | |
| " cur_len=cur_len,\n", | |
| " sequences=sequences,\n", | |
| " running_token=input_ids,\n", | |
| " is_sent_finished=is_sent_finished,\n", | |
| " prng_key=prng_key,\n", | |
| " model_kwargs_parallel_cond=model_kwargs_parallel_cond,\n", | |
| " )\n", | |
| "\n", | |
| " def sample_search_cond_fn(state):\n", | |
| " \"\"\"state termination condition fn.\"\"\"\n", | |
| " has_reached_max_length = state.cur_len == max_length\n", | |
| " all_sequence_finished = jnp.all(state.is_sent_finished)\n", | |
| " finish_generation = jnp.logical_or(\n", | |
| " has_reached_max_length, all_sequence_finished\n", | |
| " )\n", | |
| " return ~finish_generation\n", | |
| "\n", | |
| " def sample_search_body_fn(state):\n", | |
| " \"\"\"state update fn.\"\"\"\n", | |
| " prng_key, prng_key_next = jax.random.split(state.prng_key)\n", | |
| "\n", | |
| " model_outputs_parallel_cond = model_decode_explicit_parallel_cond(\n", | |
| " model,\n", | |
| " state.running_token,\n", | |
| " params,\n", | |
| " state.model_kwargs_parallel_cond['decoder_attention_mask'],\n", | |
| " state.model_kwargs_parallel_cond['decoder_position_ids'],\n", | |
| " state.model_kwargs_parallel_cond['encoder_attention_mask'],\n", | |
| " state.model_kwargs_parallel_cond['encoder_outputs'],\n", | |
| " state.model_kwargs_parallel_cond['past_key_values'],\n", | |
| " )\n", | |
| " \n", | |
| " logits_both = model_outputs_parallel_cond.logits[:, :, -1]\n", | |
| " \n", | |
| " # perform super conditioning\n", | |
| " # Source: @RiversHaveWings - https://twitter.com/RiversHaveWings/status/1478093658716966912?s=20&t=xdm-wZ61Wf7OLnE_NJHZ1w\n", | |
| " logits = logits_both[1] + condition_scale * (logits_both[0] - logits_both[1])\n", | |
| "\n", | |
| " # apply min_length, ...\n", | |
| " logits = logits_processor(state.sequences, logits, state.cur_len)\n", | |
| " # apply top_k, top_k, temperature\n", | |
| " logits = logits_warper(logits, logits, state.cur_len)\n", | |
| "\n", | |
| " next_token = jax.random.categorical(prng_key, logits, axis=-1)\n", | |
| "\n", | |
| " next_is_sent_finished = state.is_sent_finished | (\n", | |
| " next_token == eos_token_id\n", | |
| " )\n", | |
| " next_token = (\n", | |
| " next_token * ~next_is_sent_finished\n", | |
| " + pad_token_id * next_is_sent_finished\n", | |
| " )\n", | |
| " next_token = next_token[:, None]\n", | |
| "\n", | |
| " next_sequences = lax.dynamic_update_slice(\n", | |
| " state.sequences, next_token, (0, state.cur_len)\n", | |
| " )\n", | |
| " next_model_kwargs_parallel_cond = update_inputs_for_generation_parallel_cond(\n", | |
| " self, model_outputs_parallel_cond, state.model_kwargs_parallel_cond\n", | |
| " )\n", | |
| "\n", | |
| " return ParallelCondSampleState(\n", | |
| " cur_len=state.cur_len + 1,\n", | |
| " sequences=next_sequences,\n", | |
| " running_token=next_token,\n", | |
| " is_sent_finished=next_is_sent_finished,\n", | |
| " model_kwargs_parallel_cond=next_model_kwargs_parallel_cond,\n", | |
| " prng_key=prng_key_next,\n", | |
| " )\n", | |
| "\n", | |
| " # The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU\n", | |
| " if input_ids.shape[1] > 1:\n", | |
| " state = sample_search_body_fn(state)\n", | |
| "\n", | |
| " if not trace:\n", | |
| " state = self._run_loop_in_debug(\n", | |
| " sample_search_cond_fn, sample_search_body_fn, state\n", | |
| " )\n", | |
| " else:\n", | |
| " state = lax.while_loop(sample_search_cond_fn, sample_search_body_fn, state)\n", | |
| "\n", | |
| " return FlaxSampleOutput(sequences=state.sequences)\n", | |
| "\n", | |
| "DalleBart.parallel_cond_generate = parallel_cond_generate\n", | |
| "DalleBart._parallel_cond_sample = _parallel_cond_sample\n", | |
| "# Or if you already have an instance:\n", | |
| "# model.parallel_cond_generate = parallel_cond_generate.__get__(model, DalleBart)\n", | |
| "# model._parallel_cond_sample = _parallel_cond_sample.__get__(model, DalleBart)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Model loading etc." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": { | |
| "id": "92zYmvsQ38vL" | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "\u001b[34m\u001b[1mwandb\u001b[0m: Downloading large artifact mega-1-fp16:v14, 4938.53MB. 7 files... Done. 0:0:0\n", | |
| "Some of the weights of DalleBart were initialized in float16 precision from the model checkpoint at C:\\Users\\drdax\\AppData\\Local\\Temp\\tmpo7x5aom9:\n", | |
| "[('lm_head', 'kernel'), ('model', 'decoder', 'embed_positions', 'embedding'), ('model', 'decoder', 'embed_tokens', 'embedding'), ('model', 'decoder', 'final_ln', 'bias'), ('model', 'decoder', 'layernorm_embedding', 'bias'), ('model', 'decoder', 'layernorm_embedding', 'scale'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'FlaxBartAttention_0', 'k_proj', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'FlaxBartAttention_0', 'out_proj', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'FlaxBartAttention_0', 'q_proj', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'FlaxBartAttention_0', 'v_proj', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'FlaxBartAttention_1', 'k_proj', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'FlaxBartAttention_1', 'out_proj', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'FlaxBartAttention_1', 'q_proj', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'FlaxBartAttention_1', 'v_proj', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'GLU_0', 'Dense_0', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'GLU_0', 'Dense_1', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'GLU_0', 'Dense_2', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'GLU_0', 'LayerNorm_0', 'bias'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'GLU_0', 'LayerNorm_1', 'bias'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'LayerNorm_0', 'bias'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'LayerNorm_1', 'bias'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'LayerNorm_1', 'scale'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'LayerNorm_2', 'bias'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'LayerNorm_3', 'bias'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'LayerNorm_3', 'scale'), ('model', 'encoder', 'embed_positions', 'embedding'), ('model', 'encoder', 'embed_tokens', 'embedding'), ('model', 'encoder', 'final_ln', 'bias'), ('model', 'encoder', 'layernorm_embedding', 'bias'), ('model', 'encoder', 'layernorm_embedding', 'scale'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'FlaxBartAttention_0', 'k_proj', 'kernel'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'FlaxBartAttention_0', 'out_proj', 'kernel'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'FlaxBartAttention_0', 'q_proj', 'kernel'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'FlaxBartAttention_0', 'v_proj', 'kernel'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'GLU_0', 'Dense_0', 'kernel'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'GLU_0', 'Dense_1', 'kernel'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'GLU_0', 'Dense_2', 'kernel'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'GLU_0', 'LayerNorm_0', 'bias'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'GLU_0', 'LayerNorm_1', 'bias'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'LayerNorm_0', 'bias'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'LayerNorm_1', 'bias'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'LayerNorm_1', 'scale')]\n", | |
| "You should probably UPCAST the model weights to float32 if this was not intended. See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this.\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# Load models & tokenizer\n", | |
| "from dalle_mini import DalleBart, DalleBartProcessor\n", | |
| "from vqgan_jax.modeling_flax_vqgan import VQModel\n", | |
| "from transformers import CLIPProcessor, FlaxCLIPModel\n", | |
| "\n", | |
| "# Load dalle-mini\n", | |
| "model, params = DalleBart.from_pretrained(\n", | |
| " DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=dtype, _do_init=False\n", | |
| ")\n", | |
| "\n", | |
| "# Load VQGAN\n", | |
| "vqgan, vqgan_params = VQModel.from_pretrained(\n", | |
| " VQGAN_REPO, revision=VQGAN_COMMIT_ID, _do_init=False\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": { | |
| "id": "wtvLoM48EeVw" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "from flax.jax_utils import replicate\n", | |
| "\n", | |
| "params = replicate(params)\n", | |
| "vqgan_params = replicate(vqgan_params)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": { | |
| "id": "sOtoOmYsSYPz" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "from functools import partial\n", | |
| "\n", | |
| "def model_generate_factory(generate_fn):\n", | |
| " # model inference\n", | |
| " def _generate(\n", | |
| " tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale\n", | |
| " ):\n", | |
| " return generate_fn(\n", | |
| " **tokenized_prompt,\n", | |
| " prng_key=key,\n", | |
| " params=params,\n", | |
| " top_k=top_k,\n", | |
| " top_p=top_p,\n", | |
| " temperature=temperature,\n", | |
| " condition_scale=condition_scale,\n", | |
| " )\n", | |
| " return _generate\n", | |
| "\n", | |
| "partial_generate = model_generate_factory(model.generate)\n", | |
| "p_generate = jax.pmap(partial_generate, axis_name=\"batch\", static_broadcasted_argnums=(3, 4, 5, 6))\n", | |
| "partial_parallel_cond_generate = model_generate_factory(model.parallel_cond_generate)\n", | |
| "p_parallel_cond_generate = jax.pmap(partial_parallel_cond_generate, axis_name=\"batch\", static_broadcasted_argnums=(3, 4, 5, 6))\n", | |
| "\n", | |
| "# decode image\n", | |
| "@partial(jax.pmap, axis_name=\"batch\")\n", | |
| "def p_decode(indices, params):\n", | |
| " return vqgan.decode_code(indices, params=params)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "metadata": { | |
| "id": "4CTXmlUkThhX" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "import random\n", | |
| "\n", | |
| "# create a random key\n", | |
| "seed = random.randint(0, 2**32 - 1)\n", | |
| "key = jax.random.PRNGKey(seed)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "BrnVyCo81pij" | |
| }, | |
| "source": [ | |
| "## 🖍 Text Prompt" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "metadata": { | |
| "id": "YjjhUychOVxm" | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "\u001b[34m\u001b[1mwandb\u001b[0m: Downloading large artifact mega-1-fp16:v14, 4938.53MB. 7 files... Done. 0:0:0\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "from dalle_mini import DalleBartProcessor\n", | |
| "\n", | |
| "processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "prompts = [\"avocado armchair\", \"avocado armchair dot svg\"]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "metadata": { | |
| "id": "VKjEZGjtO49k" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "tokenized_prompts = processor(prompts)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "metadata": { | |
| "id": "lQePgju5Oe5z" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "tokenized_prompt = replicate(tokenized_prompts)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "phQ9bhjRkgAZ" | |
| }, | |
| "source": [ | |
| "## 🎨 Generate images" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "metadata": { | |
| "id": "d0wVkXpKqnHA" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# number of predictions per prompt\n", | |
| "n_predictions = 8\n", | |
| "\n", | |
| "# We can customize generation parameters (see https://huggingface.co/blog/how-to-generate)\n", | |
| "gen_top_k = None\n", | |
| "gen_top_p = None\n", | |
| "temperature = None\n", | |
| "cond_scale = 10.0" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 14, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from flax.training.common_utils import shard_prng_key\n", | |
| "import numpy as np\n", | |
| "from PIL import Image\n", | |
| "from tqdm.notebook import trange\n", | |
| "import IPython.display\n", | |
| "\n", | |
| "def run_pred(p_generate_fn, tokenized_prompt, subkey, params, vqgan_params, top_k, top_p, temperature, cond_scale):\n", | |
| " # generate images\n", | |
| " encoded_images = p_generate_fn(\n", | |
| " tokenized_prompt,\n", | |
| " shard_prng_key(subkey),\n", | |
| " params,\n", | |
| " gen_top_k,\n", | |
| " gen_top_p,\n", | |
| " temperature,\n", | |
| " cond_scale,\n", | |
| " )\n", | |
| " # remove BOS\n", | |
| " encoded_images = encoded_images.sequences[..., 1:]\n", | |
| " # decode images\n", | |
| " decoded_images = p_decode(encoded_images, vqgan_params)\n", | |
| " decoded_images = np.asarray(decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3)) * 255, dtype=np.uint8)\n", | |
| " return decoded_images" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Warmup/compile\n", | |
| "\n", | |
| "To properly test an implementation's performance, we first need to run it at least once **with the same sampling arguments** (top-k, top-p, temperature, condition scale) **and number of prompts as during the benchmark** (not `n_predictions`, though), as each combination of these parameters causes a fairly slow combination step on its first use.\n", | |
| "\n", | |
| "To judge an implementation's memory use, the foolproof way to ensure a fair comparison is to restart the runtime and only run a single benchmark and warmup - comment out the other implementation - with a single set of parameters." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 15, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "run_pred(p_generate, tokenized_prompt, key, params, vqgan_params, gen_top_k, gen_top_p, temperature, cond_scale);\n", | |
| "run_pred(p_parallel_cond_generate, tokenized_prompt, key, params, vqgan_params, gen_top_k, gen_top_p, temperature, cond_scale);" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## 🧪 Experiments" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 16, | |
| "metadata": { | |
| "id": "SDjEx9JxR3v8" | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<h3>With serial null prompt evaluation</h3>" | |
| ], | |
| "text/plain": [ | |
| "<IPython.core.display.HTML object>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "dalle-mini/dalle-mini/mega-1-fp16:v14 <class 'jax.numpy.float16'> None None None 10.0 1\n", | |
| "Prompts: ['avocado armchair', 'avocado armchair dot svg'] (displayed when all are done)\n", | |
| "\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "3ee4bc5f5be14dd2956ebc6bf6cc1775", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| " 0%| | 0/8 [00:00<?, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment