Mascot image.
#bico#cuda#gpu#optimisation#hpc#research

From Unmarked Thesis to a Speedup: A Practical Guide to BICO

I originally cooked up this blog to introduce my MSc thesis framework, BICO, and how it wrangles the insane trade-offs of training LLMs. However, since it seems my thesis is still moonlighting as a very expensive coaster on a professor’s desk, let’s put it to a different test.

This whole situation actually raises the classic question that divides theoretical and applied math: Can a purely abstract framework solve real-world problems?

Well, I pointed BICO at some GPUs kernels; and it turns out, the math works (sometimes).

(Disclaimer: My thesis has not been marked yet. So, this post will be a mathematical exploration, not a detailed research exposition. You get the idea, not the full story.)

(Maybe lol if I’m not lazy)

What is a Budget-Indexed Closure Operator? (Do not Tell My Professor Edition)

BICO provides a language to talk about what one can guarantee when one spends resources. It is built on a few simple yet powerful mathematical objects.

The Core Mathematical Objects

  1. The Outcome Space (X,X)(X, \leq_X): A partially ordered set (a poset) of all possible things that can happen. An “outcome” xXx \in X could be a final numerical result, a generated sentence, or a specific implementation of a piece of code. The relation xXyx \leq_X y means outcome xx is “at most as good as” outcome yy.

  2. The Budget Scale (Λ,)(\Lambda, \leq): A complete lattice representing the resources you can spend. A budget λΛ\lambda \in \Lambda could be training time, compute FLOPs, power draw, or available on-chip memory.

  3. The Operator Family {KλK_\lambda}λΛ_{\lambda \in \Lambda}: For each budget λ\lambda there is an operator Kλ:P(X)P(X)K_\lambda: \mathcal{P}(X) \to \mathcal{P}(X). It takes a set of initial outcomes AA and returns the “closure” of that set.

This family must obey a few axioms, but the most important one is Scott-Continuity in λ\lambda (Axiom A4). It formally states that for any directed set of budgets DΛD \subseteq \Lambda, the guarantees you get at the limit are the union of the guarantees you got along the way.

KsupD(A)=λDKλ(A)K_{\sup D}(A) = \bigcup_{\lambda \in D} K_\lambda(A)

More budget cannot yield worse results, and improvements happen continuously. This axiom gives us a formal basis for concepts like “diminishing returns” and “plateaus”. It ensures our exploration of the resource-performance space is predictable and well-behaved.

The Canonical Construction: Making it Real

To apply this, we need a concrete way to build a KλK_\lambda operator. We start with a probabilistic view.

Let Pr(λ,x)[0,1]\Pr(\lambda, x) \to [0, 1] be the probability of achieving outcome xx with budget λ\lambda. For a chosen reliability threshold p(0,1)p \in (0,1) we define the pp-Guaranteed Region as:

Gλ,p+{xXPr(λ,x)>p}\mathcal{G}_{\lambda, p}^{+} \coloneqq \downarrow\{\,x \in X \mid \Pr(\lambda,x) > p\,\}

  • {xXPr(λ,x)>p}\{\,x \in X \mid \Pr(\lambda,x) > p\,\}: The set of all outcomes whose success probability is strictly greater than our threshold pp. The strict >> is mathematically crucial; it prevents weird edge cases at the limit and ensures the Scott-continuity axiom holds.
  • \downarrow: The downward closure. This means that if we can guarantee a great outcome, we automatically guarantee all the “worst-or-equal” outcomes.

From this, we define the Canonical Operator:

Kλ,p(A)AGλ,p+K_{\lambda,p}(A) \coloneqq A \cup \mathcal{G}_{\lambda, p}^{+}

It takes an input set AA and adds all the outcomes now guaranteed with budget λ\lambda and confidence pp. This KK operator bridges abstract algebra to measurable reality.

First Steps: Kicking the Tyres on a CPU

Let us start with a simple CPU-based matrix multiplication (GEMM) benchmark written in Python. The goal was to see if BICO could find the best tile configurations under a memory budget.

  • Outcome Space XX: The set of tile configurations (tM;tN;tK)(tM; tN; tK).
  • Budget Scale Λ\Lambda: The on-chip memory footprint in bytes required by a tile.

λ=4×(tM×tK+tK×tN+tM×tN).\lambda = 4 \times (tM \times tK + tK \times tN + tM \times tN).

  • The Operator KλK_\lambda: A process of finding all configs that fit within budget λ\lambda and pass a performance test (e.g. latency <T< T).

