Skip to content

vishnusaireddy502/LinearRegression

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 

Repository files navigation

PyTorch Linear Regression Project

This project demonstrates a simple linear regression model implemented using PyTorch. It covers the entire machine learning workflow, including data generation, model creation, training, evaluation, and saving/loading the model.

Data Generation

The project creates synthetic data using the equation y = wx + b, where:

  • w (weight) = 0.7

  • b (bias) = 0.3

  • It produces 50 datapoints in which 40 data points are used for training and 10 for testing

Model Architecture

A custom LinearRegressionModel class is defined, inheriting from nn.Module. It includes:

  • Learnable parameters: weight (w) and bias (b)
  • A forward method defining the linear regression equation

Training

The model is trained using:

  • Loss Function: L1Loss (Mean Absolute Error)
  • Optimizer: Stochastic Gradient Descent (SGD)
  • Number of epochs: 200

Visualization

The project includes functions to visualize:

  • Training and testing data
  • Model predictions
  • Loss curves during training

Results

loss curves

image *

prediction before training

image

prediction after training

image

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published