DeepSeek Logo

After spending a weekend with DeepSeek R1, I was impressed enough that I decided to explore running it locally. Rationale:

  • Reasoning Ability: Initial testing showed promise that I could use it’s reasoning capabilities to assist with my research. Currently, I don’t have a local model that can do this.
  • Privacy: I’m not certain whether the officially hosted DeepSeek model is a covert data collection tool by the CCP or not, but as a general rule I prefer to run AI models locally so I know who has access to my data.
  • Local hosting is usually lower cost vs renting a giant 8x H100 workstation in the cloud.

Currently, I am using vLLM for local inferencing and have been quite happy with the performance, fast evolution, and the community. However, due to the size of DeepSeek R1, my limited exposure to sub-4bit quantization, and all the experimental warnings in the vLLM docs, I decided to give LLaMA.cpp a try. I realize that I am a little late to the game when it comes to LLaMA.cpp but it seemed like a good opportunity to add it to my toolbox.

Use Cases

  1. Synthetic data generation with up to 000s of tasks, where each task typically uses ≈3000 request tokens and ≈2000 response tokens. The 2000 response tokens estimation accounts for larger responses returned by reasoning models. Does not require multiple generations per prompt.
  2. General chat assistance for coding, prompt optimization, knowledge extraction, etc.

Goals

  • Determine the optimal configuration for DeepSeek R1 on my hardware that aligns with my use cases.
  • Familiarize myself with LLaMA.cpp
  • Share process and findings

Hardware

  • 4x Nvidia RTX A6000 GPUs (48GB VRAM each)
  • With ECC enabled (I know, I know) each GPU has 44GB of usable VRAM totaling ≈176GB for tensor-parallel processing. I may do a writeup on why I enabled this if there is interest.
  • AMD Ryzen Threadripper PRO 5975WX with 32 physical cores
  • 256GB system RAM
  • Lots of storage

Plan

  1. Build LLaMA.cpp from source, look for optimizations.
  2. Familiarize myself with the GGUF format and tooling.
  3. Quantize a DeepSeek R1-family model for testing.
  4. Smoke test LLaMA.cpp with a smaller, well-supported model.
  5. Benchmark performance of DeepSeek R1 on my hardware, select optimal configurations

TL;DR

  • Due to poorer performance of LLaMA.cpp vs vLLM, only use LLaMA.cpp if I don’t have the VRAM.
  • DeepSeek-R1-UD-IQ1_S via LLaMA.cpp is good enough for chat / general assistance but not batch inferencing and synthetic data generation at the scale I need.
  • DeepSeek-R1-Distill-Llama-70B is my only usable choice for synthetic data generation.

Optimal Chat Configuration (LLaMA.cpp):

docker run -it -d \
    --name=eleanor-llamacpp \
    --restart=unless-stopped \
    --shm-size=15g \
    --ulimit memlock=-1 \
    --ipc=host \
    --gpus="device=0,1,2,3" \
    --publish=7850:8000 \
    --volume=/mnt/ai/theobjectivedad/llm/models/DeepSeek-R1:/models:ro \
    --health-cmd=timeout 5 bash -c 'cat < /dev/null > /dev/tcp/localhost/8000' \
    --health-start-period=240s \
    --health-interval=15s \
    --health-timeout=8s \
    --health-retries=3 \
        ghcr.io/ggerganov/llama.cpp:server-cuda-b4580 \
            --alias DeepSeek-R1-UD-IQ1_S \
            --batch-size 2048 \
            --cache-type-k iq4_nl \
            --ctx-size 8192 \
            --log-colors \
            --log-timestamps \
            --main-gpu 2 \
            --metrics \
            --min-p 0.1 \
            --mlock \
            --model /models/DeepSeek-R1-UD-IQ1_S.gguf \
            --gpu-layers 62 \
            --parallel 1 \
            --escape \
            --no-mmap \
            --no-webui \
            --port 8000 \
            --split-mode layer \
            --temp 0.8 \
            --threads-batch 32 \
            --threads-http 4 \
            --threads 32 \
            --top-k 40 \
            --top-p 0.9 \
            --ubatch-size 512

Optimal Synthetic Data Generation Configuration (vLLM):

