<?xml version="1.0" encoding="UTF-8"?>
<rss  xmlns:atom="http://www.w3.org/2005/Atom" 
      xmlns:media="http://search.yahoo.com/mrss/" 
      xmlns:content="http://purl.org/rss/1.0/modules/content/" 
      xmlns:dc="http://purl.org/dc/elements/1.1/" 
      version="2.0">
<channel>
<title>Emilio Cantú</title>
<link>https://ecntu.com/posts/</link>
<atom:link href="https://ecntu.com/posts/index.xml" rel="self" type="application/rss+xml"/>
<description></description>
<generator>quarto-1.9.36</generator>
<lastBuildDate>Sat, 11 Apr 2026 04:00:00 GMT</lastBuildDate>
<item>
  <title>Tiny Recursive Models Pt. 1</title>
  <link>https://ecntu.com/posts/trm/</link>
  <description><![CDATA[ 





<section id="trms-in-a-nutshell" class="level2">
<h2 class="anchored" data-anchor-id="trms-in-a-nutshell">TRMs in a nutshell</h2>
<p>A few months ago Hierarchical Reasoning Models (HRMs) showed remarkable ARC performance for their relatively tiny (27M) parameter count. While HRMs introduced a lot of tricks, ablations performed by ARC’s team showed that one (<em>“deep supervision”</em>) accounted for most of the gains. By focusing on <em>“deep supervision”</em>, TRMs greatly simplified and outperformed HRMs with a quarter of the parameters.</p>
<p>TRMs can get away with such small models because they emulate much bigger ones by recursively applying the network on its outputs. They maintain a pair of latent states, <img src="https://latex.codecogs.com/png.latex?z"> and <img src="https://latex.codecogs.com/png.latex?y">, and refine them until the answer to the input puzzle is predicted by decoding <img src="https://latex.codecogs.com/png.latex?y">. Hence we can think of <img src="https://latex.codecogs.com/png.latex?y"> as maintaining the embedded answer, which frees up <img src="https://latex.codecogs.com/png.latex?z"> to do the “latent reasoning”.</p>
<p>A naive way to train such a model is to backpropagate the loss after all recursions. However, you quickly run out of memory as the model and number of recursions grow. Instead, deep supervision performs a chunk of the recursions (one <code>deep_recursion</code> call in the paper), computes the loss and performs an optimizer step. Then, the latents are detached and used in the next chunk of recursions.</p>
<p>The whole algorithm can be described with 20 lines:</p>
<div class="quarto-figure quarto-figure-center">
<figure class="figure">
<div class="quarto-figure quarto-figure-center">
<figure class="figure">
<p><a href="images/trm-algo.png" class="lightbox" data-gallery="quarto-lightbox-gallery-1" title="The full TRM training loop"><img src="https://ecntu.com/posts/trm/images/trm-algo.png" class="img-fluid quarto-figure quarto-figure-center figure-img" width="500" alt="The full TRM training loop"></a></p>
</figure>
</div>
<figcaption>The full TRM training loop</figcaption>
</figure>
</div>
<p>I also made a diagram which omits details but helped me internalize the different recursion levels:</p>
<div class="quarto-figure quarto-figure-center">
<figure class="figure">
<p><a href="images/trm-diagram.svg" class="lightbox" data-gallery="quarto-lightbox-gallery-2" title="A diagram of TRM’s recursion levels"><img src="https://ecntu.com/posts/trm/images/trm-diagram.svg" class="img-fluid figure-img" alt="A diagram of TRM’s recursion levels"></a></p>
<figcaption>A diagram of TRM’s recursion levels</figcaption>
</figure>
</div>
<p>Note that since deep supervision drives most of the performance the algorithm could have been simplified even more by defining:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb1" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> deep_recursion(x, y, z, K):</span>
<span id="cb1-2">  <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> i <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(K):</span>
<span id="cb1-3">    y, z <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> net(x, y, z)</span>
<span id="cb1-4">  <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> (y.detach(), z.detach()), output_head(y), Q_head(y)</span></code></pre></div></div>
<p>However, there are a couple of reasons why the paper does not:</p>
<ul>
<li>Because the network is kept single headed for simplicity, it cannot update <img src="https://latex.codecogs.com/png.latex?z"> and <img src="https://latex.codecogs.com/png.latex?y"> in a single forward pass.</li>
<li>Related is that <code>latent_recursion</code> allows <img src="https://latex.codecogs.com/png.latex?z"> to be updated more frequently than <img src="https://latex.codecogs.com/png.latex?y">. We might think that the network is allowed a few “scratchpad” iterations before having to commit to a revised answer. (Note that while intuitive and probably correct, the paper doesn’t perform ablations directly changing the latent’s relative update frequency).</li>
<li>Lastly, the paper recurses a few times (<img src="https://latex.codecogs.com/png.latex?T-1">) without gradients before the final call to <code>latent_recursion</code> with gradients. The argument here is that since deep supervision trains the model to be a “local improver”, it should benefit from running a few extra recursions – even without gradients. However, too many recursions without gradients seem to hurt, as Table 3 shows.</li>
</ul>
<div class="quarto-figure quarto-figure-center">
<figure class="figure">
<div class="quarto-figure quarto-figure-center">
<figure class="figure">
<p><a href="images/trm-table-3.png" class="lightbox" data-gallery="quarto-lightbox-gallery-3" title="Table 3 from the TRM paper (my highlights)"><img src="https://ecntu.com/posts/trm/images/trm-table-3.png" class="img-fluid quarto-figure quarto-figure-center figure-img" width="400" alt="Table 3 from the TRM paper (my highlights)"></a></p>
</figure>
</div>
<figcaption>Table 3 from the TRM paper (my highlights)</figcaption>
</figure>
</div>
<p>The last important detail is that deep supervision stops when the network has &gt;50% confidence its predicted solution is correct. This “early-stopping” is turned off during testing for performance but makes training more efficient:</p>
<blockquote class="blockquote">
<p>“ACT greatly diminishes the time spent per example (on average spending less than 2 steps on the Sudoku-Extreme dataset rather than the full <img src="https://latex.codecogs.com/png.latex?N_%5Ctext%7Bsup%7D%20=%2016"> steps), allowing more coverage of the dataset given a fixed number of training iterations.”</p>
</blockquote>
<p>So, that is TRM in a nutshell. The paper does experiments on several tasks (sudokus, mazes, and ARC) and performs lots of ablations (number of latents, architecture, early-stopping, etc.) to show that the choices they landed on are optimal. It also builds up the final design by simplifying HRMs step-by-step and it’s a very pleasant read.</p>
<p>After its publication, others have found simple tweaks that improve performance:</p>
<ul>
<li><a href="https://x.com/ritteradam/status/1982190450711642300?s=20">These</a> <a href="https://x.com/huskydogewoof/status/1982503109042831472?s=20">posts</a> found that simply increasing <img src="https://latex.codecogs.com/png.latex?N_%5Ctext%7Bsup%7D"> at test time increases sudoku performance from ~87% to ~96%.</li>
<li><a href="https://arxiv.org/abs/2511.08653">This</a> paper trains the network about 2x faster by using a curriculum on the number of recursions. Instead of the paper’s <img src="https://latex.codecogs.com/png.latex?(n,T)%20=%20(6,3)">, they do <img src="https://latex.codecogs.com/png.latex?(2,1)%5Crightarrow(4,2)%5Crightarrow(6,3)">.</li>
<li><a href="https://arxiv.org/pdf/2511.16886v4">This</a> paper found that if your puzzle admits contractive “milestones” that progressively zero in on the solution, supervising sequentially on those milestones performs better than supervising only on the final solution.</li>
</ul>
<p>I’m sure there are many more, but those caught my eye. Now, let’s try a few simple experiments!</p>
</section>
<section id="experiments" class="level2">
<h2 class="anchored" data-anchor-id="experiments">Experiments</h2>
<p>We’ll be using a very simple JAX <a href="https://github.com/ecntu/trm-jax">implementation</a> to use the compute provided by <a href="https://sites.research.google/trc/about/">TRC</a> (thank you!) and focus on the sudoku-extreme dataset for now.</p>
<section id="random-latent-inits-for-best-of-k-at-inference" class="level3">
<h3 class="anchored" data-anchor-id="random-latent-inits-for-best-of-k-at-inference">Random latent inits for best-of-k at inference</h3>
<p>Since increasing the reasoning “depth” (<img src="https://latex.codecogs.com/png.latex?N_%5Ctext%7Bsup%7D">) at test time worked so well, could increasing “breadth” help?</p>
<p>The paper’s implementation initializes <img src="https://latex.codecogs.com/png.latex?z"> and <img src="https://latex.codecogs.com/png.latex?y"> to random values that are chosen and fixed <em>at model initialization</em>. That is, every forward pass during training and testing starts with the same latents. A simple way to try to get “breadth” is to see if different starting latents are mapped to diverse predictions.</p>
<p>We could then combine or choose among the diverse predictions to make a final one. Normally you would take the majority vote (as <a href="https://arxiv.org/abs/1912.02781">test time augmentation (TTA)</a> in vision, or <a href="https://arxiv.org/abs/2203.11171">self-consistency</a> in LLMs), but using the model’s own confidence in its prediction (using the halting head) should make more sense here. The paper actually uses both for the TTA they do for ARC, taking the majority vote and breaking ties with model confidence.</p>
<p>Note that it’s not clear that different initializations <em>will</em> get mapped to diverse solutions. It could be that they are ignored because the model maps them to the correct (unique) solution. Or the latents could get numerically “washed away” in the recursion, especially the first <img src="https://latex.codecogs.com/png.latex?T-1"> latent recursions without gradient.</p>
<p>We train networks with random latent inits and use the same forward passes to track the performance of predicting the cell-wise majority vote, keeping the whole puzzle the model is most confident in, and just using one of the forward passes as a baseline. Recycling the forward pass makes this the least noisy experiment we’ll make, since training run-to-run variance is really high. However, we also separately train networks with static latent inits as another baseline.</p>
<div class="quarto-figure quarto-figure-center">
<figure class="figure">
<p><a href="images/randomized-models-validation-curves.png" class="lightbox" data-gallery="quarto-lightbox-gallery-4"><img src="https://ecntu.com/posts/trm/images/randomized-models-validation-curves.png" class="img-fluid quarto-figure quarto-figure-center figure-img" width="500"></a></p>
</figure>
</div>
<p>After training we evaluate models with a chunk of the test set using each method’s best validation checkpoint and let <img src="https://latex.codecogs.com/png.latex?N_%5Ctext%7Bsup%7D"> and <img src="https://latex.codecogs.com/png.latex?k"> grow.</p>
<p><a href="images/randomized-models-test-curves.png" class="lightbox" data-gallery="quarto-lightbox-gallery-5"><img src="https://ecntu.com/posts/trm/images/randomized-models-test-curves.png" class="img-fluid"></a></p>
<p>Some observations:</p>
<ul>
<li>Cell-wise mode seems silly to have tried out in retrospect since cells in sudokus are not independent. As the paper did for ARC, I should have taken puzzle-wise modes. I’ll try to run these at some point.</li>
<li>Random inits paired with model confidence do yield gains and they are most pronounced around (or slightly below) the recursion budget used during training.</li>
<li>As with most test-time scaling methods, increasing <img src="https://latex.codecogs.com/png.latex?k"> (and <img src="https://latex.codecogs.com/png.latex?N_%5Ctext%7Bsup%7D">) has diminishing returns.</li>
<li>Methods converge in performance at large recursion budgets.</li>
</ul>
<p>However, bigger <img src="https://latex.codecogs.com/png.latex?k">’s require more compute and it seems that, at least for this problem and model size, it’s more efficient to scale <img src="https://latex.codecogs.com/png.latex?N_%5Ctext%7Bsup%7D"> than <img src="https://latex.codecogs.com/png.latex?k">.</p>
<div class="quarto-figure quarto-figure-center">
<figure class="figure">
<p><a href="images/randomized-models-test-curves-compute-normalized.png" class="lightbox" data-gallery="quarto-lightbox-gallery-6"><img src="https://ecntu.com/posts/trm/images/randomized-models-test-curves-compute-normalized.png" class="img-fluid quarto-figure quarto-figure-center figure-img" width="600"></a></p>
</figure>
</div>
<p>There might be other settings where random initial latents pay off. Maybe problems with multiple solutions or models with underfit predictions but with calibrated halt heads. Anyway, I was mostly surprised that initializations don’t collapse onto the same prediction (at least at low depths) and want to come back to this in the future.</p>
</section>
<section id="more-randomization-and-path-independence" class="level3">
<h3 class="anchored" data-anchor-id="more-randomization-and-path-independence">More randomization and path independence</h3>
<p>Randomizing initializations do not only allow for test-time augmentation, but have also been shown to stabilize recurrence by helping the model converge to fixed points. <a href="https://arxiv.org/abs/2211.09961"><em>Path independent</em></a> models are those in which latents converge to fixed points regardless of their initial path. The argument is that path independent models are more likely to take advantage of additional iterations by converging to the solution, unlike path dependent models which diverge.</p>
<div class="quarto-figure quarto-figure-center">
<figure class="figure">
<p><a href="images/path-independence.png" class="lightbox" data-gallery="quarto-lightbox-gallery-7" title="Figure 1 from Path Independent Equilibrium Models Can Better Exploit Test-Time Computation."><img src="https://ecntu.com/posts/trm/images/path-independence.png" class="img-fluid figure-img" alt="Figure 1 from Path Independent Equilibrium Models Can Better Exploit Test-Time Computation."></a></p>
<figcaption>Figure 1 from <em>Path Independent Equilibrium Models Can Better Exploit Test-Time Computation</em>.</figcaption>
</figure>
</div>
<p>The paper also randomized the number of iterations during training to increase path independence. In our case we could vary <img src="https://latex.codecogs.com/png.latex?n"> (the number of iterations in <code>latent_recursion</code> that we spend refining <img src="https://latex.codecogs.com/png.latex?z"> before updating <img src="https://latex.codecogs.com/png.latex?y">), <img src="https://latex.codecogs.com/png.latex?T"> (the number of “warm-up”, no-gradient <code>latent_recursion</code> calls), and <img src="https://latex.codecogs.com/png.latex?N_%5Ctext%7Bsup%7D"> (the number of full <img src="https://latex.codecogs.com/png.latex?T(n%20+%201)">-sized recursion blocks, and the number of supervision steps).</p>
<p>We could also randomize two or more of these at a time. For example, varying both <img src="https://latex.codecogs.com/png.latex?T"> and <img src="https://latex.codecogs.com/png.latex?n"> slightly resembles the truncated BPTT with random start and end steps that <a href="https://arxiv.org/abs/2202.05826">this</a> paper proposes. However, we keep it simple for now.</p>
<p>In each training forward pass we sample from uniforms centered at the default values (<img src="https://latex.codecogs.com/png.latex?n=6">, <img src="https://latex.codecogs.com/png.latex?T=3">, <img src="https://latex.codecogs.com/png.latex?N_%5Ctext%7Bsup%7D=16">) to match the expected compute of deterministic baselines. We try different ranges for these uniforms and keep inference deterministic.</p>
<div class="quarto-figure quarto-figure-center">
<figure class="figure">
<p><a href="images/randomized-models-performance.png" class="lightbox" data-gallery="quarto-lightbox-gallery-8" title="We display mean ± SE and the uniforms’ half-width w."><img src="https://ecntu.com/posts/trm/images/randomized-models-performance.png" class="img-fluid figure-img" alt="We display mean ± SE and the uniforms’ half-width w."></a></p>
<figcaption>We display mean ± SE and the uniforms’ half-width <img src="https://latex.codecogs.com/png.latex?w">.</figcaption>
</figure>
</div>
<p>On the right we show the <em>Asymptotic Alignment Score</em> which tries to capture how path independent a network is. It measures the similarity between the final latents of two predictions made for the same data point. The first prediction is initialized normally, but the second uses the first’s final latents. If the similarity is high, the network maps different initializations to the same fixed points and is more path independent. Since we have both latents, we <code>roll</code> both <img src="https://latex.codecogs.com/png.latex?z"> and <img src="https://latex.codecogs.com/png.latex?y"> for the second forward pass.</p>
<p>Above we plotted the similarity of final <em>predictions</em> instead since we found those to be more consistent. Below we include scores for latents and plot against performance.</p>
<div class="quarto-figure quarto-figure-center">
<figure class="figure">
<p><a href="images/randomized-models-correlation.png" class="lightbox" data-gallery="quarto-lightbox-gallery-9" title="Large dots are per-method means; small ones individual seeds. Spearman \rho above each panel."><img src="https://ecntu.com/posts/trm/images/randomized-models-correlation.png" class="img-fluid figure-img" alt="Large dots are per-method means; small ones individual seeds. Spearman \rho above each panel."></a></p>
<figcaption>Large dots are per-method means; small ones individual seeds. Spearman <img src="https://latex.codecogs.com/png.latex?%5Crho"> above each panel.</figcaption>
</figure>
</div>
<p>Even with all the noise from sensitive training and only 5 seeds, we get a few observations:</p>
<ul>
<li>Like the paper that introduced path independence, we find that it is correlated to performance in this setting.</li>
<li>Reasoning latents (<img src="https://latex.codecogs.com/png.latex?z">’s) saturate and cluster around 1 more so than latent predictions (<img src="https://latex.codecogs.com/png.latex?y">’s). This might be simply because <img src="https://latex.codecogs.com/png.latex?z">’s are updated much more than <img src="https://latex.codecogs.com/png.latex?y">’s.</li>
<li>Surprisingly, random initializations seem to <em>decrease</em> path independence and performance. This doesn’t contradict the path independence paper since their randomization is different. They initialize with zero vectors and only add gaussian noise to half the entries during training, and then use all zeros during testing. I reused the code from previous experiments and randomize during testing too. I’ll try to explore this in a future post.</li>
<li>Even though we increase <img src="https://latex.codecogs.com/png.latex?N_%5Ctext%7Bsup%7D"> at test time, it’s randomizing <img src="https://latex.codecogs.com/png.latex?T"> during training that yields the biggest gains.</li>
<li>Increasing randomization during training (increasing the uniforms’ range) generally seems to improve performance, with a curious blip in <img src="https://latex.codecogs.com/png.latex?N_%5Ctext%7Bsup%7D%20(w=4)">. To investigate: is the blip real? do gains stack by combining methods? how do gains scale with model capacity?</li>
</ul>
</section>
</section>
<section id="to-be-continued" class="level2">
<h2 class="anchored" data-anchor-id="to-be-continued">To be continued</h2>
<p>There’s definitely a lot more to learn about and with TRMs. The random initialization idea likely belongs in the paper’s great “Ideas that failed” section and the experiments need polishing. But I thought releasing a first post and iterating would be more fun.</p>


</section>

 ]]></description>
  <guid>https://ecntu.com/posts/trm/</guid>
  <pubDate>Sat, 11 Apr 2026 04:00:00 GMT</pubDate>
  <media:content url="https://ecntu.com/posts/trm/images/trm-diagram.svg" medium="image" type="image/svg+xml"/>
</item>
<item>
  <title>Fuzzy matching professors to their reviews</title>
  <link>https://ecntu.com/posts/fuzzy-name-matching/</link>
  <description><![CDATA[ 





<section id="context" class="level2">
<h2 class="anchored" data-anchor-id="context">Context</h2>
<p>During my undergrad, I built a simple schedule planning <a href="https://horariositam.com">site</a> for my university. By the third semester I’d grown tired of using excel and learned enough javascript to code up a site which ranked every possible schedule by my preferences. One of these was to upweight schedules based on professors’ ratings extracted from a site like RateMyProfessor.</p>
<p>This is where I first encountered fuzzy profile matching (also called record linkage, entity resolution, etc). Basically trying to decide if profiles in different databases refer to the same person. This came about because student reviewers create a professor’s profile, which meant different names across sites:</p>
<table class="caption-top table">
<thead>
<tr class="header">
<th>Nombre (School)</th>
<th>Nombre (Review)</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>Juan Pérez</td>
<td>J. Perez</td>
</tr>
<tr class="even">
<td>María González</td>
<td>Maria Glez</td>
</tr>
<tr class="odd">
<td>Carlos Ramírez</td>
<td>C. Ramirez</td>
</tr>
</tbody>
</table>
<p>At the time I decided to link profiles using a very simple approach based on the normalized string similarity (<a href="https://www.digitalocean.com/community/tutorials/levenshtein-distance-python">levenshtein ratio</a>) between names:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb1" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> lev_ratio(a, b):</span>
<span id="cb1-2">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> edit_dist(a, b) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> (<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(a)<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(b))</span>
<span id="cb1-3"></span>
<span id="cb1-4"><span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> n_a <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> school_names:</span>
<span id="cb1-5">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> n_b <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> review_names:</span>
<span id="cb1-6">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> lev_ratio(n_a, n_b) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&gt;</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.9</span>:</span>
<span id="cb1-7">            link(n_a, n_b)</span>
<span id="cb1-8">            <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">break</span></span></code></pre></div></div>
<p>While this worked ok at the time, in retrospect it was a very bad approach for several reasons. Since a lot of students still use the site, I thought I could do better after 6 years and a master’s in stats.</p>
</section>
<section id="how-to-do-better" class="level2">
<h2 class="anchored" data-anchor-id="how-to-do-better">How to do better</h2>
<p>I thought I would simply:</p>
<ol type="1">
<li>Gather data by manually matching profiles. I need data to tune, but more importantly, evaluate the approach.</li>
<li>Engineer features and train a pairwise binary classifier to predict the probability of a true match given two profiles.</li>
<li>Use the pairwise classifier to predict matches (or ‘no match’) for every professor on the school’s site.</li>
</ol>
<p>In all of these steps I encountered a few subtleties. I’ll quickly go over these, and what I would have done differently or will improve in the future.</p>
<section id="data" class="level3">
<h3 class="anchored" data-anchor-id="data">Data</h3>
<p>I spent a couple of hours manually matching 150 profiles, 100 for training and the rest for evaluation. Professors on the school site either had one, multiple, or no matching profiles on the review site.</p>
<p>Note that since we only have manually annotated <em>matches</em>, our training set only contains positives. So how to generate negative <img src="https://latex.codecogs.com/png.latex?(%5Ctext%7Bprofile%7D_A,%20%5Ctext%7Bprofile%7D_B)"> pairs to train our classifier?</p>
<p>A naive option could be to set every pair that was not manually matched to negative and end up with <img src="https://latex.codecogs.com/png.latex?N_%7Bschool%7D%20%5Ctimes%20(N_%7Breview%7D%20-%201)"> negative pairs. However, I did not want to deal with the high class imbalance and opted for randomly sampling <img src="https://latex.codecogs.com/png.latex?k_%5Ctext%7Bneg%7D"> of such pairs for every positive one.</p>
<p>As we’ll see below this worked alright. But most of the negatives were really “easy” since two randomly sampled profiles are likely to have wildly different names (and other features). Thus, this first approach will likely not perform very well on the rare but tough cases where truly distinct professors have similar features.</p>
<p>What I wished I had thought of before was using active learning. I could have collected a few matches (say 20 instead of 100), trained a model and used its predictions to find harder examples. To find them you can use low model confidence or a bunch of other heuristics (<a href="https://github.com/koaning/doubtlab">doubtlab</a> has a bunch). After labelling a few of the hardest, you retrain and repeat (making sure to not overfit).</p>
<p>I think this approach could have saved me some labelling and resulted in higher quality negatives. I’ll try to do the experiment sometime soon.</p>
</section>
<section id="features-model" class="level3">
<h3 class="anchored" data-anchor-id="features-model">Features &amp; model</h3>
<p>Besides the names from both platforms, I also included professors’ departments:</p>
<table class="caption-top table">
<colgroup>
<col style="width: 23%">
<col style="width: 24%">
<col style="width: 26%">
<col style="width: 25%">
</colgroup>
<thead>
<tr class="header">
<th>Nombre (S)</th>
<th>Nombre (R)</th>
<th>Depto (S)</th>
<th>Depto (R)</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>Juan Pérez</td>
<td>J. Perez</td>
<td>Matemáticas, Actuaria</td>
<td>Mate</td>
</tr>
<tr class="even">
<td>María González</td>
<td>Maria Glez</td>
<td>Física, Matemáticas</td>
<td>Mate</td>
</tr>
<tr class="odd">
<td>Carlos Ramírez</td>
<td>C. Ramirez</td>
<td>Computación, Actuaria</td>
<td>Compu</td>
</tr>
</tbody>
</table>
<p>After playing around a little, I settled on using only two features for simplicity:</p>
<ul>
<li>A string similarity between the profile names. Went for <a href="https://rapidfuzz.github.io/RapidFuzz/Usage/fuzz.html#token-sort-ratio"><code>token_sort_ratio</code></a> since it handles missing words and is not sensitive to order. For example, “Juan Perez” and “J. Perez” score a perfect 1.</li>
<li>A string similarity between the departments. Because a professor could have multiple school departments (extracted from the classes they give that semester) I simply took the maximum of the similarities (<a href="https://rapidfuzz.github.io/RapidFuzz/Usage/fuzz.html#token-set-ratio"><code>token_set_ratio</code></a> here).</li>
</ul>
<div class="quarto-figure quarto-figure-center">
<figure class="figure">
<p><img src="https://ecntu.com/posts/fuzzy-name-matching/images/feature_scatter.png" class="img-fluid figure-img"></p>
<figcaption>Features per class for the training set. There are <code>k=10</code> times as many negatives as positives. You can see how the negative examples are too easy because the string similarity for randomly sampled pairs is really low.</figcaption>
</figure>
</div>
<p>Keeping with the simplicity theme, the model was a logistic regressor. Later on it might be fun to train a character-level LLM or BERT model to see if they can beat these handcrafted features.</p>
</section>
<section id="final-matching" class="level3">
<h3 class="anchored" data-anchor-id="final-matching">Final matching</h3>
<p>Once the pairwise classifier is trained, how do we use it to make matches?</p>
<p>Should we do something similar to my previous approach from years ago?</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb2" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> prof_a <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> school_profiles:</span>
<span id="cb2-2">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> prof_b <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> review_profiles:</span>
<span id="cb2-3">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> model.predict_proba([prof_a, prof_b]) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&gt;</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.9</span>:</span>
<span id="cb2-4">            link(prof_a, prof_b)</span>
<span id="cb2-5">            <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">break</span></span></code></pre></div></div>
<p>No.&nbsp;In hindsight this was terrible for a couple of reasons. First, I tuned the threshold by skimming the output since I didn’t have any manually matched profiles. Second, we link to the <em>first</em> review profile above the threshold, not the highest. This is sensitive to ordering. A simple improvement and what I ended up using is to consider only the most probable predicted match:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb3" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> prof_a <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> school_profiles:</span>
<span id="cb3-2">    argmax_prof_b, max_prob <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> best_match(prof_a)</span>
<span id="cb3-3">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> max_prob <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&gt;</span> threshold:</span>
<span id="cb3-4">        link(prof_a, argmax_prof_b)</span></code></pre></div></div>
<p>This is still not perfect since it’s still possible to map two school profiles to the same review one. The more principled approach (I might get to later) is framing this as a bipartite graph matching problem. We have school profile nodes and review site nodes and the classifier’s match probabilities as (potential) weighted edges between them. You can then choose the edges that maximize the total edge weights while respecting one-to-one constraints.</p>
</section>
</section>
<section id="results" class="level2">
<h2 class="anchored" data-anchor-id="results">Results</h2>
<p>We still have to tune the <code>threshold</code> and <code>k</code> (no. of negatives to sample per positive for training) parameters. I initially set them to 0.5 and 10 respectively, but in writing this I thought it would be nice to cross validate to see how good this guess was.</p>
<div class="quarto-figure quarto-figure-center">
<figure class="figure">
<p><img src="https://ecntu.com/posts/fuzzy-name-matching/images/cv_results.png" class="img-fluid figure-img"></p>
<figcaption>Cross-validation results (5 folds) showing the percentage of correct matches, incorrect matches, missed matches (abstained when a correct match existed), and the overall number of matches made.</figcaption>
</figure>
</div>
<p>You can see the tradeoffs. Bigger <code>k</code> means more negatives to train with, a bigger training set, but also greater class imbalance. Seeing a lot more negatives than positives makes the model less confident when predicting positives. Hence you can see that for the <code>k=100</code> model to achieve good performance we need to lower the threshold to <code>0.25</code>.</p>
<p>The metrics we care about are inherently a trade-off, but I would say the initial guess (green) fared quite well. Two other configurations (blue and red) performed comparably, and the blue (0.75, 1) option could have saved a few seconds during training.</p>
<p>Anyway, the new method is better on our test set than what I implemented years ago:</p>
<table class="caption-top table">
<thead>
<tr class="header">
<th>Metric</th>
<th>New</th>
<th>Previous</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>Correct</td>
<td>0.9</td>
<td>0.82</td>
</tr>
<tr class="even">
<td>Incorrect match</td>
<td>0.04</td>
<td>0.04</td>
</tr>
<tr class="odd">
<td>Incorrect no match</td>
<td>0.06</td>
<td>0.14</td>
</tr>
<tr class="even">
<td>Has Match</td>
<td>0.76</td>
<td>0.66</td>
</tr>
</tbody>
</table>
<p>The deployed model now links about 74% of professors to a review profile, up from 60% before.</p>
</section>
<section id="final-thoughts" class="level2">
<h2 class="anchored" data-anchor-id="final-thoughts">Final thoughts</h2>
<p>I had fun coming back to this site and fiddling around with a ML problem that is not just “standard” supervised learning. I tried tackling the problem without a literature review for fun, but this is obviously a studied problem. I should mention that since my datasets were pretty small (about 1k profiles each), I didn’t run into the usual compute challenges. <a href="https://www.science.org/doi/10.1126/sciadv.abi8021">Here</a> is a good overview survey if you are interested in what they are and how to alleviate them.</p>


</section>

 ]]></description>
  <guid>https://ecntu.com/posts/fuzzy-name-matching/</guid>
  <pubDate>Tue, 09 Sep 2025 04:00:00 GMT</pubDate>
