numpy.expand_dims() in Python
The numpy.expand_dims() function expands the shape of an array. It Inserts a new axis that appears at the axis position in the expanded array shape.
Syntax
numpy.expand_dims(a, axis)
Parameter
The numpy.expand_dims() function has two parameters which are as follows:
a : This parameter represents an input array.
axis: The axis parameter signifies the position in the expanded axes where the new axis is placed.
Return
This function returns the output array wherein the number of dimensions is one greater than that of the input array.
Example 1
#Python Program explaining
#numpy.expand_dims() function
import numpy as np
x = np.array(([1,2],[3,4]))
print ('Array x:')
print (x,"\n")
y = np.expand_dims(x, axis = 0)
print ('Array y:')
print (y,"\n")
print ('The shape of X and Y array:')
print (x.shape, y.shape,"\n")
Output
Array x:
[[1 2]
[3 4]]
Array y:
[[[1 2]
[3 4]]]
The shape of X and Y array:
(2, 2) (1, 2, 2)