A typical result from my test script:

BICO CPU demo v2 (kernel: numpy-blocked). Budgets: [8192; 16384; ...; 49152].
Budget in Bytes Number of Valid Configurations
8192 9
16384 33
24576 45
32768 48
40960 48
49152 48

Plateau detected: |K| stopped increasing by λ=40960.

The size of the guaranteed set K|K| grows as we increase the memory budget λ\lambda exactly as the theory predicts. Then, after 32KB, we stop getting better results. The framework automatically detected a plateau. Beyond a 32KB tile footprint, memory size is no longer the bottleneck for this problem.

The Main Event

The CPU test was a reassuring sanity check: BICO hit its performance plateau exactly on time, as the theory argues it should. But the real test is the GPU: memory-bound, mysterious, and often with irrationally small sweet spots.

For this blog, I targeted the Gated Linear Unit (GLU) (inspired by Joe Fioti). There are multiple ways to optimise it, including bypassing the crippling latency of moving data to and from DRAM. But it all starts with the General Matrix Multiply (GEMM). Early kernel design immediately highlighted the fundamental bottleneck: regardless of other optimisations, the kernel’s performance would be governed by the efficiency of its core GEMM operation.

Round One

My initial modelling of (Λ,)(\Lambda, \le) used direct hardware resources like shared memory (SMEM). It failed: results were messy and non-monotonic.

The key problem was interdependence. More SMEM for the data tiles meant fewer concurrent thread blocks (occupancy), leading to a Pareto frontier of incomparable trade-offs, not a simple ordered set where one budget is unambiguously larger. The axioms of BICO assume a clean lattice; my chosen budget scale was fundamentally chaotic. The answer, then, was to move from static kernel states to the dynamics of the kernel search.

Round Two: A BICO Framework for the Search Itself

This new approach required a more abstract formalisation.

  • The Exploration Space (C)(\mathcal{C}): Is our messy, unordered set of all 134 valid, compilable kernel configurations for our problem. A point cCc \in \mathcal{C} is a tuple with form (TILE_M, TILE_N, TILE_K, BLOCK_DIM_X, BLOCK_DIM_Y). It is our haystack.
  • (Λ,)(\Lambda, \leq): The number of configurations we have the resources to compile and benchmark on the GPU. Λ=N0={0,1,2,}\Lambda = \mathbb{N}_0 = \{0, 1, 2, \dots\}. A budget of n=20n=20 means we can afford to test 20 needles. This is a proper, well-behaved ordered set.
  • (X,X)(X, \leq_X): The set of best-achieved kernel latencies. The order is L1XL2    L1L2L_1 \leq_X L_2 \iff L_1 \ge L_2. A better outcome is a faster one.
  • The Information Sink (Sn)(S_n): This is the set of all tested configurations that are not the current champion. The size of the sink grows monotonically with the budget, representing the necessary cost of acquiring knowledge.

This new model enables a principled process of discovery, which can be structured into a two-phase methodology.

Phase 1: The Predictive (Cost: Milliseconds) Here, a “dry” exploration is performed. I defined a simple analytical heuristic, the Reuse Score H(c)H(c), to estimate a kernel’s theoretical efficiency based on its tiling parameters.

H(c)=Operations per TileBytes Loaded per Tile=TMTN2(TM+TN)H(c) = \frac{\text{Operations per Tile}}{\text{Bytes Loaded per Tile}} = \frac{\text{TM} \cdot \text{TN}}{2 \cdot (\text{TM} + \text{TN})}

This model is intentionally naive and blind to tile depth (TK), block dimensions, and all the other weird realities of a GPU architecture. I evaluated this score for all 134 configurations, which took milliseconds, and produced a ranked list. This list serves as an imperfect but intelligent guide for the empirical search.

Phase 2: The Empirical (Cost: Real GPU Time) This is the main event. The testbed is a single NVIDIA RTX 4090 (Ada Lovelace architecture, sm_89), running a large FP32 GEMM of size (M=1024, K=4096, N=12288). Two explorers were deployed, each with a fixed budget of n=20 evaluations (for fairness):

  1. BICO-Guided Explorer: Systematically evaluates the top 20 configurations from the predictively ranked list.
  2. Random Explorer (Baseline): Evaluates 20 configurations drawn randomly from the same 134-candidate space.

Results

