Unraveling neural networks for artificial intelligence
Computers make tons of decisions that help us with our lives. From TV show recommendations to directions from a navigation app, most aspects of our modern lives are influenced by machines’ computations. Machine learning is a powerful tool for taking real world data and then recognizing patterns, making predictions, and drawing conclusions.
Learning requires a brain, and the computer’s equivalent of that is called a neural network. But how do computer “brains” work? And if these machines can “learn”, who’s responsible for teaching them?
Dr. Yasaman Bahri is a research scientist at Google Brain where she has worked hard to understand the training processes of neural networks. You can think of Dr. Bahri as the machines’ personal trainer – her job is to understand how machine learning works and to discover ways to train them more efficiently.
Table of Contents:
What is Machine Learning?
To understand how machines learn, we can start by thinking about how humans learn. When we all learned how to write numbers in kindergarten, we probably started off by mimicking grown ups’ handwriting. Eventually, we learned to recognize each digit, even if the handwritten digits looked slightly different each time. Just like how humans learn the task of writing, we can ask computers to write or recognize handwritten digits. To the computer, a digit is an image, which just looks like a 2D grid filled with pixels that come together to represent the single number. These individual pixels all have their own values, so the image of a “7” or a “3” is really made up of a bunch of numbers, organized into a grid. The computer’s goal is to take that grid of pixels and learn to recognize handwritten digits contained inside the images. If the computer can achieve this goal without “step-by-step” instructions, then we call that machine learning.
Why's it called a neural network?
As humans learn, the physical structure of our brains changes. Our brains are full of nerve cells, or neurons, that can spread information to other neurons by forming new connections, or axons. Learning can also strengthen existing connections between these neurons, reinforcing paths that already exist.
Inspired by the organic brain, one way machines can learn is also with a collection of artificial “neurons”, called a neural network. A neural network is a collection of nodes that store information. It can spread that information by strengthening or weakening its connections to neighboring nodes. Those nodes act as artificial neurons, which come together to form a complex neural network.
Let’s get into the nitty-gritty of how a neural network actually works. Our goal is to teach a machine to recognize handwritten digits. This requires an input to the computer – the image of the digits. The computer outputs a prediction – whatever digit (0-9) the computer thinks was written in the input image. This means that we can say that the neural network is a function that takes images as inputs and generates digits as outputs.
So we can consider a neural network to be a function. But what should that function even look like to do something as complicated as taking a grid of pixels and returning an accurate single number? To figure that out, we usually let the neural network learn the function from the data itself – this is why the process is called “machine learning”!
This process of learning a function from data, and using that function to make predictions, is hugely important to almost every scientific discipline. So getting computers to learn functions for us can powerfully change the way scientists conduct their research. Functions are so powerful because if we can find a strong pattern in the past (input), then we can probably make a trustworthy prediction about the future (output).
The process of data fitting isn’t just something that’s unique to machine learning. There are tons of times when you look at a collection of points and want to figure out the rough line or shape that fits that data. The very first thing a lot of scientists try when they want to learn a function from data is linear regression. Basically, this means that our first guess at a pattern is that all our data fits on a line.
Imagine you planted some tomato seedlings in your garden at the beginning of summer. You measured their heights every week for the entire summer and then plotted their heights versus the time since planting. This data is represented by the blue dots in the graph below. We’re going to guess that this data fits on a line, meaning the tomato plants grew at a steady rate the entire summer. The equation of a line has the form
y = w₀ + w₁x ,
where y is the output and x is the input data. You might have seen the equation for a line written a little differently before (maybe you know it as y = b+mx or y = b+ax ), but remember that w₀ and w₁ are just constants that describe the slope and intercept of our line. Trying to find the line that best fits our data points boils down to finding the best values for w₀ and w₁.
Luckily, our data actually does kinda look like a line, so we can find these values pretty easily. To do that, we define an error. We’ll define this error as how far away our prediction is from our real value, but we’ll square it so that the value is always positive. This gives an equation for the error as
(y_real - y_pred) ²
where y_real is the real output data (measured from an experiment), and y_pred is the predicted output (from plugging the real input x into w₀ + w₁x , using our guess for w₀ and w₁).
If we predict every y value correctly, then the error would be exactly 0. But it will be impossible to exactly predict every value correctly because all our seedlings had slightly different heights over the course of the summer. Instead, we can find the best values for w₀ and w₁ by minimizing the error overall.
Linear regression works beautifully if the underlying data is actually linear. Unfortunately, this often isn’t the case – nature is wonderfully messy and complicated. For an example of nonlinear data, let’s suppose you own a pizza shop. You want to be sure that enough pizza-makers are scheduled for busy times, but you don’t need as many pizza-makers for slow parts of the day. Based on your sales from previous weeks, you’d like to find the pattern of every day’s busiest times. The data looks something like the dataset below.
Our pizza sales have two peaks at different times, representing when the customers want pizza the most. We can pretty safely guess that the first peak of the day is lunch time, and the second peak is dinner time. This makes sense – humans typically eat pizza around meal times, with some variation where some people want to eat a little earlier and some people want to eat a little later. But is there a more mathematical way to describe the data?
We could guess that each peak is generated by a Gaussian distribution, which has the form
y_pred = exp( -(x - a) ²/ b )
where exp means the exponential function and a and b are constants that change the shape of the peak. It’s pretty easy to describe two peaks next to each other - it’s just the sum of two Gaussians! It turns out that Gaussian distributions describe many natural phenomena, so it’s a good guess. In this case, it turns out to describe our pizza sales data pretty well! But if we guessed a different function, like our straight line from before, our predictions wouldn’t have matched the data at all.
Gaussian Distributions: Life on the Bell Curve
The Gaussian distribution is also referred to as a “normal distribution” or “bell curve” because its shape looks a bit like a bell. This shape shows up in the statistics of many natural phenomena. For example, a random sample of people’s heights starts to take the shape of a bell curve as the number of samples increases. Another example is the carnival game where you drop balls down through a wall of pegs (called a quincunx or Galton board!). The ball bounces through the pegs and lands in a bin below. If enough balls are dropped, their distribution in the bins will look like a bell curve. The normal distribution shows up all over biology, economics, and quantum mechanics. It governs flipping coins or rolling dice, and noise in sound or light signals. It’s hard to find something that can’t be described by the bell curve!
The reason this distribution is so common lies in a statistics theorem: the central limit theorem. We won’t go deep into the central limit theorem here, but it basically states that if you take enough random samples from a system, regardless of the system’s details, those random samples will form a normal distribution. This is a very strange and cool statistical fact.
Neural Networks to the Rescue
In our pizza sales example, we had to make a lucky guess that the data could be modeled as the sum of two Gaussians. But what if we had no idea what function would fit our data? This is where a neural network comes in handy! Neural networks don’t assume a specific form for the function: they let you represent lots of very complicated functions. As the neural network learns to better represent the data (by giving it lots of data), it can figure out what the function should look like automatically. Just like when you are learning your numbers, the more fonts and styles you read the number 0 in, the better you can recognize it as a zero.
Neural networks are still input-output machines, just like a line or a Gaussian - they just look more complicated. Like we said earlier, they’re made up of nodes and connections between nodes. In the image below, the nodes are the gray bubbles, and the connections are the black lines between them. The connections between the nodes are equivalent to w₀ and w₁ from our linear regression example. But now there can be many, many connections, represented by many variables, and to get the output we have to add up all the contributions from the different nodes and connections: w₁₁x₁ + w₂₁x₂+ w₃₁x₃+...
So far, everything we’ve talked about is fundamentally the same as linear regression, just with more connections and more inputs. But in order to represent functions that are more complicated than straight lines, we have to add one more element to the mix: an “activation function”.
The activation function takes in the sum of all the inputs and connections and does a small amount of processing using a more complicated function. For example, if the activation function is a sine function, our outputs after the first layer would like
y_pred = sin(w₁₁x₁ + w₂₁x₂+ w₃₁x₃+...)
So a node takes in the inputs (x’s) and connections (w’s) and produces an intermediate output ( f(w₁₁x₁ + w₂₁x₂+b₁+b₂) in the gray bubble). The values of the nodes in the final layer are the predicted outputs, y_pred. The final step is to determine all the right variables that describe the connections, just like we found the right slope and intercept in the linear regression example. We do this by minimizing the error, (y_real - y_pred) ².
In the example above, we have very few variables that we need to find, since our neural network is very simple. Our network has just one layer with only two nodes per a layer; in other words, we have a network with depth l=1 and width d=2. We can construct a larger network to represent more complicated functions by adding more nodes and layers. For example, a network with depth l=2 and width d=3 would look like the figure below. Now the nodes in the second layer ( l=2 ) will take the outputs from the previous layer ( l=1 ) and do the same node operation that we just described.
The Power and Puzzle of Deep Learning
As we increase the width or depth of our network, we can fit more complicated functions. Deep learning is a subset of machine learning methods that uses large neural networks. It works very well on some challenging, real-world problems. For example, an artificially intelligent computer called AlphaGo Zero (which is partly made up of neural networks) can beat a world-class professional chess player!
The super cool (and also super crazy) thing about deep neural networks is that no one really knows why they work so well. They model really complicated functions, but how do you know or understand what the final function will look like in general? That’s one of the things that Dr. Yasaman Bahri is working to figure out.
Dr. Bahri is a physicist, and when faced with a very complicated problem like neural networks, physicists often try to come up with a much simpler version of the same problem. They hope that by learning from the simpler problem, they can learn lessons that also apply to the more complicated problem. This is exactly what Dr. Bahri did when approaching neural networks.
Using her knowledge from physics, she approached the deep learning problem from a different angle. Instead of thinking about a real neural network, she considered a special neural network that is infinitely wide. This means that each layer of the network is made up of never-ending neurons that go on forever. This sounds more complicated, but sometimes when you zoom out of systems like this, surprisingly, things get a lot simpler!
Remember we briefly discussed how the “bell curve” emerges in data if we have a large enough number of random samples? It turns out that when we have an infinite number of neurons in a layer, the output of the layer is related to the same bell curve. Using the properties of this distribution, we can make educated guesses about the neural network’s final function based on data. This system of educated guesses is called a Gaussian process. It’s a statistical model that estimates how likely each y_pred is for each x. For example, in the figure below, given the five data points, we could draw a lot of red curves that fit this dataset. But with the special infinitely wide network, we know which functions the network is most likely to predict (the shaded red region in the below example). By looking at things this way, Dr. Bahri has been able to make progress towards understanding the puzzle of how deep neural networks are so successful.
We want to understand how a neural network is trained because it will help us interpret the learned results better. This is crucial as more and more computational power is poured into our daily lives — from credit card applications approval, to job application screening, to how your packages and groceries are delivered. Such a big impact on human lives requires careful decision making. If we just use these machine learning methods without really understanding what the results mean or how we got them, it’s possible that the machines will make the same mistakes (or worse) that humans do every day. It’s especially important to consider how biases and discrimination can be transferred into the code that we write, or else artificial intelligence runs the risk of deepening the social, economic, and racial gaps that already exist in our human society. With Dr. Bahri’s work, hopefully we’ll soon unlock the puzzle of deep learning, but with this great power will come a great responsibility.
Written by Yanting Teng and Madelyn Leembruggen
Edited by Ella King and Lindsey Oberhelman
Illustrations by Lindsey Oberhelman
Primary Sources and Additional Readings:
What is a Neural Network? | STEM for kids by Technovation
But what is a neural network? Chapter 1, Deep learning by 3Blue1Brown
Are You Unwittingly Helping to Train Google’s AI Models? by Rugare Maruzani
Neural Networks by IBM Cloud Education
Yasaman Bhari personal website
Computers are learning about you... Your turn to learn about computers!
Interact (20-30 minutes): Can a computer learn to read your handwriting? Test a pre-trained neural network's reading skills by providing it with numbers in your own handwriting.
Explore (30-45 minutes): Tinker in A Neural Network Playground. Choose a set of training data, control the input data, add or remove layers and neurons, and watch how this all affects the predicted output!