How to use SGDRegressor in scikit-learn

11,048

Solution 1

In machine learning, y represents the label or target of your data. That is, the correct answers for your training data (X).

If you want to learn some values corresponding to years, then those years will be your training data (X) and the correct values associated to them will be your targets (y).

You can notice that this fits the sizes you mentioned in your first paragraph: X will be of shape (n_samples, n_features) because it will have as many entries as you have years, and each entry will be of size 1 (you only have 1 feature, the year) and y will be of length n_samples because you have a value associated with each year.

Solution 2

y is your target (what you want to predict) and you can get it this way:

from sklearn import linear_model

clf = linear_model.SGDRegressor()
clf.fit(x_to_train, y_to_train)

# clf is a trained model

y_predicted = clf.predict(X_to_predict)
Share:
11,048
Jordan Bramble
Author by

Jordan Bramble

Updated on June 04, 2022

Comments

  • Jordan Bramble
    Jordan Bramble almost 2 years

    I am trying to figure out how to properly use scikit-learn's SGDRegressor model. in order to fit to a dataset I need to call a function fit(X,y) where x is a numpy array of shape (n_samples,n_features), and y is a 1d numpy array of length n_samples. I am trying to figure out what y is supposed to represent.

    for instance my data appears as so:

    enter image description here

    my features are years starting in 1972, and the values are a corresponding value for that year. I am trying to predict the values for years in the future such as 2008, or 2012. I am assuming that each row in my data should represent a row/sample in X where each element in that is the value for a year. in that case what would y be? I was thinking that y should just be the years, but then y would be of length n_features instead of n_samples. if y is to be of length n_samples then what could y possibly be that is of length 5(number of samples in the data shown below). I am thinking I must transform this data some way.