Here are the complete logs from both explorers. The monotonic nature of the Best Latency and Sink Size columns in both tables shows the BICO framework correctly describing the process. The difference in the quality of the outcomes is the primary result.

BICO-Guided Explorer Log The search starts by testing the heuristic’s top picks. It plateaus, but then breaks through twice to find a much better solution.

Budget(n) Latency (ms) TFLOP/s Best Latency Sink Size Configuration Tested
120.26045.0920.26040TM=32, TN=32, TK=64, BX=32, BY=32
221.12904.8820.26041TM=32, TN=32, TK=32, BX=32, BY=32
328.67863.5920.26042TM=32, TN=32, TK=16, BX=32, BY=32
443.40862.3720.26043TM=32, TN=32, TK=8, BX=32, BY=32
562.60521.6520.26044TM=32, TN=32, TK=4, BX=32, BY=32
621.12464.8820.26045TM=16, TN=64, TK=64, BX=64, BY=16
724.75874.1620.26046TM=16, TN=64, TK=32, BX=64, BY=16
829.71143.4720.26047TM=16, TN=64, TK=16, BX=64, BY=16
944.55282.3120.26048TM=16, TN=64, TK=8, BX=64, BY=16
1064.26521.6020.26049TM=16, TN=64, TK=4, BX=64, BY=16
1140.86502.5220.260410TM=64, TN=16, TK=8, BX=16, BY=64
1227.66883.7320.260411TM=64, TN=16, TK=16, BX=16, BY=64
1322.08184.6720.260412TM=64, TN=16, TK=32, BX=16, BY=64
1417.86715.7717.867113TM=64, TN=16, TK=64, BX=16, BY=64
1561.93221.6617.867114TM=64, TN=16, TK=4, BX=16, BY=64
1616.60376.2116.603715TM=16, TN=32, TK=64, BX=32, BY=16
1718.14535.6816.603716TM=16, TN=32, TK=32, BX=32, BY=16
1819.14585.3816.603717TM=16, TN=32, TK=16, BX=32, BY=16
1926.25513.9316.603718TM=16, TN=32, TK=8, BX=32, BY=16
2039.32252.6216.603719TM=16, TN=32, TK=4, BX=32, BY=16

Random (Baseline) Explorer Log The random search gets lucky early, finding a mediocre local optimum. It then spends the next 17 steps completely lost, failing to find any improvement.

Budget(n) Latency (ms) TFLOP/s Best Latency Sink Size Configuration Tested
132.93123.1332.93120TM=4, TN=128, TK=8, BX=128, BY=4
225.14164.1025.14161TM=128, TN=4, TK=64, BX=4, BY=128
323.07274.4723.07272TM=128, TN=4, TK=16, BX=4, BY=128
4137.53200.7523.07273TM=4, TN=4, TK=4, BX=4, BY=4
..................
2047.09172.1923.072719TM=8, TN=4, TK=16, BX=4, BY=8

Summary of Findings

Metric Sequential Explorer Random Explorer BICO-Guided Explorer
Best Latency Found 65.10 ms 23.07 ms 16.60 ms
Best Performance 1.58 TFLOP/s 4.47 TFLOP/s 6.21 TFLOP/s
Relative Speedup 1.0x (Baseline) 2.8x vs Sequential 3.9x vs Sequential | 1.39x vs Random

Discussion

Is 6.21 TFLOP/s Actually Any Good?

To ground these results, the winning kernel’s performance was compared to established benchmarks.

NVIDIA’s hand-tuned cuBLAS library, achieves a staggering 25-30 TFLOP/s on this hardware for large matrices. However, it achieves this by leveraging specialised hardware units (Tensor Cores) with lower-precision formats like FP16 or BF16. My exploration space, by design, contained only pure CUDA kernels using standard FP32 arithmetic on the general-purpose CUDA cores. This is a fundamentally different and more constrained problem.

The correct peer group for comparison, therefore, is other academic and reference implementations of scalar FP32 tiled GEMM kernels. When compared within this proper context, the performance is right where it should be.

  • Public benchmarks and examples from libraries like CUTLASS show that basic, non-expert-tuned FP32 tiled kernels on this architecture typically achieve 4-7 TFLOP/s.
  • Research from auto-tuning frameworks like TVM shows similar performance for their baseline tiled kernels, before their own search algorithms are applied.

