1
0
Fork 0
mirror of https://github.com/Eggbertx/gochan.git synced 2025-09-04 10:06:24 -07:00

Fixed referer checking, added more testing for server and serverutil packages

This commit is contained in:
Eggbertx 2025-02-10 21:41:21 -08:00
parent d61e3c62c2
commit 5a7695e98f
12 changed files with 375 additions and 48 deletions

View file

@ -20,6 +20,7 @@
"DBprefix": "gc_",
"_DBprefix_info": "The prefix automataically applied to tables when the database is being provisioned and queried",
"CheckRequestReferer": true,
"Lockdown": false,
"LockdownMessage": "This imageboard has temporarily disabled posting. We apologize for the inconvenience",
"Modboard": "staff",

View file

@ -41,11 +41,15 @@ type GochanConfig struct {
// ValidateValues checks to make sure that the configuration options are usable
// (e.g., ListenIP is a valid IP address, Port isn't a negative number, etc)
func (gcfg *GochanConfig) ValidateValues() error {
// if net.ParseIP(gcfg.ListenIP) == nil {
// return &InvalidValueError{Field: "ListenIP", Value: gcfg.ListenIP}
// }
changed := false
if gcfg.SiteDomain == "" {
return &InvalidValueError{Field: "SiteDomain", Value: gcfg.SiteDomain, Details: "must be set"}
}
if strings.Contains(gcfg.SiteDomain, " ") || strings.Contains(gcfg.SiteDomain, "://") {
return &InvalidValueError{Field: "SiteDomain", Value: gcfg.SiteDomain, Details: "must be a host (port optional)"}
}
_, err := durationutil.ParseLongerDuration(gcfg.CookieMaxAge)
if errors.Is(err, durationutil.ErrInvalidDurationString) {
return &InvalidValueError{Field: "CookieMaxAge", Value: gcfg.CookieMaxAge, Details: err.Error() + cookieMaxAgeEx}
@ -162,10 +166,11 @@ type SystemCriticalConfig struct {
SQLConfig
Verbose bool `json:"DebugMode"`
RandomSeed string
Version *GochanVersion `json:"-"`
TimeZone int `json:"-"`
CheckRequestReferer bool
Verbose bool `json:"DebugMode"`
RandomSeed string
Version *GochanVersion `json:"-"`
TimeZone int `json:"-"`
}
// SiteConfig contains information about the site/community, e.g. the name of the site, the slogan (if set),

View file

@ -10,6 +10,7 @@ var (
DBMaxIdleConnections: DefaultSQLMaxConns,
DBConnMaxLifetimeMin: DefaultSQLConnMaxLifetimeMin,
},
CheckRequestReferer: true,
},
SiteConfig: SiteConfig{
FirstPage: []string{"index.html", "firstrun.html", "1.html"},

View file

@ -18,8 +18,24 @@ func LValueToInterface(l *lua.LState, v lua.LValue) any {
case lua.LTUserData:
l.Push(v)
return l.CheckUserData(l.GetTop()).Value
case lua.LTTable:
t := v.(*lua.LTable)
tableLength := t.Len()
if tableLength > 0 {
// Array
arr := make([]any, tableLength)
for i := 1; i <= tableLength; i++ {
arr[i-1] = LValueToInterface(l, t.RawGetInt(i))
}
return arr
}
m := make(map[string]any)
t.ForEach(func(k, v lua.LValue) {
m[k.String()] = LValueToInterface(l, v)
})
return m
default:
l.RaiseError("Incompatible Lua type")
l.ArgError(2, "Incompatible Lua type")
}
return nil
}

View file

@ -5,9 +5,10 @@ import (
"context"
"database/sql"
"errors"
"net"
"net/http"
"regexp"
"strconv"
"strings"
"time"
"github.com/Eggbertx/durationutil"
@ -22,8 +23,6 @@ import (
)
var (
chopPortNumRegex = regexp.MustCompile(`(.+|\w+):(\d+)$`)
ErrSpambot = errors.New("request looks like a spambot")
ErrBadCredentials = errors.New("invalid username or password")
ErrUnableToCreateSession = errors.New("unable to create login session")
ErrInvalidSession = errors.New("invalid staff session")
@ -37,26 +36,42 @@ var (
func createSession(key, username, password string, request *http.Request, writer http.ResponseWriter) error {
domain := request.Host
errEv := gcutil.LogError(nil).
Str("staff", username).
Str("IP", gcutil.GetRealIP(request))
defer errEv.Discard()
infoEv, errEv := gcutil.LogRequest(request)
defer func() {
infoEv.Discard()
errEv.Discard()
}()
domain = chopPortNumRegex.Split(domain, -1)[0]
if !serverutil.ValidReferer(request) {
gcutil.LogWarning().
Str("staff", username).
Str("IP", gcutil.GetRealIP(request)).
Str("remoteAddr", request.Response.Request.RemoteAddr).
Msg("Rejected login from possible spambot")
return ErrSpambot
if strings.Contains(domain, ":") {
domain, _, err := net.SplitHostPort(domain)
if err != nil {
errEv.Err(err).Caller().Str("host", domain).Send()
return server.NewServerError("Invalid request host", http.StatusBadRequest)
}
}
refererResult, err := serverutil.CheckReferer(request)
if err != nil {
errEv.Err(err).Caller().
Str("staff", username).
Str("referer", request.Referer()).
Msg("Error checking referer")
return err
}
if refererResult != serverutil.InternalReferer {
gcutil.LogWarning().
Int("refererResult", int(refererResult)).
Str("referer", request.Referer()).
Str("siteDomain", config.GetSystemCriticalConfig().SiteDomain).
Str("staff", username).
Msg("Rejected login from possible spambot")
return serverutil.ErrSpambot
}
staff, err := gcsql.GetStaffByUsername(username, true)
if err != nil {
if err != sql.ErrNoRows {
errEv.Err(err).Caller().
Str("remoteAddr", request.RemoteAddr).
Msg("Unrecognized username")
}
return ErrBadCredentials

View file

@ -253,7 +253,13 @@ func MakePost(writer http.ResponseWriter, request *http.Request) {
infoEv, errEv := gcutil.LogRequest(request)
if !serverutil.ValidReferer(request) {
refererResult, err := serverutil.CheckReferer(request)
if err != nil {
errEv.Err(err).Caller().Send()
server.ServeError(writer, "Error checking referer", wantsJSON, nil)
return
}
if refererResult != serverutil.InternalReferer {
// post has no referrer, or has a referrer from a different domain, probably a spambot
gcutil.LogWarning().
Str("spam", "badReferer").

View file

@ -1,28 +1,48 @@
package serverutil
import (
"errors"
"net/http"
"net/url"
"strings"
"github.com/gochan-org/gochan/pkg/config"
"github.com/gochan-org/gochan/pkg/gcutil"
"github.com/rs/zerolog"
)
// ValidReferer checks to make sure that the incoming request is from the same domain (or if debug mode is enabled)
func ValidReferer(request *http.Request, errEv ...*zerolog.Event) bool {
var (
ErrSpambot = errors.New("request looks like spam")
)
const (
// NoReferer is returned when the request has no referer
NoReferer RefererResult = iota
// InvalidReferer is returned when the referer not a valid URL
InvalidReferer
// InternalReferer is returned when the request came from the same site as the server
InternalReferer
// ExternalReferer is returned when the request came from another site. It may or may not be be spam, depending on the context
ExternalReferer
)
type RefererResult int
// CheckReferer checks to make sure that the incoming request is from the same domain
func CheckReferer(request *http.Request) (RefererResult, error) {
referer := request.Referer()
if referer == "" {
return NoReferer, nil
}
rURL, err := url.ParseRequestURI(referer)
if err != nil {
var ev *zerolog.Event
if len(errEv) == 1 {
ev = gcutil.LogError(err).Caller()
} else {
ev = errEv[0].Err(err).Caller()
}
ev.Str("referer", referer).Msg("Error parsing referer URL")
return false
return InvalidReferer, err
}
return strings.Index(rURL.Path, config.GetSystemCriticalConfig().WebRoot) == 0
systemCriticalConfig := config.GetSystemCriticalConfig()
siteURLBase := url.URL{
Host: systemCriticalConfig.SiteDomain,
}
var result RefererResult = ExternalReferer
if rURL.Host == siteURLBase.Host {
result = InternalReferer
}
return result, nil
}

View file

@ -0,0 +1,69 @@
package serverutil
import (
"net/http"
"testing"
"github.com/gochan-org/gochan/pkg/config"
"github.com/stretchr/testify/assert"
)
var (
checkRefererTestCases = []checkRefererTestCase{
{
desc: "Internal referer",
referer: "http://gochan.org",
siteDomain: "gochan.org",
expectedResult: InternalReferer,
},
{
desc: "External referer",
referer: "http://somesketchysite.com",
siteDomain: "gochan.com",
expectedResult: ExternalReferer,
},
{
desc: "No referer",
siteDomain: "gochan.org",
expectedResult: NoReferer,
},
{
desc: "Internal referer with port",
referer: "http://127.0.0.1:8080",
siteDomain: "127.0.0.1:8080",
expectedResult: InternalReferer,
},
{
desc: "Internal referer with port, IPv6",
referer: "http://[::1]:8080",
siteDomain: "[::1]:8080",
expectedResult: InternalReferer,
},
}
)
type checkRefererTestCase struct {
desc string
referer string
siteDomain string
expectedResult RefererResult
}
func TestCheckReferer(t *testing.T) {
config.SetVersion("4.0.0")
systemCriticalConfig := config.GetSystemCriticalConfig()
req, err := http.NewRequest("GET", "http://gochan.org", nil)
if !assert.NoError(t, err) {
t.FailNow()
}
for _, tC := range checkRefererTestCases {
t.Run(tC.desc, func(t *testing.T) {
systemCriticalConfig.SiteDomain = tC.siteDomain
config.SetSystemCriticalConfig(systemCriticalConfig)
req.Header.Set("Referer", tC.referer)
result, err := CheckReferer(req)
assert.NoError(t, err)
assert.Equal(t, tC.expectedResult, result)
})
}
}

View file

@ -14,16 +14,28 @@ func PreloadModule(l *lua.LState) int {
l.SetFuncs(t, map[string]lua.LGFunction{
"minify_template": func(l *lua.LState) int {
tmplUD := l.CheckUserData(1)
tmpl := tmplUD.Value.(*template.Template)
dataTable := l.CheckTable(2)
data := map[string]any{}
dataTable.ForEach(func(l1, l2 lua.LValue) {
data[l1.String()] = luautil.LValueToInterface(l, l2)
})
tmplLV := l.CheckAny(1)
var tmpl *template.Template
var tmplStr string
switch tmplLV.Type() {
case lua.LTUserData:
tmpl = tmplLV.(*lua.LUserData).Value.(*template.Template)
case lua.LTString:
tmplStr = tmplLV.String()
default:
l.ArgError(1, "expected string or template")
return 0
}
data := luautil.LValueToInterface(l, l.CheckAny(2))
writer := l.CheckUserData(3).Value.(io.Writer)
mediaType := l.CheckString(4)
err := MinifyTemplate(tmpl, data, writer, mediaType)
var err error
switch tmplLV.Type() {
case lua.LTString:
err = MinifyTemplate(tmplStr, data, writer, mediaType)
case lua.LTUserData:
err = MinifyTemplate(tmpl, data, writer, mediaType)
}
l.Push(luar.New(l, err))
return 1
},

View file

@ -0,0 +1,87 @@
package serverutil
import (
"bytes"
"testing"
"text/template"
"github.com/gochan-org/gochan/pkg/config"
"github.com/gochan-org/gochan/pkg/gctemplates"
"github.com/gochan-org/gochan/pkg/gcutil/testutil"
"github.com/stretchr/testify/assert"
lua "github.com/yuin/gopher-lua"
luar "layeh.com/gopher-luar"
)
var (
buf bytes.Buffer
data = map[string]any{
"logText": "text goes here",
}
luaStringTemplateTestCases = []luaTemplateTestCase[string]{
{
desc: "minify HTML",
template: gctemplates.ErrorPage,
data: data,
luaScript: `local serverutil = require("serverutil")
return serverutil.minify_template("manage_viewlog.html", data, buf, "text/html")`,
expectString: `<textarea class=viewlog rows=24 spellcheck=false readonly>text goes here</textarea>`,
},
{
desc: "minify HTML with nil data",
template: gctemplates.ManageViewLog,
luaScript: `local serverutil = require("serverutil")
return serverutil.minify_template("manage_viewlog.html", nil, buf, "text/html")`,
expectString: `<textarea class=viewlog rows=24 spellcheck=false readonly></textarea>`,
},
{
desc: "error, unrecognized template name",
template: "invalid_template",
luaScript: `local serverutil = require("serverutil")
return serverutil.minify_template("invalid_template", nil, buf, "text/html")`,
expectError: true,
},
}
)
type luaTemplateTestCase[T string | template.Template] struct {
desc string
template T
data any
expectError bool
expectString string
luaScript string
}
func TestLuaTemplates(t *testing.T) {
testutil.GoToGochanRoot(t)
config.SetTestTemplateDir("templates")
siteConfig := config.GetSiteConfig()
siteConfig.MinifyJS = true
config.SetSiteConfig(siteConfig)
for _, tC := range luaStringTemplateTestCases {
t.Run(tC.desc, func(t *testing.T) {
buf.Reset()
l := lua.NewState()
defer l.Close()
l.PreloadModule("serverutil", PreloadModule)
l.SetGlobal("buf", luar.New(l, &buf))
l.SetGlobal("data", luar.New(l, tC.data))
if !assert.NoError(t, l.DoString(tC.luaScript)) {
t.FailNow()
}
errLV := l.Get(-1)
var err error
if errLV.Type() != lua.LTNil {
err = errLV.(*lua.LUserData).Value.(error)
}
if tC.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tC.expectString, buf.String())
}
})
}
}

View file

@ -20,5 +20,8 @@ func DeleteCookie(writer http.ResponseWriter, request *http.Request, cookieName
func IsRequestingJSON(request *http.Request) bool {
jsonField := request.FormValue("json")
if jsonField == "" {
jsonField = request.PostFormValue("json")
}
return jsonField == "1" || jsonField == "true"
}

View file

@ -0,0 +1,92 @@
package serverutil
import (
"net/http"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
var (
isRequestingJSONTestCases = []isRequestingJSONTestCase{
{
val: "1",
exp: true,
},
{
val: "on",
exp: false,
},
{
val: "true",
exp: true,
},
{
val: "yes",
exp: false,
},
}
)
type isRequestingJSONTestCase struct {
val string
exp bool
}
func TestIsRequestingJSON(t *testing.T) {
req, _ := http.NewRequest("GET", "http://localhost:8080", nil)
assert.False(t, IsRequestingJSON(req))
for _, tc := range isRequestingJSONTestCases {
t.Run("GET "+tc.val, func(t *testing.T) {
req.Form.Set("json", tc.val)
assert.Equal(t, tc.exp, IsRequestingJSON(req))
req.Form.Del("json")
})
req.Method = "POST"
req.PostFormValue("_")
t.Run("POST "+tc.val, func(t *testing.T) {
req.PostForm.Set("json", tc.val)
assert.Equal(t, tc.exp, IsRequestingJSON(req))
req.PostForm.Del("json")
})
}
}
type testResponseWriter struct {
header http.Header
status int
}
func (w *testResponseWriter) Header() http.Header {
return w.header
}
func (w *testResponseWriter) Write([]byte) (int, error) {
return 0, nil
}
func (w *testResponseWriter) WriteHeader(s int) {
w.status = s
}
func TestDeleteCookie(t *testing.T) {
req, _ := http.NewRequest("GET", "http://localhost:8080", nil)
writer := testResponseWriter{
header: make(http.Header),
}
assert.False(t, DeleteCookie(&writer, req, "test"))
cookie := &http.Cookie{
Name: "test",
Value: "test",
MaxAge: 90,
Expires: time.Now().Add(7 * 24 * time.Hour),
}
req.AddCookie(cookie)
assert.True(t, DeleteCookie(&writer, req, "test"))
cookieExpireStr := writer.header.Get("Set-Cookie")
ct, err := time.ParseInLocation(time.RFC1123, cookieExpireStr[strings.Index(cookieExpireStr, "Expires=")+8:], time.Local)
assert.NoError(t, err)
assert.True(t, ct.Before(time.Now().Add(-7*time.Hour)))
}