Understanding Tensors Dimensions, Shapes, Sizes and Axes in Pytorch
Tensors are the fundamental data structures used in deep learning. They are multi-dimensional arrays that generalize scalars, vectors, and matrices to higher dimensions. To work effectively with tensors, you need to understand their dimensions, shapes, sizes and axes.
The Fundamentals of Tensors
At its core, a tensor is a generalization of vectors and matrices to potentially higher dimensions. Imagine starting with a simple number (a scalar), then extending it to a line of numbers (a vector), then to a grid of numbers (a matrix), and continuing this pattern into higher dimensions. Each step in this progression represents a tensor of increasing dimensionality.
The dimensionality of a tensor can be expressed mathematically as follows:
For a tensor , its order (or rank) determines the number of indices needed to specify an element:
Where each index ranges from 1 to the size of that dimension.
Understanding Tensor Structure
Let's visualize how tensors progress through dimensions:
import torch
# 0D Tensor (Scalar)
scalar = torch.tensor(5)
print(f"Scalar shape: {scalar.shape}") # Shape: ()
# 1D Tensor (Vector)
vector = torch.tensor([1, 2, 3, 4, 5])
print(f"Vector shape: {vector.shape}") # Shape: (5,)
# 2D Tensor (Matrix)
matrix = torch.tensor([[1, 2, 3],
[4, 5, 6]])
print(f"Matrix shape: {matrix.shape}") # Shape: (2, 3)
# 3D Tensor
tensor_3d = torch.tensor([
[[1, 2], [3, 4]],
[[5, 6], [7, 8]]
])
print(f"3D Tensor shape: {tensor_3d.shape}") # Shape: (2, 2, 2)
Understanding axes is perhaps one of the most crucial aspects of working with tensors in deep learning and the main reason I wrote this blog. An axis represents a direction along which data is organized.
Think of axes as the coordinate system of your tensor. Just like how we use coordinates to specify a point in space, we use axes to locate elements within a tensor. Each axis is zero-indexed, meaning we start counting from 0 - a convention that aligns with most programming languages.
In mathematical notation, we can represent an element in a tensor using these axes:
For a 3D tensor an element at position would be denoted as:
Where:
- i represents the depth (axis 0)
- j represents the row (axis 1)
- k represents the column (axis 2)
The true significance of axes becomes apparent when we perform operations on tensors. Let's explore this with concrete examples:
import torch
# Create a 2D tensor (matrix)
matrix = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
# 1. Dimensional Operations
# Sum along axis 0 (columns) - collapses rows
col_sum = torch.sum(matrix, axis=0)
# Result: tensor([12, 15, 18])
# Explanation: 1+4+7=12, 2+5+8=15, 3+6+9=18
# Sum along axis 1 (rows) - collapses columns
row_sum = torch.sum(matrix, axis=1)
# Result: tensor([6, 15, 24])
# Explanation: 1+2+3=6, 4+5+6=15, 7+8+9=24
# 2. Broadcasting Rules
# Create tensors of different shapes
a = torch.tensor([[1, 2, 3]]) # Shape: (1, 3)
b = torch.tensor([[1],
[2],
[3]]) # Shape: (3, 1)
# Broadcasting multiplication
c = a * b
# Result: tensor([[1, 2, 3],
# [2, 4, 6],
# [3, 6, 9]])
# Explanation: PyTorch automatically expands both tensors to shape (3, 3)
# before multiplication
# 3. Structural Transformations
# Transpose the matrix - swap axes 0 and 1
transposed = matrix.T
# Result: tensor([[1, 4, 7],
# [2, 5, 8],
# [3, 6, 9]])
# Explanation: Rows become columns and vice versa
# Reshape the matrix - rearrange elements while preserving total size
reshaped = matrix.reshape(1, 9)
# Result: tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9]])
# Explanation: Same data, restructured into a different shape (1 row, 9 columns)
These operations demonstrate why understanding axes is crucial. Don't worry if you don't fully grasp these operations yet. The goal here is simply to help you understand the importance of axes.
If you don't understand axes, you won't be able to comprehend how operations are performed, and this will significantly complicate your Deep Learning journey.
One of the main reasons I wrote this blog was to explain 3D tensors. I think understanding axes in 1D and 2D tensors is relatively straightforward. For 2D tensors, one axis represents the rows and the other represents the columns, while for 1D tensors, there is only one axis. However, understanding 3D tensors can be a bit more challenging, so I wanted to dedicate a small section specifically to explaining them.
Understanding 3D Tensors
One of the most common tensor structures in deep learning: the 3D tensor. Think of a 3D tensor as a stack of matrices, or if you prefer, a cube of numbers. This mental model is particularly useful when working with data like image batches or sequential data.
Anatomy of a 3D Tensor
A 3D tensor can be visualized as layers of 2D matrices stacked on top of each other. Each axis serves a specific purpose:
# Creating a 3D tensor for visualization
tensor_3d = torch.tensor([
[[1, 2, 3, 4], # First matrix (depth=0)
[5, 6, 7, 8],
[9, 10, 11, 12]],
[[13, 14, 15, 16], # Second matrix (depth=1)
[17, 18, 19, 20],
[21, 22, 23, 24]]
])
print(f"Tensor shape: {tensor_3d.shape}") # Output: torch.Size([2, 3, 4])
In this example:
Axis 0 (depth): Controls how many matrices we have (2 matrices)
Axis 1 (rows): Determines the number of rows in each matrix (3 rows)
Axis 2 (columns): Specifies the number of columns in each matrix (4 columns)
Operations on 3D Tensor
The real power of 3D tensors emerges when we perform operations along different axes. Let's explore some common operations:
# Define the 3D tensor
tensor_3d = torch.tensor([
[[1, 2, 3, 4], # First matrix (depth=0)
[5, 6, 7, 8],
[9, 10, 11, 12]],
[[13, 14, 15, 16], # Second matrix (depth=1)
[17, 18, 19, 20],
[21, 22, 23, 24]]
])
# Sum along different axes
depth_sum = torch.sum(tensor_3d, axis=0) # Sum along axis 0 (depth): Combines matrices by summing corresponding elements across depth
row_sum = torch.sum(tensor_3d, axis=1) # Sum along axis 1 (rows): Combines rows within each matrix by summing elements row-wise
col_sum = torch.sum(tensor_3d, axis=2) # Sum along axis 2 (columns): Combines columns within each matrix by summing elements column-wise
# Print shapes to understand the effect of summing along each axis
print("Original shape:", tensor_3d.shape) # Original shape of the 3D tensor: (2, 3, 4) -> (depth=2, rows=3, columns=4)
print("After depth sum:", depth_sum.shape) # Shape after summing along depth: collapses depth dimension, resulting in (3, 4)
print("After row sum:", row_sum.shape) # Shape after summing along rows: collapses row dimension, resulting in (2, 4)
print("After column sum:", col_sum.shape) # Shape after summing along columns: collapses column dimension, resulting in (2, 3)
Let me now explain how we get to this results:
depth_sum: Summing along axis 0 (depth) combines the two matrices by adding corresponding elements. Like we do traditionally in Math when we want to add two matrices.
[[1, 2, 3, 4], [[13, 14, 15, 16], [[14, 16, 18, 20],
[5, 6, 7, 8], + [17, 18, 19, 20], = [22, 24, 26, 28],
[9, 10, 11, 12]] [21, 22, 23, 24]] [30, 32, 34, 36]]
row_sum: Summing along axis 1 (rows) collapses the rows within each matrix. So we get the result:
For depth=0: [1+5+9, 2+6+10, 3+7+11, 4+8+12] = [15, 18, 21, 24]
For depth=1: [13+17+21, 14+18+22, 15+19+23, 16+20+24] = [51, 54, 57, 60]
The result will be [[15, 18, 21, 24], [51, 54, 57, 60]]
col_sum: Summing along axis 2 (columns) collapses the columns within each matrix. So we get the result:
For depth=0: [1+2+3+4, 5+6+7+8, 9+10+11+12] = [10, 26, 42]
For depth=1: [13+14+15+16, 17+18+19+20, 21+22+23+24] = [58, 74, 90]
The result will be [[10, 26, 42], [58, 74, 90]]
More Tensor Operations
Reshaping and Transposing
Reshaping and transposing are fundamental operations that allow us to restructure our data while preserving the underlying values. Let understand them with an example:
# Reshaping example
tensor = torch.tensor([1, 2, 3, 4, 5, 6])
reshaped = tensor.reshape(2, 3)
print("Reshaped tensor:\n", reshaped) ## [[1, 2, 3],
[4, 5, 6]]
# Transposing example
transposed = reshaped.T
print("\nTransposed tensor:\n", transposed) ## [[1, 4],
[2, 5],
[3, 6]]
Broadcasting
Broadcasting is a powerful mechanism that allows tensors with different shapes to be combined. Here, I will limit myself to a simple example, but this concept can get tricky at times. Perhaps I should write a separate post dedicated to explaining broadcasting in more detail.
Let's consider this example:
# Create tensors of different shapes
a = torch.tensor([[1, 2, 3]]) # Shape: (1, 3)
b = torch.tensor([[1],
[2],
[3]]) # Shape: (3, 1)
# Broadcasting multiplication
c = a * b
print("Result shape:", c.shape)
print("Result:\n", c)
Tensor a and b will get extended before the multiplication the result will be:
Expanded a:
tensor([[1, 2, 3],
[1, 2, 3],
[1, 2, 3]])
Expanded b:
tensor([[1, 1, 1],
[2, 2, 2],
[3, 3, 3]])
Final result (c = a * b):
tensor([[1, 2, 3],
[2, 4, 6],
[3, 6, 9]])
Exercises
There’s no point in learning without practicing, so here are some exercises to test your understanding.
Exercise 1: Shape Manipulation
# Exercise: Determine the shapes of the following tensors
tensor_1 = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
tensor_2 = torch.tensor([[[1], [2]], [[3], [4]]])
tensor_3 = torch.tensor([[[1, 2, 3]], [[4, 5, 6]]])
# Your task: Write down the shape of each tensor
Exercise 2: Tensor Operations
# Exercise: Perform the following operations
x = torch.tensor([[[1, 2], [3, 4]],
[[5, 6], [7, 8]]])
# 1. Sum along axis 0
# 2. Calculate mean along axis 1
# 3. Concatenate with itself along axis 2
Exercise 3: Advanced Transformations
# Exercise: Transform the following tensor
tensor = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8])
# 1. Reshape it into a 2x4 matrix
# 2. Transpose the result
# 3. Sum along appropriate axis to get a vector
Conclusion
You have reached the end of this short blog on tensors. Hopefully, you leave this page with a better understanding of what a tensor is. We've covered the fundamental concepts of tensor dimensions, shapes, and axes, along with practical operations and transformations. Through hands-on exercises, we've explored how these concepts apply in real-world scenarios. Thank you for taking the time to read my work
If you have any questions or issues, feel free to reach out to me on X/Twitter