{ "cells": [ { "cell_type": "markdown", "id": "9e934779-6b12-47ac-92f5-ce8838873651", "metadata": {}, "source": [ "# Getting Started with Flux" ] }, { "cell_type": "markdown", "id": "90d0c237-de52-4d1d-8483-6525467f847e", "metadata": {}, "source": [ "The code below is copied from the tutorial of Flux,\n", "with updates necessary for later versions." ] }, { "cell_type": "code", "execution_count": 1, "id": "bb462353-e270-419a-a8bd-e546e140a268", "metadata": {}, "outputs": [], "source": [ "using Flux" ] }, { "cell_type": "markdown", "id": "e43a9402-f692-4d0f-9e52-6d963e348d5a", "metadata": {}, "source": [ "## 0. a simple example" ] }, { "cell_type": "markdown", "id": "b44f84ed-7dda-407c-b82d-2a34a18fc0c0", "metadata": {}, "source": [ "The ground truth model is defined by a matrix ``W_truth`` and vector ``b_truth``.\n", "We aim to recover ``W_truth`` and ``b_truth``\n", "using examples of ``ground_truth()``" ] }, { "cell_type": "code", "execution_count": 2, "id": "e6327cdf-5610-4298-8e2f-69c1ff1c396d", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "2×5 Matrix{Int64}:\n", " 1 2 3 4 5\n", " 5 4 3 2 1" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "W_truth = [1 2 3 4 5;\n", "\t 5 4 3 2 1]" ] }, { "cell_type": "code", "execution_count": 3, "id": "3e8da2f9-ec39-4fc0-ac98-457c2760438b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "2-element Vector{Float64}:\n", " -1.0\n", " 2.0" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "b_truth = [-1.0; 2.0]" ] }, { "cell_type": "code", "execution_count": 4, "id": "75bc17e0-7bc2-485e-ac93-c95d40c0b6cb", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "ground_truth (generic function with 1 method)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ground_truth(x) = W_truth*x .+ b_truth" ] }, { "cell_type": "markdown", "id": "0e4cb87c-ef90-42b9-b2ce-0f3092e1ce90", "metadata": {}, "source": [ "## 1. training data and the model" ] }, { "cell_type": "markdown", "id": "11ad7a01-3e16-4452-9ce3-69a820392770", "metadata": {}, "source": [ "We generate the ground truth training data as vectors of vectors." ] }, { "cell_type": "markdown", "id": "89c685e2-fd57-4660-8878-f8b5ca3af299", "metadata": {}, "source": [ "The training data consists of ``N`` random input vectors and as output we add noise to the ground truth evaluated at the input vectors." ] }, { "cell_type": "code", "execution_count": 5, "id": "890d3539-b88f-4ee8-b16d-faf1591b9066", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "10000" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "N = 10_000" ] }, { "cell_type": "code", "execution_count": 6, "id": "4b669902-d194-4438-9502-8ccb0a039f5c", "metadata": {}, "outputs": [], "source": [ "x_train = [ 5 .* rand(5) for _ in 1:N];" ] }, { "cell_type": "code", "execution_count": 7, "id": "18682efa-3486-46e8-bb7a-5dab134394f5", "metadata": {}, "outputs": [], "source": [ "y_train = [ ground_truth(x) + 0.2 .*randn(2) for x in x_train];" ] }, { "cell_type": "markdown", "id": "babfd9f0-8996-48d2-b36b-1cad49f16897", "metadata": {}, "source": [ "Next we define the model we want to train." ] }, { "cell_type": "code", "execution_count": 8, "id": "6bc589e8-bf50-4998-9766-33b3f4a4b71b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "model (generic function with 1 method)" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model(x) = W*x .+ b" ] }, { "cell_type": "markdown", "id": "67bc2f7b-b447-426e-b76f-a092b8f69f16", "metadata": {}, "source": [ "As a function of ``x`` the model depends on the weight matrix ``W`` and the bias ``b``, initialized at random in the code cells below." ] }, { "cell_type": "code", "execution_count": 9, "id": "10bd1ea8-d9ac-417f-9f54-460d7e112823", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "2-element Vector{Float64}:\n", " 0.6517795685475002\n", " 0.40564389170977555" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "W = rand(2, 5)\n", "b = rand(2)" ] }, { "cell_type": "markdown", "id": "3f14631e-f4f4-4ef6-8610-10229a7f3454", "metadata": {}, "source": [ "## 2. the loss function" ] }, { "cell_type": "markdown", "id": "3370267b-29df-46a3-96ff-2d4f400f68f0", "metadata": {}, "source": [ "To measure the performance, we define a loss function, which now needs to model as its first parameter." ] }, { "cell_type": "code", "execution_count": 10, "id": "7e2272e4-b03a-4261-aefd-34da60c35801", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "loss" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "\"\"\"\n", " function loss(M, x, y)\n", "\n", "defines the loss for the model M,\n", "with inputs x and outputs y.\n", "\"\"\"\n", "function loss(M, x, y)\n", " yy = M(x)\n", " sum(( y .- yy).^2)\n", "end" ] }, { "cell_type": "markdown", "id": "e2ccf0fd-145d-4a9f-bb9a-e119d82c9d3e", "metadata": {}, "source": [ "As a sanity check, we evaluate the loss function at the first instance of the training data." ] }, { "cell_type": "code", "execution_count": 11, "id": "d1140277-6020-4f34-8f98-f0cf8d98e652", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "3761.4469837331026" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "loss(model, x_train[1], y_train[1])" ] }, { "cell_type": "markdown", "id": "2a00b9ab-db36-45d1-97f2-74f99e70ec1e", "metadata": {}, "source": [ "## 3. selecting the optimizer" ] }, { "cell_type": "markdown", "id": "3befe7f8-b886-4555-a5af-e102c6f71ada", "metadata": {}, "source": [ "For the optimizer, we choose the gradient descent method, which requires a value for the step size parameter." ] }, { "cell_type": "code", "execution_count": 12, "id": "392bd86c-6b47-42cd-be41-5a5b7a4a78af", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Descent(0.01)" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "opt = Descent(0.01)" ] }, { "cell_type": "markdown", "id": "12986e3e-05ea-476f-a9e1-9b534b93230e", "metadata": {}, "source": [ "## 4. setup of the training data" ] }, { "cell_type": "code", "execution_count": 13, "id": "79823b5a-f1be-4b70-8f9e-ced9336baad2", "metadata": {}, "outputs": [], "source": [ "train_data = zip(x_train, y_train);" ] }, { "cell_type": "markdown", "id": "73893977-3316-4ea5-840d-d484ad3b5997", "metadata": {}, "source": [ "## 5. train" ] }, { "cell_type": "markdown", "id": "3c2a8ae8-47eb-45c9-9ae3-57975446205a", "metadata": {}, "source": [ "For the training to apply, we must now first define a neural network model. We have one layer, defined by ``W`` and ``b``." ] }, { "cell_type": "code", "execution_count": 14, "id": "3d64c71f-dc65-44d4-96c3-7422070a4d66", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Dense(5 => 2) \u001b[90m# 12 parameters\u001b[39m" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "L = Dense(W, b)" ] }, { "cell_type": "code", "execution_count": 15, "id": "34d82ba0-6804-49f6-8f72-d812fa2a521d", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Chain(\n", " Dense(5 => 2), \u001b[90m# 12 parameters\u001b[39m\n", ") " ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "M = Chain(L)" ] }, { "cell_type": "markdown", "id": "2542858f-5260-46c1-a890-433c7d670e61", "metadata": {}, "source": [ "The parameter collection is shown via the following." ] }, { "cell_type": "code", "execution_count": 16, "id": "92ec5ac0-e7a3-4938-9455-277b3140aee3", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(layers = (Dense(5 => 2),),)" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# ps = Flux.params(W, b) # deprecated\n", "ps = Flux.trainable(M)" ] }, { "cell_type": "markdown", "id": "649221d4-9801-4b9e-bb4e-39ff680d3ff5", "metadata": {}, "source": [ "To execute a training epoch, we do" ] }, { "cell_type": "code", "execution_count": 17, "id": "b828734c-c8d0-4f72-b1ef-5e5d3ee327ef", "metadata": {}, "outputs": [], "source": [ "Flux.train!(loss, M, train_data, opt)" ] }, { "cell_type": "markdown", "id": "0ba4a0cf-d5ad-463c-b082-7a4c195451b1", "metadata": {}, "source": [ "The above command uses all of ``train_data``, which may not be necessary. To reduce the computational effort, we can take a slice of ``train_data``, e.g.: via ``train_data = zip(x_train[1:100], y_train[1:100])`` to use the first 100 data elements." ] }, { "cell_type": "markdown", "id": "62235c2d-e6b0-4953-9eca-31fffa59f16f", "metadata": {}, "source": [ "## 6. evaluation of the training results" ] }, { "cell_type": "markdown", "id": "9c5b72ca-4547-47c2-a40e-95ebe4faee84", "metadata": {}, "source": [ "How well did we do?" ] }, { "cell_type": "code", "execution_count": 18, "id": "934d42e3-b3cc-44ae-9fa8-b01820af7108", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "W = [1.0142472445975599 2.0137615108573783 3.07110784681812 4.002482164146511 5.0039563812985755; 4.989538524769695 4.016276579899637 2.9859959678369483 2.0334734749709873 1.0518794781416378]\n" ] }, { "data": { "text/plain": [ "2×5 Matrix{Float64}:\n", " 1.01425 2.01376 3.07111 4.00248 5.00396\n", " 4.98954 4.01628 2.986 2.03347 1.05188" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "@show W" ] }, { "cell_type": "code", "execution_count": 19, "id": "fe4bbdb4-4f58-4d25-9657-1c4ad958752e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "maximum(abs, W .- W_truth) = 0.0711078468181201\n" ] }, { "data": { "text/plain": [ "0.0711078468181201" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "@show maximum(abs, W .- W_truth)" ] }, { "cell_type": "raw", "id": "bf735e02-02a2-4cb9-bf7d-6b9d2f94d42e", "metadata": {}, "source": [ "We observe that we have about 2 decimal places correct." ] }, { "cell_type": "markdown", "id": "7b38fce4-73e5-406f-92db-f46649c54b54", "metadata": {}, "source": [ "## 7. monitoring the training progress" ] }, { "cell_type": "markdown", "id": "de31d932-3ea1-4951-8a34-9248df11ff2e", "metadata": {}, "source": [ "Let us reset the weight matrix ``W`` and bias ``b`` to start over and monitor the error while training. We also need to redefine the model, now called ``M2``." ] }, { "cell_type": "code", "execution_count": 20, "id": "b30c7009-2759-43ee-8e0b-ab93b22d38a0", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Chain(\n", " Dense(5 => 2), \u001b[90m# 12 parameters\u001b[39m\n", ") " ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "W = rand(2, 5)\n", "b = rand(2)\n", "M2 = Chain(Dense(W, b))" ] }, { "cell_type": "markdown", "id": "61f151d6-60fd-465d-bf86-e4087d8b9ec2", "metadata": {}, "source": [ "The training epoch needs to be able to update the ``state``, defined in the code cell below." ] }, { "cell_type": "code", "execution_count": 21, "id": "4f631018-0507-4da4-a762-9b7c65c05714", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(layers = ((weight = \u001b[32mLeaf(Descent(0.01), \u001b[39mnothing\u001b[32m)\u001b[39m, bias = \u001b[32mLeaf(Descent(0.01), \u001b[39mnothing\u001b[32m)\u001b[39m, σ = ()),),)" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "state = Flux.setup(opt, M2)" ] }, { "cell_type": "markdown", "id": "e5bd0c76-5899-4e7e-a024-5c6e0bf51c9e", "metadata": {}, "source": [ "Executing a training epoch as below, we monitor the loss." ] }, { "cell_type": "code", "execution_count": 22, "id": "45fbe95f-7bd8-4731-b7e4-c4003b09bf30", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "4567.357304361625\n", "553.245833285954\n", "483.495210214943\n", "12.449436308190595\n", "285.6257320848692\n", "206.76473841580028\n", "282.4342036989557\n", "35.935431394062434\n", "224.30710674030985\n", "48.27349434845312\n", "19.330243030236904\n", "42.10185758891386\n", "3.6958737236429444\n", "12.573886025896503\n", "2.0824167640130886\n", "34.4195301641794\n", "38.55261510067439\n", "6.005866000314168\n", "5.506868982119332\n", "26.832612337169476\n", "30.95632570874156\n", "16.046621288643045\n", "0.823195196270534\n", "1.2105549290506756\n", "21.47650651942783\n", "36.41300146343701\n", "12.110637753848968\n", "1.7286918960391102\n", "2.759415237770284\n", "9.166914904948518\n", "1.497946334433237\n", "6.748319003653388\n", "2.5204779502019026\n", "3.260907795339238\n", "10.656263799233962\n", "2.6453086761731655\n", "6.842529146203612\n", "7.089395986635256\n", "2.250954564150082\n", "12.41610177120226\n", "8.056877542810183\n", "4.575267190465548\n", "1.9493420824419512\n", "4.787602390511638\n", "0.9176100046569426\n", "0.36186173827002016\n", "0.640414329363628\n", "1.0837171435532458\n", "1.0361860533866787\n", "0.7080788619274915\n", "0.23898593871976204\n", "1.1999973245448856\n", "0.23682430983194908\n", "0.6132342393537717\n", "2.1263457874728853\n", "4.478589209046477\n", "1.6291465466975268\n", "7.042942674287948\n", "3.5249918723837537\n", "0.4360170918925239\n", "0.8245420615634346\n", "1.4979541090427702\n", "2.091869194555607\n", "0.1670829353105578\n", "1.5549775857466992\n", "0.008255920709321718\n", "1.7287520205223132\n", "1.014614675062218\n", "3.427540833312369\n", "0.7064610068913421\n", "0.6353752410489342\n", "1.1693887020622102\n", "0.3271534990047012\n", "0.5376893828670511\n", "2.211656996305679\n", "0.7808933522334682\n", "1.8835113345676275\n", "4.628439082564309\n", "0.7537458447691291\n", "0.6383605146944548\n", "0.7633576801924494\n", "0.7013660215393542\n", "1.139224648752619\n", "0.24118378972597626\n", "0.3716651916069884\n", "2.032644997884912\n", "1.392898023933264\n", "0.8556324676908309\n", "1.6398106212282169\n", "0.014650442796064785\n", "0.17825094223724552\n", "0.4496419024114372\n", "0.6554331056713464\n", "1.553170873629522\n", "1.0756438953748773\n", "1.551767784962852\n", "0.2713973701962646\n", "1.8631456665339887\n", "0.21362371164547814\n", "1.6771270879333304\n" ] } ], "source": [ "nbr = 100\n", "for (x, y) in zip(x_train[1:nbr], y_train[1:nbr])\n", " gs = gradient(m -> loss(m, x, y), M2) \n", " println(loss(M2, x, y))\n", " Flux.Optimise.update!(state, M2, gs[1])\n", "end" ] }, { "cell_type": "code", "execution_count": 23, "id": "a28ff23b-cc25-4945-8d62-83c7b310ddca", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "W = [0.7651632087278496 1.9307594478285286 2.8159979332867633 3.945582632282975 4.817345349709885; 5.141966680191063 4.000216447628437 3.032893106736955 2.0786753375912923 1.1295692523306626]\n" ] }, { "data": { "text/plain": [ "2×5 Matrix{Float64}:\n", " 0.765163 1.93076 2.816 3.94558 4.81735\n", " 5.14197 4.00022 3.03289 2.07868 1.12957" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "@show W" ] }, { "cell_type": "code", "execution_count": 24, "id": "9eac467e-81a0-47ed-a06b-040de4801800", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "maximum(abs, W .- W_truth) = 0.2348367912721504\n" ] }, { "data": { "text/plain": [ "0.2348367912721504" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "@show maximum(abs, W .- W_truth)" ] }, { "cell_type": "markdown", "id": "719e7073-e764-48b8-b3b1-b60bff793e8c", "metadata": {}, "source": [ "Using the first 100 elements of the training data, we get only about one correct decimal place." ] } ], "metadata": { "kernelspec": { "display_name": "Julia 1.12", "language": "julia", "name": "julia-1.12" }, "language_info": { "file_extension": ".jl", "mimetype": "application/julia", "name": "julia", "version": "1.12.4" } }, "nbformat": 4, "nbformat_minor": 5 }