The BICO-found result of 6.21 TFLOP/s sits comfortably and competitively within this range. It is important to stress that this performance was achieved with a fundamentally basic setup. The kernel under test, matrix_multiply_kernel, is a straightforward tiled implementation with coalesced loads but no further optimisations like prefetching, software pipelining, or register tiling. Likewise, the predictive Reuse Score heuristic is a crude, first-order approximation of performance, completely blind to microarchitectural details.

This simplicity is the point. The BICO-guided search managed to find a verifiably competitive configuration without needing a highly optimised kernel or a sophisticated performance model. This demonstrates that the value is derived from the two-phase search methodology itself.

Where BICO Fits In (And Why It’s Not Just a Worse Auto-Tuner)

The point of this isn’t to build a better cuBLAS. The value of the BICO framework is distinct.

  • The Predictive Phase: The defining feature demonstrated here is the predictive phase. Unlike a generic auto-tuner that starts blind, BICO provides a formal mechanism to incorporate cheap, theoretical domain knowledge before the expensive search begins. The experiment showed that even a naive heuristic can provide enough guidance to avoid poor local optima.
  • Transparency: BICO is a “glass box”. The ranked list from the predictive phase and the monotonic logs from the empirical phase provide a complete, auditable trail of the search. One sees not just the result, but the reasoning and the cost.
  • Adaptability for Novel Kernels: BICO’s primary strength lies in optimising user-defined, novel kernels (like a fully fused GLU) where no hand-tuned library exists. For these complex, bespoke problems, an intelligent, guided search is invaluable.

5.3. What’s Next? Back to the GLU and Beyond

This entire exercise was a deep dive to solve the hardest part of the original problem. The discovered optimal GEMM configuration of (TM=16, TN=32, TK=64, BX=32, BY=16) is now the empirically-proven blueprint for building the final, high-performance fused GLU kernel.

The framework itself can also be extended. By expanding the Exploration Space to include fundamentally different kernel types (e.g., those using Tensor Cores for FP16 math) and updating the predictive heuristic, BICO could search across hardware capabilities to find solutions well beyond 10 TFLOP/s. Whether we can actually harness that power… you’ll have to find out on the next episode.

6. Conclusion

The initial question was whether an abstract framework could solve real-world problems. This investigation provides evidence that it can. The BICO framework was used to instantiate a concrete, two-phase auto-tuning programme. By separating a cheap predictive search from a budget-constrained empirical search, the methodology discovered a GEMM kernel configuration on an NVIDIA RTX 4090 that was 39% faster than an unguided random search and nearly 4x faster than a simple sequential search. The results are consistent with established performance benchmarks and validate this approach as a transparent and principled tool for navigating complex optimisation landscapes. It seems the expensive coaster on my professor’s desk might just be worth something after all.


Code

generate_dispatcher.py

import itertools


def generate_search_space():
    space = set()
    tile_m_opts = [4, 8, 16, 32, 64, 128]
    tile_n_opts = [4, 8, 16, 32, 64, 128, 256]
    tile_k_opts = [4, 8, 16, 32, 64]
    MAX_THREADS_PER_BLOCK = 1024
    MAX_SHARED_MEM_BYTES = 48 * 1024
    for tm, tn, tk in itertools.product(tile_m_opts, tile_n_opts, tile_k_opts):
        bx = tn
        by = tm
        if bx * by > MAX_THREADS_PER_BLOCK:
            continue
        shared_mem_needed = (tm * tk + tk * tn) * 4
        if shared_mem_needed > MAX_SHARED_MEM_BYTES:
            continue
        space.add((tm, tn, tk, bx, by))
    return sorted(list(space))


CONFIGURATIONS = generate_search_space()


def write_configs_to_file():
    with open("configurations.txt", "w") as f:
        for config in CONFIGURATIONS:
            f.write(f"{config[0]} {config[1]} {config[2]} {config[3]} {config[4]}\n")


