/ src / Test.java
Test.java
 1  import java.io.IOException;
 2  import java.util.Arrays;
 3  
 4  /**
 5   * Created by MichaelBick on 7/29/15.
 6   */
 7  public class Test {	
 8      public static void main(String[] args) throws IOException {
 9      	Symbol[] stocks = {new Symbol("AAL"), new Symbol("T"), new Symbol("ADSK"), new Symbol("BAC"), new Symbol("BA"), new Symbol("KO"), new Symbol("EBAY"), new Symbol("XOM"), new Symbol("F"), new Symbol("GWW"), new Symbol("HAS"), new Symbol("ORCL"), new Symbol("FDX")};
10      	
11      	test(stocks, stocks);
12  
13          /*
14      	Symbol[] promiseTest = {new Symbol("F"), new Symbol("APC"), new Symbol("CA"), new Symbol("C"), new Symbol("D"), new Symbol("GAS")};
15          getPromisingStocks(promiseTest, DAYS_BACK + 1, FUTURE_DAYS, 0.3);
16          */
17      }
18  
19      private static void test(Symbol[] trainStocks, Symbol[] testStocks) throws IOException {
20          int FUTURE_DAYS = 10;
21          int TRAIN_DAYS = 500;
22          int TEST_DAYS = 50;
23          int DAYS_BACK = TEST_DAYS + FUTURE_DAYS + 1;
24          
25          GradientDescent gd = new GradientDescent(trainStocks, TRAIN_DAYS, DAYS_BACK, FUTURE_DAYS);
26  
27          // Get test data
28          double[][] test = gd.getData(testStocks, TEST_DAYS, TEST_DAYS + 1);
29          double[] testActual = GradientDescent.getActual(testStocks, TEST_DAYS, TEST_DAYS + 1, FUTURE_DAYS);
30          
31          // Normalize test data using training mean and standard deviation
32          test = gd.normalize(test);
33  
34  
35          // Train
36          double[] theta = gd.train(2.0, 1000000);
37  
38          
39          // Show predictions
40          Features.print(theta);
41          System.out.println("Cost (Try to minimize): " + gd.getCost(theta));
42  
43          for (int i = 0; i < testActual.length; i++) {
44              System.out.println("Actual: " + testActual[i]);
45  			System.out.println("Prediction: " + gd.getPredictions(theta, test)[i]);
46          }
47      }
48      
49      public static Symbol[] getPromisingStocks(Symbol[] symbols, int startDaysAgo, int futureDays, double diversity) throws IOException {   	
50      	GradientDescent gd = new GradientDescent(symbols, 30, startDaysAgo, futureDays);
51      	
52      	double[] weights = gd.train(1.0, 1000000);
53      
54      	
55      	//Get price predictions for symbols
56      	double[][] data = gd.getData(symbols, 1, startDaysAgo);
57      	data = gd.normalize(data);
58      	double[] predictedValues = gd.getPredictions(weights, data);
59      	
60  		double[] priceRatios = new double[predictedValues.length];
61  		for (int i = 0; i < symbols.length; i++){
62  			//We wouldn't necessarily be selling the stock at the exact closing price, but its our best estimation to selling price
63  			priceRatios[i] = predictedValues[i] / symbols[i].getAdjClose(startDaysAgo).doubleValue();
64  		}
65  		
66  		//Rank sell/buy ratios
67  		int[] rankedRatioIndices = new int[priceRatios.length];
68  		double[] tempRatios = new double[priceRatios.length];
69  		
70  		//Java was passing by value, so I had to manually copy values into the temporary array
71  		for (int i = 0; i < priceRatios.length; i++){
72  			tempRatios[i] = priceRatios[i];
73  		}
74  		
75  		for (int i = 0; i < tempRatios.length; i++){
76  			double max = 0.0;
77  			int indexOfMax = 0;
78  			for (int a = 0; a < tempRatios.length; a++){
79  				if (tempRatios[a] > max){
80  					max = tempRatios[a];
81  					indexOfMax = a;
82  				}
83  			}
84  			
85  			tempRatios[indexOfMax] = 0.0;
86  			rankedRatioIndices[i] = indexOfMax;
87  		}
88  		
89  		//Now rankedRatioIndices stores the indices of promising stocks
90  		//Only take certain percentage of most promising stocks using diversity value
91  		int numPromising = (int)(diversity * (double)symbols.length + 0.5);
92  		Symbol[] promisingStocks = new Symbol[numPromising];
93  		for (int i = 0; i < numPromising; i++){
94  			promisingStocks[i] = symbols[rankedRatioIndices[i]];
95  		}
96  		
97  		return promisingStocks;
98  	}
99  }