docker run -it -d \
    --name=eleanor-vLLM \
    --restart=unless-stopped \
    --shm-size=15g \
    --ulimit memlock=-1 \
    --ipc=host \
    --entrypoint=python3 \
    --gpus="device=0,1,2,3" \
    --publish=7800:8000 \
    --volume=/models:/models:ro \
    --health-cmd=timeout 5 bash -c 'cat < /dev/null > /dev/tcp/localhost/8000' \
    --health-start-period=240s \
    --health-interval=15s \
    --health-timeout=8s \
    --health-retries=3 \
    --env=OMP_NUM_THREADS=1 \
    harbor.k8s.wm.k8slab/eleanor-ai/vllm-openai:0.6.6.post1 \
        -m vllm.entrypoints.openai.api_server \
        --model /models/DeepSeek-R1-Distill-Llama-70B \
        --served-model-name DeepSeek-R1-Distill-Llama-70B \
        --response-role auto \
        --load-format safetensors \
        --tokenizer-mode auto \
        --enable-chunked-prefill=True \
        --max-num-batched-tokens=4096 \
        --dtype bfloat16 \
        --kv-cache-dtype auto \
        --gpu-memory-utilization 0.90 \
        --enable-auto-tool-choice \
        --tool-call-parser llama3_json \
        --enable-prefix-caching \
        --device=cuda \
        --task=generate \
        --scheduler-delay-factor=0.25 \
        --uvicorn-log-level=debug \
        --distributed-executor-backend=mp \
        --guided-decoding-backend=outlines \
        --disable_custom_all_reduce \
        --max-model-len 60000 \
        --tensor-parallel-size 4 \
        --port 8000 \
        --host 0.0.0.0

Building LLaMA.cpp

Building LLaMA.cpp was fairly straightforward, these were the steps I used:

First, install ccache to speed up the (re)build process:

sudo apt-get install -y ccache

Next, just follow the official build instructions, the following steps were taken to build binaries:

git clone https://github.com/ggerganov/llama.cpp
cd llama.cpp
git checkout b4585
cmake --fresh -B build \
    -DGGML_CUDA=ON \
    -DCMAKE_CUDA_ARCHITECTURES=86 \
    -DGGML_CUDA_F16=ON \
    -DLLAMA_CURL=ON \
    -DGGML_CUDA_PEER_MAX_BATCH_SIZE=256 \
    -DGGML_CUDA_FA_ALL_QUANTS=ON \
    -DGGML_BLAS=ON \
    -DGGML_BLAS_VENDOR=OpenBLAS \
    -DGGML_CUDA_FA_ALL_QUANTS=ON \
    -DBUILD_SHARED_LIBS=OFF
cmake --build build --config Release -j 24

The build instructions are specific to my hardware, yours will likely be different.

One build option that looked interesting but I didn’t experiment with was GGML_CUDA_FORCE_MMQ=ON, which from the CUDA build documentation:

Force the use of custom matrix multiplication kernels for quantized models instead of FP16 cuBLAS even if there is no int8 tensor core implementation available (affects V100, RDNA3). MMQ kernels are enabled by default on GPUs with int8 tensor core support. With MMQ force enabled, speed for large batch sizes will be worse but VRAM consumption will be lower.

Although I plan to use the official LLaMA.cpp image, it will be very helpful to have “from scratch” build with all tools on my host workstation for testing.

GGUF Stuff

Post build of LLaMA.cpp, I needed to familiarize myself with the GPT-Generated Unified Format (GGUF) and tools. There are many articles on what GGUF is and it’s history however the minimum you need to know is that GGUF is a format for storing model weights, tokenizer, and key-value caches in a way that is optimized for GPU inferencing.

These are the models I will be working with:

GGUF Merge / Split

HuggingFace distributes large models in GGUF format as a series of files. While LLaMA.cpp can read these files directly, it other inferencing servers such as vLLM need a single file. The llama-gguf-split tool build in the previous section is used to split and merge GGUF files.

Here are the commands I used to split and merge the test model:

# Split Test
mkdir Llama-3.2-3B-Instruct-Q4_K_M-split
llama-gguf-split --split-max-size 250M Llama-3.2-3B-Instruct-Q4_K_M.gguf Llama-3.2-3B-Instruct-Q4_K_M-split/Llama-3.2-3B-Instruct-Q4_K_M

# Merge test
mkdir Llama-3.2-3B-Instruct-Q4_K_M-merged
llama-gguf-split --merge \
    Llama-3.2-3B-Instruct-Q4_K_M-split/Llama-3.2-3B-Instruct-Q4_K_M-00001-of-00009.gguf \
    Llama-3.2-3B-Instruct-Q4_K_M-merged/Llama-3.2-3B-Instruct-Q4_K_M.gguf