def generate_header():
    header_content = """
#pragma once
#include "KernelConfig.h"
#include "kernel.cuh"
#include <iostream>
// THIS FILE IS AUTO-GENERATED BY generate_dispatcher.py
// DO NOT EDIT MANUALLY
#define CUDA_CHECK(call) \\
  do { \\
    cudaError_t err = call; \\
    if (err != cudaSuccess) { \\
      fprintf(stderr, "CUDA Error at %s:%d: %s\\n", __FILE__, __LINE__, \\
              cudaGetErrorString(err)); \\
      exit(EXIT_FAILURE); \\
    } \\
  } while (0)
void launch_kernel_with_config(const KernelConfig &config, float *d_C,
                              const float *d_A, const float *d_B, int M, int N,
                              int K) {
    dim3 blockDim(config.BLOCK_DIM_X, config.BLOCK_DIM_Y);
    dim3 gridDim((N + config.TILE_N - 1) / config.TILE_N,
                 (M + config.TILE_M - 1) / config.TILE_M);
    size_t shared_size = (static_cast<size_t>(config.TILE_M) * config.TILE_K +
                          static_cast<size_t>(config.TILE_K) * config.TILE_N) *
                         sizeof(float);
"""
    first = True
    for tm, tn, tk, bx, by in CONFIGURATIONS:
        condition = f"config.TILE_M == {tm} && config.TILE_N == {tn} && config.TILE_K == {tk} && config.BLOCK_DIM_X == {bx} && config.BLOCK_DIM_Y == {by}"
        launch = f"    matrix_multiply_kernel<{tm}, {tn}, {tk}, {bx}, {by}><<<gridDim, blockDim, shared_size>>>(d_C, d_A, d_B, M, N, K);"
        if first:
            header_content += f"    if ({condition}) {{\n{launch}\n    }}"
            first = False
        else:
            header_content += f"    else if ({condition}) {{\n{launch}\n    }}"
    header_content += """
    else {
        std::cerr << "FATAL: Unsupported kernel configuration: " << config.toString() << std::endl;
        exit(1);
    }
    CUDA_CHECK(cudaGetLastError());
}
"""
    with open("Evaluator_dispatcher.h", "w") as f:
        f.write(header_content)


if __name__ == "__main__":
    print(f"Generated a search space of {len(CONFIGURATIONS)} configurations.")
    write_configs_to_file()
    generate_header()
    print("Generated configurations.txt and Evaluator_dispatcher.h successfully.")

KernelConfig.h

#pragma once
#include <sstream>
#include <string>

struct KernelConfig {
  int TILE_M      = 0;
  int TILE_N      = 0;
  int TILE_K      = 0;
  int BLOCK_DIM_X = 0;
  int BLOCK_DIM_Y = 0;

  std::string toString() const {
    std::stringstream ss;
    ss << "TM=" << TILE_M << ", TN=" << TILE_N << ", TK=" << TILE_K
       << ", BX=" << BLOCK_DIM_X << ", BY=" << BLOCK_DIM_Y;
    return ss.str();
  }

  bool isValid() const {
    return TILE_M > 0 && TILE_N > 0 && TILE_K > 0 && BLOCK_DIM_X > 0 &&
           BLOCK_DIM_Y > 0 && BLOCK_DIM_X * BLOCK_DIM_Y <= 1024;
  }
};

Heuristic.h

#pragma once
#include "KernelConfig.h"
#include <algorithm>
#include <vector>

double calculate_reuse_score(const KernelConfig &c) {
  if (c.TILE_M <= 0 || c.TILE_N <= 0 || c.TILE_K <= 0 || c.BLOCK_DIM_X <= 0 ||
      c.BLOCK_DIM_Y <= 0) 
    return 0.0;
  
  double ops   = 2.0 * c.TILE_M * c.TILE_N * c.TILE_K;
  double bytes = 4.0 * (c.TILE_M * c.TILE_K + c.TILE_K * c.TILE_N);
  return bytes == 0.0 ? 0.0 : ops / bytes;
}

std::vector<KernelConfig> predictive_search(std::vector<KernelConfig> space) {
  std::sort(space.begin(), space.end(),
            [](const KernelConfig &a, const KernelConfig &b) {
              return calculate_reuse_score(a) > calculate_reuse_score(b);
            });
  return space;
}

kernel.cuh

#pragma once
#include <cuda_runtime.h>

