From d0fe0606639365fc59f69924f2f75cd1ea03da96 Mon Sep 17 00:00:00 2001 From: Aaron Raddon Date: Sun, 13 May 2018 15:24:27 -0700 Subject: [PATCH 1/4] refactor projection --- exec/projection.go | 6 +++ plan/plan.go | 3 ++ plan/planner_select.go | 90 ++++++++++++++++++++++++------------------ plan/projection.go | 28 ++++++++----- rel/sql.go | 1 + testutil/testsuite.go | 6 ++- 6 files changed, 86 insertions(+), 48 deletions(-) diff --git a/exec/projection.go b/exec/projection.go index 90048399..01b3e204 100644 --- a/exec/projection.go +++ b/exec/projection.go @@ -123,6 +123,9 @@ func (m *Projection) projectionEvaluator(isFinal bool) MessageHandler { // If we have a projection, use that as col count if m.p.Proj != nil { colCt = len(m.p.Proj.Columns) + if len(m.p.Proj.Columns) == 0 { + u.Errorf("crap %+v", m.p.Proj) + } } rowCt := 0 @@ -175,6 +178,9 @@ func (m *Projection) projectionEvaluator(isFinal bool) MessageHandler { } if col.Star { starRow := mt.Values() + if colCt != len(starRow) { + u.Warnf("wtf wrong count %v %v", colCt, len(starRow)) + } //u.Infof("star row: %#v", starRow) if len(columns) > 1 { // select *, myvar, 1 diff --git a/plan/plan.go b/plan/plan.go index 54789c20..ccdd4be5 100644 --- a/plan/plan.go +++ b/plan/plan.go @@ -821,6 +821,9 @@ func (m *Projection) ToPb() (*PlanPb, error) { if err != nil { return nil, err } + if m.Proj == nil { + u.WarnT(10) + } ppbptr := m.Proj.ToPB() ppcpy := *ppbptr ppcpy.Final = m.Final diff --git a/plan/planner_select.go b/plan/planner_select.go index 8b0a0035..e4b2d8f7 100644 --- a/plan/planner_select.go +++ b/plan/planner_select.go @@ -34,7 +34,8 @@ func (m *PlannerDefault) WalkSelect(p *Select) error { return m.WalkLiteralQuery(p) - } else if len(p.Stmt.From) == 1 { + } + /*else if len(p.Stmt.From) == 1 { p.Stmt.From[0].Source = p.Stmt // TODO: move to a Finalize() in query parser/planner @@ -55,40 +56,47 @@ func (m *PlannerDefault) WalkSelect(p *Select) error { } } else { - - var prevSource *Source - var prevTask Task - - for i, from := range p.Stmt.From { - - // Need to rewrite the From statement to ensure all fields necessary to support - // joins, wheres, etc exist but is standalone query - from.Rewrite(p.Stmt) - srcPlan, err := NewSource(m.Ctx, from, false) - if err != nil { - return nil - } - err = m.Planner.WalkSourceSelect(srcPlan) - if err != nil { - u.Errorf("Could not visitsubselect %v %s", err, from) - return err - } - - // now fold into previous task - if i != 0 { - from.Seekable = true - // fold this source into previous - curMergeTask := NewJoinMerge(prevTask, srcPlan, prevSource.Stmt, srcPlan.Stmt) - prevTask = curMergeTask - } else { - prevTask = srcPlan - } - prevSource = srcPlan - //u.Debugf("got task: %T", lastSource) + */ + var prevSource *Source + var rootTask Task + + for i, from := range p.Stmt.From { + + // Need to rewrite the From statement to ensure all fields necessary to support + // joins, wheres, etc exist but is standalone query + u.Debugf("from.Source: %s", p.Stmt) + u.Debugf("from: %s", from.String()) + from.Rewrite(p.Stmt) + u.Debugf("from-rewrite: %s", from.String()) + u.Debugf("from.Source: %s", from.Source.String()) + sourceTask, err := NewSource(m.Ctx, from, false) + if err != nil { + return nil + } + // if len(p.Stmt.From) == 1 { + // p.From = []*Source{sourceTask} + // } + p.From = append(p.From, sourceTask) + err = m.Planner.WalkSourceSelect(sourceTask) + if err != nil { + u.Errorf("Could not visitsubselect %v %s", err, from) + return err } - p.Add(prevTask) + // now fold into previous task + if i != 0 { + from.Seekable = true + // fold this source into previous + rootTask = NewJoinMerge(rootTask, sourceTask, prevSource.Stmt, sourceTask.Stmt) + //rootTask = curMergeTask + } else { + rootTask = sourceTask + } + prevSource = sourceTask + //u.Debugf("got task: %T", lastSource) } + p.Add(rootTask) + u.Infof("did we accidentally mutate the original statement? \n\t%s", p.Stmt) if p.Stmt.Where != nil { switch { @@ -118,6 +126,7 @@ func (m *PlannerDefault) WalkSelect(p *Select) error { p.Add(NewOrder(p.Stmt)) } + u.Debugf("needsFinalProject?%v", needsFinalProject) if needsFinalProject { err := m.WalkProjectionFinal(p) if err != nil { @@ -125,7 +134,6 @@ func (m *PlannerDefault) WalkSelect(p *Select) error { } } -finalProjection: if m.Ctx.Projection == nil { proj, err := NewProjectionFinal(m.Ctx, p) //u.Infof("Projection: %T:%p %T:%p", proj, proj, proj.Proj, proj.Proj) @@ -146,10 +154,12 @@ func (m *PlannerDefault) WalkProjectionFinal(p *Select) error { proj, err := NewProjectionFinal(m.Ctx, p) //u.Infof("Projection: %T:%p %T:%p", proj, proj, proj.Proj, proj.Proj) if err != nil { + u.Warnf("could not build projection err=%v for %s", err, p.Stmt) return err } p.Add(proj) if m.Ctx.Projection == nil { + u.Infof("set projection") m.Ctx.Projection = proj } else { // Not entirely sure we should be over-writing the projection? @@ -236,12 +246,15 @@ func (m *PlannerDefault) WalkSourceSelect(p *Source) error { } // Add a Non-Final Projection to choose the columns for results - if !p.Final { - err := m.WalkProjectionSource(p) - if err != nil { - return err + /* + if !p.Final { + u.Warnf("!final wtf %s", p.Stmt.String()) + err := m.WalkProjectionSource(p) + if err != nil { + return err + } } - } + */ } if needsJoinKey { @@ -256,6 +269,7 @@ func (m *PlannerDefault) WalkSourceSelect(p *Source) error { func (m *PlannerDefault) WalkProjectionSource(p *Source) error { // Add a Non-Final Projection to choose the columns for results //u.Debugf("exec.projection: %p job.proj: %p added %s", p, m.Ctx.Projection, p.Stmt.String()) + u.Infof("------------- Source Projection") proj := NewProjectionInProcess(p.Stmt.Source) //u.Debugf("source projection: %p added %s", proj, p.Stmt.Source.String()) p.Add(proj) diff --git a/plan/projection.go b/plan/projection.go index e76bbb60..6d628e27 100644 --- a/plan/projection.go +++ b/plan/projection.go @@ -12,13 +12,14 @@ import ( "github.com/araddon/qlbridge/value" ) -// A static projection has already had its column/types defined -// and doesn't need to use internal schema to find it, often internal SHOW/DESCRIBE +// NewProjectionStatic create A static projection for literal query. +// IT has already had its column/types defined and doesn't need to use internal +// schema to find it, often internal SHOW/DESCRIBE. func NewProjectionStatic(proj *rel.Projection) *Projection { return &Projection{Proj: proj, PlanBase: NewPlanBase(false)} } -// Final Projections project final select columns for result-writing +// NewProjectionFinal project final select columns for result-writing func NewProjectionFinal(ctx *Context, p *Select) (*Projection, error) { s := &Projection{ P: p, @@ -27,10 +28,13 @@ func NewProjectionFinal(ctx *Context, p *Select) (*Projection, error) { Final: true, } var err error + u.Debugf("NewProjectionFinal") if len(p.Stmt.From) == 0 { + u.Warnf("literal projection") err = s.loadLiteralProjection(ctx) } else if len(p.From) == 1 && p.From[0].Proj != nil { s.Proj = p.From[0].Proj + u.Warnf("used the projection from From[0] %#v", s.Proj.Columns) } else { err = s.loadFinal(ctx, true) } @@ -39,6 +43,9 @@ func NewProjectionFinal(ctx *Context, p *Select) (*Projection, error) { } return s, nil } + +// NewProjectionInProcess create a projection for a non-final +// projection for source. func NewProjectionInProcess(stmt *rel.SqlSelect) *Projection { s := &Projection{ Stmt: stmt, @@ -89,11 +96,11 @@ func (m *Projection) loadLiteralProjection(ctx *Context) error { func (m *Projection) loadFinal(ctx *Context, isFinal bool) error { - //u.Debugf("creating plan.Projection final %s", m.Stmt.String()) + u.Debugf("creating plan.Projection final %s", m.Stmt.String()) m.Proj = rel.NewProjection() - for _, from := range m.Stmt.From { + for fromi, from := range m.Stmt.From { fromName := strings.ToLower(from.SourceName()) tbl, err := ctx.Schema.Table(fromName) @@ -108,7 +115,7 @@ func (m *Projection) loadFinal(ctx *Context, isFinal bool) error { //u.Debugf("getting cols? %v cols=%v", from.ColumnPositions()) for _, col := range from.Source.Columns { //_, right, _ := col.LeftRight() - //u.Infof("col %s", col) + u.Infof("%d from:%s col %s", fromi, from.Name, col) if col.Star { for _, f := range tbl.Fields { m.Proj.AddColumnShort(f.Name, f.ValueType()) @@ -117,16 +124,18 @@ func (m *Projection) loadFinal(ctx *Context, isFinal bool) error { if schemaCol, ok := tbl.FieldMap[col.SourceField]; ok { if isFinal { if col.InFinalProjection() { - //u.Debugf("in plan final %s", col.As) + u.Debugf("in plan final %s", col.As) m.Proj.AddColumnShort(col.As, schemaCol.ValueType()) + } else { + u.Warnf("not in plan final %v", col.As) } } else { - //u.Debugf("not final %s", col.As) + u.Debugf("not final %s", col.As) m.Proj.AddColumnShort(col.As, schemaCol.ValueType()) } //u.Debugf("projection: %p add col: %v %v", m.Proj, col.As, schemaCol.Type.String()) } else { - //u.Infof("schema col not found: final?%v col: %#v InFinal?%v", isFinal, col, col.InFinalProjection()) + u.Infof("schema col not found: final?%v col: %#v InFinal?%v", isFinal, col, col.InFinalProjection()) if isFinal { if col.InFinalProjection() { m.Proj.AddColumnShort(col.As, value.StringType) @@ -148,6 +157,7 @@ func (m *Projection) loadFinal(ctx *Context, isFinal bool) error { func projectionForSourcePlan(plan *Source) error { plan.Proj = rel.NewProjection() + u.Errorf("projection from source") // u.Debugf("created plan.Proj *rel.Projection %p", plan.Proj) // Not all Execution run-times support schema. ie, csv files and other "ad-hoc" structures diff --git a/rel/sql.go b/rel/sql.go index f6706853..7b8bf282 100644 --- a/rel/sql.go +++ b/rel/sql.go @@ -1294,6 +1294,7 @@ func (m *SqlSelect) CountStar() bool { return false } +// Rewrite rewrite this query and all its sources. func (m *SqlSelect) Rewrite() { for _, f := range m.From { f.Rewrite(m) diff --git a/testutil/testsuite.go b/testutil/testsuite.go index c22393ad..5add4a4b 100644 --- a/testutil/testsuite.go +++ b/testutil/testsuite.go @@ -60,7 +60,7 @@ func RunTestSuite(t TestingT) { TestSelect(t, "SELECT user_id FROM users WHERE (`users.user_id` != NULL)", [][]driver.Value{{"hT2impsabc345c"}, {"9Ip1aKbeZe2njCDM"}, {"hT2impsOPUREcVPc"}}, ) - TestSelect(t, "SELECT email FROM users WHERE interests != NULL)", + TestSelect(t, "SELECT email FROM users WHERE interests != NULL", [][]driver.Value{{"aaron@email.com"}, {"bob@email.com"}}, ) TestSelect(t, "SELECT email FROM users WHERE (`users`.`email` like \"%aaron%\");", @@ -123,6 +123,10 @@ func RunTestSuite(t TestingT) { // RunSimpleSuite run the normal DML SQL test suite. func RunSimpleSuite(t TestingT) { + TestSelect(t, "SELECT email FROM users WHERE interests != NULL)", + [][]driver.Value{{"aaron@email.com"}, {"bob@email.com"}}, + ) + return // // Function in select projected columns that needs to be late evaluated. // // "select json.jmespath(body,\"name\") AS name FROM article WHERE `author` = \"aaron\";", // TestSelect(t, "select json.jmespath(json_data,\"name\") AS name FROM users WHERE `email` = \"aaron@email.com\";", From bbf91f698be06c01d4b391ef630c492f96f48cc1 Mon Sep 17 00:00:00 2001 From: Aaron Raddon Date: Sun, 20 May 2018 18:31:29 -0700 Subject: [PATCH 2/4] sql rewrite --- datasource/mockcsv/mockcsv.go | 2 +- datasource/sqlite/conn.go | 10 +- exec/projection.go | 21 +- plan/planner_select.go | 6 +- plan/projection.go | 12 +- rel/sql.go | 17 +- rel/sql_rewrite.go | 384 +++++++++++++++++++++------------- rel/sql_rewrite_test.go | 246 ++++++++++++++++++++++ rel/sql_test.go | 217 ------------------- schema/apply_schema.go | 4 +- schema/datasource.go | 13 +- schema/source_features.go | 37 ++++ testutil/testsuite.go | 11 +- 13 files changed, 592 insertions(+), 388 deletions(-) create mode 100644 schema/source_features.go diff --git a/datasource/mockcsv/mockcsv.go b/datasource/mockcsv/mockcsv.go index 43b2ba57..4d714f42 100644 --- a/datasource/mockcsv/mockcsv.go +++ b/datasource/mockcsv/mockcsv.go @@ -139,7 +139,7 @@ func (m *Source) loadTable(tableName string) error { //u.Debugf("mockcsv:%p load mockcsv: %q data:%v", m, tableName, csvRaw) csvSource, _ := datasource.NewCsvSource(tableName, 0, sr, make(<-chan bool, 1)) ds := membtree.NewStaticData(tableName) - u.Infof("loaded columns table=%q cols=%v", tableName, csvSource.Columns()) + //u.Infof("loaded columns table=%q cols=%v", tableName, csvSource.Columns()) ds.SetColumns(csvSource.Columns()) m.tables[tableName] = ds diff --git a/datasource/sqlite/conn.go b/datasource/sqlite/conn.go index 8b79cffd..6665a15a 100644 --- a/datasource/sqlite/conn.go +++ b/datasource/sqlite/conn.go @@ -234,11 +234,11 @@ func (m *qryconn) WalkSourceSelect(planner plan.Planner, p *plan.Source) (plan.T sqlSelect := p.Stmt.Source u.Infof("original %s", sqlSelect.String()) - p.Stmt.Source = nil - p.Stmt.Rewrite(sqlSelect) - sqlSelect = p.Stmt.Source - u.Infof("original after From(source) rewrite %s", sqlSelect.String()) - sqlSelect.RewriteAsRawSelect() + //p.Stmt.Source = nil + //p.Stmt.Rewrite(sqlSelect) + //sqlSelect = p.Stmt.Source + //u.Infof("original after From(source) rewrite %s", sqlSelect.String()) + //sqlSelect.RewriteAsRawSelect() m.cols = sqlSelect.Columns.UnAliasedFieldNames() m.colidx = sqlSelect.ColIndexes() diff --git a/exec/projection.go b/exec/projection.go index 01b3e204..4b4e4a4c 100644 --- a/exec/projection.go +++ b/exec/projection.go @@ -122,10 +122,21 @@ func (m *Projection) projectionEvaluator(isFinal bool) MessageHandler { colCt := len(columns) // If we have a projection, use that as col count if m.p.Proj != nil { - colCt = len(m.p.Proj.Columns) + if len(m.p.Proj.Columns) > colCt { + colCt = len(m.p.Proj.Columns) + } else { + u.Warnf("wtf less? %v vs %v", colCt, len(m.p.Proj.Columns)) + } + if len(m.p.Proj.Columns) == 0 { u.Errorf("crap %+v", m.p.Proj) } + for i, col := range m.p.Proj.Columns { + u.Debugf("%d %+v", i, col) + } + for i, col := range columns { + u.Debugf("%d %+v", i, col) + } } rowCt := 0 @@ -148,11 +159,15 @@ func (m *Projection) projectionEvaluator(isFinal bool) MessageHandler { mt, ctx.Session, }, mt.Ts()) - //u.Debugf("about to project: %#v", mt) + u.Debugf("about to project: colCt:%d message:%#v", colCt, mt) colIdx := -1 for _, col := range columns { colIdx += 1 - //u.Debugf("%d colidx:%v sidx: %v pidx:%v key:%q Expr:%v", colIdx, col.Index, col.SourceIndex, col.ParentIndex, col.Key(), col.Expr) + u.Debugf("%d colidx:%v sidx: %v pidx:%v star=%v key:%q Expr:%v", colIdx, col.Index, col.SourceIndex, col.ParentIndex, col.Star, col.Key(), col.Expr) + if len(row) <= colIdx { + row = append(row, nil) + u.Warnf("wtf wrong count %v %v", colIdx, len(row)) + } if isFinal && col.ParentIndex < 0 { continue diff --git a/plan/planner_select.go b/plan/planner_select.go index e4b2d8f7..3b3c0646 100644 --- a/plan/planner_select.go +++ b/plan/planner_select.go @@ -59,7 +59,7 @@ func (m *PlannerDefault) WalkSelect(p *Select) error { */ var prevSource *Source var rootTask Task - + isFinal := len(p.Stmt.From) == 1 for i, from := range p.Stmt.From { // Need to rewrite the From statement to ensure all fields necessary to support @@ -69,7 +69,7 @@ func (m *PlannerDefault) WalkSelect(p *Select) error { from.Rewrite(p.Stmt) u.Debugf("from-rewrite: %s", from.String()) u.Debugf("from.Source: %s", from.Source.String()) - sourceTask, err := NewSource(m.Ctx, from, false) + sourceTask, err := NewSource(m.Ctx, from, isFinal) if err != nil { return nil } @@ -219,6 +219,7 @@ func (m *PlannerDefault) WalkSourceSelect(p *Source) error { // Can do our own planning t, err := sourcePlanner.WalkSourceSelect(m.Planner, p) if err != nil { + u.Warnf("could not source plan %v", err) return err } if t != nil { @@ -229,6 +230,7 @@ func (m *PlannerDefault) WalkSourceSelect(p *Source) error { if schemaCols, ok := p.Conn.(schema.ConnColumns); ok { if err := buildColIndex(schemaCols, p); err != nil { + u.Warnf("could not build index %v", err) return err } } else { diff --git a/plan/projection.go b/plan/projection.go index 6d628e27..c1a86162 100644 --- a/plan/projection.go +++ b/plan/projection.go @@ -157,7 +157,8 @@ func (m *Projection) loadFinal(ctx *Context, isFinal bool) error { func projectionForSourcePlan(plan *Source) error { plan.Proj = rel.NewProjection() - u.Errorf("projection from source") + u.WarnT(9) + u.Errorf("projection. tbl?%v plan.Final?%v source: %s", plan.Tbl != nil, plan.Final, plan.Stmt.Source) // u.Debugf("created plan.Proj *rel.Projection %p", plan.Proj) // Not all Execution run-times support schema. ie, csv files and other "ad-hoc" structures @@ -166,7 +167,7 @@ func projectionForSourcePlan(plan *Source) error { for _, col := range plan.Stmt.Source.Columns { - //u.Debugf("col: %v star?%v", col, col.Star) + u.Debugf("%2d col: %v star?%v inFinal?%v", len(plan.Proj.Columns), col, col.Star, col.InFinalProjection()) if plan.Tbl == nil { if plan.Final { if col.InFinalProjection() { @@ -181,7 +182,7 @@ func projectionForSourcePlan(plan *Source) error { //u.Infof("col add %v for %s", schemaCol.Type.String(), col) plan.Proj.AddColumn(col, schemaCol.ValueType()) } else { - //u.Infof("not in final? %#v", col) + u.Infof("not in final? %#v", col) } } else { plan.Proj.AddColumn(col, schemaCol.ValueType()) @@ -191,7 +192,7 @@ func projectionForSourcePlan(plan *Source) error { if plan.Tbl == nil { u.Warnf("no table?? %v", plan) } else { - //u.Infof("star cols? %v fields: %v", plan.Tbl.FieldPositions, plan.Tbl.Fields) + u.Infof("star cols? %v fields: %v", plan.Tbl.FieldPositions, plan.Tbl.Fields) for _, f := range plan.Tbl.Fields { //u.Infof(" add col %v %+v", f.Name, f) plan.Proj.AddColumnShort(f.Name, f.ValueType()) @@ -229,6 +230,9 @@ func projectionForSourcePlan(plan *Source) error { } } + for _, c := range plan.Proj.Columns { + u.Debugf("col %+v", c) + } //u.Infof("plan.Projection %p cols: %d", plan.Proj, len(plan.Proj.Columns)) return nil } diff --git a/rel/sql.go b/rel/sql.go index 881cb175..9b0f4b5c 100644 --- a/rel/sql.go +++ b/rel/sql.go @@ -1274,7 +1274,7 @@ func (m *SqlSelect) AddColumn(colArg Column) error { return nil } -// Is this a select count(*) FROM ... query? +// CountStar Is this a select count(*) FROM ... query? func (m *SqlSelect) CountStar() bool { if len(m.Columns) != 1 { return false @@ -1295,20 +1295,27 @@ func (m *SqlSelect) CountStar() bool { } // Rewrite take current SqlSelect statement and re-write it -func (m *SqlSelect) Rewrite() { +func (m *SqlSelect) Rewrite() error { for _, f := range m.From { - f.Rewrite(m) + if _, err := f.Rewrite(m); err != nil { + return err + } } + return nil } // RewriteAsRawSelect We are removing Column Aliases "user_id as uid" // as well as functions - used when we are going to defer projection, aggs func (m *SqlSelect) RewriteAsRawSelect() { - RewriteSelect(m) + rewriteSelectStatement(m) } func (m *SqlSource) IsLiteral() bool { return len(m.Name) == 0 } func (m *SqlSource) Keyword() lex.TokenType { return m.Op } + +// SourceName return the sourcename for this sqlselect source, if sub-query +// get name of FROM [name] else if join get name. Corrects for namespacing +// to only get non-namedspaced name. func (m *SqlSource) SourceName() string { if m == nil { return "" @@ -1427,7 +1434,7 @@ func (m *SqlSource) BuildColIndex(colNames []string) error { // Rewrite this Source to act as a stand-alone query to backend // @parentStmt = the parent statement that this a partial source to -func (m *SqlSource) Rewrite(parentStmt *SqlSelect) *SqlSelect { +func (m *SqlSource) Rewrite(parentStmt *SqlSelect) (*SqlSelect, error) { return RewriteSqlSource(m, parentStmt) } diff --git a/rel/sql_rewrite.go b/rel/sql_rewrite.go index c78d6ea2..855b109d 100644 --- a/rel/sql_rewrite.go +++ b/rel/sql_rewrite.go @@ -1,33 +1,85 @@ package rel import ( + fmt "fmt" "strings" u "github.com/araddon/gou" "github.com/araddon/qlbridge/expr" "github.com/araddon/qlbridge/lex" + "github.com/araddon/qlbridge/schema" ) -// RewriteSelect We are removing Column Aliases "user_id as uid" +type ( + rewriteSelect struct { + sel *SqlSelect + cols map[string]bool + matchSource string + features *schema.DataSourceFeatures + result *RewriteSelectResult + } + // RewriteSelectResult describes the result of a re-write statement to + // tell the planner which poly-fill features are needed based on re-write. + RewriteSelectResult struct { + NeedsProjection bool + NeedsWhere bool + NeedsGroupBy bool + } +) + +func newRewriteSelect(sel *SqlSelect) *rewriteSelect { + rw := &rewriteSelect{ + sel: sel, + cols: make(map[string]bool), + features: schema.FeaturesDefault(), + result: &RewriteSelectResult{}, + } + return rw +} + +// ReWriteStatement given SqlStatement +func ReWriteStatement(input SqlStatement) error { + switch stmt := input.(type) { + case *SqlSelect: + return rewriteSelectStatement(stmt) + default: + return fmt.Errorf("Rewrite not implemented for %T", input) + } +} + +// rewriteSelectStatement We are removing Column Aliases "user_id as uid" // as well as functions - used when we are going to defer projection, aggs -func RewriteSelect(m *SqlSelect) { - originalCols := m.Columns - m.Columns = make(Columns, 0, len(originalCols)+5) - rewriteIntoProjection(m, originalCols) - rewriteIntoProjection(m, m.GroupBy) - if m.Where != nil { - colsToAdd := expr.FindAllIdentityField(m.Where.Expr) - addIntoProjection(m, colsToAdd) +func rewriteSelectStatement(sel *SqlSelect) error { + rw := newRewriteSelect(sel) + + originalCols := sel.Columns + sel.Columns = make(Columns, 0, len(originalCols)+5) + if err := rw.intoProjection(sel, originalCols); err != nil { + return err + } + if err := rw.intoProjection(sel, sel.GroupBy); err != nil { + return err } - rewriteIntoProjection(m, m.OrderBy) + if sel.Where != nil { + cols := expr.FindAllIdentityField(sel.Where.Expr) + for _, col := range cols { + nc := NewColumn(col) + nc.ParentIndex = -1 + rw.addColumn(*nc) + } + } + if err := rw.intoProjection(sel, sel.OrderBy); err != nil { + return err + } + return nil } -// RewriteSqlSource this Source to act as a stand-alone query to backend +// RewriteSqlSource this SqlSource to act as a stand-alone query to backend // @parentStmt = the parent statement that this a partial source to -func RewriteSqlSource(m *SqlSource, parentStmt *SqlSelect) *SqlSelect { +func RewriteSqlSource(source *SqlSource, parentStmt *SqlSelect) (*SqlSelect, error) { - if m.Source != nil { - return m.Source + if source.Source != nil { + return source.Source, nil } // Rewrite this SqlSource for the given parent, ie // 1) find the column names we need to request from source including those used in join/where @@ -36,129 +88,173 @@ func RewriteSqlSource(m *SqlSource, parentStmt *SqlSelect) *SqlSelect { // sides should be aliased towards the left-hand join portion // 4) if we need different sort for our join algo? - newCols := make(Columns, 0) - if !parentStmt.Star { - for idx, col := range parentStmt.Columns { - left, _, hasLeft := col.LeftRight() - if !hasLeft { - // Was not left/right qualified, so use as is? or is this an error? - // what is official sql grammar on this? - newCol := col.Copy() - newCol.ParentIndex = idx - newCol.Index = len(newCols) - newCols = append(newCols, newCol) + sql2 := &SqlSelect{Columns: make(Columns, 0), Star: parentStmt.Star} + rw := newRewriteSelect(sql2) + rw.matchSource = source.Alias + originalCols := parentStmt.Columns - } else if hasLeft && left == m.Alias { - newCol := col.CopyRewrite(m.Alias) - newCol.ParentIndex = idx - newCol.SourceIndex = len(newCols) - newCol.Index = len(newCols) - newCols = append(newCols, newCol) - } - } + if err := rw.intoProjection(sql2, originalCols); err != nil { + return nil, err } - + //u.Debugf("after into projection: %s", sql2.Columns) // TODO: // - rewrite the Sort // - rewrite the group-by - sql2 := &SqlSelect{Columns: newCols, Star: parentStmt.Star} - m.joinNodes = make([]expr.Node, 0) - if m.SubQuery != nil { - if len(m.SubQuery.From) != 1 { - u.Errorf("Not supported, nested subQuery %v", m.SubQuery.String()) + + source.joinNodes = make([]expr.Node, 0) + if source.SubQuery != nil { + if len(source.SubQuery.From) != 1 { + u.Errorf("Not supported, nested subQuery %v", source.SubQuery.String()) } else { - sql2.From = append(sql2.From, &SqlSource{Name: m.SubQuery.From[0].Name}) + sql2.From = append(sql2.From, &SqlSource{Name: source.SubQuery.From[0].Name}) } } else { - sql2.From = append(sql2.From, &SqlSource{Name: m.Name}) + sql2.From = append(sql2.From, &SqlSource{Name: source.Name}) } for _, from := range parentStmt.From { // We need to check each participant in the Join for possible // columns which need to be re-written - sql2.Columns = columnsFromJoin(m, from.JoinExpr, sql2.Columns) + rw.columnsFromExpression(source, from.JoinExpr) // We also need to create an expression used for evaluating // the values of Join "Keys" if from.JoinExpr != nil { - joinNodesForFrom(parentStmt, m, from.JoinExpr, 0) + rw.joinNodesForFrom(parentStmt, source, from.JoinExpr, 0) } } + //u.Debugf("after FROM: %s", sql2.Columns) if parentStmt.Where != nil { - node, cols := rewriteWhere(parentStmt, m, parentStmt.Where.Expr, make(Columns, 0)) + node := rw.rewriteWhere(parentStmt, source, parentStmt.Where.Expr) if node != nil { sql2.Where = &SqlWhere{Expr: node} } - if len(cols) > 0 { - parentIdx := len(parentStmt.Columns) - for _, col := range cols { - col.Index = len(sql2.Columns) - col.ParentIndex = parentIdx - parentIdx++ - sql2.Columns = append(sql2.Columns, col) + /* + if len(cols) > 0 { + parentIdx := len(parentStmt.Columns) + for _, col := range cols { + col.Index = len(sql2.Columns) + col.ParentIndex = parentIdx + parentIdx++ + sql2.Columns = append(sql2.Columns, col) + } } - } + */ } - m.Source = sql2 - m.cols = sql2.UnAliasedColumns() - return sql2 + //u.Debugf("after WHERE: %s", sql2.Columns) + source.Source = sql2 + source.cols = sql2.UnAliasedColumns() + return sql2, nil } -func rewriteIntoProjection(sel *SqlSelect, m Columns) { - if len(m) == 0 { +func (m *rewriteSelect) addColumn(col Column) { + col.Index = len(m.sel.Columns) + if col.Star { + if _, found := m.cols["*"]; found { + //u.Debugf("dupe %+v", col) + return + } + m.cols["*"] = true + m.sel.AddColumn(col) return } - colsToAdd := make([]string, 0) - for _, c := range m { - // u.Infof("source=%-15s as=%-15s exprT:%T expr=%s star:%v", c.As, c.SourceField, c.Expr, c.Expr, c.Star) + if _, found := m.cols[col.SourceField]; found { + //u.Debugf("dupe %+v", col) + return + } + + //u.Infof("adding col %+v", col) + m.cols[col.SourceField] = true + m.sel.AddColumn(col) +} +func (m *rewriteSelect) intoProjection(sel *SqlSelect, cols Columns) error { + if len(cols) == 0 { + return nil + } + /* + if !parentStmt.Star { + for idx, col := range parentStmt.Columns { + left, _, hasLeft := col.LeftRight() + if !hasLeft { + // Was not left/right qualified, so use as is? or is this an error? + // what is official sql grammar on this? + newCol := col.Copy() + newCol.ParentIndex = idx + newCol.Index = len(newCols) + newCols = append(newCols, newCol) + + } else if hasLeft && left == m.Alias { + newCol := col.CopyRewrite(m.Alias) + newCol.ParentIndex = idx + newCol.SourceIndex = len(newCols) + newCol.Index = len(newCols) + newCols = append(newCols, newCol) + } + } + } + */ + for i, c := range cols { + left, _, hasLeft := c.LeftRight() + if !hasLeft { + // ?? + } else if hasLeft && left == m.matchSource { + // ok + c = c.CopyRewrite(m.matchSource) + } else { + //u.Warnf("no.... %v", c) + continue + } + + //u.Infof("as=%-15s source=%-15s exprT:%T expr=%s star:%v", c.As, c.SourceField, c.Expr, c.Expr, c.Star) switch n := c.Expr.(type) { case *expr.IdentityNode: - colsToAdd = append(colsToAdd, c.SourceField) + nc := NewColumn(strings.ToLower(c.SourceField)) + nc.ParentIndex = i + nc.Expr = n + m.addColumn(*nc) case *expr.FuncNode: - + // TODO: use features. idents := expr.FindAllIdentities(n) for _, in := range idents { - _, r, _ := in.LeftRight() - colsToAdd = append(colsToAdd, r) + _, right, _ := in.LeftRight() + nc := NewColumn(strings.ToLower(right)) + nc.ParentIndex = i + nc.Expr = in + m.addColumn(*nc) } - + case *expr.NumberNode, *expr.NullNode, *expr.StringNode: + // literals + nc := NewColumn(strings.ToLower(n.String())) + nc.ParentIndex = i + nc.Expr = n + m.addColumn(*nc) case nil: if c.Star { - colsToAdd = append(colsToAdd, "*") + nc := c.Copy() + m.addColumn(*nc) } else { u.Warnf("unhandled column? %T %s", n, n) } - default: u.Warnf("unhandled column? %T %s", n, n) } } - addIntoProjection(sel, colsToAdd) -} -func addIntoProjection(sel *SqlSelect, newCols []string) { - notExists := make(map[string]bool) - for _, colName := range newCols { - colName = strings.ToLower(colName) - found := false - for _, c := range sel.Columns { - if c.SourceField == colName { - // already in projection - found = true - break - } - } - if !found { - notExists[colName] = true - if colName == "*" { - sel.AddColumn(Column{Star: true}) - } else { - nc := NewColumn(colName) - sel.AddColumn(*nc) - } - } - } + return nil } -func rewriteWhere(stmt *SqlSelect, from *SqlSource, node expr.Node, cols Columns) (expr.Node, Columns) { + +// func (m *rewriteSelect) addIntoProjection(sel *SqlSelect, colsToAdd map[string]int) { +// for colName, idx := range colsToAdd { +// colName = strings.ToLower(colName) +// if colName == "*" { +// m.addColumn(Column{Star: true, ParentIndex: idx}) +// } else { +// nc := NewColumn(colName) +// nc.ParentIndex = idx +// m.addColumn(*nc) +// } +// } +// } +func (m *rewriteSelect) rewriteWhere(stmt *SqlSelect, from *SqlSource, node expr.Node) expr.Node { //u.Debugf("rewrite where %s", node) switch nt := node.(type) { case *expr.IdentityNode: @@ -166,42 +262,43 @@ func rewriteWhere(stmt *SqlSelect, from *SqlSource, node expr.Node, cols Columns //u.Debugf("rewriteWhere from.Name:%v l:%v r:%v", from.alias, left, right) if left == from.alias { in := expr.IdentityNode{Text: right} - cols = append(cols, NewColumn(right)) - //u.Warnf("nice, found it! in = %v cols:%d", in, len(cols)) - return &in, cols + nc := *NewColumn(right) + nc.ParentIndex = -1 + m.addColumn(nc) + return &in } else { //u.Warnf("what to do? source:%v %v", from.alias, nt.String()) } } else { //u.Debugf("returning original: %s", nt) - return node, cols + return node } case *expr.NumberNode, *expr.NullNode, *expr.StringNode: - return nt, cols + return nt case *expr.BinaryNode: //u.Infof("binaryNode T:%v", nt.Operator.T.String()) switch nt.Operator.T { case lex.TokenAnd, lex.TokenLogicAnd, lex.TokenLogicOr: var n1, n2 expr.Node - n1, cols = rewriteWhere(stmt, from, nt.Args[0], cols) - n2, cols = rewriteWhere(stmt, from, nt.Args[1], cols) + n1 = m.rewriteWhere(stmt, from, nt.Args[0]) + n2 = m.rewriteWhere(stmt, from, nt.Args[1]) if n1 != nil && n2 != nil { - return &expr.BinaryNode{Operator: nt.Operator, Args: []expr.Node{n1, n2}}, cols + return &expr.BinaryNode{Operator: nt.Operator, Args: []expr.Node{n1, n2}} } else if n1 != nil { - return n1, cols + return n1 } else if n2 != nil { - return n2, cols + return n2 } else { //u.Warnf("n1=%#v n2=%#v %#v", n1, n2, nt) } case lex.TokenEqual, lex.TokenEqualEqual, lex.TokenGT, lex.TokenGE, lex.TokenLE, lex.TokenNE: var n1, n2 expr.Node - n1, cols = rewriteWhere(stmt, from, nt.Args[0], cols) - n2, cols = rewriteWhere(stmt, from, nt.Args[1], cols) + n1 = m.rewriteWhere(stmt, from, nt.Args[0]) + n2 = m.rewriteWhere(stmt, from, nt.Args[1]) //u.Debugf("n1=%#v n2=%#v %#v", n1, n2, nt) if n1 != nil && n2 != nil { - return &expr.BinaryNode{Operator: nt.Operator, Args: []expr.Node{n1, n2}}, cols + return &expr.BinaryNode{Operator: nt.Operator, Args: []expr.Node{n1, n2}} // } else if n1 != nil { // return n1 // } else if n2 != nil { @@ -212,14 +309,25 @@ func rewriteWhere(stmt *SqlSelect, from *SqlSource, node expr.Node, cols Columns default: //u.Warnf("un-implemented op: %#v", nt) } + case *expr.FuncNode: + // TODO: use features. + idents := expr.FindAllIdentities(nt) + for _, in := range idents { + _, right, _ := in.LeftRight() + nc := *NewColumn(right) + nc.ParentIndex = -1 + nc.Expr = in + m.addColumn(nc) + } + default: u.Warnf("%T node types are not suppored yet for where rewrite", node) } //u.Warnf("nil?? %T %s %#v", node, node, node) - return nil, cols + return nil } -func joinNodesForFrom(stmt *SqlSelect, from *SqlSource, node expr.Node, depth int) expr.Node { +func (m *rewriteSelect) joinNodesForFrom(stmt *SqlSelect, from *SqlSource, node expr.Node, depth int) expr.Node { switch nt := node.(type) { case *expr.IdentityNode: @@ -266,8 +374,8 @@ func joinNodesForFrom(stmt *SqlSelect, from *SqlSource, node expr.Node, depth in //u.Infof("%v binaryNode %v", depth, nt.String()) switch nt.Operator.T { case lex.TokenAnd, lex.TokenLogicAnd, lex.TokenLogicOr: - n1 := joinNodesForFrom(stmt, from, nt.Args[0], depth+1) - n2 := joinNodesForFrom(stmt, from, nt.Args[1], depth+1) + n1 := m.joinNodesForFrom(stmt, from, nt.Args[0], depth+1) + n2 := m.joinNodesForFrom(stmt, from, nt.Args[1], depth+1) if n1 != nil && n2 != nil { //u.Debugf("%d neither nil: n1=%v n2=%v %q", depth, n1, n2, nt.String()) @@ -282,8 +390,8 @@ func joinNodesForFrom(stmt *SqlSelect, from *SqlSource, node expr.Node, depth in //u.Warnf("%d n1=%#v n2=%#v %#v", depth, n1, n2, nt) } case lex.TokenEqual, lex.TokenEqualEqual, lex.TokenGT, lex.TokenGE, lex.TokenLE, lex.TokenNE: - n1 := joinNodesForFrom(stmt, from, nt.Args[0], depth+1) - n2 := joinNodesForFrom(stmt, from, nt.Args[1], depth+1) + n1 := m.joinNodesForFrom(stmt, from, nt.Args[0], depth+1) + n2 := m.joinNodesForFrom(stmt, from, nt.Args[1], depth+1) if n1 != nil && n2 != nil { //u.Debugf("%d neither nil: n1=%v n2=%v %q", depth, n1, n2, nt.String()) @@ -316,54 +424,42 @@ func joinNodesForFrom(stmt *SqlSelect, from *SqlSource, node expr.Node, depth in return nil } -// We need to find all columns used in the given Node (where/join expression) -// to ensure we have those columns in projection for sub-queries -func columnsFromJoin(from *SqlSource, node expr.Node, cols Columns) Columns { +// We need to find all columns used in the given Node (where or join expression) +// to ensure we have those columns in projection. +func (m *rewriteSelect) columnsFromExpression(from *SqlSource, node expr.Node) error { if node == nil { - return cols + return nil } //u.Debugf("columnsFromJoin() T:%T node=%q", node, node.String()) switch nt := node.(type) { case *expr.IdentityNode: if left, right, ok := nt.LeftRight(); ok { - //u.Debugf("from.Name:%v AS %v Joinnode l:%v r:%v %#v", from.Name, from.alias, left, right, nt) - //u.Warnf("check cols against join expr arg: %#v", nt) - if left == from.alias { - found := false - for _, col := range cols { - colLeft, colRight, _ := col.LeftRight() - //u.Debugf("left='%s' colLeft='%s' right='%s' %#v", left, colLeft, colRight, col) - //u.Debugf("col: From %s AS '%s' '%s'.'%s' JoinExpr: '%v'.'%v' col:%#v", from.Name, from.alias, colLeft, colRight, left, right, col) - if left == colLeft || colRight == right { - found = true - //u.Infof("columnsFromJoin from.Name:%v l:%v r:%v", from.alias, left, right) - } else { - //u.Warnf("not? from.Name:%v l:%v r:%v col: P:%p %#v", from.alias, left, right, col, col) - } - } - if !found { - //u.Debugf("columnsFromJoin from.Name:%v l:%v r:%v", from.alias, left, right) - newCol := &Column{As: right, SourceField: right, Expr: &expr.IdentityNode{Text: right}} - newCol.Index = len(cols) - newCol.ParentIndex = -1 // if -1, we don't need in parent index - cols = append(cols, newCol) - //u.Warnf("added col %s idx:%d pidx:%v", right, newCol.Index, newCol.Index) - } + if left != from.alias { + return nil + } + if _, found := m.cols[strings.ToLower(right)]; found { + return nil } + + newCol := Column{As: right, SourceField: right, Expr: &expr.IdentityNode{Text: right}} + newCol.ParentIndex = -1 // if -1, we don't need in parent projection + m.addColumn(newCol) + //u.Warnf("added col %s idx:%d pidx:%v", right, newCol.Index, newCol.Index) } + case *expr.FuncNode: //u.Warnf("columnsFromJoin func node: %s", nt.String()) for _, arg := range nt.Args { - cols = columnsFromJoin(from, arg, cols) + m.columnsFromExpression(from, arg) } case *expr.BinaryNode: switch nt.Operator.T { case lex.TokenAnd, lex.TokenLogicAnd, lex.TokenLogicOr: - cols = columnsFromJoin(from, nt.Args[0], cols) - cols = columnsFromJoin(from, nt.Args[1], cols) + m.columnsFromExpression(from, nt.Args[0]) + m.columnsFromExpression(from, nt.Args[1]) case lex.TokenEqual, lex.TokenEqualEqual: - cols = columnsFromJoin(from, nt.Args[0], cols) - cols = columnsFromJoin(from, nt.Args[1], cols) + m.columnsFromExpression(from, nt.Args[0]) + m.columnsFromExpression(from, nt.Args[1]) default: u.Warnf("un-implemented op: %v", nt.Operator) } @@ -371,7 +467,7 @@ func columnsFromJoin(from *SqlSource, node expr.Node, cols Columns) Columns { u.LogTracef(u.INFO, "whoops") u.Warnf("%T node types are not suppored yet for join rewrite %s", node, from.String()) } - return cols + return nil } // Remove any aliases diff --git a/rel/sql_rewrite_test.go b/rel/sql_rewrite_test.go index 0e2a2bd7..dad896c8 100644 --- a/rel/sql_rewrite_test.go +++ b/rel/sql_rewrite_test.go @@ -1 +1,247 @@ package rel_test + +import ( + "strings" + "testing" + + u "github.com/araddon/gou" + "github.com/stretchr/testify/assert" + + "github.com/araddon/qlbridge/rel" + "github.com/araddon/qlbridge/schema" +) + +func parseFeatures(t testing.TB, f *schema.DataSourceFeatures, q string) *rel.SqlSelect { + stmt, err := rel.ParseSqlSelect(q) + assert.Equal(t, nil, err, "expected no error but got %v for %s", err, q) + assert.NotEqual(t, nil, stmt) + err = stmt.Rewrite() + assert.Equal(t, nil, err) + return stmt +} +func parse(t testing.TB, q string) *rel.SqlSelect { + return parseFeatures(t, schema.FeaturesDefault(), q) +} + +func TestSqlSelectReWrite(t *testing.T) { + ss := parse(t, "SELECT user_id FROM users WHERE (`users.user_id` != NULL)") + assert.Equal(t, 1, len(ss.From[0].Source.Columns)) + ss = parse(t, `select exists(email), email FROM users WHERE yy(reg_date) > 10;`) + assert.Equal(t, 2, len(ss.From[0].Source.Columns)) +} + +func TestSqlRewriteTemp(t *testing.T) { + + s := `SELECT u.user_id, o.item_id, u.reg_date, u.email, o.price, o.order_date FROM users AS u + INNER JOIN ( + SELECT price, order_date, user_id from ORDERS + WHERE user_id IS NOT NULL AND price > 10 + ) AS o + ON u.user_id = o.user_id + ` + sql := parseOrPanic(t, s).(*rel.SqlSelect) + assert.True(t, len(sql.Columns) == 6, "has 6 cols: %v", len(sql.Columns)) + assert.True(t, len(sql.From) == 2, "has 2 sources: %v", len(sql.From)) + + assert.True(t, sql.String() == `SELECT u.user_id, o.item_id, u.reg_date, u.email, o.price, o.order_date FROM users AS u + INNER JOIN ( + SELECT price, order_date, user_id FROM ORDERS WHERE user_id != NULL AND price > 10 + ) AS o ON u.user_id = o.user_id`, "Wrong Full SQL?: '%v'", sql.String()) +} + +func TestSqlRewrite(t *testing.T) { + t.Parallel() + /* + SQL Re-writing is to take select statement with multiple sources (joins, sub-select) + and rewrite these sub-statements/sources into standalone statements + and prepare the column name, index mapping + + - Do we want to send the columns fully aliased? ie + SELECT name AS u.name, email as u.email, user_id as u.user_id FROM users + */ + s := `SELECT u.name, o.item_id, u.email, o.price + FROM users AS u INNER JOIN orders AS o + ON u.user_id = o.user_id;` + sql := parseOrPanic(t, s).(*rel.SqlSelect) + err := sql.Finalize() + assert.True(t, err == nil, "no error: %v", err) + assert.True(t, len(sql.Columns) == 4, "has 4 cols: %v", len(sql.Columns)) + assert.True(t, len(sql.From) == 2, "has 2 sources: %v", len(sql.From)) + + // Test the Left/Right column level parsing + // TODO: This field should not be u.name? sourcefield should be name right? as = u.name? + col, _ := sql.Columns.ByName("u.name") + assert.True(t, col.As == "u.name", "col.As=%s", col.As) + left, right, ok := col.LeftRight() + //u.Debugf("left=%v right=%v ok%v", left, right, ok) + assert.True(t, left == "u" && right == "name" && ok == true) + + rw1, _ := sql.From[0].Rewrite(sql) + assert.True(t, rw1 != nil, "should not be nil:") + assert.True(t, len(rw1.Columns) == 3, "has 3 cols: %v", rw1.Columns.String()) + //u.Infof("SQL?: '%v'", rw1.String()) + assert.Equal(t, rw1.String(), "SELECT name, email, user_id FROM users", "%v", rw1.String()) + + rw1, _ = sql.From[1].Rewrite(sql) + assert.True(t, rw1 != nil, "should not be nil:") + assert.True(t, len(rw1.Columns) == 3, "has 3 cols: %v", rw1.Columns.String()) + //u.Infof("SQL?: '%v'", rw1.String()) + assert.True(t, rw1.String() == "SELECT item_id, price, user_id FROM orders", "%v", rw1.String()) + + // Do we change? + //assert.Equal(t, sql.Columns.FieldNames(), []string{"user_id", "email", "item_id", "price"}) + + s = `SELECT u.name, u.email, b.title + FROM users AS u INNER JOIN blog AS b + ON u.name = b.author;` + sql = parseOrPanic(t, s).(*rel.SqlSelect) + assert.True(t, len(sql.Columns) == 3, "has 3 cols: %v", len(sql.Columns)) + assert.True(t, len(sql.From) == 2, "has 2 sources: %v", len(sql.From)) + rw1, _ = sql.From[0].Rewrite(sql) + assert.True(t, rw1 != nil, "should not be nil:") + assert.True(t, len(rw1.Columns) == 2, "has 2 cols: %v", rw1.Columns.String()) + //u.Infof("SQL?: '%v'", rw1.String()) + assert.True(t, rw1.String() == "SELECT name, email FROM users", "%v", rw1.String()) + jn := sql.From[0].JoinNodes() + assert.True(t, len(jn) == 1, "%v", jn) + assert.True(t, jn[0].String() == "name", "wanted 1 node %v", jn[0].String()) + cols := sql.From[0].UnAliasedColumns() + assert.True(t, len(cols) == 2, "Should have 2: %#v", cols) + //u.Infof("cols: %#v", cols) + rw1, _ = sql.From[1].Rewrite(sql) + assert.True(t, rw1 != nil, "should not be nil:") + assert.True(t, len(rw1.Columns) == 2, "has 2 cols: %v", rw1.Columns.String()) + // TODO: verify that we can rewrite sql for aliases + // jn, _ = sql.From[1].JoinValueExpr() + // assert.True(t, jn.String() == "name", "%v", jn.String()) + // u.Infof("SQL?: '%v'", rw1.String()) + // assert.True(t, rw1.String() == "SELECT title, author as name FROM blog", "%v", rw1.String()) + + s = `SELECT u.name, u.email, b.title + FROM users AS u INNER JOIN blog AS b + ON tolower(u.author) = b.author;` + sql = parseOrPanic(t, s).(*rel.SqlSelect) + sql.Rewrite() + selu := sql.From[0].Source + assert.True(t, len(selu.Columns) == 3, "user 3 cols: %v", selu.Columns.String()) + assert.True(t, selu.String() == "SELECT name, email, author FROM users", "%v", selu.String()) + jn = sql.From[0].JoinNodes() + assert.True(t, len(jn) == 1, "wanted 1 node but got fromP: %p %v", sql.From[0], jn) + assert.True(t, jn[0].String() == "tolower(author)", "wanted 1 node %v", jn[0].String()) + cols = sql.From[0].UnAliasedColumns() + assert.True(t, len(cols) == 3, "Should have 3: %#v", cols) + + // Now lets try compound join keys + s = `SELECT u.name, u.email, b.title + FROM users AS u INNER JOIN blog AS b + ON u.name = b.author and tolower(u.alias) = b.alias;` + sql = parseOrPanic(t, s).(*rel.SqlSelect) + sql.Rewrite() + assert.True(t, len(sql.Columns) == 3, "has 3 cols: %v", len(sql.Columns)) + assert.True(t, len(sql.From) == 2, "has 2 sources: %v", len(sql.From)) + rw1 = sql.From[0].Source + assert.True(t, rw1 != nil, "should not be nil:") + assert.True(t, len(rw1.Columns) == 3, "has 3 cols: %v", rw1.Columns.String()) + //u.Infof("SQL?: '%v'", rw1.String()) + assert.True(t, rw1.String() == "SELECT name, email, alias FROM users", "%v", rw1.String()) + jn = sql.From[0].JoinNodes() + assert.True(t, len(jn) == 2, "wanted 2 join nodes but %v", len(jn)) + assert.True(t, jn[0].String() == "name", `want "name" %v`, jn[0].String()) + assert.True(t, jn[1].String() == "tolower(alias)", `want "tolower(alias)" but got %q`, jn[1].String()) + cols = sql.From[0].UnAliasedColumns() + assert.True(t, len(cols) == 3, "Should have 3: %#v", cols) + //u.Infof("cols: %#v", cols) + rw1 = sql.From[1].Source + assert.True(t, rw1 != nil, "should not be nil:") + assert.True(t, len(rw1.Columns) == 3, "has 3 cols: %v", rw1.Columns.String()) + + // This test, is looking at these aspects of rewrite + // 1 the dotted notation of 'repostory.name' ensuring we have removed the p. + // 2 where clause + s = ` + SELECT + p.actor, p.repository.name, a.title + FROM article AS a + INNER JOIN github_push AS p + ON p.actor = a.author + WHERE p.follow_ct > 20 AND a.email IS NOT NULL + ` + sql = parseOrPanic(t, s).(*rel.SqlSelect) + assert.True(t, len(sql.Columns) == 3, "has 3 cols: %v", len(sql.Columns)) + assert.True(t, len(sql.From) == 2, "has 2 sources: %v", len(sql.From)) + + rw0, _ := sql.From[0].Rewrite(sql) + rw1, _ = sql.From[1].Rewrite(sql) + assert.True(t, rw0 != nil, "should not be nil:") + assert.True(t, len(rw0.Columns) == 3, "has 3 cols: %v", rw0.String()) + assert.True(t, len(sql.From[0].Source.Columns) == 3, "has 3 cols? %s", sql.From[0].Source) + assert.True(t, rw0.String() == "SELECT title, author, email FROM article WHERE email != NULL", "Wrong SQL 0: %v", rw0.String()) + assert.True(t, rw1 != nil, "should not be nil:") + assert.True(t, len(rw1.Columns) == 3, "has 3 cols: %v", rw1.Columns.String()) + assert.True(t, len(sql.From[1].Source.Columns) == 3, "has 3 cols? %s", sql.From[1].Source) + assert.True(t, rw1.String() == "SELECT actor, `repository.name`, follow_ct FROM github_push WHERE follow_ct > 20", "Wrong SQL 1: %v", rw1.String()) + + // Original should still be the same + parts := strings.Split(sql.String(), "\n") + for _, p := range parts { + u.Debugf("----%v----", p) + } + assert.True(t, parts[0] == "SELECT p.actor, p.`repository.name`, a.title FROM article AS a", "Wrong Full SQL?: '%v'", parts[0]) + assert.True(t, parts[1] == ` INNER JOIN github_push AS p ON p.actor = a.author WHERE p.follow_ct > 20 AND a.email != NULL`, "Wrong Full SQL?: '%v'", parts[1]) + assert.True(t, sql.String() == `SELECT p.actor, p.`+"`repository.name`"+`, a.title FROM article AS a + INNER JOIN github_push AS p ON p.actor = a.author WHERE p.follow_ct > 20 AND a.email != NULL`, "Wrong Full SQL?: '%v'", sql.String()) + + s = `SELECT u.user_id, o.item_id, u.reg_date, u.email, o.price, o.order_date FROM users AS u + INNER JOIN ( + SELECT price, order_date, user_id from ORDERS + WHERE user_id IS NOT NULL AND price > 10 + ) AS o + ON u.user_id = o.user_id + ` + sql = parseOrPanic(t, s).(*rel.SqlSelect) + assert.True(t, len(sql.Columns) == 6, "has 6 cols: %v", len(sql.Columns)) + assert.True(t, len(sql.From) == 2, "has 2 sources: %v", len(sql.From)) + + assert.True(t, sql.String() == `SELECT u.user_id, o.item_id, u.reg_date, u.email, o.price, o.order_date FROM users AS u + INNER JOIN ( + SELECT price, order_date, user_id FROM ORDERS WHERE user_id != NULL AND price > 10 + ) AS o ON u.user_id = o.user_id`, "Wrong Full SQL?: '%v'", sql.String()) + + // Rewrite to remove functions, and aliasing to send all fields needed down to source + // used when we are going to poly-fill + s = `SELECT count AS ct, name as nm, todate(myfield) AS mydate FROM user` + sql = parseOrPanic(t, s).(*rel.SqlSelect) + sql.RewriteAsRawSelect() + assert.True(t, sql.String() == `SELECT count, name, myfield FROM user`, "Wrong rewrite SQL?: '%v'", sql.String()) + + // Now ensure a group by, and where columns + s = `SELECT name as nm, todate(myfield) AS mydate FROM user WHERE created > todate("2016-01-01") GROUP BY referral;` + sql = parseOrPanic(t, s).(*rel.SqlSelect) + sql.RewriteAsRawSelect() + assert.True(t, sql.String() == `SELECT name, myfield, referral, created FROM user WHERE created > todate("2016-01-01") GROUP BY referral`, "Wrong rewrite SQL?: '%v'", sql.String()) + + //assert.True(t, sql.From[1].Name == "ORDERS", "orders? %q", sql.From[1].Name) + // sql.From[0].Rewrite(sql) + // sql.From[1].Rewrite(sql) + // assert.True(t, sql.From[0].Source.String() == `SELECT user_id, reg_date, email FROM users`, "Wrong Full SQL?: '%v'", sql.From[0].Source.String()) + // assert.True(t, sql.From[1].Source.String() == `SELECT item_id, price, order_date, user_id FROM ORDERS`, "Wrong Full SQL?: '%v'", sql.From[1].Source.String()) + + // s = `SELECT aa.*, + // bb.meal + // FROM table1 aa + // INNER JOIN table2 bb + // ON aa.tableseat = bb.tableseat AND + // aa.weddingtable = bb.weddingtable + // INNER JOIN + // ( + // SELECT a.tableSeat + // FROM table1 a + // INNER JOIN table2 b + // ON a.tableseat = b.tableseat AND + // a.weddingtable = b.weddingtable + // WHERE b.meal IN ('chicken', 'steak') + // GROUP by a.tableSeat + // HAVING COUNT(DISTINCT b.Meal) = 2 + // ) c ON aa.tableseat = c.tableSeat + // ` +} diff --git a/rel/sql_test.go b/rel/sql_test.go index e6db444c..f6de2b9e 100644 --- a/rel/sql_test.go +++ b/rel/sql_test.go @@ -3,7 +3,6 @@ package rel_test import ( "fmt" "reflect" - "strings" "testing" u "github.com/araddon/gou" @@ -146,222 +145,6 @@ func compareNode(t *testing.T, n1, n2 expr.Node) { assert.True(t, rv1.Kind() == rv2.Kind(), "kinds match: %T %T", n1, n2) } -func TestSqlRewriteTemp(t *testing.T) { - - s := `SELECT u.user_id, o.item_id, u.reg_date, u.email, o.price, o.order_date FROM users AS u - INNER JOIN ( - SELECT price, order_date, user_id from ORDERS - WHERE user_id IS NOT NULL AND price > 10 - ) AS o - ON u.user_id = o.user_id - ` - sql := parseOrPanic(t, s).(*rel.SqlSelect) - assert.True(t, len(sql.Columns) == 6, "has 6 cols: %v", len(sql.Columns)) - assert.True(t, len(sql.From) == 2, "has 2 sources: %v", len(sql.From)) - - assert.True(t, sql.String() == `SELECT u.user_id, o.item_id, u.reg_date, u.email, o.price, o.order_date FROM users AS u - INNER JOIN ( - SELECT price, order_date, user_id FROM ORDERS WHERE user_id != NULL AND price > 10 - ) AS o ON u.user_id = o.user_id`, "Wrong Full SQL?: '%v'", sql.String()) -} - -func TestSqlRewrite(t *testing.T) { - t.Parallel() - /* - SQL Re-writing is to take select statement with multiple sources (joins, sub-select) - and rewrite these sub-statements/sources into standalone statements - and prepare the column name, index mapping - - - Do we want to send the columns fully aliased? ie - SELECT name AS u.name, email as u.email, user_id as u.user_id FROM users - */ - s := `SELECT u.name, o.item_id, u.email, o.price - FROM users AS u INNER JOIN orders AS o - ON u.user_id = o.user_id;` - sql := parseOrPanic(t, s).(*rel.SqlSelect) - err := sql.Finalize() - assert.True(t, err == nil, "no error: %v", err) - assert.True(t, len(sql.Columns) == 4, "has 4 cols: %v", len(sql.Columns)) - assert.True(t, len(sql.From) == 2, "has 2 sources: %v", len(sql.From)) - - // Test the Left/Right column level parsing - // TODO: This field should not be u.name? sourcefield should be name right? as = u.name? - col, _ := sql.Columns.ByName("u.name") - assert.True(t, col.As == "u.name", "col.As=%s", col.As) - left, right, ok := col.LeftRight() - //u.Debugf("left=%v right=%v ok%v", left, right, ok) - assert.True(t, left == "u" && right == "name" && ok == true) - - rw1 := sql.From[0].Rewrite(sql) - assert.True(t, rw1 != nil, "should not be nil:") - assert.True(t, len(rw1.Columns) == 3, "has 3 cols: %v", rw1.Columns.String()) - //u.Infof("SQL?: '%v'", rw1.String()) - assert.Equal(t, rw1.String(), "SELECT name, email, user_id FROM users", "%v", rw1.String()) - - rw1 = sql.From[1].Rewrite(sql) - assert.True(t, rw1 != nil, "should not be nil:") - assert.True(t, len(rw1.Columns) == 3, "has 3 cols: %v", rw1.Columns.String()) - //u.Infof("SQL?: '%v'", rw1.String()) - assert.True(t, rw1.String() == "SELECT item_id, price, user_id FROM orders", "%v", rw1.String()) - - // Do we change? - //assert.Equal(t, sql.Columns.FieldNames(), []string{"user_id", "email", "item_id", "price"}) - - s = `SELECT u.name, u.email, b.title - FROM users AS u INNER JOIN blog AS b - ON u.name = b.author;` - sql = parseOrPanic(t, s).(*rel.SqlSelect) - assert.True(t, len(sql.Columns) == 3, "has 3 cols: %v", len(sql.Columns)) - assert.True(t, len(sql.From) == 2, "has 2 sources: %v", len(sql.From)) - rw1 = sql.From[0].Rewrite(sql) - assert.True(t, rw1 != nil, "should not be nil:") - assert.True(t, len(rw1.Columns) == 2, "has 2 cols: %v", rw1.Columns.String()) - //u.Infof("SQL?: '%v'", rw1.String()) - assert.True(t, rw1.String() == "SELECT name, email FROM users", "%v", rw1.String()) - jn := sql.From[0].JoinNodes() - assert.True(t, len(jn) == 1, "%v", jn) - assert.True(t, jn[0].String() == "name", "wanted 1 node %v", jn[0].String()) - cols := sql.From[0].UnAliasedColumns() - assert.True(t, len(cols) == 2, "Should have 2: %#v", cols) - //u.Infof("cols: %#v", cols) - rw1 = sql.From[1].Rewrite(sql) - assert.True(t, rw1 != nil, "should not be nil:") - assert.True(t, len(rw1.Columns) == 2, "has 2 cols: %v", rw1.Columns.String()) - // TODO: verify that we can rewrite sql for aliases - // jn, _ = sql.From[1].JoinValueExpr() - // assert.True(t, jn.String() == "name", "%v", jn.String()) - // u.Infof("SQL?: '%v'", rw1.String()) - // assert.True(t, rw1.String() == "SELECT title, author as name FROM blog", "%v", rw1.String()) - - s = `SELECT u.name, u.email, b.title - FROM users AS u INNER JOIN blog AS b - ON tolower(u.author) = b.author;` - sql = parseOrPanic(t, s).(*rel.SqlSelect) - sql.Rewrite() - selu := sql.From[0].Source - assert.True(t, len(selu.Columns) == 3, "user 3 cols: %v", selu.Columns.String()) - assert.True(t, selu.String() == "SELECT name, email, author FROM users", "%v", selu.String()) - jn = sql.From[0].JoinNodes() - assert.True(t, len(jn) == 1, "wanted 1 node but got fromP: %p %v", sql.From[0], jn) - assert.True(t, jn[0].String() == "tolower(author)", "wanted 1 node %v", jn[0].String()) - cols = sql.From[0].UnAliasedColumns() - assert.True(t, len(cols) == 3, "Should have 3: %#v", cols) - - // Now lets try compound join keys - s = `SELECT u.name, u.email, b.title - FROM users AS u INNER JOIN blog AS b - ON u.name = b.author and tolower(u.alias) = b.alias;` - sql = parseOrPanic(t, s).(*rel.SqlSelect) - sql.Rewrite() - assert.True(t, len(sql.Columns) == 3, "has 3 cols: %v", len(sql.Columns)) - assert.True(t, len(sql.From) == 2, "has 2 sources: %v", len(sql.From)) - rw1 = sql.From[0].Source - assert.True(t, rw1 != nil, "should not be nil:") - assert.True(t, len(rw1.Columns) == 3, "has 3 cols: %v", rw1.Columns.String()) - //u.Infof("SQL?: '%v'", rw1.String()) - assert.True(t, rw1.String() == "SELECT name, email, alias FROM users", "%v", rw1.String()) - jn = sql.From[0].JoinNodes() - assert.True(t, len(jn) == 2, "wanted 2 join nodes but %v", len(jn)) - assert.True(t, jn[0].String() == "name", `want "name" %v`, jn[0].String()) - assert.True(t, jn[1].String() == "tolower(alias)", `want "tolower(alias)" but got %q`, jn[1].String()) - cols = sql.From[0].UnAliasedColumns() - assert.True(t, len(cols) == 3, "Should have 3: %#v", cols) - //u.Infof("cols: %#v", cols) - rw1 = sql.From[1].Source - assert.True(t, rw1 != nil, "should not be nil:") - assert.True(t, len(rw1.Columns) == 3, "has 3 cols: %v", rw1.Columns.String()) - - // This test, is looking at these aspects of rewrite - // 1 the dotted notation of 'repostory.name' ensuring we have removed the p. - // 2 where clause - s = ` - SELECT - p.actor, p.repository.name, a.title - FROM article AS a - INNER JOIN github_push AS p - ON p.actor = a.author - WHERE p.follow_ct > 20 AND a.email IS NOT NULL - ` - sql = parseOrPanic(t, s).(*rel.SqlSelect) - assert.True(t, len(sql.Columns) == 3, "has 3 cols: %v", len(sql.Columns)) - assert.True(t, len(sql.From) == 2, "has 2 sources: %v", len(sql.From)) - - rw0 := sql.From[0].Rewrite(sql) - rw1 = sql.From[1].Rewrite(sql) - assert.True(t, rw0 != nil, "should not be nil:") - assert.True(t, len(rw0.Columns) == 3, "has 3 cols: %v", rw0.String()) - assert.True(t, len(sql.From[0].Source.Columns) == 3, "has 3 cols? %s", sql.From[0].Source) - assert.True(t, rw0.String() == "SELECT title, author, email FROM article WHERE email != NULL", "Wrong SQL 0: %v", rw0.String()) - assert.True(t, rw1 != nil, "should not be nil:") - assert.True(t, len(rw1.Columns) == 3, "has 3 cols: %v", rw1.Columns.String()) - assert.True(t, len(sql.From[1].Source.Columns) == 3, "has 3 cols? %s", sql.From[1].Source) - assert.True(t, rw1.String() == "SELECT actor, `repository.name`, follow_ct FROM github_push WHERE follow_ct > 20", "Wrong SQL 1: %v", rw1.String()) - - // Original should still be the same - parts := strings.Split(sql.String(), "\n") - for _, p := range parts { - u.Debugf("----%v----", p) - } - assert.True(t, parts[0] == "SELECT p.actor, p.`repository.name`, a.title FROM article AS a", "Wrong Full SQL?: '%v'", parts[0]) - assert.True(t, parts[1] == ` INNER JOIN github_push AS p ON p.actor = a.author WHERE p.follow_ct > 20 AND a.email != NULL`, "Wrong Full SQL?: '%v'", parts[1]) - assert.True(t, sql.String() == `SELECT p.actor, p.`+"`repository.name`"+`, a.title FROM article AS a - INNER JOIN github_push AS p ON p.actor = a.author WHERE p.follow_ct > 20 AND a.email != NULL`, "Wrong Full SQL?: '%v'", sql.String()) - - s = `SELECT u.user_id, o.item_id, u.reg_date, u.email, o.price, o.order_date FROM users AS u - INNER JOIN ( - SELECT price, order_date, user_id from ORDERS - WHERE user_id IS NOT NULL AND price > 10 - ) AS o - ON u.user_id = o.user_id - ` - sql = parseOrPanic(t, s).(*rel.SqlSelect) - assert.True(t, len(sql.Columns) == 6, "has 6 cols: %v", len(sql.Columns)) - assert.True(t, len(sql.From) == 2, "has 2 sources: %v", len(sql.From)) - - assert.True(t, sql.String() == `SELECT u.user_id, o.item_id, u.reg_date, u.email, o.price, o.order_date FROM users AS u - INNER JOIN ( - SELECT price, order_date, user_id FROM ORDERS WHERE user_id != NULL AND price > 10 - ) AS o ON u.user_id = o.user_id`, "Wrong Full SQL?: '%v'", sql.String()) - - // Rewrite to remove functions, and aliasing to send all fields needed down to source - // used when we are going to poly-fill - s = `SELECT count AS ct, name as nm, todate(myfield) AS mydate FROM user` - sql = parseOrPanic(t, s).(*rel.SqlSelect) - sql.RewriteAsRawSelect() - assert.True(t, sql.String() == `SELECT count, name, myfield FROM user`, "Wrong rewrite SQL?: '%v'", sql.String()) - - // Now ensure a group by, and where columns - s = `SELECT name as nm, todate(myfield) AS mydate FROM user WHERE created > todate("2016-01-01") GROUP BY referral;` - sql = parseOrPanic(t, s).(*rel.SqlSelect) - sql.RewriteAsRawSelect() - assert.True(t, sql.String() == `SELECT name, myfield, referral, created FROM user WHERE created > todate("2016-01-01") GROUP BY referral`, "Wrong rewrite SQL?: '%v'", sql.String()) - - //assert.True(t, sql.From[1].Name == "ORDERS", "orders? %q", sql.From[1].Name) - // sql.From[0].Rewrite(sql) - // sql.From[1].Rewrite(sql) - // assert.True(t, sql.From[0].Source.String() == `SELECT user_id, reg_date, email FROM users`, "Wrong Full SQL?: '%v'", sql.From[0].Source.String()) - // assert.True(t, sql.From[1].Source.String() == `SELECT item_id, price, order_date, user_id FROM ORDERS`, "Wrong Full SQL?: '%v'", sql.From[1].Source.String()) - - // s = `SELECT aa.*, - // bb.meal - // FROM table1 aa - // INNER JOIN table2 bb - // ON aa.tableseat = bb.tableseat AND - // aa.weddingtable = bb.weddingtable - // INNER JOIN - // ( - // SELECT a.tableSeat - // FROM table1 a - // INNER JOIN table2 b - // ON a.tableseat = b.tableseat AND - // a.weddingtable = b.weddingtable - // WHERE b.meal IN ('chicken', 'steak') - // GROUP by a.tableSeat - // HAVING COUNT(DISTINCT b.Meal) = 2 - // ) c ON aa.tableseat = c.tableSeat - // ` -} - func TestSqlFingerPrinting(t *testing.T) { t.Parallel() // Fingerprinting allows the select statement to have a cached plan regardless diff --git a/schema/apply_schema.go b/schema/apply_schema.go index 66f36afb..b8f6827b 100644 --- a/schema/apply_schema.go +++ b/schema/apply_schema.go @@ -64,7 +64,7 @@ func (m *InMemApplyer) AddOrUpdateOnSchema(s *Schema, v interface{}) error { // Find the type of operation being updated. switch v := v.(type) { case *Table: - u.Debugf("%p:%s InfoSchema P:%p adding table %q", s, s.Name, s.InfoSchema, v.Name) + //u.Debugf("%p:%s InfoSchema P:%p adding table %q", s, s.Name, s.InfoSchema, v.Name) s.InfoSchema.DS.Init() // Wipe out cache, it is invalid s.mu.Lock() s.addTable(v) @@ -72,7 +72,7 @@ func (m *InMemApplyer) AddOrUpdateOnSchema(s *Schema, v interface{}) error { s.InfoSchema.refreshSchemaUnlocked() case *Schema: - u.Debugf("%p:%s InfoSchema P:%p adding schema %q s==v?%v", s, s.Name, s.InfoSchema, v.Name, s == v) + //u.Debugf("%p:%s InfoSchema P:%p adding schema %q s==v?%v", s, s.Name, s.InfoSchema, v.Name, s == v) if s == v { // s==v means schema has been updated m.reg.mu.Lock() diff --git a/schema/datasource.go b/schema/datasource.go index 6437e5d8..002a6b7a 100644 --- a/schema/datasource.go +++ b/schema/datasource.go @@ -32,15 +32,15 @@ type ( // Close() // Source interface { - // Init provides opportunity for those sources that require/ no configuration and - // introspect schema from their environment time to load pre-schema discovery + // Init provides opportunity for those sources that require no configuration and + // introspect schema from their environment time to load pre-schema discovery. Init() // Setup optional interface for getting the Schema injected during creation/starup. // Since the Source is a singleton, stateful manager, it has a startup/shutdown process. Setup(*Schema) error // Close this source, ensure connections, underlying resources are closed. Close() error - // Open create a connection (not thread safe) to this source. + // Open create a connection to this source (the connection is not thread safe). Open(source string) (Conn, error) // Tables is a list of table names provided by this source. Tables() []string @@ -67,6 +67,13 @@ type ( // Underlying data type of column Column(col string) (value.ValueType, bool) } + + // SourceFeatures is optional interface allowing a source to declare its features so the + // planner can be more accurate. + SourceFeatures interface { + // Features describes the features of a datasource. + Features() *DataSourceFeatures + } ) type ( diff --git a/schema/source_features.go b/schema/source_features.go new file mode 100644 index 00000000..eaeea685 --- /dev/null +++ b/schema/source_features.go @@ -0,0 +1,37 @@ +package schema + +type ( + // DataSourceFeatures describes the features of a datasource. + DataSourceFeatures struct { + aggFuncs map[string]struct{} + projectionFuncs map[string]*FuncFeature + GroupBy bool + Having bool + Partitionable bool + } + + // FuncFeature describes the features of a function from datasource. + FuncFeature struct { + // Name of the function in underlying source. + Name string + // QLBName is the QLBridge name + QLBName string + } +) + +// FeaturesDefault is list of datasource features. +func FeaturesDefault() *DataSourceFeatures { + return &DataSourceFeatures{} +} + +// HasAgg does this datasource support Agg function (count(*), sum(*)) etc, these func's +// can be pushed down to underlying engine as part of GroupBy query. +func (m *DataSourceFeatures) HasAgg(name string) bool { + return false +} + +// HasProjectionFunc does this datasource support projection function tolower(field) +// can be pushed down to underlying engine as part of projection. +func (m *DataSourceFeatures) HasProjectionFunc(name string) (string, bool) { + return "", false +} diff --git a/testutil/testsuite.go b/testutil/testsuite.go index 5add4a4b..10c7d7dc 100644 --- a/testutil/testsuite.go +++ b/testutil/testsuite.go @@ -33,6 +33,10 @@ func RunDDLTests(t TestingT) { // RunTestSuite run the normal DML SQL test suite. func RunTestSuite(t TestingT) { + // TestSelect(t, `select exists(email), email FROM users WHERE yy(reg_date) > 10;`, + // [][]driver.Value{{true, "aaron@email.com"}}, + // ) + // return // Literal Queries TestSelect(t, `select 1;`, [][]driver.Value{{int64(1)}}, @@ -123,9 +127,12 @@ func RunTestSuite(t TestingT) { // RunSimpleSuite run the normal DML SQL test suite. func RunSimpleSuite(t TestingT) { - TestSelect(t, "SELECT email FROM users WHERE interests != NULL)", - [][]driver.Value{{"aaron@email.com"}, {"bob@email.com"}}, + TestSelect(t, "SELECT user_id FROM users WHERE (`users.user_id` != NULL)", + [][]driver.Value{{"hT2impsabc345c"}, {"9Ip1aKbeZe2njCDM"}, {"hT2impsOPUREcVPc"}}, ) + // TestSelect(t, "SELECT *, user_id as uid FROM users WHERE (`users.user_id` != NULL)", + // [][]driver.Value{{"hT2impsabc345c"}, {"9Ip1aKbeZe2njCDM"}, {"hT2impsOPUREcVPc"}}, + // ) return // // Function in select projected columns that needs to be late evaluated. // // "select json.jmespath(body,\"name\") AS name FROM article WHERE `author` = \"aaron\";", From e8d975c9fb530ea8e3b1a33b260d431d4b651e78 Mon Sep 17 00:00:00 2001 From: Aaron Raddon Date: Sun, 10 Jun 2018 15:13:06 -0700 Subject: [PATCH 3/4] wip --- datasource/context.go | 4 ++ datasource/datatypes.go | 1 + datasource/files/filesource_test.go | 1 + datasource/schemadb.go | 1 + exec/join.go | 6 +-- exec/projection.go | 27 ++++++---- exec/where.go | 8 +-- plan/plan.go | 14 +++-- plan/planner_select.go | 7 +++ plan/projection.go | 80 +++++++++-------------------- rel/sql.go | 43 +++++++++++----- rel/sql_rewrite.go | 9 ++-- schema/schema.go | 5 +- vm/vm.go | 2 +- 14 files changed, 110 insertions(+), 98 deletions(-) diff --git a/datasource/context.go b/datasource/context.go index 7b2bf902..ea9e4f1d 100644 --- a/datasource/context.go +++ b/datasource/context.go @@ -116,6 +116,7 @@ func (m *SqlDriverMessageMap) Values() []driver.Value { return m.Vals } func (m *SqlDriverMessageMap) SetRow(row []driver.Value) { m.Vals = row } func (m *SqlDriverMessageMap) Ts() time.Time { return time.Time{} } func (m *SqlDriverMessageMap) Get(key string) (value.Value, bool) { + key = strings.ToLower(key) if idx, ok := m.ColIndex[key]; ok { return value.NewValue(m.Vals[idx]), true } @@ -229,6 +230,9 @@ func NewNestedContextReadWriter(readers []expr.ContextReader, writer expr.Contex func (n *NestedContextReader) Get(key string) (value.Value, bool) { for _, r := range n.readers { + if r == nil { + continue + } val, ok := r.Get(key) if ok && val != nil { return val, ok diff --git a/datasource/datatypes.go b/datasource/datatypes.go index b43d77d9..ed7eb499 100644 --- a/datasource/datatypes.go +++ b/datasource/datatypes.go @@ -63,6 +63,7 @@ func (m *TimeValue) Time() time.Time { func (m *TimeValue) Scan(src interface{}) error { + u.Debugf("time %T %v", src, src) var t time.Time var dstr string switch val := src.(type) { diff --git a/datasource/files/filesource_test.go b/datasource/files/filesource_test.go index 6c5a6edd..ae5d3f71 100644 --- a/datasource/files/filesource_test.go +++ b/datasource/files/filesource_test.go @@ -70,6 +70,7 @@ func TestFileList(t *testing.T) { {"testjson"}, }, ) + return testutil.TestSqlSelect(t, "testcsvs", `show tables;`, [][]driver.Value{ {"appearances"}, diff --git a/datasource/schemadb.go b/datasource/schemadb.go index 4aa34968..71eaa40a 100644 --- a/datasource/schemadb.go +++ b/datasource/schemadb.go @@ -132,6 +132,7 @@ func (m *SchemaDb) Open(schemaObjectName string) (schema.Conn, error) { case "engines", "procedures", "functions", "indexes": return &SchemaSource{db: m, tbl: tbl, rows: nil}, nil default: + u.Warnf("here") return &SchemaSource{db: m, tbl: tbl, rows: tbl.AsRows()}, nil } diff --git a/exec/join.go b/exec/join.go index 7b7bdba4..14358b36 100644 --- a/exec/join.go +++ b/exec/join.go @@ -242,7 +242,7 @@ func (m *JoinMerge) Run() error { //u.Debugf("msgsct: %v msgs:%#v", len(msgs), msgs) for _, msg := range msgs { //outCh <- datasource.NewUrlValuesMsg(i, msg) - //u.Debugf("i:%d msg:%#v", i, msg) + u.Warnf("i:%d msg:%#v", i, msg) msg.IdVal = i i++ outCh <- msg @@ -289,8 +289,8 @@ func (m *JoinMerge) valIndexing(valOut, valSource []driver.Value, cols []*rel.Co if col.Index < 0 || col.Index >= len(valSource) { u.Errorf("source index out of range? idx:%v of %d source: %#v \n\tcol=%#v", col.Index, len(valSource), valSource, col) } - //u.Infof("found: si=%v pi:%v idx:%d as=%v vals:%v len(out):%v", col.SourceIndex, col.ParentIndex, col.Index, col.As, valSource, len(valOut)) - valOut[col.ParentIndex] = valSource[col.Index] + //u.Infof("found: si=%v pi:%v idx:%d as=%v val:%v len(out):%v", col.SourceIndex, col.ParentIndex, col.Index, col.As, valSource[col.Index], len(valOut)) + valOut[col.ParentIndex] = valSource[col.SourceIndex] } return valOut } diff --git a/exec/projection.go b/exec/projection.go index 4b4e4a4c..e30ab6b7 100644 --- a/exec/projection.go +++ b/exec/projection.go @@ -124,7 +124,7 @@ func (m *Projection) projectionEvaluator(isFinal bool) MessageHandler { if m.p.Proj != nil { if len(m.p.Proj.Columns) > colCt { colCt = len(m.p.Proj.Columns) - } else { + } else if len(m.p.Proj.Columns) != colCt { u.Warnf("wtf less? %v vs %v", colCt, len(m.p.Proj.Columns)) } @@ -132,10 +132,10 @@ func (m *Projection) projectionEvaluator(isFinal bool) MessageHandler { u.Errorf("crap %+v", m.p.Proj) } for i, col := range m.p.Proj.Columns { - u.Debugf("%d %+v", i, col) + u.Debugf("%d %#v", i, col) } for i, col := range columns { - u.Debugf("%d %+v", i, col) + u.Debugf("%d %#v", i, col) } } @@ -153,12 +153,18 @@ func (m *Projection) projectionEvaluator(isFinal bool) MessageHandler { var outMsg schema.Message switch mt := msg.(type) { case *datasource.SqlDriverMessageMap: + var rdr expr.ContextReader // use our custom write context for example purposes row := make([]driver.Value, colCt) - rdr := datasource.NewNestedContextReader([]expr.ContextReader{ - mt, - ctx.Session, - }, mt.Ts()) + if ctx.Session == nil { + rdr = mt + } else { + rdr = datasource.NewNestedContextReader([]expr.ContextReader{ + mt, + ctx.Session, + }, mt.Ts()) + } + u.Debugf("about to project: colCt:%d message:%#v", colCt, mt) colIdx := -1 for _, col := range columns { @@ -238,14 +244,15 @@ func (m *Projection) projectionEvaluator(isFinal bool) MessageHandler { //u.Infof("mt: %T mt %#v", mt, mt) row[colIdx] = nil //v.Value() } else { - //u.Debugf("%d:%d row:%d evaled: %v val=%v", colIdx, colCt, len(row), col, v.Value()) + u.Debugf("%d:%d row:%d evaled: %v val=%v", colIdx, colCt, len(row), col, v.Value()) //writeContext.Put(col, mt, v) row[colIdx] = v.Value() + } } } - //u.Infof("row: %#v", row) - //u.Infof("row cols: %v", colIndex) + u.Infof("row: %#v", row) + u.Infof("row cols: %v", colIndex) outMsg = datasource.NewSqlDriverMessageMap(0, row, colIndex) case expr.ContextReader: diff --git a/exec/where.go b/exec/where.go index 3da1577d..2965d4f0 100644 --- a/exec/where.go +++ b/exec/where.go @@ -87,7 +87,7 @@ func NewHaving(ctx *plan.Context, p *plan.Having) *Where { func whereFilter(filter expr.Node, task TaskRunner, cols map[string]int) MessageHandler { out := task.MessageOut() - //u.Debugf("prepare filter %s", filter) + u.Debugf("WHERE prepare filter %s", filter) return func(ctx *plan.Context, msg schema.Message) bool { var filterValue value.Value @@ -102,7 +102,7 @@ func whereFilter(filter expr.Node, task TaskRunner, cols map[string]int) Message case *datasource.SqlDriverMessageMap: filterValue, ok = vm.Eval(mt, filter) if !ok { - u.Warnf("wtf %s %#v", filter, mt) + //u.Warnf("wtf %s %#v", filter, mt) } //u.Debugf("WHERE: result:%v T:%T \n\trow:%#v \n\tvals:%#v", filterValue, msg, mt, mt.Values()) //u.Debugf("cols: %#v", cols) @@ -125,7 +125,7 @@ func whereFilter(filter expr.Node, task TaskRunner, cols map[string]int) Message switch valTyped := filterValue.(type) { case value.BoolValue: if valTyped.Val() == false { - //u.Debugf("Filtering out: T:%T v:%#v", valTyped, valTyped) + u.Debugf("Filtering out: T:%T v:%#v \n\t%#v", valTyped, valTyped, msg) return true } case nil: @@ -136,7 +136,7 @@ func whereFilter(filter expr.Node, task TaskRunner, cols map[string]int) Message } } - //u.Debugf("about to send from where to forward: %#v", msg) + u.Debugf("about to send from where to forward: %#v", msg) select { case out <- msg: return true diff --git a/plan/plan.go b/plan/plan.go index ccdd4be5..5a22646c 100644 --- a/plan/plan.go +++ b/plan/plan.go @@ -749,7 +749,7 @@ func (m *Source) serializeToPb() error { return nil } func (m *Source) load() error { - // u.Debugf("source load schema=%s from=%s %#v", m.ctx.Schema.Name, m.Stmt.SourceName(), m.Stmt) + u.Debugf("source load schema=%s from=%s %#v", m.ctx.Schema.Name, m.Stmt.SourceName(), m.Stmt) if m.Stmt == nil { return nil } @@ -867,13 +867,17 @@ func NewJoinMerge(l, r Task, lf, rf *rel.SqlSource) *JoinMerge { // Build an index of source to destination column indexing for _, col := range lf.Source.Columns { //u.Debugf("left col: idx=%d key=%q as=%q col=%v parentidx=%v", len(m.colIndex), col.Key(), col.As, col.String(), col.ParentIndex) - m.ColIndex[lf.Alias+"."+col.Key()] = col.ParentIndex - //u.Debugf("left colIndex: %15q : idx:%d sidx:%d pidx:%d", m.leftStmt.Alias+"."+col.Key(), col.Index, col.SourceIndex, col.ParentIndex) + if col.ParentIndex >= 0 { + m.ColIndex[lf.Alias+"."+col.Key()] = col.ParentIndex + } + u.Debugf("left colIndex: %15q : idx:%d sidx:%d pidx:%d", lf.Alias+"."+col.Key(), col.Index, col.SourceIndex, col.ParentIndex) } for _, col := range rf.Source.Columns { //u.Debugf("right col: idx=%d key=%q as=%q col=%v", len(m.colIndex), col.Key(), col.As, col.String()) - m.ColIndex[rf.Alias+"."+col.Key()] = col.ParentIndex - //u.Debugf("right colIndex: %15q : idx:%d sidx:%d pidx:%d", m.rightStmt.Alias+"."+col.Key(), col.Index, col.SourceIndex, col.ParentIndex) + if col.ParentIndex >= 0 { + m.ColIndex[rf.Alias+"."+col.Key()] = col.ParentIndex + } + u.Debugf("right colIndex: %15q : idx:%d sidx:%d pidx:%d", rf.Alias+"."+col.Key(), col.Index, col.SourceIndex, col.ParentIndex) } return m diff --git a/plan/planner_select.go b/plan/planner_select.go index 3b3c0646..c7e6570e 100644 --- a/plan/planner_select.go +++ b/plan/planner_select.go @@ -83,6 +83,12 @@ func (m *PlannerDefault) WalkSelect(p *Select) error { return err } + var curSource Task = sourceTask + if from.Source.Where != nil { + u.Errorf("got a WHERE") + curSource.Add(NewWhere(from.Source)) + } + // now fold into previous task if i != 0 { from.Seekable = true @@ -229,6 +235,7 @@ func (m *PlannerDefault) WalkSourceSelect(p *Source) error { } else { if schemaCols, ok := p.Conn.(schema.ConnColumns); ok { + u.Debugf("schemaCols: %#v cols=%v", p.Conn, schemaCols.Columns()) if err := buildColIndex(schemaCols, p); err != nil { u.Warnf("could not build index %v", err) return err diff --git a/plan/projection.go b/plan/projection.go index c1a86162..03db14f8 100644 --- a/plan/projection.go +++ b/plan/projection.go @@ -74,14 +74,12 @@ func (m *Projection) loadLiteralProjection(ctx *Context) error { case *expr.NumberNode: // number? if et.IsInt { - proj.AddColumnShort(as, value.IntType) + proj.AddColumnShort(as, value.IntType, true) } else { - proj.AddColumnShort(as, value.NumberType) + proj.AddColumnShort(as, value.NumberType, true) } - //u.Infof("number? %#v", et) default: - //u.Infof("type? %#v", et) - proj.AddColumnShort(as, value.StringType) + proj.AddColumnShort(as, value.StringType, true) } } @@ -118,76 +116,45 @@ func (m *Projection) loadFinal(ctx *Context, isFinal bool) error { u.Infof("%d from:%s col %s", fromi, from.Name, col) if col.Star { for _, f := range tbl.Fields { - m.Proj.AddColumnShort(f.Name, f.ValueType()) + m.Proj.AddColumnShort(f.Name, f.ValueType(), true) } } else { if schemaCol, ok := tbl.FieldMap[col.SourceField]; ok { - if isFinal { - if col.InFinalProjection() { - u.Debugf("in plan final %s", col.As) - m.Proj.AddColumnShort(col.As, schemaCol.ValueType()) - } else { - u.Warnf("not in plan final %v", col.As) - } - } else { - u.Debugf("not final %s", col.As) - m.Proj.AddColumnShort(col.As, schemaCol.ValueType()) - } - //u.Debugf("projection: %p add col: %v %v", m.Proj, col.As, schemaCol.Type.String()) + m.Proj.AddColumnShort(col.As, schemaCol.ValueType(), col.InFinalProjection()) } else { u.Infof("schema col not found: final?%v col: %#v InFinal?%v", isFinal, col, col.InFinalProjection()) - if isFinal { - if col.InFinalProjection() { - m.Proj.AddColumnShort(col.As, value.StringType) - } else { - u.Warnf("not adding to projection? %s", col) - } - } else { - m.Proj.AddColumnShort(col.As, value.StringType) - } + m.Proj.AddColumnShort(col.As, value.StringType, col.InFinalProjection()) } } - } } } + + for i, col := range m.Proj.Columns { + u.Debugf("%d %#v", i, col) + } return nil } func projectionForSourcePlan(plan *Source) error { plan.Proj = rel.NewProjection() - u.WarnT(9) - u.Errorf("projection. tbl?%v plan.Final?%v source: %s", plan.Tbl != nil, plan.Final, plan.Stmt.Source) + u.Infof("projection. tbl?%v plan.Final?%v source: %s", plan.Tbl != nil, plan.Final, plan.Stmt.Source) - // u.Debugf("created plan.Proj *rel.Projection %p", plan.Proj) // Not all Execution run-times support schema. ie, csv files and other "ad-hoc" structures // do not have to have pre-defined data in advance, in which case the schema output // will not be deterministic on the sql []driver.values for _, col := range plan.Stmt.Source.Columns { - u.Debugf("%2d col: %v star?%v inFinal?%v", len(plan.Proj.Columns), col, col.Star, col.InFinalProjection()) + u.Debugf("%2d col: %#v star?%v inFinal?%v", len(plan.Proj.Columns), col, col.Star, col.InFinalProjection()) if plan.Tbl == nil { - if plan.Final { - if col.InFinalProjection() { - plan.Proj.AddColumn(col, value.StringType) - } - } else { - plan.Proj.AddColumn(col, value.StringType) - } + plan.Proj.AddColumn(col, value.StringType, col.InFinalProjection()) + } else if schemaCol, ok := plan.Tbl.FieldMap[col.SourceField]; ok { - if plan.Final { - if col.InFinalProjection() { - //u.Infof("col add %v for %s", schemaCol.Type.String(), col) - plan.Proj.AddColumn(col, schemaCol.ValueType()) - } else { - u.Infof("not in final? %#v", col) - } - } else { - plan.Proj.AddColumn(col, schemaCol.ValueType()) - } - //u.Debugf("projection: %p add col: %v %v", plan.Proj, col.As, schemaCol.Type.String()) + + plan.Proj.AddColumn(col, schemaCol.ValueType(), col.InFinalProjection()) + } else if col.Star { if plan.Tbl == nil { u.Warnf("no table?? %v", plan) @@ -195,14 +162,15 @@ func projectionForSourcePlan(plan *Source) error { u.Infof("star cols? %v fields: %v", plan.Tbl.FieldPositions, plan.Tbl.Fields) for _, f := range plan.Tbl.Fields { //u.Infof(" add col %v %+v", f.Name, f) - plan.Proj.AddColumnShort(f.Name, f.ValueType()) + plan.Proj.AddColumnShort(f.Name, f.ValueType(), true) } } } else { + u.Warnf("WTF %#v", plan.Tbl.FieldMap) if col.Expr != nil && strings.ToLower(col.Expr.String()) == "count(*)" { //u.Warnf("count(*) as=%v", col.As) - plan.Proj.AddColumn(col, value.IntType) + plan.Proj.AddColumn(col, value.IntType, true) } else if col.Expr != nil { // A column was included in projection that does not exist in source. // TODO: Should we allow sources to have settings that specify wether @@ -210,16 +178,16 @@ func projectionForSourcePlan(plan *Source) error { // this is fine switch nt := col.Expr.(type) { case *expr.IdentityNode, *expr.StringNode: - plan.Proj.AddColumnShort(col.As, value.StringType) + plan.Proj.AddColumnShort(col.As, value.StringType, col.InFinalProjection()) case *expr.NumberNode: if nt.IsInt { - plan.Proj.AddColumnShort(col.As, value.IntType) + plan.Proj.AddColumnShort(col.As, value.IntType, col.InFinalProjection()) } else { - plan.Proj.AddColumnShort(col.As, value.NumberType) + plan.Proj.AddColumnShort(col.As, value.NumberType, col.InFinalProjection()) } case *expr.FuncNode, *expr.BinaryNode: // Probably not string? - plan.Proj.AddColumnShort(col.As, value.StringType) + plan.Proj.AddColumnShort(col.As, value.StringType, col.InFinalProjection()) default: u.Warnf("schema col not found: SourceField=%q vals=%#v", col.SourceField, col) } diff --git a/rel/sql.go b/rel/sql.go index 9b0f4b5c..b748857a 100644 --- a/rel/sql.go +++ b/rel/sql.go @@ -321,13 +321,16 @@ func NewSqlDialect() expr.DialectWriter { func NewProjection() *Projection { return &Projection{Columns: make(ResultColumns, 0), colNames: make(map[string]struct{})} } -func NewResultColumn(as string, ordinal int, col *Column, valtype value.ValueType) *ResultColumn { - rc := ResultColumn{Name: as, As: as, ColPos: ordinal, Col: col, Type: valtype} + +// NewResultColumn create a new column describing a result column, may be final or intermediate. +func NewResultColumn(as string, ordinal int, col *Column, valtype value.ValueType, final bool) *ResultColumn { + rc := ResultColumn{Name: as, As: as, ColPos: ordinal, Col: col, Type: valtype, Final: final} if col != nil { rc.Name = col.SourceField } return &rc } + func NewSqlSelect() *SqlSelect { req := &SqlSelect{} req.Columns = make(Columns, 0) @@ -390,10 +393,13 @@ func NewColumnValue(tok lex.Token) *Column { } } func NewColumn(col string) *Column { + l, r, _ := expr.LeftRight(col) + u.Debugf("col=%q l=%q r=%q", col, l, r) return &Column{ - As: col, - SourceField: col, - Expr: &expr.IdentityNode{Text: col}, + As: col, + SourceField: col, + SourceOriginal: col, + Expr: &expr.IdentityNode{Text: col}, } } @@ -470,22 +476,23 @@ func resultColumnToPb(m *ResultColumn) *ResultColumnPb { return s } -func (m *Projection) AddColumnShort(colName string, vt value.ValueType) { +func (m *Projection) AddColumnShort(colName string, vt value.ValueType, final bool) { //colName = strings.ToLower(colName) // if _, exists := m.colNames[colName]; exists { // return // } //u.Infof("adding column %s to %v", colName, m.colNames) //m.colNames[colName] = struct{}{} - m.Columns = append(m.Columns, NewResultColumn(colName, len(m.Columns), nil, vt)) + m.Columns = append(m.Columns, NewResultColumn(colName, len(m.Columns), nil, vt, final)) } -func (m *Projection) AddColumn(col *Column, vt value.ValueType) { +func (m *Projection) AddColumn(col *Column, vt value.ValueType, final bool) { //colName := strings.ToLower(col.As) // if _, exists := m.colNames[colName]; exists { // return // } //m.colNames[colName] = struct{}{} - m.Columns = append(m.Columns, NewResultColumn(col.As, len(m.Columns), col, vt)) + u.Debugf("AddColumn %#v", col) + m.Columns = append(m.Columns, NewResultColumn(col.As, len(m.Columns), col, vt, final)) } func (m *Projection) Equal(s *Projection) bool { if m == nil && s == nil { @@ -633,6 +640,11 @@ func (m Columns) Equal(cols Columns) bool { } func (m *Column) Key() string { + // if m.right == "" && m.As == "" { + // u.Warnf("WTF no col info %#v", m) + // u.WarnT(10) + // } + // u.Debugf("Key(): left=%q right=%q As=%q", m.left, m.right, m.As) if m.left != "" { return m.right } @@ -1252,7 +1264,7 @@ func (m *SqlSelect) ColIndexes() map[string]int { cols := make(map[string]int, len(m.Columns)) for i, col := range m.Columns { //u.Debugf("aliasing: key():%-15q As:%-15q %-15q", col.Key(), col.As, col.String()) - cols[col.Key()] = i + cols[strings.ToLower(col.Key())] = i } return cols } @@ -1405,17 +1417,20 @@ func (m *SqlSource) BuildColIndex(colNames []string) error { } starDelta := 0 // how many columns were added due to * for _, col := range m.Source.Columns { + // if col.Key() == "" { + // u.Errorf("WTF no key? %#v", col) + // } if col.Star { starStart := len(m.colIndex) - for colIdx := range colNames { - m.colIndex[col.Key()] = colIdx + starStart + for colIdx, colName := range colNames { + m.colIndex[strings.ToLower(colName)] = colIdx + starStart } starDelta = len(colNames) } else { found := false for colIdx, colName := range colNames { _, colName, _ = expr.LeftRight(colName) - //u.Debugf("col.Key():%v sourceField:%v colName:%v", col.Key(), col.SourceField, colName) + u.Debugf("col.Key():%v sourceField:%v colName:%v", col.Key(), col.SourceField, colName) if colName == col.Key() || col.SourceField == colName { //&& //u.Debugf("build col: idx=%d key=%-15q as=%-15q col=%-15s sourcidx:%d", len(m.colIndex), col.Key(), col.As, col.String(), colIdx) m.colIndex[col.Key()] = colIdx + starDelta @@ -1425,6 +1440,8 @@ func (m *SqlSource) BuildColIndex(colNames []string) error { } } if !found && !col.IsLiteralOrFunc() { + u.Errorf("could not find col? %s", col) + u.WarnT(10) return fmt.Errorf("Missing Column in source: %q", col.String()) } } diff --git a/rel/sql_rewrite.go b/rel/sql_rewrite.go index 855b109d..a4d97fd2 100644 --- a/rel/sql_rewrite.go +++ b/rel/sql_rewrite.go @@ -144,6 +144,7 @@ func RewriteSqlSource(source *SqlSource, parentStmt *SqlSelect) (*SqlSelect, err } //u.Debugf("after WHERE: %s", sql2.Columns) source.Source = sql2 + u.Debugf("rewritten source: %s", sql2) source.cols = sql2.UnAliasedColumns() return sql2, nil } @@ -163,7 +164,7 @@ func (m *rewriteSelect) addColumn(col Column) { return } - //u.Infof("adding col %+v", col) + u.Infof("adding col %#v", col) m.cols[col.SourceField] = true m.sel.AddColumn(col) } @@ -205,12 +206,12 @@ func (m *rewriteSelect) intoProjection(sel *SqlSelect, cols Columns) error { continue } - //u.Infof("as=%-15s source=%-15s exprT:%T expr=%s star:%v", c.As, c.SourceField, c.Expr, c.Expr, c.Star) + u.Infof("as=%-15s source=%-15s exprT:%T expr=%s star:%v", c.As, c.SourceField, c.Expr, c.Expr, c.Star) switch n := c.Expr.(type) { case *expr.IdentityNode: nc := NewColumn(strings.ToLower(c.SourceField)) nc.ParentIndex = i - nc.Expr = n + nc.Expr = &expr.IdentityNode{Text: strings.ToLower(n.Text)} m.addColumn(*nc) case *expr.FuncNode: // TODO: use features. @@ -255,7 +256,7 @@ func (m *rewriteSelect) intoProjection(sel *SqlSelect, cols Columns) error { // } // } func (m *rewriteSelect) rewriteWhere(stmt *SqlSelect, from *SqlSource, node expr.Node) expr.Node { - //u.Debugf("rewrite where %s", node) + u.Debugf("rewrite where %s", node) switch nt := node.(type) { case *expr.IdentityNode: if left, right, hasLeft := nt.LeftRight(); hasLeft { diff --git a/schema/schema.go b/schema/schema.go index b6adb0eb..3f5a2a20 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -275,6 +275,7 @@ func (m *Schema) SchemaForTable(tableName string) (*Schema, error) { // We always lower-case table names tableName = strings.ToLower(tableName) + u.Warnf("Schema for table schema.Name=%q", m.Name) if m.Name == "schema" { return m, nil } @@ -525,7 +526,7 @@ func (m *Table) AddField(fld *Field) { fld.idx = uint64(len(m.Fields)) m.Fields = append(m.Fields, fld) } - m.FieldMap[fld.Name] = fld + m.FieldMap[strings.ToLower(fld.Name)] = fld } // AddFieldType describe and register a new column @@ -550,7 +551,7 @@ func (m *Table) Column(col string) (value.ValueType, bool) { func (m *Table) SetColumns(cols []string) { m.FieldPositions = make(map[string]int, len(cols)) for idx, col := range cols { - //col = strings.ToLower(col) + col = strings.ToLower(col) m.FieldPositions[col] = idx cols[idx] = col } diff --git a/vm/vm.go b/vm/vm.go index b6a7b406..c408aeb8 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -574,7 +574,7 @@ func evalBinary(ctx expr.EvalContext, node *expr.BinaryNode, depth int) (value.V return value.NewBoolValue(false), true } // Should we evaluate strings that are non-nil to be = true? - u.Debugf("not handled: boolean %v %T=%v expr: %s", node.Operator, at.Value(), at.Val(), node.String()) + //u.Debugf("not handled: boolean %v %T=%v expr: %s", node.Operator, at.Value(), at.Val(), node.String()) return nil, false case value.Map: switch node.Operator.T { From 11cdb32a418b8a20f5db6d68f151be741285ef2e Mon Sep 17 00:00:00 2001 From: Aaron Raddon Date: Wed, 18 Jul 2018 09:34:41 -0700 Subject: [PATCH 4/4] Sql Rewrite work --- exec/source.go | 1 + plan/plan.go | 15 +++++++++++++- plan/planner_select.go | 15 +++++++++----- rel/sql.go | 24 ++++++++++++++------- rel/sql_rewrite.go | 47 ++++++++++++++++++++++-------------------- 5 files changed, 67 insertions(+), 35 deletions(-) diff --git a/exec/source.go b/exec/source.go index 01d3e314..451f50fc 100644 --- a/exec/source.go +++ b/exec/source.go @@ -133,6 +133,7 @@ func (m *Source) Run() error { for item := m.Scanner.Next(); item != nil; item = m.Scanner.Next() { + u.Debugf("source msg %#v", item) select { case <-sigChan: return nil diff --git a/plan/plan.go b/plan/plan.go index 5a22646c..87659383 100644 --- a/plan/plan.go +++ b/plan/plan.go @@ -879,7 +879,20 @@ func NewJoinMerge(l, r Task, lf, rf *rel.SqlSource) *JoinMerge { } u.Debugf("right colIndex: %15q : idx:%d sidx:%d pidx:%d", rf.Alias+"."+col.Key(), col.Index, col.SourceIndex, col.ParentIndex) } - + for _, col := range lf.Source.Columns { + //u.Debugf("left col: idx=%d key=%q as=%q col=%v parentidx=%v", len(m.colIndex), col.Key(), col.As, col.String(), col.ParentIndex) + if col.ParentIndex < 0 { + m.ColIndex[lf.Alias+"."+col.Key()] = len(m.ColIndex) + } + u.Debugf("left colIndex: %15q : idx:%d sidx:%d pidx:%d", lf.Alias+"."+col.Key(), col.Index, col.SourceIndex, len(m.ColIndex)-1) + } + for _, col := range rf.Source.Columns { + //u.Debugf("right col: idx=%d key=%q as=%q col=%v", len(m.colIndex), col.Key(), col.As, col.String()) + if col.ParentIndex < 0 { + m.ColIndex[rf.Alias+"."+col.Key()] = len(m.ColIndex) + } + u.Debugf("right colIndex: %15q : idx:%d sidx:%d pidx:%d", rf.Alias+"."+col.Key(), col.Index, col.SourceIndex, len(m.ColIndex)-1) + } return m } diff --git a/plan/planner_select.go b/plan/planner_select.go index c7e6570e..960c3a6d 100644 --- a/plan/planner_select.go +++ b/plan/planner_select.go @@ -64,11 +64,16 @@ func (m *PlannerDefault) WalkSelect(p *Select) error { // Need to rewrite the From statement to ensure all fields necessary to support // joins, wheres, etc exist but is standalone query - u.Debugf("from.Source: %s", p.Stmt) - u.Debugf("from: %s", from.String()) - from.Rewrite(p.Stmt) - u.Debugf("from-rewrite: %s", from.String()) - u.Debugf("from.Source: %s", from.Source.String()) + if len(p.Stmt.From) == 1 { + from.Source = p.Stmt + } else { + u.Debugf("from.Source: %s", p.Stmt) + u.Debugf("from: %s", from.String()) + from.Rewrite(p.Stmt) + u.Debugf("from-rewrite: %s", from.String()) + u.Debugf("from.Source: %s", from.Source.String()) + } + sourceTask, err := NewSource(m.Ctx, from, isFinal) if err != nil { return nil diff --git a/rel/sql.go b/rel/sql.go index b748857a..3f4abff5 100644 --- a/rel/sql.go +++ b/rel/sql.go @@ -236,7 +236,7 @@ type ( } // Columns List of Columns in SELECT [columns] Columns []*Column - // Column represents the Column as expressed in a [SELECT] + // Column represents the Column(s) as expressed in a [SELECT COLUMNS] // expression Column struct { sourceQuoteByte byte // quote mark? [ or ` etc @@ -1450,9 +1450,17 @@ func (m *SqlSource) BuildColIndex(colNames []string) error { } // Rewrite this Source to act as a stand-alone query to backend -// @parentStmt = the parent statement that this a partial source to +// @parentStmt = the parent statement. IE, source is a partial (join, from where-in) source in a +// multi-source SELECT statement. We are re-writing to allow the sources to be independent. func (m *SqlSource) Rewrite(parentStmt *SqlSelect) (*SqlSelect, error) { - return RewriteSqlSource(m, parentStmt) + sql2, err := RewriteSqlSource(m, parentStmt) + if err != nil { + return nil, err + } + m.Source = sql2 + u.Debugf("rewritten source: %s", sql2) + m.cols = sql2.UnAliasedColumns() + return sql2, err } func (m *SqlSource) findFromAliases() (string, string) { @@ -1477,8 +1485,8 @@ func (m *SqlSource) findFromAliases() (string, string) { return from1, from2 } -// Get a list of Un-Aliased Columns, ie columns with column -// names that have NOT yet been aliased +// UnAliasedColumns Get a list of Un-Aliased Columns, ie columns with column +// names that have NOT yet been aliased func (m *SqlSource) UnAliasedColumns() map[string]*Column { //u.Warnf("un-aliased %d", len(m.Source.Columns)) if len(m.cols) > 0 || m.Source != nil && len(m.Source.Columns) == 0 { @@ -1498,7 +1506,7 @@ func (m *SqlSource) UnAliasedColumns() map[string]*Column { return cols } -// Get a list of Column names to position +// ColumnPositions Get a list of Column names to position in array of columns. func (m *SqlSource) ColumnPositions() map[string]int { if len(m.colIndex) > 0 { return m.colIndex @@ -1520,7 +1528,7 @@ func (m *SqlSource) ColumnPositions() map[string]int { return m.colIndex } -// We need to be able to rewrite statements to convert a stmt such as: +// JoinNodes We need to be able to rewrite statements to convert a stmt such as: // // FROM users AS u // INNER JOIN orders AS o @@ -1542,6 +1550,8 @@ func (m *SqlSource) ColumnPositions() map[string]int { func (m *SqlSource) JoinNodes() []expr.Node { return m.joinNodes } + +// Finalize the source. func (m *SqlSource) Finalize() error { if m.final { return nil diff --git a/rel/sql_rewrite.go b/rel/sql_rewrite.go index a4d97fd2..6215e8f9 100644 --- a/rel/sql_rewrite.go +++ b/rel/sql_rewrite.go @@ -1,10 +1,11 @@ package rel import ( - fmt "fmt" + "fmt" "strings" u "github.com/araddon/gou" + "github.com/araddon/qlbridge/expr" "github.com/araddon/qlbridge/lex" "github.com/araddon/qlbridge/schema" @@ -54,24 +55,21 @@ func rewriteSelectStatement(sel *SqlSelect) error { originalCols := sel.Columns sel.Columns = make(Columns, 0, len(originalCols)+5) - if err := rw.intoProjection(sel, originalCols); err != nil { + if err := rw.intoProjection(sel, originalCols, true); err != nil { return err } - if err := rw.intoProjection(sel, sel.GroupBy); err != nil { + if err := rw.intoProjection(sel, sel.GroupBy, false); err != nil { return err } if sel.Where != nil { cols := expr.FindAllIdentityField(sel.Where.Expr) for _, col := range cols { nc := NewColumn(col) - nc.ParentIndex = -1 + nc.ParentIndex = -1 // ie, NOT in final rw.addColumn(*nc) } } - if err := rw.intoProjection(sel, sel.OrderBy); err != nil { - return err - } - return nil + return rw.intoProjection(sel, sel.OrderBy, false) } // RewriteSqlSource this SqlSource to act as a stand-alone query to backend @@ -93,7 +91,7 @@ func RewriteSqlSource(source *SqlSource, parentStmt *SqlSelect) (*SqlSelect, err rw.matchSource = source.Alias originalCols := parentStmt.Columns - if err := rw.intoProjection(sql2, originalCols); err != nil { + if err := rw.intoProjection(sql2, originalCols, true); err != nil { return nil, err } //u.Debugf("after into projection: %s", sql2.Columns) @@ -143,9 +141,6 @@ func RewriteSqlSource(source *SqlSource, parentStmt *SqlSelect) (*SqlSelect, err */ } //u.Debugf("after WHERE: %s", sql2.Columns) - source.Source = sql2 - u.Debugf("rewritten source: %s", sql2) - source.cols = sql2.UnAliasedColumns() return sql2, nil } func (m *rewriteSelect) addColumn(col Column) { @@ -168,7 +163,7 @@ func (m *rewriteSelect) addColumn(col Column) { m.cols[col.SourceField] = true m.sel.AddColumn(col) } -func (m *rewriteSelect) intoProjection(sel *SqlSelect, cols Columns) error { +func (m *rewriteSelect) intoProjection(sel *SqlSelect, cols Columns, final bool) error { if len(cols) == 0 { return nil } @@ -198,6 +193,7 @@ func (m *rewriteSelect) intoProjection(sel *SqlSelect, cols Columns) error { left, _, hasLeft := c.LeftRight() if !hasLeft { // ?? + u.Warnf("is this possible no left? %#v", c) } else if hasLeft && left == m.matchSource { // ok c = c.CopyRewrite(m.matchSource) @@ -206,39 +202,46 @@ func (m *rewriteSelect) intoProjection(sel *SqlSelect, cols Columns) error { continue } + parentIndex := i + if !final { + parentIndex = -1 + } + u.Infof("as=%-15s source=%-15s exprT:%T expr=%s star:%v", c.As, c.SourceField, c.Expr, c.Expr, c.Star) + + var nc *Column switch n := c.Expr.(type) { case *expr.IdentityNode: - nc := NewColumn(strings.ToLower(c.SourceField)) - nc.ParentIndex = i + nc = NewColumn(strings.ToLower(c.SourceField)) nc.Expr = &expr.IdentityNode{Text: strings.ToLower(n.Text)} - m.addColumn(*nc) case *expr.FuncNode: - // TODO: use features. + // TODO: use features to rewrite this. ie, idents := expr.FindAllIdentities(n) for _, in := range idents { _, right, _ := in.LeftRight() nc := NewColumn(strings.ToLower(right)) - nc.ParentIndex = i + nc.ParentIndex = parentIndex nc.Expr = in m.addColumn(*nc) } case *expr.NumberNode, *expr.NullNode, *expr.StringNode: // literals nc := NewColumn(strings.ToLower(n.String())) - nc.ParentIndex = i nc.Expr = n - m.addColumn(*nc) case nil: if c.Star { - nc := c.Copy() - m.addColumn(*nc) + nc = c.Copy() } else { u.Warnf("unhandled column? %T %s", n, n) } default: u.Warnf("unhandled column? %T %s", n, n) } + + if nc != nil { + nc.ParentIndex = parentIndex + m.addColumn(*nc) + } } return nil }