HuggingFace Transformers to GGUF

This is a simple example of how a HuggingFace transformers model can be converted to GUFF. Notice that we are not quantizing the model here, just copying the BF16 weights to the GGUF format:

python -m venv venv
source venv/bin/activate
python convert_hf_to_gguf.py \
    --outtype bf16 \
    --outfile /models/DeepSeek-R1-Distill-Llama-70B-B16.gguf \
    /models/DeepSeek-R1-Distill-Llama-70B

Quantizing

Quantization is the process of reducing the precision of the numeric model weights into lower bit representations. The benefit if this is usually faster inferencing speed and reduced VRAM requirements at (possibly) the cost of accuracy. This is a giant topic and an active area of research beyond the scope of this article.

The example command below quantizes DeepSeek-R1-Distill-Llama-70B-B16 (16 bits per weight) down to 8 bits per weight:

llama-quantize \
    /models/DeepSeek-R1-Distill-Llama-70B-B16.gguf \
    /models/DeepSeek-R1-Distill-Llama-70B-Q8_0.gguf \
    Q8_0

While playing around with llama-quantize I used the following references:

LLaMA.cpp Smoke Test

To ensure that my build of LLaMA.cpp is working properly I wanted to test it with a smaller model that is known to work When evaluating a new tool, this is typically ends up being the most important step in the process. Basically this is where I:

  • Test the LLaMA.cpp build.
  • Establish my first working configuration, as you can see below I usually add all the CLI options I care about to the initial command (even if the settings is the same as the default) so I can quickly explore the effects of the various configure options. Moreover, this helps me learn the a tool’s capabilities quickly.
  • Fix any initial issues that show up in the “easy” case.

After some trial and error, the llama-cli command below represents my “hello world” configuration. Documentation for llama-cli can be found via the online --help as well as the official example:

llama-cli \
    --model Llama-3.2-3B-Instruct-Q4_K_M-split/Llama-3.2-3B-Instruct-Q4_K_M-00001-of-00009.gguf \
    --prompt "You are a helpful assistant" \
    --conversation \
    --dump-kv-cache \
    --no-mmap \
    --mlock \
    --flash-attn \
    --n-gpu-layers 28 \
    --main-gpu 3 \
    --split-mode layer \
    --tensor-split 1,1,1,1 \
    --parallel 1 \
    --log-colors \
    --multiline-input \
    --log-timestamps \
    --temp 0.8 \
    --top-k 40 \
    --top-p 0.9 \
    --min-p 0.1 \
    --xtc-probability 0.0 \
    --xtc-threshold 0.0 \
    --typical 1.0 \
    --repeat-last-n 64 \
    --repeat-penalty 1.0 \
    --presence-penalty 0.0 \
    --frequency-penalty 0.0 \
    --dynatemp-range 0.0 \
    --no-kv-offload \
    --dynatemp-exp 1.0 \
    --ctx-size 0 \
    --color

There were a two things that I found non-intuitive using llama-cli:

  1. llama.cpp will not accept an equal sign between the flag and the value. For example, --model=... will not work. Instead, use --model ....
  2. Unusual (in my opinion) syntax for loading a sharded GGUF model. After failing a few times, I found the correct syntax in a text script in the official repository. In my opinion passing a base directory would be more intuitive but this works too.

Additionally, I was getting errors setting the KV cache data types, in particular setting --cache-type-k bf16 and --cache-type-v bf16. As of 2025-01-30 there is an open issue on the LLaMA.cpp that seems to be related. Moving this problem to the parking lot for now.

Now that I have a working baseline, I can move on to DeepSeek R1.

First Run of DeepSeek R1

For my first run, I used more-or-less the recommended configuration from the Unsloth blog post:

# Baseline
llama-cli \
    --model DeepSeek-R1-UD-IQ1_M/DeepSeek-R1-UD-IQ1_M-00001-of-00004.gguf \
    -no-cnv \
    --cache-type-k q4_0 \
    --temp 0.6 \
    --ctx-size 8192 \
    --seed 3407 \
    --prompt "<|User|>Create a Flappy Bird game in Python.<|Assistant|>" \
    --log-colors \
    --log-timestamps \
    --temp 0.6 \
    --min-p 0.05 \
    --color

