[MINI] Activation Functions


listen on castbox.fmlisten on google podcastslisten on player.fmlisten on pocketcastslisten on podcast addictlisten on tuninlisten on Amazon Musiclisten on Stitcher

--:--
--:--


2017-06-16

Activation Functions

In a neural network, the output value of a neuron is almost always transformed in some way using a function. A trivial choice would be a linear transformation which can only scale the data. However, other transformations, like a step function allow for non-linear properties to be introduced.

Activation functions can also help to standardize your data between layers. Some functions such as the sigmoid have the effect of "focusing" the area of interest on data. Extreme values are placed close together, while values near it\'s point of inflection change more quickly with respect to small changes in the input. Similarly, these functions can take any real number and map all of them to a finite range such as [0, 1] which can have many advantages for downstream calculation.

In this episode we overview the concept and discuss a few reasons why you might select one function versus another.

Some examples are shown below.

%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import math
xmin = -4
xmax = 4
x = np.arange(start=xmin, stop=xmax, step=.1)
#Tanh
y = list(map(lambda x_i: np.tanh(x_i), x))
plt.figure(figsize=(15,5))
plt.plot(x,y, linewidth=4)
plt.plot([xmin, xmax], [0, 0], linestyle='--', color='#888888')
plt.title('tanh', fontsize=22)
plt.show()
#Step
def step(x):
    if x < 0:
        return 0
    if x >=0:
        return 1
y = list(map(step, x))
plt.figure(figsize=(15,5))
plt.plot(x,y, linewidth=4)
plt.title('step', fontsize=22)
plt.show()
y = list(map(lambda x_i: math.pow(x_i,2), x))
#Sigmoid
def sigmoid(x):
    return 1/(1+np.exp(-x))
y = list(map(sigmoid,x))
plt.figure(figsize=(15,5))
plt.plot(x,y, linewidth=4)
plt.title('sigmoid', fontsize=22)
plt.show()
#ReLU
def relu(x):
    if x<0:
        return 0
    if x>=0:
        return x
y = list(map(relu,x))
plt.figure(figsize=(15,5))
plt.plot(x,y, linewidth=4)
plt.title('relu', fontsize=22)
plt.show()