120 lines
2.9 KiB
Go
120 lines
2.9 KiB
Go
|
package internal
|
||
|
|
||
|
import (
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"time"
|
||
|
|
||
|
"git.sr.ht/~erock/wish/cms/db"
|
||
|
"git.sr.ht/~erock/wish/cms/util"
|
||
|
"git.sr.ht/~erock/wish/send/utils"
|
||
|
"github.com/gliderlabs/ssh"
|
||
|
)
|
||
|
|
||
|
type Opener struct {
|
||
|
entry *utils.FileEntry
|
||
|
}
|
||
|
|
||
|
func (o *Opener) Open(name string) (io.Reader, error) {
|
||
|
return o.entry.Reader, nil
|
||
|
}
|
||
|
|
||
|
type DbHandler struct {
|
||
|
User *db.User
|
||
|
DBPool db.DB
|
||
|
Cfg *ConfigSite
|
||
|
}
|
||
|
|
||
|
func NewDbHandler(dbpool db.DB, cfg *ConfigSite) *DbHandler {
|
||
|
return &DbHandler{
|
||
|
DBPool: dbpool,
|
||
|
Cfg: cfg,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (h *DbHandler) Validate(s ssh.Session) error {
|
||
|
var err error
|
||
|
key, err := util.KeyText(s)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("key not found")
|
||
|
}
|
||
|
|
||
|
user, err := h.DBPool.FindUserForKey(s.User(), key)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
if user.Name == "" {
|
||
|
return fmt.Errorf("must have username set")
|
||
|
}
|
||
|
|
||
|
h.User = user
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (h *DbHandler) Write(s ssh.Session, entry *utils.FileEntry) (string, error) {
|
||
|
logger := h.Cfg.Logger
|
||
|
userID := h.User.ID
|
||
|
filename := entry.Name
|
||
|
title := filename
|
||
|
var err error
|
||
|
post, err := h.DBPool.FindPostWithFilename(filename, userID)
|
||
|
if err != nil {
|
||
|
logger.Debug("unable to load post, continuing:", err)
|
||
|
}
|
||
|
|
||
|
user, err := h.DBPool.FindUser(userID)
|
||
|
if err != nil {
|
||
|
return "", fmt.Errorf("error for %s: %v", filename, err)
|
||
|
}
|
||
|
|
||
|
var text string
|
||
|
if b, err := io.ReadAll(entry.Reader); err == nil {
|
||
|
text = string(b)
|
||
|
}
|
||
|
|
||
|
if !IsTextFile(text, entry.Filepath) {
|
||
|
logger.Errorf("WARNING: (%s) invalid file, the contents must be plain text, skipping", entry.Name)
|
||
|
return "", fmt.Errorf("WARNING: (%s) invalid file, the contents must be plain text, skipping", entry.Name)
|
||
|
}
|
||
|
|
||
|
// if the file is empty we remove it from our database
|
||
|
if len(text) == 0 {
|
||
|
// skip empty files from being added to db
|
||
|
if post == nil {
|
||
|
logger.Infof("(%s) is empty, skipping record", filename)
|
||
|
return "", nil
|
||
|
}
|
||
|
|
||
|
err := h.DBPool.RemovePosts([]string{post.ID})
|
||
|
logger.Infof("(%s) is empty, removing record", filename)
|
||
|
if err != nil {
|
||
|
logger.Errorf("error for %s: %v", filename, err)
|
||
|
return "", fmt.Errorf("error for %s: %v", filename, err)
|
||
|
}
|
||
|
} else if post == nil {
|
||
|
publishAt := time.Now()
|
||
|
logger.Infof("(%s) not found, adding record", filename)
|
||
|
_, err = h.DBPool.InsertPost(userID, filename, title, text, "", &publishAt)
|
||
|
if err != nil {
|
||
|
logger.Errorf("error for %s: %v", filename, err)
|
||
|
return "", fmt.Errorf("error for %s: %v", filename, err)
|
||
|
}
|
||
|
} else {
|
||
|
publishAt := post.PublishAt
|
||
|
if text == post.Text {
|
||
|
logger.Infof("(%s) found, but text is identical, skipping", filename)
|
||
|
return h.Cfg.PostURL(user.Name, filename), nil
|
||
|
}
|
||
|
|
||
|
logger.Infof("(%s) found, updating record", filename)
|
||
|
_, err = h.DBPool.UpdatePost(post.ID, title, text, "", publishAt)
|
||
|
if err != nil {
|
||
|
logger.Errorf("error for %s: %v", filename, err)
|
||
|
return "", fmt.Errorf("error for %s: %v", filename, err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return h.Cfg.PostURL(user.Name, filename), nil
|
||
|
}
|