template <int TILE_M, int TILE_N, int TILE_K, int BLOCK_DIM_X, int BLOCK_DIM_Y>
__global__ void matrix_multiply_kernel(float *C, const float *A, const float *B,
                                       int M, int N, int K) {
  static_assert(BLOCK_DIM_X >= TILE_N, "BLOCK_DIM_X must be at least TILE_N");
  static_assert(BLOCK_DIM_Y >= TILE_M, "BLOCK_DIM_Y must be at least TILE_M");
  extern __shared__ float smem[];
  float *As             = smem;
  float *Bs             = smem + TILE_M * TILE_K;
  const int thread_idx  = threadIdx.y * blockDim.x + threadIdx.x;
  const int block_row   = blockIdx.y;
  const int block_col   = blockIdx.x;
  const int num_threads = BLOCK_DIM_X * BLOCK_DIM_Y;
  const int c_row       = block_row * TILE_M + threadIdx.y;
  const int c_col       = block_col * TILE_N + threadIdx.x;
  float accumulator     = 0.0f;

  for (int k_tile_idx = 0; k_tile_idx < K; k_tile_idx += TILE_K) {
    for (int i = thread_idx; i < TILE_M * TILE_K; i += num_threads) {
      const int load_row = i / TILE_K;
      const int load_col = i % TILE_K;
      const int gmem_row = block_row * TILE_M + load_row;
      const int gmem_col = k_tile_idx + load_col;
      if (gmem_row < M && gmem_col < K)
        As[load_row * TILE_K + load_col] = A[gmem_row * K + gmem_col];
      else As[load_row * TILE_K + load_col] = 0.0f;
    }
    for (int i = thread_idx; i < TILE_K * TILE_N; i += num_threads) {
      const int load_row = i / TILE_N;
      const int load_col = i % TILE_N;
      const int gmem_row = k_tile_idx + load_row;
      const int gmem_col = block_col * TILE_N + load_col;

      if (gmem_row < K && gmem_col < N)
        Bs[load_row * TILE_N + load_col] = B[gmem_row * N + gmem_col];
      else Bs[load_row * TILE_N + load_col] = 0.0f;
    }
    __syncthreads();
    for (int k = 0; k < TILE_K; ++k)
      accumulator += As[threadIdx.y * TILE_K + k] * Bs[k * TILE_N + threadIdx.x];
    __syncthreads();
  }
  if (c_row < M && c_col < N) C[c_row * N + c_col] = accumulator;
}

Evaluator.h

#pragma once
#include "KernelConfig.h"
#include <cuda_runtime.h>
class Evaluator {
public:
  Evaluator(float *d_A, const float *d_B, float *d_C, int M, int N, int K);
  double evaluate(const KernelConfig &config);

private:
  const float *d_B;
  float *d_A, *d_C;
  int M, N, K;
};

Evaluator.cu

#include "Evaluator.h"
#include "Evaluator_dispatcher.h"
#include <iostream>

Evaluator::Evaluator(float *d_A_, const float *d_B_, float *d_C_, int M_,
                     int N_, int K_)
    : d_A(d_A_), d_B(d_B_), d_C(d_C_), M(M_), N(N_), K(K_) {}

double Evaluator::evaluate(const KernelConfig &config) {
  cudaEvent_t start, stop;
  CUDA_CHECK(cudaEventCreate(&start));
  CUDA_CHECK(cudaEventCreate(&stop));
  launch_kernel_with_config(config, d_C, d_A, d_B, M, N, K);
  CUDA_CHECK(cudaDeviceSynchronize());
  CUDA_CHECK(cudaEventRecord(start));
  const int num_runs = 100;
  float milliseconds = 0;
  for (int i = 0; i < num_runs; ++i)
    launch_kernel_with_config(config, d_C, d_A, d_B, M, N, K);

  CUDA_CHECK(cudaEventRecord(stop));
  CUDA_CHECK(cudaEventSynchronize(stop));
  CUDA_CHECK(cudaEventElapsedTime(&milliseconds, start, stop));
  CUDA_CHECK(cudaEventDestroy(start));
  CUDA_CHECK(cudaEventDestroy(stop));
  return static_cast<double>(milliseconds) / num_runs;
}

BICOExplorer.h

#pragma once
#include "Evaluator.h"
#include "KernelConfig.h"
#include <string>
#include <vector>
class BICOExplorer {
private:
  std::vector<KernelConfig> exploration_space_;
  double best_latency_;
  KernelConfig best_config_;
  std::vector<KernelConfig> information_sink_;
  Evaluator evaluator_;
  std::string explorer_name_;
  bool shuffle_;

public:
  BICOExplorer(std::string name, std::vector<KernelConfig> search_space,
               Evaluator evaluator, bool shuffle = true);
  void explore(int max_budget);
};

BICOExplorer.cu

#include "BICOExplorer.h"
#include <algorithm>
#include <cstdio>
#include <iomanip>
#include <iostream>
#include <limits>
#include <random>
#include <vector>

BICOExplorer::BICOExplorer(std::string name,
                           std::vector<KernelConfig> search_space,
                           Evaluator evaluator, bool shuffle)
    : exploration_space_(std::move(search_space)),
      best_latency_(std::numeric_limits<double>::max()), best_config_(),
      evaluator_(std::move(evaluator)), explorer_name_(std::move(name)),
      shuffle_(shuffle) {}