</item>
<item>
  <title>Engression</title>
  <link>https://ecntu.com/posts/engression/</link>
  <description><![CDATA[ 





<p>Traditional regression models predict the conditional mean <img src="https://latex.codecogs.com/png.latex?%5Cmathbb%20E%5BY%E2%88%A3X=x%5D">, or <a href="https://en.wikipedia.org/wiki/Quantile_regression">sometimes</a> a few quantiles. In contrast, distributional regression attempts to learn the <em>entire</em> conditional distribution <img src="https://latex.codecogs.com/png.latex?Y%7CX=x">. Having access to the full distribution gives us calibrated uncertainty estimates, probabilistic forecasts, etc.</p>
<p>In a recent seminar I learned about engression, a lightweight and principled approach to distributional regression. Instead of predicting a parametric distribution or optimizing a likelihood, engression trains models to transform noise into samples from <img src="https://latex.codecogs.com/png.latex?Y%7CX=x">, using a <a href="https://en.wikipedia.org/wiki/Scoring_rule">proper scoring rule</a> called the <a href="https://sites.stat.washington.edu/raftery/Research/PDF/Gneiting2007jasa.pdf">energy score</a>. It’s implicit, generative, and remarkably straightforward to implement.</p>
<p>Here, I try to explain the main idea, reproduce some of the <a href="https://arxiv.org/abs/2307.00835">paper</a>’s results, and discuss a few of its properties.</p>
<section id="in-a-nutshell" class="level2">
<h2 class="anchored" data-anchor-id="in-a-nutshell">In a nutshell</h2>
<p>We consider a general class of models <img src="https://latex.codecogs.com/png.latex?%5Cmathcal%7BM%7D%20=%20%7Bg(x,%20%5Cvarepsilon)%7D">, where each model takes covariates <img src="https://latex.codecogs.com/png.latex?x"> and a random vector <img src="https://latex.codecogs.com/png.latex?%5Cvarepsilon"> as input. The noise vector <img src="https://latex.codecogs.com/png.latex?%5Cvarepsilon"> is drawn from a fixed distribution independent of <img src="https://latex.codecogs.com/png.latex?x">. We imagine each function <img src="https://latex.codecogs.com/png.latex?g(x,%20%5Cvarepsilon)%20%5Cmapsto%20y"> in <img src="https://latex.codecogs.com/png.latex?%5Cmathcal%7BM%7D"> as defining a conditional distribution of <img src="https://latex.codecogs.com/png.latex?Y%20%5Cmid%20X%20=%20x">. The “best” <img src="https://latex.codecogs.com/png.latex?g"> is found by minimizing the <strong>engression loss</strong>: <img src="https://latex.codecogs.com/png.latex?%0A%5Cmathbb%20E%20%5Cleft%5B%20%5Cleft%20%5C%7C%20Y%20-%20g(X,%20%5Cvarepsilon)%20%5Cright%20%5C%7C%20-%20%5Cfrac%7B1%7D%7B2%7D%20%5Cleft%20%5C%7C%20g(X,%20%5Cvarepsilon)%20-%20g(X,%20%5Cvarepsilon')%20%5Cright%20%5C%7C%20%5Cright%5D%0A"></p>
<p>where <img src="https://latex.codecogs.com/png.latex?%5Cvarepsilon"> and <img src="https://latex.codecogs.com/png.latex?%5Cvarepsilon'"> are independent draws from the noise distribution. This is the negative of the energy score, a proper scoring rule. As a result, the authors show that the minimizer recovers the true conditional distribution <img src="https://latex.codecogs.com/png.latex?Y%20%5Cmid%20X"> — assuming the model class <img src="https://latex.codecogs.com/png.latex?%5Cmathcal%7BM%7D"> is expressive enough. In practice, this means using neural networks with sufficient capacity.</p>
<p>Intuitively, the loss encourages two things. The first term ensures that the generated samples are close to the observed target. It pulls the predicted distribution toward the actual data. The second term penalizes collapsing all samples to a single point. It forces the model to generate diverse samples, reflecting the variability in the data.</p>
<div class="quarto-figure quarto-figure-center">
<figure class="figure">
<p><img src="https://ecntu.com/posts/engression/images/loss.gif" class="img-fluid figure-img"></p>
<figcaption>If we only used the first term in the engression loss, our estimated distribution would collapse. Toy example where <img src="https://latex.codecogs.com/png.latex?Y%7CX%20%5Csim%20N(X,%200.1)"></figcaption>
</figure>
</div>
<p>To minimize the empirical version of the loss using a dataset <img src="https://latex.codecogs.com/png.latex?%5C%7B(X_i,%20Y_i):%20i=1,%20%5Cdots,%20n%5C%7D">, we sample <img src="https://latex.codecogs.com/png.latex?m"> noise vectors <img src="https://latex.codecogs.com/png.latex?%5Cvarepsilon"> <em>per</em> <img src="https://latex.codecogs.com/png.latex?i">-th observation and minimize</p>
<p><img src="https://latex.codecogs.com/png.latex?%0A%5Cfrac%7B1%7D%7Bn%7D%20%5Csum_%7Bi=1%7D%5En%20%5Cleft%20%5B%20%5Cfrac%7B1%7D%7Bm%7D%20%5Csum_%7Bj=1%7D%5Em%20%20%5C%7C%20Y_i%20-%20g(X_i,%20%5Cvarepsilon_%7Bi,j%7D)%20%5C%7C%20-%20%5Cfrac%7B1%7D%7B2m(m-1)%7D%20%5Csum_%7Bj=1%7D%5Em%20%5Csum_%7Bj'=1%7D%5Em%20%5C%7Cg(X_i,%20%5Cvarepsilon_%7Bi,%20j%7D)%20-%20g(X_i,%20%5Cvarepsilon_%7Bi,%20j'%7D)%20%20%5C%7C%5Cright%5D%0A"></p>
<p>Once trained, <img src="https://latex.codecogs.com/png.latex?g"> acts as an <em>implicit</em> and <em>generative</em> model for <img src="https://latex.codecogs.com/png.latex?Y%20%5Cmid%20X">. That is, <img src="https://latex.codecogs.com/png.latex?g"> won’t give us an explicit density <img src="https://latex.codecogs.com/png.latex?P(Y=y%20%5Cmid%20X=x)">, but we can use it to obtain samples from <img src="https://latex.codecogs.com/png.latex?Y%20%5Cmid%20X=x">. We independently draw as much <img src="https://latex.codecogs.com/png.latex?%5Cvarepsilon">’s as we want observations and feed them into <img src="https://latex.codecogs.com/png.latex?g"> to get <img src="https://latex.codecogs.com/png.latex?%5C%7B%20g(x,%20%5Cvarepsilon_j)%20%5C%7D_%7Bj=1%7D%5EK">. With the samples, we can estimate means, medians, confidence intervals, etc. as usual.</p>
<p>Note: While we focused on the main loss explored in the paper, the authors mention several generalizations in Appendix D. Specifically, the energy score is one example of a broader class of kernel scores, and engression can in principle use any proper scoring rule that characterizes a distribution (see Section 2). I wouldn’t be surprised if future work develops variants that emphasize the tails, which could be especially useful in risk-sensitive applications.</p>
</section>
<section id="pre-additive-noise-and-extrapolation" class="level2">
<h2 class="anchored" data-anchor-id="pre-additive-noise-and-extrapolation">Pre-additive noise and extrapolation</h2>
<p>While minimizing the loss above is sufficient to learn the conditional distribution within the training range (assuming <img src="https://latex.codecogs.com/png.latex?g"> is expressive enough), the authors show that, under certain assumptions about the noise structure, engression can also support limited extrapolation.</p>
<p>Most regression and generative models assume that the noise is post-additive. That is, that noise <img src="https://latex.codecogs.com/png.latex?%5Ceta"> is added after applying a nonlinear transformation to the covariates: <img src="https://latex.codecogs.com/png.latex?Y%20=%20g(X)%20+%20%5Ceta">. Pre-additive noise instead assumes <img src="https://latex.codecogs.com/png.latex?Y%20=%20g(X%20+%20%5Ceta)">. This helps with extrapolation because, as the authors note:</p>
<blockquote class="blockquote">
<p>“As such, if the data are generated according to a post-ANM, the observations for the response variable are perturbed values of the true function evaluated at covariate values within the support. We hence generally have no data-driven information about the behaviour of the true function outside the support. In contrast, data generated from a pre-ANM contain response values that reveal some information beyond the support”.</p>
</blockquote>
<div class="quarto-figure quarto-figure-left">
<figure class="figure">
<p><img src="https://ecntu.com/posts/engression/images/noise.svg" class="img-fluid figure-img"></p>
<figcaption>With a pre-Additive Noise Model (right) we gain information beyond the training support. Illustrated here with only two points. The blue and orange dots represent the possible values for <img src="https://latex.codecogs.com/png.latex?x_1"> and <img src="https://latex.codecogs.com/png.latex?x_2"> respectively due to noise from <img src="https://latex.codecogs.com/png.latex?%5Ceta">.</figcaption>
</figure>
</div>
<p>The authors formalize this and show that under certain structural assumptions — like smoothness or monotonicity of <img src="https://latex.codecogs.com/png.latex?g">, and symmetric pre-additive noise — engression can recover aspects of <img src="https://latex.codecogs.com/png.latex?g"> beyond the training range. A key idea is that larger input noise gives you more indirect coverage of nearby regions, so the model can “see” a bit past the edge of the data. I won’t go into the technical details here, but the extrapolation results are interesting and worth checking out if you’re curious.</p>
</section>
<section id="implementation" class="level2">
<h2 class="anchored" data-anchor-id="implementation">Implementation</h2>
<p>We walk through how simple engression is to implement and attempt to reproduce Figure 4 from the paper, which highlights its extrapolation capabilities using synthetic data.</p>
<p>We implement <img src="https://latex.codecogs.com/png.latex?g(x,%20%5Cvarepsilon)%20%5Cmapsto%20y"> as a MLP. In the paper, the authors feed the concatenated vector <img src="https://latex.codecogs.com/png.latex?%5Bx,%20%5Cvarepsilon%5D"> into a standard (deterministic) network. To keep things flexible and modular — and since <img src="https://latex.codecogs.com/png.latex?%5Cvarepsilon"> must be sampled independently of <img src="https://latex.codecogs.com/png.latex?x"> — we can instead:</p>
<div id="485baf7e" class="cell">
<details open="" class="code-fold">
<summary>Defining g(x, eps)</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb1" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">class</span> gConcatenate(nn.Module):</span>
<span id="cb1-2">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">""" g(x, eps) = g([x, eps]), where eps ~ N(0, 1) or Unif(0, 1) """</span></span>
<span id="cb1-3"></span>
<span id="cb1-4">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, model, noise_dim, noise_type <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'normal'</span>, scale <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">1.0</span>):</span>
<span id="cb1-5">        <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">super</span>(gConcatenate, <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>).<span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>()</span>
<span id="cb1-6">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.model <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> model</span>
<span id="cb1-7">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.noise_dim <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> noise_dim</span>
<span id="cb1-8">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.noise_f <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.randn <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> noise_type <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'normal'</span> <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">else</span> torch.rand</span>
<span id="cb1-9">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.scale <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> scale</span>
<span id="cb1-10"></span>
<span id="cb1-11">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> forward(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, x):</span>
<span id="cb1-12">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">assert</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(x.shape) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&gt;=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>,<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'x must have at least 2 dims, where batch is the first'</span></span>
<span id="cb1-13">        eps <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.noise_f(x.shape[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>], <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.noise_dim).to(x.device)</span>
<span id="cb1-14">        eps <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.scale <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> eps</span>
<span id="cb1-15">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.model(torch.cat([x, eps], dim <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>))</span>
<span id="cb1-16">    </span>
<span id="cb1-17"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># example</span></span>
<span id="cb1-18">model <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> MLP(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, [<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">64</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">64</span>], <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)</span>
<span id="cb1-19">g <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> gConcatenate(model, noise_dim <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">100</span>, noise_type <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'uniform'</span>)</span></code></pre></div></div>
</details>
</div>
<p>Of course, there are many different ways one could define <img src="https://latex.codecogs.com/png.latex?g"> beyond concatenating the noise with the input. I’m sure future work will explore them.</p>
<p>Now, define the loss function:</p>
<div id="24ad86f1" class="cell" data-execution_count="87">
<details open="" class="code-fold">
<summary>engression loss</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb2" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> engression_loss(y, preds, return_terms <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>):</span>
<span id="cb2-2">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">"""</span></span>
<span id="cb2-3"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">    Args:</span></span>
<span id="cb2-4"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">        y (torch.Tensor): True target values (batch_size, output_dim).</span></span>
<span id="cb2-5"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">        preds (torch.Tensor): Predicted target with independently sampled noise (batch_size, m_samples, output_dim).</span></span>
<span id="cb2-6"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">        return_terms (bool): If True, return the individual terms of the loss.</span></span>
<span id="cb2-7"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">    """</span></span>
<span id="cb2-8">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">assert</span> y.shape[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> preds.shape[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>] <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">and</span> y.shape[<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> preds.shape[<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>]</span>
<span id="cb2-9">    b, d <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> y.shape</span>
<span id="cb2-10">    b, m, d <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> preds.shape</span>
<span id="cb2-11"></span>
<span id="cb2-12">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Term 1: the absolute error between the predicted and true values</span></span>
<span id="cb2-13">    term1 <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.linalg.vector_norm(preds <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> y[:, <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>, :], <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">ord</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, dim <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>).mean()</span>
<span id="cb2-14"></span>
<span id="cb2-15">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Term 2: pairwise absolute differences between the predicted values</span></span>
<span id="cb2-16">    term2 <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.tensor(<span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.0</span>, device <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> preds.device, dtype <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> preds.dtype)</span>
<span id="cb2-17"></span>
<span id="cb2-18">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> m <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&gt;</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>:</span>
<span id="cb2-19">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># cdist is convinient. The result shape before mean is (n, m, m).</span></span>
<span id="cb2-20">        mean_pairwise_l2_dists <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.cdist(preds, preds, p <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>).mean()</span>
<span id="cb2-21">        term2 <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.5</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> mean_pairwise_l2_dists <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> m <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> (m <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)</span>
<span id="cb2-22">        </span>
<span id="cb2-23">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> (term1 <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> term2, term1, term2) <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> return_terms <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">else</span> term1 <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> term2</span></code></pre></div></div>
</details>
</div>
<p>And a simple training loop:</p>
<div id="f6ca290b" class="cell" data-execution_count="88">
<details open="" class="code-fold">
<summary>Training loop</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb3" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> train(g, dl, m, lr <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.001</span>, epochs <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">100</span>, verbose <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>):</span>
<span id="cb3-2"></span>
<span id="cb3-3">    optimizer <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.optim.Adam(g.parameters(), lr <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> lr)</span>
<span id="cb3-4"></span>
<span id="cb3-5">    losses <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> []</span>
<span id="cb3-6">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> _ <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(epochs):</span>
<span id="cb3-7">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> x, y <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> dl:</span>
<span id="cb3-8"></span>
<span id="cb3-9">            g.zero_grad()</span>
<span id="cb3-10"></span>
<span id="cb3-11">            <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Generate m samples from the model</span></span>
<span id="cb3-12">            <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># shape: (batch_size, m, output_dim)</span></span>
<span id="cb3-13">            preds <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.stack([g(x) <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> _ <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(m)], dim <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)</span>
<span id="cb3-14"></span>
<span id="cb3-15">            loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> engression_loss(y, preds)</span>
<span id="cb3-16"></span>
<span id="cb3-17">            loss.backward()</span>
<span id="cb3-18">            optimizer.step()</span>
<span id="cb3-19"></span>
<span id="cb3-20">            losses.append(loss.item())</span>
<span id="cb3-21">            <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> verbose: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(loss.item())</span>
<span id="cb3-22">    </span>
<span id="cb3-23">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> losses</span></code></pre></div></div>
</details>
</div>
<p>Finally, we attempt to replicate Figure 4 from the paper using simulated data with pre-additive noise. For the <code>softplus</code> case, for example, we generate <img src="https://latex.codecogs.com/png.latex?X%20%5Csim%20%5Ctext%7BUnif%7D%5B-2,%202%5D">, <img src="https://latex.codecogs.com/png.latex?%5Ceta%20%5Csim%20%5Cmathcal%20N(0,%201)"> and set <img src="https://latex.codecogs.com/png.latex?Y%20=%20g%5E%5Cast(x%20+%20%5Ceta)"> with <img src="https://latex.codecogs.com/png.latex?g%5E%5Cast(t)%20=%20%5Clog(1+%5Cexp(t))">. We then train simple MLPs using the engression loss and evaluate their ability to recover the true median — given by <img src="https://latex.codecogs.com/png.latex?g%5E%5Cast(x)"> — especially outside the range of the training data.</p>
<div id="63f80c15" class="cell" data-execution_count="90">
<details class="code-fold">
<summary>Synthetic data</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb4" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1">g_stars <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> {</span>
<span id="cb4-2">    <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'softplus'</span>: F.softplus,</span>
<span id="cb4-3">    <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'square'</span>: <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">lambda</span> x: (torch.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">max</span>(x, torch.tensor(<span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.0</span>)) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">**</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>,</span>
<span id="cb4-4">    <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'cubic'</span>: <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">lambda</span> x: (x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">**</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span>,</span>
<span id="cb4-5">    <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'log'</span>: <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">lambda</span> x: (x<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> math.log(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span>)<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span>(x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&lt;=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> (torch.log(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> x<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span>(x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&gt;</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>)))<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span>(x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&gt;</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>) </span>
<span id="cb4-6">}</span>
<span id="cb4-7"></span>
<span id="cb4-8">x_train_lims <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> {</span>
<span id="cb4-9">    <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'softplus'</span>: (<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>),</span>
<span id="cb4-10">    <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'square'</span>: (<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>),</span>
<span id="cb4-11">    <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'cubic'</span>: (<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>),</span>
<span id="cb4-12">    <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'log'</span>: (<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>)</span>
<span id="cb4-13">}</span>
<span id="cb4-14"></span>
<span id="cb4-15">x_test_lims <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> {</span>
<span id="cb4-16">    <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'softplus'</span>: (<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">8</span>),</span>
<span id="cb4-17">    <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'square'</span>: (<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">6</span>),</span>
<span id="cb4-18">    <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'cubic'</span>: (<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">6</span>),</span>
<span id="cb4-19">    <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'log'</span>: (<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span>)</span>
<span id="cb4-20">}</span>
<span id="cb4-21"></span>
<span id="cb4-22"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> train_data(n, input_dim, <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">type</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'softplus'</span>, pre_additive_noise <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>):</span>
<span id="cb4-23"></span>
<span id="cb4-24">    g_star <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> g_stars[<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">type</span>]</span>
<span id="cb4-25"></span>
<span id="cb4-26">    a, b <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> x_train_lims[<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">type</span>]</span>
<span id="cb4-27">    X <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.rand((n, input_dim)) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> (b <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> a) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> a <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Unif(a, b)</span></span>
<span id="cb4-28"></span>
<span id="cb4-29">    sd <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span> <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">type</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">!=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'cubic'</span> <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">else</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">1.1</span></span>
<span id="cb4-30">    eta <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.randn((n, input_dim)) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> sd</span>
<span id="cb4-31">    </span>
<span id="cb4-32">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> pre_additive_noise:</span>
<span id="cb4-33">        Y <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> g_star(X <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> eta)</span>
<span id="cb4-34">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">else</span>:</span>
<span id="cb4-35">        Y <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> g_star(X) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> eta</span>
<span id="cb4-36"></span>
<span id="cb4-37">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> X, Y</span>
<span id="cb4-38"></span>
<span id="cb4-39"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># example</span></span>
<span id="cb4-40">X, Y <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> train_data(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1000</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">type</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'softplus'</span>, pre_additive_noise <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)</span></code></pre></div></div>
</details>
</div>
<div id="e01519ce" class="cell" data-execution_count="94">
<details class="code-fold">
<summary>Figure 4</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb5" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1">input_dim, hidden_dim, output_dim <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">128</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span></span>
<span id="cb5-2">noise_dim, m_train, m_pred <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">100</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">512</span></span>
<span id="cb5-3"></span>
<span id="cb5-4">batch_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1024</span></span>
<span id="cb5-5">epochs <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">20</span></span>
<span id="cb5-6"></span>
<span id="cb5-7">ds_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">lambda</span> name: <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">50_000</span> <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> name <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'softplus/square'</span> <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">else</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">100_000</span> <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># 50k, 100k</span></span>
<span id="cb5-8">depth <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">lambda</span> name: <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span> <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> name <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'softplus/square'</span> <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">else</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span></span>
<span id="cb5-9">hidden_dim <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">100</span></span>
<span id="cb5-10">lr <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">lambda</span> name: <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">5e-2</span> <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">#1e-3 #if name in 'softplus/square' else 1e-4</span></span>
<span id="cb5-11"></span>
<span id="cb5-12">n_runs <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">20</span></span>
<span id="cb5-13">pre_additive_noise <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span></span>
<span id="cb5-14"></span>
<span id="cb5-15">saved_g_preds <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> {}</span>
<span id="cb5-16">saved_l1_preds <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> {}</span>
<span id="cb5-17"></span>
<span id="cb5-18">extra <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'_post'</span> <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">not</span> pre_additive_noise <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">else</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">''</span></span>
<span id="cb5-19"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># saved_g_preds = torch.load(f'logs/saved_g_preds{extra}.pt')</span></span>
<span id="cb5-20"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># saved_l1_preds = torch.load(f'logs/saved_l1_preds{extra}.pt')</span></span>
<span id="cb5-21"></span>
<span id="cb5-22"></span>
<span id="cb5-23">f, axs <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> plt.subplots(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, figsize <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> (<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">7</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">7</span>))</span>
<span id="cb5-24"></span>
<span id="cb5-25"><span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> (name, g_star), ax <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">zip</span>(g_stars.items(), axs.flatten()):</span>
<span id="cb5-26"></span>
<span id="cb5-27">    t <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.linspace(x_test_lims[name][<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>], x_test_lims[name][<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>], <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">25</span>)[:, <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>]</span>
<span id="cb5-28"></span>
<span id="cb5-29">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Cache to fix plot, etc</span></span>
<span id="cb5-30">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> name <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> saved_g_preds:</span>
<span id="cb5-31">        g_preds <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> saved_g_preds[name]</span>
<span id="cb5-32"></span>
<span id="cb5-33">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">else</span>:</span>
<span id="cb5-34">        g_preds <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> []</span>
<span id="cb5-35">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> seed <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(n_runs):</span>
<span id="cb5-36"></span>
<span id="cb5-37">            set_seed(seed)</span>
<span id="cb5-38">            <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># model = MLP(input_dim + noise_dim, [hidden_dim] * depth(name), output_dim)</span></span>
<span id="cb5-39">            model <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> ResMLP(input_dim <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> noise_dim, hidden_dim, depth(name), output_dim)</span>
<span id="cb5-40">            g <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> gConcatenate(model, noise_dim <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> noise_dim, noise_type <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'normal'</span>)</span>
<span id="cb5-41"></span>
<span id="cb5-42">            X, Y <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> train_data(ds_size(name), input_dim, <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">type</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> name, pre_additive_noise <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> pre_additive_noise)</span>
<span id="cb5-43">            dl <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.utils.data.DataLoader(<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">list</span>(<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">zip</span>(X, Y)), batch_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> batch_size <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> batch_size <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">else</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(Y), shuffle <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>, drop_last <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)</span>
<span id="cb5-44"></span>
<span id="cb5-45">            losses <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> train(g, dl, m_train, epochs <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> epochs, verbose <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>, lr <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> lr(name))</span>
<span id="cb5-46"></span>
<span id="cb5-47">            <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Predict the median using g: sample m_pred per point and take the median</span></span>
<span id="cb5-48">            <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">with</span> torch.no_grad():</span>
<span id="cb5-49">                g.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">eval</span>()</span>
<span id="cb5-50">                g_pred <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.quantile(</span>
<span id="cb5-51">                    torch.stack([g(t) <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> _ <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(m_pred)], dim <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>), q <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.5</span>, dim <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span></span>
<span id="cb5-52">                )</span>
<span id="cb5-53">                g_preds.append(g_pred)</span>
<span id="cb5-54">        </span>
<span id="cb5-55">        g_preds <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.stack(g_preds)</span>
<span id="cb5-56">        saved_g_preds[name] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> g_preds</span>
<span id="cb5-57"></span>
<span id="cb5-58">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># L1 baseline</span></span>
<span id="cb5-59">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> name <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> saved_l1_preds:</span>
<span id="cb5-60">        l1_preds <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> saved_l1_preds[name]</span>
<span id="cb5-61">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">else</span>:</span>
<span id="cb5-62">        l1_preds <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> []</span>
<span id="cb5-63">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> seed <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(n_runs):</span>
<span id="cb5-64">            set_seed(seed)</span>
<span id="cb5-65">            model <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> ResMLP(input_dim, hidden_dim, depth(name), output_dim)</span>
<span id="cb5-66">            X, Y <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> train_data(ds_size(name), input_dim, <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">type</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> name, pre_additive_noise <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)</span>
<span id="cb5-67">            dl <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.utils.data.DataLoader(<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">list</span>(<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">zip</span>(X, Y)), batch_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> batch_size <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> batch_size <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">else</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(Y), shuffle <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>, drop_last <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)</span>
<span id="cb5-68">            l1_losses <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> train_l1(model, dl, epochs <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> epochs, verbose <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>, lr <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> lr(name))</span>
<span id="cb5-69"></span>
<span id="cb5-70">            <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">with</span> torch.no_grad():</span>
<span id="cb5-71">                model.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">eval</span>()</span>
<span id="cb5-72">                l1_preds.append(model(t))</span>
<span id="cb5-73"></span>
<span id="cb5-74">        l1_preds <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.stack(l1_preds)</span>
<span id="cb5-75">        saved_l1_preds[name] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> l1_preds</span>
<span id="cb5-76"></span>
<span id="cb5-77"></span>
<span id="cb5-78">    X, Y <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> train_data(ds_size(name), <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">type</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> name, pre_additive_noise <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)</span>
<span id="cb5-79">    ax.scatter(X[:<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">5000</span>], Y[:<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">5000</span>], color <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'gray'</span>,  alpha <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.1</span>, s <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)</span>
<span id="cb5-80">    ax.plot(t, g_star(t), label <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f'True g'</span>, color <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'red'</span>)</span>
<span id="cb5-81"></span>
<span id="cb5-82">    ax.plot(t, l1_preds.mean(dim <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>), color <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'tab:blue'</span>, label <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'L1'</span>)</span>
<span id="cb5-83">    ax.fill_between(t.flatten(),</span>
<span id="cb5-84">        l1_preds.quantile(<span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.10</span>, dim <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>).flatten(),</span>
<span id="cb5-85">        l1_preds.quantile(<span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.90</span>, dim <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>).flatten(),</span>
<span id="cb5-86">        alpha <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.2</span>, color <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'tab:blue'</span></span>
<span id="cb5-87">    )</span>
<span id="cb5-88"></span>
<span id="cb5-89">    ax.plot(t, g_preds.mean(dim <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>), color <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'tab:orange'</span>, label <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'engression'</span>)</span>
<span id="cb5-90">    ax.fill_between(t.flatten(),</span>
<span id="cb5-91">        g_preds.quantile(<span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.10</span>, dim <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>).flatten(),</span>
<span id="cb5-92">        g_preds.quantile(<span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.90</span>, dim <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>).flatten(),</span>
<span id="cb5-93">        alpha <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.2</span>, color <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'tab:orange'</span></span>
<span id="cb5-94">    )</span>
<span id="cb5-95">    ax.set_title(name)</span>
<span id="cb5-96"></span>
<span id="cb5-97"></span>
<span id="cb5-98">axs[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>].legend()</span>
<span id="cb5-99">axs[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>].set_ylim((<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">5</span>))</span>
<span id="cb5-100">axs[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>].set_ylim((<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">80</span>))</span>
<span id="cb5-101"></span>
<span id="cb5-102">f.tight_layout()</span>
<span id="cb5-103">f.savefig(<span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f'images/figure4</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>extra<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">.png'</span>, dpi <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">300</span>, bbox_inches <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'tight'</span>)</span>
<span id="cb5-104"></span>
<span id="cb5-105"></span>
<span id="cb5-106">torch.save(saved_g_preds, <span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f'logs/saved_g_preds</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>extra<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">.pt'</span>)</span>
<span id="cb5-107">torch.save(saved_l1_preds, <span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f'logs/saved_l1_preds</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>extra<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">.pt'</span>)</span></code></pre></div></div>
</details>
<div class="cell-output cell-output-display">
<div class="quarto-figure quarto-figure-center">
<figure class="figure">
<p><img src="https://ecntu.com/posts/engression/index_files/figure-html/cell-9-output-1.png" class="img-fluid figure-img"></p>
<figcaption>Comparing engression and L1 regression’s extrapolation performance on synthetic data. Lines are predicted and true (red) conditional medians. Bands are 10-90 percentiles over 20 runs.</figcaption>
</figure>
</div>
</div>
</div>
<p>We find, as the authors do, that while both methods perform similarly in-domain, engression is much better extrapolating than L1 regression — at least for monotone and with data generated with pre-additive noise.</p>
<div class="callout callout-style-default callout-note no-icon callout-titled" title="What if the noise is not pre-additive?">
<div class="callout-header d-flex align-content-center collapsed" data-bs-toggle="collapse" data-bs-target=".callout-1-contents" aria-controls="callout-1" aria-expanded="false" aria-label="Toggle callout">
<div class="callout-icon-container">
<i class="callout-icon no-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>What if the noise is not pre-additive?
</div>
<div class="callout-btn-toggle d-inline-block border-0 py-1 ps-1 pe-0 float-end"><i class="callout-toggle"></i></div>
</div>
<div id="callout-1" class="callout-1-contents callout-collapse collapse">
<div class="callout-body-container callout-body">
<p>If we generate the synthetic data with post-additive noise instead, we see that engression loses its extrapolation capability — in accordance with the theory presented in the paper.</p>
<div class="quarto-figure quarto-figure-center">
<figure class="figure">
<p><img src="https://ecntu.com/posts/engression/images/figure4_post.png" class="img-fluid figure-img"></p>
<figcaption>Figure 4 repeated with data generated with post-additive noise instead.</figcaption>
</figure>
</div>
<p>A natural question, then, is whether one can diagnose whether the noise in a real dataset is pre- or post-additive. The paper doesn’t address this directly, but it’s an important question.</p>
</div>
</div>
</div>
<p>While reproducing Figure 4, we also noted a few practical details that matter more than expected. Despite the simplicity of the functions, the authors used between 50k and 100k samples, depending on the function, and relatively large networks. In our experiments, the cubic and logarithmic scenarios struggled with extrapolation until we added residual connections. Also, the noise dimension was set to 100, which seems surprisingly high but turned out to be important.</p>
</section>
<section id="hyperparameters" class="level2">
<h2 class="anchored" data-anchor-id="hyperparameters">Hyperparameters</h2>
<p>Engression introduces a couple of hyperparameters related to the noise vectors (<img src="https://latex.codecogs.com/png.latex?%5Cvarepsilon">) sampled during training: namely, <img src="https://latex.codecogs.com/png.latex?m">, the number of noise samples drawn per example, and the distribution from which <img src="https://latex.codecogs.com/png.latex?%5Cvarepsilon"> is drawn. For example, in their synthetic data experiments, the authors set <img src="https://latex.codecogs.com/png.latex?m%20=%202"> and sample <img src="https://latex.codecogs.com/png.latex?%5Cvarepsilon%20%5Csim%20%5Ctext%7BUnif%7D%5B0,%201%5D%5E%7B100%7D">.</p>
<p>The parameter <img src="https://latex.codecogs.com/png.latex?m"> controls a compute-variance trade-off during training. Increasing <img src="https://latex.codecogs.com/png.latex?m"> means we obtain a cleaner (lower variance) Monte Carlo estimate of the population engression loss and hence gradients. However, the cost of computing the loss grows linearly in <img src="https://latex.codecogs.com/png.latex?m"> for the first term but quadratically for the second, since we compute pairwise distances.</p>
<p>That said, since these are Monte Carlo estimates, we expect diminishing returns as <img src="https://latex.codecogs.com/png.latex?m"> increases. Another consolation point is that we can try to combat the quadratic cost by training for more epochs. This allows the model to revisit each example with different noise vectors, effectively averaging over more noisy estimates of the loss.</p>
<p>To test this, we vary <img src="https://latex.codecogs.com/png.latex?m"> on a toy example. We observe that, given enough data and epochs, all runs eventually reach similar loss values. However, models with higher <img src="https://latex.codecogs.com/png.latex?m"> converge faster and more stably. Perhaps not surprisingly, the conditional median tends to be learned first, with the tails filling in later.</p>
<!-- TODO find reference pointers for last sentence -->
<div id="b2df4995" class="cell">
<details class="code-fold">
<summary>Vary <img src="https://latex.codecogs.com/png.latex?m"> on toy data</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb6" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1">n <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">5_000</span></span>
<span id="cb6-2">X <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.rand(n) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>               <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># X ~ Unif(-1, 1)</span></span>
<span id="cb6-3">Y <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> X <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> (torch.randn_like(X) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)       <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Y|X ~ N(X, 1)</span></span>
<span id="cb6-4">X, Y <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> X[:, <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>], Y[:, <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>]</span>
<span id="cb6-5"></span>
<span id="cb6-6">t <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.linspace(<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">25</span>)[:, <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>]</span>
<span id="cb6-7"></span>
<span id="cb6-8"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># True Y|X quantiles if we assume Y|X ~ N(X, 1)</span></span>
<span id="cb6-9">qs <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> [<span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.1</span>, <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.5</span>, <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.9</span>]</span>
<span id="cb6-10">true_quantiles <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> {q: t <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> norm.ppf(q) <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> q <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> qs}</span>
<span id="cb6-11"></span>
<span id="cb6-12"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Estimated Y|X quantiles</span></span>
<span id="cb6-13"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># [m_train, quantile level] = list of [est quantile], one per batch</span></span>
<span id="cb6-14">quantiles <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> {}</span>
<span id="cb6-15"></span>
<span id="cb6-16">input_dim, hidden_dim, output_dim <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">64</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span></span>
<span id="cb6-17">noise_dim <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">50</span></span>
<span id="cb6-18">mini_batch_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">32</span></span>
<span id="cb6-19">m_pred <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">512</span></span>
<span id="cb6-20"></span>
<span id="cb6-21"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># noise_scale = 1</span></span>
<span id="cb6-22"></span>
<span id="cb6-23">stats <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> defaultdict(<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">list</span>)</span>
<span id="cb6-24"></span>
<span id="cb6-25">m_trains <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> [<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">8</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">16</span>]</span>
<span id="cb6-26">noise_scales <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> [<span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span>]</span>
<span id="cb6-27"></span>
<span id="cb6-28"><span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> m_train, noise_scale <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> product(m_trains, noise_scales):</span>
<span id="cb6-29"></span>
<span id="cb6-30">    set_seed(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">69</span>)</span>
<span id="cb6-31">    dl <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.utils.data.DataLoader(<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">list</span>(<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">zip</span>(X, Y)), batch_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> mini_batch_size, shuffle <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>)</span>
<span id="cb6-32">    model <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> ResMLP(in_dim <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> input_dim <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> noise_dim, hidden_dim <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> hidden_dim, n_blocks <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, out_dim <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> output_dim)</span>
<span id="cb6-33">    g <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> gConcatenate(model, noise_dim <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> noise_dim, scale <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> noise_scale)</span>
<span id="cb6-34"></span>
<span id="cb6-35">    optimizer <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.optim.Adam(g.parameters(), lr <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">1e-4</span>)</span>
<span id="cb6-36"></span>
<span id="cb6-37"></span>
<span id="cb6-38">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> _ <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span>):</span>
<span id="cb6-39">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> x, y <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> dl:</span>
<span id="cb6-40"></span>
<span id="cb6-41">            g.zero_grad()</span>
<span id="cb6-42"></span>
<span id="cb6-43">            <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Generate m samples from the model</span></span>
<span id="cb6-44">            preds <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.stack([g(x) <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> _ <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(m_train)], dim <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)</span>
<span id="cb6-45">            loss, loss_1, loss_2 <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> engression_loss(y, preds, return_terms <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)</span>
<span id="cb6-46">            loss.backward()</span>
<span id="cb6-47">            optimizer.step()</span>
<span id="cb6-48"></span>
<span id="cb6-49">            <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">with</span> torch.no_grad():</span>
<span id="cb6-50"></span>
<span id="cb6-51">                q_losses <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> []</span>
<span id="cb6-52">                preds <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.stack([g(t) <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> _ <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(m_pred)], dim <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)</span>
<span id="cb6-53">                <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> q <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> qs:</span>
<span id="cb6-54">                    pred_quantile <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.quantile(preds, q <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> q, dim <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)</span>
<span id="cb6-55">                    quantiles[m_train, noise_scale, q] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> quantiles.get((m_train, noise_scale, q), []) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> [pred_quantile]</span>
<span id="cb6-56">                    </span>
<span id="cb6-57">                    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> q <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> true_quantiles:</span>
<span id="cb6-58">                        stats[m_train, noise_scale, q].append(</span>
<span id="cb6-59">                            (true_quantiles[q] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> pred_quantile.mean(dim <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>)).flatten().<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">abs</span>().mean().item()</span>
<span id="cb6-60">                        )</span>
<span id="cb6-61"></span>
<span id="cb6-62">                stats[m_train, noise_scale, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'loss'</span>].append(loss.item())</span>
<span id="cb6-63">                stats[m_train, noise_scale, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'loss_1'</span>].append(loss_1.item())</span>
<span id="cb6-64">                stats[m_train, noise_scale, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'loss_2'</span>].append(loss_2.item())</span>
<span id="cb6-65"></span>
<span id="cb6-66">torch.save((quantiles, stats), <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'logs/m_noise_scale.pt'</span>)</span></code></pre></div></div>
</details>
</div>
<div id="fd9269f9" class="cell" data-execution_count="9">
<div class="cell-output cell-output-display">
<div class="quarto-figure quarto-figure-center">
<figure class="figure">
<p><img src="https://ecntu.com/posts/engression/index_files/figure-html/cell-14-output-1.png" class="img-fluid figure-img"></p>
<figcaption>Effects of varying <img src="https://latex.codecogs.com/png.latex?m"> on a simple example. Increasing <img src="https://latex.codecogs.com/png.latex?m"> yields more stable losses and faster reduction in the error estimating <img src="https://latex.codecogs.com/png.latex?Y%7CX"> quantiles.</figcaption>
</figure>
</div>
</div>
</div>
<p>While the only formal requirement on the noise distribution is that <img src="https://latex.codecogs.com/png.latex?%5Cvarepsilon"> be independent of <img src="https://latex.codecogs.com/png.latex?x">, its specific choice can have important practical implications.</p>
<p>For instance, in our imperfect toy setup <img src="https://latex.codecogs.com/png.latex?%5Csigma_%5Cvarepsilon"> seems to resemble a smoothing or locality parameter: large values encourage distant <img src="https://latex.codecogs.com/png.latex?x">’s to map to the same <img src="https://latex.codecogs.com/png.latex?y">, producing smoother estimates. Imagining a one dimensional <img src="https://latex.codecogs.com/png.latex?%5Cvarepsilon">, and <img src="https://latex.codecogs.com/png.latex?g(x,%20%5Cvarepsilon)%20=%20g(x%20+%20%5Cvarepsilon)"> should help with this intuition.</p>
<div id="917c82b5" class="cell">
<div class="cell-output cell-output-display">
<div class="quarto-figure quarto-figure-center">
<figure class="figure">
<p><img src="https://ecntu.com/posts/engression/index_files/figure-html/cell-15-output-1.png" class="img-fluid figure-img"></p>
<figcaption>Conditional medians learned after training with gaussian noise with different <img src="https://latex.codecogs.com/png.latex?%5Csigma_%5Cvarepsilon">, with <img src="https://latex.codecogs.com/png.latex?m=2">.</figcaption>
</figure>
</div>
</div>
</div>
<p>So, while sufficiently expressive model should in principle learn to correctly map any i.i.d. noise, in practice the choice of noise distribution may affect convergence speed and the quality of the learned conditional distribution.</p>
</section>
<section id="real-data" class="level2">
<h2 class="anchored" data-anchor-id="real-data">Real data</h2>
<p>So far, we’ve focused on toy problems and synthetic experiments, but the authors also evaluate the method on real datasets in Section 4. They benchmark both univariate and multivariate tasks — including point prediction, interval estimation, and full distributional modeling — and compare it against standard approaches like L2 and L1 regression, as well as <a href="https://www.jmlr.org/papers/volume7/meinshausen06a/meinshausen06a.pdf">quantile regression forests</a>. Engression consistently outperforms these baselines, especially when extrapolating.</p>
<p>In their experiments, the authors use a MLP trained directly on the tabular inputs. But it’s <a href="https://arxiv.org/abs/2207.08815">well known</a> that deep nets are not particularly strong on tabular data (although this might be <a href="https://www.nature.com/articles/s41586-024-08328-6">changing</a>). That said, there’s nothing inherent in engression that requires us to use a neural network end-to-end. We could just as easily treat the engression-trained MLP as a modular head and stack it on top of any strong tabular model.</p>
<div class="quarto-figure quarto-figure-center">
<figure class="figure">
<p><img src="https://ecntu.com/posts/engression/images/stacking_engression.svg" class="img-fluid figure-img"></p>
<figcaption>We can stack engression on top of any base model</figcaption>
</figure>
</div>
<p>I tried <a href="_temperature.ipynb">this</a> idea on a probabilistic forecasting <a href="https://www.kaggle.com/competitions/probabilistic-forecasting-i-temperature/overview">competition</a>, where the goal was to predict a set of conditional quantiles (0.05 through 0.95). The winning submission had used <a href="https://catboost.ai/">catboost</a> with a multi-quantile loss, followed by <a href="https://en.wikipedia.org/wiki/Conformal_prediction">conformal prediction</a>. Keeping the base model and replacing the conformalization step with a small engression-trained MLP improved on the winning model’s <a href="https://en.wikipedia.org/wiki/Scoring_rule#Conditional_continuous_ranked_probability_score">CRPS score</a> by about 0.04 on the private test set<sup>1</sup> – without tuning any of the engression hyperparameters. I thought this was impressive, as conformalization is the go-to calibration method right now. Of course, if you wanted finite-sample guarantees, you could still conformalize on top.</p>
<p>The broader point is that the pre-additive noise assumption underlying engression appears to hold up on real-world datasets — and that stacking engression as a flexible head could make it a practical upgrade to existing tabular pipelines.</p>
</section>
<section id="final-thoughts" class="level2">
<h2 class="anchored" data-anchor-id="final-thoughts">Final thoughts</h2>
<p>Engression’s main strength lies in its simplicity and flexibility. It requires no parametric assumptions on the output distribution, no likelihood computations, no adversarial training, and no architectural constraints like invertibility. It scales naturally to multivariate <img src="https://latex.codecogs.com/png.latex?X"> and <img src="https://latex.codecogs.com/png.latex?Y">, and at test time, sampling is fast and easily parallelizable. These properties make it particularly appealing for tasks like forecasting, simulation, or structured prediction, where calibrated uncertainty is important but explicit density evaluation is not.</p>
<p>That said, engression comes with tradeoffs. Because it models distributions implicitly — without yielding closed-form densities — it’s less suitable for inference tasks that rely on likelihoods. Its one-shot sampling may also struggle with complex, multimodal distributions, as the energy score tends to cover modes rather than <a href="https://sander.ai/2020/03/24/audio-generation.html#mode-covering-vs-mode-seeking-behaviour">seek them</a>. In such cases, methods like diffusion models or normalizing flows might offer better performance, albeit at higher computational and implementation cost.</p>
<p>Encouragingly, the authors have begun to explore extensions. A recent <a href="https://arxiv.org/abs/2502.13747v1">paper</a> proposes a multi-step version of engression that improves performance on challenging tasks. Another work introduces <a href="https://arxiv.org/abs/2404.13649">distributional autoencoders</a>, combining engression with dimensionality reduction.</p>
<p>Overall, engression is a clever and lightweight approach to distributional regression. I’m excited to see where future research and applications take it. The <a href="https://arxiv.org/abs/2307.00835">paper</a> is very readable, and the authors have released their code <a href="https://github.com/xwshen51/engression/">here</a> if you want to play with it. I’m also learning about python packaging — here’s a <a href="https://github.com/emiliocantuc/engression-pytorch">small one</a> with the loss and a few wrappers for convenience.</p>
<hr>
<p>Thanks for reading! If you spot any errors, or have comments or suggestions, feel free to reach out.</p>


</section>


<div id="quarto-appendix" class="default"><section id="footnotes" class="footnotes footnotes-end-of-document"><h2 class="anchored quarto-appendix-heading">Footnotes</h2>

<ol>
<li id="fn1"><p>For context, the difference between first and second place was 0.06.↩︎</p></li>
</ol>
</section></div> ]]></description>
  <category>paper</category>
  <guid>https://ecntu.com/posts/engression/</guid>
  <pubDate>Tue, 06 May 2025 04:00:00 GMT</pubDate>
</item>
<item>
  <title>Distilling the Knowledge in a Neural Network</title>
  <link>https://ecntu.com/posts/distilling-knowledge/</link>
  <description><![CDATA[ 





<section id="idea" class="level1">
<h1>Idea</h1>
<p>This classic paper introduced <em>distillation</em> as a way of transferring knowledge from a big network teacher into a small one. The core observation is that we should use the big model’s output distribution as soft labels to train the small model.</p>
<p>Remember that in classification we measure the <a href="https://en.wikipedia.org/wiki/Cross-entropy">cross-entropy</a> loss, given the predicted <img src="https://latex.codecogs.com/png.latex?%5Chat%20y_c"> and correct <img src="https://latex.codecogs.com/png.latex?y_c"> class probabilities of an example, with:</p>
<p><img src="https://latex.codecogs.com/png.latex?%0AL(%5Chat%20y,y)%20=%20-%5Csum_c%20y_c%20%5Clog%20%5Chat%20y_c%0A"></p>
<p>To use soft labels we just set <img src="https://latex.codecogs.com/png.latex?y%20=%20f_%7B%5Ctext%7Bbig%7D%7D(x)">.</p>
<p>These soft labels provide a much richer training signal for the smaller model, especially when the larger model distributes its probability mass across multiple classes (i.e.&nbsp;when the labels have high entropy). To force this high entropy, the authors propose increasing the temperature <img src="https://latex.codecogs.com/png.latex?T"> of the softmax layer in the larger model to produce the soft labels. The small model trains with this same temperature but then sets it to 1 during testing.</p>
<div class="quarto-figure quarto-figure-center">
<figure class="figure">
<p><img src="https://ecntu.com/posts/distilling-knowledge/images/demo.png" class="img-fluid figure-img"></p>
<figcaption>Increasing the temperature of the big model produces softer and more informative labels.</figcaption>
</figure>
</div>
<p>They also had better results by adding a small term to the loss function with the regular hard-labeled cross-entropy. The reasoning is that the model may not have enough capacity to learn the soft targets, so <em>“erring in the direction of the correct answer turns out to be helpful”</em>. If we write the output of a model with temperature <img src="https://latex.codecogs.com/png.latex?T"> as <img src="https://latex.codecogs.com/png.latex?f(x;%20T)">, then the complete loss is</p>
<p><img src="https://latex.codecogs.com/png.latex?%0AL_%7B%5Ctext%7Bdistill%7D%7D(x,y)%20=%20a%20T%5E2%20%5Ccdot%20L%5Cleft%5B%20f_%7B%5Ctext%7Bsmall%7D%7D(x;%20T),%20f_%7B%5Ctext%7Bbig%7D%7D(x;%20T)%20%5Cright%5D%20+%20(1-a)%20%5Ccdot%20L%20%5Cleft%20%5B%20f_%7B%5Ctext%7Bsmall%7D%7D(x;%201),%20y%20%5Cright%20%5D%0A"></p>
<p>The first term is scaled by <img src="https://latex.codecogs.com/png.latex?T%5E2"> because the magnitudes of the gradients scale as <img src="https://latex.codecogs.com/png.latex?T%5E%7B-2%7D"> and we want to control the contribution of each term by changing only <img src="https://latex.codecogs.com/png.latex?a">.</p>
<div class="callout callout-style-default callout-note no-icon callout-titled" title="Why do the gradient magnitudes scale as $T^{-2}$?">
<div class="callout-header d-flex align-content-center collapsed" data-bs-toggle="collapse" data-bs-target=".callout-1-contents" aria-controls="callout-1" aria-expanded="false" aria-label="Toggle callout">
<div class="callout-icon-container">
<i class="callout-icon no-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Why do the gradient magnitudes scale as <img src="https://latex.codecogs.com/png.latex?T%5E%7B-2%7D">?
</div>
<div class="callout-btn-toggle d-inline-block border-0 py-1 ps-1 pe-0 float-end"><i class="callout-toggle"></i></div>
</div>
<div id="callout-1" class="callout-1-contents callout-collapse collapse">
<div class="callout-body-container callout-body">
<p>Let <img src="https://latex.codecogs.com/png.latex?z"> be the logits, then the output <img src="https://latex.codecogs.com/png.latex?i">th entry of the softmax layer with temperature <img src="https://latex.codecogs.com/png.latex?T"> is <img src="https://latex.codecogs.com/png.latex?%0A%5Csigma_T(z)_i%20=%20%5Cfrac%7Be%5E%7Bz_i/T%7D%7D%7B%5Csum_j%20e%5E%7Bz_j/T%7D%7D%20=%20%5Chat%20y_i%0A"> Plugging into the loss <img src="https://latex.codecogs.com/png.latex?%0AL(%5Chat%20y,%20y)%20=%20-%5Csum_i%20y_i%20%5Clog%20%5Cleft(%5Cfrac%7Be%5E%7Bz_i/T%7D%7D%7B%5Csum_j%20e%5E%7Bz_j/T%7D%7D%20%5Cright)%20=%20-%5Cfrac%7B1%7D%7BT%7D%5Csum_i%20y_i%20z_i%20+%20(1)%20%5Clog%20%5Cleft(%20%5Csum_j%20e%5E%7Bz_j/T%7D%20%5Cright)%0A"> and differentiating w.r.t. <img src="https://latex.codecogs.com/png.latex?z_i"> (don’t forget the chain rule), we get <img src="https://latex.codecogs.com/png.latex?%0A%5Cfrac%7B%5Cpartial%20L%7D%7B%5Cpartial%20z_i%7D%20=%20-%5Cfrac%7B1%7D%7BT%7D%20y_i%20+%20%5Cfrac%7B1%7D%7B%20%5Csum_j%20e%5E%7Bz_j/T%7D%7D%20%5Ctimes%20e%5E%7Bz_i/T%7D%20(1/T)%20=%20%5Cfrac%7B1%7D%7BT%7D(%5Csigma_T(z)_i%20-%20y_i)%0A"> So, we see that <img src="https://latex.codecogs.com/png.latex?%0A%7C%7C%20%5Cnabla%20L%7C%7C_2%5E2%20=%20%5Cfrac%7B1%7D%7BT%5E2%7D%20%5Csum_i%20(%5Csigma_T(z)_i%20-%20y_i)%5E2%20%5Cpropto%20L%5E%7B-2%7D%0A"></p>
</div>
</div>
</div>
</section>
<section id="mnist" class="level1">
<h1>MNIST</h1>
<p>We try out distillation on the small-scale MNIST experiment that the authors describe. They use a two-layer linear ReLU architecture with dropout, a jitter image augmentation, and max norm as regularization.</p>
<div id="cell-5" class="cell" data-execution_count="51">
<details class="code-fold">
<summary>Model definition</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb1" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">class</span> Model(nn.Module):</span>
<span id="cb1-2">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">'''</span></span>
<span id="cb1-3"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">    Used in MNIST experiments.</span></span>
<span id="cb1-4"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">    A two-layer linear ReLU network with dropout and max norm regularization.</span></span>
<span id="cb1-5"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">    '''</span></span>
<span id="cb1-6">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, hidden_size, max_norm <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">2.0</span>, drop_rate <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.5</span>):</span>
<span id="cb1-7">        <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">super</span>(Model, <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>).<span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>()</span>
<span id="cb1-8">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.max_norm <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> max_norm</span>
<span id="cb1-9">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.layers <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.Sequential(</span>
<span id="cb1-10">            nn.Flatten(),</span>
<span id="cb1-11">            nn.Linear(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">28</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">28</span>, hidden_size),</span>
<span id="cb1-12">            nn.ReLU(),</span>
<span id="cb1-13">            nn.Dropout(drop_rate),</span>
<span id="cb1-14">            nn.Linear(hidden_size, hidden_size),</span>
<span id="cb1-15">            nn.ReLU(),</span>
<span id="cb1-16">            nn.Dropout(drop_rate),</span>
<span id="cb1-17">            nn.Linear(hidden_size, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span>)</span>
<span id="cb1-18">        )</span>
<span id="cb1-19"></span>
<span id="cb1-20">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> forward(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, x):</span>
<span id="cb1-21">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Clip the weights to the maximum allowed norm</span></span>
<span id="cb1-22">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.max_norm <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">is</span> <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">not</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>:</span>
<span id="cb1-23">            <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">with</span> torch.no_grad():</span>
<span id="cb1-24">                <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> layer <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.modules():</span>
<span id="cb1-25">                    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">isinstance</span>(layer, nn.Linear):</span>
<span id="cb1-26">                        norm <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> layer.weight.data.norm(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, dim<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, keepdim<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)</span>
<span id="cb1-27">                        desired <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.clamp(norm, <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">max</span><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.max_norm)</span>
<span id="cb1-28">                        layer.weight.data <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*=</span> (desired <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> norm)</span>
<span id="cb1-29">                    </span>
<span id="cb1-30">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.layers(x)</span></code></pre></div></div>
</details>
</div>
<p>We define the distillation loss:</p>
<div id="cell-8" class="cell" data-execution_count="329">
<details open="" class="code-fold">
<summary>Define training losses</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb2" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Regular cross-entropy loss</span></span>
<span id="cb2-2"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> hard_loss(outputs, labels, criterion, <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span>args):</span>
<span id="cb2-3">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> criterion(outputs, labels)</span>
<span id="cb2-4"></span>
<span id="cb2-5"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Distillation loss</span></span>
<span id="cb2-6"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> soft_loss(outputs, labels, criterion, examples, big_model, T, a):</span>
<span id="cb2-7">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">with</span> torch.no_grad():</span>
<span id="cb2-8">        big_model.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">eval</span>()</span>
<span id="cb2-9">        soft_labels <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> F.softmax(big_model(examples) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> T, dim <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)</span>
<span id="cb2-10"></span>
<span id="cb2-11">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> a <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> (T <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">**</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> criterion(outputs <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> T, soft_labels) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> (<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> a) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> criterion(outputs, labels)</span></code></pre></div></div>
</details>
</div>
<p>The hidden dimensions of the big and small networks are 1200 and 800 respectively. To train the networks we use an early stopping validation set and choose <img src="https://latex.codecogs.com/png.latex?T%20=%204.0"> and <img src="https://latex.codecogs.com/png.latex?a%20=%200.5"> (since the authors don’t mention their values).</p>
<div id="cell-11" class="cell" data-execution_count="291">
<details class="code-fold">
<summary>Train the big model</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb3" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Hyperparameters</span></span>
<span id="cb3-2">num_epochs <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">100</span></span>
<span id="cb3-3">batch_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">128</span></span>
<span id="cb3-4">lr <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.01</span></span>
<span id="cb3-5">patience <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">7</span></span>
<span id="cb3-6">big_model_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1200</span></span>
<span id="cb3-7">small_model_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">800</span></span>
<span id="cb3-8"></span>
<span id="cb3-9">loader <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">lambda</span> ds, shuffle <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>: DataLoader(ds, batch_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> batch_size, shuffle <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> shuffle)</span>
<span id="cb3-10">val_loader   <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> loader(val_dataset)</span>
<span id="cb3-11">test_loader  <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> loader(test_dataset)</span>
<span id="cb3-12"></span>
<span id="cb3-13"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Train the big model</span></span>
<span id="cb3-14">train_dataset.dataset.transform <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> aug_transform</span>
<span id="cb3-15">train_loader <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> loader(train_dataset, shuffle <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)</span>
<span id="cb3-16"></span>
<span id="cb3-17">big_model <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> Model(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1200</span>).to(device)</span>
<span id="cb3-18">big_train_history <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> train(big_model, hard_loss, train_loader, val_loader, num_epochs, lr, patience)</span>
<span id="cb3-19">test_loss, test_accuracy <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> evaluate_model(big_model, test_loader)</span>
<span id="cb3-20"></span>
<span id="cb3-21">save_results(big_model, big_train_history, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'big_model'</span>)</span>
<span id="cb3-22">test_loss, test_accuracy</span></code></pre></div></div>
</details>
</div>
<div id="cell-12" class="cell" data-execution_count="434">
<details class="code-fold">
<summary>Train the smaller model on hard labels</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb4" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> train_small_model(train_dataset, val_dataset, seed, loss, model_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> small_model_size):</span>
<span id="cb4-2">    set_seed(seed)</span>
<span id="cb4-3">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># no augmentation</span></span>
<span id="cb4-4">    train_dataset.dataset.transform <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> reg_transform</span>
<span id="cb4-5">    train_loader <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> loader(train_dataset, shuffle <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)</span>
<span id="cb4-6">    val_loader <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> loader(val_dataset)</span>
<span id="cb4-7"></span>
<span id="cb4-8">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># or regularization</span></span>
<span id="cb4-9">    small_model <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> Model(model_size, max_norm <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>, drop_rate <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.0</span>).to(device)</span>
<span id="cb4-10">    small_train_history <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> train(small_model, loss, train_loader, val_loader, num_epochs, lr, patience)</span>
<span id="cb4-11">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> small_model, small_train_history</span>
<span id="cb4-12"></span>
<span id="cb4-13"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># small_model, small_train_history = train_small_model(train_dataset, val_dataset, seed = 42, loss = hard_loss, T = 1.0)</span></span>
<span id="cb4-14"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># save_results(small_model, small_train_history, 'small_model')</span></span>
<span id="cb4-15"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># evaluate_model(small_model, test_loader)</span></span></code></pre></div></div>
</details>
</div>
<div id="cell-13" class="cell">
<details class="code-fold">
<summary>Train the distilled model</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb5" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1">temperature, a <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">4.0</span>, <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.5</span></span>
<span id="cb5-2"></span>
<span id="cb5-3">distilled_model, distilled_train_history <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> train_small_model(</span>
<span id="cb5-4">    train_dataset <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> train_dataset, val_dataset <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> val_dataset, seed <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">42</span>,</span>
<span id="cb5-5">    loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> functools.partial(soft_loss, big_model <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> big_model, T <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> temperature, a <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> a)</span>
<span id="cb5-6">)</span>
<span id="cb5-7">save_results(distilled_model, distilled_train_history, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'distilled_model'</span>)</span>
<span id="cb5-8">evaluate_model(distilled_model, test_loader)</span></code></pre></div></div>
</details>
</div>
<p>And get the following test accuracies:</p>
<div id="cell-16" class="cell" data-execution_count="379">
<div class="cell-output cell-output-stdout">
<pre><code>big: 0.9901, small: 0.9833, distilled: 0.9891</code></pre>
</div>
</div>
</section>
<section id="mystical-3" class="level1">
<h1>Mystical 3</h1>
<p>The authors then remove 3 from the transfer set the distilled model is trained on to test its generalization to unseen classes. <em>“So from the perspective of the distilled model, 3 is a mythical digit that it has never seen</em>”. When we evaluate on the test set, which still contains 3s, we see that the distilled model performs much better than a small model trained with hard labels:</p>
<div id="cell-18" class="cell">
<details class="code-fold">
<summary>Train without 3s in transfer set</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb7" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Remove all 3s from the dataset</span></span>
<span id="cb7-2">train_dataset <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> datasets.MNIST(root <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> DATA_DIR, train <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>,  download <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>, transform <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> aug_transform)</span>
<span id="cb7-3">train_dataset <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.utils.data.Subset(train_dataset, np.where(train_dataset.targets <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">!=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span>)[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>])</span>
<span id="cb7-4"></span>
<span id="cb7-5"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Split training data into train and validation sets</span></span>
<span id="cb7-6">train_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>(<span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.9</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(train_dataset))</span>
<span id="cb7-7">val_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(train_dataset) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> train_size</span>
<span id="cb7-8">train_dataset, val_dataset <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.utils.data.random_split(</span>
<span id="cb7-9">    train_dataset, [train_size, val_size]</span>
<span id="cb7-10">)</span>
<span id="cb7-11"></span>
<span id="cb7-12"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Train small without distillation</span></span>
<span id="cb7-13">small_no_3, small_no_3_history <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> train_small_model(</span>
<span id="cb7-14">    train_dataset <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> train_dataset, val_dataset <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> val_dataset,</span>
<span id="cb7-15">    seed <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">42</span>, loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> hard_loss, model_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">800</span></span>
<span id="cb7-16">)</span>
<span id="cb7-17"></span>
<span id="cb7-18"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Train small with distillation</span></span>
<span id="cb7-19">distilled_no_3, distilled_no_3_history <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> train_small_model(</span>
<span id="cb7-20">    train_dataset <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> train_dataset, val_dataset <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> val_dataset,</span>
<span id="cb7-21">    seed <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">42</span>, loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> functools.partial(soft_loss, big_model <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> big_model, T <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">4.0</span>, a <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.5</span>),</span>
<span id="cb7-22">    model_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">800</span></span>
<span id="cb7-23">)</span></code></pre></div></div>
</details>
</div>
<div id="cell-19" class="cell">
<div class="cell-output cell-output-stdout">
<pre><code>Not distilled: 0.8882, distilled: 0.9869</code></pre>
</div>
</div>
<p>In the paper, the authors take it to the extreme and show that a distilled model trained only on 7 and 8 still achieves impressive performance. They also do experiments on a bigger speech recognition dataset and discuss training experts on a CV dataset with distillation from a generalist model as regularization.</p>
</section>
<section id="final-thoughts" class="level1">
<h1>Final thoughts</h1>
<p>It was very fun to return to this classic paper. It introduced a simple yet powerful idea that is still widely used today. Like most of these papers (circa 2015), it is very clear and readable. And—as Hinton staple—it is slightly bio-inspired, in this case by larvae.</p>
<p>Some pointers to papers that extended on this idea. <a href="https://arxiv.org/abs/1805.04770">Self-distillation</a> makes the teacher (“big”) and student (“small”) models the same size, and in <a href="https://arxiv.org/abs/1706.00384">mutual learning</a> two or more networks learn collaboratively. However, the main extensions of this paper build on its main theme: train on a richer signal. You might train the student to imitate the teacher’s <a href="https://arxiv.org/abs/1412.6550">intermediate</a> (or <a href="https://openreview.net/forum?id=ZzwDy_wiWv">last</a>) representations, <a href="https://arxiv.org/abs/1612.03928">attention maps</a>, etc.</p>


</section>

 ]]></description>
  <category>deep learning</category>
  <category>paper</category>
  <guid>https://ecntu.com/posts/distilling-knowledge/</guid>
  <pubDate>Tue, 28 Jan 2025 05:00:00 GMT</pubDate>
</item>
<item>
  <title>TENT: Fully Test-Time Adaptation By Entropy Minimization</title>
  <link>https://ecntu.com/posts/tent/</link>
  <description><![CDATA[ 





<p>Once a model is deployed the feature (covariate) data distribution might shift from that seen during training. These shifts make models go out-of-distribution and worsen their predictions. This paper proposes a simple method to help models adapt to these shifts: minimize the entropy of your predictions.</p>
<p>That is, before making test-time predictions for a batch, you nudge (SGD) the model to predict peakier (less entropic) class distributions.</p>
<p>Why minimize entropy?</p>
<p>Firstly, because it is convenient. In contrast to other methods, you don’t need to modify the training procedure nor require test-time labels. Because labels are rarely available at test time, this makes TENT <em>“fully test-time”</em>.</p>
<p>Second, the authors argue that entropy is related to both error and shifts:</p>
<blockquote class="blockquote">
<p>“Entropy is related to error, as more confident predictions are all-in-all more correct (Figure 1). Entropy is related to shifts due to corruption, as more corruption results in more entropy, with a strong rank correlation to the loss for image classification as the level of corruption increases (Figure 2).”</p>
</blockquote>
<p>To reproduce Figures 1 &amp; 2 we train a ResNet on CIFAR-10 and evaluate its predictions on corrupted versions of the test set to simulate test-time shifts.</p>
<p>(Note: while the authors also show results for CIFAR-100 and ImageNet, we’ll only deal with this small dataset and model for convenience.)</p>
<div id="cell-4" class="cell" data-execution_count="2">
<details class="code-fold">
<summary>Datasets</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb1" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1">corruption_types <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> [</span>
<span id="cb1-2">    <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'gaussian_noise'</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'shot_noise'</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'impulse_noise'</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'defocus_blur'</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'glass_blur'</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'motion_blur'</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'zoom_blur'</span>,</span>
<span id="cb1-3">    <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'snow'</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'frost'</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'fog'</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'brightness'</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'contrast'</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'elastic_transform'</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'pixelate'</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'jpeg_compression'</span></span>
<span id="cb1-4">]</span>
<span id="cb1-5"></span>
<span id="cb1-6"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># train_set = datasets.CIFAR10('../data', download = True, train = True,  transform = transforms.ToTensor())</span></span>
<span id="cb1-7">test_set <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>  datasets.CIFAR10(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'../data'</span>, download <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>, train <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>, transform <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> transforms.ToTensor())</span>
<span id="cb1-8">n_classes <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(test_set.classes)</span>
<span id="cb1-9"></span>
<span id="cb1-10"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> get_test_set(corr_type <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'brightness'</span>, data_path <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'data'</span>, severity <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">5</span>):</span>
<span id="cb1-11">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">assert</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&lt;=</span> severity <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&lt;=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">5</span></span>
<span id="cb1-12">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> corr_type <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> corruption_types:</span>
<span id="cb1-13">        X_test, y_test <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> load_cifar10c(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10_000</span>, severity, data_path, <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>, [corr_type])</span>
<span id="cb1-14">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> TensorDataset(X_test, y_test)</span>
<span id="cb1-15">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> test_set</span>
<span id="cb1-16"></span>
<span id="cb1-17"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> get_cifar10_model(model_path <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'models/cifar10_pretrained'</span>):</span>
<span id="cb1-18">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">try</span>:</span>
<span id="cb1-19">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> torch.load(model_path)</span>
<span id="cb1-20">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">except</span>:</span>
<span id="cb1-21">        m <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> load_model(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Standard'</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'models'</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'cifar10'</span>, ThreatModel.corruptions)</span>
<span id="cb1-22">        torch.save(m, model_path)</span>
<span id="cb1-23">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> torch.load(model_path)</span>
<span id="cb1-24"></span>
<span id="cb1-25"></span>
<span id="cb1-26">get_model <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">lambda</span>: load_model(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Standard'</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'models'</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'cifar10'</span>, ThreatModel.corruptions)</span>
<span id="cb1-27">model <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> get_model()</span></code></pre></div></div>
</details>
</div>
<div id="cell-5" class="cell" data-execution_count="19">
<div class="cell-output cell-output-display">
<div class="quarto-figure quarto-figure-center">
<figure class="figure">
<p><img src="https://ecntu.com/posts/tent/index_files/figure-html/cell-4-output-1.png" class="img-fluid figure-img"></p>
<figcaption>4 of 15 corruption types included in CIFAR-10-C, shown at the highest severity (5/5) level</figcaption>
</figure>
</div>
</div>
</div>
<div id="cell-6" class="cell" data-execution_count="24">
<details class="code-fold">
<summary>Reproduce Figs 1 &amp; 2</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb2" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1">model <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> get_model()</span>
<span id="cb2-2"></span>
<span id="cb2-3">c, e, l <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> [], [], []</span>
<span id="cb2-4">corruptions, severities <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> [], []</span>
<span id="cb2-5"></span>
<span id="cb2-6"><span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> corr, severity <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> itertools.product([<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> corruption_types, <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">6</span>)): <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># 'gaussian_noise', 'gaussian_blur', 'jpeg_compression', 'snow'</span></span>
<span id="cb2-7"></span>
<span id="cb2-8">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> corr <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">is</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span> <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">and</span> severity <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&gt;</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>: <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">continue</span></span>
<span id="cb2-9">    corrupted_test_set <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> get_test_set(corr_type <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> corr, severity <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> severity)</span>
<span id="cb2-10">   </span>
<span id="cb2-11">    test <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> DataLoader(corrupted_test_set, batch_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">128</span>, shuffle <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>)</span>
<span id="cb2-12"></span>
<span id="cb2-13">    model.to(device)<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">;</span> model.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">eval</span>()</span>
<span id="cb2-14">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">with</span> torch.no_grad():</span>
<span id="cb2-15">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> images, labels <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> test:</span>
<span id="cb2-16">            images, labels <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> images.to(device), labels.to(device)</span>
<span id="cb2-17">            logits <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> model(images)</span>
<span id="cb2-18">            _, pred <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">max</span>(logits, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)</span>
<span id="cb2-19">            correct <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> (pred <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> labels).<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span>()</span>
<span id="cb2-20">            entropy <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span>(logits.softmax(dim <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> logits.log_softmax(dim <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)).<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">sum</span>(dim <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)</span>
<span id="cb2-21">            loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.CrossEntropyLoss(<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">reduce</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>)(logits, labels)</span>
<span id="cb2-22">            c.append(correct)</span>
<span id="cb2-23">            e.append(entropy)</span>
<span id="cb2-24">            l.append(loss)</span>
<span id="cb2-25">            corruptions.extend([corr] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(correct))</span>
<span id="cb2-26">            severities.extend([severity] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(correct))</span>
<span id="cb2-27"></span>
<span id="cb2-28">correct <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.cat(c).cpu().numpy()</span>
<span id="cb2-29">entropy <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.cat(e).cpu().numpy()</span>
<span id="cb2-30">loss    <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.cat(l).cpu().numpy()</span></code></pre></div></div>
</details>
</div>
<div id="cell-fig-1-2" class="cell quarto-layout-panel" data-layout-ncol="2" data-execution_count="30">
<div class="quarto-layout-row">
<div class="cell-output cell-output-display quarto-layout-cell" style="flex-basis: 50.0%;justify-content: flex-start;">
<div id="fig-1-2" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-fig figure">
<div aria-describedby="fig-1-2-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<img src="https://ecntu.com/posts/tent/index_files/figure-html/fig-1-2-output-1.png" class="img-fluid figure-img">
</div>
<figcaption class="quarto-float-caption-bottom quarto-float-caption quarto-float-fig" id="fig-1-2-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Figure&nbsp;1: Preds with less entropy have lower error rates.
</figcaption>
</figure>
</div>
</div>
<div class="cell-output cell-output-display quarto-layout-cell" style="flex-basis: 50.0%;justify-content: flex-start;">
<div id="fig-1-2" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-fig figure">
<div aria-describedby="fig-1-2-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<img src="https://ecntu.com/posts/tent/index_files/figure-html/fig-1-2-output-2.png" class="img-fluid figure-img">
</div>
<figcaption class="quarto-float-caption-bottom quarto-float-caption quarto-float-fig" id="fig-1-2-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Figure&nbsp;2: More corruption (shown as alpha) leads to higher loss and entropy.
</figcaption>
</figure>
</div>
</div>
</div>
</div>
<p>The intuition here, as far as I can tell, is that entropy encodes the model’s confidence. If the model’s prediction is confident, it is all-in-all more probable to be correct (it might have seen similar examples during training, the example might be “easy”, etc). Corruptions take the model OOD and decrease its confidence. Since cross-entropy is lowest when all probability mass is assigned to the correct label, increasing entropy (all-in-all) dilutes that mass and increases loss.</p>
<p>Two important notes on <em>how</em> entropy is minimized:</p>
<p>First, the authors note that once we switch the model to entropy minimization we run the risk of causing it to deviate from its training. While you could choose a sufficiently small learning rate or add KL regularization to alleviate this, the authors opt for freezing most of the model and only updating the learnable parameters in the batch norm layers.</p>
<p>Second, we must use batches. If we minimize single examples the model might just learn to assign more mass to that class instead of something more generalizable.</p>
<div id="cell-fig-demo" class="cell" data-execution_count="25">
<details class="code-fold">
<summary>TENT example</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb3" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Only update BN layers</span></span>
<span id="cb3-2"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> prepare_for_test_time(module, reset_stats <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>):</span>
<span id="cb3-3">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">isinstance</span>(module, nn.BatchNorm2d):</span>
<span id="cb3-4">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> reset_stats: module.reset_running_stats()</span>
<span id="cb3-5">        module.requires_grad <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span></span>
<span id="cb3-6">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">else</span>: module.requires_grad <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span> </span>
<span id="cb3-7"></span>
<span id="cb3-8">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> m <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> module.children(): prepare_for_test_time(m, reset_stats)</span>
<span id="cb3-9"></span>
<span id="cb3-10"></span>
<span id="cb3-11"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Init the model &amp; optimizer</span></span>
<span id="cb3-12">model <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> get_cifar10_model()<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">;</span> model.to(device)</span>
<span id="cb3-13">corr_test_set <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> get_test_set(corr_type <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'gaussian_noise'</span>, severity <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">5</span>)</span>
<span id="cb3-14">model.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">apply</span>(functools.partial(prepare_for_test_time, reset_stats <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>))</span>
<span id="cb3-15">optimizer <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> optim.AdamW(model.parameters(), lr <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.00001</span>)</span>
<span id="cb3-16"></span>
<span id="cb3-17"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Get a batch of corrupted images</span></span>
<span id="cb3-18">images, labels <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">next</span>(<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">iter</span>(DataLoader(corr_test_set, batch_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">128</span>)))</span>
<span id="cb3-19">images, labels <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> images.to(device), labels.to(device)</span>
<span id="cb3-20"></span>
<span id="cb3-21"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Minimize entropy</span></span>
<span id="cb3-22">model.train()</span>
<span id="cb3-23">optimizer.zero_grad()</span>
<span id="cb3-24">preds <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> model(images)</span>
<span id="cb3-25">entropy <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span>(preds.softmax(dim <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> preds.log_softmax(dim <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)).<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">sum</span>(dim <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)</span>
<span id="cb3-26">loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> entropy.mean()</span>
<span id="cb3-27">loss.backward()</span>
<span id="cb3-28">optimizer.step()</span>
<span id="cb3-29"></span>
<span id="cb3-30">new_preds <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> model(images)</span>
<span id="cb3-31"></span>
<span id="cb3-32"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Plot</span></span>
<span id="cb3-33">f, axs <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> plt.subplots(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span>, figsize <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> (<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">8</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span>))</span>
<span id="cb3-34">ix <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">89</span></span>
<span id="cb3-35">axs[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>].imshow(test_set[ix][<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>].permute(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>).cpu().numpy())</span>
<span id="cb3-36">axs[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>].set_title(test_set.classes[test_set[ix][<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>]])</span>
<span id="cb3-37">axs[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>].set_xticks([])<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">;</span> axs[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>].set_yticks([])</span>
<span id="cb3-38"></span>
<span id="cb3-39">axs[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>].imshow(images[ix].permute(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>).cpu().numpy())</span>
<span id="cb3-40">axs[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>].set_title(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'corrupted'</span>)</span>
<span id="cb3-41">axs[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>].set_xticks([])<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">;</span> axs[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>].set_yticks([])</span>
<span id="cb3-42"></span>
<span id="cb3-43">order <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.argsort(preds.softmax(dim <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)[ix]).detach().cpu().numpy()[::<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>]</span>
<span id="cb3-44">rows <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> [{<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'class'</span>: i, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'prob'</span>:p, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'type'</span>: <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'unadapted'</span>} <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> i, p <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">enumerate</span>(preds.softmax(dim <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)[ix].detach().cpu().numpy()[order])]</span>
<span id="cb3-45">rows.extend([{<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'class'</span>: i, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'prob'</span>:p, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'type'</span>: <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'TENT'</span>} <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> i, p <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">enumerate</span>(new_preds.softmax(dim <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)[ix].detach().cpu().numpy()[order])])</span>
<span id="cb3-46">sns.barplot(x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'class'</span>, y <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'prob'</span>, hue <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'type'</span>, data <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> pd.DataFrame(rows), ax <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> axs[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>])</span>
<span id="cb3-47">axs[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>].set_title(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'class distribution'</span>)</span>
<span id="cb3-48">axs[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>].legend()</span>
<span id="cb3-49">axs[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>].set_xticks([])<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">;</span> axs[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>].set_yticks([])</span>
<span id="cb3-50"></span>
<span id="cb3-51">f.tight_layout()</span></code></pre></div></div>
</details>
<div class="cell-output cell-output-display">
<div id="fig-demo" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-fig figure">
<div aria-describedby="fig-demo-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<img src="https://ecntu.com/posts/tent/index_files/figure-html/fig-demo-output-1.png" class="img-fluid figure-img">
</div>
<figcaption class="quarto-float-caption-bottom quarto-float-caption quarto-float-fig" id="fig-demo-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Figure&nbsp;3: TENT learns to output class distributions with less entropy, with the hope of doing better in out-of-distribution (here corrupted) settings.
</figcaption>
</figure>
</div>
</div>
</div>
<p>Now for evaluation. While the authors consider other baselines, for simplicity, we only compare TENT against the unadapted source model and a test-time normalization <a href="https://arxiv.org/abs/2006.16971">method</a> (“Norm”) which just updates the BN statistics during testing.</p>
<div id="cell-11" class="cell">
<details class="code-fold">
<summary>Eval unadapted model</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb4" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> eval_source(init_model_fn, severity <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">5</span>, batch_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">128</span>):</span>
<span id="cb4-2">    model <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> init_model_fn()<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">;</span> model.to(device)</span>
<span id="cb4-3">    results <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> {}</span>
<span id="cb4-4">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> corr_type <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> corruption_types:</span>
<span id="cb4-5">        corr_test_set <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> get_test_set(corr_type <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> corr_type, severity <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> severity)</span>
<span id="cb4-6">        _, source_acc <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> eval_model(model, DataLoader(corr_test_set, batch_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> batch_size))</span>
<span id="cb4-7">        results[corr_type] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> source_acc</span>
<span id="cb4-8">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> results</span>
<span id="cb4-9"></span>
<span id="cb4-10">source_results <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> eval_source(get_model)</span>
<span id="cb4-11">torch.save(source_results, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'logs/source_results'</span>)</span></code></pre></div></div>
</details>
</div>
<div id="cell-12" class="cell">
<details class="code-fold">
<summary>Eval Norm</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb5" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> reset_bn_stats(module):</span>
<span id="cb5-2">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">isinstance</span>(module, nn.BatchNorm2d):</span>
<span id="cb5-3">        module.reset_running_stats()</span>
<span id="cb5-4">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> m <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> module.children(): reset_bn_stats(m)</span>
<span id="cb5-5"></span>
<span id="cb5-6"></span>
<span id="cb5-7"><span class="at" style="color: #657422;
background-color: null;
font-style: inherit;">@torch.no_grad</span>()</span>
<span id="cb5-8"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> eval_norm(init_model_fn, severity <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">5</span>, batch_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">128</span>, reset_stats <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>, corr_types <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>):</span>
<span id="cb5-9"></span>
<span id="cb5-10">    results_acc <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> {}</span>
<span id="cb5-11">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> corr_types <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">is</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>: corr_types <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> [<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> corruption_types</span>
<span id="cb5-12">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> corr_type <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> corr_types:</span>
<span id="cb5-13">        <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(corr_type)</span>
<span id="cb5-14">        </span>
<span id="cb5-15">        model <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> init_model_fn()<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">;</span> model.to(device)  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Re-init the model</span></span>
<span id="cb5-16">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> reset_stats: model.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">apply</span>(reset_bn_stats)</span>
<span id="cb5-17">        corr_test_set <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> get_test_set(corr_type <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> corr_type, severity <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> severity)</span>
<span id="cb5-18"></span>
<span id="cb5-19">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> i, (images, labels) <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">enumerate</span>(DataLoader(corr_test_set, batch_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> batch_size)):</span>
<span id="cb5-20">            images, labels <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> images.to(device), labels.to(device)</span>
<span id="cb5-21"></span>
<span id="cb5-22">            <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Update the BN stats</span></span>
<span id="cb5-23">            model.train()</span>
<span id="cb5-24">            preds <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> model(images.to(device))</span>
<span id="cb5-25"></span>
<span id="cb5-26">            err <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> (torch.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">max</span>(preds, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">!=</span> labels).<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span>().<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">sum</span>().item() <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> labels.shape[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>]</span>
<span id="cb5-27">            results_acc[corr_type] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> results_acc.get(corr_type, []) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> [err]</span>
<span id="cb5-28">            <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> i <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">15</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(err)</span>
<span id="cb5-29">        </span>
<span id="cb5-30">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> results_acc</span>
<span id="cb5-31"></span>
<span id="cb5-32">norm_results <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> eval_norm(get_cifar10_model, reset_stats <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>)</span>
<span id="cb5-33">torch.save(norm_results, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'logs/norm_results_all'</span>)</span></code></pre></div></div>
</details>
</div>
<div id="cell-13" class="cell" data-execution_count="10">
<details class="code-fold">
<summary>Eval TENT</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb6" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Only update BN layers</span></span>
<span id="cb6-2"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> prepare_for_test_time(module, reset_stats <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>):</span>
<span id="cb6-3">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">isinstance</span>(module, nn.BatchNorm2d):</span>
<span id="cb6-4">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> reset_stats: module.reset_running_stats()</span>
<span id="cb6-5">        module.requires_grad <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span></span>
<span id="cb6-6">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">else</span>: module.requires_grad <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span> </span>
<span id="cb6-7"></span>
<span id="cb6-8">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> m <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> module.children(): prepare_for_test_time(m, reset_stats)</span>
<span id="cb6-9"></span>
<span id="cb6-10"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> eval_tent(init_model_fn, severity <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">5</span>, lr <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.001</span>, batch_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">128</span>, reset_stats <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>, corr_types <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>):</span>
<span id="cb6-11"></span>
<span id="cb6-12">    results_acc <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> {}</span>
<span id="cb6-13">    results_e <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> {}</span>
<span id="cb6-14">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> corr_types <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">is</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>: corr_types <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> [<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> corruption_types</span>
<span id="cb6-15"></span>
<span id="cb6-16">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> corr_type <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> corr_types:</span>
<span id="cb6-17">        <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(corr_type)</span>
<span id="cb6-18"></span>
<span id="cb6-19">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Re-init the model &amp; optimizer</span></span>
<span id="cb6-20">        model <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> init_model_fn()<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">;</span> model.to(device)</span>
<span id="cb6-21">        corr_test_set <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> get_test_set(corr_type <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> corr_type, severity <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> severity)</span>
<span id="cb6-22">        model.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">apply</span>(functools.partial(prepare_for_test_time, reset_stats <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> reset_stats))</span>
<span id="cb6-23">        optimizer <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> optim.AdamW(model.parameters(), lr <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> lr)</span>
<span id="cb6-24"></span>
<span id="cb6-25">        corr_test_set <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> get_test_set(corr_type <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> corr_type, severity <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> severity)</span>
<span id="cb6-26"></span>
<span id="cb6-27">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> i, (images, labels) <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">enumerate</span>(DataLoader(corr_test_set, batch_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> batch_size)):</span>
<span id="cb6-28">            images, labels <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> images.to(device), labels.to(device)</span>
<span id="cb6-29"></span>
<span id="cb6-30">            <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Minimize entropy</span></span>
<span id="cb6-31">            model.train()</span>
<span id="cb6-32">            optimizer.zero_grad()</span>
<span id="cb6-33">            preds <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> model(images)</span>
<span id="cb6-34">            entropy <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span>(preds.softmax(dim <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> preds.log_softmax(dim <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)).<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">sum</span>(dim <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)</span>
<span id="cb6-35">            loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> entropy.mean()</span>
<span id="cb6-36">            loss.backward()</span>
<span id="cb6-37">            optimizer.step()</span>
<span id="cb6-38"></span>
<span id="cb6-39">            err <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> (torch.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">max</span>(preds, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">!=</span> labels).<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span>().<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">sum</span>().item() <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> labels.shape[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>]</span>
<span id="cb6-40">            results_acc[corr_type] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> results_acc.get(corr_type, []) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> [err]</span>
<span id="cb6-41">            results_e[corr_type] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> results_e.get(corr_type, []) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> [loss.item()]</span>
<span id="cb6-42">            <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> i <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">15</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(err)</span>
<span id="cb6-43"></span>
<span id="cb6-44">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> results_acc, results_e</span>
<span id="cb6-45"></span>
<span id="cb6-46"></span>
<span id="cb6-47"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># tent_acc, tent_entropy = eval_tent(get_cifar10_model, reset_stats = False, lr = 0.00001)</span></span>
<span id="cb6-48"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># torch.save(tent_acc, 'logs/tent_acc'); torch.save(tent_entropy, 'logs/tent_entropy')</span></span>
<span id="cb6-49"></span>
<span id="cb6-50"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># tent_acc_r, tent_entropy_r = eval_tent(get_model, reset_stats = True, lr = 0.00001)</span></span>
<span id="cb6-51"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># torch.save(tent_acc_r, 'logs/tent_acc_r'); torch.save(tent_entropy_r, 'logs/tent_entropy_r')</span></span></code></pre></div></div>
</details>
</div>
<div id="cell-fig-results" class="cell" data-execution_count="26">
<div class="cell-output cell-output-display">
<div id="fig-results" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-fig figure">
<div aria-describedby="fig-results-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<img src="https://ecntu.com/posts/tent/index_files/figure-html/fig-results-output-1.png" class="img-fluid figure-img">
</div>
<figcaption class="quarto-float-caption-bottom quarto-float-caption quarto-float-fig" id="fig-results-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Figure&nbsp;4: TENT &amp; Norm consistently outperform the unadapted model, with TENT (lr = 1e-5, batch_size = 128) taking a slight lead.
</figcaption>
</figure>
</div>
</div>
</div>
<div id="cell-16" class="cell" data-execution_count="11">
<details class="code-fold">
<summary>Hyperparam grid</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb7" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1">grid_results <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> {}</span>
<span id="cb7-2"><span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> lr, b <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> itertools.product([<span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">1e-4</span>, <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">1e-5</span>, <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">1e-6</span>], <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">9</span>)):</span>
<span id="cb7-3">    batch_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">**</span> b)</span>
<span id="cb7-4">    <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(lr, batch_size)</span>
<span id="cb7-5">    tent_acc, tent_entropy <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> eval_tent(get_cifar10_model, reset_stats <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>, lr <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> lr, batch_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> batch_size, corr_types <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> [<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'gaussian_noise'</span>])</span>
<span id="cb7-6">    grid_results[(lr, batch_size)] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> (tent_acc, tent_entropy)</span>
<span id="cb7-7"></span>
<span id="cb7-8">torch.save(grid_results, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'logs/grid_results_all'</span>)</span>
<span id="cb7-9"></span>
<span id="cb7-10">grid_results_norm <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> {}</span>
<span id="cb7-11"><span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> b <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">9</span>):</span>
<span id="cb7-12">    batch_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">**</span> b)</span>
<span id="cb7-13">    <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(batch_size)</span>
<span id="cb7-14">    acc <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> eval_norm(get_cifar10_model, reset_stats <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>, batch_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> batch_size, corr_types <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> [<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'gaussian_noise'</span>])</span>
<span id="cb7-15">    grid_results_norm[(batch_size)] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> acc</span>
<span id="cb7-16"></span>
<span id="cb7-17">torch.save(grid_results_norm, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'logs/grid_results_norm'</span>)</span></code></pre></div></div>
</details>
</div>
<p>The paper shows TENT having more of a lead on this dataset, but this is the best I could do.</p>
<p>How sensitive is it to hyperparameters? TENT has two: the test-time learning rate and batch size. We vary these and show results for the <code>gaussian_noise</code> corruption.</p>
<div id="cell-fig-grid" class="cell quarto-layout-panel" data-layout-ncol="2" data-execution_count="9">
<div class="quarto-layout-row">
<div class="cell-output cell-output-display quarto-layout-cell" style="flex-basis: 50.0%;justify-content: flex-start;">
<div id="fig-grid" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-fig figure">
<div aria-describedby="fig-grid-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<img src="https://ecntu.com/posts/tent/index_files/figure-html/fig-grid-output-1.png" class="img-fluid figure-img">
</div>
<figcaption class="quarto-float-caption-bottom quarto-float-caption quarto-float-fig" id="fig-grid-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Figure&nbsp;5: TENT is sensitive to learning rate and batch size. Note that in practice these are scaled together.
</figcaption>
</figure>
</div>
</div>
<div class="cell-output cell-output-display quarto-layout-cell" style="flex-basis: 50.0%;justify-content: flex-start;">
<div id="fig-grid" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-fig figure">
<div aria-describedby="fig-grid-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<img src="https://ecntu.com/posts/tent/index_files/figure-html/fig-grid-output-2.png" class="img-fluid figure-img">
</div>
<figcaption class="quarto-float-caption-bottom quarto-float-caption quarto-float-fig" id="fig-grid-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Figure&nbsp;6: TENT reaches lower entropy than Norm.
</figcaption>
</figure>
</div>
</div>
</div>
</div>
<p>You can see that TENT seems quite sensitive to hyperparameters, which is a <a href="https://arxiv.org/abs/2306.03536">common challenge</a> to all Test-Time Adaptation methods. There definitely seems to be an entropy sweet-spot – presumably specific to the dataset and shift – controlled by the learning rate and batch size.</p>
<p>So? TENT has some drawbacks. It is limited to classification (as far as I can tell), can <a href="https://arxiv.org/pdf/2410.10894v1">make models overconfident</a>, cannot be applied online (we need batches), and is sensitive to hyperparameters. However, you do not need test-time labels and can use a pretrained model. Most importantly, the technique is simple, intuitive, and seems to work.</p>
<p><em>All-in-all</em> it was an interesting paper, introduced me to the test-time adaptation literature, and was worth the read.</p>



 ]]></description>
  <category>deep learning</category>
  <category>paper</category>
  <guid>https://ecntu.com/posts/tent/</guid>
  <pubDate>Sun, 29 Dec 2024 05:00:00 GMT</pubDate>
  <media:content url="https://ecntu.com/posts/tent/index_files/figure-html/fig-demo-output-1.png" medium="image" type="image/png" height="53" width="144"/>
</item>
<item>
  <title>A Closer Look at Memorization in Deep Networks</title>
  <link>https://ecntu.com/posts/nn-memorization/</link>
  <description><![CDATA[ 





<p>This paper argues that memorization is a behavior exhibited by networks trained on random data, as, in the absence of patterns, they can only rely on remembering examples. The authors investigate this phenomenon and make three key claims:</p>
<ol type="1">
<li>Networks do not exclusively memorize data.</li>
<li>Networks initially learn simple patterns before resorting to memorization.</li>
<li>Regularization prevents memorization and promotes generalization.</li>
</ol>
<p>Here we aim to reproduce Figures 1, 7, and 8 from the paper.</p>
<section id="fig-1" class="level2">
<h2 class="anchored" data-anchor-id="fig-1">Fig 1</h2>
<p>To support the first claim, the authors argue that if networks simply memorize inputs they perform equally on different training examples. However, if networks learn patterns, there should be points that are easy to learn because they fit these patterns better than others. To see if this is the case they train an MLP for a single epoch starting from 100 different initializations and data shufflings and log the percentage of times an example was correctly classified.</p>
<p>The experiment is performed with the CIFAR10 dataset, a noisy input version <em>RandX</em>, and a noisy label version <em>RandY</em>. We first define dataset wrappers to implement the noisy variants. Note that for epoch-to-epoch consistency we determine which examples to corrupt at initialization.</p>
<div id="cell-6" class="cell" data-execution_count="265">
<details class="code-fold">
<summary>Random dataset wrappers</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb1" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">class</span> RandX(Dataset):</span>
<span id="cb1-2">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">"""Injects example noise into dataset by replacing x% of inputs with random gaussian N(0, 1) noise"""</span></span>
<span id="cb1-3">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, dataset, x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">1.0</span>):</span>
<span id="cb1-4">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.dataset <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> dataset</span>
<span id="cb1-5">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> x</span>
<span id="cb1-6"></span>
<span id="cb1-7">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.modified <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> {}</span>
<span id="cb1-8">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> idx, (img, _) <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">enumerate</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.dataset):</span>
<span id="cb1-9">            <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> np.random.rand() <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&lt;=</span> x:</span>
<span id="cb1-10">                <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.modified[idx] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.randn_like(img)</span>
<span id="cb1-11">        torch.save(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.modified, os.path.join(dataset.root, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'randX_modified'</span>))</span>
<span id="cb1-12"></span>
<span id="cb1-13">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__len__</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>): <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.dataset)</span>
<span id="cb1-14"></span>
<span id="cb1-15">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__getitem__</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, idx):</span>
<span id="cb1-16">        X, y <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.dataset[idx]</span>
<span id="cb1-17">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.modified.get(idx, X), y </span>
<span id="cb1-18"></span>
<span id="cb1-19"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">class</span> RandY(Dataset):</span>
<span id="cb1-20">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">"""Injects example noise into dataset by replacing y% of labels with random labels"""</span></span>
<span id="cb1-21">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, dataset, y <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">1.0</span>):</span>
<span id="cb1-22">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.dataset <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> dataset</span>
<span id="cb1-23">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.y <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> y</span>
<span id="cb1-24"></span>
<span id="cb1-25">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.modified <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> {}</span>
<span id="cb1-26">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> idx <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.dataset)):</span>
<span id="cb1-27">            <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> np.random.rand() <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&lt;=</span> y:</span>
<span id="cb1-28">                <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.modified[idx] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.random.randint(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.dataset.classes))</span>
<span id="cb1-29">        torch.save(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.modified, os.path.join(dataset.root, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'randY_modified'</span>))</span>
<span id="cb1-30"></span>
<span id="cb1-31">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__len__</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>): <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.dataset)</span>
<span id="cb1-32"></span>
<span id="cb1-33">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__getitem__</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, idx):</span>
<span id="cb1-34">        X, y <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.dataset[idx]</span>
<span id="cb1-35">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> X, <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.modified.get(idx, y)</span></code></pre></div></div>
</details>
</div>
<p>Now we define a standard training loop, initialization functions, and the MLP specified in the paper.</p>
<div id="cell-8" class="cell" data-execution_count="8">
<details class="code-fold">
<summary>Model initing and training functions</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb2" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> train(model, train, val, optimizer, criterion <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.CrossEntropyLoss(), epochs <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">100</span>, batch_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">256</span>, save_path <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'models/tmp'</span>):</span>
<span id="cb2-2">    model.to(device)</span>
<span id="cb2-3">    train <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> DataLoader(train, batch_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> batch_size, shuffle <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)</span>
<span id="cb2-4">    val <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> DataLoader(val, batch_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> batch_size, shuffle <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>)</span>
<span id="cb2-5">    best_loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.inf</span>
<span id="cb2-6">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> epoch <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(epochs):</span>
<span id="cb2-7">        model.train()</span>
<span id="cb2-8">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> images, labels <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> train:</span>
<span id="cb2-9">            images, labels <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> images.to(device), labels.to(device)</span>
<span id="cb2-10">            optimizer.zero_grad()</span>
<span id="cb2-11">            loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> criterion(model(images), labels)</span>
<span id="cb2-12">            loss.backward()</span>
<span id="cb2-13">            optimizer.step()</span>
<span id="cb2-14">        val_loss, val_acc <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> eval_model(model, val)</span>
<span id="cb2-15">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> val_loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&lt;</span> best_loss:</span>
<span id="cb2-16">            best_loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> val_loss</span>
<span id="cb2-17">            torch.save(model.state_dict(), save_path)</span>
<span id="cb2-18">        <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(<span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f'Epoch </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>epoch <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">/</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>epochs<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">, Loss: </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>val_loss<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">:.4f}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">, Accuracy: </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>val_acc<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">:.4f}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">'</span>)</span>
<span id="cb2-19"></span>
<span id="cb2-20"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> _init_weights(m):</span>
<span id="cb2-21">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Linear'</span> <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">str</span>(<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">type</span>(m)):</span>
<span id="cb2-22">        nn.init.xavier_uniform_(m.weight)</span>
<span id="cb2-23">        nn.init.zeros_(m.bias)</span>
<span id="cb2-24">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">elif</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">isinstance</span>(m, nn.Conv2d):</span>
<span id="cb2-25">        nn.init.kaiming_normal_(m.weight, mode<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"fan_out"</span>, nonlinearity<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"relu"</span>)</span>
<span id="cb2-26">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">elif</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">isinstance</span>(m, (nn.BatchNorm2d, nn.GroupNorm)):</span>
<span id="cb2-27">        nn.init.constant_(m.weight, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)</span>
<span id="cb2-28">        nn.init.constant_(m.bias, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>)</span>
<span id="cb2-29"></span>
<span id="cb2-30"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> initialize_model(model, data_loader, in_device <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>):</span>
<span id="cb2-31">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">with</span> torch.no_grad():</span>
<span id="cb2-32">        imgs, _ <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">next</span>(<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">iter</span>(data_loader))</span>
<span id="cb2-33">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> in_device:model.to(device)<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">;</span> model(imgs.to(device))</span>
<span id="cb2-34">        _ <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> model(imgs)</span>
<span id="cb2-35">        model.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">apply</span>(_init_weights)</span></code></pre></div></div>
</details>
</div>
<div id="cell-9" class="cell" data-execution_count="248">
<details class="code-fold">
<summary>MLP definition</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb3" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">class</span> MLP(nn.Module):</span>
<span id="cb3-2">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, n_classes <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span>):</span>
<span id="cb3-3">        <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">super</span>().<span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>()</span>
<span id="cb3-4">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.model <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.Sequential(</span>
<span id="cb3-5">            nn.Flatten(),</span>
<span id="cb3-6">            nn.LazyLinear(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4096</span>),</span>
<span id="cb3-7">            nn.ReLU(),</span>
<span id="cb3-8">            nn.LazyLinear(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4096</span>),</span>
<span id="cb3-9">            nn.ReLU(),</span>
<span id="cb3-10">            nn.LazyLinear(n_classes)</span>
<span id="cb3-11">        )</span>
<span id="cb3-12"></span>
<span id="cb3-13">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> forward(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, x):</span>
<span id="cb3-14">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.model(x)</span></code></pre></div></div>
</details>
</div>
<p>And run the experiment, training models for a single epoch as in the paper but also after 10 epochs to investigate how results vary.</p>
<div id="cell-11" class="cell">
<details class="code-fold">
<summary>Get estimated P(correct)</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb4" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> add_missclassified(missclassified, model, test_set, batch_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">256</span>):</span>
<span id="cb4-2">    model <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> model.to(device)<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">;</span> model.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">eval</span>()</span>
<span id="cb4-3">    test <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> DataLoader(test_set, batch_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> batch_size, shuffle <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>)</span>
<span id="cb4-4">    i <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span></span>
<span id="cb4-5">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">with</span> torch.no_grad():</span>
<span id="cb4-6">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> images, labels <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> test:</span>
<span id="cb4-7">            images, labels <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> images.to(device), labels.to(device)</span>
<span id="cb4-8">            _, pred <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">max</span>(model(images), <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)</span>
<span id="cb4-9">            missclassified[i:i <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> test.batch_size] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+=</span> (pred <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">!=</span> labels).<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span>()</span>
<span id="cb4-10">            i <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+=</span> test.batch_size</span>
<span id="cb4-11"></span>
<span id="cb4-12"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> gen_fig_1(epochs <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, n_inits <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">100</span>):</span>
<span id="cb4-13">    training_sets <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> [train_set, RandX(train_set, x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">1.0</span>), RandY(train_set, y <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">1.0</span>)]</span>
<span id="cb4-14">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> training_set <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> training_sets:</span>
<span id="cb4-15">        missclassified <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.zeros(<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(test_set)).to(device)</span>
<span id="cb4-16">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> _ <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(n_inits):</span>
<span id="cb4-17">            m <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> MLP()</span>
<span id="cb4-18">            initialize_model(m, DataLoader(train_set, batch_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">256</span>))</span>
<span id="cb4-19">            train(m, training_set, test_set, optim.SGD(m.parameters(), lr <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.01</span>), epochs <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> epochs)</span>
<span id="cb4-20">            add_missclassified(missclassified, m, test_set)</span>
<span id="cb4-21">        missclassified <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/=</span> n_inits</span>
<span id="cb4-22">        torch.save(missclassified, <span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f'logs/missclassified_epochs=</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>epochs<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">_'</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> training_set.<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">__class__</span>.<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">__name__</span>)</span>
<span id="cb4-23"></span>
<span id="cb4-24">gen_fig_1(epochs <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, n_inits <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">100</span>)</span>
<span id="cb4-25">gen_fig_1(epochs <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span>, n_inits <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">100</span>)</span></code></pre></div></div>
</details>
</div>
<div id="cell-12" class="cell" data-execution_count="249">
<details class="code-fold">
<summary>Plot results</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb5" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1">f, (ax1, ax2) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> plt.subplots(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, figsize <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> (<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">9</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span>), sharey <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)</span>
<span id="cb5-2"></span>
<span id="cb5-3"><span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> ax, epochs <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">zip</span>([ax1, ax2], [<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span>]):</span>
<span id="cb5-4"></span>
<span id="cb5-5">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> fname <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">sorted</span>([f <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> f <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> os.listdir(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'logs'</span>) <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'missclassified'</span> <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> f <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">and</span> <span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f'epochs=</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>epochs<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">_'</span> <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> f]):</span>
<span id="cb5-6">        missclassified <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.load(os.path.join(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'logs'</span>, fname))</span>
<span id="cb5-7">        p <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> (<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> missclassified).sort().values.to(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'cpu'</span>).numpy()</span>
<span id="cb5-8">        ax.plot(p, label <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> fname.split(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'_'</span>)[<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>])</span>
<span id="cb5-9">    </span>
<span id="cb5-10">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Plot binomially sampled points</span></span>
<span id="cb5-11">    randX_mean <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> torch.load(os.path.join(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'logs'</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'missclassified_RandX'</span>)).mean().item()</span>
<span id="cb5-12">    bin_data <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.random.binomial(n <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">100</span>, p <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> randX_mean, size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10000</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">100</span></span>
<span id="cb5-13">    bin_data.sort()</span>
<span id="cb5-14">    ax.plot(bin_data, label <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Binomial_X'</span>)</span>
<span id="cb5-15">    ax.set_title(<span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f'Epoch = </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>epochs<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">'</span>)</span>
<span id="cb5-16"></span>
<span id="cb5-17">ax1.legend()</span>
<span id="cb5-18">f.supylabel(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'P(correct)'</span>)</span>
<span id="cb5-19">f.supxlabel(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Example(sorted by P(correct))'</span>)</span>
<span id="cb5-20">f.tight_layout()</span></code></pre></div></div>
</details>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://ecntu.com/posts/nn-memorization/index_files/figure-html/cell-8-output-1.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
</div>
<p>Observe that the left is a figure very similar to the paper’s. Whereas real data has easy patterns that can be learned in a single epoch, random data does not and networks must resort to memorization. After 10 epochs we observe that the networks trained on random data manage to improve the performance on a few points at the expense of the rest, whose performance becomes worse than random.</p>
<p>Out of curiosity here are the 10 easiest and hardest examples.</p>
<div id="cell-15" class="cell" data-execution_count="199">
<details class="code-fold">
<summary>Plot hardest and easiest examples</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb6" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1">n <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span></span>
<span id="cb6-2">f, ax <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> plt.subplots(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, n, figsize <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> (n <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>))</span>
<span id="cb6-3"><span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> i, idx <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">enumerate</span>(torch.sort(p).indices[:n]):</span>
<span id="cb6-4">    img, label <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> test_set[idx]</span>
<span id="cb6-5">    ax[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>][i].imshow(img.permute(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>).numpy())</span>
<span id="cb6-6">    ax[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>][i].axis(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'off'</span>)</span>
<span id="cb6-7">    ax[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>][i].set_title(test_set.classes[label].replace(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'mobile'</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">''</span>))</span>
<span id="cb6-8"></span>
<span id="cb6-9"><span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> i, idx <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">enumerate</span>(torch.sort(p).indices[<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span>n:]):</span>
<span id="cb6-10">    img, label <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> test_set[idx]</span>
<span id="cb6-11">    ax[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>][i].imshow(img.permute(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>).numpy())</span>
<span id="cb6-12">    ax[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>][i].axis(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'off'</span>)</span>
<span id="cb6-13">    ax[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>][i].set_title(test_set.classes[label].replace(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'mobile'</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">''</span>))</span>
<span id="cb6-14"></span>
<span id="cb6-15">f.suptitle(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Hardest (top) and easiest (bottom) examples'</span>)</span>
<span id="cb6-16">f.tight_layout()</span></code></pre></div></div>
</details>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://ecntu.com/posts/nn-memorization/index_files/figure-html/cell-9-output-1.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
</div>
</section>
<section id="fig-2" class="level2">
<h2 class="anchored" data-anchor-id="fig-2">Fig 2</h2>
<p>The fact that networks learn patterns when trained on real data and don’t when trained on noise can also be visualized by plotting the first layer weights of a convolutional network. We show the weights for networks trained for 10 epochs on real and random data.</p>
<div id="cell-18" class="cell" data-execution_count="4">
<details class="code-fold">
<summary>ConvNet definition</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb7" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">class</span> ConvNet(nn.Module):</span>
<span id="cb7-2">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, num_classes):</span>
<span id="cb7-3">        <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">super</span>(ConvNet, <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>).<span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>()</span>
<span id="cb7-4"></span>
<span id="cb7-5">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.model <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.Sequential(</span>
<span id="cb7-6">            nn.LazyConv2d(out_channels<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">200</span>, kernel_size<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">5</span>),  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># LazyConv2d to infer input channels</span></span>
<span id="cb7-7">            nn.BatchNorm2d(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">200</span>),</span>
<span id="cb7-8">            nn.ReLU(),</span>
<span id="cb7-9">            nn.MaxPool2d(kernel_size<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span>, stride<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span>),</span>
<span id="cb7-10"></span>
<span id="cb7-11">            nn.LazyConv2d(out_channels<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">200</span>, kernel_size<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">5</span>),  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Another LazyConv2d</span></span>
<span id="cb7-12">            nn.BatchNorm2d(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">200</span>),</span>
<span id="cb7-13">            nn.ReLU(),</span>
<span id="cb7-14">            nn.MaxPool2d(kernel_size<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span>, stride<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span>),</span>
<span id="cb7-15"></span>
<span id="cb7-16">            nn.Flatten(),  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Flatten for the fully connected layer</span></span>
<span id="cb7-17">            nn.LazyLinear(out_features<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">384</span>),  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># LazyLinear to infer input features</span></span>
<span id="cb7-18">            nn.BatchNorm1d(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">384</span>),</span>
<span id="cb7-19">            nn.ReLU(),</span>
<span id="cb7-20"></span>
<span id="cb7-21">            nn.LazyLinear(out_features<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">192</span>),  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Another LazyLinear</span></span>
<span id="cb7-22">            nn.BatchNorm1d(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">192</span>),</span>
<span id="cb7-23">            nn.ReLU(),</span>
<span id="cb7-24"></span>
<span id="cb7-25">            nn.LazyLinear(out_features<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>num_classes)</span>
<span id="cb7-26">        )</span>
<span id="cb7-27"></span>
<span id="cb7-28">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> forward(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, x):</span>
<span id="cb7-29">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.model(x)</span>
<span id="cb7-30">    </span>
<span id="cb7-31"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> train_conv(model, train, val, optimizer, criterion <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.CrossEntropyLoss(), epochs <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">100</span>, batch_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">256</span>):</span>
<span id="cb7-32">    model.to(device)</span>
<span id="cb7-33">    train <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> DataLoader(train, batch_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> batch_size, shuffle <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)</span>
<span id="cb7-34">    val <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> DataLoader(val, batch_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> batch_size, shuffle <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>)</span>
<span id="cb7-35">    scheduler <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> optim.lr_scheduler.StepLR(optimizer, step_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">15</span>, gamma <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.5</span>)</span>
<span id="cb7-36">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> epoch <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(epochs):</span>
<span id="cb7-37">        model.train()</span>
<span id="cb7-38">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> images, labels <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> train:</span>
<span id="cb7-39">            images, labels <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> images.to(device), labels.to(device)</span>
<span id="cb7-40">            optimizer.zero_grad()</span>
<span id="cb7-41">            loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> criterion(model(images), labels)</span>
<span id="cb7-42">            loss.backward()</span>
<span id="cb7-43">            optimizer.step()</span>
<span id="cb7-44">        scheduler.step()</span>
<span id="cb7-45">        val_loss, val_acc <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> eval_model(model, val)</span>
<span id="cb7-46">        <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(<span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f'Epoch </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>epoch <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">/</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>epochs<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">, Loss: </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>val_loss<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">:.4f}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">, Accuracy: </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>val_acc<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">:.4f}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">'</span>)</span></code></pre></div></div>
</details>
</div>
<div id="cell-19" class="cell" data-execution_count="165">
<details class="code-fold">
<summary>Train models</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb8" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1">cifar10_model <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> ConvNet(num_classes <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> n_classes)</span>
<span id="cb8-2">randY_model <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> ConvNet(num_classes <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> n_classes)</span>
<span id="cb8-3"></span>
<span id="cb8-4"><span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> model, dataset <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">zip</span>([cifar10_model, randY_model], [train_set, RandY(train_set)]):</span>
<span id="cb8-5">    initialize_model(model, DataLoader(dataset, batch_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">256</span>))</span>
<span id="cb8-6">    train_conv(model, dataset, test_set, optim.SGD(model.parameters(), lr <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.01</span>, momentum <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.9</span>), epochs <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span>)</span>
<span id="cb8-7">    torch.save(model.state_dict(), <span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f'models/fig2_</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>model<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">.</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">__class__</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">.</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">__name__</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">'</span>)</span></code></pre></div></div>
</details>
</div>
<div id="cell-20" class="cell" data-execution_count="182">
<details class="code-fold">
<summary>Plot filters</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb9" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1">f, (ax1, ax2) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> plt.subplots(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, figsize <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> (<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">9</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span>), sharey <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)</span>
<span id="cb9-2"></span>
<span id="cb9-3"><span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> name, model, ax <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">zip</span>([<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Cifar 10'</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'RandY'</span>], [cifar10_model, randY_model], [ax1, ax2]):</span>
<span id="cb9-4"></span>
<span id="cb9-5">    model.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">eval</span>()</span>
<span id="cb9-6">    kernel_weights <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> model.model[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>].weight[:<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">7</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">12</span>].detach().cpu().clone()</span>
<span id="cb9-7">    kernel_weights <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> (kernel_weights <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> kernel_weights.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">min</span>()) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> (kernel_weights.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">max</span>() <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> kernel_weights.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">min</span>())</span>
<span id="cb9-8">    filter_img <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> utils.make_grid(kernel_weights, nrow <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">12</span>, padding <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)</span>
<span id="cb9-9">    ax.imshow(filter_img.permute(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>))</span>
<span id="cb9-10">    ax.set_title(name)</span>
<span id="cb9-11">    ax.axis(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'off'</span>)</span></code></pre></div></div>
</details>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://ecntu.com/posts/nn-memorization/index_files/figure-html/cell-12-output-1.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
</div>
<p>And we are able to see that the filters learned by the network trained on real data are much more structured and seem useful in contrast to the ones learned by training on noise.</p>
</section>
<section id="fig-9" class="level2">
<h2 class="anchored" data-anchor-id="fig-9">Fig 9</h2>
<p>To attempt to show that networks trained on real data are simpler hypotheses because they learn patterns, the authors introduce <em>Critical Sample Ratio</em> as a way to measure complexity. The idea is to</p>
<blockquote class="blockquote">
<p>“estimate the complexity by measuring how densely points on the data manifold are present around the model’s decision boundaries. Intuitively, if we were to randomly sample points from the data distribution, a smaller fraction of points in the proximity of a decision boundary suggests that the learned hypothesis is simpler.”</p>
</blockquote>
<p>A simple sketch illustrates:</p>
<div class="quarto-figure quarto-figure-center">
<figure class="figure">
<p><img src="https://ecntu.com/posts/nn-memorization/index_files/images/fig9_sketch.jpg" class="img-fluid figure-img" width="300"></p>
<figcaption>CSR intuition sketch</figcaption>
</figure>
</div>
<p>To estimate the density of points close to decision boundaries we might perturb the original data points within a box of size <img src="https://latex.codecogs.com/png.latex?r"> and see if we cross the boundary. If a point crosses a boundary we call it “critical”. The <em>Critical Sample Ratio</em> is then the proportion of points that are critical and we expect simpler networks to have lower CSRs.</p>
<p>The perturbation done to data points is not totally random. The technique used by the paper is presented in Algorithm 1, borrows ideas from adversarial attacks, and is called Langevin Adversarial Sample Search (LASS). Here is how I implemented it.</p>
<div id="cell-25" class="cell" data-execution_count="195">
<details class="code-fold">
<summary>LASS implementation</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb10" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb10-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> standard_normal(shape):</span>
<span id="cb10-2">    r <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.randn(shape)</span>
<span id="cb10-3">    r <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> r.to(device)</span>
<span id="cb10-4">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> r</span>
<span id="cb10-5"></span>
<span id="cb10-6"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> lass(model, x, alpha <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.25</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">255</span>, beta <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.2</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">255</span>, r <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.3</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">255</span>, eta <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> standard_normal, max_iter <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span>):</span>
<span id="cb10-7">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">"""</span></span>
<span id="cb10-8"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">    Langevin Adversarial Sample Search (LASS).</span></span>
<span id="cb10-9"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">    Finds a perturbation of x that changes the model's prediction.</span></span>
<span id="cb10-10"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">    </span></span>
<span id="cb10-11"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">        labels: Tensor of true labels corresponding to the input x.</span></span>
<span id="cb10-12"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">        alpha: Step size for the gradient sign method.</span></span>
<span id="cb10-13"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">        beta: Scaling factor for the noise.</span></span>
<span id="cb10-14"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">        r: Clipping radius for adversarial perturbations.</span></span>
<span id="cb10-15"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">        eta: Noise process.</span></span>
<span id="cb10-16"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">    """</span></span>
<span id="cb10-17">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Orignal prediction</span></span>
<span id="cb10-18">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">with</span> torch.no_grad():</span>
<span id="cb10-19">        pred_on_x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> model(x).argmax(dim<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)</span>
<span id="cb10-20">    </span>
<span id="cb10-21"></span>
<span id="cb10-22">    x_adv <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> x.clone().detach().requires_grad_(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)</span>
<span id="cb10-23">    converged <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span></span>
<span id="cb10-24">    iter_count <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span></span>
<span id="cb10-25"></span>
<span id="cb10-26">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">while</span> <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">not</span> converged <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">and</span> iter_count <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&lt;</span> max_iter:</span>
<span id="cb10-27">        iter_count <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span></span>
<span id="cb10-28"></span>
<span id="cb10-29">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Forward pass to get model output</span></span>
<span id="cb10-30">        x_adv.requires_grad_(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)</span>
<span id="cb10-31">        output <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> model(x_adv)</span>
<span id="cb10-32"></span>
<span id="cb10-33">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Compute gradient of the output with respect to input</span></span>
<span id="cb10-34">        loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> F.cross_entropy(output, pred_on_x)  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Use actual labels</span></span>
<span id="cb10-35">        loss.backward()</span>
<span id="cb10-36"></span>
<span id="cb10-37">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Compute the perturbation</span></span>
<span id="cb10-38">        gradient_sign <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> x_adv.grad.sign()</span>
<span id="cb10-39">        delta <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> alpha <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> gradient_sign <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> beta <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> eta(x_adv.shape)</span>
<span id="cb10-40"></span>
<span id="cb10-41">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">with</span> torch.no_grad():</span>
<span id="cb10-42">            </span>
<span id="cb10-43">            x_adv <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+=</span> delta</span>
<span id="cb10-44"></span>
<span id="cb10-45">            <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Apply the clipping to each dimension so that each pixel is in the range [x - r, x + r]</span></span>
<span id="cb10-46">            x_adv <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.clamp(x_adv, x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> r, x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> r)</span>
<span id="cb10-47"></span>
<span id="cb10-48">            <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Check if the adversarial example has changed the model's prediction</span></span>
<span id="cb10-49">            new_output <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> model(x_adv)</span>
<span id="cb10-50">            <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">not</span> torch.equal(output.argmax(dim<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>), new_output.argmax(dim<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)):</span>
<span id="cb10-51">                converged <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span></span>
<span id="cb10-52">                x_hat <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> x_adv.clone().detach()</span>
<span id="cb10-53"></span>
<span id="cb10-54">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Zero the gradients for the next iteration</span></span>
<span id="cb10-55">        model.zero_grad()</span>
<span id="cb10-56">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> x_adv.grad <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">is</span> <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">not</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>: x_adv.grad.zero_()</span>
<span id="cb10-57"></span>
<span id="cb10-58">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> converged, x_hat <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> converged <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">else</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span></span>
<span id="cb10-59"></span>
<span id="cb10-60"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> compute_csr(model, test_set, n_examples <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>, shuffle <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>, <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">**</span>lass_kwargs):</span>
<span id="cb10-61">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> n_examples <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">is</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>: n_examples <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(test_set)</span>
<span id="cb10-62">    model <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> model.to(device)</span>
<span id="cb10-63">    csr <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span></span>
<span id="cb10-64">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> i, (images, labels) <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">enumerate</span>(DataLoader(test_set, batch_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, shuffle <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> shuffle)):</span>
<span id="cb10-65">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> i <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> n_examples: <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">break</span></span>
<span id="cb10-66">        images, labels <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> images.to(device), labels.to(device)</span>
<span id="cb10-67">        converged, _ <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> lass(model, images, <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">**</span>lass_kwargs)</span>
<span id="cb10-68">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> converged: csr <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span></span>
<span id="cb10-69">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> csr <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> n_examples</span></code></pre></div></div>
</details>
</div>
<p>The paper sets the radius we search for adversarial examples to <img src="https://latex.codecogs.com/png.latex?r%20=%2030/255"> because it was small enough to not be noticed by a human evaluator. Here is an example.</p>
<div id="cell-27" class="cell" data-execution_count="335">
<details class="code-fold">
<summary>Adversarial example</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb11" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb11-1">model <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> cifar10_model</span>
<span id="cb11-2">model.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">eval</span>()<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">;</span> model.to(device)</span>
<span id="cb11-3">set_seed(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">5</span>)</span>
<span id="cb11-4"></span>
<span id="cb11-5">f, (ax1, ax2) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> plt.subplots(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, figsize <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> (<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">9</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span>), sharey <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)</span>
<span id="cb11-6"><span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> i, (x, y) <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">enumerate</span>(DataLoader(test_set, batch_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, shuffle <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)):</span>
<span id="cb11-7">    x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> x.to(device)</span>
<span id="cb11-8">    y <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> y.to(device)</span>
<span id="cb11-9">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">with</span> torch.no_grad():</span>
<span id="cb11-10">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> model(x).argmax() <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">!=</span> y: <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">continue</span></span>
<span id="cb11-11">    converged, adv <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> lass(model, x)</span>
<span id="cb11-12">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> converged:</span>
<span id="cb11-13">        ax1.imshow(x.squeeze().cpu().permute(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>))</span>
<span id="cb11-14">        ax2.imshow(adv.squeeze().cpu().permute(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>))</span>
<span id="cb11-15">        ax1.axis(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'off'</span>)<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">;</span> ax2.axis(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'off'</span>)</span>
<span id="cb11-16">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">with</span> torch.no_grad():</span>
<span id="cb11-17">            ax1.set_title(<span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f'Original. Predicted: </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>test_set<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">.</span>classes[model(x).argmax().item()]<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">'</span>)</span>
<span id="cb11-18">            ax2.set_title(<span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f'Adversarial. Predicted: </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>test_set<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">.</span>classes[model(adv).argmax().item()]<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">'</span>)</span>
<span id="cb11-19">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">break</span></span></code></pre></div></div>
</details>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://ecntu.com/posts/nn-memorization/index_files/figure-html/cell-14-output-1.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
</div>
<p>And now try to compute the Critical Sample Ratio as we train models to reproduce Figure 9.</p>
<div id="cell-29" class="cell">
<details class="code-fold">
<summary>ConvNet training loop</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb12" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb12-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> train_fig9(model, train, val, optimizer, criterion <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.CrossEntropyLoss(), epochs <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">100</span>, batch_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">256</span>):</span>
<span id="cb12-2">    val_accs, csrs <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> [], []</span>
<span id="cb12-3">    model.to(device)</span>
<span id="cb12-4">    train <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> DataLoader(train, batch_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> batch_size, shuffle <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)</span>
<span id="cb12-5">    val <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> DataLoader(val, batch_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> batch_size, shuffle <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>)</span>
<span id="cb12-6">    scheduler <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> optim.lr_scheduler.StepLR(optimizer, step_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">15</span>, gamma <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.5</span>)</span>
<span id="cb12-7">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> epoch <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(epochs):</span>
<span id="cb12-8">        model.train()</span>
<span id="cb12-9">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> images, labels <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> train:</span>
<span id="cb12-10">            images, labels <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> images.to(device), labels.to(device)</span>
<span id="cb12-11">            optimizer.zero_grad()</span>
<span id="cb12-12">            loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> criterion(model(images), labels)</span>
<span id="cb12-13">            loss.backward()</span>
<span id="cb12-14">            optimizer.step()</span>
<span id="cb12-15">        scheduler.step()</span>
<span id="cb12-16">        val_loss, val_acc <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> eval_model(model, val)</span>
<span id="cb12-17">        val_accs.append(val_acc)</span>
<span id="cb12-18">        <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(<span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f'Epoch </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>epoch <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">/</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>epochs<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">, Loss: </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>val_loss<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">:.4f}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">, Accuracy: </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>val_acc<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">:.4f}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">'</span>)</span>
<span id="cb12-19"></span>
<span id="cb12-20">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> epoch <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>:</span>
<span id="cb12-21">            csr <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> compute_csr(model, val.dataset, n_examples <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">500</span>, r <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">40</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">255</span>)</span>
<span id="cb12-22">            csrs.append(csr)</span>
<span id="cb12-23">            <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(<span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f'CSR: </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>csr<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">:.4f}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">'</span>)</span>
<span id="cb12-24">    </span>
<span id="cb12-25">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> val_accs, csrs</span>
<span id="cb12-26"></span>
<span id="cb12-27"><span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> dataset <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> [train_set, RandX(train_set, x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">1.0</span>), RandY(train_set, y <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">1.0</span>)]:</span>
<span id="cb12-28">    set_seed(seed  <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">42</span>)</span>
<span id="cb12-29">    <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(dataset.<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">__class__</span>.<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">__name__</span>)</span>
<span id="cb12-30">    model <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> ConvNet(num_classes <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> n_classes)</span>
<span id="cb12-31">    initialize_model(model, DataLoader(dataset, batch_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">256</span>))</span>
<span id="cb12-32">    val_accs, csrs <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> train_fig9(model, dataset, test_set, optim.SGD(model.parameters(), lr <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.01</span>, momentum <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.9</span>), epochs <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">141</span>)</span>
<span id="cb12-33">    torch.save({<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'val_accs'</span>: val_accs, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'csrs'</span>: csrs}, <span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f'logs/fig9_r=45_</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>dataset<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">.</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">__class__</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">.</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">__name__</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">'</span>)</span></code></pre></div></div>
</details>
</div>
<div id="cell-30" class="cell" data-execution_count="343">
<details class="code-fold">
<summary>Plot results</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb13" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb13-1">plt.plot(np.arange(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">141</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span>), torch.load(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'logs/fig9_CIFAR10'</span>)[<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'csrs'</span>], <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'b'</span>, label <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'CIFAR10'</span>)</span>
<span id="cb13-2">plt.plot(np.arange(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">141</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span>), [<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">15</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'r--'</span>, label <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'RandX (?)'</span>)</span>
<span id="cb13-3">plt.plot(np.arange(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">141</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span>), torch.load(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'logs/fig9_RandY'</span>)[<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'csrs'</span>], <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'g'</span>, label <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'RandY'</span>)</span>
<span id="cb13-4"></span>
<span id="cb13-5">plt.legend()</span>
<span id="cb13-6">plt.xlabel(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Epoch'</span>)</span>
<span id="cb13-7">plt.ylabel(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Critical Sample Ratio (CSR)'</span>)</span></code></pre></div></div>
</details>
<div class="cell-output cell-output-display" data-execution_count="343">
<pre><code>Text(0, 0.5, 'Critical Sample Ratio (CSR)')</code></pre>
</div>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://ecntu.com/posts/nn-memorization/index_files/figure-html/cell-16-output-2.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
</div>
<div class="quarto-figure quarto-figure-center">
<figure class="figure">
<p><img src="https://ecntu.com/posts/nn-memorization/index_files/images/fig9.png" class="img-fluid figure-img" width="300"></p>
<figcaption>Original Fig 9</figcaption>
</figure>
</div>
<p>Where we observe roughly the same trend as in the paper displayed above while the network trained on real data has a somewhat constant CSR, the one trained on random labels has a higher CSR as training progresses. However, I could not reproduce RandX’s behavior and obtained a constant CSR of 0. I tried different seeds, <img src="https://latex.codecogs.com/png.latex?r">, and datasets (training and validation) without luck. My suspicion is that the model’s capacity and thus performance were not high enough (around 10% validation accuracy). I decided to stick with the paper’s architecture and move on.</p>
</section>
<section id="what-i-learned-practiced" class="level2">
<h2 class="anchored" data-anchor-id="what-i-learned-practiced">What I learned / practiced</h2>
<ul>
<li>How to visualize 1st layer kernel weights</li>
<li>A bit about adversarial attacks</li>
<li>A creative proxy for model complexity (CSR)</li>
</ul>


</section>

 ]]></description>
  <category>deep learning</category>
  <category>paper</category>
  <guid>https://ecntu.com/posts/nn-memorization/</guid>
  <pubDate>Sat, 07 Sep 2024 04:00:00 GMT</pubDate>
</item>
<item>
  <title>Approximate Nearest Cosine Neighbors</title>
  <link>https://ecntu.com/posts/lsh/</link>
  <description><![CDATA[ 





<p>Suppose you have some vectors and wish to find, for each point, the <img src="https://latex.codecogs.com/png.latex?k"> nearest points. While you could compute pairwise distances using a naïve quadratic algorithm for small datasets, this approach becomes infeasible with millions or billions of points. If the points are in a low-dimensional space, clever data structures like <a href="https://en.wikipedia.org/wiki/K-d_tree">kd-trees</a>, ball trees, and M-trees can achieve substantial speedups. However, in high dimensions, performance degrades, and you may need to sacrifice exactness and turn to Approximate Nearest Neighbors (<a href="https://en.wikipedia.org/wiki/Nearest_neighbor_search#Approximation_methods">ANN</a>) techniques.</p>
<p>Locality-sensitive hashing (<a href="https://en.wikipedia.org/wiki/Locality-sensitive_hashing">LSH</a>) is a family of ANN algorithms that aim to group similar points into the same (or nearby) buckets efficiently using specialized hash functions. Recall that traditional hashing tries to map items to a set of buckets uniformly and minimize collisions. Thus, traditionally, slightly changing a point results in a vastly different hash and assigned bucket. LSH uses different hashing functions that often depend on the distance metric employed. Here, we explore a simple approach using cosine distance based on <a href="https://en.wikipedia.org/wiki/Random_projection">Random Projection</a>.</p>
<section id="how-it-works" class="level3">
<h3 class="anchored" data-anchor-id="how-it-works">How it works</h3>
<p>The mechanics are not very complicated. We first generate <img src="https://latex.codecogs.com/png.latex?N_h"> random hyperplanes. Let’s visualize this in two dimensions with <img src="https://latex.codecogs.com/png.latex?n=5"> points and <img src="https://latex.codecogs.com/png.latex?N_h=2">:</p>
<div id="853827e1-6a27-4fff-bbb5-031c666bab0d" class="cell" data-execution_count="48">
<details class="code-fold">
<summary>Random data and visualisation</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb1" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Note that we sample from a standard normal distribution</span></span>
<span id="cb1-2">X <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.randn((<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">5</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>))</span>
<span id="cb1-3">rand_planes <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.randn((<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>)) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.5</span></span>
<span id="cb1-4"></span>
<span id="cb1-5"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> lims(X, ax, eps <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.1</span>):</span>
<span id="cb1-6">    m <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> X[:, ax].<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">abs</span>().<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">max</span>().item()</span>
<span id="cb1-7">    m <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+=</span> eps</span>
<span id="cb1-8">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> (<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span>m, m)</span>
<span id="cb1-9"></span>
<span id="cb1-10">plt.scatter(X[:, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>], X[:, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>])</span>
<span id="cb1-11">plt.xlim(lims(X, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>))<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">;</span> plt.ylim(lims(X, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>))</span>
<span id="cb1-12"></span>
<span id="cb1-13"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Axes</span></span>
<span id="cb1-14">plt.axline((<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>,<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>), (<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>,<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>), c <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'0'</span>, linewidth <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.8</span>)</span>
<span id="cb1-15">plt.axline((<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>,<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>), (<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>,<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>), c <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'0'</span>, linewidth <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.8</span>)</span>
<span id="cb1-16"></span>
<span id="cb1-17"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Random planes</span></span>
<span id="cb1-18"><span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> v <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> rand_planes:</span>
<span id="cb1-19">    x, y <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> v.tolist()</span>
<span id="cb1-20">    plt.axline((<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>,<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>), (y, <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span>x), c <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'g'</span>, linewidth <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.8</span>)</span></code></pre></div></div>
</details>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://ecntu.com/posts/lsh/index_files/figure-html/cell-3-output-1.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
</div>
<p>Note that the plane is effectively partitioned into four regions. The main idea is to treat the regions as buckets and assign all the points in a region to the same bucket. When a new point comes along, we simply find out which region it belongs to and then search for points in that region and nearby regions until we accumulate <img src="https://latex.codecogs.com/png.latex?k"> of them.</p>
<p>Great! How do we do it computationally? We mainly need to remember that a hyperplane is characterized by its normal vector <img src="https://latex.codecogs.com/png.latex?%5Cvec%7Bv%7D"> and that we can determine on which side of it a point <img src="https://latex.codecogs.com/png.latex?x"> lies by the sign of their dot product, <img src="https://latex.codecogs.com/png.latex?%5Ctext%7Bsign%7D(%5Cvec%7Bv%7D%20%5Ccdot%20%5Cvec%7Bx%7D)">. For example, points 0 and 2 are on opposite sides of hyperplane 0:</p>
<div id="72836dfd-a7c5-4551-83f0-bbb4926a60fb" class="cell" data-execution_count="49">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb2" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1">v <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> rand_planes[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>]</span>
<span id="cb2-2">v <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">@</span> X[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&gt;=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, v <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">@</span> X[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&gt;=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span></span></code></pre></div></div>
<div class="cell-output cell-output-display" data-execution_count="49">
<pre><code>(tensor(True), tensor(False))</code></pre>
</div>
</div>
<div id="fd0c59c0-cff9-4e4b-9854-6c42d3a719ae" class="cell" data-execution_count="50">
<details class="code-fold">
<summary>Dot product example</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb4" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1">plt.scatter(<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span>X[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>].tolist())</span>
<span id="cb4-2">plt.scatter(<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span>X[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>].tolist())</span>
<span id="cb4-3">plt.xlim(lims(X, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>))<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">;</span> plt.ylim(lims(X, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>))</span>
<span id="cb4-4">plt.axline((<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>,<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>), (<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>,<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>), c <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'0'</span>, linewidth <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.8</span>)</span>
<span id="cb4-5">plt.axline((<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>,<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>), (<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>,<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>), c <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'0'</span>, linewidth <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.8</span>)</span>
<span id="cb4-6"></span>
<span id="cb4-7">x, y <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> rand_planes[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>].tolist()</span>
<span id="cb4-8">plt.axline((<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>,<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>), (y, <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span>x), c <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'g'</span>, linewidth <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.8</span>, label <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'hyperplane'</span>)</span>
<span id="cb4-9">plt.quiver(<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span>([<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>], [<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>]), x, y, color <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'0.5'</span>, label <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'v'</span>)</span>
<span id="cb4-10">plt.legend()</span></code></pre></div></div>
</details>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://ecntu.com/posts/lsh/index_files/figure-html/cell-5-output-1.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
</div>
<p>Thus, to find a point’s region, we can repeat this process for all (in our case, two) of the hyperplanes. For example:</p>
<div id="7d3f3cf4-e33b-4209-b1ea-2d67f6df20fd" class="cell" data-execution_count="51">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb5" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1">X[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">@</span> rand_planes[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&gt;=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, X[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">@</span> rand_planes[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&gt;=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span></span></code></pre></div></div>
<div class="cell-output cell-output-display" data-execution_count="51">
<pre><code>(tensor(True), tensor(False))</code></pre>
</div>
</div>
<div id="2a916e1c-f547-4f30-92ce-fd57d84e6b9b" class="cell" data-execution_count="52">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb7" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1">X[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">@</span> rand_planes[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&gt;=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, X[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">@</span> rand_planes[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&gt;=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span></span></code></pre></div></div>
<div class="cell-output cell-output-display" data-execution_count="52">
<pre><code>(tensor(False), tensor(False))</code></pre>
</div>
</div>
<p>I.e., points 0 and 2 are in different regions. Using matrix notation, we can obtain every point’s region succinctly:</p>
<div id="68abe665-4d08-4867-a559-efe9025f822c" class="cell" data-execution_count="53">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb9" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1">regions <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> X <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">@</span> rand_planes.T <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&gt;=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span></span>
<span id="cb9-2">regions</span></code></pre></div></div>
<div class="cell-output cell-output-display" data-execution_count="53">
<pre><code>tensor([[ True, False],
        [ True, False],
        [False, False],
        [False,  True],
        [ True, False]])</code></pre>
</div>
</div>
<p>We can now place each point in a region into a bucket:</p>
<div id="d3c9c1a6-e5b2-422e-865d-f5468b87c3b5" class="cell" data-execution_count="54">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb11" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb11-1">buckets <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> {}</span>
<span id="cb11-2"><span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> i, reg <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">enumerate</span>(regions):</span>
<span id="cb11-3">    reg <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">tuple</span>(reg.tolist())</span>
<span id="cb11-4">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> reg <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">not</span> <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> buckets: buckets[reg] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> []</span>
<span id="cb11-5">    buckets[reg].append(i)</span>
<span id="cb11-6">buckets</span></code></pre></div></div>
<div class="cell-output cell-output-display" data-execution_count="54">
<pre><code>{(True, False): [0, 1, 4], (False, False): [2], (False, True): [3]}</code></pre>
</div>
</div>
<p>And that’s all the preprocessing we need to do. Now, to find the nearest neighbors of a query point <img src="https://latex.codecogs.com/png.latex?%5Cvec%7Bq%7D">, we find its region:</p>
<div id="0d860b8c-3afa-401c-ad3c-f4939975b2b2" class="cell" data-execution_count="55">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb13" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb13-1">q <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.randn((<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>,))</span>
<span id="cb13-2">q_reg <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">tuple</span>((q <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">@</span> rand_planes.T <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&gt;=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>).tolist())</span>
<span id="cb13-3">q, q_reg, buckets.get(q_reg, [])</span></code></pre></div></div>
<div class="cell-output cell-output-display" data-execution_count="55">
<pre><code>(tensor([-0.9890,  0.9580]), (True, False), [0, 1, 4])</code></pre>
</div>
</div>
<p>In this case, we had three points in the query’s bucket. When the number of elements in the bucket is less than <img src="https://latex.codecogs.com/png.latex?k">, the common approach is to look to nearby buckets in terms of Hamming distance. I.e., we first add points in buckets one bit flip away, then two flips away, and so on until we accumulate <img src="https://latex.codecogs.com/png.latex?k"> (or slightly more) points.</p>
</section>
<section id="considerations" class="level3">
<h3 class="anchored" data-anchor-id="considerations">Considerations</h3>
<p>To help root out any false positives—points that are not in the top <img src="https://latex.codecogs.com/png.latex?k"> nearest but happened to land in or near the buckets—we can calculate the actual cosine distance from the candidate points we retrieved and return only those below a specified threshold.</p>
<p>You can imagine that the number of hyperplanes presents an accuracy-speed tradeoff: more hyperplanes imply (exponentially) more buckets, which means fewer false positives but also more computation to calculate regions. You can find details in the references.</p>
<p>If you get unlucky with the random generation of the hyperplanes, you might have a high false negative rate—actually close points that end up in far-away buckets. In this case, we could generate other sets of hyperplanes, obtain their respective candidates, and obtain a candidate pool by taking their union. See the <a href="http://infolab.stanford.edu/~bawa/Pub/similarity.pdf">LSH Forest paper</a>.</p>
<p>As a final note, how we generate hyperplanes is somewhat important. You can imagine that we generally want them to be evenly distributed on the unit sphere, which is why we sample from a normal distribution. We’d also like them to be spaced out, which is why some methods generate (expensive) orthogonal random matrices. Or you might not care in practice, remembering that in high dimensions, <a href="https://math.stackexchange.com/questions/995623/why-are-randomly-drawn-vectors-nearly-perpendicular-in-high-dimensions">pairs of sampled points are almost surely orthogonal</a>.</p>
</section>
<section id="resources" class="level3">
<h3 class="anchored" data-anchor-id="resources">Resources</h3>
<ul>
<li><a href="https://towardsdatascience.com/similarity-search-part-6-random-projections-with-lsh-forest-f2e9b31dcc47">Blog post w/some probabilty details</a></li>
<li><a href="http://www.mmds.org/">Mining of Massive Datasets</a> (book and course, chapter 3)</li>
<li><a href="https://www.cs.princeton.edu/courses/archive/spr04/cos598B/bib/CharikarEstim.pdf">Similarity Estimation Techniques from Rounding Algorithms</a></li>
</ul>


</section>

 ]]></description>
  <category>cs</category>
  <category>quick intro</category>
  <guid>https://ecntu.com/posts/lsh/</guid>
  <pubDate>Fri, 09 Aug 2024 04:00:00 GMT</pubDate>
</item>
<item>
  <title>Understanding Batch Normalization</title>
  <link>https://ecntu.com/posts/understanding-bn/</link>
  <description><![CDATA[ 





<p>The paper investigates the cause of batch norm’s benefits experimentally. The authors show that its main benefit is allowing for larger learning rates during training. In particular:</p>
<blockquote class="blockquote">
<p>“We show that the activations and gradients in deep neural networks without BN tend to be heavy-tailed. In particular, during an early on-set of divergence, a small subset of activations (typically in deep layer) “explode”. The typical practice to avoid such divergence is to set the learning rate to be sufficiently small such that no steep gradient direction can lead to divergence. However, small learning rates yield little progress along flat directions of the optimization landscape and may be more prone to convergence to sharp local minima with possibly worse generalization performance.”</p>
</blockquote>
<p>We attempt to reproduce figures 1-3, 5, and 6.</p>
<section id="convolutional-bn-layer" class="level3">
<h3 class="anchored" data-anchor-id="convolutional-bn-layer">Convolutional BN Layer</h3>
<p>As a reminder, the input <img src="https://latex.codecogs.com/png.latex?I"> and output <img src="https://latex.codecogs.com/png.latex?O"> tensors to a batch norm layer are 4 dimensional. The dimensions <img src="https://latex.codecogs.com/png.latex?(b,%20c,%20x,%20y)"> correspond to the batch example, channel, and spatial <img src="https://latex.codecogs.com/png.latex?x">, <img src="https://latex.codecogs.com/png.latex?y"> dimensions respectively. Batch norm (BN) applies a channel-wise normalization:</p>
<p><img src="https://latex.codecogs.com/png.latex?%0AO_%7Bb,%20c,%20x,%20y%7D%20%5Cleftarrow%20%5Cgamma_c%20%5Cfrac%7BI_%7Bb,%20c,%20x,%20y%7D%20-%20%5Chat%20%5Cmu_c%7D%7B%5Csqrt%7B%5Chat%20%5Csigma_c%5E2%20+%20%5Cepsilon%7D%7D%20+%20%5Cbeta_c%0A"></p>
<p>Where <img src="https://latex.codecogs.com/png.latex?%5Chat%20%5Cmu_c"> and <img src="https://latex.codecogs.com/png.latex?%5Chat%20%5Csigma_c%5E2"> are estimates channel <img src="https://latex.codecogs.com/png.latex?c">’s mean and standard deviation computed on the minibatch <img src="https://latex.codecogs.com/png.latex?%5Cmathcal%20B">:</p>
<p><img src="https://latex.codecogs.com/png.latex?%0A%5Chat%20%5Cmu_c%20=%20%5Cfrac%7B1%7D%7B%7C%5Cmathcal%20B%7C%7D%5Csum_%7Bb,%20x,%20y%7D%20I_%7Bb,%20c,%20x,%20y%7D%0A"></p>
<p><img src="https://latex.codecogs.com/png.latex?%0A%5Chat%20%5Csigma_c%5E2%20=%20%5Cfrac%7B1%7D%7B%5Cmathcal%20%7CB%7C%7D%20%5Csum_%7Bb,%20x,%20y%7D%20(I_%7Bb,%20c,%20x,%20y%7D%20-%20%5Chat%20%5Cmu_c)%20%5E%202%0A"></p>
<p>To make sure the layer does not lose expressive power we introduce learned parameters <img src="https://latex.codecogs.com/png.latex?%5Cgamma_c"> and <img src="https://latex.codecogs.com/png.latex?%5Cbeta_c">. <img src="https://latex.codecogs.com/png.latex?%5Cepsilon"> is a small constant added for numerical stability. In pytorch, we can simply use the <code>BatchNorm2d</code> layer.</p>
</section>
<section id="experimental-setup" class="level3">
<h3 class="anchored" data-anchor-id="experimental-setup">Experimental setup</h3>
<p>Let’s set up our data loaders, model, and training loop as described in Appendix B of the paper.</p>
<div id="cell-6" class="cell" data-execution_count="1">
<details class="code-fold">
<summary>Imports and model evaluation function</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb1" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> torch</span>
<span id="cb1-2"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> torch.nn <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> nn</span>
<span id="cb1-3"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> torch.optim <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> optim</span>
<span id="cb1-4"></span>
<span id="cb1-5"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> torchvision <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> datasets, transforms, models</span>
<span id="cb1-6"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> torch.utils.data <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> DataLoader, Dataset</span>
<span id="cb1-7"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> PIL <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> Image</span>
<span id="cb1-8"></span>
<span id="cb1-9"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> numpy <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> np</span>
<span id="cb1-10"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> pandas <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> pd</span>
<span id="cb1-11"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> seaborn <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> sns</span>
<span id="cb1-12"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> matplotlib.pyplot <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> plt</span>
<span id="cb1-13"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> os, itertools, time</span>
<span id="cb1-14"></span>
<span id="cb1-15">os.makedirs(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'logs'</span>, exist_ok <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)</span>
<span id="cb1-16">os.makedirs(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'models'</span>, exist_ok <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)</span>
<span id="cb1-17"></span>
<span id="cb1-18">seed <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">42</span></span>
<span id="cb1-19">np.random.seed(seed)</span>
<span id="cb1-20">torch.manual_seed(seed)</span>
<span id="cb1-21"></span>
<span id="cb1-22">device <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.device(</span>
<span id="cb1-23">    <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'cuda'</span> <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> torch.cuda.is_available() <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">else</span></span>
<span id="cb1-24">    (<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'mps'</span> <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> torch.backends.mps.is_available() <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">else</span></span>
<span id="cb1-25">    <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'cpu'</span>)</span>
<span id="cb1-26">)</span>
<span id="cb1-27"></span>
<span id="cb1-28"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> eval_model(model, test, criterion <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.CrossEntropyLoss()):</span>
<span id="cb1-29">    model.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">eval</span>()</span>
<span id="cb1-30">    correct, loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.0</span></span>
<span id="cb1-31">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">with</span> torch.no_grad():</span>
<span id="cb1-32">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> images, labels <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> test:</span>
<span id="cb1-33">            images, labels <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> images.to(device), labels.to(device)</span>
<span id="cb1-34">            _, pred <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">max</span>(model(images), <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)</span>
<span id="cb1-35">            correct <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+=</span> (pred <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> labels).<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span>().<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">sum</span>().item()</span>
<span id="cb1-36">            loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+=</span> criterion(model(images), labels).item()</span>
<span id="cb1-37">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(test.dataset), correct <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(test.dataset)</span>
<span id="cb1-38"></span>
<span id="cb1-39">device</span></code></pre></div></div>
</details>
</div>
<p>The paper trains ResNet-110s on CIFAR-10, with channel-wise normalization, random horizontal flipping, and 32-by-32 cropping with 4-pixel zero padding. We’ll train the ResNet-101 included in torchvision but keep everything the same.</p>
<p>We first get the datasets and compute the channel-wise means and variances. Note: both the training and validation set have the same values.</p>
<div id="cell-8" class="cell" data-execution_count="3">
<details class="code-fold">
<summary>Datasets and channel-wise means and stds</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb2" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1">train_set <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> datasets.CIFAR10(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'./data'</span>, download <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>, train <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>, transform <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> transforms.ToTensor())</span>
<span id="cb2-2">val_set <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> datasets.CIFAR10(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'./data'</span>, download <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>, train <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>, transform <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> transforms.ToTensor())</span>
<span id="cb2-3"></span>
<span id="cb2-4"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> channel_means_stds(dataset):</span>
<span id="cb2-5">    imgs <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.stack([img <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> img, _ <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> train_set])</span>
<span id="cb2-6">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> imgs.mean(dim <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> [<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span>]), imgs.std(dim <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> [<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span>])</span>
<span id="cb2-7"></span>
<span id="cb2-8">means, stds <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> channel_means_stds(train_set)</span>
<span id="cb2-9"><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(<span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f'Training channel-wise</span><span class="ch" style="color: #20794D;
background-color: null;
font-style: inherit;">\n\t</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">means: </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>means<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ch" style="color: #20794D;
background-color: null;
font-style: inherit;">\n\t</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">stds: </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>stds<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">'</span>)</span>
<span id="cb2-10"></span>
<span id="cb2-11">means, stds <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> channel_means_stds(val_set)</span>
<span id="cb2-12"><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(<span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f'Validation channel-wise</span><span class="ch" style="color: #20794D;
background-color: null;
font-style: inherit;">\n\t</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">means: </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>means<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ch" style="color: #20794D;
background-color: null;
font-style: inherit;">\n\t</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">stds: </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>stds<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">'</span>)</span></code></pre></div></div>
</details>
</div>
<p>We now define the transforms with data augmentation and data loaders with batch size <img src="https://latex.codecogs.com/png.latex?128">.</p>
<div id="cell-10" class="cell" data-execution_count="4">
<details class="code-fold">
<summary>Data transforms and data loaders</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb3" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1">train_transform <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> transforms.Compose([</span>
<span id="cb3-2">    transforms.RandomCrop(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">32</span>, padding <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span>),</span>
<span id="cb3-3">    transforms.RandomHorizontalFlip(),</span>
<span id="cb3-4">    transforms.ToTensor(),</span>
<span id="cb3-5">    transforms.Normalize(means, stds),</span>
<span id="cb3-6">])</span>
<span id="cb3-7"></span>
<span id="cb3-8"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># We do not perform data augmentation on the validation set</span></span>
<span id="cb3-9">val_transform <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> transforms.Compose([</span>
<span id="cb3-10">    transforms.ToTensor(),</span>
<span id="cb3-11">    transforms.Normalize(means, stds),</span>
<span id="cb3-12">])</span>
<span id="cb3-13"></span>
<span id="cb3-14">train_set.transform <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> train_transform</span>
<span id="cb3-15">val_set.transform <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> val_transform</span>
<span id="cb3-16"></span>
<span id="cb3-17">train_loader <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> DataLoader(train_set, batch_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">128</span>, shuffle <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)</span>
<span id="cb3-18">val_loader <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> DataLoader(val_set, batch_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">128</span>, shuffle <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>)</span></code></pre></div></div>
</details>
</div>
<p>We’ll use <code>torchvision</code>’s implementation of ResNet-101, Xavier initialization, SGD with momentum <img src="https://latex.codecogs.com/png.latex?0.9"> and weight decay <img src="https://latex.codecogs.com/png.latex?5%5Ctimes%2010%5E%7B-4%7D">, and cross-entropy loss. We try to implement the training details and learning rate scheduling as mentioned in the paper:</p>
<blockquote class="blockquote">
<p>“Initially, all models are trained for 165 epochs and as in [17] we divide the learning rate by 10 after epoch 50% and 75%, at which point learning has typically plateaued. If learning doesn’t plateu for some number of epochs, we roughly double the number of epochs until it does”.</p>
</blockquote>
<div id="cell-12" class="cell" data-execution_count="9">
<details class="code-fold">
<summary>Init, and train functions</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb4" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> xavier_init(m):</span>
<span id="cb4-2">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">isinstance</span>(m, nn.Conv2d) <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">or</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">isinstance</span>(m, nn.Linear):</span>
<span id="cb4-3">        nn.init.xavier_uniform_(m.weight)</span>
<span id="cb4-4"></span>
<span id="cb4-5"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> train_epoch(model, train, optimizer, criterion):</span>
<span id="cb4-6">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Trains the model for one epoch</span></span>
<span id="cb4-7">    model.train()</span>
<span id="cb4-8">    train_loss, correct <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.0</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span></span>
<span id="cb4-9">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> images, labels <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> train:</span>
<span id="cb4-10">        images, labels <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> images.to(device), labels.to(device)</span>
<span id="cb4-11">        optimizer.zero_grad()</span>
<span id="cb4-12">        output <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> model(images)</span>
<span id="cb4-13">        loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> criterion(output, labels)</span>
<span id="cb4-14">        loss.backward()</span>
<span id="cb4-15">        optimizer.step()</span>
<span id="cb4-16">        train_loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+=</span> loss.item()</span>
<span id="cb4-17">        _, pred <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">max</span>(output, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)</span>
<span id="cb4-18">        correct <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+=</span> (pred <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> labels).<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span>().<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">sum</span>().item()</span>
<span id="cb4-19">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> train_loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(train.dataset), correct <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(train.dataset)</span>
<span id="cb4-20"></span>
<span id="cb4-21"></span>
<span id="cb4-22"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> train(model, train, val, init_lr, plateau_patience <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">20</span>):</span>
<span id="cb4-23"></span>
<span id="cb4-24">    optimizer <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> optim.SGD(model.parameters(), lr <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> init_lr, momentum <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.9</span>, weight_decay <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">5e-4</span>)</span>
<span id="cb4-25">    scheduler  <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> optim.lr_scheduler.MultiStepLR(optimizer, milestones <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> [<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">82</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">123</span>], gamma <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.1</span>)</span>
<span id="cb4-26">    criterion <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.CrossEntropyLoss()</span>
<span id="cb4-27"></span>
<span id="cb4-28">    model.to(device)</span>
<span id="cb4-29"></span>
<span id="cb4-30">    init_epochs <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">165</span></span>
<span id="cb4-31"></span>
<span id="cb4-32">    epoch <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span></span>
<span id="cb4-33">    plateau_count <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span></span>
<span id="cb4-34">    best_loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span></span>
<span id="cb4-35"></span>
<span id="cb4-36">    train_losses, train_accs, val_losses, val_accs <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> [], [], [], []</span>
<span id="cb4-37"></span>
<span id="cb4-38">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">while</span> epoch <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&lt;</span> init_epochs <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">and</span> plateau_count <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&lt;</span> plateau_patience:</span>
<span id="cb4-39"></span>
<span id="cb4-40">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Train the model for an epoch</span></span>
<span id="cb4-41">        loss, acc <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> train_epoch(model, train, optimizer, criterion)</span>
<span id="cb4-42">        train_losses.append(loss)</span>
<span id="cb4-43">        train_accs.append(acc)</span>
<span id="cb4-44"></span>
<span id="cb4-45">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Evaluate the model on the validation set</span></span>
<span id="cb4-46">        val_loss, val_acc <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> eval_model(model, val, criterion)</span>
<span id="cb4-47">        val_losses.append(val_loss)</span>
<span id="cb4-48">        val_accs.append(val_acc)</span>
<span id="cb4-49"></span>
<span id="cb4-50">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Update the learning rate</span></span>
<span id="cb4-51">        scheduler.step(val_loss)</span>
<span id="cb4-52"></span>
<span id="cb4-53">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Check for a plateau</span></span>
<span id="cb4-54">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> best_loss <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">is</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span> <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">or</span> val_loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&lt;</span> best_loss:</span>
<span id="cb4-55">            best_loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> val_loss</span>
<span id="cb4-56">            plateau_count <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span></span>
<span id="cb4-57">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">else</span>:</span>
<span id="cb4-58">            plateau_count <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span></span>
<span id="cb4-59">        </span>
<span id="cb4-60">        epoch <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span></span>
<span id="cb4-61"></span>
<span id="cb4-62">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># "If learning doesn’t plateu for some number of epochs,</span></span>
<span id="cb4-63">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># we roughly double the number of epochs until it does."</span></span>
<span id="cb4-64">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> epoch <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> init_epochs <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">and</span> plateau_count <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&lt;</span> plateau_patience:</span>
<span id="cb4-65">            init_epochs <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span></span>
<span id="cb4-66"></span>
<span id="cb4-67">        <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(<span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f'Epoch </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>epoch<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">/</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>init_epochs<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;"> | Learning Rate: </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>optimizer<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">.</span>param_groups[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>][<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"lr"</span>]<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;"> | '</span></span>
<span id="cb4-68">              <span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f'Training loss: </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>train_losses[<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>]<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">:.4f}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;"> | '</span></span>
<span id="cb4-69">              <span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f'Validation loss: </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>val_losses[<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>]<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">:.4f}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;"> | '</span></span>
<span id="cb4-70">              <span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f'Validation accuracy: </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>val_accs[<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>]<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">:.4f}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">'</span>)</span>
<span id="cb4-71">        </span>
<span id="cb4-72">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> train_losses, train_accs, val_losses, val_accs</span></code></pre></div></div>
</details>
</div>
<p>And we define a function to disable batch norm layers in a model by replacing them with identity layers:</p>
<div id="cell-14" class="cell" data-execution_count="6">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb5" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> disable_bn(model):</span>
<span id="cb5-2">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> name, module <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> model.named_children():</span>
<span id="cb5-3">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">isinstance</span>(module, nn.BatchNorm2d):</span>
<span id="cb5-4">            <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">setattr</span>(model, name, nn.Identity())</span>
<span id="cb5-5">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">else</span>:</span>
<span id="cb5-6">            disable_bn(module)  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Recursively replace in child modules</span></span></code></pre></div></div>
</div>
</section>
<section id="fig-1" class="level3">
<h3 class="anchored" data-anchor-id="fig-1">Fig 1</h3>
<p>Figure 1 aims to demonstrate that batch norm’s primary benefit is that it allows training with larger learning rates.</p>
<p>The authors find the highest (<em>initial</em>) learning rate with which they can train an unnormalized model (<img src="https://latex.codecogs.com/png.latex?%5Calpha%20=%200.0001">) and compare its performance with normalized models trained with <img src="https://latex.codecogs.com/png.latex?%5Calpha%20%5Cin%20%5C%7B0.0001,%200.003,%200.1%5C%7D">. We train each model once (instead of five times to save on compute) and present train and test accuracy curves.</p>
<div id="cell-16" class="cell" data-execution_count="7">
<details class="code-fold">
<summary>Train models</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb6" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1">MODELS_DIR <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'models'</span></span>
<span id="cb6-2">LOGS_DIR <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'logs'</span></span>
<span id="cb6-3"></span>
<span id="cb6-4"><span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> lr, bn <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> [(<span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.0001</span>, <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>), (<span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.0001</span>, <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>), (<span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.003</span>, <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>), (<span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.1</span>, <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)]:</span>
<span id="cb6-5"></span>
<span id="cb6-6">    s <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f'lr=</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>lr<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">'</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> (<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'_bn'</span> <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> bn <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">else</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">''</span>)</span>
<span id="cb6-7">    <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(s)</span>
<span id="cb6-8"></span>
<span id="cb6-9">    model <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> models.resnet101(num_classes <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span>)</span>
<span id="cb6-10">    model.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">apply</span>(xavier_init)</span>
<span id="cb6-11"></span>
<span id="cb6-12">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">not</span> bn: disable_bn(model)</span>
<span id="cb6-13"></span>
<span id="cb6-14">    torch.save(model, <span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f'</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>MODELS_DIR<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">/</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>s<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">_init.pth'</span>)</span>
<span id="cb6-15">    data <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> train(model, train_loader, val_loader, init_lr <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> lr)</span>
<span id="cb6-16">    torch.save(model, <span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f'</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>MODELS_DIR<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">/</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>s<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">_end.pth'</span>)</span>
<span id="cb6-17">    torch.save(data, <span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f'</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>LOGS_DIR<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">/</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>s<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">.pth'</span>)</span></code></pre></div></div>
</details>
</div>
<div id="cell-17" class="cell" data-execution_count="108">
<details class="code-fold">
<summary>Code</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb7" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># code-summary: Plot results</span></span>
<span id="cb7-2"></span>
<span id="cb7-3">get_x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">lambda</span> x: <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">100</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> np.arange(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(x) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(x)</span>
<span id="cb7-4"></span>
<span id="cb7-5">f, (ax1, ax2) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> plt.subplots(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, figsize <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> (<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">9</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span>), sharey <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)</span>
<span id="cb7-6"></span>
<span id="cb7-7"><span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> fname <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> os.listdir(LOGS_DIR):</span>
<span id="cb7-8"></span>
<span id="cb7-9">    train_losses, train_accs, val_losses, val_accs <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.load(<span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f'</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>LOGS_DIR<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">/</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>fname<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">'</span>)</span>
<span id="cb7-10">    ax1.plot(get_x(train_accs), train_accs, label <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> fname[:<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span>])</span>
<span id="cb7-11">    ax2.plot(get_x(val_accs), val_accs, label <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> fname[:<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span>])</span>
<span id="cb7-12">    <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(<span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f'</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>fname[:<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span>]<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;"> took </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(train_accs)<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;"> epochs'</span>)</span>
<span id="cb7-13"></span>
<span id="cb7-14"></span>
<span id="cb7-15">ax1.legend()<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">;</span> ax2.legend()</span>
<span id="cb7-16">ax1.set_ylabel(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Training accuracy'</span>)<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">;</span> ax2.set_ylabel(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Validation accuracy'</span>)</span>
<span id="cb7-17">ax1.set_xlabel(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">% o</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">f training'</span>)<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">;</span> ax2.set_xlabel(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">% o</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">f training'</span>)</span>
<span id="cb7-18">f.tight_layout()</span></code></pre></div></div>
</details>
<div class="cell-output cell-output-stdout">
<pre><code>lr=0.1_bn took 83 epochs
lr=0.0001 took 263 epochs
lr=0.003_bn took 69 epochs
lr=0.0001_bn took 211 epochs</code></pre>
</div>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://ecntu.com/posts/understanding-bn/index_files/figure-html/cell-8-output-2.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
</div>
<p>And observe the same general trends found in the original paper: similar learning rates result in about the same performance (red and orange) while increasing the rate yields better performance for normalized networks (blue) and training diverges for non-normalized ones (not shown).</p>
</section>
<section id="fig-2" class="level3">
<h3 class="anchored" data-anchor-id="fig-2">Fig 2</h3>
<p>In Figure 2 the authors begin to investigate <em>“why BN facilitates training with higher learning rates in the first place”</em>. The authors claim that batch norm (BN) prevents divergence during training, which usually occurs because of large gradients in the first mini-batches.</p>
<p>So, the authors analyze the gradients at initialization of a midpoint layer (55) with and without batch norm. They find that gradients in unnormalized networks are consistently larger and distributed with heavier tails.</p>
<p>I had trouble replicating this figure. I could not obtain the general shape and scale of the histograms they did:</p>
<div class="quarto-figure quarto-figure-center">
<figure class="figure">
<p><img src="https://ecntu.com/posts/understanding-bn/images/fig2.png" class="img-fluid quarto-figure quarto-figure-center figure-img" width="500"></p>
</figure>
</div>
<p>At first, I thought because I was</p>
<ul>
<li>looking at the wrong layer, +- 1 (then found it made little difference)</li>
<li>logging the gradient magnitudes incorrectly - why does the plot have negative values? (then found the authors plot the raw gradient)</li>
<li>misunderstanding the whole process</li>
</ul>
<p>As I understood it, we initialize the model (using Xavier’s initialization), do a forward and backward pass on a single batch, and log the gradients at roughly the middle layer:</p>
<div id="cell-20" class="cell" data-execution_count="10">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb9" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1">f, (ax1, ax2) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> plt.subplots(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, figsize <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> (<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">8</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span>))</span>
<span id="cb9-2">images, labels <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">next</span>(<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">iter</span>(val_loader))</span>
<span id="cb9-3">images, labels <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> images.to(device), labels.to(device)</span>
<span id="cb9-4"></span>
<span id="cb9-5"><span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> bn, ax <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">zip</span>([<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>, <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>], [ax1, ax2]):</span>
<span id="cb9-6"></span>
<span id="cb9-7">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Init</span></span>
<span id="cb9-8">    model <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> models.resnet101(num_classes <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span>)</span>
<span id="cb9-9">    model.to(device)</span>
<span id="cb9-10">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">not</span> bn: disable_bn(model)</span>
<span id="cb9-11">    model.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">apply</span>(xavier_init)</span>
<span id="cb9-12"></span>
<span id="cb9-13">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Forward and backward pass</span></span>
<span id="cb9-14">    model.train()</span>
<span id="cb9-15">    output <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> model(images)</span>
<span id="cb9-16">    loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.CrossEntropyLoss()(output, labels)</span>
<span id="cb9-17">    loss.backward()</span>
<span id="cb9-18"></span>
<span id="cb9-19">    model.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">eval</span>()</span>
<span id="cb9-20">    grads <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> model.layer3[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">9</span>].conv1.weight.grad.view(<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>).cpu().detach().numpy()</span>
<span id="cb9-21">    sns.histplot(grads, ax <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> ax)</span>
<span id="cb9-22"></span>
<span id="cb9-23">ax1.set_title(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'With Batch Normalization'</span>)</span>
<span id="cb9-24">ax2.set_title(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Without Batch Normalization'</span>)</span>
<span id="cb9-25">ax2.set_ylim((<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2500</span>))<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">;</span> ax2.set_ylabel(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">''</span>)</span>
<span id="cb9-26">f.tight_layout()</span></code></pre></div></div>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://ecntu.com/posts/understanding-bn/index_files/figure-html/cell-9-output-1.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
</div>
<p>Although the unnormalized gradients are heavy-tailed, they are still much smaller than the normalized ones. I was stuck on this issue for a few days until I experimented with different initializations:</p>
<div id="cell-22" class="cell" data-execution_count="24">
<details class="code-fold">
<summary>Trying other inits</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb10" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb10-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> init_func(f, also_linear <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>):</span>
<span id="cb10-2">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> init(m):</span>
<span id="cb10-3">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">isinstance</span>(m, nn.Conv2d):</span>
<span id="cb10-4">            f(m.weight)</span>
<span id="cb10-5">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> also_linear <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">and</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">isinstance</span>(m, nn.Linear):</span>
<span id="cb10-6">            f(m.weight)</span>
<span id="cb10-7">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> init</span>
<span id="cb10-8"></span>
<span id="cb10-9">criterion <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.CrossEntropyLoss()</span>
<span id="cb10-10">images, labels <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">next</span>(<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">iter</span>(val_loader))</span>
<span id="cb10-11">images, labels <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> images.to(device), labels.to(device)</span>
<span id="cb10-12"></span>
<span id="cb10-13">with_bn, without_bn <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> {}, {}</span>
<span id="cb10-14"></span>
<span id="cb10-15"><span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> init_f <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> [</span>
<span id="cb10-16">    nn.init.xavier_uniform_, nn.init.kaiming_uniform_,</span>
<span id="cb10-17">    nn.init.kaiming_normal_, <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">lambda</span> x: nn.init.kaiming_normal_(x, mode <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'fan_out'</span>, nonlinearity<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'relu'</span>)]:</span>
<span id="cb10-18"></span>
<span id="cb10-19">    init_name <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> init_f.<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">__name__</span>[:<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>]</span>
<span id="cb10-20">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'lambda'</span> <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> init_name: init_name <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'(default) kaiming_normal fan_out'</span></span>
<span id="cb10-21">    </span>
<span id="cb10-22">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> linear <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> (<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>, <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>):</span>
<span id="cb10-23"></span>
<span id="cb10-24">        lin_name <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">' w/linear lyrs'</span> <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> linear <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">else</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">''</span></span>
<span id="cb10-25">        </span>
<span id="cb10-26">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> bn, d <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">zip</span>((<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>, <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>), (with_bn, without_bn)):</span>
<span id="cb10-27">                </span>
<span id="cb10-28">                model <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> models.resnet101(num_classes <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span>)</span>
<span id="cb10-29">                model.to(device)</span>
<span id="cb10-30">                <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">not</span> bn: disable_bn(model)</span>
<span id="cb10-31">                model.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">apply</span>(init_func(init_f, linear))</span>
<span id="cb10-32"></span>
<span id="cb10-33">                model.train()</span>
<span id="cb10-34">                output <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> model(images)</span>
<span id="cb10-35">                loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> criterion(output, labels)</span>
<span id="cb10-36">                loss.backward()</span>
<span id="cb10-37"></span>
<span id="cb10-38">                model.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">eval</span>()</span>
<span id="cb10-39">                d[init_name <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> lin_name] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> model.layer3[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">9</span>].conv1.weight.grad.view(<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>).cpu().detach().numpy()</span></code></pre></div></div>
</details>
</div>
<div id="cell-23" class="cell" data-execution_count="25">
<details class="code-fold">
<summary>Plotting the results</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb11" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb11-1"><span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> init_name, v <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> with_bn.items():</span>
<span id="cb11-2">    f, (ax1, ax2) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> plt.subplots(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, figsize <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> (<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">8</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span>))</span>
<span id="cb11-3">    sns.histplot(v, ax <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> ax1)</span>
<span id="cb11-4">    ax1.set_title(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'with BN'</span>)</span>
<span id="cb11-5">    sns.histplot(without_bn[init_name], ax <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> ax2)</span>
<span id="cb11-6">    ax2.set_title(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'without BN'</span>)</span>
<span id="cb11-7">    ax2.set_ylabel(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">''</span>)</span>
<span id="cb11-8">    ax2.set_ylim(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1500</span>)</span>
<span id="cb11-9">    f.suptitle(init_name)</span>
<span id="cb11-10">    f.tight_layout()</span></code></pre></div></div>
</details>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://ecntu.com/posts/understanding-bn/index_files/figure-html/cell-11-output-1.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://ecntu.com/posts/understanding-bn/index_files/figure-html/cell-11-output-2.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://ecntu.com/posts/understanding-bn/index_files/figure-html/cell-11-output-3.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://ecntu.com/posts/understanding-bn/index_files/figure-html/cell-11-output-4.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://ecntu.com/posts/understanding-bn/index_files/figure-html/cell-11-output-5.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://ecntu.com/posts/understanding-bn/index_files/figure-html/cell-11-output-6.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://ecntu.com/posts/understanding-bn/index_files/figure-html/cell-11-output-7.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://ecntu.com/posts/understanding-bn/index_files/figure-html/cell-11-output-8.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
</div>
<p>As you can see the general shapes, normal vs heavy-tailed don’t depend that much on the initialization scheme but the scales do. We could only achieve the same scale of the gradients presented in the paper by using the <code>kaiming_normal</code> scheme with <code>fan=out</code> (to preserve the magnitudes of the variance of the weights in the backward pass instead of the forward one) and applying it only to <code>Conv2</code> layers. This is the default used by torchvision’s resnets.</p>
<p>Note: <code>xavier_normal</code> produced very similar shapes/scales as <code>xavier_uniform</code> so we don’t show it.</p>
<p>For the rest of the figures we’ll use the default init scheme:</p>
<div id="cell-25" class="cell" data-execution_count="26">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb12" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb12-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> init_net(m):</span>
<span id="cb12-2">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">isinstance</span>(m, nn.Conv2d):</span>
<span id="cb12-3">        nn.init.kaiming_normal_(m.weight, mode <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'fan_out'</span>, nonlinearity <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'relu'</span>)</span></code></pre></div></div>
</div>
</section>
<section id="fig-3" class="level3">
<h3 class="anchored" data-anchor-id="fig-3">Fig 3</h3>
<p>The authors then investigate the loss landscape along the gradient direction for the first few mini-batches for models with BN (trained with <img src="https://latex.codecogs.com/png.latex?%5Calpha%20=%200.1">) and without BN (<img src="https://latex.codecogs.com/png.latex?%5Calpha%20=%200.0001">). For each network and mini-batch they compute the gradient and plot the relative change in the loss (new_loss/old_loss).</p>
<p>We save the model’s and optimizer’s states (<code>state_dict</code>) before taking the tentative steps to explore the landscape and restore them before taking the actual step between batches.</p>
<div id="cell-28" class="cell" data-execution_count="57">
<details class="code-fold">
<summary>Explore loss landscape</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb13" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb13-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> fig3(model, init_lr, log_lrs, log_batches <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> [<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">7</span>]):</span>
<span id="cb13-2"></span>
<span id="cb13-3">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># batch -&gt; list of relative losses for each lr</span></span>
<span id="cb13-4">    out <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> {}</span>
<span id="cb13-5"></span>
<span id="cb13-6">    criterion <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.CrossEntropyLoss()</span>
<span id="cb13-7">    optimizer <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> optim.SGD(model.parameters(), lr <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> init_lr, momentum <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.9</span>, weight_decay <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">5e-4</span>)</span>
<span id="cb13-8"></span>
<span id="cb13-9">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> i, (images, labels) <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">enumerate</span>(train_loader):</span>
<span id="cb13-10"></span>
<span id="cb13-11">        images, labels <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> images.to(device), labels.to(device)</span>
<span id="cb13-12"></span>
<span id="cb13-13">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> i <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> log_batches:</span>
<span id="cb13-14">            </span>
<span id="cb13-15">            torch.save(model.state_dict(), <span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f'models/model_state.tmp'</span>)</span>
<span id="cb13-16">            torch.save(optimizer.state_dict(), <span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f'models/optimizer_state.tmp'</span>)</span>
<span id="cb13-17"></span>
<span id="cb13-18">            rel_losses <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> []</span>
<span id="cb13-19">            <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> lr <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> log_lrs:</span>
<span id="cb13-20"></span>
<span id="cb13-21">                <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> param_group <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> optimizer.param_groups: param_group[<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'lr'</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> lr</span>
<span id="cb13-22"></span>
<span id="cb13-23">                optimizer.zero_grad()</span>
<span id="cb13-24">                output <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> model(images)</span>
<span id="cb13-25">                current_loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> criterion(output, labels)</span>
<span id="cb13-26">                current_loss.backward()</span>
<span id="cb13-27">                optimizer.step()</span>
<span id="cb13-28"></span>
<span id="cb13-29">                <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">with</span> torch.no_grad():</span>
<span id="cb13-30">                    output <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> model(images)</span>
<span id="cb13-31">                    tmp_loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> criterion(output, labels)</span>
<span id="cb13-32">                    rel_loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> (tmp_loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> current_loss).item()</span>
<span id="cb13-33">                </span>
<span id="cb13-34">                    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># print learning rate, current loss, tmp loss, relative loss (at 4 decimal places)</span></span>
<span id="cb13-35">                    <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(<span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f'</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>lr<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">:.5f}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;"> </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>current_loss<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">:.5f}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;"> </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>tmp_loss<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">:.5f}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;"> </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>rel_loss<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">:.5f}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">'</span>)</span>
<span id="cb13-36">                    rel_losses.append(rel_loss)</span>
<span id="cb13-37">            </span>
<span id="cb13-38">                model.load_state_dict(torch.load(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'models/model_state.tmp'</span>))</span>
<span id="cb13-39">                optimizer.load_state_dict(torch.load(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'models/optimizer_state.tmp'</span>))   </span>
<span id="cb13-40"></span>
<span id="cb13-41">                <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># If loss is nan of int, break. Unlikely to recover.</span></span>
<span id="cb13-42">                <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> torch.isnan(tmp_loss).item() <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">or</span> torch.isinf(tmp_loss).item():</span>
<span id="cb13-43">                    rel_losses.pop()</span>
<span id="cb13-44">                    <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'breaking'</span>)</span>
<span id="cb13-45">                    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">break</span></span>
<span id="cb13-46"></span>
<span id="cb13-47">            out[i] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> rel_losses</span>
<span id="cb13-48">        </span>
<span id="cb13-49">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> i <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">max</span>(log_batches): <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">break</span></span>
<span id="cb13-50"></span>
<span id="cb13-51">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># take the actual step</span></span>
<span id="cb13-52">        optimizer.zero_grad()</span>
<span id="cb13-53">        output <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> model(images)</span>
<span id="cb13-54">        loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> criterion(output, labels)</span>
<span id="cb13-55">        loss.backward()</span>
<span id="cb13-56">        optimizer.step()</span>
<span id="cb13-57">    </span>
<span id="cb13-58">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> out</span>
<span id="cb13-59"></span>
<span id="cb13-60"></span>
<span id="cb13-61">lrs <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.logspace(<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">5.5</span>, <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.5</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">80</span>)</span>
<span id="cb13-62"></span>
<span id="cb13-63"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># With BN</span></span>
<span id="cb13-64">model <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> models.resnet101(num_classes <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span>)</span>
<span id="cb13-65">model.to(device)</span>
<span id="cb13-66">model.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">apply</span>(init_net)</span>
<span id="cb13-67">with_bn <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> fig3(model, init_lr <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.1</span>, log_lrs <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> lrs)</span>
<span id="cb13-68"></span>
<span id="cb13-69"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Without BN</span></span>
<span id="cb13-70">model <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> models.resnet101(num_classes <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span>)</span>
<span id="cb13-71">model.to(device)</span>
<span id="cb13-72">disable_bn(model)</span>
<span id="cb13-73">model.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">apply</span>(init_net)</span>
<span id="cb13-74">without_bn <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> fig3(model, init_lr <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.0001</span>, log_lrs <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> lrs)</span></code></pre></div></div>
</details>
</div>
<div id="cell-29" class="cell" data-execution_count="65">
<details class="code-fold">
<summary>Plotting the results</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb14" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb14-1">f, axs <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> plt.subplots(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span>, figsize <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> (<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">12</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">6</span>))</span>
<span id="cb14-2"></span>
<span id="cb14-3"><span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> (batch, rel_losses), ax <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">zip</span>(with_bn.items(), axs[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>]):</span>
<span id="cb14-4">    ax.plot(lrs, rel_losses)</span>
<span id="cb14-5">    ax.set_xscale(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'log'</span>)</span>
<span id="cb14-6">    ax.set_yscale(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'log'</span>)</span>
<span id="cb14-7">    ax.set_ylim(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">1.5</span>)</span>
<span id="cb14-8">    ax.set_xlim(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">**-</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">5.5</span>, <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.4</span>)</span>
<span id="cb14-9">    ax.set_title(<span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f'Batch </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>batch<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;"> with BN'</span>)</span>
<span id="cb14-10">    ax.axhline(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, color <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'grey'</span>, linestyle <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'--'</span>)</span>
<span id="cb14-11"></span>
<span id="cb14-12"><span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> (batch, rel_losses), ax <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">zip</span>(without_bn.items(), axs[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>]):</span>
<span id="cb14-13">    ax.plot(lrs[:<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(rel_losses)], rel_losses)</span>
<span id="cb14-14">    ax.set_xscale(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'log'</span>)</span>
<span id="cb14-15">    ax.set_yscale(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'log'</span>)</span>
<span id="cb14-16">    ax.set_ylim(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">**-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">**</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span>)</span>
<span id="cb14-17">    ax.set_xlim(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">**-</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">5.5</span>, <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.4</span>)</span>
<span id="cb14-18">    ax.set_title(<span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f'Batch </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>batch<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;"> w/o BN'</span>)</span>
<span id="cb14-19">    ax.axhline(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, color <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'grey'</span>, linestyle <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'--'</span>)</span>
<span id="cb14-20"></span>
<span id="cb14-21">axs[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>].set_ylabel(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Relative loss'</span>)</span>
<span id="cb14-22">axs[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>].set_ylabel(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Relative loss'</span>)</span>
<span id="cb14-23">f.tight_layout()</span></code></pre></div></div>
</details>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://ecntu.com/posts/understanding-bn/index_files/figure-html/cell-14-output-1.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
</div>
<p>Although we get roughly different scales, we observe that unnormalized networks reduce the loss only with small steps while normalized ones can improve with a much larger range, as in the paper.</p>
</section>
<section id="fig-5" class="level3">
<h3 class="anchored" data-anchor-id="fig-5">Fig 5</h3>
<p>Figures 5 and 6 explore the behavior of networks at initialization. Figure 5 displays the mean and variances of channels in the network as a function of depth at initialization. We initialize <img src="https://latex.codecogs.com/png.latex?10"> networks and use forward hooks to log their channel mean and standard deviations.</p>
<div id="cell-33" class="cell" data-execution_count="45">
<details class="code-fold">
<summary>Log activation stats</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb15" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb15-1">df <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> []</span>
<span id="cb15-2"></span>
<span id="cb15-3"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> log_activation_stats(layer_name, key):</span>
<span id="cb15-4">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> hook(module, <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">input</span>, output):</span>
<span id="cb15-5">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">with</span> torch.no_grad():</span>
<span id="cb15-6">            df.append({</span>
<span id="cb15-7">                <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'bn'</span>: key,</span>
<span id="cb15-8">                <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'layer'</span>: layer_name,</span>
<span id="cb15-9">                <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'mean'</span>: output[:, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, :, :].mean().<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">abs</span>().item(),</span>
<span id="cb15-10">                <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'std'</span>: output[:, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, :, :].std().<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">abs</span>().item()</span>
<span id="cb15-11">            })</span>
<span id="cb15-12">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> hook</span>
<span id="cb15-13"></span>
<span id="cb15-14"></span>
<span id="cb15-15"><span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> _ <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span>):</span>
<span id="cb15-16"></span>
<span id="cb15-17">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> bn <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> [<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>, <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>]:</span>
<span id="cb15-18"></span>
<span id="cb15-19">        model <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> models.resnet101(num_classes <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span>)</span>
<span id="cb15-20">        model.to(device)</span>
<span id="cb15-21">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">not</span> bn: disable_bn(model)</span>
<span id="cb15-22">        model.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">apply</span>(init_net)</span>
<span id="cb15-23"></span>
<span id="cb15-24">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Layers to log activations from</span></span>
<span id="cb15-25">        log_layers <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> [] <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># (n, name, layer)</span></span>
<span id="cb15-26">        n <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span></span>
<span id="cb15-27">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> name, layer <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> model.named_modules():</span>
<span id="cb15-28">            <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'conv'</span> <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> name:</span>
<span id="cb15-29">                n <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span></span>
<span id="cb15-30">                <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> n <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">list</span>(<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">5</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">101</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">12</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">12</span>)):</span>
<span id="cb15-31">                    log_layers.append((n, name, layer))</span>
<span id="cb15-32"></span>
<span id="cb15-33">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> n, _, layer <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> log_layers: layer.register_forward_hook(log_activation_stats(n, <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">str</span>(bn)))</span>
<span id="cb15-34">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> images, _ <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> val_loader: model(images.to(device))</span>
<span id="cb15-35"></span>
<span id="cb15-36">df <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> pd.DataFrame(df)</span></code></pre></div></div>
</details>
</div>
<div id="cell-34" class="cell" data-execution_count="46">
<details class="code-fold">
<summary>Code</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb16" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb16-1">f, (ax1, ax2) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> plt.subplots(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, figsize <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> (<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">9</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span>))</span>
<span id="cb16-2"></span>
<span id="cb16-3">sns.lineplot(data <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> df, x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'layer'</span>, y <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'mean'</span>, hue <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'bn'</span>, ax <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> ax1)</span>
<span id="cb16-4">ax1.set_yscale(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'log'</span>)</span>
<span id="cb16-5"></span>
<span id="cb16-6">sns.lineplot(data <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> df, x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'layer'</span>, y <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'std'</span>, hue <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'bn'</span>, ax <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> ax2)</span>
<span id="cb16-7">ax2.set_yscale(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'log'</span>)</span>
<span id="cb16-8"></span>
<span id="cb16-9">f.tight_layout()</span></code></pre></div></div>
</details>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://ecntu.com/posts/understanding-bn/index_files/figure-html/cell-16-output-1.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
</div>
<p>We observe, consistent with the findings in the paper, that activation means and standard deviations increase almost exponentially in non-normalized networks, whereas they remain nearly constant in normalized networks.</p>
</section>
<section id="fig-6" class="level3">
<h3 class="anchored" data-anchor-id="fig-6">Fig 6</h3>
<p>The large activations in the final layers for unnormalized networks in the previous figure make us suspect that networks are biased towards a class. The authors investigate whether this is the case by looking at the gradients in the final (output) layer across images in a mini-batch and classes.</p>
<p>Note: Don’t confuse this with the last fully connected layer of the network. We are looking at the gradients of the output logits themselves. We need to use <code>retain_grad</code> on the output (non-leaf node) to calculate its gradient on the backward pass.</p>
<div id="cell-37" class="cell" data-execution_count="69">
<details class="code-fold">
<summary>Generate heatmaps</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb17" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb17-1">f, (ax1, ax2) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> plt.subplots(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, figsize <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> (<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">8</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span>), sharey <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>, sharex <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)</span>
<span id="cb17-2"></span>
<span id="cb17-3">images, labels <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">next</span>(<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">iter</span>(val_loader))</span>
<span id="cb17-4">images, labels <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> images.to(device), labels.to(device)</span>
<span id="cb17-5"></span>
<span id="cb17-6"><span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> bn, ax <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">zip</span>([<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>, <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>], [ax1, ax2]):</span>
<span id="cb17-7"></span>
<span id="cb17-8">    model <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> models.resnet101(num_classes <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span>)</span>
<span id="cb17-9">    model.to(device)</span>
<span id="cb17-10">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">not</span> bn: disable_bn(model)</span>
<span id="cb17-11">    model.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">apply</span>(init_net)</span>
<span id="cb17-12"></span>
<span id="cb17-13">    out <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> model(images)</span>
<span id="cb17-14">    out.retain_grad()</span>
<span id="cb17-15">    loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.CrossEntropyLoss()(out, labels)</span>
<span id="cb17-16">    loss.backward()</span>
<span id="cb17-17"></span>
<span id="cb17-18">    ax <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> sns.heatmap(out.grad.cpu().detach().numpy(), cmap <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'viridis'</span>, ax <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> ax)</span>
<span id="cb17-19">    ax.set_xticks([])<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">;</span> ax.set_yticks([<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">40</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">80</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">120</span>])<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">;</span> ax.set_yticklabels([<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">40</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">80</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">120</span>])</span>
<span id="cb17-20">    ax.set_xlabel(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Classes'</span>)</span>
<span id="cb17-21">    ax.set_ylabel(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Images in batch'</span> <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">not</span> bn <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">else</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">''</span>)</span>
<span id="cb17-22">    ax.set_title(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'With BN'</span> <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> bn <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">else</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Without BN'</span>)</span>
<span id="cb17-23"></span>
<span id="cb17-24">f.tight_layout()</span></code></pre></div></div>
</details>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://ecntu.com/posts/understanding-bn/index_files/figure-html/cell-17-output-1.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
</div>
<p>And basically observe the same results as the paper:</p>
<blockquote class="blockquote">
<p>“A yellow entry indicates that the gradient is positive, and the step along the negative gradient would decrease the prediction strength of this class for this particular image. A dark blue entry indicates a negative gradient, indicating that this particular class prediction should be strengthened. Each row contains one dark blue entry, which corresponds to the true class of this particular image (as initially all predictions are arbitrary). A striking observation is the distinctly yellow column in the left heatmap (network without BN). This indicates that after initialization the network tends to almost always predict the same (typically wrong) class, which is then corrected with a strong gradient update. In contrast, the network with BN does not exhibit the same behavior, instead positive gradients are distributed throughout all classes.”</p>
</blockquote>
<p>Running the above code multiple times, however, sometimes results in two or three yellow columns. We think this is because different mini-batches behave slightly differently or due to initialization randomness. Below, we log and average the gradients for a whole epoch and find much more consistent behavior.</p>
<div id="cell-40" class="cell" data-execution_count="370">
<details class="code-fold">
<summary>Code</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb18" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb18-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> fig6(init_func <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> init_net):</span>
<span id="cb18-2"></span>
<span id="cb18-3">    f, (ax1, ax2) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> plt.subplots(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, figsize <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> (<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">8</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span>), sharex <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>, sharey <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)</span>
<span id="cb18-4"></span>
<span id="cb18-5">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> bn, ax <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">zip</span>([<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>, <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>], [ax1, ax2]):</span>
<span id="cb18-6"></span>
<span id="cb18-7">        model <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> models.resnet101(num_classes <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span>)</span>
<span id="cb18-8">        model.to(device)</span>
<span id="cb18-9">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">not</span> bn: disable_bn(model)</span>
<span id="cb18-10">        model.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">apply</span>(init_func)</span>
<span id="cb18-11"></span>
<span id="cb18-12">        avg_grad <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.zeros((<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">128</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span>), device <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> device)</span>
<span id="cb18-13"></span>
<span id="cb18-14">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> images, labels <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> val_loader:</span>
<span id="cb18-15">            images, labels <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> images.to(device), labels.to(device)</span>
<span id="cb18-16">            out <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> model(images)</span>
<span id="cb18-17">            out.retain_grad()</span>
<span id="cb18-18">            loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.CrossEntropyLoss()(out, labels)</span>
<span id="cb18-19">            loss.backward()</span>
<span id="cb18-20">            <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> out.grad.shape <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> avg_grad.shape: avg_grad <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+=</span> out.grad</span>
<span id="cb18-21">        </span>
<span id="cb18-22">        avg_grad <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(val_loader)</span>
<span id="cb18-23"></span>
<span id="cb18-24">        ax <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> sns.heatmap(avg_grad.cpu().detach().numpy(), cmap <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'viridis'</span>, ax <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> ax)</span>
<span id="cb18-25">        ax.set_xlabel(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Classes'</span>)</span>
<span id="cb18-26">        ax.set_ylabel(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Images in batch'</span>)</span>
<span id="cb18-27">        ax.set_title(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'With BN'</span> <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> bn <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">else</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Without BN'</span>)</span>
<span id="cb18-28"></span>
<span id="cb18-29">    f.tight_layout()</span>
<span id="cb18-30"></span>
<span id="cb18-31">fig6()</span></code></pre></div></div>
</details>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://ecntu.com/posts/understanding-bn/index_files/figure-html/cell-18-output-1.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
</div>
<p>And that’s about it.</p>
</section>
<section id="what-i-learned-practiced" class="level3">
<h3 class="anchored" data-anchor-id="what-i-learned-practiced">What I learned / practiced</h3>
<p>I gained a better understanding and intuition of why Batch Normalization (BN) works. More importantly, I got comfortable with PyTorch and debugging training, etc.</p>
<p>Pytorch specific:</p>
<ul>
<li>Basics of image augmentation: basically use <a href="https://pytorch.org/vision/stable/transforms.html">transforms</a> and compose them.</li>
<li>Learning rate schedulers: they exist, are really useful, and pytorch has a <a href="https://blog.paperspace.com/pytorch-hooks-gradient-clipping-debugging/">good assortment of them</a>.</li>
<li><code>state_dict</code> preserves optimizer’s param groups and args (learning rates, etc.) but also momentum buffers.</li>
<li><a href="https://blog.paperspace.com/pytorch-hooks-gradient-clipping-debugging/">hooks</a> as useful debugging and visualization tools.</li>
<li><code>retain_grad</code> is required to get gradients of non-leaf nodes like the output logits.</li>
</ul>
<p>For the large training runs, I also experimented with <a href="https://jarvislabs.ai">jarvislabs.ai</a> as a provider. In-browser notebooks and VS Code, and direct SSH/FTP access were pretty nice. I could not work out funkiness with VS Code remote windows. Used</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb19" style="background: #f1f3f5;"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb19-1"><span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">nohup</span> jupyter nbconvert <span class="at" style="color: #657422;
background-color: null;
font-style: inherit;">--execute</span> <span class="at" style="color: #657422;
background-color: null;
font-style: inherit;">--to</span> notebook <span class="at" style="color: #657422;
background-color: null;
font-style: inherit;">--inplace</span> <span class="at" style="color: #657422;
background-color: null;
font-style: inherit;">--allow-errors</span> main.ipynb <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">&amp;</span></span></code></pre></div></div>
<p>to run the notebook, write results, and be able to close the Jupyter / VS Code tab.</p>


</section>

 ]]></description>
  <category>deep learning</category>
  <category>paper</category>
  <guid>https://ecntu.com/posts/understanding-bn/</guid>
  <pubDate>Wed, 17 Jul 2024 04:00:00 GMT</pubDate>
</item>
<item>
  <title>Deep Learning is Robust to Massive Label Noise</title>
  <link>https://ecntu.com/posts/dl-massive-label-noise/</link>
  <description><![CDATA[ 





<p>The paper shows that neural networks can keep generalizing when large numbers of (non-adversarially) incorrectly labeled examples are added to datasets (MNIST, CIFAR, and ImageNet). It also appears that larger networks are more robust and that higher noise levels lead to lower optimal (fixed) learning rates.</p>
<p>We’ll focus on the uniform label noise experiment and attempt to reproduce Figure 1:</p>
<div class="quarto-figure quarto-figure-center">
<figure class="figure">
<p><img src="https://ecntu.com/posts/dl-massive-label-noise/images/fig1.png" class="img-fluid figure-img" width="300"></p>
<figcaption>Figure 1. As we increase the amount of noise in the dataset the performance drops. However, note that even when there are 100 noisy labels <em>per</em> clean label performance is still acceptable. For example, the Convnet still achieves 91% accuracy.</figcaption>
</figure>
</div>
<p>Note: As far as I can tell the paper has no accompanying code so I’ll be filling in the details to the best of my abilities.</p>
<div id="cell-2" class="cell" data-execution_count="1">
<details class="code-fold">
<summary>Imports and model evaluation function</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb1" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> torch</span>
<span id="cb1-2"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> torch.nn <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> nn</span>
<span id="cb1-3"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> torch.optim <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> optim</span>
<span id="cb1-4"></span>
<span id="cb1-5"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> torchvision <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> datasets, transforms</span>
<span id="cb1-6"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> torch.utils.data <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> DataLoader, Dataset</span>
<span id="cb1-7"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> PIL <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> Image</span>
<span id="cb1-8"></span>
<span id="cb1-9"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> numpy <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> np</span>
<span id="cb1-10"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> pandas <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> pd</span>
<span id="cb1-11"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> seaborn <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> sns</span>
<span id="cb1-12"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> matplotlib.pyplot <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> plt</span>
<span id="cb1-13"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> os, itertools, time</span>
<span id="cb1-14"></span>
<span id="cb1-15">os.makedirs(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'logs'</span>, exist_ok <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)</span>
<span id="cb1-16">os.makedirs(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'models'</span>, exist_ok <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)</span>
<span id="cb1-17"></span>
<span id="cb1-18">seed <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">42</span></span>
<span id="cb1-19">np.random.seed(seed)</span>
<span id="cb1-20">torch.manual_seed(seed)</span>
<span id="cb1-21"></span>
<span id="cb1-22">device <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.device(</span>
<span id="cb1-23">    <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'cuda'</span> <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> torch.cuda.is_available() <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">else</span></span>
<span id="cb1-24">    (<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'mps'</span> <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> torch.backends.mps.is_available() <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">else</span></span>
<span id="cb1-25">    <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'cpu'</span>)</span>
<span id="cb1-26">)</span>
<span id="cb1-27">device <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'cpu'</span> <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># faster for the small models we are using</span></span>
<span id="cb1-28"></span>
<span id="cb1-29"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> eval_model(model, test, criterion <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.CrossEntropyLoss()):</span>
<span id="cb1-30">    model.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">eval</span>()</span>
<span id="cb1-31">    correct, loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.0</span></span>
<span id="cb1-32">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">with</span> torch.no_grad():</span>
<span id="cb1-33">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> images, labels <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> test:</span>
<span id="cb1-34">            images, labels <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> images.to(device), labels.to(device)</span>
<span id="cb1-35">            _, pred <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">max</span>(model(images), <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)</span>
<span id="cb1-36">            correct <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+=</span> (pred <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> labels).<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span>().<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">sum</span>().item()</span>
<span id="cb1-37">            loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+=</span> criterion(model(images), labels).item()</span>
<span id="cb1-38">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(test.dataset), correct <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(test.dataset)</span></code></pre></div></div>
</details>
</div>
<p>To generate the uniform label noise the paper augments the original dataset with an additional <img src="https://latex.codecogs.com/png.latex?%5Calpha"> <img src="https://latex.codecogs.com/png.latex?(X_i,%20Y')"> pairs, where <img src="https://latex.codecogs.com/png.latex?Y'"> is a class sampled uniformly at random with replacement.</p>
<p>To minimize disk use I opted for a custom dataset that wraps the original. Pytorch only requires we override <code>__len__</code> and <code>__getitem__</code>. The length is simply the original size plus the amount of noisy labels. When queried for data we’ll generate the noisy labels for the original pairs immediately after it. For example when <img src="https://latex.codecogs.com/png.latex?%5Calpha%20=%202">:</p>
<p><img src="https://latex.codecogs.com/png.latex?%0A(X_1,%20Y_1),%20(X_1,%20Y'),%20(X_1,%20Y'),%20(X_2,%20Y_2),%20...%0A"></p>
<p>Note that to guarantee that noisy labels are consistent between epochs, i.e.&nbsp;data[1] returns the same class when called again, we can’t sample the labels at query time. To avoid storing all the randomly sampled labels (<img src="https://latex.codecogs.com/png.latex?60,%20000%20%5Ctimes%20100"> in the worst case), we simply return a shifted index’s label. We can do this with MNIST because its reasonably class-balanced and shuffled.</p>
<div id="cell-4" class="cell" data-execution_count="3">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb2" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">class</span> NoisyLabelDataset(Dataset):</span>
<span id="cb2-2">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">"""Adds alpha noisy labels per original example"""</span></span>
<span id="cb2-3">    </span>
<span id="cb2-4">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, dataset, alpha):</span>
<span id="cb2-5">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.dataset <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> dataset</span>
<span id="cb2-6">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.alpha <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> alpha</span>
<span id="cb2-7">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.shift <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.random.randint(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(dataset))</span>
<span id="cb2-8"></span>
<span id="cb2-9">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__len__</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>):</span>
<span id="cb2-10">        n <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.dataset)</span>
<span id="cb2-11">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> n <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> (<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.alpha <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> n)</span>
<span id="cb2-12">    </span>
<span id="cb2-13">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__getitem__</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, idx):</span>
<span id="cb2-14">        x, y <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.dataset[idx <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">//</span> (<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.alpha <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)]</span>
<span id="cb2-15">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> idx <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%</span> (<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.alpha <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">!=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>:</span>
<span id="cb2-16">            y <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.dataset[(idx <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.shift) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.dataset)][<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>]</span>
<span id="cb2-17">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> x, y</span></code></pre></div></div>
</div>
<p>Although the paper appears to only include a test set, we also include a validation set to perform early stopping with.</p>
<div id="cell-6" class="cell" data-execution_count="4">
<details class="code-fold">
<summary>Datasets and loaders</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb3" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1">batch_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">128</span></span>
<span id="cb3-2"></span>
<span id="cb3-3">transform <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> transforms.Compose([transforms.ToTensor(), transforms.Normalize((<span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.5</span>,), (<span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.5</span>,))])</span>
<span id="cb3-4">train_dataset <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> datasets.MNIST(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'data'</span>, download <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>, train <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>,   transform <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> transform)</span>
<span id="cb3-5">test_dataset <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>  datasets.MNIST(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'data'</span>, download <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>, train <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>,  transform <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> transform)</span>
<span id="cb3-6"></span>
<span id="cb3-7">noisy_train_dataset <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> NoisyLabelDataset(train_dataset, alpha <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">5</span>)</span>
<span id="cb3-8">val_dataset, test_dataset <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.utils.data.random_split(test_dataset, (<span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.2</span>, <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.8</span>), generator <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.Generator().manual_seed(seed))</span>
<span id="cb3-9"></span>
<span id="cb3-10">train_loader <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> DataLoader(noisy_train_dataset, batch_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> batch_size, shuffle <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)</span>
<span id="cb3-11">val_loader <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> DataLoader(val_dataset, batch_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> batch_size, shuffle <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>)</span>
<span id="cb3-12">test_loader <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> DataLoader(test_dataset, batch_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> batch_size, shuffle <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>)</span></code></pre></div></div>
</details>
</div>
<p>Our training loop is pretty standard. We use the <code>Adadelta</code> optimizer as per the paper. Although the paper does not mention it, we assume they used early stopping: stop training when the validation accuracy does not increase after <code>patience</code> epochs and return the model with the highest validation accuracy.</p>
<div id="cell-8" class="cell" data-execution_count="5">
<details class="code-fold">
<summary>Training function</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb4" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> train(model, train_loader, val_loader, lr <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.01</span>, patience <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span>, max_epochs <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">100</span>, verbose <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>):</span>
<span id="cb4-2">    </span>
<span id="cb4-3">    model.to(device)</span>
<span id="cb4-4">    criterion <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.CrossEntropyLoss()</span>
<span id="cb4-5">    optimizer <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> optim.Adadelta(model.parameters(), lr <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> lr)</span>
<span id="cb4-6"></span>
<span id="cb4-7">    log <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> {<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'train_loss'</span>: [], <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'val_loss'</span>: [], <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'val_acc'</span>: []}</span>
<span id="cb4-8"></span>
<span id="cb4-9">    best_val_acc <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span>(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'inf'</span>)</span>
<span id="cb4-10">    best_model <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span></span>
<span id="cb4-11"></span>
<span id="cb4-12">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> epoch <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(max_epochs):</span>
<span id="cb4-13"></span>
<span id="cb4-14">        model.train()</span>
<span id="cb4-15">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> images, labels <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> train_loader:</span>
<span id="cb4-16">            images, labels <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> images.to(device), labels.to(device)</span>
<span id="cb4-17">            optimizer.zero_grad()</span>
<span id="cb4-18">            loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> criterion(model(images), labels)</span>
<span id="cb4-19">            loss.backward()</span>
<span id="cb4-20">            optimizer.step()</span>
<span id="cb4-21"></span>
<span id="cb4-22">        val_loss, val_acc <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> eval_model(model, val_loader)</span>
<span id="cb4-23">        log[<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'train_loss'</span>].append(loss.item())</span>
<span id="cb4-24">        log[<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'val_loss'</span>].append(val_loss)</span>
<span id="cb4-25">        log[<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'val_acc'</span>].append(val_acc)</span>
<span id="cb4-26"></span>
<span id="cb4-27">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> verbose: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">', '</span>.join([<span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f'Epoch </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>epoch <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">'</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> [<span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f'</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>k<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">: </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>v[<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>]<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">:.4f}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">'</span> <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> k, v <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> log.items()]))</span>
<span id="cb4-28"></span>
<span id="cb4-29">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> val_acc <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&lt;</span> best_val_acc:</span>
<span id="cb4-30">            best_val_acc <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> val_acc</span>
<span id="cb4-31">            best_model <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> model.state_dict()</span>
<span id="cb4-32"></span>
<span id="cb4-33">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Early stopping: stop if val acc has not increased in the last `patience` epochs</span></span>
<span id="cb4-34">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> epoch <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&gt;</span> patience <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">and</span> val_acc <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&lt;=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">max</span>(log[<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'val_acc'</span>][<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span>patience<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>:<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>]): <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">break</span> </span>
<span id="cb4-35">    </span>
<span id="cb4-36">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> best_model: model.load_state_dict(best_model)</span>
<span id="cb4-37">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> model, log</span></code></pre></div></div>
</details>
</div>
<p>We train with learning rates <img src="https://latex.codecogs.com/png.latex?%5C%7B0.01,%200.05,%200.1,%200.5%5C%7D"> as per the paper and <img src="https://latex.codecogs.com/png.latex?%5Calpha%20%5Cin%20%5C%7B0,%2025,%2050%5C%7D"> to save some compute (we should get the idea). Below we define our perceptron, MLPs with 1, 2, and 4 layers and a 4-layer Convnet. Again, since the paper does not specify hidden dims, activations, or the convnet architecture, we set it ourselves.</p>
<div id="cell-10" class="cell" data-execution_count="6">
<details class="code-fold">
<summary>Hyperparam and model definitions</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb5" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1">learning_rates <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> [<span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.01</span>, <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.05</span>, <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.1</span>, <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.5</span>]</span>
<span id="cb5-2">alphas <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">75</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">25</span>)</span>
<span id="cb5-3"></span>
<span id="cb5-4">lin_relu <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">lambda</span> n_in, n_out: nn.Sequential(nn.Linear(n_in, n_out), nn.ReLU())</span>
<span id="cb5-5">models <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> {</span>
<span id="cb5-6">    <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'perceptron'</span>:nn.Sequential(nn.Flatten(), nn.Linear(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">28</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">28</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span>)),</span>
<span id="cb5-7">    <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'MLP1'</span>:nn.Sequential(nn.Flatten(), lin_relu(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">28</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">28</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">256</span>), nn.Linear(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">256</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span>)),</span>
<span id="cb5-8">    <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'MLP2'</span>:nn.Sequential(nn.Flatten(), lin_relu(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">28</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">28</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">256</span>), lin_relu(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">256</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">128</span>), nn.Linear(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">128</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span>)),</span>
<span id="cb5-9">    <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'MLP4'</span>:nn.Sequential(nn.Flatten(), lin_relu(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">28</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">28</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">256</span>), lin_relu(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">256</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">128</span>), lin_relu(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">128</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">64</span>), nn.Linear(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">64</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span>)),</span>
<span id="cb5-10">    <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Conv4'</span>:nn.Sequential(</span>
<span id="cb5-11">        nn.Conv2d(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">32</span>, kernel_size<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span>, stride<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, padding<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>),</span>
<span id="cb5-12">        nn.ReLU(),</span>
<span id="cb5-13">        nn.MaxPool2d(kernel_size<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, stride<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, padding<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>),</span>
<span id="cb5-14">        </span>
<span id="cb5-15">        nn.Conv2d(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">32</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">64</span>, kernel_size<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span>, stride<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, padding<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>),</span>
<span id="cb5-16">        nn.ReLU(),</span>
<span id="cb5-17">        nn.MaxPool2d(kernel_size<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, stride<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, padding<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>),</span>
<span id="cb5-18">        </span>
<span id="cb5-19">        nn.Conv2d(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">64</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">128</span>, kernel_size<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span>, stride<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, padding<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>),</span>
<span id="cb5-20">        nn.ReLU(),</span>
<span id="cb5-21">        nn.MaxPool2d(kernel_size<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, stride<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, padding<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>),</span>
<span id="cb5-22">        </span>
<span id="cb5-23">        nn.Conv2d(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">128</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">256</span>, kernel_size<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span>, stride<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, padding<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>),</span>
<span id="cb5-24">        nn.ReLU(),</span>
<span id="cb5-25">        nn.MaxPool2d(kernel_size<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, stride<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, padding<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>),</span>
<span id="cb5-26">        </span>
<span id="cb5-27">        nn.Flatten(),</span>
<span id="cb5-28">        nn.Linear(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">256</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1000</span>),</span>
<span id="cb5-29">        nn.ReLU(),</span>
<span id="cb5-30">        nn.Linear(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1000</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span>)</span>
<span id="cb5-31">    )</span>
<span id="cb5-32">}</span></code></pre></div></div>
</details>
</div>
<div id="cell-11" class="cell" data-execution_count="7">
<details class="code-fold">
<summary>Train and save models</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb6" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> alpha, (name, model), lr <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> itertools.product(alphas, models.items(), learning_rates):</span>
<span id="cb6-2"></span>
<span id="cb6-3">    noisy_train_dataset <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> NoisyLabelDataset(train_dataset, alpha <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> alpha)</span>
<span id="cb6-4">    train_loader <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> DataLoader(noisy_train_dataset, batch_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> batch_size, shuffle <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)</span>
<span id="cb6-5">    </span>
<span id="cb6-6">    start <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> time.time()</span>
<span id="cb6-7">    model, log <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> train(model, train_loader, val_loader, lr <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> lr, verbose <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)</span>
<span id="cb6-8">    test_loss, test_acc <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> eval_model(model, test_loader)</span>
<span id="cb6-9">    log[<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'test_loss'</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> test_loss</span>
<span id="cb6-10">    log[<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'test_acc'</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> test_acc</span>
<span id="cb6-11"></span>
<span id="cb6-12">    <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(<span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f'</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>name<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;"> - alpha: </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>alpha<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">, lr: </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>lr<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">, test acc: </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>test_acc<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">:.4f}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">, took: </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>time<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">.</span>time() <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> start<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">:.2f}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">s'</span>)</span>
<span id="cb6-13">    torch.save(log, <span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f'logs/</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>name<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">_</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>alpha<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">_</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>lr<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">.pt'</span>)</span>
<span id="cb6-14">    torch.save(model, <span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f'models/</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>name<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">_</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>alpha<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">_</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>lr<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">.pt'</span>)</span></code></pre></div></div>
</details>
</div>
<p>Finally, we plot the accuracies on both the validation and test sets:</p>
<div id="cell-13" class="cell" data-execution_count="44">
<details class="code-fold">
<summary>Plot results</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb7" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Load results into dataframe</span></span>
<span id="cb7-2">df <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> []</span>
<span id="cb7-3"><span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> fname <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> os.listdir(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'logs'</span>):</span>
<span id="cb7-4">    name, alpha, lr <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> fname.split(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'_'</span>)</span>
<span id="cb7-5">    lr <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span>(lr.replace(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'.pt'</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">''</span>))</span>
<span id="cb7-6">    alpha <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>(alpha)</span>
<span id="cb7-7">    </span>
<span id="cb7-8">    log <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.load(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'logs/'</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> fname)</span>
<span id="cb7-9">    tmp <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> {}</span>
<span id="cb7-10">    tmp[<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Model'</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> name</span>
<span id="cb7-11">    tmp[<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Alpha'</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> alpha</span>
<span id="cb7-12">    tmp[<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'lr'</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> lr</span>
<span id="cb7-13">    tmp[<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Prediction Accuracy'</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> log[<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'test_acc'</span>]</span>
<span id="cb7-14">    tmp[<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Validation Accuracy'</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">max</span>(log[<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'val_acc'</span>])</span>
<span id="cb7-15">    df.append(tmp)</span>
<span id="cb7-16">df <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> pd.DataFrame(df)</span>
<span id="cb7-17"></span>
<span id="cb7-18">hue_order <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> [<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'perceptron'</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'MLP1'</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'MLP2'</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'MLP4'</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Conv4'</span>]</span>
<span id="cb7-19"></span>
<span id="cb7-20">f, (ax1, ax2) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> plt.subplots(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, figsize <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> (<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">8</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span>), sharey <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>, sharex <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)</span>
<span id="cb7-21"></span>
<span id="cb7-22">sns.lineplot(</span>
<span id="cb7-23">    df.groupby([<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Model'</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Alpha'</span>]).<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">max</span>(), x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Alpha'</span>, y <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Prediction Accuracy'</span>,</span>
<span id="cb7-24">    hue <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Model'</span>, hue_order <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> hue_order, ax <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> ax1</span>
<span id="cb7-25">)</span>
<span id="cb7-26">ax1.set_title(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'On Test Set'</span>)</span>
<span id="cb7-27">ax1.grid()</span>
<span id="cb7-28"></span>
<span id="cb7-29">sns.lineplot(</span>
<span id="cb7-30">    df.groupby([<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Model'</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Alpha'</span>]).<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">max</span>(), x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Alpha'</span>, y <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Validation Accuracy'</span>,</span>
<span id="cb7-31">    hue <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Model'</span>, hue_order <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> hue_order, ax <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> ax2, legend <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span></span>
<span id="cb7-32">)</span>
<span id="cb7-33">ax2.set_title(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'On Validation Set'</span>)</span>
<span id="cb7-34">ax2.grid()</span>
<span id="cb7-35">plt.tight_layout()</span></code></pre></div></div>
</details>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://ecntu.com/posts/dl-massive-label-noise/index_files/figure-html/cell-8-output-1.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
</div>
<p>We observe that the general trends seem to hold. As we add noise the performance drops and larger models tend to be more robust.</p>
<p>However, our models overall tend to perform worse than the paper’s. At <img src="https://latex.codecogs.com/png.latex?%5Calpha%20=%2050"> most of our models have accuracies below <img src="https://latex.codecogs.com/png.latex?60%5C%25">, whereas the paper’s are around the high <img src="https://latex.codecogs.com/png.latex?80%5C%25">’s. In addition, our Conv4 model is already below 90% when the paper achieves 91% at <img src="https://latex.codecogs.com/png.latex?%5Calpha%20=%20100">.</p>
<p>This might be due to differences in training (use and implementation of early stopping), architecture implementation, random seeds (we did not try multiple / averaging because of compute), etc.</p>
<p>We also observe the trend the paper points out (in Section 5) that higher noise levels lead to smaller effective batch sizes and thus lower optimal learning rates:</p>
<div id="cell-16" class="cell" data-execution_count="71">
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://ecntu.com/posts/dl-massive-label-noise/index_files/figure-html/cell-9-output-1.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
</div>
<p>A few closing thoughts: Although our results did not completely align with the paper, we still find the robustness to noise impressive.</p>
<p>The paper performs and studies the effect much further under non-uniform noise, with different datasets, batch sizes, etc. It is worth a <a href="https://arxiv.org/pdf/1705.10694">read</a>.</p>
<p>I might revisit this notebook to perform further experiments and try to answer some lingering questions:</p>
<ul>
<li>Does within epoch sample ordering matter? Intuitively, if we place all clean labels before the noisy ones, one expects worse performance (catastrophic forgetting?)</li>
<li>What effect does early stopping have? We used a clean validation set to determine when to stop – which is not realistic. What happens if we use the loss or no early stopping?</li>
</ul>



 ]]></description>
  <category>deep learning</category>
  <category>paper</category>
  <guid>https://ecntu.com/posts/dl-massive-label-noise/</guid>
  <pubDate>Tue, 18 Jun 2024 04:00:00 GMT</pubDate>
</item>
</channel>
</rss>
