go-gorp/gorp

Upsert()?

mcandre opened this issue · 1 comments

Could there be a method for performing an update-or-insert as needed?

thats it, and its working

//upsert code (insert or update)
//i: to insert or update
//case update_filter==nil => all fields are updated
//conflict_filter = conflict column names
func (m *DbMap) Upsert(i interface{}, update_filter ColumnFilter, conflict_filter ColumnFilter) error {
	return upsert(m, m, i, update_filter, conflict_filter)
}
//inserts i as insert
//if exists updates columns as written in update_filter
//occurs on conflict on conflict_filter
func upsert(m *DbMap, exec SqlExecutor, ptr interface{}, update_filter ColumnFilter, conflict_filter ColumnFilter) error {
	if _, ok := m.Dialect.(PostgresDialect); !ok {
		return fmt.Errorf("not supported dialect!")
	}

	table, elem, err := m.tableForPointer(ptr, false)
	if err != nil {
		return err
	}

	eval := elem.Addr().Interface()
	//as insert
	if v, ok := eval.(HasPreInsert); ok {
		err := v.PreInsert(exec)
		if err != nil {
			return err
		}
	}
	//as update
	if v, ok := eval.(HasPreUpdate); ok {
		err = v.PreUpdate(exec)
		if err != nil {
			return err
		}
	}

	bi, err := table.bindUpsert(elem, update_filter, conflict_filter)
	if err != nil {
		return err
	}

	_, err = exec.Exec(bi.query, bi.args...)
	if err != nil {
		return err
	}

	//post actions
	//insert
	if v, ok := eval.(HasPostInsert); ok {
		err := v.PostInsert(exec)
		if err != nil {
			return err
		}
	}
	//update
	if v, ok := eval.(HasPostUpdate); ok {
		err = v.PostUpdate(exec)
		if err != nil {
			return err
		}
	}

	return nil
}

func (t *TableMap) bindUpsert(elem reflect.Value, colFilter ColumnFilter, conflictFilter ColumnFilter) (bindInstance, error) {
	if colFilter == nil {
		colFilter = acceptAllFilter
	}

	if conflictFilter == nil {
		return bindInstance{}, fmt.Errorf("conflictFilter cant be nill")
	}

	plan := &t.upsertPlan
	plan.once.Do(func() {
		plan.autoIncrIdx = -1

		s := bytes.Buffer{}
		s2 := bytes.Buffer{}
		s.WriteString(fmt.Sprintf("insert into %s (", t.dbmap.Dialect.QuotedTableForQuery(t.SchemaName, t.TableName)))

		x := 0
		first := true
		for y := range t.Columns {
			col := t.Columns[y]
			if !(col.isAutoIncr && t.dbmap.Dialect.AutoIncrBindValue() == "") {
				if !col.Transient {
					if !first {
						s.WriteString(",")
						s2.WriteString(",")
					}
					s.WriteString(t.dbmap.Dialect.QuoteField(col.ColumnName))

					if col.isAutoIncr {
						s2.WriteString(t.dbmap.Dialect.AutoIncrBindValue())
						plan.autoIncrIdx = y
						plan.autoIncrFieldName = col.fieldName
					} else {
						if col.DefaultValue == "" {
							s2.WriteString(t.dbmap.Dialect.BindVar(x))
							if col == t.version {
								plan.versField = col.fieldName
								plan.argFields = append(plan.argFields, versFieldConst)
							} else {
								plan.argFields = append(plan.argFields, col.fieldName)
							}
							x++
						} else {
							s2.WriteString(col.DefaultValue)
						}
					}
					first = false
				}
			} else {
				plan.autoIncrIdx = y
				plan.autoIncrFieldName = col.fieldName
			}
		}
		s.WriteString(") values (")
		s.WriteString(s2.String())
		s.WriteString(")")

		s.WriteString(" on conflict")

		s.WriteString("(")
		ci := 0
		for y := range t.Columns {
			col := t.Columns[y]
			if conflictFilter(col) {
				if ci > 0 {
					s.WriteString(", ")
				}
				s.WriteString(t.dbmap.Dialect.QuoteField(col.ColumnName))
				ci++
			}
		}
		if ci == 0 {
			panic("conflictFilter is incorrect: no one column matched")
		}
		s.WriteString(")")

		s.WriteString(" do update set ")
		x2 := 0
		for y := range t.Columns {
			col := t.Columns[y]
			if !col.isAutoIncr && !col.Transient && colFilter(col) {
				if x2 > 0 {
					s.WriteString(", ")
				}
				s.WriteString(t.dbmap.Dialect.QuoteField(col.ColumnName))
				s.WriteString("=")
				s.WriteString(t.dbmap.Dialect.BindVar(x + x2))

				if col == t.version {
					plan.versField = col.fieldName
					plan.argFields = append(plan.argFields, versFieldConst)
				} else {
					plan.argFields = append(plan.argFields, col.fieldName)
				}
				x2++
			}
		}

		s.WriteString(t.dbmap.Dialect.QuerySuffix())

		plan.query = s.String()
	})

	return plan.createBindInstance(elem, t.dbmap.TypeConverter)
}