void print_progress_bar(double current_latency, double best_latency) {
  std::cout << " Perf: [";
  int bar_width = 50;
  if (current_latency < best_latency) {
    std::cout << "\033[1;32m";
    for (int i = 0; i < bar_width; ++i)
      std::cout << "*";
    std::cout << "\033[0m] NEW BEST!";
  } else {
    std::cout << "\033[1;31m";
    bar_width = static_cast<int>(50.0 * best_latency / current_latency);
    for (int i = 0; i < bar_width; ++i)
      std::cout << "|";
    for (int i = bar_width; i < 50; ++i)
      std::cout << " ";
    std::cout << "\033[0m]";
  }
  std::cout << std::endl;
}

void BICOExplorer::explore(int max_budget) {
  const long long M = 1024, N = 12288, K = 4096;
  std::cout << "Starting " << explorer_name_
            << " exploration with budget: " << max_budget << std::endl;
  std::cout << "Exploration space size: " << exploration_space_.size()
            << std::endl;
  if (shuffle_) {
    std::random_device rd;
    std::mt19937 g(rd());
    std::shuffle(exploration_space_.begin(), exploration_space_.end(), g);
  }
  int budget =
      std::min(static_cast<int>(exploration_space_.size()), max_budget);

  bool first_run = true;
  printf("\n");
  printf("|===================================================================="
         "=======================|\n");
  printf("| BICO EXPLORATION LOG (%-12s) |\n", explorer_name_.c_str());
  printf("|===================================================================="
         "=======================|\n");
  printf("| Budget(n) | Latency (ms) | TFLOP/s | Best Latency | Sink Size | "
         "Configuration Tested        |\n");
  printf("|-----------|--------------|---------|--------------|-----------|----"
         "------------------------|\n");
  for (int n = 1; n <= budget; ++n) {
    KernelConfig current_config = exploration_space_[n - 1];
    double current_latency      = evaluator_.evaluate(current_config);
    double tflops               = (2.0 * M * N * K) / (current_latency * 1e-3) / 1e12;
    bool is_new_best            = false;

    if (current_latency < best_latency_) {
      if (!first_run) information_sink_.push_back(best_config_);
      best_latency_ = current_latency;
      best_config_  = current_config;
      is_new_best   = true;
      first_run     = false;
    } else information_sink_.push_back(current_config);

    printf("| %-9d | %-12.4f | %-7.2f | %-12.4f | %-9zu | %-26s |\n", n,
           current_latency, tflops, best_latency_, information_sink_.size(),
           current_config.toString().c_str());
    if (is_new_best) {
      size_t smem =
          (static_cast<size_t>(current_config.TILE_M) * current_config.TILE_K +
           static_cast<size_t>(current_config.TILE_K) * current_config.TILE_N) *
          4;
      printf("| %-9s | ==> NEW BEST! %.2f TFLOP/s, %zu bytes SMEM. \033[0m|\n",
             "", tflops, smem);
    }
    print_progress_bar(current_latency, best_latency_);
  }
  printf("|===================================================================="
         "=======================|\n\n");
  double best_tflops = (2.0 * M * N * K) / (best_latency_ * 1e-3) / 1e12;
  std::cout << "===== " << explorer_name_ << " Exploration Finished =====\n"
            << "Optimal configuration: " << best_config_.toString() << "\n"
            << "Latency: " << best_latency_ << " ms\n"
            << "Performance: " << std::fixed << std::setprecision(2)
            << best_tflops << " TFLOP/s\n";
}

main.cu

#include "BICOExplorer.h"
#include "Evaluator.h"
#include "Heuristic.h"
#include "KernelConfig.h"
#include "kernel.cuh"
#include <cuda_runtime.h>
#include <fstream>
#include <iomanip>
#include <iostream>
#include <random>
#include <string>
#include <vector>

#define CUDA_CHECK(call)                                                       \
  do {                                                                         \
    cudaError_t err = call;                                                    \
    if (err != cudaSuccess) {                                                  \
      fprintf(stderr, "CUDA Error at %s:%d: %s\n", __FILE__, __LINE__,         \
              cudaGetErrorString(err));                                        \
      exit(EXIT_FAILURE);                                                      \
    }                                                                          \
  } while (0)

