This project uses machine learning to predict stock prices based on historical data. It employs a Long Short-Term Memory (LSTM) neural network model to forecast future stock prices for various companies.
- Data collection from Yahoo Finance using the
yfinance
library - LSTM model for time series prediction
- Training on multiple stock data to improve generalization
- Prediction visualization with matplotlib
- GPU support for faster training (if available)
- Python 3.7+
- TensorFlow 2.x
- yfinance
- pandas
- numpy
- matplotlib
- scikit-learn
You can install the required packages using:
pip install -r requirements.txt
The weights are already trained and saved in the repo 'weights' folder for the current version. This is optional.
To train the model on historical stock data:
python scripts/train.py
This script will:
- Download historical stock data for predefined tickers
- Prepare the data for training
- Build and train the LSTM model
- Save the trained model as
model.h5
and the scaler asscaler.npy
To make predictions using the trained model:
python scripts/inference.py --stock <stock symbol> --output <output filename>
By default, this will predict stock prices for NVIDIA (NVDA) and save the plot as predictions.png
. You can modify the ticker and output filename by passing the appropriate arguments as shown above.
- Data Collection: Historical stock data is fetched using the
yfinance
library. - Data Preprocessing: The data is scaled using MinMaxScaler to normalize the values.
- Model Architecture: An LSTM-based neural network is used for sequence prediction.
- Training: The model is trained on multiple stock data to capture general market trends.
- Prediction: The trained model predicts future stock prices based on recent data.
- Visualization: Predictions are plotted against actual prices, including a simple trading recommendation.
- To train on different stocks, update the list of tickers in
train.py
. - Adjust the
time_step
andfuture_days
parameters in both scripts to change the input sequence length and prediction horizon.
This project is released under the GPL-3.0. For more information, see the LICENSE file.
This may provide inaccurate predictions. Dont rely on it too much until it is improved.