package main import ( "database/sql" "encoding/json" "fmt" "log" "os" "os/signal" "slices" "sort" "syscall" "time" "github.com/gorilla/websocket" _ "github.com/lib/pq" ) type SubscriptionMessage struct { APIKey string `json:"APIKey"` MessageType []int `json:"MessageType"` FiltersShipMMSI []string `json:"FiltersShipMMSI"` } type Position struct { Type string `json:"type"` Coordinates []float64 `json:"coordinates"` } type Geometry struct { Type string `json:"type"` Coordinates []float64 `json:"coordinates"` } type MetaData struct { MMSI int `json:"MMSI"` ShipName string `json:"ShipName"` Latitude float64 `json:"latitude"` Longitude float64 `json:"longitude"` TimeUtc string `json:"time_utc"` } type PositionReport struct { Cog float64 `json:"Cog"` NavigationalStatus int `json:"NavigationalStatus"` RateOfTurn float64 `json:"RateOfTurn"` Sog float64 `json:"Sog"` TrueHeading int `json:"TrueHeading"` UserID int `json:"UserID"` Latitude float64 `json:"Latitude"` Longitude float64 `json:"Longitude"` } type Message struct { PositionReport *PositionReport `json:"PositionReport"` } type AisStreamMessage struct { MessageType string `json:"MessageType"` MetaData MetaData `json:"MetaData"` Message Message `json:"Message"` } func getEnv(key string) string { val := os.Getenv(key) if val == "" { log.Fatalf("required env var %s not set", key) } return val } func connectDB() *sql.DB { dsn := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=disable", getEnv("POSTGRES_HOST"), getEnv("POSTGRES_PORT"), getEnv("POSTGRES_USER"), getEnv("POSTGRES_PASSWORD"), getEnv("POSTGRES_DB"), ) db, err := sql.Open("postgres", dsn) if err != nil { log.Fatalf("failed to open db: %v", err) } for i := 0; i < 30; i++ { if err := db.Ping(); err == nil { log.Println("connected to postgres") return db } log.Printf("waiting for postgres (attempt %d/30)...", i+1) time.Sleep(2 * time.Second) } log.Fatal("could not connect to postgres after 30 attempts") return nil } func createTables(db *sql.DB) { _, err := db.Exec(` CREATE TABLE IF NOT EXISTS aisstream_pos ( id BIGSERIAL PRIMARY KEY, mmsi INTEGER NOT NULL, ship_name VARCHAR, latitude DOUBLE PRECISION NOT NULL, longitude DOUBLE PRECISION NOT NULL, cog DOUBLE PRECISION, sog DOUBLE PRECISION, heading INTEGER, nav_status INTEGER, received_at TIMESTAMPTZ NOT NULL DEFAULT NOW() ); CREATE INDEX IF NOT EXISTS idx_aisstream_pos_received_at ON aisstream_pos (received_at); CREATE INDEX IF NOT EXISTS idx_aisstream_pos_mmsi ON aisstream_pos (mmsi); CREATE TABLE IF NOT EXISTS aisstream_last ( mmsi INTEGER PRIMARY KEY, ship_name VARCHAR, latitude DOUBLE PRECISION NOT NULL, longitude DOUBLE PRECISION NOT NULL, cog DOUBLE PRECISION, sog DOUBLE PRECISION, heading INTEGER, nav_status INTEGER, received_at TIMESTAMPTZ NOT NULL DEFAULT NOW() ); `) if err != nil { log.Fatalf("failed to create tables: %v", err) } log.Println("tables ready") } func loadMMSIs(db *sql.DB) []string { rows, err := db.Query("SELECT feed_id FROM track_devices WHERE device_type = 'AIS' AND feed_id IS NOT NULL") if err != nil { log.Fatalf("failed to query track_devices: %v", err) } defer rows.Close() var mmsis []string for rows.Next() { var mmsi string if err := rows.Scan(&mmsi); err != nil { log.Fatalf("failed to scan mmsi: %v", err) } mmsis = append(mmsis, mmsi) } log.Printf("loaded %d AIS MMSIs", len(mmsis)) return mmsis } func purgeOldPositions(db *sql.DB) { ticker := time.NewTicker(10 * time.Minute) defer ticker.Stop() for range ticker.C { res, err := db.Exec("DELETE FROM aisstream_pos WHERE received_at < NOW() - INTERVAL '24 hours'") if err != nil { log.Printf("purge error: %v", err) continue } n, _ := res.RowsAffected() if n > 0 { log.Printf("purged %d old positions", n) } } } func insertPosition(db *sql.DB, mmsi int, shipName string, lat, lon, cog, sog float64, heading, navStatus int) { _, err := db.Exec(` INSERT INTO aisstream_pos (mmsi, ship_name, latitude, longitude, cog, sog, heading, nav_status) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)`, mmsi, shipName, lat, lon, cog, sog, heading, navStatus, ) if err != nil { log.Printf("insert pos error: %v", err) } } func upsertLast(db *sql.DB, mmsi int, shipName string, lat, lon, cog, sog float64, heading, navStatus int) { _, err := db.Exec(` INSERT INTO aisstream_last (mmsi, ship_name, latitude, longitude, cog, sog, heading, nav_status, received_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, NOW()) ON CONFLICT (mmsi) DO UPDATE SET ship_name = EXCLUDED.ship_name, latitude = EXCLUDED.latitude, longitude = EXCLUDED.longitude, cog = EXCLUDED.cog, sog = EXCLUDED.sog, heading = EXCLUDED.heading, nav_status = EXCLUDED.nav_status, received_at = EXCLUDED.received_at`, mmsi, shipName, lat, lon, cog, sog, heading, navStatus, ) if err != nil { log.Printf("upsert last error: %v", err) } } func connectWebSocket(apiKey string, mmsis []string) *websocket.Conn { for { ws, _, err := websocket.DefaultDialer.Dial("wss://stream.aisstream.io/v0/stream", nil) if err != nil { log.Printf("websocket dial error: %v, retrying in 10s...", err) time.Sleep(10 * time.Second) continue } // MessageType 1,2,3 are position reports sub := SubscriptionMessage{ APIKey: apiKey, MessageType: []int{1, 2, 3}, FiltersShipMMSI: mmsis, } msg, _ := json.Marshal(sub) if err := ws.WriteMessage(websocket.TextMessage, msg); err != nil { log.Printf("subscribe error: %v, retrying in 10s...", err) ws.Close() time.Sleep(10 * time.Second) continue } log.Printf("subscribed to aisstream with %d MMSIs", len(mmsis)) return ws } } func mmsiEqual(a, b []string) bool { if len(a) != len(b) { return false } sa := make([]string, len(a)) sb := make([]string, len(b)) copy(sa, a) copy(sb, b) sort.Strings(sa) sort.Strings(sb) return slices.Equal(sa, sb) } func main() { log.SetFlags(log.Ldate | log.Ltime | log.Lmsgprefix) log.SetPrefix("[aisstream] ") db := connectDB() defer db.Close() createTables(db) mmsis := loadMMSIs(db) if len(mmsis) == 0 { log.Fatal("no AIS devices found in track_devices") } apiKey := getEnv("AISSTREAM_API_KEY") go purgeOldPositions(db) reconnectCh := make(chan []string, 1) go func() { ticker := time.NewTicker(5 * time.Minute) defer ticker.Stop() for range ticker.C { updated := loadMMSIs(db) if len(updated) > 0 && !mmsiEqual(mmsis, updated) { log.Printf("MMSI list changed (%d -> %d), reconnecting...", len(mmsis), len(updated)) reconnectCh <- updated } } }() sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) ws := connectWebSocket(apiKey, mmsis) defer ws.Close() go func() { <-sigCh log.Println("shutting down...") ws.Close() db.Close() os.Exit(0) }() msgCh := make(chan []byte, 64) go func() { for { _, p, err := ws.ReadMessage() if err != nil { log.Printf("read error: %v", err) msgCh <- nil return } msgCh <- p } }() var count int64 for { select { case newMMSIs := <-reconnectCh: mmsis = newMMSIs ws.Close() ws = connectWebSocket(apiKey, mmsis) go func() { for { _, p, err := ws.ReadMessage() if err != nil { log.Printf("read error: %v", err) msgCh <- nil return } msgCh <- p } }() case p := <-msgCh: if p == nil { log.Println("reconnecting after read error...") ws.Close() ws = connectWebSocket(apiKey, mmsis) go func() { for { _, p, err := ws.ReadMessage() if err != nil { log.Printf("read error: %v", err) msgCh <- nil return } msgCh <- p } }() continue } var packet AisStreamMessage if err := json.Unmarshal(p, &packet); err != nil { log.Printf("unmarshal error: %v", err) continue } if packet.MessageType != "PositionReport" { continue } pr := packet.Message.PositionReport if pr == nil { continue } insertPosition(db, pr.UserID, packet.MetaData.ShipName, pr.Latitude, pr.Longitude, pr.Cog, pr.Sog, pr.TrueHeading, pr.NavigationalStatus) upsertLast(db, pr.UserID, packet.MetaData.ShipName, pr.Latitude, pr.Longitude, pr.Cog, pr.Sog, pr.TrueHeading, pr.NavigationalStatus) count++ if count%100 == 0 { log.Printf("processed %d positions", count) } } } }