package internal import ( "fmt" "io" "time" "git.sr.ht/~erock/wish/cms/db" "git.sr.ht/~erock/wish/cms/util" "git.sr.ht/~erock/wish/send/utils" "github.com/gliderlabs/ssh" ) type Opener struct { entry *utils.FileEntry } func (o *Opener) Open(name string) (io.Reader, error) { return o.entry.Reader, nil } type DbHandler struct { User *db.User DBPool db.DB Cfg *ConfigSite } func NewDbHandler(dbpool db.DB, cfg *ConfigSite) *DbHandler { return &DbHandler{ DBPool: dbpool, Cfg: cfg, } } func (h *DbHandler) Validate(s ssh.Session) error { var err error key, err := util.KeyText(s) if err != nil { return fmt.Errorf("key not found") } user, err := h.DBPool.FindUserForKey(s.User(), key) if err != nil { return err } if user.Name == "" { return fmt.Errorf("must have username set") } h.User = user return nil } func (h *DbHandler) Write(s ssh.Session, entry *utils.FileEntry) (string, error) { logger := h.Cfg.Logger userID := h.User.ID filename := entry.Name title := filename var err error post, err := h.DBPool.FindPostWithFilename(filename, userID) if err != nil { logger.Debug("unable to load post, continuing:", err) } user, err := h.DBPool.FindUser(userID) if err != nil { return "", fmt.Errorf("error for %s: %v", filename, err) } var text string if b, err := io.ReadAll(entry.Reader); err == nil { text = string(b) } if !IsTextFile(text, entry.Filepath) { logger.Errorf("WARNING: (%s) invalid file, the contents must be plain text, skipping", entry.Name) return "", fmt.Errorf("WARNING: (%s) invalid file, the contents must be plain text, skipping", entry.Name) } // if the file is empty we remove it from our database if len(text) == 0 { // skip empty files from being added to db if post == nil { logger.Infof("(%s) is empty, skipping record", filename) return "", nil } err := h.DBPool.RemovePosts([]string{post.ID}) logger.Infof("(%s) is empty, removing record", filename) if err != nil { logger.Errorf("error for %s: %v", filename, err) return "", fmt.Errorf("error for %s: %v", filename, err) } } else if post == nil { publishAt := time.Now() logger.Infof("(%s) not found, adding record", filename) _, err = h.DBPool.InsertPost(userID, filename, title, text, "", &publishAt) if err != nil { logger.Errorf("error for %s: %v", filename, err) return "", fmt.Errorf("error for %s: %v", filename, err) } } else { publishAt := post.PublishAt if text == post.Text { logger.Infof("(%s) found, but text is identical, skipping", filename) return h.Cfg.PostURL(user.Name, filename), nil } logger.Infof("(%s) found, updating record", filename) _, err = h.DBPool.UpdatePost(post.ID, title, text, "", publishAt) if err != nil { logger.Errorf("error for %s: %v", filename, err) return "", fmt.Errorf("error for %s: %v", filename, err) } } return h.Cfg.PostURL(user.Name, filename), nil }