44#include "EST_cutils.h"
53 return impurity.value();
54 else if (question.ask(d))
55 return left->predict(d);
57 return right->predict(d);
64 else if (question.ask(d))
65 return left->predict_node(d);
67 return right->predict_node(d);
74 if ((left == 0) && (right == 0))
76 else if (get_impurity().type() != wnim_class)
82void WNode::prune(
void)
90 if (left != 0) left->prune();
91 if (right != 0) right->prune();
95 if ((left->pure() == TRUE) && ((right->pure() == TRUE)) &&
96 (left->get_impurity().value() == right->get_impurity().value()))
98 delete left; left = 0;
99 delete right; right = 0;
105void WNode::held_out_prune()
116 wgn_score_question(question,get_data());
117 if (question.get_score() < get_impurity().measure())
119 wgn_find_split(question,get_data(),
122 left->held_out_prune();
123 right->held_out_prune();
127 delete left; left = 0;
128 delete right; right = 0;
138 for (i=0;i<
margin;i++) s <<
" ";
145 left->print_out(s,
margin+1);
146 right->print_out(s,
margin+1);
160void WDataSet::ignore_non_numbers()
165 for (i=0; i<dlength; i++)
167 if ((p_type[i] == wndt_binary) ||
168 (p_type[i] == wndt_float))
186 description = car(vload(
fname,1));
187 dlength = siod_llength(description);
193 if (wgn_predictee_name ==
"")
198 for (i=0,d=description; d != NIL; d=cdr(d),i++)
200 p_name[i] = get_c_string(car(car(d)));
201 tname = get_c_string(car(cdr(car(d))));
203 if ((wgn_predictee_name !=
"") && (wgn_predictee_name == p_name[i]))
205 if ((wgn_count_field_name !=
"") &&
206 (wgn_count_field_name == p_name[i]))
208 if ((
tname ==
"count") || (i == wgn_count_field))
211 p_type[i] = wndt_ignore;
215 else if ((
tname ==
"ignore") || (siod_member_str(p_name[i],
ignores)))
217 p_type[i] = wndt_ignore;
219 if (i == wgn_predictee)
220 wagon_error(
EST_String(
"predictee \"")+p_name[i]+
221 "\" can't be ignored \n");
223 else if (siod_llength(car(d)) > 2)
227 siod_list_to_strlist(
rest,
sl);
228 p_type[i] = wgn_discretes.def(
sl);
229 if (streq(get_c_string(car(
rest)),
"_other_"))
230 wgn_discretes[p_type[i]].def_val(
"_other_");
232 else if (
tname ==
"binary")
233 p_type[i] = wndt_binary;
234 else if (
tname ==
"cluster")
235 p_type[i] = wndt_cluster;
236 else if (
tname ==
"vector")
237 p_type[i] = wndt_vector;
238 else if (
tname ==
"trajectory")
239 p_type[i] = wndt_trajectory;
240 else if (
tname ==
"ols")
241 p_type[i] = wndt_ols;
242 else if (
tname ==
"matrix")
243 p_type[i] = wndt_matrix;
244 else if (
tname ==
"float")
245 p_type[i] = wndt_float;
249 "\" for field number "+itoString(i)+
250 "/"+p_name[i]+
" in description file \""+
fname+
"\"");
254 if (wgn_predictee == -1)
256 wagon_error(
EST_String(
"predictee field \"")+wgn_predictee_name+
257 "\" not found in description ");
261const int WQuestion::ask(
const WVector &w)
const
267 if (w.get_flt_val(feature_pos) == operand1.
Float())
272 if (w.get_int_val(feature_pos) == 1)
276 case wnop_greaterthan:
277 if (w.get_flt_val(feature_pos) > operand1.
Float())
282 if (w.get_flt_val(feature_pos) < operand1.
Float())
287 if (w.get_int_val(feature_pos) == operand1.
Int())
292 if (ilist_member(operandl,w.get_int_val(feature_pos)))
297 wagon_error(
"Unknown test operator");
308 s <<
"(" << wgn_dataset.feat_name(
q.get_fp());
312 s <<
" = " <<
q.get_operand1().string();
316 case wnop_greaterthan:
317 s <<
" > " <<
q.get_operand1().Float();
320 s <<
" < " <<
q.get_operand1().Float();
323 name = wgn_discretes[wgn_dataset.ftype(
q.get_fp())].
324 name(
q.get_operand1().Int());
327 s << quote_string(name,
"\"",
"\\",1);
332 name = wgn_discretes[wgn_dataset.ftype(
q.get_fp())].
333 name(
q.get_operand1().Int());
334 s <<
" matches " << quote_string(name,
"\"",
"\\",1);
338 for (
int l=0; l <
q.get_operandl().length(); l++)
340 name = wgn_discretes[wgn_dataset.ftype(
q.get_fp())].
341 name(
q.get_operandl().nth(l));
343 s << quote_string(name,
"\"",
"\\",1);
367 cerr <<
"WImpurity: no value currently set\n";
370 else if (t==wnim_class)
372 else if (t==wnim_cluster)
374 else if (t==wnim_ols)
376 else if (t==wnim_vector)
378 else if (t==wnim_trajectory)
384double WImpurity::samples(
void)
388 else if (t==wnim_class)
390 else if (t==wnim_cluster)
391 return members.length();
392 else if (t==wnim_ols)
393 return members.length();
394 else if (t==wnim_vector)
395 return members.length();
396 else if (t==wnim_trajectory)
397 return members.length();
407 a.
reset(); trajectory=0; l=0; width=0;
409 for (i=0; i <
ds.n(); i++)
413 else if (wgn_count_field == -1)
414 cumulate((*(
ds(i)))[wgn_predictee],1);
416 cumulate((*(
ds(i)))[wgn_predictee],
417 (*(
ds(i)))[wgn_count_field]);
421float WImpurity::measure(
void)
425 else if (t == wnim_vector)
426 return vector_impurity();
427 else if (t == wnim_trajectory)
428 return trajectory_impurity();
429 else if (t == wnim_matrix)
431 else if (t == wnim_class)
433 else if (t == wnim_cluster)
434 return cluster_impurity();
435 else if (t == wnim_ols)
436 return ols_impurity();
439 cerr <<
"WImpurity: can't measure unset object" <<
endl;
444float WImpurity::vector_impurity()
459 if (wgn_VertexFeats.
a(0,
j) > 0.0)
467 b.cumulate(wgn_VertexTrack.
a(i,
j), member_counts.
item(
countpp)) ;
482 if (wgn_VertexFeats.
a(0,
j) > 0.0)
485 for (
pp=members.head(),
countpp=member_counts.head();
pp != 0;
490 c[
j].cumulate(wgn_VertexTrack.
a(i,
j),member_counts.
item(
countpp));
497 for (
pp=members.head(),
countpp=member_counts.head();
pp != 0;
507 for (
q=-20;
q<=20;
q++)
510 for (
j=67+
q;
j<147+
q;
j++)
512 x = c[
j].
mean() - wgn_VertexTrack(i,
j);
536 if (wgn_VertexFeats.
a(0,
j) > 0.0)
538 for (
pp=members.head();
pp != 0;
pp=
pp->next())
545 if (wgn_VertexFeats.
a(0,
j) > 0.0)
547 for (
pp=members.head();
pp != 0;
pp=
pp->next())
550 cs[i][
j] += (wgn_VertexTrack.
a(
mmm,i)-
cs[
j][
j].mean())*
551 (wgn_VertexTrack.
a(
mmm,
j)-
cs[
j][
j].mean());
558 if (wgn_VertexFeats.
a(0,
j) > 0.0)
559 a +=
cs[i][
j].stddev();
561 count =
cs[0][0].samples();
570 for (
pp=members.head();
pp != 0;
pp=
pp->next())
578 if (wgn_VertexFeats.
a(0,
j) > 0.0)
580 d = wgn_VertexTrack(
x,
j)-wgn_VertexTrack(
y,
j);
590 return a.
mean() * count;
593WImpurity::~WImpurity()
600 delete [] trajectory[
j];
601 delete [] trajectory;
608float WImpurity::trajectory_impurity()
620 double n,
m,
m1,
m2, w;
633 for (
pp=members.head();
pp != 0;
pp=
pp->next())
636 for (
q=0;
q<wgn_UnitTrack.
a(i,1);
q++)
638 ni = (int)wgn_UnitTrack.
a(i,0)+
q;
639 if (wgn_VertexTrack.
a(
ni,0) == -1.0)
646 if (
q==wgn_UnitTrack.
a(i,1))
652 l2ss += wgn_UnitTrack.
a(i,1) - (
q+1) - 1;
653 lss += wgn_UnitTrack.
a(i,1);
654 if (wgn_UnitTrack.
a(i,1) > l)
655 l = (
int)wgn_UnitTrack.
a(i,1);
660 l = ((int)
lss.mean() < 7) ? 7 : (int)
lss.mean();
668 for (
pp=members.head();
pp != 0;
pp=
pp->next())
672 s = (int)wgn_UnitTrack.
a(i,0);
673 for (
ti=0,n=0.0;
ti<l;
ti++,n+=
m)
678 if (wgn_VertexFeats.
a(0,
j) > 0.0)
679 trajectory[
ti][
j] += wgn_VertexTrack.
a(s+
ni,
j);
689 if (wgn_VertexFeats.
a(0,
j) > 0.0)
690 stdss += trajectory[
ti][
j].stddev();
694 score =
stdss.mean() * members.length();
698 l1 = (
l1ss.mean() < 10.0) ? 10 : (int)
l1ss.mean();
699 l2 = (
l2ss.mean() < 10.0) ? 10 : (int)
l2ss.mean();
707 for (
pp=members.head();
pp != 0;
pp=
pp->next())
711 s = (int)wgn_UnitTrack.
a(i,0);
712 for (
q=0;
q<wgn_UnitTrack.
a(i,1);
q++)
713 if (wgn_VertexTrack.
a(s+
q,0) == -1.0)
718 s2l = (int)wgn_UnitTrack.
a(i,1) - (
s1l + 2);
724 ni = s + (((int)n <
s1l) ? (int)n :
s1l - 1);
726 if (wgn_VertexFeats.
a(0,
j) > 0.0)
727 trajectory[
ti][
j] += wgn_VertexTrack.
a(
ni,
j);
731 if (wgn_VertexFeats.
a(0,
j) > 0.0)
732 trajectory[
ti][
j] += -1;
737 ni = s + (((int)n <
s2l) ? (int)n :
s2l - 1);
739 if (wgn_VertexFeats.
a(0,
j) > 0.0)
740 trajectory[
ti][
j] += wgn_VertexTrack.
a(
ni,
j);
743 if (wgn_VertexFeats.
a(0,
j) > 0.0)
744 trajectory[
ti][
j] += -2;
753 if (wgn_VertexFeats.
a(0,
j) > 0.0)
754 stdss += trajectory[
ti][
j].stddev() * w;
756 for (w=1.0,
ti++;
ti<l-1;
ti++,w-=
m)
758 if (wgn_VertexFeats.
a(0,
j) > 0.0)
759 stdss += trajectory[
ti][
j].stddev() * w;
762 score =
stdss.mean() * members.length();
778 w = wgn_dataset.width();
780 X.resize(members.length(),w);
781 Y.resize(members.length(),1);
785 for (p=0,
pp=members.head();
pp; p++,
pp=
pp->next())
794 Y.a_no_check(p,0) = (*wv)[0];
795 X.a_no_check(p,0) = 1;
796 for (
m=1,
xm=1;
m < w;
m++)
798 if (wgn_dataset.ftype(
m) == wndt_float)
804 X.a_no_check(p,
xm) = (*wv)[
m];
817float WImpurity::ols_impurity()
849 printf(
"Impurity OLS X(%d,%d) Y(%d,%d) %f, %f, %f\n",
850 X.num_rows(),X.num_columns(),
Y.num_rows(),
Y.num_columns(),
862float WImpurity::cluster_impurity()
873 for (
pp=members.head();
pp != 0;
pp=
pp->next())
876 for (
q=
pp->next();
q != 0;
q=
q->next())
893float WImpurity::cluster_distance(
int i)
897 float dist = cluster_member_mean(i);
907int WImpurity::in_cluster(
int i)
911 float dist = cluster_member_mean(i);
914 for (
pp=members.head();
pp != 0;
pp=
pp->next())
916 if (
dist < cluster_member_mean(members.
item(
pp)))
922float WImpurity::cluster_ranking(
int i)
925 float dist = cluster_distance(i);
929 for (
pp=members.head();
pp != 0;
pp=
pp->next())
931 if (
dist >= cluster_distance(members.
item(
pp)))
938float WImpurity::cluster_member_mean(
int i)
946 for (sum=0.0,n=0,
q=members.head();
q != 0;
q=
q->next())
951 dist = (
j < i ? wgn_DistMatrix(i,
j) : wgn_DistMatrix(
j,i));
957 return ( n == 0 ? 0.0 : sum/n );
960void WImpurity::cumulate(
const float pv,
double count)
964 if (wgn_dataset.ftype(wgn_predictee) == wndt_cluster)
969 else if (wgn_dataset.ftype(wgn_predictee) == wndt_ols)
974 else if (wgn_dataset.ftype(wgn_predictee) == wndt_vector)
980 member_counts.
append((
float)count);
982 else if (wgn_dataset.ftype(wgn_predictee) == wndt_trajectory)
987 else if (wgn_dataset.ftype(wgn_predictee) >= wndt_class)
990 p.init(&wgn_discretes[wgn_dataset.ftype(wgn_predictee)]);
992 p.cumulate((
int)
pv,count);
994 else if (wgn_dataset.ftype(wgn_predictee) == wndt_binary)
997 a.cumulate((
int)
pv,count);
999 else if (wgn_dataset.ftype(wgn_predictee) == wndt_float)
1002 a.cumulate(
pv,count);
1006 wagon_error(
"WImpurity: cannot cumulate EST_Val type");
1015 if (
imp.t == wnim_float)
1016 s <<
"(" <<
imp.a.stddev() <<
" " <<
imp.a.mean() <<
")";
1017 else if (
imp.t == wnim_vector)
1021 imp.vector_impurity();
1022 if (wgn_vertex_output ==
"mean")
1030 b.cumulate(wgn_VertexTrack.
a(
imp.members.item(p),
j),
imp.member_counts.item(
countp));
1033 s <<
"(" << b.
mean() <<
" ";
1037 s <<
"0.001" <<
")";
1046 double best = WGN_HUGE_VAL;
1056 for (p=
imp.members.head(); p != 0; p=p->next())
1058 cs[
j] += wgn_VertexTrack.
a(
imp.members.item(p),
j);
1062 for (p=
imp.members.head(); p != 0; p=p->next())
1065 if (wgn_VertexFeats.
a(0,
j) > 0.0)
1067 d = (wgn_VertexTrack.
a(
imp.members.item(p),
j)-
cs[
j].mean())
1082 s << wgn_VertexTrack.
a(
bestp,
j);
1086 s <<
cs[
j].stddev();
1097 s <<
imp.a.mean() <<
")";
1099 else if (
imp.t == wnim_trajectory)
1102 imp.trajectory_impurity();
1103 for (i=0; i<
imp.l; i++)
1108 s <<
"(" <<
imp.trajectory[i][
j].mean() <<
" "
1109 <<
imp.trajectory[i][
j].stddev() <<
" " <<
")";
1115 s <<
imp.a.mean() <<
")";
1117 else if (
imp.t == wnim_cluster)
1121 for (p=
imp.members.head(); p != 0; p=p->next())
1124 s <<
"(" <<
imp.members.item(p) <<
" " <<
1125 imp.cluster_member_mean(
imp.members.item(p)) <<
")";
1131 s <<
imp.a.mean() <<
")";
1133 else if (
imp.t == wnim_ols)
1149 printf(
"no robust ols\n");
1156 for (i=0; i<
coeffsl.num_rows(); i++)
1167 s <<
") " <<
cor <<
")";
1169 else if (
imp.t == wnim_class)
1176 for (i=
imp.p.item_start(); !
imp.p.item_end(i); i=
imp.p.item_next(i))
1178 imp.p.item_prob(i,name,prob);
1179 s <<
"(" << name <<
" " << prob <<
") ";
1181 s <<
imp.p.most_probable(&prob) <<
")";
1184 s <<
"([WImpurity unset])";
const EST_String & most_probable(double *prob=NULL) const
Return the most probable member of the distribution.
double samples(void) const
Total number of example found.
double entropy(void) const
int matches(const char *e, int pos=0) const
Exactly match this string?
double stddev(void) const
standard deviation of currently cummulated values
double variance(void) const
variance of currently cummulated values
double mean(void) const
mean of currently cummulated values
void reset(void)
reset internal values
double samples(void)
number of samples in set
T & item(const EST_Litem *p)
void append(const T &item)
add item onto end of list
INLINE const T & a_no_check(int row, int col) const
const access with no bounds check, care recommend
void resize(int n, int set=1)
resize vector
void resize(int n, int set=1)
float & a(int i, int c=0)
int num_channels() const
return number of channels in track
const int Int(void) const
const float Float(void) const