While this technically worked, there were two issues immediately apparent:

  1. Very slow load time which caused me to double-check my setup. Moreover CPU, RAM, and GPU were all idle.
  2. As expected under this configuration, the output was 3.85 response tokens per second and way too slow for my use case.

Addressing Slow Startup Time

To address the slow load time, I decided to move the model from my magnetic NAS-grade storage to my faster NVMe. I actually need to do this frequently and take advantage of rsync to get around 350MB/s to 450MB/s for the transfer:

ionice -c 2 -n 0 rsync \
    --archive \
    --human-readable \
    --info=progress2 \
    --delete-before \
    --preallocate \
    --whole-file \
    --bwlimit=0 \
    --inplace \
    /archive/models/DeepSeek-R1-GGUF/DeepSeek-R1-UD-IQ1_S \
    /mnt/ai/theobjectivedad/llm/models/

The table below explains what each rsync option is doing. This is fairly specific to a setup where storage is attached but I wanted to share in case others may find it useful.

Rsync Option Why It Was Chosen
--archive Ensures an exact copy, preserving permissions, timestamps, and symlinks.
--human-readable Easier to read.
--info=progress2 Displays detailed real-time progress for large files.
--delete-before Cleans up destination files before copying to avoid unnecessary writes.
--preallocate Reserves space on the SSD in advance, reducing fragmentation.
--whole-file Copies full files instead of incremental deltas (faster for local transfers).
--inplace Writes directly to the target file, avoiding extra I/O operations.
--bwlimit=0 Removes any artificial bandwidth limits to maximize speed.
--partial Ensures interrupted transfers can resume without restarting.
ionice -c 2 -n 0 Gives rsync high-priority disk access for faster reads from the HDD.

The next test run took only over a minute to load, considering this resolved.

Addressing Slow Inference Time

More or less this is going to be finding the right balance bits per weight (BPW) and context size such that the entire model can fit into GPU VRAM. Since the entire KV cache needs to stay in VRAM, given my hardware, the fastest response tokens per second configuration will be the least accurate with the smallest context.

Unfortunately on startup, I saw a llama_init_from_model: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off message on model load, meaning that flash attention wasn’t going to be in the toolbox for me. Moreover, flash attention is required for value cache quantization so this is out of the question as well.

Since this test will - by its nature be run out of VRAM for a single inferencing request, I am not going to tune batch size. Additionally, no-mmap and mlock are not needed as the model and KV cache are completely in VRAM.

After some trial and error, I was able to get around 17.76 response tokens per second on average, which is obviously a significant improvement over the initial run. The configuration I used is below:

llama-cli \
    --model DeepSeek-R1-UD-IQ1_S/DeepSeek-R1-UD-IQ1_S-00001-of-00003.gguf \
    --prompt "<|User|>What is the best college degree to get for a US graduate in 2025?<|Assistant|>" \
    --ctx-size 7100 \
    --temp 0.6 \
    --min-p 0.05 \
    --n-gpu-layers 62 \
    --split-mode layer \
    --seed 3407 \
    --no-conversation \
    --log-timestamps \
    --log-colors \
    --color

The table below details the configuration options I used to achieve this performance:

Option Description & Performance Impact
--n-gpu-layers 62 Offloads all 62 model layers to the GPUs.
--ctx-size 7100 Sets context length (how many tokens the model remembers), 7100 was near the maximum I could use without having to spill into system RAM
--seed 3407 Specified a constant seed since I want to have more deterministic tests.
--split-mode layer Controls how the model is split across multiple GPUs. layer performed significantly better thanrow for me.
--temp 0.6 Recommended value from the Unsloth README
--min-p 0.05 Recommended value from the Unsloth README

For completeness, this is the command I used to run a CLI chat with the working configuration above:

llama-cli \
    --model DeepSeek-R1-UD-IQ1_S/DeepSeek-R1-UD-IQ1_S-00001-of-00003.gguf \
    --prompt "You are a helpful assistant" \
    --conversation \
    --dump-kv-cache \
    --ctx-size 7100 \
    --temp 0.6 \
    --min-p 0.05 \
    --n-gpu-layers 62 \
    --split-mode layer \
    --log-timestamps \
    --log-colors \
    --color

Performance Benchmarking Strategy

While I’m not claiming to be an expert in language model benchmarking, I did want to briefly share the process I used to determine an optimal server configuration given my hardware and use case. Broadly, this involved optimizing both prompt processing speed (S_PP t/s) and text generation speed (S_TG t/s) for my hardware.