std::vector<KernelConfig>
load_search_space_from_file(const std::string &filename) {
  std::vector<KernelConfig> space;
  std::ifstream infile(filename);
  if (!infile.is_open()) {
    std::cerr << "FATAL: Could not open configuration file: " << filename
              << std::endl;
    exit(1);
  }
  KernelConfig config;
  while (infile >> config.TILE_M >> config.TILE_N >> config.TILE_K >>
         config.BLOCK_DIM_X >> config.BLOCK_DIM_Y) {
    if (config.isValid())
      space.push_back(config);
    else
      std::cerr << "Warning: Invalid configuration skipped: "
                << config.toString() << std::endl;
  }
  std::cout << "Loaded a search space of " << space.size()
            << " valid configurations from " << filename << "." << std::endl;
  return space;
}

int main() {
  const int M = 1024;
  const int K = 4096;
  const int N = 12288;
  int budget  = 20;

  float *d_A, *d_B, *d_C;

  std::cout << "Matrix dimensions: " << M << " x " << K << " * " << K << " x "
            << N << std::endl;

  std::vector<float> h_A(M * K);
  std::vector<float> h_B(K * N);
  std::mt19937 rng(1337);
  std::uniform_real_distribution<float> dist(-1.0f, 1.0f);

  for (auto &val : h_A) val = dist(rng);
  for (auto &val : h_B) val = dist(rng);

  CUDA_CHECK(cudaMalloc(&d_A, M * K * sizeof(float)));
  CUDA_CHECK(cudaMalloc(&d_B, K * N * sizeof(float)));
  CUDA_CHECK(cudaMalloc(&d_C, M * N * sizeof(float)));
  CUDA_CHECK(cudaMemcpy(d_A, h_A.data(), M * K * sizeof(float),
                        cudaMemcpyHostToDevice));
  CUDA_CHECK(cudaMemcpy(d_B, h_B.data(), K * N * sizeof(float),
                        cudaMemcpyHostToDevice));
  std::cout << "\n===== [Phase 1: Predictive BICO Exploration] =====\n";
  std::vector<KernelConfig> full_search_space =
      load_search_space_from_file("configurations.txt");
  std::vector<KernelConfig> guided_search_space =
      predictive_search(full_search_space);
  std::cout << "Predictive model has ranked " << guided_search_space.size()
            << " configurations.\n";

  std::ofstream ranked_file("ranked_configurations.txt");
  if (ranked_file.is_open()) {
    ranked_file << "Ranked Configurations by Reuse Score (Phase 1 Predictive "
                   "BICO Exploration):\n\n";
    for (size_t i = 0; i < guided_search_space.size(); ++i) {
      double score = calculate_reuse_score(guided_search_space[i]);
      ranked_file << "#" << (i + 1) << ": " << guided_search_space[i].toString()
                  << " (Score: " << std::fixed << std::setprecision(2) << score
                  << ")\n";
    }
    ranked_file.close();
    std::cout
        << "All ranked configurations saved to ranked_configurations.txt\n";
  } else 
    std::cerr
        << "Warning: Could not open ranked_configurations.txt for writing.\n";

  std::cout << "Top 5 most promising candidates:\n";
  for (int i = 0; i < 5 && i < guided_search_space.size(); ++i) 
    std::cout << " #" << i + 1 << ": " << guided_search_space[i].toString()
              << " (Reuse Score: " << std::fixed << std::setprecision(2)
              << calculate_reuse_score(guided_search_space[i]) << ")\n";
  
  std::vector<KernelConfig> random_search_space = full_search_space;
  std::random_device rd;
  std::mt19937 g(rd());
  std::shuffle(random_search_space.begin(), random_search_space.end(), g);
  std::cout << "\n===== [Phase 2: Empirical Head-to-Head Exploration] =====\n";

  Evaluator evaluator(d_A, d_B, d_C, M, N, K);
  std::cout << "\n--- [Running BICO-GUIDED Explorer] ---\n";
  BICOExplorer explorer_guided("GUIDED", guided_search_space, evaluator, false);
  explorer_guided.explore(budget);
  std::cout << "\n--- [Running RANDOM (Baseline) Explorer] ---\n";
  BICOExplorer explorer_random("RANDOM", random_search_space, evaluator, true);
  explorer_random.explore(budget);
  CUDA_CHECK(cudaFree(d_A));
  CUDA_CHECK(cudaFree(d_B));
  CUDA_CHECK(cudaFree(d_C));
  return 0;
}

owo what are u still doing here? Thanks for reading BYEE