aisstream/main.go

361 lines
8.5 KiB
Go
Raw Permalink Normal View History

package main
import (
"database/sql"
"encoding/json"
"fmt"
"log"
"os"
"os/signal"
"slices"
"sort"
"syscall"
"time"
"github.com/gorilla/websocket"
_ "github.com/lib/pq"
)
type SubscriptionMessage struct {
APIKey string `json:"APIKey"`
MessageType []int `json:"MessageType"`
FiltersShipMMSI []string `json:"FiltersShipMMSI"`
}
type Position struct {
Type string `json:"type"`
Coordinates []float64 `json:"coordinates"`
}
type Geometry struct {
Type string `json:"type"`
Coordinates []float64 `json:"coordinates"`
}
type MetaData struct {
MMSI int `json:"MMSI"`
ShipName string `json:"ShipName"`
Latitude float64 `json:"latitude"`
Longitude float64 `json:"longitude"`
TimeUtc string `json:"time_utc"`
}
type PositionReport struct {
Cog float64 `json:"Cog"`
NavigationalStatus int `json:"NavigationalStatus"`
RateOfTurn float64 `json:"RateOfTurn"`
Sog float64 `json:"Sog"`
TrueHeading int `json:"TrueHeading"`
UserID int `json:"UserID"`
Latitude float64 `json:"Latitude"`
Longitude float64 `json:"Longitude"`
}
type Message struct {
PositionReport *PositionReport `json:"PositionReport"`
}
type AisStreamMessage struct {
MessageType string `json:"MessageType"`
MetaData MetaData `json:"MetaData"`
Message Message `json:"Message"`
}
func getEnv(key string) string {
val := os.Getenv(key)
if val == "" {
log.Fatalf("required env var %s not set", key)
}
return val
}
func connectDB() *sql.DB {
dsn := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=disable",
getEnv("POSTGRES_HOST"),
getEnv("POSTGRES_PORT"),
getEnv("POSTGRES_USER"),
getEnv("POSTGRES_PASSWORD"),
getEnv("POSTGRES_DB"),
)
db, err := sql.Open("postgres", dsn)
if err != nil {
log.Fatalf("failed to open db: %v", err)
}
for i := 0; i < 30; i++ {
if err := db.Ping(); err == nil {
log.Println("connected to postgres")
return db
}
log.Printf("waiting for postgres (attempt %d/30)...", i+1)
time.Sleep(2 * time.Second)
}
log.Fatal("could not connect to postgres after 30 attempts")
return nil
}
func createTables(db *sql.DB) {
_, err := db.Exec(`
CREATE TABLE IF NOT EXISTS aisstream_pos (
id BIGSERIAL PRIMARY KEY,
mmsi INTEGER NOT NULL,
ship_name VARCHAR,
latitude DOUBLE PRECISION NOT NULL,
longitude DOUBLE PRECISION NOT NULL,
cog DOUBLE PRECISION,
sog DOUBLE PRECISION,
heading INTEGER,
nav_status INTEGER,
received_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_aisstream_pos_received_at ON aisstream_pos (received_at);
CREATE INDEX IF NOT EXISTS idx_aisstream_pos_mmsi ON aisstream_pos (mmsi);
CREATE TABLE IF NOT EXISTS aisstream_last (
mmsi INTEGER PRIMARY KEY,
ship_name VARCHAR,
latitude DOUBLE PRECISION NOT NULL,
longitude DOUBLE PRECISION NOT NULL,
cog DOUBLE PRECISION,
sog DOUBLE PRECISION,
heading INTEGER,
nav_status INTEGER,
received_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
`)
if err != nil {
log.Fatalf("failed to create tables: %v", err)
}
log.Println("tables ready")
}
func loadMMSIs(db *sql.DB) []string {
rows, err := db.Query("SELECT feed_id FROM track_devices WHERE device_type = 'AIS' AND feed_id IS NOT NULL")
if err != nil {
log.Fatalf("failed to query track_devices: %v", err)
}
defer rows.Close()
var mmsis []string
for rows.Next() {
var mmsi string
if err := rows.Scan(&mmsi); err != nil {
log.Fatalf("failed to scan mmsi: %v", err)
}
mmsis = append(mmsis, mmsi)
}
log.Printf("loaded %d AIS MMSIs", len(mmsis))
return mmsis
}
func purgeOldPositions(db *sql.DB) {
ticker := time.NewTicker(10 * time.Minute)
defer ticker.Stop()
for range ticker.C {
res, err := db.Exec("DELETE FROM aisstream_pos WHERE received_at < NOW() - INTERVAL '24 hours'")
if err != nil {
log.Printf("purge error: %v", err)
continue
}
n, _ := res.RowsAffected()
if n > 0 {
log.Printf("purged %d old positions", n)
}
}
}
func insertPosition(db *sql.DB, mmsi int, shipName string, lat, lon, cog, sog float64, heading, navStatus int) {
_, err := db.Exec(`
INSERT INTO aisstream_pos (mmsi, ship_name, latitude, longitude, cog, sog, heading, nav_status)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)`,
mmsi, shipName, lat, lon, cog, sog, heading, navStatus,
)
if err != nil {
log.Printf("insert pos error: %v", err)
}
}
func upsertLast(db *sql.DB, mmsi int, shipName string, lat, lon, cog, sog float64, heading, navStatus int) {
_, err := db.Exec(`
INSERT INTO aisstream_last (mmsi, ship_name, latitude, longitude, cog, sog, heading, nav_status, received_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, NOW())
ON CONFLICT (mmsi) DO UPDATE SET
ship_name = EXCLUDED.ship_name,
latitude = EXCLUDED.latitude,
longitude = EXCLUDED.longitude,
cog = EXCLUDED.cog,
sog = EXCLUDED.sog,
heading = EXCLUDED.heading,
nav_status = EXCLUDED.nav_status,
received_at = EXCLUDED.received_at`,
mmsi, shipName, lat, lon, cog, sog, heading, navStatus,
)
if err != nil {
log.Printf("upsert last error: %v", err)
}
}
func connectWebSocket(apiKey string, mmsis []string) *websocket.Conn {
for {
ws, _, err := websocket.DefaultDialer.Dial("wss://stream.aisstream.io/v0/stream", nil)
if err != nil {
log.Printf("websocket dial error: %v, retrying in 10s...", err)
time.Sleep(10 * time.Second)
continue
}
// MessageType 1,2,3 are position reports
sub := SubscriptionMessage{
APIKey: apiKey,
MessageType: []int{1, 2, 3},
FiltersShipMMSI: mmsis,
}
msg, _ := json.Marshal(sub)
if err := ws.WriteMessage(websocket.TextMessage, msg); err != nil {
log.Printf("subscribe error: %v, retrying in 10s...", err)
ws.Close()
time.Sleep(10 * time.Second)
continue
}
log.Printf("subscribed to aisstream with %d MMSIs", len(mmsis))
return ws
}
}
func mmsiEqual(a, b []string) bool {
if len(a) != len(b) {
return false
}
sa := make([]string, len(a))
sb := make([]string, len(b))
copy(sa, a)
copy(sb, b)
sort.Strings(sa)
sort.Strings(sb)
return slices.Equal(sa, sb)
}
func main() {
log.SetFlags(log.Ldate | log.Ltime | log.Lmsgprefix)
log.SetPrefix("[aisstream] ")
db := connectDB()
defer db.Close()
createTables(db)
mmsis := loadMMSIs(db)
if len(mmsis) == 0 {
log.Fatal("no AIS devices found in track_devices")
}
apiKey := getEnv("AISSTREAM_API_KEY")
go purgeOldPositions(db)
reconnectCh := make(chan []string, 1)
go func() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for range ticker.C {
updated := loadMMSIs(db)
if len(updated) > 0 && !mmsiEqual(mmsis, updated) {
log.Printf("MMSI list changed (%d -> %d), reconnecting...", len(mmsis), len(updated))
reconnectCh <- updated
}
}
}()
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
ws := connectWebSocket(apiKey, mmsis)
defer ws.Close()
go func() {
<-sigCh
log.Println("shutting down...")
ws.Close()
db.Close()
os.Exit(0)
}()
msgCh := make(chan []byte, 64)
go func() {
for {
_, p, err := ws.ReadMessage()
if err != nil {
log.Printf("read error: %v", err)
msgCh <- nil
return
}
msgCh <- p
}
}()
var count int64
for {
select {
case newMMSIs := <-reconnectCh:
mmsis = newMMSIs
ws.Close()
ws = connectWebSocket(apiKey, mmsis)
go func() {
for {
_, p, err := ws.ReadMessage()
if err != nil {
log.Printf("read error: %v", err)
msgCh <- nil
return
}
msgCh <- p
}
}()
case p := <-msgCh:
if p == nil {
log.Println("reconnecting after read error...")
ws.Close()
ws = connectWebSocket(apiKey, mmsis)
go func() {
for {
_, p, err := ws.ReadMessage()
if err != nil {
log.Printf("read error: %v", err)
msgCh <- nil
return
}
msgCh <- p
}
}()
continue
}
var packet AisStreamMessage
if err := json.Unmarshal(p, &packet); err != nil {
log.Printf("unmarshal error: %v", err)
continue
}
if packet.MessageType != "PositionReport" {
continue
}
pr := packet.Message.PositionReport
if pr == nil {
continue
}
insertPosition(db, pr.UserID, packet.MetaData.ShipName,
pr.Latitude, pr.Longitude, pr.Cog, pr.Sog, pr.TrueHeading, pr.NavigationalStatus)
upsertLast(db, pr.UserID, packet.MetaData.ShipName,
pr.Latitude, pr.Longitude, pr.Cog, pr.Sog, pr.TrueHeading, pr.NavigationalStatus)
count++
if count%100 == 0 {
log.Printf("processed %d positions", count)
}
}
}
}