Optimizing Text Generation Speed

Since I already know that the entire context will fit into my VRAM, I am need to derive the maximum parallel requests:

$$ \begin{aligned} \text{max_request_tokens} &= 3000 \\ \text{max_resp_tokens} &= 2000 \\ \text{safety_tokens} &= 500 \\ \text{max_total_tokens} &= 3000 + 2000 + 500 = 5500 \\ \text{context_length} = 131072 \end{aligned} $$

Recall that \(\text{parallel_requests}\) is limited by the \(\text{context_length}\) such that:

$$ \text{max_total_tokens} \times \text{parallel_requests} \leq \text{context_length} $$

Therefore,

$$ \text{parallel_requests} = \left\lfloor \frac{131072}{5500} \right\rfloor = 23 $$

Translating this to llama-batched-bench parameters we will set the following:

Parameter Value
-npl 23
--ctx-size 131072
--parallel 1 (default)

Optimizing Prompt Processing Speed

The first parameter to tune is the maximum hardware batch size \(\text{ubatch_size}\), where I used the following formula:

$$ \text{ubatch_size} = \text{max_request_tokens} + \left| \frac{\text{safety_tokens}}{2} \right| = 3000 + \left| \frac{500}{2} \right| = 3250 $$

Update: When testing DeepSeek-R1-Distill-Llama-70B-B16 I needed to lower --ubatch-size to 1024 because I was getting OOM errors. Obviously something is missing from my calculation but it did get me close enough to continue the test after a few iterations.

Next, I derive the batch_size:

$$ \text{batch_size} = \text{ubatch_size} \times 4 = 3250 \times 4 = 13000 $$

Translating this to llama-batched-bench parameters we will set the following:

Parameter Value
--ubatch-size 3250
--batch-size 13000

Additional optimizations

The table below details additional parameters I used to optimize performance:

Parameter Value
--n-gpu-layers Offload model layers to the GPU, combined with --split-mode layer runs LLaMA.cpp in tensor parallel mode
--flash-attn The DeepSeek distill models are just fine tunes of other models do not share the deekseek2 architecture, therefore I can use flash attention to increase inference speed and lower VRAM requirements.
--mlock Prevents model weights from being swapped to disk.
--no-mmap Forces llama.cpp to fully load the model into RAM instead of keeping it on disk.

Regarding --mlock and --no-mmap

Benchmarks

This section shares the benchmark commands I uses and results.

LLaMA.cpp: DeepSeek-R1-UD-IQ1_M

After much testing I decided that batch inferencing with DeepSeek-R1-UD-IQ1_M isn’t feasible. While I was able to attain acceptable performance for short chat sessions, extending the context for batch inferencing is beyond my hardware. Making room in VRAM for additional KV cache means moving some model weights into system RAM and tanking performance.

For this reason, I decided to focus benchmarking on the distill models.

LLaMA.cpp: DeepSeek-R1-Distill-Llama-70B-Q8_0

Benchmark command:

llama-batched-bench \
  --model /models/DeepSeek-R1-Distill-Llama-70B-Q8_0.gguf \
  -npp 3000 \
  -ntg 2000 \
  -npl 23 \
  --batch-size 13000 \
  --ubatch-size 3250 \
  --ctx-size 131072 \
  --flash-attn \
  --n-gpu-layers 82 \
  --split-mode layer \
  --mlock \
  --no-mmap \
  --temp 0.6 \
  --min-p 0.05 \
  --threads 32 \
  --threads-batch 32 \
  --seed 2738 \
  --output-format md

Under this configuration I can handle 23 parallel requests in about 24 minutes at a rate of 79.62 tokens per second:

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
3000 2000 23 115000 306.572 225.07 1137.734 40.43 1444.306 79.62

Not to bad overall and certainly meets the requirements of my use cases.

LLaMA.cpp: DeepSeek-R1-Distill-Llama-70B-B16

Given that DeepSeek-R1-Distill-Llama-70B-B16 requires more VRAM for the model weights, I needed to reduce the context size to 95000 tokens to keep the model in VRAM. Using the previous formulas, I found an optimal value for -npl of \(17\):

