package iterators;
import java.io.*;
import main.BetterTokenizer;
import main.GeneralInput;
import parameters.ParameterSet;
import java.util.Vector;
import java.util.Enumeration;
/**
Use this to find the surface around a particular instance of the network. There are currently two modes:
FindEdges
For each parameter that its
told to vary, this iterator changes that parameter in both directions until it reaches a stopping
value that's larger than the thresholdScore or until it goes outside the bounds of that parameter.
I've made the steps a bit crudely. It simply multiplies the total range of the parameter by Fraction,
and that's the initial step in each direction. Each subsequent step is twice the previous step (ie. next step
is original_pos + 2 * initialStep, then original_pos + 4 * initialStep, ...). If it reaches a value above
the thresholdScore, it backs up a bit and does a couple more runs just before that to get a better idea of
what the transition looks like. If it goes in one direction for numMoves without going above the threshold
score, it stops going in that direction.
Range
This just scans from lower to upper bounds. You can include bounds when you input parameters.
RelativeRange
This scans between Fraction times current value on either side of current value
FractionRange
This scans Fraction of the whole range, centered around current position.
*/
public class SurfaceIterator extends ModelIterator {
int numMoves = 10;
float thresholdScore = 0.8f;
float startFraction = 0.01f; // The fraction of the parameter's range to move at the beginning of finding the surface around that parameter
// Later steps will multiply this starting move to make successive steps larger and larger
static final int FIND_EDGES = 1, SCAN_RANGE = 2, SCAN_RELATIVE_RANGE = 3, SCAN_FRACTION_RANGE = 4;
int mode = SCAN_RANGE;
int dimensions = 1;
Vector output = new Vector(10);
public SurfaceIterator() {
setPrint(true);
}
public ModelIterator copy() throws Exception {
SurfaceIterator newIter = new SurfaceIterator();
newIter.init(this.network, this.model);
newIter.nParsTV = this.nParsTV;
newIter.parsTV = this.parsTV.copy();
newIter.theFunction = this.theFunction.copy();
newIter.numMoves = this.numMoves;
newIter.thresholdScore = this.thresholdScore;
newIter.dimensions = this.dimensions;
newIter.mode = this.mode;
return newIter;
}
// Put any parameters specific to this iterator class in here, as if clauses.
protected void loadParameter(String info, BetterTokenizer tokenizer) throws Exception {
if(info.equals("NumMoves")) {GeneralInput.nextToken(tokenizer); numMoves = (int)tokenizer.nval; }
else if(info.equals("Threshold")) {GeneralInput.nextToken(tokenizer); thresholdScore = (float)tokenizer.nval; }
else if(info.equals("Fraction")) {GeneralInput.nextToken(tokenizer); startFraction = (float)tokenizer.nval; }
else if(info.equals("Mode")) {
GeneralInput.nextToken(tokenizer);
if(tokenizer.sval.equals("FindEdges")) mode = FIND_EDGES;
else if(tokenizer.sval.equals("Range")) mode = SCAN_RANGE;
else if(tokenizer.sval.equals("RelativeRange")) mode = SCAN_RELATIVE_RANGE;
else if(tokenizer.sval.equals("FractionRange")) mode = SCAN_FRACTION_RANGE;
}
else if(info.equals("NumDimensions")) { GeneralInput.nextToken(tokenizer); dimensions = (int)tokenizer.nval; }
else super.loadParameter(info, tokenizer);
}
//ELI 3/23 Changed from run to doRun
public void doRun() {
if(dimensions == 2 && mode == FIND_EDGES) {
System.out.println("Find Edges not implemented for 2 dimensions. Sorry. ");
super.stopRun();
return;
}
if(dimensions > 2) {
System.out.println("Dimensions > 2 not yet implemented. Sorry.");
super.stopRun();
return;
}
reset(); // reload p, nParsTV, delta because these may have changed
float cur_score = F(p);
int par_num1 = 0, par_num2 = 1;
while(par_num1 < nParsTV) {
output.removeAllElements();
if(mode == FIND_EDGES) {
System.out.println("Scanning " + parsTV.getName(par_num1));
output.addElement(new ScoreStorage(p[par_num1], cur_score));
scanSurfaceEdges(par_num1, 1, output);
scanSurfaceEdges(par_num1, -1, output);
}
else if(mode == SCAN_RANGE && dimensions == 1) {
System.out.println("Scanning " + parsTV.getName(par_num1));
scanSurfaceRange(par_num1, parsTV.getLowerBound(par_num1), parsTV.getUpperBound(par_num1), numMoves, -1, output);
}
else if(mode == SCAN_RANGE && dimensions == 2) {
System.out.println("Scanning " + parsTV.getName(par_num1) + "," + parsTV.getName(par_num2));
scanSurface2Range(par_num1, par_num2, parsTV.getLowerBound(par_num1), parsTV.getUpperBound(par_num1),
parsTV.getLowerBound(par_num2), parsTV.getUpperBound(par_num2), numMoves, output);
}
else if(mode == SCAN_RELATIVE_RANGE && dimensions == 1) {
System.out.println("Scanning " + parsTV.getName(par_num1));
float low = p[par_num1] - startFraction * p[par_num1];
if(low < parsTV.getLowerBound(par_num1)) low = parsTV.getLowerBound(par_num1);
float high = p[par_num1] + startFraction * p[par_num1];
if(high > parsTV.getUpperBound(par_num1)) high = parsTV.getUpperBound(par_num1);
scanSurfaceRange(par_num1, low, high, numMoves, -1, output);
}
else if(mode == SCAN_RELATIVE_RANGE && dimensions == 2) {
System.out.println("Scanning " + parsTV.getName(par_num1) + "," + parsTV.getName(par_num2));
float low1 = p[par_num1] - startFraction * p[par_num1];
if(low1 < parsTV.getLowerBound(par_num1)) low1 = parsTV.getLowerBound(par_num1);
float high1 = p[par_num1] + startFraction * p[par_num1];
if(high1 > parsTV.getUpperBound(par_num1)) high1 = parsTV.getUpperBound(par_num1);
float low2 = p[par_num2] - startFraction * p[par_num2];
if(low2 < parsTV.getLowerBound(par_num2)) low2 = parsTV.getLowerBound(par_num2);
float high2 = p[par_num2] + startFraction * p[par_num2];
if(high2 > parsTV.getUpperBound(par_num2)) high2 = parsTV.getUpperBound(par_num2);
scanSurface2Range(par_num1, par_num2, low1, high1, low2, high2, numMoves, output);
}
else if(mode == SCAN_FRACTION_RANGE && dimensions == 1) {
System.out.println("Scanning " + parsTV.getName(par_num1));
float range = (parsTV.getUpperBound(par_num1) - parsTV.getLowerBound(par_num1)) * startFraction;
float low = p[par_num1] - range / 2;
if(low < parsTV.getLowerBound(par_num1)) low = parsTV.getLowerBound(par_num1);
float high = p[par_num1] + range / 2;
if(high > parsTV.getUpperBound(par_num1)) high = parsTV.getUpperBound(par_num1);
scanSurfaceRange(par_num1, low, high, numMoves, -1, output);
}
else if(mode == SCAN_FRACTION_RANGE && dimensions == 2) {
System.out.println("Scanning " + parsTV.getName(par_num1) + "," + parsTV.getName(par_num2));
float range = (parsTV.getUpperBound(par_num1) - parsTV.getLowerBound(par_num1)) * startFraction;
float low1 = p[par_num1] - range / 2;
if(low1 < parsTV.getLowerBound(par_num1)) low1 = parsTV.getLowerBound(par_num1);
float high1 = p[par_num1] + range / 2;
if(high1 > parsTV.getUpperBound(par_num1)) high1 = parsTV.getUpperBound(par_num1);
range = (parsTV.getUpperBound(par_num2) - parsTV.getLowerBound(par_num2)) * startFraction;
float low2 = p[par_num2] - range / 2;
if(low2 < parsTV.getLowerBound(par_num2)) low2 = parsTV.getLowerBound(par_num2);
float high2 = p[par_num2] + range / 2;
if(high2 > parsTV.getUpperBound(par_num2)) high2 = parsTV.getUpperBound(par_num2);
scanSurface2Range(par_num1, par_num2, low1, high1, low2, high2, numMoves, output);
}
// Save the output to the file
if(dimensions == 1) ps.println("\r\r" + parsTV.getName(par_num1));
else ps.println("\r\r" + parsTV.getName(par_num1) + "\t" + parsTV.getName(par_num2));
Enumeration enum = output.elements();
while(enum.hasMoreElements()) {
ScoreStorage elem = (ScoreStorage)enum.nextElement();
ps.println(elem.toString());
}
par_num1++;
if(dimensions == 2 && par_num1 == par_num2) par_num1++;
if(par_num1 >= nParsTV && dimensions == 2 && par_num2 < nParsTV - 1) {
par_num1 = 0;
par_num2++;
}
}
finalScore = cur_score;
super.stopRun();
}
private void scanSurface2Range(int par_num1, int par_num2, float low1, float high1, float low2, float high2,
int num_moves, Vector output) {
float incr = (high2 - low2) / (num_moves - 1);
float orig_value = p[par_num2];
for(p[par_num2] = low2; p[par_num2] <= high2; p[par_num2] += incr) {
scanSurfaceRange(par_num1, low1, high1, num_moves, p[par_num2], output);
}
p[par_num2] = orig_value;
}
private void scanSurfaceRange(int par_num, float low, float high, int num_moves, float second_value, Vector output) {
float incr = (high - low) / (num_moves - 1);
float orig_value = p[par_num];
for(p[par_num] = low; p[par_num] <= high; p[par_num] += incr) {
output.addElement(new ScoreStorage(p[par_num], second_value, F(p)));
}
p[par_num] = orig_value;
}
private void scanSurfaceEdges(int par_num, int dir, Vector output) {
float score;
float orig_value = p[par_num];
float movelen = startFraction * (parsTV.getUpperBound(par_num) - parsTV.getLowerBound(par_num)) * dir;
int num_moves = 0;
score = 0;
while(score < thresholdScore && num_moves < numMoves && inBounds(orig_value + movelen, par_num)) {
p[par_num] = orig_value + movelen;
score = F(p);
output.addElement(new ScoreStorage(p[par_num], score));
movelen *= 2;
num_moves++;
}
// If the score went above the threshold, go backwards a little to get more
// detail on how it left the region
movelen /= 2;
if(inBounds(orig_value + movelen, par_num)) {
movelen /= 4;
p[par_num] -= movelen;
score = F(p);
output.addElement(new ScoreStorage(p[par_num], score));
movelen /= 2;
p[par_num] -= movelen;
// System.out.print(parsTV.getName(par_num) + " = " + p[par_num] + "\t"); // Print this here because F prints out score - so might as well see everything
score = F(p);
output.addElement(new ScoreStorage(p[par_num], score));
}
p[par_num] = orig_value;
}
private boolean inBounds(float val, int par_num) {
if(val >= parsTV.getLowerBound(par_num) && val <= parsTV.getUpperBound(par_num))
return true;
else return false;
}
class ScoreStorage {
float score, parValue1, parValue2;
ScoreStorage(float par_value, float score) {
this.score = score; this.parValue1 = par_value; this.parValue2 = -1;
}
ScoreStorage(float par_value1, float par_value2, float score) {
this.score = score; this.parValue1 = par_value1; this.parValue2 = par_value2;
}
public String toString() {
if(parValue2 < 0) return new String(parValue1 + "\t" + score);
else return new String(parValue1 + "\t" + parValue2 + "\t" + score);
}
}
}