llama-batched-bench \
  --model /models/DeepSeek-R1-Distill-Llama-70B-B16.gguf \
  -npp 3000 \
  -ntg 2000 \
  -npl 17 \
  --batch-size 13000 \
  --ubatch-size 1024 \
  --ctx-size 95000 \
  --flash-attn \
  --n-gpu-layers 82 \
  --split-mode layer \
  --mlock \
  --no-mmap \
  --temp 0.6 \
  --min-p 0.05 \
  --threads 32 \
  --threads-batch 32 \
  --seed 2738 \
  --output-format md

As shown in the table below, I was able to achieve about 26 tokens per second, which is expected given the smaller context.

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
3000 2000 17 85000 256.797 198.60 2953.181 11.51 3209.978 26.48

vLLM: DeepSeek-R1-Distill-Llama-70B

For comparison, I decided to test DeepSeek-R1-Distill-Llama-70B in vLLM. Overall, vLLM was about 22% faster than LLaMA.cpp in my test with 14.8 S_TG t/s from vs 11.51 S_TG t/s from LLaMA.cpp. Watching the GPU utilization I saw the typical full GPU utilization from vLLM vs the “spiky” utilization I observed with LLaMA.cpp.

docker run -it -d \
    --name=eleanor-vLLM \
    --restart=unless-stopped \
    --shm-size=15g \
    --ulimit memlock=-1 \
    --ipc=host \
    --entrypoint=python3 \
    --gpus="device=0,1,2,3" \
    --publish=7800:8000 \
    --volume=/models:/models:ro \
    --health-cmd=timeout 5 bash -c 'cat < /dev/null > /dev/tcp/localhost/8000' \
    --health-start-period=240s \
    --health-interval=15s \
    --health-timeout=8s \
    --health-retries=3 \
    --env=OMP_NUM_THREADS=1 \
    harbor.k8s.wm.k8slab/eleanor-ai/vllm-openai:0.6.6.post1 \
        -m vllm.entrypoints.openai.api_server \
        --model /models/DeepSeek-R1-Distill-Llama-70B \
        --served-model-name DeepSeek-R1-Distill-Llama-70B \
        --response-role auto \
        --load-format safetensors \
        --tokenizer-mode auto \
        --enable-chunked-prefill=True \
        --max-num-batched-tokens=4096 \
        --dtype bfloat16 \
        --kv-cache-dtype auto \
        --gpu-memory-utilization 0.90 \
        --enable-auto-tool-choice \
        --tool-call-parser llama3_json \
        --enable-prefix-caching \
        --device=cuda \
        --task=generate \
        --scheduler-delay-factor=0.25 \
        --uvicorn-log-level=debug \
        --distributed-executor-backend=mp \
        --guided-decoding-backend=outlines \
        --disable_custom_all_reduce \
        --max-model-len 60000 \
        --tensor-parallel-size 4 \
        --port 8000 \
        --host 0.0.0.0

Unfortunately, vLLM does not yet support the Deepseek2 architecture for GGUF models. Digging a little deeper, vLLM is just relying on support in transformers. I was able to find a branch and draft PR to support DeepSeekV2 architecture but as of 2025-02-05 it has not been merged.

A comparative study of LLaMA.cpp and other inferencing servers, take a look at Benchmarking LLaMA.cpp vs other inferencing servers

Future Work & Final Thoughts

Obviously the benchmark here is just a starting point and only focuses on batch inferencing speed for my use case. The dimension I haven’t tested yet is accuracy and reasoning capability with respect to my use case. Obviously this is a significant motivator for using DeepSeek-R1-Distill-Llama-70B-B16 in vLLM (fast) or DeepSeek-R1-UD-IQ1_S from LLaMA.cpp (slow). I’ve been thinking about the best way to pull this off and came up with the following:

  • I already have a good amount of rendered prompt templates that are representative of my use case.
  • I can generate text with given these rendered prompts and calculate average response perplexity to quantitatively measure the model’s confidence in its response.
  • Manually verify the top responses (with the lowest perplexity) to ensure they meet my standard.
  • Automate as a custom task for EleutherAI/lm-evaluation-harness. It looks like the wikitext task is doing something similar.

Additionally, when benchmarking, I found a problem with my --ubatch-size calculation where I still ran out of VRAM. I should revisit this at some point and see if I can get a better estimate for the optimal value.

Lastly, I am tempted to fork transformers and merge the deepseek2 architecture support into a version that would work with vLLM so I can do an “apples-to-apples” benchmark of vLLM and LLaMA.cpp on DeepSeek-R1-UD-IQ1_S

Copyright © 2025, The Objective Dad
Updated: